Advertisement
Abhisek92

Least_Square.py

Feb 13th, 2024 (edited)
667
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.73 KB | None | 0 0
  1. import torch
  2. import rasterio as rio
  3. from pathlib import Path
  4. from einops import rearrange
  5. from rich.progress import track
  6.  
  7. src_dir = Path("")
  8. dst_img = Path("")
  9.  
  10. dtype = torch.float32
  11. device="cpu"
  12. nodata = torch.nan
  13.  
  14. img_list = sorted(src_dir.glob("*.tif"))
  15.  
  16. db = torch.cat(
  17.     tensors=[
  18.         torch.tensor(
  19.             data=rio.open(img, 'r').read(),
  20.             dtype=dtype,
  21.             device=device
  22.         )
  23.         for img in track(
  24.             sequence=img_list,
  25.             description='♻️'
  26.         )
  27.     ],
  28.     dim=0
  29. )
  30.  
  31. t, h, w = db.shape
  32.  
  33. db = rearrange(
  34.     tensor=db,
  35.     pattern="t h w -> (h w) t 1"
  36. )
  37.  
  38. dt = repeat(
  39.     tensor=torch.arange(
  40.         start=0,
  41.         end=t,
  42.         step=1,
  43.         dtype=dtype,
  44.         device=device,
  45.         requires_grad=False
  46.     ),
  47.     pattern="t -> n t 1",
  48.     n=(h * w)
  49. )
  50.  
  51. dt = torch.cat(
  52.     tensors=[
  53.         dt,
  54.         torch.ones_like(dt)
  55.     ],
  56.     dim=-1
  57. )
  58.  
  59. grad_img = rearrange(
  60.     tensor=(
  61.         torch.atan(
  62.             torch.linalg.lstsq(
  63.                 A=dt,
  64.                 B=db,
  65.                 rcond=None,
  66.                 driver=None
  67.             ).solution
  68.         ) / (torch.pi / 2)
  69.     ),
  70.     pattern="(h w) c 1 -> c h w"
  71. ).detach().cpu().numpy()
  72.  
  73. meta = rio.open(img_list[0], 'r').meta.copy()
  74. meta["count"], meta["height"], meta["width"] = grad_img.shape
  75. meta["driver"] = "GTiff"
  76. meta["dtype"] = grad_img.dtype
  77. meta["predictor"] = 3
  78. meta["tiled"] = True
  79. meta["blockxsize"] = 64
  80. meta["blockysize"] = 64
  81. meta["compress"] = "LZW"
  82. meta["num_threads"] = 4
  83. meta["sparse_ok"] = True
  84. meta["nodata"] = nodata
  85.  
  86. # 1st channel `m` and 2nd channel `c`
  87. with rio.open(dst_img, 'w', **meta) as dst:
  88.     dst.write(grad_img)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement