Advertisement
Guest User

Untitled

a guest
Aug 18th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.89 KB | None | 0 0
  1. import os
  2. import pickle
  3. import numpy as np
  4. import pyarrow as pa
  5. import pyarrow.parquet as pq
  6. import itertools
  7. from functools import partial
  8. from collections import defaultdict
  9. from rdkit import Chem
  10. from concurrent.futures import ThreadPoolExecutor
  11.  
  12.  
  13. def consume_shards(serialize_fns, schema_map, shards):
  14. column_stage = defaultdict(list)
  15. for shard in shards:
  16. for k, v in shard.items():
  17. column_stage[k].append(v)
  18. arrays = []
  19. names = []
  20. for k, v in column_stage.items():
  21. new_key = schema_map.get(k)
  22. if new_key:
  23. sfn = serialize_fns.get(new_key)[0]
  24. v = [sfn(x) for x in v]
  25. names.append(new_key)
  26. else:
  27. names.append(k)
  28. arrays.append(pa.array(v))
  29. table = pa.Table.from_arrays(arrays, names)
  30. return table
  31.  
  32.  
  33. class Megaset:
  34. """
  35. Dataset class for processing and training over datasets that won't fit
  36. into memory. Uses a memory-mapped parquet file under the hood.
  37.  
  38. Batch size is set during initialization. This corresponds to row group size
  39. of the underlying parquet file.
  40.  
  41. The parquet file is generated if a shard geneator is supplied otherwise
  42. it is read from disk for each row group (batch) during iteration of the
  43. batch_geneator.
  44.  
  45. The shard generator is required to return a dictionary per sample (datapoint)
  46. like such: { 'X': 'CC=CC', 'y': 3, 'mycoolfield': 'a' }
  47. The keys can be anything that would pass as a column name in the parquetfile.
  48.  
  49. During batch reading, a whole row group (batch) is read at once, the
  50. corresponding dictionary will have all your keys from the dict yielded by
  51. the shard generator but the values will be lists of the original. I.E:
  52. { 'X': ['CC=CC', 'CCCCC', ...] 'y': [3, 4, ...], 'mycoolfield': ['a', 'b', ...] }
  53. """
  54.  
  55. def __init__(self,
  56. path,
  57. batch_size=100,
  58. shard_generator=None,
  59. num_workers=8):
  60. self.path = path
  61. self.batch_size = batch_size
  62. self.num_workers = num_workers
  63. self.batch_multiple = 8
  64. self.column_staging = defaultdict(list)
  65. self.serialize_fns = dict()
  66. self.schema_map = dict()
  67. self.shard_generator = shard_generator
  68. if self.shard_generator is not None:
  69. self._infer_schema()
  70. self._generate_from_shards()
  71. self.datafile = pq.ParquetFile(self.path)
  72. if self.shard_generator is None:
  73. self._load_schema()
  74.  
  75. def _infer_schema(self):
  76. """
  77. Infers the schema for the parquet file from the first sample
  78. of the shard generator
  79. """
  80. column_schema = []
  81. shard = next(self.shard_generator())
  82. for k, v in shard.items():
  83. if isinstance(v, str):
  84. column_schema.append(pa.field(k, pa.string()))
  85. elif isinstance(v, bytes):
  86. column_schema.append(pa.field(k, pa.binary()))
  87. elif isinstance(v, (object, np.ndarray)):
  88. new_key = '__pkl__' + k
  89. column_schema.append(pa.field(new_key, pa.binary()))
  90. self.schema_map[k] = new_key
  91. self.serialize_fns[new_key] = (
  92. lambda x: pickle.dumps(x, pickle.HIGHEST_PROTOCOL),
  93. pickle.loads)
  94. self.schema = pa.schema(column_schema)
  95.  
  96. def _load_schema(self):
  97. self.schema = self.datafile.metadata.schema
  98. for col_def in self.schema:
  99. name = col_def.name
  100. if name.startswith('__pkl__'):
  101. self.serialize_fns[name] = (None, pickle.loads)
  102.  
  103. def _generate_from_shards(self):
  104. """
  105. Writes out the parquet file from the generator
  106. """
  107. writer = pq.ParquetWriter(self.path, self.schema)
  108. shard_gen = self.shard_generator()
  109. num_samples_processed = 0
  110. while True:
  111. work = [
  112. x for x in [
  113. list(itertools.islice(shard_gen, self.batch_size))
  114. for _ in range(self.num_workers)
  115. ] if x
  116. ]
  117. if not work:
  118. break
  119. partial_consume_shards = partial(
  120. consume_shards, self.serialize_fns, self.schema_map)
  121. with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
  122. futures = executor.map(partial_consume_shards, work)
  123. for batch in futures:
  124. num_samples_processed += self.batch_size
  125. print(num_samples_processed)
  126. writer.write_table(batch)
  127.  
  128. def __len__(self):
  129. return self.datafile.metadata.num_rows
  130.  
  131. def batch_generator(self):
  132. """
  133. Returns a generator that yields a dictionary for each batch
  134. The keys of the dictionary are they keys of the columns of the parquet file
  135. the values are a list of batch_size
  136. """
  137.  
  138. def batch_generator():
  139. for row_group in range(self.datafile.num_row_groups):
  140. rg = self.datafile.read_row_group(row_group).to_pydict()
  141. if self.serialize_fns:
  142. for k, v in self.serialize_fns.items():
  143. sfn = v[1]
  144. old_key = k[7:]
  145. with ThreadPoolExecutor(
  146. max_workers=self.num_workers) as executor:
  147. rg[old_key] = list(executor.map(sfn, rg[k]))
  148. del rg[k]
  149. yield rg
  150.  
  151. return batch_generator
  152.  
  153.  
  154. if __name__ == '__main__':
  155.  
  156. def pickle_shard_gen():
  157. # this is an example shard generator function
  158. # It should return a dict where the keys will correspond
  159. # to columns in the parquet column store
  160. csv_file = './smaller.csv'
  161. with open(csv_file, 'r') as f:
  162. for line in f:
  163. smiles = line.strip("\r\n ").split(',')[0]
  164. mol = Chem.MolFromSmiles(smiles)
  165. yield {'X': mol}
  166.  
  167. # write a megaset to disk using the shard generator to featurize samples
  168. megaset = Megaset('./p.dat', shard_generator=pickle_shard_gen)
  169.  
  170. # load it from disk
  171. megaset = Megaset('./p.dat')
  172.  
  173. # get a batch generator
  174. bg = megaset.batch_generator()
  175. for batch in bg():
  176. # prints a list of mol objects for batch
  177. print(batch['X'])
  178.  
  179. from ligand_ml.feat.graph_features import ConvMolFeaturizer
  180. featurizer = ConvMolFeaturizer()
  181.  
  182. def numpy_shard_gen():
  183. # same as above but this time we're saving numpy arrays
  184. csv_file = './smaller.csv'
  185. with open(csv_file, 'r') as f:
  186. for line in f:
  187. smiles = line.strip("\r\n ").split(',')[0]
  188. mol = Chem.MolFromSmiles(smiles)
  189. conv_mol = featurizer._featurize(mol)
  190. feature_matrix = conv_mol.get_atom_features()
  191. yield {'X': feature_matrix}
  192.  
  193. megaset = Megaset('./p.dat', shard_generator=numpy_shard_gen)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement