parameters
[imago.git] / src / ransac.py
index 197fc31..6557ea4 100644 (file)
@@ -7,9 +7,6 @@ import numpy as NP
 # TODO comments
 # TODO threshold
 
-def initial_estimate(data):
-    return random.sample(data, 2)
-
 def points_to_line((x1, y1), (x2, y2)):
     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
 
@@ -26,37 +23,71 @@ def least_squares(data):
     [a,c] = NP.dot(NP.linalg.inv(NP.dot(xt, x)), xt).dot(y).flat
     return (a, -1, c)
 
-def get_model(data):
-    if len(data) == 2:
-        return points_to_line(*data)
-    else:
-        return least_squares(data)
-
-def iterate(data, distance):
-    consensus = 0
-    consensual = initial_estimate(data)
-    while (len(consensual) > consensus):
-        consensus = len(consensual)
-        model = get_model(consensual)
-        consensual = filter_near(data, model, distance)
-    return consensus, model, consensual
+class Linear_model:
+    def __init__(self, data):
+        self.data = data
+
+    def get(self, sample):
+        if len(sample) == 2:
+            return points_to_line(*sample)
+        else:
+            return least_squares(sample)
+
+    def initial(self):
+        return random.sample(self.data, 2)
+
+    def score(self, est, dist):
+        cons = []
+        score = 0
+        a, b, c = est
+        dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
+        for p in self.data:
+            d = dst(p)
+            if d <= dist:
+                cons.append(p)
+            score += min(d, dist)
+        return score, cons
+
+    def remove(self, data):
+        self.data = list(set(self.data) - set(data))
+
+def iterate(model, distance):
+    score = float("inf")
+    consensual = model.initial()
+    estimate = model.get(consensual)
+    new_score, new_consensual = model.score(estimate, distance)
+    if new_consensual != []:
+        while (new_score < score):
+            score, consensual = new_score, new_consensual
+            try:
+                estimate = model.get(consensual)
+                new_score, new_consensual = model.score(estimate, distance)
+            except (NP.linalg.LinAlgError):
+                pass
+    return score, estimate, consensual
         
-def estimate(data, dist, k):
-    best = 0
-    model = None
+def estimate(data, dist, k, modelClass=Linear_model, model=None):
+    if not model:
+        model = modelClass(data)
+    best = float("inf")
+    estimate = None
     consensual = None
     for i in xrange(0, k):
-        new, new_model, new_consensual  = iterate(data, dist)
-        if new > best:
+        new, new_estimate, new_consensual = iterate(model, dist)
+        if new < best:
             best = new
-            model = new_model
+            estimate = new_estimate
             consensual = new_consensual
 
-    return model, consensual
+    return estimate, consensual
 
-def ransac_duo(data, dist, k, mk):
+def ransac_multi(m, data, dist, k, modelClass=Linear_model, model=None):
+    if not model:
+        model = modelClass(data)
+    ests = []
     cons = []
-    for i in xrange(mk):
-        model, cons = estimate(set(data) - set(cons), dist, k)
-    return (model, cons), estimate(set(data) - set(cons), dist, k)
-
+    for i in xrange(m):
+        est, cons_new = estimate(None, dist, k, model=model)
+        model.remove(cons_new)
+        ests.append((est, cons_new))
+    return ests