work on the third variant of gridf
[imago.git] / gridf3.py
1 import pickle
2 import matplotlib.pyplot as pyplot
3 import random
4
5 lines = pickle.load(open('lines.pickle'))
6
7 from src.intrsc import intersections_from_angl_dist
8 import src.linef as linef
9 import src.ransac as ransac
10
11 points = intersections_from_angl_dist(lines, (520, 390))
12
13 pyplot.scatter(*zip(*sum(points, [])))
14
15 def plot_line(line, c):
16     points = linef.line_from_angl_dist(line, (520, 390))
17     pyplot.plot(*zip(*points), color=c)
18
19 def plot_line_g((a, b, c), max_x):
20     find_y = lambda x: - (c + a * x) / b
21     pyplot.plot([0, max_x], [find_y(0), find_y(max_x)], color='b')
22
23 class Diagonal_model:
24     def __init__(self, data):
25         self.data = [p for p in sum(data, []) if p]
26         self.lines = data
27         self.gen = self.initial_g()
28
29     def initial_g(self):
30         l1, l2 = random.sample(self.lines, 2)
31         for i in xrange(len(l1)):
32             for j in xrange(len(l2)):
33                 if i == j:
34                     continue
35                 if l1[i] and l2[j]:
36                     yield (l1[i], l2[j])
37
38     def initial(self):
39         try:
40             return self.gen.next()
41         except StopIteration:
42             self.gen = self.initial_g()
43             return self.gen.next()
44
45     def get(self, sample):
46         if len(sample) == 2:
47             return ransac.points_to_line(*sample)
48         else:
49             return ransac.least_squares(sample)
50
51 def intersection((a1, b1, c1), (a2, b2, c2)):
52     delim = float(a1 * b2 - b1 * a2)
53     x = (b1 * c2 - c1 * b2) / delim
54     y = (c1 * a2 - a1 * c2) / delim
55     return x, y
56
57
58
59 while True:
60     line1, cons = ransac.estimate(points, 2, 800, Diagonal_model)
61     points2 = map(lambda l: [(p if not p in cons else None) for p in l], points)
62     line2, cons2 = ransac.estimate(points2, 2, 800, Diagonal_model)
63     center = intersection(line1, line2)
64     
65
66     plot_line_g(line1, 520)
67     plot_line_g(line2, 520)
68     pyplot.scatter(*zip(*sum(points, [])))
69     pyplot.scatter([center[0]], [center[1]], color='r')
70     pyplot.xlim(0, 520)
71     pyplot.ylim(0, 390)
72     pyplot.show()
73
74 #map(lambda l: plot_line(l, 'g'), sum(lines, []))
75
76 pyplot.show()
77
78