Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import numpy as np
- import rasterio as rio
- from pathlib import Path
- from einops import rearrange
- from itertools import product
- from rich.progress import track
- from rasterio.vrt import WarpedVRT
- from rasterio.windows import Window
- from rasterio.enums import Resampling
- src_dir = Path("")
- dst_img = Path("")
- dtype = np.float32
- device=torch.device("cpu")
- nodata = np.nan
- tile_height = 64
- tile_width = 64
- img_list = sorted(src_dir.glob("*.tif"))
- with rio.open(img_list[0], 'r') as meta_src:
- meta = meta_src.meta.copy()
- meta["driver"] = "GTiff"
- meta["dtype"] = dtype
- meta["predictor"] = 3
- meta["tiled"] = True
- meta["blockxsize"] = 64
- meta["blockysize"] = 64
- meta["compress"] = "LZW"
- meta["num_threads"] = 4
- meta["sparse_ok"] = True
- meta["nodata"] = nodata
- img_window = Window(
- row_off=0,
- col_off=0,
- height=meta["height"],
- width=meta["width"]
- )
- row_markers = range(
- start=0,
- stop=meta["height"],
- step=tile_height
- )
- col_markers = range(
- start=0,
- stop=meta["width"],
- step=tile_width
- )
- with rio.open(dst_img, 'w', **meta) as dst:
- for rs, cs in track(
- sequence=tuple(
- product(
- row_markers,
- col_markers
- ),
- description='♻️'
- )
- ):
- current_window = img_window.intersection(
- Window(
- row_off=rs,
- col_off=cs,
- height=tile_height,
- width=tile_width
- )
- )
- stack = list()
- for img in img_list:
- with rio.open(img, 'r') as src:
- with WarpedVRT(
- src_dataset=src,
- height=meta["height"],
- width=meta["width"],
- resampling=Resampling.nearest
- ) as vrt:
- stack.append(
- vrt.read(
- window=current_window,
- masked=True,
- boundless=False
- ).astype(dtype)
- )
- tile_stack = np.ma.concatenate(stack, axis=0)
- t, h, w = tile_stack.shape
- tile_valid = rearrange(
- tensor=np.logical_not(
- np.any(
- a=tile_stack.mask,
- axis=0,
- keepdims=False
- )
- ),
- pattern="h w -> (h w)"
- )
- tile_stack = rearrange(
- tensor=tile_stack.filled(fill_value=nodata),
- pattern="t h w -> (h w) t 1"
- )
- tile_stack = tile_stack[tile_valid]
- tile_stack = torch.tensor(
- data=tile_stack,
- requires_grad=False,
- device=device
- )
- n = tile_stack.shape[0]
- dt = repeat(
- tensor=torch.arange(
- start=0,
- end=t,
- step=1,
- dtype=tile_stack.dtype,
- device=device,
- requires_grad=False
- ),
- pattern="t -> n t 1",
- n=n
- )
- dt = torch.cat(
- tensors=[
- dt,
- torch.ones_like(dt)
- ],
- dim=-1
- )
- grad_data=(
- torch.atan(
- torch.linalg.lstsq(
- A=dt,
- B=db,
- rcond=None,
- driver=None
- ).solution
- ) / (torch.pi / 2)
- ).detach().cpu().numpy()
- grad_tile = np.full(
- shape=((h * w), 2),
- fill_value=nodata,
- dtype=dtype
- )
- grad_tile[tile_valid] = grad_data
- grad_tile = rearrange(
- tensor=grad_tile,
- pattern="(h w) 2 -> h w 2",
- h=h
- w=w
- )
- dst.write(grad_tile, window=current_window)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement