Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """Convert generic type name (real) in PyTorch project
- to whatever you want.
- Author: Roger-luo
- """
- import os
- import re
- import sys
- import shutil
- import inspect
- class Token(object):
- """token name in torch c source.
- This class will parse the numerical type
- for each token name.
- """
- typenames = [
- 'real',
- 'complex',
- 'ntype',
- 'double',
- 'float',
- 'zdouble',
- 'zfloat',
- 'int',
- 'long',
- 'short',
- ]
- def __init__(self, text, meta=None):
- self.str = text
- # split by '_' and uppercase
- src = []
- for each in re.split(r'(_)', text):
- src.extend(re.split(r'([A-Z][a-z]+)', each))
- self.names = src
- self.with_prefix('acc')
- self.with_prefix('u')
- for name, method in inspect.getmembers(self):
- if name.startswith('init_'):
- method()
- def with_prefix(self, prefix):
- src = []
- for each in self.names:
- if each.lower() == prefix + 'real':
- src.extend([each[:len(prefix)], each[len(prefix):]])
- else:
- src.append(each)
- self.names = src
- @staticmethod
- def pattern(text):
- if text.istitle():
- return 'title'
- elif text.isupper():
- return 'upper'
- elif text.islower():
- return 'lower'
- else:
- raise ValueError("invalid text")
- def init_dtype(self):
- for ind, each in enumerate(self.names):
- if each.lower() in self.typenames:
- self._dtype = dict(
- name=each.lower(),
- pattern=self.pattern(each),
- index=ind,
- )
- break
- @property
- def dtype(self):
- out = getattr(self, '_dtype', None)
- if out is not None:
- return out['name']
- else:
- return out
- @dtype.setter
- def dtype(self, val):
- # check if dtype exists
- if self._dtype is None:
- raise ValueError("variable name does not have dtype")
- # check type
- lower = val.lower()
- if lower not in self.typenames:
- raise TypeError("Invalid type")
- # match pattern
- dtype = lower
- if self._dtype['pattern'] == 'title':
- dtype = dtype.title()
- elif self._dtype['pattern'] == 'upper':
- dtype = dtype.upper()
- self._dtype['name'] = lower
- self.names[self._dtype['index']] = dtype
- def __repr__(self):
- dtype = self.dtype if self.dtype is not None else 'none'
- return "TOKEN{" + dtype + "}[" + ''.join(self.names) + "]"
- def __str__(self):
- return ''.join(self.names)
- class THTokenName(object):
- """change names in torch c source
- files from real to ntype
- """
- rules = [
- re.compile(r'(;)'), # block
- re.compile(r'(\()'), # inline parathesis
- re.compile(r'(\))'),
- re.compile(r'(\[)'),
- re.compile(r'(\])'),
- re.compile(r'({)'),
- re.compile(r'(})'),
- re.compile(r'(<)'), # CPP/CUDA specifier
- re.compile(r'(>)'),
- re.compile(r'(`)'), # markdown
- re.compile(r'(\')'), # python string
- re.compile(r'(")'),
- re.compile(r'(#)'), # macro
- re.compile(r'(\.)'), # operator
- re.compile(r'(\\)'),
- re.compile(r'(\w+)(\s*)(\*)'), # pointers
- re.compile(r'(,)'), # commas
- re.compile(r'(\s+)'), # spaces
- re.compile(r'(_)'), # underlines
- ]
- def manipulate(self, src):
- for rule in self.rules:
- out = []
- for each in src:
- out.extend(re.split(rule, each))
- src = out
- return out
- def tokenize(self, src):
- """split source code following self.rules
- and methods begin with split_, e.g split_block
- def split_block(self, src):
- out = []
- for each in src:
- out.extend(re.split(r'(;)', each))
- return out
- """
- tokens = [src]
- tokens = self.manipulate(tokens)
- for name, method in inspect.getmembers(self):
- if name.startswith('split'):
- tokens = method(tokens)
- return tokens
- def split_varname(self, src):
- for ind, each in enumerate(src):
- m = re.match(r'[A-Za-z]+', each)
- if m is not None:
- src[ind] = Token(each)
- return src
- class THComplexRename(THTokenName):
- """change real in torch source names
- to num.
- """
- static_src = 'torch/lib/'
- c_src_dirs = [
- 'TH', 'THC', 'THS', 'THCS', 'THD', 'ATen',
- 'THNN', 'THCUNN',
- ]
- def __init__(self, src, target,
- static_src=None,
- c_src_dirs=None,
- tname='ntype',
- ):
- super(THComplexRename, self).__init__()
- self.root = os.path.abspath(src)
- self.target = os.path.abspath(target)
- self.tname = tname
- if static_src is not None:
- self.static_src = static_src
- if c_src_dirs is not None:
- self.c_src_dirs = c_src_dirs
- def rename_src(self, src):
- tokens = self.tokenize(src)
- out = []
- for each in tokens:
- if isinstance(each, Token) and each.dtype == 'real':
- each.dtype = self.tname
- out.append(each)
- return ''.join(str(each) for each in out)
- def rename_file(self, path):
- with open(path, 'r') as f:
- raw = f.read()
- return self.rename_src(raw)
- def rename_dir(self, path):
- src_path = os.path.join(self.root, path)
- target_path = os.path.join(self.target, path)
- # make target directory
- os.makedirs(target_path, exist_ok=True)
- # walk through source directory
- for dirpath, dirnames, filenames in os.walk(src_path):
- sub_dir_relpath = os.path.relpath(dirpath, src_path)
- target_dir = os.path.join(target_path, sub_dir_relpath)
- os.makedirs(target_dir, exist_ok=True)
- for file in filenames:
- msg = 'processing: %s' % os.path.join(dirpath, file)
- print(msg)
- with open(os.path.join(target_dir, file), 'w') as f:
- f.write(self.rename_file(os.path.join(dirpath, file)))
- def rename(self):
- if os.path.isdir(self.target):
- print("Warning: target dir exist\nrewrite?[y/n]:", end='')
- if sys.stdin.read(1) == 'y':
- shutil.rmtree(self.target)
- else:
- return
- shutil.copytree(self.root, self.target)
- # libraries
- for each in self.c_src_dirs:
- self.rename_dir(os.path.join(self.static_src, each))
- # torch csrc
- self.rename_dir('torch/csrc')
- # tools
- self.rename_dir('tools')
- # test
- self.rename_dir('test')
- # copy build script
- shutil.copyfile(
- os.path.join(self.root, self.static_src, 'build_libs.sh'),
- os.path.join(self.target, self.static_src, 'build_libs.sh')
- )
- if __name__ == '__main__':
- torch = THComplexRename(
- 'pytorch', # source dir
- 'complex', # target dir
- tname='ntype' # target name in lowercase
- )
- torch.rename()
- # torch.rename_dir('TH')
Add Comment
Please, Sign In to add comment