Advertisement
Sam____

Dataloader

Oct 17th, 2022 (edited)
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.39 KB | None | 0 0
  1. from pathlib import Path
  2. from typing import Callable, Dict, List, Optional, Tuple
  3.  
  4.  
  5. import nibabel as nib
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from pandas import DataFrame
  10. from torch import Tensor
  11. from torch.utils.data import Dataset
  12. from tqdm import tqdm
  13.  
  14. # nii paths
  15. IMGS: List[Path] = sorted(Path(__file__).resolve().parent.rglob("*.nii"))
  16. # Path to a custom csv file with the file name, subject id, and diagnosis
  17. ANNOTATIONS: DataFrame = pd.read_csv(Path(__file__).resolve().parent / "NEW_FMRI.csv")
  18. FmriSlice = Tuple[int, int, int, int]  # just a convencience type to save space
  19.  
  20.  
  21. class RandomFmriDataset(Dataset):
  22.     """Just grabs a random patch of size `patch_shape` from a random brain.
  23.    Parameters
  24.    ----------
  25.    patch_shape: Tuple[int, int, int, int]
  26.        The patch size.
  27.    standardize: bool = True
  28.        Whether or not to do intensity normalization before returning the Tensor.
  29.    transform: Optional[Callable] = None
  30.        The transform to apply to the 4D array.
  31.   """
  32.     def __init__(
  33.         self,
  34.         patch_shape: Optional[FmriSlice] = None,
  35.         standardize: bool = True,
  36.         transform: Optional[Callable] = None,
  37.     ) -> None:
  38.         self.annotations = ANNOTATIONS
  39.         self.img_paths = IMGS
  40.         self.labels: List[int] = []
  41.         self.shapes: List[Tuple[int, int, int, int]] = []
  42.         for img in IMGS:  # get the diagnosis, 0 = control, 1 = autism and other info
  43.             file_id = img.stem.replace("_func_minimal", "")
  44.             label_idx = self.annotations["FILE_ID"] == file_id
  45.             self.labels.append(self.annotations.loc[label_idx, "DX_GROUP"])
  46.             self.shapes.append(nib.load(str(img)).shape)  # usually (61, 73, 61, 236)
  47.         self.max_dims = np.max(self.shapes, axis=0)
  48.         self.min_dims = np.min(self.shapes, axis=0)
  49.  
  50.         self.standardize = standardize
  51.         self.transform = transform
  52.  
  53.         # ensure patch shape is valid
  54.         if patch_shape is None:
  55.             smallest_dims = np.min(self.shapes, axis=0)[:-1]  # exclude time dim
  56.             self.patch_shape = (*smallest_dims, 8)
  57.         else:
  58.             if len(patch_shape) != 4:
  59.                 raise ValueError("Patches must be 4D for fMRI")
  60.             for dim, max_dim in zip(patch_shape, self.max_dims):
  61.                 if dim > max_dim:
  62.                     raise ValueError("Patch size too large for data")
  63.             self.patch_shape = patch_shape
  64.  
  65.     def __len__(self) -> int:
  66.         # when generating the random dataloader, the "length" is kind of phoney. You could make the
  67.         # length be anything, e.g. 1000, 4962, or whatever. However, what you set as the length will
  68.         # define the epoch size.
  69.         return len(self.annotations)
  70.  
  71.     def __getitem__(self, index: int) -> Tensor:
  72.         # just return a random patch
  73.         global array_1
  74.         path = np.random.choice(self.img_paths)
  75.         img = nib.load(str(path))
  76.         # going larger than max_idx will put us past the end of the array
  77.         max_idx = np.array(img.shape) - np.array(self.patch_shape) + 1
  78.  
  79.         # Python has a `slice` object which you can use to index into things with the `[]` operator
  80.         # we are going to build the slices we need to index appropriately into our niis with the
  81.         # `.dataobj` trick
  82.         slices = []
  83.         for length, maximum in zip(self.patch_shape, max_idx):
  84.             start = np.random.randint(0, maximum)
  85.             slices.append(slice(start, start + length))
  86.         array = img.dataobj[slices[0], slices[1], slices[2], slices[3]]
  87.  
  88.         if self.standardize:
  89.             array_1 = np.copy(array)
  90.             array_1 -= np.mean(array_1)
  91.             array_1 /= np.std(array_1, ddof=1)
  92.         return torch.Tensor(array_1)
  93.  
  94.     def test_get_item(self) -> None:
  95.         """Just test that the produced slices can't ever go past the end of a brain"""
  96.         for path in self.img_paths:
  97.             img = nib.load(str(path))
  98.             max_idx = np.array(img.shape) - np.array(self.patch_shape) + 1
  99.             max_dims = img.shape
  100.             for length, maximum, max_dim in zip(self.patch_shape, max_idx, max_dims):
  101.                 for start in range(maximum):
  102.                     # array[a:maximum] is to the end
  103.                     assert start + length <= max_dim
  104.                     if start == maximum - 1:
  105.                         assert start + length == max_dim
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement