Pella86

MyImage_class.py

Oct 4th, 2017
186
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.32 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Jun 19 16:33:12 2017
  4.  
  5. @author: Mauro
  6.  
  7. This class manages gray scale images. The images are stored as mxn arrays and
  8. the class provide basic processing metods
  9. """
  10.  
  11. #==============================================================================
  12. # # Imports
  13. #==============================================================================
  14.  
  15. # numpy import
  16. import numpy as np
  17.  
  18. # matplotlib import
  19. import matplotlib.image as mpimg
  20. import matplotlib.pyplot as plt
  21.  
  22. # scipy import
  23. from scipy import signal
  24.  
  25. # py imports
  26. from copy import deepcopy
  27.  
  28. #==============================================================================
  29. # # Image Handling class
  30. #==============================================================================
  31.  
  32. class MyImage(object):
  33.     ''' Main class, this will be a standard image object mxn matrix of values'''
  34.    
  35.     # ------ initialization functions ------
  36.    
  37.     def __init__(self, data = np.zeros((5,5))):  
  38.         ''' The image can be initiated with any numpy array, default is a 5x5
  39.        zeroed matrix. The image can be intiated by tuple indicating its size
  40.        (mxn). The image can be initiated by a path to an image.
  41.        The data is stored in the self data folder.
  42.        Usage:
  43.            img = MyImg()
  44.            img = MyImg(np.zeros((512, 512)))
  45.            img = MyImg((512, 512))
  46.            img = MyImg(path/to/picture.png)
  47.        '''
  48.         if type(data) == np.ndarray:
  49.             self.data = data
  50.         elif type(data) == tuple:
  51.             if len(data) == 2:
  52.                 self.data = np.zeros(data)
  53.         elif type(data) == str:
  54.             # shall i check for path being an image?
  55.             self.read_from_file(data)
  56.         else:
  57.             raise ValueError("data type not supported")
  58.            
  59.    
  60.     # ------ debug options ------
  61.    
  62.     def inspect(self, output = True):
  63.         ''' short function that returns the image values: mean,
  64.        standard deviation, max, min and size of image
  65.        if output is True, it prints to the console the string containing the
  66.        formatted value
  67.        '''
  68.         m = np.mean(self.data)
  69.         s = np.std(self.data)
  70.         u = np.max(self.data)
  71.         l = np.min(self.data)
  72.         d = self.data.shape
  73.        
  74.         if output:
  75.             s  = "Mean: {0:.2f} | Std: {1:.2f} | Max: {2:.2f}|Min: {3:.2f} | \
  76.                  Dim: {4[0]}x{4[1]}".format(m, s, u, l, d)
  77.             print(s)
  78.             return s
  79.            
  80.         return (m, s, u, l, d)  
  81.    
  82.     def show_image(self):
  83.         ''' This prepares a image canvas for matlibplot visible in  the
  84.        IPython console.
  85.        '''
  86.         data = deepcopy(self.data)  # copy the data to not modify them
  87.        
  88.         # limit the data between 0 and 1
  89.         mi = np.min(data)
  90.         pospic = data + mi
  91.         m = np.max(pospic)
  92.         npic = pospic / float(m)
  93.         data = 1 - npic
  94.        
  95.         # show the image in greyscale
  96.         plt.imshow(data, cmap = "Greys")    
  97.    
  98.     def get_size(self):
  99.         return (self.data.shape[0], self.data.shape[1])
  100.    
  101.     def get_sizex(self):
  102.         return self.get_size()[0]
  103.    
  104.     def get_sizey(self):
  105.         return self.get_size()[1]
  106.    
  107.     # ------ I/O functions ------
  108.    
  109.     def read_from_file(self, filepathname):
  110.         ''' import image from file using the mpimg utility of matplotlib'''
  111.         # todo warnings about file existing ?
  112.         self.data = mpimg.imread(filepathname)
  113.  
  114.     def save(self, filename):
  115.         ''' saves an image using the pyplot method'''
  116.         plt.imsave(filename, self.data)    
  117.  
  118.  
  119.     # ------ operators overload  ------      
  120.     def __add__(self, rhs):
  121.         ''' sums two images px by px'''
  122.         self.data = self.data + rhs.data
  123.         return MyImage(self.data)
  124.    
  125.     def __truediv__(self, rhs):
  126.         ''' divide image by scalar (myimg / number)'''
  127.         rpic = deepcopy(self.data)
  128.         for x in range(self.data.shape[0]):
  129.             for y in range(self.data.shape[1]):
  130.                 rpic[x][y] = self.data[x][y] / rhs
  131.        
  132.         return MyImage(rpic)
  133.  
  134.     # ------ editing functions ------
  135.     def create_composite_right(self, rhsimage):
  136.         ''' concatenates 2 images on the right'''
  137.         # todo multiple arugments
  138.         # enlarge the array to fit the next image
  139.         self.data = np.concatenate((self.data, rhsimage.data), axis = 1)
  140.    
  141.     def normalize(self):
  142.         ''' normalize the picture values so that the resulting image will have
  143.        mean = 0 and std = 1'''
  144.         m = np.mean(self.data)
  145.         s = np.std(self.data)
  146.         self.data = (self.data - m) / s
  147.  
  148.     def convert2grayscale(self):
  149.         ''' when importing an rgb image is necessary to calculate the
  150.        luminosity and reduce the array from mxnx3 to mxn
  151.        '''
  152.         self.data = np.dot(self.data[...,:3], [0.299, 0.587, 0.114])
  153.    
  154.     def transpose(self):
  155.         ''' transposes the picture from mxn to nxm'''
  156.         self.data.transpose()
  157.        
  158.     def binning(self, n = 1):
  159.         ''' Averages a matrix of 2x2 pixels to one value, effectively reducing
  160.        size by two and averaging the value, giving less noise. n indicates
  161.        how many time the procedure is done
  162.        512x512 bin 1 -> 256x256 bin 2 -> 128128 bin 3 -> ...
  163.        '''
  164.         for i in range(n):
  165.             # initialize resulting image
  166.             rimg = np.zeros( (int(self.data.shape[0] / 2) , int(self.data.shape[1] / 2)))
  167.            
  168.             # calculate for each pixel the corresponding
  169.             # idx rimg = 2*idx srcimg
  170.             for x in range(rimg.shape[0]):
  171.                 for y in range(rimg.shape[1]):
  172.                     a = self.data[x*2]    [y*2]
  173.                     b = self.data[x*2 + 1][y*2]
  174.                     c = self.data[x*2]    [y*2 + 1]
  175.                     d = self.data[x*2 + 1][y*2 + 1]
  176.                     rimg[x,y] =  (a + b + c + d) / 4.0
  177.                    
  178.            
  179.             self.data = rimg
  180.  
  181.     def move(self, dx, dy):
  182.         ''' moves the picture by the dx or dy values. dx dy must be ints'''
  183.         # correction to give dx a right movement if positive
  184.         dx = -dx
  185.        
  186.         # initialize the image
  187.         mpic = np.zeros(self.data.shape)
  188.        
  189.         # get image size
  190.         sizex = mpic.shape[0]
  191.         sizey = mpic.shape[1]
  192.        
  193.         for x in range(sizex):
  194.             for y in range(sizey):
  195.                 xdx = x + dx
  196.                 ydy = y + dy
  197.                 if xdx >= 0 and xdx < sizex and ydy >= 0 and ydy < sizey:
  198.                     mpic[x][y] = self.data[xdx][ydy]
  199.        
  200.         self.data = mpic
  201.    
  202.     def squareit(self, mode = "center"):
  203.         ''' Squares the image. Two methods available
  204.        center: cuts a square in the center of the picture
  205.        left side: cuts a square on top or on left side of the pic
  206.        '''
  207.         if mode == "center":
  208.             lx = self.data.shape[0]
  209.             ly = self.data.shape[1]
  210.            
  211.             if lx > ly:
  212.                 ix = int(lx / 2 - ly / 2)
  213.                 iy = int(lx / 2 + ly / 2)
  214.                 self.data = self.data[ ix : iy , 0 : ly]
  215.             else:
  216.                 ix = int(ly / 2 - lx / 2)
  217.                 iy = int(ly / 2 + lx / 2)
  218.                 self.data = self.data[0 : lx, ix : iy ]            
  219.         if mode == "left side":
  220.             lx = self.data.shape[0]
  221.             ly = self.data.shape[1]
  222.            
  223.             if lx > ly:
  224.                 self.data = self.data[0:ly,0:ly]
  225.             else:
  226.                 self.data = self.data[0:lx,0:lx]
  227.    
  228.     def correlate(self, image):
  229.         ''' scipy correlate function. veri slow, based on convolution'''
  230.         corr = signal.correlate2d(image.data, self.data, boundary='symm', mode='same')
  231.         return Corr(corr)
  232.  
  233.     def limit(self, valmax):
  234.         ''' remaps the values from 0 to valmax'''
  235.         # si potrebbe cambiare da minvalue a value
  236.         mi = self.data.min()
  237.         mi = np.abs(mi)
  238.         pospic = self.data + mi
  239.         m = np.max(pospic)
  240.         npic = pospic / float(m)
  241.         self.data = npic * valmax
  242.    
  243.     def apply_mask(self, mask):
  244.         ''' apply a mask on the picture with a dot product '''
  245.         self.data = self.data * mask.data
  246.        
  247.     def rotate(self, deg, center = (0,0)):
  248.         ''' rotates the image by set degree'''
  249.         #where c is the cosine of the angle, s is the sine of the angle and
  250.         #x0, y0 are used to correctly translate the rotated image.
  251.        
  252.         # size of source image
  253.         src_dimsx = self.data.shape[0]
  254.         src_dimsy = self.data.shape[1]
  255.        
  256.         # get the radians and calculate sin and cos
  257.         rad = np.deg2rad(deg)
  258.         c = np.cos(rad)
  259.         s = np.sin(rad)
  260.        
  261.         # calculate center of image
  262.         cx = center[0] + src_dimsx/2
  263.         cy = center[1] + src_dimsy/2
  264.        
  265.         # factor that moves the index to the center
  266.         x0 = cx - c*cx - s*cx
  267.         y0 = cy - c*cy + s*cy
  268.        
  269.         # initialize destination image
  270.         dest = MyImage(self.data.shape)
  271.         for y in range(src_dimsy):
  272.             for x in range(src_dimsx):
  273.                 # get the source indexes
  274.                 src_x = int(c*x + s*y + x0)
  275.                 src_y = int(-s*x + c*y + y0)
  276.                 if src_y > 0 and src_y < src_dimsy and src_x > 0 and src_x < src_dimsx:
  277.                     #paste the value in the destination image
  278.                     dest.data[x][y] = self.data[src_x][src_y]
  279.                    
  280.         self.data = dest.data
  281.  
  282.     def flip_H(self):
  283.         sizex = self.data.shape[0] - 1
  284.         sizey = self.data.shape[1] - 1
  285.         for x in range(int(sizex / 2)):
  286.             for y in range(sizey):
  287.                 tmp = self.data[x][y]
  288.                
  289.                 self.data[x][y] = self.data[sizex - x][y]
  290.                 self.data[sizex - x][y] = tmp
  291.  
  292.     def flip_V(self):
  293.         sizex = self.data.shape[0] - 1
  294.         sizey = self.data.shape[1] - 1
  295.         for x in range(int(sizex)):
  296.             for y in range(int(sizey / 2)):
  297.                 tmp = self.data[x][y]
  298.                
  299.                 self.data[x][y] = self.data[x][sizey - y]
  300.                 self.data[x][sizey - y] = tmp
  301.  
  302. #==============================================================================
  303. # # Cross correlation image Handling class
  304. #==============================================================================
  305.  
  306. class Corr(MyImage):
  307.     ''' This class provide additional methods in case the picture is a
  308.    correlation picture.
  309.    '''
  310.    
  311.     def find_peak(self, msize = 5):
  312.         ''' finde the pixel with highest value in the image considers a matrix
  313.        of msize x msize, for now works very good even if size is 1.
  314.        returns in a tuple s, x, y. s is the corrrelation coefficient and
  315.        x y are the pixel coordinate of the peak.
  316.        '''
  317.         #' consider a matrix of some pixels
  318.         best = (0,0,0)
  319.         for x in range(self.data.shape[0] - msize):
  320.             for y in range(self.data.shape[1] - msize):
  321.                 # calculate mean of the matrix
  322.                 s = 0
  323.                 for i in range(msize):
  324.                     for j in range(msize):
  325.                         s += self.data[x + i][y + j]
  326.                 s =  s / float(msize)**2
  327.                
  328.                 # assign the best value to best, the return tuple
  329.                 if s > best[0]:
  330.                     best = (s, x, y)
  331.         return best
  332.    
  333.     def find_translation(self, peak):
  334.         ''' converts the peak into the translation needed to overlap completely
  335.        the pictures
  336.        '''
  337.         if type(peak) == int:
  338.             peak = self.find_peak(peak)
  339.        
  340.         #best = self.find_peak(msize)
  341.         peakx = peak[1]
  342.         peaky = peak[2]
  343.        
  344.         dx = -(self.data.shape[0]/2 - peakx)
  345.         dy = self.data.shape[1]/2 - peaky
  346.        
  347.         return int(dx), int(dy)
  348.    
  349.     def show_translation(self, dx, dy):
  350.         ''' prints on the image where the peak is
  351.        usage:
  352.            corr = Corr()
  353.            best = corr.find_peak()
  354.            dx, dy = corr.find_translation(best)
  355.            corr.show_image()
  356.            corr.show_translation(dx, dy)
  357.            plt.show()
  358.        '''
  359.         ody = dx + self.data.shape[0]/2
  360.         odx = self.data.shape[1]/2 - dy
  361.         plt.scatter(odx, ody, s=40, alpha = .5)    
  362.         return odx, ody
  363.  
  364. #==============================================================================
  365. # # Mask image Handling class
  366. #==============================================================================
  367.  
  368. class Mask(MyImage):
  369.     ''' This class manages the creation of masks
  370.    '''
  371.    
  372.     def create_circle_mask(self, radius, smooth):
  373.         ''' creates a smoothed circle with value 1 in the center and zero
  374.        outside radius + smooth, uses a linear interpolation from 0 to 1 in
  375.        r +- smooth.
  376.        '''
  377.         # initialize data array
  378.         dims = self.data.shape
  379.         mask = np.ones(dims)*0.5
  380.         center = (dims[0]/2.0, dims[1]/2.0)
  381.         for i in range(dims[0]):
  382.             for j in range(dims[1]):
  383.                 # if distance from center > r + s = 0, < r - s = 1 else
  384.                 # smooth interpolate
  385.                 dcenter = np.sqrt( (i - center[0])**2 + (j - center[1])**2)
  386.                 if dcenter >= (radius + smooth):
  387.                     mask[i][j] = 0
  388.                 elif dcenter <= (radius - smooth):
  389.                     mask[i][j] = 1
  390.                 else:
  391.                     y = -1*(dcenter - (radius + smooth))/radius
  392.                     mask[i][j] = y
  393.         self.data = mask
  394.        
  395.         # normalize the picture from 0 to 1
  396.         self.limit(1)
  397.         return self.data
  398.    
  399.     def invert(self):
  400.         self.data = 1 - self.data
  401.    
  402.     def bandpass(self, rin, sin, rout, sout):
  403.         ''' To create a band pass two circle images are created, one inverted
  404.        and pasted into dthe other'''
  405.        
  406.         # if radius zero dont create the inner circle
  407.         if rin != 0:
  408.             self.create_circle_mask(rin, sin)
  409.         else:
  410.             self.data = np.zeros(self.data.shape)
  411.        
  412.         # create the outer circle
  413.         bigcircle = deepcopy(self)
  414.         bigcircle.create_circle_mask(rout, sout)
  415.         bigcircle.invert()
  416.        
  417.         # sum the two pictures
  418.         m = (self + bigcircle)
  419.        
  420.         # limit fro 0 to 1 and invert
  421.         m.limit(1)
  422.         m.invert()  
  423.        
  424.         self.data = m.data
  425.    
  426.     def __add__(self, rhs):
  427.         ''' overload of the + operator why is not inherited from MyImage?'''
  428.        
  429.         self.data = self.data + rhs.data
  430.         return Mask(self.data)
  431.    
  432.    
  433. if __name__ == "__main__":
  434.     mypicname = "../../../Lenna.png"
  435.     mypic = MyImage()
  436.     mypic.read_from_file(mypicname)
  437.     mypic.squareit()
  438.     mypic.convert2grayscale()
  439.     mypic.binning(0)
  440.     mypic.normalize()
  441.  
  442.     mypic.show_image()
  443.     plt.show()
  444.    
  445.     mypic.flip_V()
  446.    
  447.     mypic.show_image()
  448.     plt.show()
  449.    
  450.     movpic = deepcopy(mypic)
  451.     movpic.move(40, 0)
  452.     movpic.normalize()
  453.    
  454.     movpic.show_image()
  455.     plt.show()
  456.    
  457.     myrot = deepcopy(mypic)
  458.     myrot.rotate(45, center = (0, 0))
  459.     myrot.normalize()
  460.    
  461.     myrot.show_image()
  462.     plt.show()  
  463.    
  464.    
  465.     # at theoretical level the precision can be to the 100th of degree...
  466. #    angles = [10, 29.999, 30, 30.001]
  467. #    myrots = []
  468. #    for angle in angles:
  469. #        myrot10 = deepcopy(mypic)
  470. #        myrot10.rotate(angle)
  471. #        myrot10.normalize()
  472. #        myrot10.show_image()
  473. #        plt.show()
  474. #        myrots.append(myrot10)
  475. #
  476. #    from ImageFFT_class import ImgFFT
  477. #    myrotft = ImgFFT(myrot)
  478. #    myrotft.ft()
  479. #    smax = 0
  480. #    for i, rot in enumerate(myrots):
  481. #        # find rotation
  482. #        
  483. #        rotft = ImgFFT(rot)
  484. #        rotft.ft()      
  485. #        cc = myrotft.correlate(rotft)
  486. #        cc.show_image()
  487. #        s, dx, dy = cc.find_peak(1)
  488. #        plt.show()
  489. #        print("my angle:", angles[i])
  490. #        print(dx, dy, s)
  491. #    
  492. #    xarr = [i * 10 + 10 for i in range(20)]
  493. #    yarr = [np.rad2deg(1 / float(x)) for x in xarr]    
  494. #        
  495. #    plt.scatter(xarr, yarr)
  496. #    plt.show()
  497.    
  498. #    # test the average
  499. #    from ImageFFT_class import ImgFFT
  500. #    
  501. #    # lena is the template
  502. #    template = deepcopy(mypic)
  503. #    tempft = ImgFFT(template)
  504. #    tempft.ft()
  505. #    
  506. #    # construct a rotation space for the template
  507. #    rotangles = [x for x in range(-10,10,1)]
  508. #    
  509. #    rotationspace = []
  510. ##    for angle in rotangles:
  511. ##        print("rotating template angle:", angle)
  512. ##        temprot = deepcopy(template)
  513. ##        temprot.rotate(angle)
  514. ##        temprotft = ImgFFT(temprot)
  515. ##        temprotft.ft()
  516. ##        rotationspace.append(temprotft)
  517. #    
  518. #    # construct a dataset with randomly moved and rotated images
  519. #    
  520. #    np.random.seed(5)
  521. #    dataset = []
  522. #    datasetangles = []
  523. #    datasetshifts = []
  524. #    
  525. #    datatrans = []
  526. #    
  527. #    print("------------------------------")
  528. #    print("creating dataset")
  529. #    print("------------------------------")
  530. #
  531. #    angle_list = np.arange(-20, 20, 5)    
  532. #    for i, ngle in enumerate(angle_list):
  533. #        image = deepcopy(mypic)
  534. #        anglefirst = False if np.random.randint(0,2) == 0 else True
  535. #        
  536. ##        angle = np.random.randint(-10, 10)
  537. ##        dx = np.random.randint(-50, 50)
  538. ##        dy = np.random.randint(-50, 50)
  539. #
  540. #        anglefirst = True
  541. #        angle = angle_list[i]
  542. #        print(angle)
  543. #        dx = 0
  544. #        dy = 0
  545. #        
  546. #        datatrans.append((anglefirst, angle, dx, dy))
  547. #        
  548. #        if anglefirst:
  549. #            datasetangles.append(angle)
  550. #            image.rotate(angle)
  551. #
  552. #            datasetshifts.append((dx,dy))
  553. #            image.move(dx, dy)
  554. #        else:
  555. #            datasetshifts.append((dx,dy))
  556. #            image.move(dx, dy)
  557. #
  558. #            datasetangles.append(angle)
  559. #            image.rotate(angle)
  560. #        
  561. #        print(datatrans[i])
  562. #          
  563. #        image.show_image()
  564. #        plt.show()
  565. #        dataset.append(image)
  566. #    
  567. #    # for each image test the rotation against the template
  568. #    print("------------------------------")
  569. #    print("Align dataset")
  570. #    print("------------------------------")  
  571. #    
  572. #    def distance(x, y):
  573. #        d = np.sqrt((y-x)**2)
  574. #        return d
  575. #    
  576. #    algimg = []
  577. #    for i, image in enumerate(dataset):
  578. #        print("original transformation: ", datatrans[i])
  579. #        imageft = ImgFFT(image)
  580. #        imageft.ft()
  581. #
  582. ##        smax = 0
  583. ##        idxmax = 0
  584. ##        for idx, temp in enumerate(rotationspace):
  585. ##            corr = temp.correlate(imageft)
  586. ##            s, dx, dy = corr.find_peak(1)
  587. ##            
  588. ##            
  589. ##            
  590. ##            if s > smax:
  591. ##                smax = s
  592. ##                idxmax = idx
  593. ##        
  594. ##        print("angle found",rotangles[idxmax] )
  595. #        
  596. #        rotalgimage = deepcopy(image)
  597. ##        rotalgimage.rotate(-rotangles[idxmax])
  598. #        
  599. #        rotalgimageft = ImgFFT(rotalgimage)
  600. #        rotalgimageft.ft()
  601. #        
  602. #        corr = rotalgimageft.correlate(tempft)
  603. #        
  604. #
  605. #        
  606. #        dx, dy = corr.find_translation(1)
  607. #
  608. #        corr.show_image()
  609. #        corr.show_translation(dx, dy)
  610. #        plt.show()
  611. #        
  612. #        print("shifts:", dx, dy)
  613. #        
  614. #        rotalgimage.move(dx, dy)
  615. #        # distance
  616. ##        print("distance")
  617. ##        print(distance(datatrans[i][1], rotangles[idxmax]),
  618. ##              distance(datatrans[i][2], -dx),
  619. ##              distance(datatrans[i][3], -dy)
  620. ##              )        
  621. ##        
  622. #        rotalgimage.show_image()
  623. #        plt.show()
  624. #        algimg.append(rotalgimage)
  625.        
  626.  
  627.        
  628.        
  629.            
  630.            
  631.            
  632.    
  633.     # then rotate it and test the shifts
  634.    
  635.    
  636.    
  637.     # test correlation
  638. #    cc = mypic.correlate(movpic)
  639. #    
  640. #    cc.show_image()
  641. #    
  642. #    dx, dy = cc.find_translation()
  643. #    cc.show_translation(dx, dy)
  644. #    plt.show()
  645.    
  646. #    mypic.create_composite_right(movpic)
  647. #
  648. #    mypic.show_image()
  649. #    plt.show()
  650. #
Advertisement
Add Comment
Please, Sign In to add comment