Advertisement
Guest User

Untitled

a guest
Jul 18th, 2018
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.61 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. from __future__ import print_function
  4. from Tkinter import *
  5. from tkFileDialog import askopenfilename as selectFILE
  6. from tkFileDialog import askdirectory as selectFOLDER
  7. import tkMessageBox as tkmb
  8. import sys
  9. import glob
  10. import subprocess
  11. from itertools import cycle
  12. import matplotlib
  13. matplotlib.use('TkAgg')
  14. import matplotlib.pyplot as plt
  15. from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2TkAgg
  16. from matplotlib.figure import Figure
  17.  
  18. from itertools import cycle
  19. from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
  20.                                   AnnotationBbox)
  21.  
  22. import sys
  23. import os
  24. import argparse
  25.  
  26. import ehtim as eh
  27. import matplotlib.pyplot as plt
  28. import networkx as nx
  29. import numpy as np
  30.  
  31. def image_consistancy(imarr, beamparams, metric='nxcorr', blursmall=True, beam_max=1.0, beam_steps=5, savepath=[]):
  32.    
  33.     # get the pixel sizes and fov to compare images at
  34.     (min_psize, max_fov) = get_psize_fov(imarr)
  35.    
  36.     # initialize matrix matrix
  37.     metric_mtx = np.zeros([len(imarr), len(imarr), beam_steps])
  38.    
  39.     # get the different fracsteps
  40.     fracsteps = np.linspace(0,beam_max,beam_steps)
  41.    
  42.     # loop over the different beam sizes
  43.     for fracidx in range(beam_steps):
  44.    
  45.         # look at every pair of images and compute their beam convolved metrics
  46.         for i in range(len(imarr)):
  47.             img1 = imarr[i]
  48.             if fracsteps[fracidx]>0:
  49.                 img1 = img1.blur_gauss(beamparams, fracsteps[fracidx])
  50.    
  51.             for j in range(i+1, len(imarr)):
  52.                 img2 = imarr[j]
  53.                 if fracsteps[fracidx]>0:
  54.                     img2 = img2.blur_gauss(beamparams, fracsteps[fracidx])
  55.                
  56.                 print(j, i, fracidx)
  57.            
  58.                 # compute image comparision under a specified blur_frac
  59.                 (error, im1_pad, im2_shift) = img1.compare_images(img2, metric = [metric], psize = min_psize, target_fov = max_fov, blur_frac=0.0, beamparams=beamparams)
  60.                
  61.                 # if specified save the shifted images used for comparision
  62.                 if savepath:
  63.                     im1_pad.save_fits(savepath + '/' + str(i) + '_' + str(fracidx) + '.fits')
  64.                     im2_shift.save_fits(savepath + '/' + str(j) +  '_' + str(fracidx) + '.fits')
  65.  
  66.                 # save the metric value in a matrix
  67.                 metric_mtx[i,j,fracidx] = error[0]
  68.    
  69.     return (metric_mtx, fracsteps)
  70.    
  71.    
  72.    
  73. # look over an array of images and determine the min pixel size and max fov that can be used consistently across them
  74. def get_psize_fov(imarr):  
  75.     min_psize = 100
  76.     for i in range(0, len(imarr)):
  77.         if i==0:
  78.             max_fov = np.max([imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
  79.             min_psize = imarr[i].psize
  80.         else:
  81.             max_fov = np.max([max_fov, imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
  82.             min_psize = np.min([min_psize, imarr[i].psize])
  83.     return (min_psize, max_fov)
  84.    
  85.  
  86.  
  87. def image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=0.95):
  88.    
  89.     (min_psize, max_fov) = get_psize_fov(imarr)
  90.    
  91.     im_cliques_fraclevels = []
  92.     cliques_fraclevels = []
  93.     for fracidx in range(len(fracsteps)):
  94.         print(fracidx)
  95.    
  96.         slice_metric_mtx = metric_mtx[:,:,fracidx]
  97.         cuttoffidx = np.where( slice_metric_mtx >= cutoff)
  98.         consistant = zip(*cuttoffidx)
  99.        
  100.         # make graph
  101.         G=nx.Graph()
  102.         for i in range(len(consistant)):
  103.             G.add_edge(consistant[i][0], consistant[i][1])
  104.        
  105.         # find all cliques
  106.         cliques = list(nx.find_cliques(G))
  107.         print(cliques)
  108.        
  109.         cliques_fraclevels.append(cliques)
  110.        
  111.         im_clique = []
  112.         for c in range(len(cliques)):
  113.             clique = cliques[c]
  114.             im_avg = imarr[clique[0]].blur_gauss(beamparams, fracsteps[fracidx])
  115.            
  116.             for n in range(1,len(clique)):
  117.                 (error, im_avg, im2_shift) = im_avg.compare_images(imarr[clique[n]].blur_gauss(beamparams, fracsteps[fracidx]), metric = ['xcorr'], psize = min_psize, target_fov = max_fov, blur_frac=0.0,
  118.                          beamparams=beamparams)
  119.                 im_avg.imvec = (im_avg.imvec + im2_shift.imvec ) / 2.0
  120.                
  121.        
  122.             im_clique.append(im_avg.copy())
  123.        
  124.         im_cliques_fraclevels.append(im_clique)
  125.        
  126.     return(cliques_fraclevels, im_cliques_fraclevels)
  127.  
  128.  
  129. def change_cut_off(metric_mtx, fracsteps, imarr, beamparams, master=None, cutoff=0.95):
  130.     (cliques_fraclevels, im_cliques_fraclevels) = image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=cutoff)
  131.     generate_consistency_plot(cliques_fraclevels, im_cliques_fraclevels, metric_mtx=metric_mtx, fracsteps=fracsteps, master=master, beamparams=beamparams)
  132.  
  133.  
  134. # def save_plot()        
  135.  
  136. def generate_consistency_plot(clique_fraclevels, im_clique_fraclevels, metric_mtx=None, fracsteps=None, beamparams=None, imarr=None, show=True, master=None, gui=True):
  137.     # matplotlib aesthetics
  138.     plt.rc('text', usetex=True)
  139.     plt.rc('font', family='serif')
  140.     plt.rcParams.update({'font.size': 16})
  141.     plt.rcParams['axes.linewidth'] = 2 #set the value globally
  142.     plt.rcParams["font.weight"] = "bold"
  143.     # fig, ax = plt.subplots()
  144.     fig = Figure(figsize=(5,5), dpi=100)
  145.     ax = fig.add_subplot(111)
  146.     cycol = cycle('bgrcmk')
  147.  
  148.     for c, column in enumerate(clique_fraclevels):
  149.         colorc = cycol.next()
  150.         for r, row in enumerate(column):
  151.  
  152.             # adding the images
  153.             lenx = len(clique_fraclevels)
  154.             leny = 0
  155.             for li in clique_fraclevels:
  156.                 if len(li) > leny:
  157.                     leny = len(li)
  158.             sample_image = im_clique_fraclevels[c][r]
  159.             arr_img = sample_image.imvec.reshape(sample_image.xdim, sample_image.ydim)
  160.             imagebox = OffsetImage(arr_img, zoom=0.1, cmap='afmhot')
  161.             imagebox.image.axes = ax
  162.  
  163.             ab = AnnotationBbox(imagebox, ((20./lenx)*c,(20./leny)*r),
  164.                                 xycoords='data',
  165.                                 pad=0.0,
  166.                                 arrowprops=None)
  167.  
  168.             ax.add_artist(ab)
  169.  
  170.             # adding the arrows
  171.             if c+1 != len(clique_fraclevels):
  172.                 for a, ro in enumerate(clique_fraclevels[c+1]):
  173.                     if set(row).issubset(ro):
  174.                         px = c+1
  175.                         px = ((20./lenx)*px)
  176.                         py = a
  177.                         py = (20./leny)*py
  178.                         break
  179.  
  180.                 xx = (20./lenx)*c + (8./lenx)
  181.                 yy = (20./leny)*r
  182.                 ax.arrow(   xx, yy,
  183.                             px - xx - (9./lenx), py- yy,  
  184.                             head_width=0.05,
  185.                             head_length=0.1,
  186.                             color=colorc
  187.                         )
  188.             row.sort()
  189.  
  190.             # adding the text
  191.             txtstring = str(row)
  192.             if len(row) == len(clique_fraclevels[-1][0]):
  193.                 txtstring = '[all]'
  194.             ax.text((20./lenx)*c - (0./lenx), (20./leny)*r  - (10./leny), txtstring, fontsize=6, horizontalalignment='center')
  195.  
  196.     ax.set_xlim(0, 22)
  197.     ax.set_ylim(-1, 22)
  198.  
  199.     for item in [fig, ax]:
  200.         item.patch.set_visible(False)
  201.     fig.patch.set_visible(False)
  202.     ax.axis('off')
  203.  
  204.     if show == True:
  205.         if gui==True:
  206.             # plt.savefig("tmp_consis.png")
  207.             try:
  208.                 master.destroy()
  209.             except:
  210.                 pass
  211.             main = Tk()
  212.  
  213.             slider = Scale(main, from_=0, to=100, label='Cutoff (%)')
  214.             slider.pack()
  215.             change_cutoff_BUTTON = Button(main, text='Rerender graph', command=lambda:change_cut_off(metric_mtx, fracsteps, imarr, beamparams, master=main, cutoff=float(slider.get()/100.)))
  216.             change_cutoff_BUTTON.pack()
  217.  
  218.             save_fig_BUTTON = Button(main, text='Save plot', command=lambda:plt.savefig("test_outpout.png"))
  219.             save_fig_BUTTON.pack()
  220.  
  221.             # plt.show()
  222.             canvas = FigureCanvasTkAgg(fig)
  223.             # canvas.show()
  224.             canvas.get_tk_widget().pack(side=BOTTOM, fill=BOTH, expand=True)
  225.  
  226.             # toolbar = NavigationToolbar2TkAgg(canvas)
  227.             # toolbar.update()
  228.             canvas._tkcanvas.pack(side=TOP, fill=BOTH, expand=True)
  229.  
  230.             main.mainloop()
  231.         else:
  232.             plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement