Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class data_prefetcher():
- def __init__(self, loader):
- self.loader = iter(loader)
- self.stream = torch.cuda.Stream()
- self.preload()
- def preload(self):
- try:
- self.next_data_1, self.next_data_2 = next(self.loader)
- except StopIteration:
- self.next_data_1 = None
- self.next_data_2 = None
- return
- with torch.cuda.stream(self.stream):
- self.next_data_1 = self.next_data_1.cuda(non_blocking=True)
- self.next_data_2 = self.next_data_2.cuda(non_blocking=True)
- def next(self):
- torch.cuda.current_stream().wait_stream(self.stream)
- data_1, data_2 = self.next_data_1, self.next_data_2
- self.preload()
- return data_1, data_2
- prefetcher = data_prefetcher(data_loader)
- inputs, targets = prefetcher.next()
- i = 0
- while inputs is not None:
- #training process
- inputs, targets = prefetcher.next()
- i +=1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement