Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import pickle
- import numpy as np
- import pyarrow as pa
- import pyarrow.parquet as pq
- import itertools
- from functools import partial
- from collections import defaultdict
- from rdkit import Chem
- from concurrent.futures import ThreadPoolExecutor
- def consume_shards(serialize_fns, schema_map, shards):
- column_stage = defaultdict(list)
- for shard in shards:
- for k, v in shard.items():
- column_stage[k].append(v)
- arrays = []
- names = []
- for k, v in column_stage.items():
- new_key = schema_map.get(k)
- if new_key:
- sfn = serialize_fns.get(new_key)[0]
- v = [sfn(x) for x in v]
- names.append(new_key)
- else:
- names.append(k)
- arrays.append(pa.array(v))
- table = pa.Table.from_arrays(arrays, names)
- return table
- class Megaset:
- """
- Dataset class for processing and training over datasets that won't fit
- into memory. Uses a memory-mapped parquet file under the hood.
- Batch size is set during initialization. This corresponds to row group size
- of the underlying parquet file.
- The parquet file is generated if a shard geneator is supplied otherwise
- it is read from disk for each row group (batch) during iteration of the
- batch_geneator.
- The shard generator is required to return a dictionary per sample (datapoint)
- like such: { 'X': 'CC=CC', 'y': 3, 'mycoolfield': 'a' }
- The keys can be anything that would pass as a column name in the parquetfile.
- During batch reading, a whole row group (batch) is read at once, the
- corresponding dictionary will have all your keys from the dict yielded by
- the shard generator but the values will be lists of the original. I.E:
- { 'X': ['CC=CC', 'CCCCC', ...] 'y': [3, 4, ...], 'mycoolfield': ['a', 'b', ...] }
- """
- def __init__(self,
- path,
- batch_size=100,
- shard_generator=None,
- num_workers=8):
- self.path = path
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.batch_multiple = 8
- self.column_staging = defaultdict(list)
- self.serialize_fns = dict()
- self.schema_map = dict()
- self.shard_generator = shard_generator
- if self.shard_generator is not None:
- self._infer_schema()
- self._generate_from_shards()
- self.datafile = pq.ParquetFile(self.path)
- if self.shard_generator is None:
- self._load_schema()
- def _infer_schema(self):
- """
- Infers the schema for the parquet file from the first sample
- of the shard generator
- """
- column_schema = []
- shard = next(self.shard_generator())
- for k, v in shard.items():
- if isinstance(v, str):
- column_schema.append(pa.field(k, pa.string()))
- elif isinstance(v, bytes):
- column_schema.append(pa.field(k, pa.binary()))
- elif isinstance(v, (object, np.ndarray)):
- new_key = '__pkl__' + k
- column_schema.append(pa.field(new_key, pa.binary()))
- self.schema_map[k] = new_key
- self.serialize_fns[new_key] = (
- lambda x: pickle.dumps(x, pickle.HIGHEST_PROTOCOL),
- pickle.loads)
- self.schema = pa.schema(column_schema)
- def _load_schema(self):
- self.schema = self.datafile.metadata.schema
- for col_def in self.schema:
- name = col_def.name
- if name.startswith('__pkl__'):
- self.serialize_fns[name] = (None, pickle.loads)
- def _generate_from_shards(self):
- """
- Writes out the parquet file from the generator
- """
- writer = pq.ParquetWriter(self.path, self.schema)
- shard_gen = self.shard_generator()
- num_samples_processed = 0
- while True:
- work = [
- x for x in [
- list(itertools.islice(shard_gen, self.batch_size))
- for _ in range(self.num_workers)
- ] if x
- ]
- if not work:
- break
- partial_consume_shards = partial(
- consume_shards, self.serialize_fns, self.schema_map)
- with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
- futures = executor.map(partial_consume_shards, work)
- for batch in futures:
- num_samples_processed += self.batch_size
- print(num_samples_processed)
- writer.write_table(batch)
- def __len__(self):
- return self.datafile.metadata.num_rows
- def batch_generator(self):
- """
- Returns a generator that yields a dictionary for each batch
- The keys of the dictionary are they keys of the columns of the parquet file
- the values are a list of batch_size
- """
- def batch_generator():
- for row_group in range(self.datafile.num_row_groups):
- rg = self.datafile.read_row_group(row_group).to_pydict()
- if self.serialize_fns:
- for k, v in self.serialize_fns.items():
- sfn = v[1]
- old_key = k[7:]
- with ThreadPoolExecutor(
- max_workers=self.num_workers) as executor:
- rg[old_key] = list(executor.map(sfn, rg[k]))
- del rg[k]
- yield rg
- return batch_generator
- if __name__ == '__main__':
- def pickle_shard_gen():
- # this is an example shard generator function
- # It should return a dict where the keys will correspond
- # to columns in the parquet column store
- csv_file = './smaller.csv'
- with open(csv_file, 'r') as f:
- for line in f:
- smiles = line.strip("\r\n ").split(',')[0]
- mol = Chem.MolFromSmiles(smiles)
- yield {'X': mol}
- # write a megaset to disk using the shard generator to featurize samples
- megaset = Megaset('./p.dat', shard_generator=pickle_shard_gen)
- # load it from disk
- megaset = Megaset('./p.dat')
- # get a batch generator
- bg = megaset.batch_generator()
- for batch in bg():
- # prints a list of mol objects for batch
- print(batch['X'])
- from ligand_ml.feat.graph_features import ConvMolFeaturizer
- featurizer = ConvMolFeaturizer()
- def numpy_shard_gen():
- # same as above but this time we're saving numpy arrays
- csv_file = './smaller.csv'
- with open(csv_file, 'r') as f:
- for line in f:
- smiles = line.strip("\r\n ").split(',')[0]
- mol = Chem.MolFromSmiles(smiles)
- conv_mol = featurizer._featurize(mol)
- feature_matrix = conv_mol.get_atom_features()
- yield {'X': feature_matrix}
- megaset = Megaset('./p.dat', shard_generator=numpy_shard_gen)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement