fix makefile
[imago.git] / src / gridf3.py
index 0697b92..693b8ee 100644 (file)
@@ -12,6 +12,9 @@ from geometry import l2ad
 class GridFittingFailedError(Exception):
     pass
 
+class BadGenError(Exception):
+    pass
+
 def plot_line(line, c, size):
     points = linef.line_from_angl_dist(line, size)
     pyplot.plot(*zip(*points), color=c)
@@ -32,12 +35,16 @@ class Diagonal_model:
                 if l1[i] and l2[j]:
                     yield (l1[i], l2[j])
 
+    def remove(self, data):
+        self.data = list(set(self.data) - set(data))
+
     def initial(self):
         try:
-            return self.gen.next()
+            nxt = self.gen.next()
         except StopIteration:
             self.gen = self.initial_g()
-            return self.gen.next()
+            nxt = self.gen.next()
+        return nxt
 
     def get(self, sample):
         if len(sample) == 2:
@@ -50,11 +57,18 @@ class Diagonal_model:
         score = 0
         a, b, c = est
         dst = lambda (x, y): abs(a * x + b * y + c) / sqrt(a*a+b*b)
+        l1 = None
+        l2 = None
         for p in self.data:
             d = dst(p)
             if d <= dist:
                 cons.append(p)
+                if p.l1 == l1 or p.l2 == l2:
+                    return float("inf"), []
+                else:
+                    l1, l2 = p.l1, p.l2
             score += min(d, dist)
+
         return score, cons
 
 def intersection((a1, b1, c1), (a2, b2, c2)):