1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| import numpy as np import matplotlib.pyplot as plt def fps(points, n_samples): points = np.array(points)
points_left = np.arange(len(points))
sample_inds = np.zeros(n_samples, dtype='int')
dists = np.ones_like(points_left) * float('inf')
selected = 0 sample_inds[0] = points_left[selected]
points_left = np.delete(points_left, selected)
for i in range(1, n_samples):
last_added = sample_inds[i-1]
dist_to_last_added_point = ( (points[last_added] - points[points_left])**2).sum(-1)
dists[points_left] = np.minimum(dist_to_last_added_point, dists[points_left])
selected = np.argmax(dists[points_left]) sample_inds[i] = points_left[selected]
points_left = np.delete(points_left, selected)
return points[sample_inds]
num,dim = 300,2 x=np.random.randn(num,dim) ds=fps(x,30) print(x.shape) plt.figure() plt.subplot(1,2,1) plt.scatter(x[:,0],x[:,1]) plt.title('Normal Dist') plt.subplot(1,2,2) plt.scatter(ds[:,0],ds[:,1]) plt.title('FPS Dist') plt.show()
|