gitignore
[imago.git] / src / ransac.py
index 4908421..cc16d2f 100644 (file)
@@ -52,18 +52,20 @@ def iterate(model, distance):
     score = float("inf")
     consensual = model.initial()
     estimate = model.get(consensual)
-    new_score, consensual = model.score(estimate, distance)
-    while (new_score < score):
-        score = new_score
-        try:
-            estimate = model.get(consensual)
-            new_score, consensual = model.score(estimate, distance)
-        except NP.linalg.LinAlgError:
-            pass
+    new_score, new_consensual = model.score(estimate, distance)
+    if new_consensual != []:
+        while (new_score < score):
+            score, consensual = new_score, new_consensual
+            try:
+                estimate = model.get(consensual)
+                new_score, new_consensual = model.score(estimate, distance)
+            except (NP.linalg.LinAlgError):
+                pass
     return score, estimate, consensual
         
-def estimate(data, dist, k, modelClass=Linear_model):
-    model = modelClass(data)
+def estimate(data, dist, k, modelClass=Linear_model, model=None):
+    if not model:
+        model = modelClass(data)
     best = float("inf")
     estimate = None
     consensual = None
@@ -82,3 +84,14 @@ def ransac_duo(data, dist, k, mk, modelClass=Linear_model):
         model, cons = estimate(set(data) - set(cons), dist, k, modelClass)
     return (model, cons), estimate(set(data) - set(cons), dist, k, modelClass)
 
+def ransac_multi(m, data, dist, k, modelClass=Linear_model, model=None):
+    ests = []
+    cons = []
+    for i in xrange(m):
+        est, cons_new = estimate(None, dist, k, model=model)
+        model.remove(cons_new)
+        ests.append(est)
+    return ests
+
+        
+