#!/usr/bin/env python

"""Go image recognition"""

import sys
import os
import math
import argparse
from operator import itemgetter

try:
    import Image, ImageDraw
except ImportError, msg:
    print >> sys.stderr, msg
    sys.exit(1)

import im_debug
import linef
import manual

def main():
    """Main function of the program."""
    
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', metavar='file', nargs=1,
                        help="image to analyse")
    parser.add_argument('-w', type=int, default=640,
                        help="scale image to the specified width before analysis")
    parser.add_argument('-m', '--manual', dest='manual_mode', action='store_true',
                        help="manual grid selection")
    parser.add_argument('-d', '--debug', dest='show_all', action='store_true',
                        help="show every step of the computation")
    parser.add_argument('-s', '--save', dest='saving', action='store_true',
                        help="save images instead of displaying them")
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                        help="report progress")
    args = parser.parse_args()

    show_all = args.show_all
    verbose = args.verbose

    try:
        image = Image.open(args.file[0])
    except IOError, msg:
        print >> sys.stderr, msg
        return 1
    if image.mode == 'P':
        image = image.convert('RGB')
    
    if image.size[0] > args.w:
        image = image.resize((args.w, int((float(args.w)/image.size[0]) *
                              image.size[1])), Image.ANTIALIAS)
    do_something = im_debug.show
    if args.saving:
        do_something = imsave("saved/" + args.file[0][:-4] + "_" +
                               str(image.size[0]) + "/").save

    if args.manual_mode:
        try:
            lines = manual.find_lines(image)
        except manual.UserQuitError:
            #TODO ask user to try again
            return 1
    else:
        lines = linef.find_lines(image, show_all, do_something, verbose)

    intersections = intersections_from_angl_dist(lines, image.size)
    image_g = image.copy()
    draw = ImageDraw.Draw(image_g)
    for line in intersections:
        for (x, y) in line:
            draw.point((x , y), fill=(120, 255, 120))
    
    for line in intersections:
        print ' '.join([stone_color(image, intersection) for intersection in
                       line])

    if show_all:
        do_something(image_g, "intersections")

    return 0

def stone_color(image, (x, y)):
    suma = 0.
    for i in range(-2, 3):
        for j in range(-2, 3):
            try:
                suma += sum(image.getpixel((x + i, y + j)))
            except IndexError:
                pass
    suma /= 3 * 25
    if suma < 55:
        return 'B'
    elif suma < 200: 
        return '.'
    else:
        return 'W'

class imsave():
    def __init__(self, saving_dir):
        self.saving_dir = saving_dir
        self.saving_num = 0

    def save(self, image, title=''):
        filename = self.saving_dir + "{0:0>2}".format(self.saving_num) + '.jpg'
        if not os.path.isdir(self.saving_dir):
            os.makedirs(self.saving_dir)
        image.save(filename, 'JPEG')
        self.saving_num += 1

def combine(image1, image2):
    im_l1 = image1.load()
    im_l2 = image2.load()

    on_both = []

    for x in xrange(image1.size[0]):
        for y in xrange(image1.size[1]):
            if im_l1[x, y] and im_l2[x, y]:
                on_both.append((x, y))
    return on_both

def intersections_from_angl_dist(lines, size):
    intersections = []
    for (angl1, dist1) in sorted(lines[1], key=itemgetter(1)):
        line = []
        for (angl2, dist2) in sorted(lines[0], key=itemgetter(1)):
            if abs(angl1 - angl2) > 0.4:
                x =  - ((dist2 / math.cos(angl2))-(dist1 / math.cos(angl1))) / (math.tan(angl1) - math.tan(angl2))
                y = (math.tan(angl1) * x) - (dist1 / math.cos(angl1))
                if (-size[0] / 2 < x < size[0] / 2 and 
                    -size[1] / 2 < y < size[1] / 2):
                    line.append((int(x + size[0] / 2), int(y + size[1] / 2)))
        intersections.append(line)
    return intersections

if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        print "Interrupted."
        sys.exit()
