Advertisement
Sam____

Dataloader_V2

Dec 21st, 2022
18
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.71 KB | None | 0 0
  1. from pathlib import Path
  2. from typing import Callable, List, Optional, Tuple
  3.  
  4.  
  5. import nibabel as nib
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from pandas import DataFrame
  10. from torch import Tensor
  11. from torch.utils.data import Dataset
  12. import torchvision.transforms as transforms
  13.  
  14. # nii paths
  15. IMGS: List[Path] = sorted(Path(__file__).resolve().parent.rglob("*.nii"))
  16. # Path to a custom csv file with the file name, subject id, and diagnosis
  17. ANNOTATIONS: DataFrame = pd.read_csv(Path(__file__).resolve().parent / "ABIDE_FMRI_Balanced.csv")
  18. FmriSlice = Tuple[int, int, int, int]  # just a convencience type to save space
  19.  
  20.  
  21. class RandomFmriDataset(Dataset):
  22.     def __init__(
  23.         self,
  24.         patch_shape: Optional[FmriSlice] = None,
  25.         standardize: bool = True,
  26.         transform: Optional[Callable] = None,
  27.     ) -> None:
  28.         self.annotations = ANNOTATIONS
  29.         self.img_paths = IMGS
  30.         self.labels: List[int] = []
  31.         self.shapes: List[Tuple[int, int, int, int]] = []
  32.         for img in IMGS:  # get the diagnosis, 0 = control, 1 = autism and other info
  33.             file_id = img.stem.replace("_func_minimal", "")
  34.             label_idx = self.annotations["FILE_ID"] == file_id
  35.             self.labels.append(self.annotations.loc[label_idx, "DX_GROUP"]) # 1 = Autism, 0 = Control
  36.             self.shapes.append(nib.load(str(img)).shape)  # usually (61, 73, 61, 236)
  37.         self.max_dims = np.max(self.shapes, axis=0)
  38.         self.min_dims = np.min(self.shapes, axis=0)
  39.  
  40.         self.standardize = standardize
  41.         self.transform = transforms.Compose([transforms.ToTensor()])
  42.  
  43.         # ensure patch shape is valid
  44.         if patch_shape is None:
  45.             smallest_dims = np.min(self.shapes, axis=0)[:-1]  # exclude time dim
  46.             self.patch_shape = (*smallest_dims, 8)
  47.         else:
  48.             if len(patch_shape) != 4:
  49.                 raise ValueError("Patches must be 4D for fMRI")
  50.             for dim, max_dim in zip(patch_shape, self.max_dims):
  51.                 if dim > max_dim:
  52.                     raise ValueError("Patch size too large for data")
  53.             self.patch_shape = patch_shape
  54.  
  55.     def __len__(self) -> int:
  56.         # when generating the random dataloader, the "length" is kind of phoney. You could make the
  57.         # length be anything, e.g. 1000, 4962, or whatever. However, what you set as the length will
  58.         # define the epoch size.
  59.         return len(self.annotations)
  60.  
  61.     def __getitem__(self, index: int) -> Tensor:
  62.         # just return a random patch
  63.         global array_1
  64.         path = self.img_paths[index]
  65.         img = nib.load(str(path))
  66.         # going larger than max_idx will put us past the end of the array
  67.         max_idx = np.array(img.shape) - np.array(self.patch_shape) + 1
  68.        
  69.         single_label = self.labels[index]
  70.         np_single_label = np.asarray(single_label)
  71.         #label_tensor = torch.Tensor(single_label)
  72.  
  73.         # Python has a `slice` object which you can use to index into things with the `[]` operator
  74.         # we are going to build the slices we need to index appropriately into our niis with the
  75.         # `.dataobj` trick
  76.         slices = []
  77.         for length, maximum in zip(self.patch_shape, max_idx):
  78.             start = np.array(0, maximum)
  79.             slices.append(slice(start, start + length))
  80.         array = img.dataobj[slices[0], slices[1], slices[2], slices[3]]
  81.  
  82.         if self.standardize:
  83.             array_1 = np.copy(array)
  84.             array_1 -= np.mean(array_1)
  85.             array_1 /= np.std(array_1, ddof=1)
  86.        
  87.         array_1 = array_1.transpose((3,2,1,0))
  88.         img_tensor = torch.Tensor(array_1)
  89.        
  90.         return (img_tensor, np_single_label)
  91.  
  92.  
  93.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement