提交 34422235 编写于 作者: L Lars Nieradzik

Improved the algorithm.

上级 a884d1d0
# kmeans-anchor-boxes # kmeans-anchor-boxes
K-Means Clustering with the Intersection over Union (IoU) metric as described in the YOLO9000 paper This repository contains an implementation of k-means clustering with the Intersection over Union (IoU) metric as described in the YOLO9000 paper [1].
## Tests
According to the paper we should get 61.0 avg IoU with 5 clusters and 67.2 avg IoU with 9 clusters on the VOC 2007 data set:
![Table](https://i.imgur.com/DoScgDL.png)
First I tried normal k-means clustering:
![k-means k = 5](https://i.imgur.com/lnHijWm.png)
![k-means k = 9](https://i.imgur.com/w0pePI0.png)
As the plots show the algorithm converges to lower values than expected. To resolve this problem, I changed k-means to not run until convergence. Whenever the values started to drop, the algorithm would start again with different initial means. By doing this for about 50 iterations, an average IoU of about 60 was possible.
However, this didn't seem good enough, because now the algorithm has to run for a long time to find the right values. So I started trying out different initialization methods and variants of k-means clustering. In the end the best results were obtained by just using the median to calculate the new centroids.
![k-medians k = 5](https://i.imgur.com/bxtX4cD.png)
![k-medians k = 9](https://i.imgur.com/ly2OGuj.png)
The end result is about 60.15 for k = 5 and 67.13 for k = 9 on the VOC 2007 data set.
## References
[1] J. Redmon and A. Farhadi, “YOLO9000: Better, Faster, Stronger,” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Jul. 2017.
...@@ -45,41 +45,34 @@ def translate_boxes(boxes): ...@@ -45,41 +45,34 @@ def translate_boxes(boxes):
return np.delete(boxes, [0, 1], axis=1) return np.delete(boxes, [0, 1], axis=1)
def kmeans(boxes, k, iterations=10): def kmeans(boxes, k, dist=np.median):
""" """
Calculates k-means clustering with the Intersection over Union (IoU) metric. Calculates k-means clustering with the Intersection over Union (IoU) metric.
:param boxes: numpy array of shape (r, 2), where r is the number of rows :param boxes: numpy array of shape (r, 2), where r is the number of rows
:param k: number of clusters :param k: number of clusters
:param iterations: number of iterations :param dist: distance function
:return: numpy array of shape (k, 2) :return: numpy array of shape (k, 2)
""" """
rows = boxes.shape[0] rows = boxes.shape[0]
distances = np.empty((rows, k)) distances = np.empty((rows, k))
last_clusters = np.zeros((rows,))
result = [0.0, None] # the Forgy method will fail if the whole array contains the same rows
for i in range(0, iterations): clusters = boxes[np.random.choice(rows, k, replace=False)]
# the Forgy method will fail if the whole array contains the same rows
clusters = boxes[np.random.choice(rows, k, replace=False)]
tmp = [0.0, clusters] while True:
while True: for row in range(rows):
for row in range(rows): distances[row] = 1 - iou(boxes[row], clusters)
distances[row] = 1 - iou(boxes[row], clusters)
nearest_clusters = np.argmin(distances, axis=1) nearest_clusters = np.argmin(distances, axis=1)
for cluster in range(k): if (last_clusters == nearest_clusters).all():
clusters[cluster] = np.mean(boxes[nearest_clusters == cluster], axis=0) break
# improve this for cluster in range(k):
avg = avg_iou(boxes, clusters) clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)
if avg > tmp[0]:
tmp = [avg, clusters]
else:
break
if tmp[0] > result[0]: last_clusters = nearest_clusters
result = tmp
return result[1] return clusters
...@@ -30,14 +30,16 @@ class TestVoc2007(TestCase): ...@@ -30,14 +30,16 @@ class TestVoc2007(TestCase):
def test_kmeans_5(self): def test_kmeans_5(self):
dataset = self.__load_dataset() dataset = self.__load_dataset()
out = kmeans(dataset, 5, iterations=50)
out = kmeans(dataset, 5)
percentage = avg_iou(dataset, out) percentage = avg_iou(dataset, out)
np.testing.assert_almost_equal(percentage, 0.61, decimal=2) np.testing.assert_almost_equal(percentage, 0.61, decimal=2)
def test_kmeans_9(self): def test_kmeans_9(self):
dataset = self.__load_dataset() dataset = self.__load_dataset()
out = kmeans(dataset, 9, iterations=50)
out = kmeans(dataset, 9)
percentage = avg_iou(dataset, out) percentage = avg_iou(dataset, out)
np.testing.assert_almost_equal(percentage, 0.672, decimal=2) np.testing.assert_almost_equal(percentage, 0.672, decimal=2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册