#!/usr/bin/env python

__version__ = '1.0'
__author__  = "Avinash Kak (kak@purdue.edu)"
__date__    = '2010-August-22'
__url__     = 'http://RVL4.ecn.purdue.edu/~kak/distICP/ICP-1.0.html'
__copyright__ = "(C) 2010 Avinash Kak. Python Software Foundation."


import Image
import ImageFilter
import ImageFont
import ImageDraw
import scipy
import scipy.linalg
import ImageTk
import Tkinter
import sys, os.path

#____________________  Support functions _______________________

# calculate Euclidean distance between two points p and q and
# return the distance as a scalar:
def _euclidean(p, q):
    p,q = scipy.array(p), scipy.array(q)
    return scipy.sqrt( scipy.dot(p-q,p-q) )

# calculate Euclidean distance from a point p to the list of
# points in data and return a list of all such distances
def _dist(p, data):
    return [_euclidean(p,q) for q in data]

def _indexAndLeastDistance(distList):
    minVal = min(distList)
    return distList.index(minVal), minVal    

def _difference(p,q):
    return p[0]-q[0],p[1]-q[1]

def _mean_coordinates(coords_list):
    '''
    Returns a pair of values in the form of a tuple, the pair
    consisting of the mean values for the x and the y coordinates.
    '''
    mean = reduce(lambda x,y: x+y, [x[0] for x in coords_list]), \
           reduce(lambda x,y: x+y, [x[1] for x in coords_list])
    mean = mean[0]/float(len(coords_list)), mean[1]/float(len(coords_list))
    return mean

#---------------------- ICP Class Definition ---------------------

