Advertisement
Abhisek92

Tiled_Least_Square.py

Feb 16th, 2024 (edited)
629
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.95 KB | None | 0 0
  1. import torch
  2. import numpy as np
  3. import rasterio as rio
  4. from pathlib import Path
  5. from einops import rearrange
  6. from itertools import product
  7. from rich.progress import track
  8. from rasterio.vrt import WarpedVRT
  9. from rasterio.windows import Window
  10. from rasterio.enums import Resampling
  11.  
  12. src_dir = Path("")
  13. dst_img = Path("")
  14.  
  15. dtype = np.float32
  16. device=torch.device("cpu")
  17. nodata = np.nan
  18. tile_height = 64
  19. tile_width = 64
  20.  
  21. img_list = sorted(src_dir.glob("*.tif"))
  22.  
  23. with rio.open(img_list[0], 'r') as meta_src:
  24.     meta = meta_src.meta.copy()
  25.  
  26. meta["driver"] = "GTiff"
  27. meta["dtype"] = dtype
  28. meta["predictor"] = 3
  29. meta["tiled"] = True
  30. meta["blockxsize"] = 64
  31. meta["blockysize"] = 64
  32. meta["compress"] = "LZW"
  33. meta["num_threads"] = 4
  34. meta["sparse_ok"] = True
  35. meta["nodata"] = nodata
  36.  
  37. img_window = Window(
  38.     row_off=0,
  39.     col_off=0,
  40.     height=meta["height"],
  41.     width=meta["width"]
  42. )
  43. row_markers = range(
  44.     start=0,
  45.     stop=meta["height"],
  46.     step=tile_height
  47. )
  48. col_markers = range(
  49.     start=0,
  50.     stop=meta["width"],
  51.     step=tile_width
  52. )
  53.  
  54. with rio.open(dst_img, 'w', **meta) as dst:
  55.     for rs, cs in track(
  56.         sequence=tuple(
  57.             product(
  58.                 row_markers,
  59.                 col_markers
  60.             ),
  61.             description='♻️'
  62.         )
  63.     ):
  64.         current_window = img_window.intersection(
  65.             Window(
  66.                 row_off=rs,
  67.                 col_off=cs,
  68.                 height=tile_height,
  69.                 width=tile_width
  70.             )
  71.         )
  72.         stack = list()
  73.         for img in img_list:
  74.             with rio.open(img, 'r') as src:
  75.                 with WarpedVRT(
  76.                     src_dataset=src,
  77.                     height=meta["height"],
  78.                     width=meta["width"],
  79.                     resampling=Resampling.nearest
  80.                 ) as vrt:
  81.                     stack.append(
  82.                         vrt.read(
  83.                             window=current_window,
  84.                             masked=True,
  85.                             boundless=False
  86.                         ).astype(dtype)
  87.                     )
  88.         tile_stack = np.ma.concatenate(stack, axis=0)
  89.         t, h, w = tile_stack.shape
  90.         tile_valid = rearrange(
  91.             tensor=np.logical_not(
  92.                 np.any(
  93.                     a=tile_stack.mask,
  94.                     axis=0,
  95.                     keepdims=False
  96.                 )
  97.             ),
  98.             pattern="h w -> (h w)"
  99.         )
  100.         tile_stack = rearrange(
  101.             tensor=tile_stack.filled(fill_value=nodata),
  102.             pattern="t h w -> (h w) t 1"
  103.         )
  104.         tile_stack = tile_stack[tile_valid]
  105.         tile_stack = torch.tensor(
  106.             data=tile_stack,
  107.             requires_grad=False,
  108.             device=device
  109.         )
  110.         n = tile_stack.shape[0]
  111.         dt = repeat(
  112.             tensor=torch.arange(
  113.                 start=0,
  114.                 end=t,
  115.                 step=1,
  116.                 dtype=tile_stack.dtype,
  117.                 device=device,
  118.                 requires_grad=False
  119.             ),
  120.             pattern="t -> n t 1",
  121.             n=n
  122.         )
  123.         dt = torch.cat(
  124.             tensors=[
  125.                 dt,
  126.                 torch.ones_like(dt)
  127.             ],
  128.             dim=-1
  129.         )
  130.         grad_data=(
  131.             torch.atan(
  132.                 torch.linalg.lstsq(
  133.                     A=dt,
  134.                     B=db,
  135.                     rcond=None,
  136.                     driver=None
  137.                 ).solution
  138.             ) / (torch.pi / 2)
  139.         ).detach().cpu().numpy()
  140.         grad_tile = np.full(
  141.             shape=((h * w), 2),
  142.             fill_value=nodata,
  143.             dtype=dtype
  144.         )
  145.         grad_tile[tile_valid] = grad_data
  146.         grad_tile = rearrange(
  147.             tensor=grad_tile,
  148.             pattern="(h w) 2 -> h w 2",
  149.             h=h
  150.             w=w
  151.         )
  152.         dst.write(grad_tile, window=current_window)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement