Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def computeOpticalFlows(self, targetRgb, sourceRgbs):
- print(sourceRgbs.shape)
- MB, seqSize, _, height, width = sourceRgbs.shape
- firstResult = self.opticalFlowNet(targetRgb, sourceRgbs[:, 0])
- results = [tr.zeros(MB, seqSize, 2, h, w).to(device) for (h, w) in map(lambda x : x.shape[2 : 4], firstResult)]
- for scaleIx in range(len(results)):
- results[scaleIx][:, 0] = firstResult[scaleIx]
- for i in range(1, seqSize):
- res = self.opticalFlowNet(targetRgb, sourceRgbs[:, i])
- for scaleIx in range(len(results)):
- results[scaleIx][:, i] = res[scaleIx]
- return results
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement