fix error in error-handling
[imago.git] / src / ransac.py
1 """RANSAC estimation."""
2
3 import random
4 from math import sqrt
5 import numpy as NP
6
7 # TODO comments
8 # TODO threshold
9
10
11 def points_to_line((x1, y1), (x2, y2)):
12     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
13
14 def filter_near(data, line, distance):
15     a, b, c = line
16     dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
17     is_near = lambda p: dst(p) <= distance
18     return [p for p in data if is_near(p)]
19
20 def least_squares(data):
21     x = NP.matrix([(a, 1) for (a, b) in data])
22     xt = NP.transpose(x)
23     y = NP.matrix([[b] for (a, b) in data])
24     [a,c] = NP.dot(NP.linalg.inv(NP.dot(xt, x)), xt).dot(y).flat
25     return (a, -1, c)
26
27 class Linear_model:
28     def __init__(self, data):
29         self.data = data
30
31     def get(self, sample):
32         if len(sample) == 2:
33             return points_to_line(*sample)
34         else:
35             return least_squares(sample)
36
37     def initial(self):
38         return random.sample(self.data, 2)
39
40 def iterate(model, distance):
41     consensus = 0
42     consensual = model.initial()
43     while (len(consensual) > consensus):
44         consensus = len(consensual)
45         try:
46             estimate = model.get(consensual)
47         except NP.linalg.LinAlgError:
48             pass
49         consensual = filter_near(model.data, estimate, distance)
50     return consensus, estimate, consensual
51         
52 def estimate(data, dist, k, modelClass=Linear_model):
53     model = modelClass(data)
54     best = 0
55     estimate = None
56     consensual = None
57     for i in xrange(0, k):
58         new, new_estimate, new_consensual = iterate(model, dist)
59         if new > best:
60             best = new
61             estimate = new_estimate
62             consensual = new_consensual
63
64     return estimate, consensual
65
66 def ransac_duo(data, dist, k, mk, modelClass=Linear_model):
67     cons = []
68     for i in xrange(mk):
69         model, cons = estimate(set(data) - set(cons), dist, k, modelClass)
70     return (model, cons), estimate(set(data) - set(cons), dist, k, modelClass)
71