23 #ifndef LIBMARGRET_SRC_KMEANS_INL_HH 24 #define LIBMARGRET_SRC_KMEANS_INL_HH 34 #include <blitz/array.h> 49 const PointT& p1,
const PointT& p2)
const 51 PointT diff(p1.shape());
53 return blitz::sum(diff*diff);
58 k_(k), max_iterations_(50), max_restarts_(100),
59 variance_threshold_(0),
60 variance_iter_no_improvement_(10),
61 all(
blitz::Range::all())
75 ArrayPointT &best_means,
76 std::vector<int> &labels)
78 num_points_ = points.extent(0);
79 point_dim_ = points.extent(1);
84 s <<
"KMeans: Number of data points " << num_points_
85 <<
" is less than k=" << k_;
88 points_.reference(points);
90 means_.resize(k_, point_dim_);
91 labels_.resize(num_points_, 0);
98 best_means.resize(k_, point_dim_);
99 std::vector<int> &best_labels = labels;
100 best_labels.resize(k_);
101 double min_variance = std::numeric_limits<double>::max();
102 double last_variance = min_variance;
104 for (
int i = 0; i < max_restarts_ && min_variance > variance_threshold_; ++i)
107 iterateUntilConvergence();
108 double variance = totalVariance();
109 if (variance < min_variance)
111 min_variance = variance;
113 best_labels = labels_;
116 if (i % variance_iter_no_improvement_ == 0)
118 if (last_variance == min_variance)
122 last_variance = min_variance;
128 template<
class DataT>
131 for (
int i = 0; i < max_iterations_; ++i)
134 if (!assignLabels())
break;
140 template<
class DataT>
143 std::set<ptrdiff_t> sampled_indices;
144 while (sampled_indices.size() < k_)
146 ptrdiff_t rand_i =
static_cast<ptrdiff_t
>(
147 static_cast<double>(rand()) / (RAND_MAX - 1) *
148 static_cast<double>(num_points_));
149 sampled_indices.insert(rand_i);
152 for (std::set<ptrdiff_t>::const_iterator it = sampled_indices.begin();
153 it != sampled_indices.end(); ++it, ++i)
154 means_(i, all) = points_(*it, all);
157 template<
class DataT>
161 blitz::Array<double, 1> tt(4);
162 blitz::Array<DataT, 1> count_points(k_);
164 for (ptrdiff_t i_p = 0; i_p < num_points_; ++i_p)
166 ptrdiff_t mean_idx = labels_[i_p];
167 means_(mean_idx, all) += points_(i_p, all);
168 ++count_points(mean_idx);
171 blitz::secondIndex j;
172 means_ = means_(i, j) / count_points(i);
177 template<
class DataT>
180 bool changed =
false;
181 for (ptrdiff_t point_idx = 0; point_idx < num_points_; ++point_idx)
183 const PointT& point = points_(point_idx, all);
184 double min_dist = std::numeric_limits<DataT>::max();
185 int old_label = labels_[point_idx];
187 for (ptrdiff_t mean_idx = 0; mean_idx < k_; ++mean_idx)
189 const PointT &mean = means_(mean_idx, all);
190 double dist = (*distance_)(mean, point);
194 labels_[point_idx] = mean_idx;
197 if (!changed && old_label != labels_[point_idx])
203 template<
class DataT>
207 for (ptrdiff_t i = 0; i < num_points_; ++i)
209 double dist = (*distance_)(points_(i, all), means_(labels_[i], all));
210 variance += blitz::pow2(dist);
212 return variance /
static_cast<double>(num_points_);
215 template<
class DataT>
217 const blitz::Array<DataT, 2 > &points,
218 blitz::Array<DataT, 2 > &means,
219 std::vector<int> &labels
224 return kmeans.
cluster(points, means, labels);
double cluster(const ArrayPointT &points, ArrayPointT &means, std::vector< int > &labels)
_KMeans(const unsigned int k, const Distance &distance)
virtual double operator()(const PointT &p1, const PointT &p2) const
double kmeans(const unsigned int k, const blitz::Array< DataT, 2 > &points, blitz::Array< DataT, 2 > &means, std::vector< int > &labels)