save figures from matplotlib
[imago.git] / imago_pack / intrsc.py
index 45308b1..1d4dd67 100644 (file)
@@ -1,10 +1,11 @@
-"""Imago intersections module"""
+"""Imago intersections module."""
 
 from math import cos, tan, pi
 from operator import itemgetter
 
 import ImageDraw
 
+import filters
 import k_means
 import output
 
@@ -25,6 +26,7 @@ def dst_sort(lines):
 
 def board(image, lines, show_all, do_something):
     """Compute intersections, find stone colors and return board situation."""
+    # TODO refactor show_all, do_something
     lines = [dst_sort(l) for l in lines]
     intersections = intersections_from_angl_dist(lines, image.size)
 
@@ -36,10 +38,14 @@ def board(image, lines, show_all, do_something):
                 draw.point((x , y), fill=(120, 255, 120))
         do_something(image_g, "intersections")
 
+    image_c = filters.color_enhance(image)
+    if show_all:
+        do_something(image_c, "white balance")
+    
     board_raw = []
     
     for line in intersections:
-        board_raw.append([stone_color_raw(image, intersection) for intersection in
+        board_raw.append([stone_color_raw(image_c, intersection) for intersection in
                       line])
     board_raw = sum(board_raw, [])
 
@@ -49,6 +55,8 @@ def board(image, lines, show_all, do_something):
 
     if show_all:
         import matplotlib.pyplot as pyplot
+        import Image
+        fig = pyplot.figure(figsize=(8, 6))
         pyplot.scatter(luma, saturation, 
                        color=[(s[2][0]/255.,
                                s[2][1]/255.,
@@ -56,13 +64,18 @@ def board(image, lines, show_all, do_something):
                                    for s in board_raw])
         pyplot.xlim(0,1)
         pyplot.ylim(0,1)
-        pyplot.show()
+        fig.canvas.draw()
+        size = fig.canvas.get_width_height()
+        buff = fig.canvas.tostring_rgb()
+        image_p = Image.fromstring('RGB', size, buff, 'raw')
+        do_something(image_p, "color distribution")
 
     clusters = k_means.cluster(3, 2,zip(zip(luma, saturation), range(len(luma))),
                                [[0., 0.5], [0.5, 0.5], [1., 0.5]])
    #clusters.sort(key=mean_luma)
 
     if show_all:
+        fig = pyplot.figure(figsize=(8, 6))
         pyplot.scatter([d[0][0] for d in clusters[0]], [d[0][1] for d in clusters[0]],
                                                  color=(1,0,0,1))
         pyplot.scatter([d[0][0] for d in clusters[1]], [d[0][1] for d in clusters[1]],
@@ -71,7 +84,11 @@ def board(image, lines, show_all, do_something):
                                                  color=(0,0,1,1))
         pyplot.xlim(0,1)
         pyplot.ylim(0,1)
-        pyplot.show()
+        fig.canvas.draw()
+        size = fig.canvas.get_width_height()
+        buff = fig.canvas.tostring_rgb()
+        image_p = Image.fromstring('RGB', size, buff, 'raw')
+        do_something(image_p, "color clustering")
 
     clusters[0] = [(p[1], 'B') for p in clusters[0]]
     clusters[1] = [(p[1], '.') for p in clusters[1]]
@@ -95,6 +112,7 @@ def board(image, lines, show_all, do_something):
     return output.Board(19, board_r)
 
 def mean_luma(cluster):
+    """Return mean luma of the *cluster* of points."""
     return sum(c[0][0] for c in cluster) / float(len(cluster))
 
 def intersections_from_angl_dist(lines, size, get_all=True):