color quantization experiment
[imago.git] / src / gridf2.py
index 4f276fd..0988508 100644 (file)
@@ -32,7 +32,10 @@ def nearest2(lines, point):
 
 def generate_models(sgrid, lh):
     for f in [0, 1, 2, 3, 5, 7, 8, 11, 15, 17]:
-        grid = gm.fill(sgrid[0], sgrid[1], lh , f)
+        try:
+            grid = gm.fill(sgrid[0], sgrid[1], lh , f)
+        except ZeroDivisionError:
+            continue
         grid = [sgrid[0]] + grid + [sgrid[1]]
         for s in xrange(17 - f):
             grid = [gm.expand_left(grid, lh)] + grid
@@ -48,11 +51,11 @@ def score(grid, lines, limit):
     ds = map(dsg, grid)
     d = max(map(abs, ds))
     if d > limit:
-        return 999999
+        return float("inf")
     score = 0
     for line in lines:
         s = min(map(lambda g: abs(line[1] - g), ds))
-        s = min(s, 2)
+        s = min(s, 4)
         score += s
 
     return score
@@ -62,6 +65,28 @@ def lines2grid(lines, perp_l):
     f = lambda l: (gm.intersection(b1, l), gm.intersection(b2, l))
     return map(f, lines)
 
+def pertubations(grid, middle_l):
+    corners = [grid[0], grid[-1]]
+    for l in [0, 1]:
+        for c in [0, 1]:
+            for s in [0, 1]:
+                for x in [-1, 1]:
+                    sgrid = corners
+                    sgrid[l] = list(sgrid[l])
+                    sgrid[l][c] = list(sgrid[l][c])
+                    sgrid[l][c][s] += x
+                    try:
+                        middle = middle_l(sgrid)
+                        lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
+                        sgrid = ([sgrid[0]] +
+                             gm.fill(sgrid[0], sgrid[1], lh, 17) +
+                             [sgrid[1]])
+                    except ZeroDivisionError:
+                        continue
+
+                    yield sgrid
+
+
 def test(): 
     import pickle
     import matplotlib.pyplot as pyplot
@@ -84,32 +109,59 @@ def test():
 
     while True:
         t0 = time.time()
-        sc1, gridv = 999999, None
-        for i in range(250):
-            l1s = random.sample(l1, 2)
-            l1s.sort(key=lambda l: l[1])
-            sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l1s) 
-            middle = lambda m: ((m, 0),(m, 390))
-            middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
-                                            (sgrid[0][1], sgrid[1][0]))[0])
-            lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
-            sc1_n, gridv_n = min(map(lambda g: (score(g, l1, 210), g), generate_models(sgrid, lh)))
-            if sc1_n < sc1:
-                sc1, gridv = sc1_n, gridv_n
-
-        sc2, gridh = 999999, None
-        for i in range(250):
-            l2s = random.sample(l2, 2)
-            l2s.sort(key=lambda l: l[1])
-            sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l2s) 
-            middle = lambda m: ((0, m),(520, m))
-            middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
-                                            (sgrid[0][1], sgrid[1][0]))[1])
-            lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
-            sc2_n, gridh_n = min(map(lambda g: (score(g, l2, 275), g), generate_models(sgrid, lh)))
-            if sc2_n < sc2:
-                sc2, gridh = sc2_n, gridh_n
-        gridv, gridh = lines2grid(gridv, gridh), lines2grid(gridh, gridv)
+        sc1, gridv = float("inf"), None
+        sc2, gridh = float("inf"), None
+        sc1_n, sc2_n = float("inf"), float("inf")
+        gridv_n, gridh_n = None, None
+        for k in range(50):
+            for i in range(5):
+                l1s = random.sample(l1, 2)
+                l1s.sort(key=lambda l: l[1])
+                sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l1s)
+                middle_l1 = lambda m: ((m, 0),(m, 390))
+                middle_l = lambda sgrid: middle_l1(gm.intersection((sgrid[0][0], sgrid[1][1]), 
+                                                (sgrid[0][1], sgrid[1][0]))[0])
+                middle = middle_l(sgrid)
+                lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
+                sc1_n, gridv_n = min(map(lambda g: (score(g, l1, 210), g), generate_models(sgrid, lh)))
+
+                p = True
+                while p:
+                    p = False
+                    for ng in pertubations(gridv_n, middle_l): # TODO randomize
+                        sc = score(ng, l1, 210)
+                        if sc < sc1_n:
+                            sc1_n, gridv_n = sc, ng
+                            p = True
+
+                if sc1_n < sc1:
+                    sc1, gridv = sc1_n, gridv_n
+
+            for i in range(5):
+                l2s = random.sample(l2, 2)
+                l2s.sort(key=lambda l: l[1])
+                sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l2s) 
+                middle_l1 = lambda m: ((0, m),(520, m))
+                middle_l = lambda sgrid: middle_l1(gm.intersection((sgrid[0][0], sgrid[1][1]), 
+                                                (sgrid[0][1], sgrid[1][0]))[1])
+                middle = middle_l(sgrid)
+                lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
+                sc2_n, gridh_n = min(map(lambda g: (score(g, l2, 275), g), generate_models(sgrid, lh)))
+
+                p = True
+                while p:
+                    p = False
+                    for ng in pertubations(gridh_n, middle_l): # TODO randomize
+                        sc = score(ng, l2, 275)
+                        if sc < sc2_n:
+                            sc2_n, gridh = sc, ng
+                            p = True
+
+                if sc2_n < sc2:
+                    sc2, gridh = sc2_n, gridh_n
+
+            gridv, gridh = lines2grid(gridv, gridh), lines2grid(gridh, gridv)
+
         print time.time() - t0
         print sc1, sc2
 
@@ -129,7 +181,7 @@ def test():
 def find(lines, size, l1, l2, bounds, hough, show_all, do_something, logger):
     logger("finding the grid")
     l1, l2 = lines
-    sc1, gridv = 999999, None
+    sc1, gridv = float("inf"), None
     for i in range(250):
         l1s = random.sample(l1, 2)
         l1s.sort(key=lambda l: l[1])
@@ -142,7 +194,7 @@ def find(lines, size, l1, l2, bounds, hough, show_all, do_something, logger):
         if sc1_n < sc1:
             sc1, gridv = sc1_n, gridv_n
 
-    sc2, gridh = 999999, None
+    sc2, gridh = float("inf"), None
     for i in range(250):
         l2s = random.sample(l2, 2)
         l2s.sort(key=lambda l: l[1])