class ICP(object):

    def __init__(self, *args, **kwargs ):
        if args:
            raise ValueError(  
                   '''ICP constructor can only be called
                      with keyword arguments for the following
                      keywords: model_image, data_image, binary_or_color,
                      output_image_size, iterations, error_threshold, 
                      display_step,connectivity_threshold, or
                      save_output_images''')       

        model_image = data_image = output_image_size = iterations = None
        error_threshold = display_step = connectivity_threshold = None
        debug = save_output_images = None

        if kwargs.has_key('model_image'): model_image=kwargs.pop('model_image')
        if kwargs.has_key('data_image'):  data_image=kwargs.pop('data_image')
        if kwargs.has_key('binary_or_color'): \
                                binary_or_color=kwargs.pop('binary_or_color')
        if kwargs.has_key('output_image_size'):  \
                            output_image_size=kwargs.pop('output_image_size')
        if kwargs.has_key('iterations'):  iterations=kwargs.pop('iterations')
        if kwargs.has_key('error_threshold'):  \
                            error_threshold=kwargs.pop('error_threshold')
        if kwargs.has_key('display_step'):  display_step=kwargs.pop('display_step')
        if kwargs.has_key('connectivity_threshold'):  \
                  connectivity_threshold=kwargs.pop('connectivity_threshold')
        if kwargs.has_key('debug'): debug=kwargs.pop('debug')
        if kwargs.has_key('save_output_images'): \
                             save_output_images=kwargs.pop('save_output_images')
        if len(kwargs) != 0:
            raise ValueError('''You have provided unrecognizable keyword args''')

        if model_image: 
            self.model_im = Image.open(model_image)
        else:
            raise ValueError('''You must specify a model image''')
        if data_image: 
            self.data_im =  Image.open(data_image)
        else:
            raise ValueError('''You must specify a data image''')
        if binary_or_color:
            self.binary_or_color = binary_or_color
        else:
            raise ValueError('''You must specify either "binary" or "color" ''')
        if output_image_size:
            self.output_image_size = output_image_size
        elif self.model_im.size[0] <= 100:
                self.output_image_size = self.model_im.size[0]
        else:
            self.output_image_size = 100
        if iterations:
            self.iterations = iterations
        else:
            self.iterations = 10
        if connectivity_threshold:
            self.connectivity_threshold = connectivity_threshold
        else:
            self.connectivity_threshold = 5
        if error_threshold:
            self.error_threshold = error_threshold
        else:
            self.error_threshold = 0.0
        if display_step:
            self.display_step = display_step
        else:
            if self.iterations > 40:
                self.display_step = 5
            else:
                self.display_step = 2
        if debug:
            self.debug = debug
        else:
            self.debug = 0
        if save_output_images:
            self.save_output_images = save_output_images
        else:
            self.save_output_images = 0
        self.tk_images = []

    def extract_model_pixels(self):
        width, height =  self.model_im.size
        if self.debug: print "model image of width ", width, " and height ", height
        if self.binary_or_color == 'color':
            self.model_im.thumbnail( (self.output_image_size,\
                                      self.output_image_size), Image.ANTIALIAS )
            self.model_edge_im = \
                self.model_im.filter(ImageFilter.FIND_EDGES).convert('L').convert('1')
            image_processing_results = \
                     self.extract_pixels_from_color_image( self.model_edge_im )
            self.model_list = image_processing_results[0]
            self.sparse_model_list = image_processing_results[1]
            if self.debug: 
                print "sparse_model_list: ", self.sparse_model_list
                print "number of pixels selected for ICP: ", len(self.sparse_model_list)
                print "number of all edge pixels: ", len(self.model_list)
            self.sparse_model_im = image_processing_results[2]
        else:
            self.model_im = self.model_im.convert("1")
            for i in range(3,width-3): 
                for j in range(3,height-3):
                    if ( self.model_im.getpixel((i,j)) != 0 ):  
                        self.model_list.append( (i,j) )       
        if self.debug:
            print "model list", self.model_list
            print "Number of pixels in model_list: ", len(self.model_list)
            print "\n\n"


    def extract_data_pixels(self):
        '''
        Since extract_model_pixels() and extract_data_pixels() have the same
        logic, why can't we have just one method called with two different
        arguments?  These exist as separate methods to allow for the future
        possibility that one may want to process the model and the data images
        very differently.  Presumably, the model images will be noiseless and
        will be output by some GIS system.  On the other hand, the data images
        can be expected to be noisy and may suffer from optical and other 
        distortions.
        '''
        if self.binary_or_color == 'color':
            self.data_im.thumbnail( (self.output_image_size,\
                                     self.output_image_size), Image.ANTIALIAS )
            self.data_edge_im = \
                self.data_im.filter(ImageFilter.FIND_EDGES).convert('L').convert('1')
            image_processing_results = \
                  self.extract_pixels_from_color_image( self.data_edge_im )
            self.data_list = image_processing_results[0]
            self.sparse_data_list = image_processing_results[1]
            if self.debug: 
                print "sparse_data_list: ", self.sparse_data_list
                print "number of data pixels selected for ICP: ", \
                                             len(self.sparse_data_list)
                print "number of all model edge pixels: ", len(self.data_list)
            self.sparse_data_im = image_processing_results[2]
        else:
            self.data_im = self.data_im.convert("1")
            (width,height) = self.data_im.size
            for i in range(width): 
                for j in range(height):
                    if self.data_im.getpixel((i,j)) != 0:  
                        self.data_list.append( (i,j) )       
        if self.debug:
            print "data list", self.data_list
            print "Number of pixels in data_list: ", len(self.data_list)
            print "\n\n"


    def extract_pixels_from_color_image(self, image):
        '''
        This very simple routine would need to be either replaced in this
        class or overridden in a subclass of ICP for a more practical 
        approach to the selection of pixels for ICP calculation.  At this 
        time, all we do is to choose a pixel if it has at least 
        connectivity_threshold number of pixels in its 5x5 neighborhood.  
        This is done with the hope that pixels on edges are likely 
        to be selected with this criterion.
        '''
        (width,height) = image.size
        im_sparse =  Image.new("1", (width,height), 0)
        sparse_edge_pixels_list = []   
        all_edge_pixels_list = []                
        for i in range(3,width-3):
            for j in range(3,height-3):
                count = 0
                if ( image.getpixel((i,j)) != 0 ):
                    all_edge_pixels_list.append( (i,j) )
                    for k in (-2,-1,1,2):
                        for l in (-2,-1,1,2):
                            if ( image.getpixel((i+k,j+l)) == 255 ):
                                count = count + 1;
                if (count >= self.connectivity_threshold): 
                    im_sparse.putpixel((i,j),255)
                    sparse_edge_pixels_list.append( (i,j) )
        return all_edge_pixels_list, sparse_edge_pixels_list, im_sparse

    def move_to_model_origin(self):
        '''
        Since two patterns in a plane, even when one appears to be a
        rotated version of the other, may be not be related by a 
        Euclidean transform for an arbitrary placement of the origin,
        we will assume that the origin for ICP calculations will be 
        at the "center" of the model image.  Now our goal becomes to
        find an R and a T that will make the data pattern congruent 
        with the model pattern.
        '''
        if self.binary_or_color == 'color':
            self.model_mean =  \
                (scipy.matrix(list(_mean_coordinates(self.sparse_model_list)))).T
            self.zero_mean_sparse_model_list = [(p[0] - self.model_mean[0,0], \
                         p[1] - self.model_mean[1,0]) for p in self.sparse_model_list]
            self.zero_mean_model_list = [(p[0] - self.model_mean[0,0], \
                         p[1] - self.model_mean[1,0]) for p in self.model_list]
            self.zero_mean_sparse_data_list = [(p[0] - self.model_mean[0,0], \
                            p[1] - self.model_mean[1,0]) for p in self.sparse_data_list]
            self.zero_mean_data_list = [(p[0] - self.model_mean[0,0], \
                                   p[1] - self.model_mean[1,0]) for p in self.data_list]
        else: 
            # Need to do the following separately because for we calculate 
            # model mean from only the sparse set for color images
            self.model_mean = (scipy.matrix(list(_mean_coordinates(self.model_list)))).T
            self.zero_mean_model_list = [(p[0] - self.model_mean[0,0], \
                                p[1] - self.model_mean[1,0]) for p in self.model_list]
            self.zero_mean_data_list = [(p[0] - self.model_mean[0,0], \
                                p[1] - self.model_mean[1,0]) for p in self.data_list]
            if self.debug:
                print "model mean:\n", self.model_mean
                print "zero mean model list: ", self.zero_mean_model_list
                print "zero mean data list: ", self.zero_mean_data_list
                print "\n\n"

    def construct_A_matrix(self):
        '''
        In the plane whose origin at the model center, the relationship
        between the model points x_m and the data points x_d is given by

             R . x_d  +  T  =  x_m
        Let the list of the n chosen data points be given by
             A =  [x_d1, x_d2, .....,   x_dn]
        We can now express the relationship between the data and the 
        model points by
             R . A  = B
        where B is the list of the CORRESPONDING model points after 
        we subtract the translation form each:
             B =  [x_m1 - T, x_m2 - T, ......,  x_mn - T] 
        Eventually our goal will be to estimate R from the R.A = B 
        relationship.
        '''
        if self.binary_or_color == 'color':
            data = self.zero_mean_sparse_data_list
        else:
            data = self.zero_mean_data_list
        self.A = scipy.matrix( [ [p[0] for p in data], [p[1] for p in data] ] )

    def construct_AATI_matrix(self):
        '''
        So we want to estimate R from R.A = B.  If we had to construct a 
        one-shot estimate for R (note ICP is iterative and not one-shot),
        we would write R.A=B as 
                 R.A.A^t  =  B.A^t
        and then
                 R =  B . A^t . (A . A^t)^-1
        We will group together what comes after B on the right hand side
        and write
                 AATI  =  A^t . (A . A^t)^-1
        In the ICP implementation, this matrix will remain the same for all
        the iterations.
        '''
        A = self.A
        self.AATI = A.T * scipy.linalg.inv( A * A.T )

    def initialize(self):
        self.R = scipy.matrix( [[1.0, 0.0],[0.0, 1.0]] )
        self.T = (scipy.matrix([[0.0, 0.0]])).T
        self.model_list = []
        self.data_list = []
        self.sparse_model_list = []
        self.sparse_data_list = []
        self.zero_mean_sparse_model_list = []
        self.zero_mean_sparse_data_list = []
        self.zero_mean_model_list = []
        self.zero_mean_data_list = []

    def setSizeForDisplay(self):
        if self.binary_or_color == 'color':
            self.displayWidth, self.displayHeight = self.data_edge_im.size
        else:
            self.displayWidth, self.displayHeight = self.model_im.size

    def icp(self):
        self.initialize()
        self.extract_model_pixels()
        self.extract_data_pixels()
        self.move_to_model_origin()
        self.construct_A_matrix()
        self.construct_AATI_matrix()
        self.setSizeForDisplay()
        old_error = float('inf')
        iteration = 0
        R,T = self.R, self.T
        if self.binary_or_color == "color":
            model = self.zero_mean_sparse_model_list
            data = self.zero_mean_sparse_data_list
        else:
            model = self.zero_mean_model_list
            data = self.zero_mean_data_list
        while 1:
            print "\n>>>>>>>>>>   STARTING ITERATION ", iteration, " OUT OF ", self.iterations, "\n"
            if iteration == self.iterations: break

            data_matrix = [ R * (scipy.matrix( list(data[p]) )).T + T \
                                                for p in range(len(data)) ] 
            data_transformed = [ ( p[0,0], p[1,0] ) for p in data_matrix]

            # For every data point find the closest model point.  The set of such
            # model points will be called the matched_model set of points
            # The following returns a list of pairs, the first the index of 
            # the model point that was found closest to the data point in question,
            # and the second the actual minimal distance
            leastDistMapping = [_indexAndLeastDistance(_dist(p,model)) \
                                                for p in data_transformed]
            matched_model = [model[p[0]] for p in leastDistMapping]
            matched_model_mean = scipy.matrix(list(_mean_coordinates(matched_model))).T
            error = reduce(lambda x, y: x + y, [x[1] for x in leastDistMapping])
            error = error / len(leastDistMapping)
            print "old_error: ", old_error, "    error: ", error
            print "\n"
            diff_error = abs(old_error - error)
            if (diff_error > self.error_threshold):
                old_error = error
            else:
                self.iterations = iteration
                break
            AATI = self.AATI
            B = scipy.matrix([ [p[0] - T[0,0] for p in matched_model], \
                               [p[1] - T[1,0] for p in matched_model] ])
            R_update = B * AATI * R.T
            [U,S,VT] = scipy.linalg.svd(R_update)
            U,VT = scipy.matrix(U), scipy.matrix(VT) 
            deter = scipy.linalg.det(U * VT)
            U[0,1] = U[0,1] * deter
            U[1,1] = U[1,1] * deter
            R_update = U * VT
            R = R_update * R
            print "Rotation:\n", R
            print "\n"
            # Rotate the data for estimating the translation T
            data_matrix2 = [ R * (scipy.matrix( list(data[p]))).T  \
                                                for p in range(len(data)) ] 
            data_transformed2 = [ ( p[0,0], p[1,0] ) for p in data_matrix2]
            data_transformed_mean = \
                  scipy.matrix(list(_mean_coordinates(data_transformed2))).T
            T = matched_model_mean - data_transformed_mean  
            print "Translation:\n", T
            print "\n"
            # Now apply the R,T transformation to the original data for updated image
            # representation of the result at the end of this iteration
            if self.binary_or_color == 'color':
                data_matrix_new = \
                 [ R * (scipy.matrix( list(self.zero_mean_data_list[p]))).T + T \
                                        for p in range(len(self.zero_mean_data_list)) ] 
            else:
                data_matrix_new = [ R * (scipy.matrix( list(data[p]))).T + T \
                                              for p in range(len(data)) ] 
            data_transformed_new = \
                [ ( p[0,0] + self.model_mean[0,0], p[1,0] + self.model_mean[1,0] ) \
                                                       for p in data_matrix_new]
            result_im = Image.new("1", (self.displayWidth,self.displayHeight), 0)
            for p in data_transformed_new:
                x,y = int(p[0]), int(p[1])
                if ( (0 <= x < self.displayWidth) and (0 <= y < self.displayHeight ) ):
                    result_im.putpixel( (x,y), 255 )
            result_im.save( "__result" + str(iteration) + ".jpg")
            iteration = iteration + 1
        self.R,self.T = R,T


    def display_results(self):
        rootWindow = Tkinter.Tk()
        rootWindow.geometry("1200x750+50+50") 
        cellwidth = self.output_image_size
        padding = 10
        padded_cellwidth = cellwidth + 2*padding
        if cellwidth > 80:
