Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import rasterio as rio
- from pathlib import Path
- from einops import rearrange
- from rich.progress import track
- src_dir = Path("")
- dst_img = Path("")
- dtype = torch.float32
- device="cpu"
- nodata = torch.nan
- img_list = sorted(src_dir.glob("*.tif"))
- db = torch.cat(
- tensors=[
- torch.tensor(
- data=rio.open(img, 'r').read(),
- dtype=dtype,
- device=device
- )
- for img in track(
- sequence=img_list,
- description='♻️'
- )
- ],
- dim=0
- )
- t, h, w = db.shape
- db = rearrange(
- tensor=db,
- pattern="t h w -> (h w) t 1"
- )
- dt = repeat(
- tensor=torch.arange(
- start=0,
- end=t,
- step=1,
- dtype=dtype,
- device=device,
- requires_grad=False
- ),
- pattern="t -> n t 1",
- n=(h * w)
- )
- dt = torch.cat(
- tensors=[
- dt,
- torch.ones_like(dt)
- ],
- dim=-1
- )
- grad_img = rearrange(
- tensor=(
- torch.atan(
- torch.linalg.lstsq(
- A=dt,
- B=db,
- rcond=None,
- driver=None
- ).solution
- ) / (torch.pi / 2)
- ),
- pattern="(h w) c 1 -> c h w"
- ).detach().cpu().numpy()
- meta = rio.open(img_list[0], 'r').meta.copy()
- meta["count"], meta["height"], meta["width"] = grad_img.shape
- meta["driver"] = "GTiff"
- meta["dtype"] = grad_img.dtype
- meta["predictor"] = 3
- meta["tiled"] = True
- meta["blockxsize"] = 64
- meta["blockysize"] = 64
- meta["compress"] = "LZW"
- meta["num_threads"] = 4
- meta["sparse_ok"] = True
- meta["nodata"] = nodata
- # 1st channel `m` and 2nd channel `c`
- with rio.open(dst_img, 'w', **meta) as dst:
- dst.write(grad_img)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement