fix grid orientation
[imago.git] / src / gridf2.py
1 from math import sqrt
2 import random
3 import sys
4
5 import linef as linef
6 import gridf as gridf
7 from manual import lines as g_grid
8 from geometry import l2ad
9 import new_geometry as gm
10
11
12 def plot_line(line, c):
13     points = linef.line_from_angl_dist(line, (520, 390))
14     pyplot.plot(*zip(*points), color=c)
15
16 def dst((x, y), (a, b, c)):
17     return abs(a * x + b * y + c) / sqrt(a*a+b*b)
18
19 def points_to_line((x1, y1), (x2, y2)):
20     return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)
21
22 def to_general(line):
23     points = linef.line_from_angl_dist(line, (520, 390))
24     return points_to_line(*points)
25
26 def nearest(lines, point):
27     return min(map(lambda l: dst(point, l), lines))
28
29 def nearest2(lines, point):
30     return min(map(lambda l: dst(point, points_to_line(*l)), lines))
31
32
33 def generate_models(sgrid, lh):
34     for f in [0, 1, 2, 3, 5, 7, 8, 11, 15, 17]:
35         try:
36             grid = gm.fill(sgrid[0], sgrid[1], lh , f)
37         except ZeroDivisionError:
38             continue
39         grid = [sgrid[0]] + grid + [sgrid[1]]
40         for s in xrange(17 - f):
41             grid = [gm.expand_left(grid, lh)] + grid
42         yield grid
43         for i in xrange(17 - f):
44             grid = grid[1:]
45             grid.append(gm.expand_right(grid, lh))
46             yield grid
47
48 def score(grid, lines, limit):
49     dst = lambda (a, b, c): (a * 260 + b * 195 + c) / sqrt(a*a+b*b)
50     dsg = lambda l: dst(points_to_line(*l))
51     ds = map(dsg, grid)
52     d = max(map(abs, ds))
53     if d > limit:
54         return 999999
55     score = 0
56     for line in lines:
57         s = min(map(lambda g: abs(line[1] - g), ds))
58         s = min(s, 4)
59         score += s
60
61     return score
62
63 def lines2grid(lines, perp_l):
64     b1, b2 = perp_l[0], perp_l[-1]
65     f = lambda l: (gm.intersection(b1, l), gm.intersection(b2, l))
66     return map(f, lines)
67
68 def pertubations(grid, middle_l):
69     corners = [grid[0], grid[-1]]
70     for l in [0, 1]:
71         for c in [0, 1]:
72             for s in [0, 1]:
73                 for x in [-1, 1]:
74                     sgrid = corners
75                     sgrid[l] = list(sgrid[l])
76                     sgrid[l][c] = list(sgrid[l][c])
77                     sgrid[l][c][s] += x
78                     try:
79                         middle = middle_l(sgrid)
80                         lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
81                         sgrid = ([sgrid[0]] +
82                              gm.fill(sgrid[0], sgrid[1], lh, 17) +
83                              [sgrid[1]])
84                     except ZeroDivisionError:
85                         continue
86
87                     yield sgrid
88
89
90 def test(): 
91     import pickle
92     import matplotlib.pyplot as pyplot
93     import time
94
95     size = (520, 390)
96     points = pickle.load(open('edges.pickle'))
97
98     lines = pickle.load(open('lines.pickle'))
99
100     r_lines = pickle.load(open('r_lines.pickle'))
101
102     #pyplot.scatter(*zip(*sum(r_lines, [])))
103     #pyplot.show()
104
105     l1, l2 = lines
106
107     lines_general = map(to_general, sum(lines, []))
108     near_points = [p for p in points if nearest(lines_general, p) <= 2]
109
110     while True:
111         t0 = time.time()
112         sc1, gridv = 999999, None
113         sc2, gridh = 999999, None
114         sc1_n, sc2_n = 999999, 999999
115         gridv_n, gridh_n = None, None
116         for k in range(50):
117             for i in range(5):
118                 l1s = random.sample(l1, 2)
119                 l1s.sort(key=lambda l: l[1])
120                 sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l1s)
121                 middle_l1 = lambda m: ((m, 0),(m, 390))
122                 middle_l = lambda sgrid: middle_l1(gm.intersection((sgrid[0][0], sgrid[1][1]), 
123                                                 (sgrid[0][1], sgrid[1][0]))[0])
124                 middle = middle_l(sgrid)
125                 lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
126                 sc1_n, gridv_n = min(map(lambda g: (score(g, l1, 210), g), generate_models(sgrid, lh)))
127
128                 p = True
129                 while p:
130                     p = False
131                     for ng in pertubations(gridv_n, middle_l): # TODO randomize
132                         sc = score(ng, l1, 210)
133                         if sc < sc1_n:
134                             sc1_n, gridv_n = sc, ng
135                             p = True
136
137                 if sc1_n < sc1:
138                     sc1, gridv = sc1_n, gridv_n
139
140             for i in range(5):
141                 l2s = random.sample(l2, 2)
142                 l2s.sort(key=lambda l: l[1])
143                 sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l2s) 
144                 middle_l1 = lambda m: ((0, m),(520, m))
145                 middle_l = lambda sgrid: middle_l1(gm.intersection((sgrid[0][0], sgrid[1][1]), 
146                                                 (sgrid[0][1], sgrid[1][0]))[1])
147                 middle = middle_l(sgrid)
148                 lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
149                 sc2_n, gridh_n = min(map(lambda g: (score(g, l2, 275), g), generate_models(sgrid, lh)))
150
151                 p = True
152                 while p:
153                     p = False
154                     for ng in pertubations(gridh_n, middle_l): # TODO randomize
155                         sc = score(ng, l2, 275)
156                         if sc < sc2_n:
157                             sc2_n, gridh = sc, ng
158                             p = True
159
160                 if sc2_n < sc2:
161                     sc2, gridh = sc2_n, gridh_n
162
163             gridv, gridh = lines2grid(gridv, gridh), lines2grid(gridh, gridv)
164
165         print time.time() - t0
166         print sc1, sc2
167
168         pyplot.scatter(*zip(*near_points))
169
170         #map(lambda l: plot_line(l, 'g'), l1 + l2)
171         map(lambda l: pyplot.plot(*zip(*l), color='g'), gridv)
172         map(lambda l: pyplot.plot(*zip(*l), color='g'), gridh)
173         #plot_line(l2s[0], 'r')
174         #plot_line(l2s[1], 'r')
175         #plot_line(l1s[0], 'r')
176         #plot_line(l1s[1], 'r')
177         pyplot.xlim(0, 520)
178         pyplot.ylim(0, 390)
179         pyplot.show()
180
181 def find(lines, size, l1, l2, bounds, hough, show_all, do_something, logger):
182     logger("finding the grid")
183     l1, l2 = lines
184     sc1, gridv = 999999, None
185     for i in range(250):
186         l1s = random.sample(l1, 2)
187         l1s.sort(key=lambda l: l[1])
188         sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l1s) 
189         middle = lambda m: ((m, 0),(m, size[1]))
190         middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
191                                         (sgrid[0][1], sgrid[1][0]))[0])
192         lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
193         sc1_n, gridv_n = min(map(lambda g: (score(g, l1, size[1] / 2 + 15), g), generate_models(sgrid, lh)))
194         if sc1_n < sc1:
195             sc1, gridv = sc1_n, gridv_n
196
197     sc2, gridh = 999999, None
198     for i in range(250):
199         l2s = random.sample(l2, 2)
200         l2s.sort(key=lambda l: l[1])
201         sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l2s) 
202         middle = lambda m: ((0, m),(size[0], m))
203         middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
204                                         (sgrid[0][1], sgrid[1][0]))[1])
205         lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
206         sc2_n, gridh_n = min(map(lambda g: (score(g, l2, size[0] / 2 + 15), g), generate_models(sgrid, lh)))
207         if sc2_n < sc2:
208             sc2, gridh = sc2_n, gridh_n
209     gridv, gridh = lines2grid(gridv, gridh), lines2grid(gridh, gridv)
210
211     grid = [gridv, gridh]
212     grid_lines = [[l2ad(l, size) for l in grid[0]], 
213                   [l2ad(l, size) for l in grid[1]]]
214
215     return grid, grid_lines