Advertisement
Guest User

Untitled

a guest
Mar 18th, 2019
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.85 KB | None | 0 0
  1. dataset = CustomDataset(items)
  2. data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_PREPROCESS_WORKERS, drop_last=False, collate_fn=filtered_collate_fn)
  3.  
  4. if len(data_loader) == 0:
  5.   print('Nothing to segment in {}', items)
  6.   return
  7.  
  8.  
  9. curr_items = BATCH_SIZE * len(data_loader)
  10. print('---> processing dataloader with {} batches of BATCH SIZE {} each at {}'.format(len(data_loader), BATCH_SIZE, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')))
  11.  
  12. st = time.time()
  13. cids, features, meta = [], [], []
  14. # note each of *_batch is a tuple
  15. for batch_idx, (id_batch, img_batch, tag_batch) in enumerate(data_loader):
  16.   try:
  17.     mask_batch = self.model.forward_pass(img_batch)
  18.     ids = [*ids, *list(id_batch)]
  19.     np_masks = [*np_masks, *list(mask_batch)]
  20.     meta = [*meta, *list(meta_batch)]
  21.  
  22. print('--> finished batch processing dataloader with {} batches of BATCH SIZE {} each at {}, Overall time taken : {:.4f} seconds'.format(len(data_loader), BATCH_SIZE, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'), time.time() - st))
  23.  
  24.  
  25. checkpoint1 = time.time()
  26.  
  27. filename_pattern = 'chunk-{}.json'.format(get_random_string_with_timestamp())
  28.  
  29.  
  30.  
  31. items = []
  32. for cid, feature, tag in zip(cids, np_masks, meta):
  33.   json_obj = {
  34.                 "cid" : str(cid),
  35.                 "feature" : base64.b64encode(gzip.compress(feature)).decode(), # compress a high dimensional vector
  36.                 "tag" : str(tag)
  37.              }
  38.            
  39. items.append(json.dumps(json_obj))
  40. file_data = '\n'.join(items)
  41.  
  42.  
  43. filename = os.path.join(s3_bucket_name, filename_pattern)
  44. S3_utils.write_data_to_bucket(boto3_s3_client, filename, file_data)
  45.  
  46. print('===> saving inference data to {}, Count {}, overall time taken {:.4f}'.format(filename, len(cids), time.time() - checkpoint))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement