refactor do_something in linef
[imago.git] / imago_pack / intrsc.py
index ce89cb2..d577205 100644 (file)
@@ -1,10 +1,11 @@
-"""Imago intersections module"""
+"""Imago intersections module."""
 
 from math import cos, tan, pi
 from operator import itemgetter
 
 import ImageDraw
 
 
 from math import cos, tan, pi
 from operator import itemgetter
 
 import ImageDraw
 
+import filters
 import k_means
 import output
 
 import k_means
 import output
 
@@ -36,10 +37,14 @@ def board(image, lines, show_all, do_something):
                 draw.point((x , y), fill=(120, 255, 120))
         do_something(image_g, "intersections")
 
                 draw.point((x , y), fill=(120, 255, 120))
         do_something(image_g, "intersections")
 
+    image_c = filters.color_enhance(image)
+    if show_all:
+        do_something(image_c, "white balance")
+    
     board_raw = []
     
     for line in intersections:
     board_raw = []
     
     for line in intersections:
-        board_raw.append([stone_color_raw(image, intersection) for intersection in
+        board_raw.append([stone_color_raw(image_c, intersection) for intersection in
                       line])
     board_raw = sum(board_raw, [])
 
                       line])
     board_raw = sum(board_raw, [])
 
@@ -50,12 +55,16 @@ def board(image, lines, show_all, do_something):
     if show_all:
         import matplotlib.pyplot as pyplot
         pyplot.scatter(luma, saturation, 
     if show_all:
         import matplotlib.pyplot as pyplot
         pyplot.scatter(luma, saturation, 
-                       color=[(s[2][0]/255., s[2][1]/255., s[2][2]/255., 1.) 
-                              for s in board_raw])
+                       color=[(s[2][0]/255.,
+                               s[2][1]/255.,
+                               s[2][2]/255., 1.) 
+                                   for s in board_raw])
+        pyplot.xlim(0,1)
+        pyplot.ylim(0,1)
         pyplot.show()
 
     clusters = k_means.cluster(3, 2,zip(zip(luma, saturation), range(len(luma))),
         pyplot.show()
 
     clusters = k_means.cluster(3, 2,zip(zip(luma, saturation), range(len(luma))),
-                               [[0., 0.], [0.5, 0.5], [1., 1.]])
+                               [[0., 0.5], [0.5, 0.5], [1., 0.5]])
    #clusters.sort(key=mean_luma)
 
     if show_all:
    #clusters.sort(key=mean_luma)
 
     if show_all:
@@ -65,6 +74,8 @@ def board(image, lines, show_all, do_something):
                                                  color=(0,1,0,1))
         pyplot.scatter([d[0][0] for d in clusters[2]], [d[0][1] for d in clusters[2]],
                                                  color=(0,0,1,1))
                                                  color=(0,1,0,1))
         pyplot.scatter([d[0][0] for d in clusters[2]], [d[0][1] for d in clusters[2]],
                                                  color=(0,0,1,1))
+        pyplot.xlim(0,1)
+        pyplot.ylim(0,1)
         pyplot.show()
 
     clusters[0] = [(p[1], 'B') for p in clusters[0]]
         pyplot.show()
 
     clusters[0] = [(p[1], 'B') for p in clusters[0]]
@@ -89,6 +100,7 @@ def board(image, lines, show_all, do_something):
     return output.Board(19, board_r)
 
 def mean_luma(cluster):
     return output.Board(19, board_r)
 
 def mean_luma(cluster):
+    """Return mean luma of the *cluster* of points."""
     return sum(c[0][0] for c in cluster) / float(len(cluster))
 
 def intersections_from_angl_dist(lines, size, get_all=True):
     return sum(c[0][0] for c in cluster) / float(len(cluster))
 
 def intersections_from_angl_dist(lines, size, get_all=True):
@@ -109,25 +121,28 @@ def intersections_from_angl_dist(lines, size, get_all=True):
     return intersections
    
 def RGBtoSat(c):
     return intersections
    
 def RGBtoSat(c):
+    """Using the HSI color model."""
     max_diff = max(c) - min(c)
     if max_diff == 0:
         return 0
     else:
     max_diff = max(c) - min(c)
     if max_diff == 0:
         return 0
     else:
-        #TODO simplify this
-        return max_diff / float(255. - abs(max(c) + min(c) - 255))
+        return 1. - ((3. * min(c)) / sum(c)) 
 
 def stone_color_raw(image, (x, y)):
     """Given image and coordinates, return stone color."""
 
 def stone_color_raw(image, (x, y)):
     """Given image and coordinates, return stone color."""
+    size = 3 
     suma = []
     suma = []
-    for i in range(-2, 3):
-        for j in range(-2, 3):
+    t = 0
+    for i in range(-size, size + 1):
+        for j in range(-size, size + 1):
             try:
                 suma.append(image.getpixel((x + i, y + j)))
             try:
                 suma.append(image.getpixel((x + i, y + j)))
+                t += 1
             except IndexError:
                 pass
             except IndexError:
                 pass
-    luma = sum([0.30 * sum(s[0] for s in suma) / 25., 0.59 * sum(s[1] for s in suma) / 25.
-            0.11 * sum(s[2] for s in suma) / 25.]) / 255.
-    saturation = sum(RGBtoSat(s) for s in suma) / (25. * 255.)
-    color = [sum(s[0] for s in suma) / 25., sum(s[1] for s in suma) / 25.,
-             sum(s[2] for s in suma) / 25.]
+    luma = sum([0.30 * sum(s[0] for s in suma) / t, 0.59 * sum(s[1] for s in suma) / t
+            0.11 * sum(s[2] for s in suma) / t]) / 255.
+    saturation = sum(RGBtoSat(s) for s in suma) / t
+    color = [sum(s[0] for s in suma) / t, sum(s[1] for s in suma) / t,
+             sum(s[2] for s in suma) / t]
     return luma, saturation, color
     return luma, saturation, color