parameters
[imago.git] / src / ransac.py
index 4f0d9a0..6557ea4 100644 (file)
@@ -7,7 +7,6 @@ import numpy as NP
 # TODO comments
 # TODO threshold
 
-
 def points_to_line((x1, y1), (x2, y2)):
     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
 
@@ -37,35 +36,58 @@ class Linear_model:
     def initial(self):
         return random.sample(self.data, 2)
 
+    def score(self, est, dist):
+        cons = []
+        score = 0
+        a, b, c = est
+        dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
+        for p in self.data:
+            d = dst(p)
+            if d <= dist:
+                cons.append(p)
+            score += min(d, dist)
+        return score, cons
+
+    def remove(self, data):
+        self.data = list(set(self.data) - set(data))
+
 def iterate(model, distance):
-    consensus = 0
+    score = float("inf")
     consensual = model.initial()
-    while (len(consensual) > consensus):
-        consensus = len(consensual)
-        try:
-            estimate = model.get(consensual)
-        except NP.linalg.LinAlgError:
-            pass
-        consensual = filter_near(model.data, estimate, distance)
-    return consensus, estimate, consensual
+    estimate = model.get(consensual)
+    new_score, new_consensual = model.score(estimate, distance)
+    if new_consensual != []:
+        while (new_score < score):
+            score, consensual = new_score, new_consensual
+            try:
+                estimate = model.get(consensual)
+                new_score, new_consensual = model.score(estimate, distance)
+            except (NP.linalg.LinAlgError):
+                pass
+    return score, estimate, consensual
         
-def estimate(data, dist, k, modelClass=Linear_model):
-    model = modelClass(data)
-    best = 0
+def estimate(data, dist, k, modelClass=Linear_model, model=None):
+    if not model:
+        model = modelClass(data)
+    best = float("inf")
     estimate = None
     consensual = None
     for i in xrange(0, k):
         new, new_estimate, new_consensual = iterate(model, dist)
-        if new > best:
+        if new < best:
             best = new
             estimate = new_estimate
             consensual = new_consensual
 
     return estimate, consensual
 
-def ransac_duo(data, dist, k, mk, modelClass=Linear_model):
+def ransac_multi(m, data, dist, k, modelClass=Linear_model, model=None):
+    if not model:
+        model = modelClass(data)
+    ests = []
     cons = []
-    for i in xrange(mk):
-        model, cons = estimate(set(data) - set(cons), dist, k, modelClass)
-    return (model, cons), estimate(set(data) - set(cons), dist, k, modelClass)
-
+    for i in xrange(m):
+        est, cons_new = estimate(None, dist, k, model=model)
+        model.remove(cons_new)
+        ests.append((est, cons_new))
+    return ests