mario119

Data pre-processing_1

Mar 8th, 2022 (edited)
383
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.82 KB | None | 0 0
  1. import pyvista as pv
  2. from pyvista import examples
  3. import pyacvd
  4. import trimesh
  5. from tqdm import tqdm
  6. import torch
  7. import numpy as np
  8. from torch_geometric.datasets import MNISTSuperpixels
  9. from torch_geometric.data import Data
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from torch_geometric.data import DataLoader
  13. from torch_geometric.nn import SplineConv
  14. from torch_scatter import scatter
  15. import os
  16. import os.path as osp
  17. import torch_geometric.transforms as T
  18. import matplotlib.pyplot as plt
  19. import matplotlib.cm as cm
  20. from scipy.spatial import Delaunay
  21.  
  22.  
  23. path = osp.join('.','..','data', 'MNIST')
  24.  
  25. transform = transform = T.Compose([T.Delaunay(), T.FaceToEdge(remove_faces=False)])
  26.  
  27. mnist_dataset = MNISTSuperpixels(path, True, transform=transform)[:10000] # True means it's train dataset
  28.  
  29. class MeshData(Data):
  30. def __init__(self,x=None, y=None, num_nodes=None,
  31. edge_index0=None,edge_attr0=None,cluster0=None,
  32. edge_index1=None,edge_attr1=None,cluster1=None,
  33. edge_index2=None,edge_attr2=None,cluster2=None
  34. ):
  35. super().__init__()
  36. self.x = x
  37. self.y = y
  38. self.num_nodes = num_nodes
  39. self.edge_index0,self.edge_attr0,self.cluster0 = edge_index0,edge_attr0,cluster0
  40. self.edge_index1,self.edge_attr1,self.cluster1 = edge_index1,edge_attr1,cluster1
  41. self.edge_index2,self.edge_attr2,self.cluster2 = edge_index2,edge_attr2,cluster2
  42. def __inc__(self,key,value,*args,**kawrgs):
  43. if key=='x':
  44. return self.x.size(0)
  45. if key=='y':
  46. return self.y.size(0)
  47. if key=='edge_index0':
  48. return self.cluster0.size(0)
  49. if key=='edge_index1':
  50. return self.cluster1.size(0)
  51. if key=='edge_index2':
  52. return self.cluster2.size(0)
  53. if key=='cluster0':
  54. return self.cluster0.max()+1
  55. if key=='cluster1':
  56. return self.cluster1.max()+1
  57. if key=='cluster2':
  58. return self.cluster2.max()+1
  59. else:
  60. return super().__inc__(key,value,*args,**kawrgs)
  61.  
  62. # Only Use for Initial Data Generation
  63. entire_data = pv.MultiBlock()
  64. for i in tqdm(range(len(mnist_dataset))):
  65. points0 = np.vstack((np.array(mnist_dataset[i].pos).T, np.zeros(75))).T
  66. face0 = np.vstack((np.ones(np.array(mnist_dataset[i].face).shape[1])*3, np.array(mnist_dataset[i].face))).T.reshape(-1)
  67. face0 = np.intc(face0) # Face should be shape(-1), which is acceptable as PolyData// Vertices shape = (x, 3)
  68.  
  69. mesh0 = pv.PolyData(points0, face0)
  70. mesh0['x'] = np.array(mnist_dataset[i].x) # store node feature in the mesh data
  71. entire_data.append(mesh0)
  72.  
  73. # Only Use for Initial Data Generation
  74. edge_index_all = []
  75. edge_attr_all = []
  76. cluster_all = []
  77. depth = 3 # Number of coarsened graphs for pooling
  78. mesh_all = []
  79. for i in tqdm(range(len(entire_data))): # cluster1.nclus = n1
  80. for d in range(depth):
  81.  
  82. edge_index = [[], []]
  83. edge_attr = []
  84. cluster = []
  85.  
  86. # for i in tqdm(range(len(entire_data))):
  87. if d==0 :
  88. mesh0 = entire_data[i]
  89. else :
  90. mesh0 = mesh1
  91.  
  92. n0 = mesh0.n_points
  93. reduction_rate = 0.2
  94. n1 = int(reduction_rate*n0)
  95.  
  96. cluster0 = pyacvd.Clustering(mesh0)
  97. cluster0.cluster(n1)
  98. pos = mesh0.points
  99.  
  100. for [p,q] in cluster0._edges.tolist():
  101. edge_index[0].extend([p,q])
  102. edge_index[1].extend([q,p])
  103. edge_attr.extend([pos[q]-pos[p], pos[p]-pos[q]])
  104. edge_attr = np.array(edge_attr)
  105. edge_index_all.append(edge_index)
  106. edge_attr_all.append(edge_attr)
  107. cluster_all.append(cluster0.clusters)
  108. if d !=(depth-1):
  109. mesh1 = cluster0.create_mesh()
  110.  
  111. g = MeshData(x = torch.tensor(entire_data[i]['x'], dtype=torch.float32),
  112. y = torch.tensor(mnist_dataset[i].y),
  113. edge_index0 = torch.tensor(edge_index_all[0],dtype=torch.long),
  114. edge_index1 = torch.tensor(edge_index_all[1],dtype=torch.long),
  115. edge_index2 = torch.tensor(edge_index_all[2],dtype=torch.long),
  116. edge_attr0 = torch.tensor(edge_attr_all[0],dtype=torch.float32),
  117. edge_attr1 = torch.tensor(edge_attr_all[1],dtype=torch.float32),
  118. edge_attr2 = torch.tensor(edge_attr_all[2],dtype=torch.float32),
  119. cluster0 = torch.tensor(cluster_all[0],dtype=torch.long),
  120. cluster1 = torch.tensor(cluster_all[1],dtype=torch.long),
  121. cluster2 = torch.tensor(cluster_all[2],dtype=torch.long),
  122. )
  123. mesh_all.append(g)
  124.  
  125.  
  126. # Normalize edge_attrs and save the all mesh data
  127.  
  128. xmm0,ymm0,zmm0 = [],[],[]
  129. xmm1,ymm1,zmm1 = [],[],[]
  130. xmm2,ymm2,zmm2 = [],[],[]
  131.  
  132. for p in tqdm(mesh_all):
  133. xmm0.append(p.edge_attr0[:,0].max())
  134. ymm0.append(p.edge_attr0[:,1].max())
  135. zmm0.append(p.edge_attr0[:,2].max())
  136.  
  137. xmm1.append(p.edge_attr1[:,0].max())
  138. ymm1.append(p.edge_attr1[:,1].max())
  139. zmm1.append(p.edge_attr1[:,2].max())
  140.  
  141. xmm2.append(p.edge_attr2[:,0].max())
  142. ymm2.append(p.edge_attr2[:,1].max())
  143. zmm2.append(p.edge_attr2[:,2].max())
  144.  
  145. # xyzm0 = np.array([max(xmm0),max(ymm0),max(zmm0)])
  146. # xyzm1 = np.array([max(xmm1),max(ymm1),max(zmm1)])
  147. # xyzm2 = np.array([max(xmm2),max(ymm2),max(zmm2)])
  148.  
  149. # for 2D (x,y)
  150. xyzm0 = np.array([max(xmm0),max(ymm0)])
  151. xyzm1 = np.array([max(xmm1),max(ymm1)])
  152. xyzm2 = np.array([max(xmm2),max(ymm2)])
  153.  
  154. # print(xyzm0,xyzm1,xyzm2)
  155. print(xyzm0, xyzm1, xyzm2)
  156. for p in tqdm(mesh_all):
  157. p.edge_attr0 = 0.05+(p.edge_attr0[:,:2]+xyzm0)/(2.2222222*xyzm0)
  158. p.edge_attr1 = 0.05+(p.edge_attr1[:,:2]+xyzm1)/(2.2222222*xyzm1)
  159. p.edge_attr2 = 0.05+(p.edge_attr2[:,:2]+xyzm2)/(2.2222222*xyzm2)
  160.  
  161.  
  162. torch.save(mesh_all,os.path.join('./','mesh_graph.pt'))
Add Comment
Please, Sign In to add comment