import pickle
import matplotlib.pyplot as pyplot
from math import sqrt
import random
import sys
import time

import src.linef as linef
import src.gridf as gridf
from src.manual import lines as g_grid

import new_geometry as gm

random.seed(12345)

def plot_line(line, c):
    points = linef.line_from_angl_dist(line, (520, 390))
    pyplot.plot(*zip(*points), color=c)

def dst((x, y), (a, b, c)):
    return abs(a * x + b * y + c) / sqrt(a*a+b*b)

def points_to_line((x1, y1), (x2, y2)):
    return (y2 - y1, x1 - x2, x2 * y1 - x1 * y2)

def to_general(line):
    points = linef.line_from_angl_dist(line, (520, 390))
    return points_to_line(*points)

def nearest(lines, point):
    return min(map(lambda l: dst(point, l), lines))

def nearest2(lines, point):
    return min(map(lambda l: dst(point, points_to_line(*l)), lines))

size = (520, 390)

def generate_models(sgrid, lh):
    for f in [0, 1, 2, 3, 5, 7, 8, 11, 15, 17]:
        grid = gm.fill(sgrid[0], sgrid[1], lh , f)
        grid = [sgrid[0]] + grid + [sgrid[1]]
        for s in xrange(17 - f):
            grid = [gm.expand_left(grid, lh)] + grid
        yield grid
        for i in xrange(17 - f):
            grid = grid[1:]
            grid.append(gm.expand_right(grid, lh))
            yield grid

def score(grid, points, limit):
    d = max(map(lambda l: dst((0, 0), points_to_line(*l)), grid + grid))
    if d > limit:
        return 0
    return len([p for p in points if nearest2(grid, p) <= 2])

points = pickle.load(open('edges.pickle'))

lines = pickle.load(open('lines.pickle'))

r_lines = pickle.load(open('r_lines.pickle'))

#pyplot.scatter(*zip(*sum(r_lines, [])))
#pyplot.show()

l1, l2 = lines

lines_general = map(to_general, sum(lines, []))
near_points = [p for p in points if nearest(lines_general, p) <= 2]

while True:
    t0 = time.time()
    #l1s = random.sample(l1, 2)
    l1s = [l1[0], l1[-1]]
    l1s.sort(key=lambda l: l[1])
    sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l1s) 
    middle = lambda m: ((m, 0),(m, 390))
    middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
                                    (sgrid[0][1], sgrid[1][0]))[0])
    lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
    sc, grid = max(map(lambda g: (score(g, points, 400), g), generate_models(sgrid, lh)))
    map(lambda l: pyplot.plot(*zip(*l), color='b'), grid)
    #l2s = random.sample(l2, 2)
    l2s = [l2[0], l2[-1]]
    l2s.sort(key=lambda l: l[1])
    sgrid = map(lambda l:linef.line_from_angl_dist(l, size), l2s) 
    middle = lambda m: ((0, m),(520, m))
    middle = middle(gm.intersection((sgrid[0][0], sgrid[1][1]), 
                                    (sgrid[0][1], sgrid[1][0]))[1])
    lh = (gm.intersection(sgrid[0], middle), gm.intersection(sgrid[1], middle))
    sc, grid = max(map(lambda g: (score(g, points, 530), g), generate_models(sgrid, lh)))
    print time.time() - t0

    pyplot.scatter(*zip(*near_points))
    map(lambda l: pyplot.plot(*zip(*l), color='b'), grid)
    plot_line(l2s[0], 'r')
    plot_line(l2s[1], 'r')
    plot_line(l1s[0], 'r')
    plot_line(l1s[1], 'r')
    pyplot.xlim(0, 520)
    pyplot.ylim(0, 390)
    pyplot.show()

sys.exit()




for l in lines[0]:
    plot_line(l, 'g')

for l in lines[1]:
    plot_line(l, 'g')

pyplot.xlim(0, 520)
pyplot.ylim(0, 390)
pyplot.show()
