Guest User

Untitled

a guest
Jan 19th, 2018
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.31 KB | None | 0 0
  1. """Convert generic type name (real) in PyTorch project
  2. to whatever you want.
  3.  
  4.  
  5. Author: Roger-luo
  6. """
  7.  
  8. import os
  9. import re
  10. import sys
  11. import shutil
  12. import inspect
  13.  
  14.  
  15. class Token(object):
  16. """token name in torch c source.
  17.  
  18. This class will parse the numerical type
  19. for each token name.
  20. """
  21.  
  22. typenames = [
  23. 'real',
  24. 'complex',
  25. 'ntype',
  26. 'double',
  27. 'float',
  28. 'zdouble',
  29. 'zfloat',
  30. 'int',
  31. 'long',
  32. 'short',
  33. ]
  34.  
  35. def __init__(self, text, meta=None):
  36. self.str = text
  37.  
  38. # split by '_' and uppercase
  39. src = []
  40. for each in re.split(r'(_)', text):
  41. src.extend(re.split(r'([A-Z][a-z]+)', each))
  42. self.names = src
  43.  
  44. self.with_prefix('acc')
  45. self.with_prefix('u')
  46.  
  47. for name, method in inspect.getmembers(self):
  48. if name.startswith('init_'):
  49. method()
  50.  
  51. def with_prefix(self, prefix):
  52. src = []
  53. for each in self.names:
  54. if each.lower() == prefix + 'real':
  55. src.extend([each[:len(prefix)], each[len(prefix):]])
  56. else:
  57. src.append(each)
  58. self.names = src
  59.  
  60. @staticmethod
  61. def pattern(text):
  62. if text.istitle():
  63. return 'title'
  64. elif text.isupper():
  65. return 'upper'
  66. elif text.islower():
  67. return 'lower'
  68. else:
  69. raise ValueError("invalid text")
  70.  
  71. def init_dtype(self):
  72. for ind, each in enumerate(self.names):
  73. if each.lower() in self.typenames:
  74. self._dtype = dict(
  75. name=each.lower(),
  76. pattern=self.pattern(each),
  77. index=ind,
  78. )
  79. break
  80.  
  81. @property
  82. def dtype(self):
  83. out = getattr(self, '_dtype', None)
  84. if out is not None:
  85. return out['name']
  86. else:
  87. return out
  88.  
  89. @dtype.setter
  90. def dtype(self, val):
  91. # check if dtype exists
  92. if self._dtype is None:
  93. raise ValueError("variable name does not have dtype")
  94. # check type
  95. lower = val.lower()
  96. if lower not in self.typenames:
  97. raise TypeError("Invalid type")
  98.  
  99. # match pattern
  100. dtype = lower
  101. if self._dtype['pattern'] == 'title':
  102. dtype = dtype.title()
  103. elif self._dtype['pattern'] == 'upper':
  104. dtype = dtype.upper()
  105.  
  106. self._dtype['name'] = lower
  107. self.names[self._dtype['index']] = dtype
  108.  
  109. def __repr__(self):
  110. dtype = self.dtype if self.dtype is not None else 'none'
  111. return "TOKEN{" + dtype + "}[" + ''.join(self.names) + "]"
  112.  
  113. def __str__(self):
  114. return ''.join(self.names)
  115.  
  116.  
  117. class THTokenName(object):
  118. """change names in torch c source
  119. files from real to ntype
  120. """
  121.  
  122. rules = [
  123. re.compile(r'(;)'), # block
  124. re.compile(r'(\()'), # inline parathesis
  125. re.compile(r'(\))'),
  126. re.compile(r'(\[)'),
  127. re.compile(r'(\])'),
  128. re.compile(r'({)'),
  129. re.compile(r'(})'),
  130. re.compile(r'(<)'), # CPP/CUDA specifier
  131. re.compile(r'(>)'),
  132. re.compile(r'(`)'), # markdown
  133. re.compile(r'(\')'), # python string
  134. re.compile(r'(")'),
  135. re.compile(r'(#)'), # macro
  136. re.compile(r'(\.)'), # operator
  137. re.compile(r'(\\)'),
  138. re.compile(r'(\w+)(\s*)(\*)'), # pointers
  139. re.compile(r'(,)'), # commas
  140. re.compile(r'(\s+)'), # spaces
  141. re.compile(r'(_)'), # underlines
  142. ]
  143.  
  144. def manipulate(self, src):
  145. for rule in self.rules:
  146. out = []
  147. for each in src:
  148. out.extend(re.split(rule, each))
  149. src = out
  150. return out
  151.  
  152. def tokenize(self, src):
  153. """split source code following self.rules
  154. and methods begin with split_, e.g split_block
  155.  
  156. def split_block(self, src):
  157. out = []
  158. for each in src:
  159. out.extend(re.split(r'(;)', each))
  160. return out
  161. """
  162. tokens = [src]
  163. tokens = self.manipulate(tokens)
  164. for name, method in inspect.getmembers(self):
  165. if name.startswith('split'):
  166. tokens = method(tokens)
  167. return tokens
  168.  
  169. def split_varname(self, src):
  170. for ind, each in enumerate(src):
  171. m = re.match(r'[A-Za-z]+', each)
  172. if m is not None:
  173. src[ind] = Token(each)
  174. return src
  175.  
  176.  
  177. class THComplexRename(THTokenName):
  178. """change real in torch source names
  179. to num.
  180. """
  181.  
  182. static_src = 'torch/lib/'
  183. c_src_dirs = [
  184. 'TH', 'THC', 'THS', 'THCS', 'THD', 'ATen',
  185. 'THNN', 'THCUNN',
  186. ]
  187.  
  188. def __init__(self, src, target,
  189. static_src=None,
  190. c_src_dirs=None,
  191. tname='ntype',
  192. ):
  193. super(THComplexRename, self).__init__()
  194. self.root = os.path.abspath(src)
  195. self.target = os.path.abspath(target)
  196. self.tname = tname
  197. if static_src is not None:
  198. self.static_src = static_src
  199. if c_src_dirs is not None:
  200. self.c_src_dirs = c_src_dirs
  201.  
  202. def rename_src(self, src):
  203. tokens = self.tokenize(src)
  204. out = []
  205. for each in tokens:
  206. if isinstance(each, Token) and each.dtype == 'real':
  207. each.dtype = self.tname
  208. out.append(each)
  209. return ''.join(str(each) for each in out)
  210.  
  211. def rename_file(self, path):
  212. with open(path, 'r') as f:
  213. raw = f.read()
  214. return self.rename_src(raw)
  215.  
  216. def rename_dir(self, path):
  217. src_path = os.path.join(self.root, path)
  218. target_path = os.path.join(self.target, path)
  219. # make target directory
  220. os.makedirs(target_path, exist_ok=True)
  221.  
  222. # walk through source directory
  223. for dirpath, dirnames, filenames in os.walk(src_path):
  224. sub_dir_relpath = os.path.relpath(dirpath, src_path)
  225. target_dir = os.path.join(target_path, sub_dir_relpath)
  226. os.makedirs(target_dir, exist_ok=True)
  227.  
  228. for file in filenames:
  229. msg = 'processing: %s' % os.path.join(dirpath, file)
  230. print(msg)
  231. with open(os.path.join(target_dir, file), 'w') as f:
  232. f.write(self.rename_file(os.path.join(dirpath, file)))
  233.  
  234. def rename(self):
  235. if os.path.isdir(self.target):
  236. print("Warning: target dir exist\nrewrite?[y/n]:", end='')
  237. if sys.stdin.read(1) == 'y':
  238. shutil.rmtree(self.target)
  239. else:
  240. return
  241.  
  242. shutil.copytree(self.root, self.target)
  243. # libraries
  244. for each in self.c_src_dirs:
  245. self.rename_dir(os.path.join(self.static_src, each))
  246.  
  247. # torch csrc
  248. self.rename_dir('torch/csrc')
  249.  
  250. # tools
  251. self.rename_dir('tools')
  252.  
  253. # test
  254. self.rename_dir('test')
  255.  
  256. # copy build script
  257. shutil.copyfile(
  258. os.path.join(self.root, self.static_src, 'build_libs.sh'),
  259. os.path.join(self.target, self.static_src, 'build_libs.sh')
  260. )
  261.  
  262.  
  263. if __name__ == '__main__':
  264. torch = THComplexRename(
  265. 'pytorch', # source dir
  266. 'complex', # target dir
  267. tname='ntype' # target name in lowercase
  268. )
  269. torch.rename()
  270. # torch.rename_dir('TH')
Add Comment
Please, Sign In to add comment