preliminary work on generalized RANSAC
[imago.git] / src / ransac.py
1 """RANSAC estimation."""
2
3 import random
4 from math import sqrt
5 import numpy as NP
6
7 # TODO comments
8 # TODO threshold
9
10 def points_to_line((x1, y1), (x2, y2)):
11     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
12
13 def filter_near(data, line, distance):
14     a, b, c = line
15     dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
16     is_near = lambda p: dst(p) <= distance
17     return [p for p in data if is_near(p)]
18
19 def least_squares(data):
20     x = NP.matrix([(a, 1) for (a, b) in data])
21     xt = NP.transpose(x)
22     y = NP.matrix([[b] for (a, b) in data])
23     [a,c] = NP.dot(NP.linalg.inv(NP.dot(xt, x)), xt).dot(y).flat
24     return (a, -1, c)
25
26 class Linear_model:
27     def __init__(self, data):
28         self.data = data
29
30     def get(self, sample):
31         if len(sample) == 2:
32             return points_to_line(*sample)
33         else:
34             return least_squares(sample)
35
36     def initial(self):
37         return random.sample(self.data, 2)
38
39     def score(self, est, dist):
40         cons = []
41         score = 0
42         a, b, c = est
43         dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
44         for p in self.data:
45             d = dst(p)
46             if d <= dist:
47                 cons.append(p)
48             score += min(d, dist)
49         return score, cons
50
51 def iterate(model, distance):
52     score = float("inf")
53     consensual = model.initial()
54     estimate = model.get(consensual)
55     new_score, new_consensual = model.score(estimate, distance)
56     if new_consensual != []:
57         while (new_score < score):
58             score, consensual = new_score, new_consensual
59             try:
60                 estimate = model.get(consensual)
61                 new_score, new_consensual = model.score(estimate, distance)
62             except (NP.linalg.LinAlgError):
63                 pass
64     return score, estimate, consensual
65         
66 def estimate(data, dist, k, modelClass=Linear_model, model=None):
67     if not model:
68         model = modelClass(data)
69     best = float("inf")
70     estimate = None
71     consensual = None
72     for i in xrange(0, k):
73         new, new_estimate, new_consensual = iterate(model, dist)
74         if new < best:
75             best = new
76             estimate = new_estimate
77             consensual = new_consensual
78
79     return estimate, consensual
80
81 def ransac_duo(data, dist, k, mk, modelClass=Linear_model):
82     cons = []
83     for i in xrange(mk):
84         model, cons = estimate(set(data) - set(cons), dist, k, modelClass)
85     return (model, cons), estimate(set(data) - set(cons), dist, k, modelClass)
86
87 def ransac_multi(m, data, dist, k, modelClass=Linear_model, model=None):
88     ests = []
89     cons = []
90     for i in xrange(m):
91         est, cons_new = estimate(None, dist, k, model=model)
92         model.remove(cons_new)
93         ests.append(est)
94     return ests
95
96         
97