gridf2
[imago.git] / src / ransac.py
1 """RANSAC estimation."""
2
3 import random
4 from math import sqrt
5 import numpy as NP
6
7 # TODO comments
8 # TODO threshold
9
10 def initial_estimate(data):
11     return random.sample(data, 2)
12
13 def points_to_line((x1, y1), (x2, y2)):
14     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
15
16 def filter_near(data, line, distance):
17     a, b, c = line
18     dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
19     is_near = lambda p: dst(p) <= distance
20     return [p for p in data if is_near(p)]
21
22 def least_squares(data):
23     x = NP.matrix([(a, 1) for (a, b) in data])
24     xt = NP.transpose(x)
25     y = NP.matrix([[b] for (a, b) in data])
26     [a,c] = NP.dot(NP.linalg.inv(NP.dot(xt, x)), xt).dot(y).flat
27     return (a, -1, c)
28
29 def get_model(data):
30     if len(data) == 2:
31         return points_to_line(*data)
32     else:
33         return least_squares(data)
34
35 def iterate(data, distance):
36     consensus = 0
37     consensual = initial_estimate(data)
38     while (len(consensual) > consensus):
39         consensus = len(consensual)
40         model = get_model(consensual)
41         consensual = filter_near(data, model, distance)
42     return consensus, model, consensual
43         
44 def estimate(data, dist, k):
45     best = 0
46     model = None
47     consensual = None
48     for i in xrange(0, k):
49         new, new_model, new_consensual  = iterate(data, dist)
50         if new > best:
51             best = new
52             model = new_model
53             consensual = new_consensual
54
55     return model, consensual
56
57 def ransac_duo(data, dist, k, mk):
58     cons = []
59     for i in xrange(mk):
60         model, cons = estimate(set(data) - set(cons), dist, k)
61     return (model, cons), estimate(set(data) - set(cons), dist, k)
62