fontsize = 20
        else:
fontsize = 10
        import os.path
        if os.path.isfile("times.ttf"):
font = ImageFont.truetype("times.ttf", fontsize)
elif os.path.exists("/usr/share/fonts/truetype/msttcorefonts/times.ttf"):
font = ImageFont.truetype( \
"/usr/share/fonts/truetype/msttcorefonts/times.ttf", fontsize)
        else:
print "Unable to find the font file 'times.ttf' needed for displaying the results"
            sys.exit(1)
        textImage1 = Image.new( "F", (self.displayWidth,self.displayHeight), 200 )
        draw = ImageDraw.Draw(textImage1)
draw.text((10,10), "Model:", font=font)
        textImage2 = Image.new( "F", (self.displayWidth,self.displayHeight), 200 )
        draw = ImageDraw.Draw(textImage2)
draw.text((10,10), "Data:", font=font)
        textImage3 = Image.new( "F", (self.displayWidth,self.displayHeight), 200 )
        draw = ImageDraw.Draw(textImage3)
draw.text((10,10), "Results:", font=font)
        if self.debug: print "cell width for display: ", cellwidth, "\n\n"
        if self.binary_or_color == 'color':
            self.tk_images.append(ImageTk.PhotoImage( textImage1 ))
            self.tk_images.append(ImageTk.PhotoImage( self.model_im ))
            self.tk_images.append(ImageTk.PhotoImage( self.model_edge_im ))            
            self.tk_images.append(ImageTk.PhotoImage( self.sparse_model_im ))
            self.tk_images.append(ImageTk.PhotoImage( textImage2 ))
            self.tk_images.append(ImageTk.PhotoImage( self.data_im ))
            self.tk_images.append(ImageTk.PhotoImage( self.data_edge_im ))
            self.tk_images.append(ImageTk.PhotoImage( self.sparse_data_im ))
        else:
            self.tk_images.append(ImageTk.PhotoImage( textImage1 ))
            self.tk_images.append(ImageTk.PhotoImage( self.model_im ))
            self.tk_images.append(ImageTk.PhotoImage( textImage2 ))
            self.tk_images.append(ImageTk.PhotoImage( self.data_im ))
        for i in range(len(self.tk_images)):
            Tkinter.Label(rootWindow,image=self.tk_images[i],\
                    width=cellwidth).grid(row=0,column=i,padx=10,pady=10)
        print "Will display ", self.iterations, "result images\n\n"
        resultImageLabel = ImageTk.PhotoImage(textImage3)
        Tkinter.Label(rootWindow,image=resultImageLabel,\
                    width=cellwidth).grid(row=1,column=0,padx=10,pady=10)
        tkim = [None] * self.iterations
        for i in range(0,self.iterations,self.display_step):
            tkim[i] = ImageTk.PhotoImage( Image.open( "__result" + str(i) + ".jpg" ) )
            j = i / self.display_step
            if ((j+1)*padded_cellwidth < 1200):
                Tkinter.Label(rootWindow,image=tkim[i],\
                    width=cellwidth).grid(row=1,column=j+1,padx=10,pady=10)
            elif ((j+1)*padded_cellwidth < 2400):
                j = j - (1200-padded_cellwidth) / padded_cellwidth
                Tkinter.Label(rootWindow,image=tkim[i],\
                    width=cellwidth).grid(row=2,column=j,padx=10,pady=10)
            elif ((j+1)*padded_cellwidth < 3600):
                j = j - (2400-padded_cellwidth) / padded_cellwidth
                Tkinter.Label(rootWindow,image=tkim[i],\
                    width=cellwidth).grid(row=3,column=j,padx=10,pady=10)
            else: 
                j = j - (3600-padded_cellwidth) / padded_cellwidth
                Tkinter.Label(rootWindow,image=tkim[i],\
                    width=cellwidth).grid(row=4,column=j,padx=10,pady=10)
        Tkinter.mainloop()

    def __del__(self):
        if not self.save_output_images:
            import glob,os
            for filename in glob.glob( '__result*' ): os.unlink(filename)

    @staticmethod
    def gendata():
        '''
        The code here is just the simplest example of synthetic data
        generation for experimenting with ICP.  You can obviously 
        construct more complex model and data images by calling on the
        other shape drawing primitives of the ImageDraw class.  When
        specifying coordinates, note the following

               .----------> positive x
               |
               |
               |        
               V
             positive y

        A line is drawn from the first pair (x,y) coordinates to the
        second pair.
        '''
        s = 100
        data = Image.new( "L", (s,s), 0 )
        draw = ImageDraw.Draw(data)
        draw.line( (0.8*s, 0, 0, s/2), fill=255 )
        draw.line( (0.8*s, 0, s-1, 0.2*s), fill=255 )
        draw.line( (s-1, 0.2*s, 0, s/2), fill=255 )
        del draw
        data.save( "triangleA.jpg" )

        data = Image.new( "L", (s,s), 0 )
        draw = ImageDraw.Draw(data)
        draw.line( (0, s/2, s-1, 0.8*s), fill=255 )
        draw.line( (0, s/2, 0.8*s, s-1), fill=255 )
        draw.line( (0.8*s, s-1, s-1, 0.8*s), fill=255 )
        del draw
        data.save( "triangleB.jpg" )

        theta = 20
        radians = theta * math.pi / 180
        offset = s * scipy.tan(radians)
        print "offset: ", offset
        data = Image.new( "L", (s,s), 0 )
        draw = ImageDraw.Draw(data)
        draw.line( (0,(s/2 + offset/2)) + (s-1, (s/2 - offset/2)), fill=255 )
        del draw
        data.show()
        data.save( "linedataA.jpg" )
        
        theta = 70
        radians = theta * math.pi / 180
        offset = s * scipy.tan(radians)
        print "offset: ", offset
        data = Image.new( "L", (s,s), 0 )
        draw = ImageDraw.Draw(data)
        draw.line( (0,(s/2 + offset/2)) + (s-1, (s/2 - offset/2)), fill=255 )
        del draw
        data.show()
        data.save( "linedataB.jpg" )
        
#------------------------- End of ICP Class Definition  ---------------------------

#-------------------------    Test code follows         ---------------------------

if __name__ == '__main__': 

#    ICP.gendata()

    '''
    icp = ICP( 
               model_image = "triangle1.jpg",
#               model_image = "linemodel.jpg", 
               data_image =  "triangle2.jpg",
#               data_image =  "linedata.jpg",
               binary_or_color = "binary",
               iterations = 40,
               display_step = 1,
               debug = 1 )
    icp.icp()
    icp.display_results()
    '''

    icp = ICP( 
               model_image = "test_images/football_model.jpg",
               data_image = "test_images/football_query.jpg",
               binary_or_color = "color",
               iterations = 40,
               connectivity_threshold = 5,
               output_image_size = 100,
               display_step = 1,
               debug = 1 )

    icp.icp()
    icp.display_results()