Advertisement
Guest User

Untitled

a guest
Sep 19th, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.94 KB | None | 0 0
  1. class data_prefetcher():
  2. def __init__(self, loader):
  3. self.loader = iter(loader)
  4. self.stream = torch.cuda.Stream()
  5. self.preload()
  6.  
  7. def preload(self):
  8. try:
  9. self.next_data_1, self.next_data_2 = next(self.loader)
  10. except StopIteration:
  11. self.next_data_1 = None
  12. self.next_data_2 = None
  13. return
  14. with torch.cuda.stream(self.stream):
  15. self.next_data_1 = self.next_data_1.cuda(non_blocking=True)
  16. self.next_data_2 = self.next_data_2.cuda(non_blocking=True)
  17.  
  18. def next(self):
  19. torch.cuda.current_stream().wait_stream(self.stream)
  20. data_1, data_2 = self.next_data_1, self.next_data_2
  21. self.preload()
  22. return data_1, data_2
  23.  
  24. prefetcher = data_prefetcher(data_loader)
  25. inputs, targets = prefetcher.next()
  26. i = 0
  27. while inputs is not None:
  28. #training process
  29. inputs, targets = prefetcher.next()
  30. i +=1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement