未验证 提交 8f26e69f 编写于 作者: T TheBloodthirster 提交者: GitHub

update kmeans (#3489)

Signed-off-by: Nhjp <13606074505@163.com>
上级 80def0d0
......@@ -20,6 +20,7 @@
#include <faiss/utils/random.h>
#include <faiss/utils/distances.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/FaissHook.h>
#include <faiss/IndexFlat.h>
namespace faiss {
......@@ -258,11 +259,115 @@ int split_clusters (size_t d, size_t k, size_t n,
return nsplit;
}
};
KmeansType kmeans_type = KmeansType::KMEANS;
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k,
idx_t nx, const uint8_t *x_in)
{
// centroids with random points from the dataset
rand_perm (centroids_index.data(), nx, random_seed);
}
void Clustering::kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d,
size_t k, idx_t nx, const uint8_t *x_in)
{
FAISS_THROW_IF_NOT_MSG (
n_input_centroids == 0,
"Kmeans plus plus only support the provided input centroids number of zero"
);
};
size_t thread_max_num = omp_get_max_threads();
auto x = reinterpret_cast<const float*>(x_in);
// The square of distance to current centroid
std::vector<float> dx_distance(nx, 1.0 / 0.0);
std::vector<float> pre_sum(nx);
// task of each thread when calculate P(x)
std::vector<size_t> task(thread_max_num, nx);
size_t step = (nx + thread_max_num - 1) / thread_max_num;
for (size_t i = 0; i + 1 < thread_max_num; i++) {
task[i] = (i + 1) * step;
}
// Record the centroids that has been calculated
// Input :
// nx : int -> nb of points
// d : size_t -> nb of dimensions
// k : size_t -> nb of centroids
// x : unsigned char -> data : the x[i*d] means the i-th point's d-th value
// Output:
// centroids : array -> the cluster centers
// 1. get the pre-n-input-centroids: if equal to 0,
// then should get the first random start point
RandomGenerator rng (random_seed);
//if (n_input_centroids == 0) {}
size_t first_center;
first_center = static_cast<size_t>(rng.rand_int64() % nx);
centroids_index[0] = first_center;
// 2. use the first few centroids to calculate the next centroid,and already has first random start point
//size_t current_centroids = n_input_centroids == 0 ? 1 : n_input_centroids;
size_t current_centroids = 1;
// For every epoch there is i-th centroids,and we want to calculate the i+1 centroid
for (size_t i = current_centroids; i < k; i++) {
auto last_centroids_data = x + centroids_index[i - 1] * d;
// for every point
#pragma omp parallel for
for (size_t point_it = 0; point_it < nx; point_it++) {
float distance_of_point_and_centroid = 0;
distance_of_point_and_centroid = fvec_L2sqr((x + point_it * d), last_centroids_data, d);
if (distance_of_point_and_centroid < dx_distance[point_it]) {
dx_distance[point_it] = distance_of_point_and_centroid;
}
}
//calculate P(x)
#pragma omp parallel for
for (size_t point_it = 0; point_it < thread_max_num; point_it++) {
size_t left = point_it == 0 ? 0 : task[point_it - 1];
size_t right = task[point_it];
// cout <<"Thread = "<< omp_get_thread_num() <<" left = "<<left<<" right = "<<right << endl;
pre_sum[left] = dx_distance[left];
for (size_t j = left + 1; j < right; j++) {
pre_sum[j] = pre_sum[j - 1] + dx_distance[j];
}
}
float sum = 0.0;
for (size_t point_it = 0; point_it < thread_max_num; point_it++) {
sum += pre_sum[task[point_it] - 1];
}
// the random num is [0,sum]
float choose_centroid_random = rng.rand_double() * sum;
size_t task_i = 0;
for (task_i = 0; task_i < thread_max_num; task_i++) {
auto task_pre_sum = pre_sum[task[task_i] - 1];
if (choose_centroid_random - task_pre_sum <= 0) {
break;
}
choose_centroid_random -= task_pre_sum;
}
size_t left = task_i == 0 ? 0 : task[task_i - 1];
size_t right = task[task_i];
//find the next centroid using Binary search and the left is what we want
while(left < right) {
size_t mid = left + (right - left) / 2;
if (pre_sum[mid] < choose_centroid_random)
left = mid + 1;
else
right = mid;
}
centroids_index[i] = left;
}
}
void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
const Index * codec, Index & index,
......@@ -384,23 +489,31 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
printf("Outer iteration %d / %d\n", redo, nredo);
}
// initialize (remaining) centroids with random points from the dataset
centroids.resize (d * k);
std::vector<int> perm (nx);
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
{
int64_t random_seed = seed + 1 + redo * 15486557L;
std::vector<int> centroids_index(nx);
if (!codec) {
for (int i = n_input_centroids; i < k ; i++) {
memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
if (KmeansType::KMEANS == kmeans_type) {
//Use classic kmeans algorithm
kmeans_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
} else if (KmeansType::KMEANS_PLUSPLUS == kmeans_type) {
//Use kmeans++ algorithm
kmeans_plus_plus_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
}
} else {
for (int i = n_input_centroids; i < k ; i++) {
codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
centroids.resize(d * k);
if (!codec) {
for (int i = n_input_centroids; i < k; i++) {
memcpy(&centroids[i * d], x + centroids_index[i] * line_size, line_size);
}
} else {
for (int i = n_input_centroids; i < k; i++) {
codec->sa_decode(1, x + centroids_index[i] * line_size, &centroids[i * d]);
}
}
}
post_process_centroids ();
post_process_centroids();
// prepare the index
......
......@@ -15,6 +15,19 @@
namespace faiss {
/**
* The algorithm of Kmeans Type
*/
enum KmeansType
{
KMEANS,
KMEANS_PLUSPLUS,
KMEANS_TWO,
};
//The default algorithm use the KMEANS_PLUSPLUS
extern KmeansType kmeans_type;
/** Class for the clustering parameters. Can be passed to the
* constructor of the Clustering object.
......@@ -87,6 +100,24 @@ struct Clustering: ClusteringParameters {
virtual void train (idx_t n, const float * x, faiss::Index & index,
const float *x_weights = nullptr);
/**
* @brief Kmeans algorithm
*
* @param centroids_index [out] centroids index
* @param random_seed seed for the random number generator
* @param n_input_centroids the number of centroids that user input
* @param d dimension
* @param k number of centroids
* @param nx size of data
* @param x_in data of point
*/
void kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k,
idx_t nx, const uint8_t *x_in);
void kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k,
idx_t nx, const uint8_t *x_in);
/** run with encoded vectors
*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册