ransac
[imago.git] / imago_pack / ransac.py
1 """RANSAC estimation."""
2
3 import random
4 from math import sqrt
5 import numpy as NP
6
7 def initial_estimate(data):
8     return random.sample(data, 2)
9
10 def points_to_line((x1, y1), (x2, y2)):
11     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
12
13 def filter_near(data, line, distance):
14     a, b, c = line
15     dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
16     is_near = lambda p: dst(p) <= distance
17     return [p for p in data if is_near(p)]
18
19 def least_squares(data):
20     x = NP.matrix([(a, 1) for (a, b) in data])
21     xt = NP.transpose(x)
22     y = NP.matrix([[b] for (a, b) in data])
23     [a,c] = NP.dot(NP.linalg.inv(NP.dot(xt, x)), xt).dot(y).flat
24     return (a, -1, c)
25
26 def get_model(data):
27     if len(data) == 2:
28         return points_to_line(*data)
29     else:
30         return least_squares(data)
31
32 def iterate(data, distance):
33     consensus = 0
34     consensual = initial_estimate(data)
35     while (len(consensual) > consensus):
36         consensus = len(consensual)
37         model = get_model(consensual)
38         consensual = filter_near(data, model, distance)
39     return consensus, model
40         
41 def estimate(data, dist, k):
42     best = 0
43     model = None
44     for i in xrange(0, k):
45         new, new_model = iterate(data, dist)
46         if new > best:
47             best = new
48             model = new_model
49
50     return model
51