Advertisement
Guest User

Untitled

a guest
Jul 22nd, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.92 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. # test edit
  4. from __future__ import print_function
  5. from itertools import cycle
  6. from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
  7.                                   AnnotationBbox)
  8.  
  9. import sys
  10. import os
  11. import argparse
  12.  
  13. import ehtim as eh
  14. import matplotlib.pyplot as plt
  15. import networkx as nx
  16. import numpy as np
  17.  
  18. import matplotlib.pyplot as plt
  19. from matplotlib.patches import Circle
  20. from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
  21.                                   AnnotationBbox)
  22. import glob
  23. from itertools import cycle
  24. import copy
  25.  
  26.  
  27. def image_consistency(imarr, beamparams, metric='nxcorr', blursmall=True, beam_max=1.0, beam_steps=5, savepath=[]):
  28.    
  29.     # get the pixel sizes and fov to compare images at
  30.     (min_psize, max_fov) = get_psize_fov(imarr)
  31.    
  32.     # initialize matrix matrix
  33.     metric_mtx = np.zeros([len(imarr), len(imarr), beam_steps])
  34.    
  35.     # get the different fracsteps
  36.     fracsteps = np.linspace(0,beam_max,beam_steps)
  37.    
  38.     # loop over the different beam sizes
  39.     for fracidx in range(beam_steps):
  40.         #print(fracidx)
  41.         # look at every pair of images and compute their beam convolved metrics
  42.         for i in range(len(imarr)):
  43.             img1 = imarr[i]
  44.             if fracsteps[fracidx]>0:
  45.                 img1 = img1.blur_gauss(beamparams, fracsteps[fracidx])
  46.    
  47.             for j in range(i+1, len(imarr)):
  48.                 img2 = imarr[j]
  49.                 if fracsteps[fracidx]>0:
  50.                     img2 = img2.blur_gauss(beamparams, fracsteps[fracidx])
  51.                
  52.                 #print(j, i, fracidx)
  53.            
  54.                 # compute image comparision under a specified blur_frac
  55.                 (error, im1_pad, im2_shift) = img1.compare_images(img2, metric = [metric], psize = min_psize, target_fov = max_fov, blur_frac=0.0, beamparams=beamparams)
  56.                
  57.                 # if specified save the shifted images used for comparision
  58.                 if savepath:
  59.                     im1_pad.save_fits(savepath + '/' + str(i) + '_' + str(fracidx) + '.fits')
  60.                     im2_shift.save_fits(savepath + '/' + str(j) +  '_' + str(fracidx) + '.fits')
  61.  
  62.                 # save the metric value in a matrix
  63.                 metric_mtx[i,j,fracidx] = error[0]
  64.    
  65.     return (metric_mtx, fracsteps)
  66.    
  67.    
  68.    
  69. # look over an array of images and determine the min pixel size and max fov that can be used consistently across them
  70. def get_psize_fov(imarr):  
  71.     min_psize = 100
  72.     for i in range(0, len(imarr)):
  73.         if i==0:
  74.             max_fov = np.max([imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
  75.             min_psize = imarr[i].psize
  76.         else:
  77.             max_fov = np.max([max_fov, imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
  78.             min_psize = np.min([min_psize, imarr[i].psize])
  79.     return (min_psize, max_fov)
  80.    
  81.  
  82.  
  83. def image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=0.95):
  84.    
  85.     (min_psize, max_fov) = get_psize_fov(imarr)
  86.    
  87.     im_cliques_fraclevels = []
  88.     cliques_fraclevels = []
  89.     for fracidx in range(len(fracsteps)):
  90.         #print(fracidx)
  91.    
  92.         slice_metric_mtx = metric_mtx[:,:,fracidx]
  93.         cuttoffidx = np.where( slice_metric_mtx >= cutoff)
  94.         consistant = zip(*cuttoffidx)
  95.        
  96.         # make graph
  97.         G=nx.Graph()
  98.         for i in range(len(consistant)):
  99.             G.add_edge(consistant[i][0], consistant[i][1])
  100.        
  101.         # find all cliques
  102.         cliques = list(nx.find_cliques(G))
  103.         #print(cliques)
  104.        
  105.         cliques_fraclevels.append(cliques)
  106.        
  107.         im_clique = []
  108.         for c in range(len(cliques)):
  109.             clique = cliques[c]
  110.             im_avg = imarr[clique[0]].blur_gauss(beamparams,fracsteps[fracidx])
  111.            
  112.             for n in range(1,len(clique)):
  113.                 (error, im_avg, im2_shift) = im_avg.compare_images(imarr[clique[n]].blur_gauss(beamparams,fracsteps[fracidx]) , metric = ['xcorr'], psize = min_psize, target_fov = max_fov, blur_frac=0.0,
  114.                          beamparams=beamparams)
  115.                 im_avg.imvec = (im_avg.imvec + im2_shift.imvec ) / 2.0
  116.                
  117.        
  118.             im_clique.append(im_avg.copy())
  119.        
  120.         im_cliques_fraclevels.append(im_clique)
  121.        
  122.     return(cliques_fraclevels, im_cliques_fraclevels)
  123.            
  124.  
  125. def change_cut_off(metric_mtx, fracsteps, imarr, beamparams, cutoff=0.95, zoom=0.1, fov=1):
  126.     (cliques_fraclevels, im_cliques_fraclevels) = image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=cutoff)
  127.     generate_consistency_plot(cliques_fraclevels, im_cliques_fraclevels, metric_mtx=metric_mtx, fracsteps=fracsteps, beamparams=beamparams, zoom=zoom, fov=fov, cutoff=cutoff)
  128.            
  129.  
  130. def generate_consistency_plot(clique_fraclevels, im_clique_fraclevels, zoom=0.1, fov=1, show=True, framesize=(20,10), fracsteps=None, cutoff=None):
  131.  
  132.     fig, ax = plt.subplots(figsize=framesize)
  133.     cycol = cycle('bgrcmk')
  134.  
  135.     x_loc = []
  136.      
  137.     for c, column in enumerate(clique_fraclevels):
  138.         colorc = cycol.next()
  139.         x_loc.append(((20./len(clique_fraclevels))*c))
  140.         for r, row in enumerate(column):
  141.  
  142.             # adding the images
  143.             lenx = len(clique_fraclevels)
  144.             leny = 0
  145.             for li in clique_fraclevels:
  146.                 if len(li) > leny:
  147.                     leny = len(li)
  148.             sample_image = im_clique_fraclevels[c][r].regrid_image(fov*im_clique_fraclevels[c][r].fovx(), 512)
  149.             arr_img = sample_image.imvec.reshape(sample_image.xdim, sample_image.ydim)
  150.             imagebox = OffsetImage(arr_img, zoom=zoom, cmap='afmhot')
  151.            
  152.             imagebox.image.axes = ax
  153.              
  154.        
  155.             ab = AnnotationBbox(imagebox, ((20./lenx)*c,(20./leny)*r),
  156.                                 xycoords='data',
  157.                                 pad=0.0,
  158.                                 arrowprops=None)
  159.  
  160.             ax.add_artist(ab)
  161.  
  162.             # adding the arrows
  163.             if c+1 != len(clique_fraclevels):
  164.                 for a, ro in enumerate(clique_fraclevels[c+1]):
  165.                     if set(row).issubset(ro):
  166.                         px = c+1
  167.                         px = ((20./lenx)*px)
  168.                         py = a
  169.                         py = (20./leny)*py
  170.                         break
  171.  
  172.                 xx = (20./lenx)*c + (8./lenx)
  173.                 yy = (20./leny)*r
  174.                 ax.arrow(   xx, yy,
  175.                             px - xx - (9./lenx), py- yy,  
  176.                             head_width=0.05,
  177.                             head_length=0.1,
  178.                             color=colorc
  179.                         )
  180.             row.sort()
  181.             # adding the text
  182.             txtstring = str(row)
  183.             print ("TEST")
  184.             # print(ab.get_window_extent())
  185.             if len(row) == len(clique_fraclevels[-1][0]):
  186.                 txtstring = '[all]'
  187.  
  188.             # ax.text((20./lenx)*c - (0./lenx), (20./leny)*r  - (10./leny), txtstring, fontsize=6, horizontalalignment='center')
  189.             ax.text((20./lenx)*c,(20./leny)*(r-0.5), txtstring, fontsize=10, horizontalalignment='center', color='black', zorder=1000)
  190.  
  191.     ax.set_xlim(0, 22)
  192.     ax.set_ylim(-10, 22)
  193.  
  194.     ax.set_xticks(x_loc)
  195.     ax.set_xticklabels(fracsteps)
  196.  
  197.  
  198.     ax.set_yticks([])
  199.     ax.set_yticklabels([])
  200.  
  201.     ax.spines['right'].set_visible(False)
  202.     ax.spines['top'].set_visible(False)
  203.     ax.spines['left'].set_visible(False)
  204.     ax.spines['bottom'].set_visible(False)
  205.  
  206.     ax.set_title('Blurred comparison of all images; cutoff={0}, fov (uas)={1}'.format(str(cutoff), str(im_clique_fraclevels[0][0].fovx()/eh.RADPERUAS)))
  207.  
  208.  
  209. #     for item in [fig, ax]:
  210. #         item.patch.set_visible(False)
  211. #     fig.patch.set_visible(False)
  212. #     ax.axis('off')
  213.     if show == True:
  214.         plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement