Abhisek92

SplitTiles.py

Apr 10th, 2021 (edited)
871
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import os
  2. import gdal
  3. import numpy as np
  4. import rasterio as rio
  5. from pathlib import Path
  6. from random import shuffle
  7. from math import floor, ceil
  8. from itertools import product
  9. from rasterio import windows as rio_windows
  10.  
  11. def generate_windows(
  12.     img_height,
  13.     img_width,
  14.     win_height,
  15.     win_width,
  16.     min_hoverlap,
  17.     min_woverlap,
  18.     boundless=False
  19. ):
  20.     hc = ceil((img_height - min_hoverlap) / (win_height - min_hoverlap))
  21.     wc = ceil((img_width - min_woverlap) / (win_width - min_woverlap))
  22.    
  23.    
  24.     h_overlap = ((hc * win_height) - img_height) // (hc - 1)
  25.     w_overlap = ((wc * win_height) - img_width) // (wc - 1)
  26.    
  27.    
  28.     hslack_res = ((hc * win_height) - img_height) % (hc - 1)
  29.     wslack_res = ((wc * win_width) - img_width) % (wc - 1)
  30.    
  31.     dh = win_height - h_overlap
  32.     dw = win_width - w_overlap
  33.    
  34.     row_offsets = np.arange(0, (img_height-h_overlap), dh)
  35.     col_offsets = np.arange(0, (img_width-w_overlap), dw)
  36.    
  37.     if hslack_res > 0:
  38.         row_offsets[-hslack_res:] -= np.arange(1, (hslack_res + 1), 1)
  39.     if wslack_res > 0:
  40.         col_offsets[-wslack_res:] -= np.arange(1, (wslack_res + 1), 1)
  41.    
  42.     row_offsets = row_offsets.tolist()
  43.     col_offsets = col_offsets.tolist()
  44.    
  45.     offsets = product(col_offsets, row_offsets)
  46.    
  47.     indices = product(range(len(col_offsets)), range(len(row_offsets)))
  48.    
  49.     big_window = rio_windows.Window(col_off=0, row_off=0, width=img_width, height=img_height)
  50.    
  51.     for index, (col_off, row_off) in zip(indices, offsets):
  52.         window = rio_windows.Window(
  53.             col_off=col_off,
  54.             row_off=row_off,
  55.             width=win_width,
  56.             height=win_height
  57.         )
  58.         if boundless:
  59.             yield index, window
  60.         else:
  61.             yield index, window.intersection(big_window)
  62.  
  63. def generate_image_tiles(
  64.     img_path,
  65.     win_height,
  66.     win_width,
  67.     min_hoverlap,
  68.     min_woverlap,
  69.     dst_dir,
  70.     boundless=False,
  71.     dst_driver='GTiff',
  72.     dst_ext='tif',
  73.     dst_base=None,
  74. ):
  75.     tile_list = list()
  76.     with rio.open(img_path, 'r') as src:
  77.         meta = src.meta.copy()
  78.         img_height = src.height
  79.         img_width = src.width
  80.         for idx, w in generate_windows(
  81.             img_height=img_height,
  82.             img_width=img_width,
  83.             win_height=win_height,
  84.             win_width=win_width,
  85.             min_hoverlap=min_hoverlap,
  86.             min_woverlap=min_woverlap,
  87.             boundless=False
  88.         ):
  89.             w_arr = src.read(window=w, masked=True)
  90.             if not(np.all(w_arr.mask)):
  91.                 w_transform = rio_windows.transform(
  92.                     window=w,
  93.                     transform=src.transform
  94.                 )
  95.                 meta['count'], meta['height'], meta['width'] = w_arr.shape
  96.                 meta['transform'] = w_transform
  97.                 meta['driver'] = dst_driver
  98.                 if dst_base is None:
  99.                     dst_base = img_path.stem
  100.                 d_dir = Path(dst_dir / 'Tiles')
  101.                 d_dir.mkdir(parents=True, exist_ok=True)
  102.                 dst_path = d_dir / '{}_{}_{}.{}'.format(
  103.                     dst_base, idx[0], idx[1], dst_ext
  104.                 )
  105.                 with rio.open(dst_path, 'w', **meta) as dst:
  106.                     dst.write(w_arr)
  107.                 tile_list.append(dst_path.relative_to(Path(dst_dir)))
  108.     return tile_list
  109.  
  110.  
  111. def split_tiles(tile_list, ratio=(8, 1, 1)):
  112.     denominator = sum(ratio)
  113.     sorted_index = sorted(
  114.         range(len(ratio)),
  115.         key=lambda k: ratio[k],
  116.         reverse=True
  117.     )
  118.    
  119.     part_sizes = list()
  120.     remaining = len(tile_list)
  121.     for  i in sorted_index:
  122.         c = ceil((ratio[i] * len(tile_list)) / denominator)
  123.         if c > remaining:
  124.             c = remaining
  125.         remaining -= c
  126.         part_sizes.append(c)
  127.  
  128.     part_sizes = [
  129.         x for _, x in sorted(
  130.             zip(sorted_index, part_sizes), key=lambda pair: pair[0]
  131.         )
  132.     ]
  133.     assert not(0 in part_sizes), 'Unslovable with given ratio!!'
  134.     shuffle(tile_list)
  135.     parts = list()
  136.     start = 0
  137.     for delta in part_sizes:
  138.         end = start + delta
  139.         parts.append(tile_list[start:end])
  140.         start = end
  141.     return parts
  142.  
  143. def build_vrt(tile_list, dst_path, **vrt_options):
  144.     v_ops = gdal.BuildVRTOptions(**vrt_options)
  145.     gdal.BuildVRT(
  146.         destName=str(dst_path),
  147.         srcDSOrSrcDSTab=[str(t) for t in tile_list],
  148.         options=v_ops
  149.     )
  150.    
  151.  
  152. if __name__ == '__main__':
  153.     img_path = Path('Potsdam_2_10_RGBIR.tif')
  154.     win_height = 600
  155.     win_width = 600
  156.     min_hoverlap = 0
  157.     min_woverlap = 0
  158.     dst_dir = Path('Test')
  159.     im_tiles = generate_image_tiles(
  160.         img_path=img_path,
  161.         win_height=win_height,
  162.         win_width=win_width,
  163.         min_hoverlap=min_hoverlap,
  164.         min_woverlap=min_woverlap,
  165.         dst_dir=dst_dir
  166.     )
  167.    
  168.     wd = Path(os.getcwd())
  169.     os.chdir(dst_dir)
  170.    
  171.     s_parts = split_tiles(im_tiles)
  172.     dpaths = (
  173.         Path('Test/Train.vrt'),
  174.         Path('Test/Validation.vrt'),
  175.         Path('Test/Test.vrt')
  176.     )
  177.     for p, dp in zip(s_parts, dpaths):
  178.         build_vrt(
  179.             tile_list=p,
  180.             dst_path=dp.relative_to(dst_dir),
  181.             resampleAlg='near',
  182.             addAlpha=False
  183.         )
  184.     os.chdir(wd)
  185.  
RAW Paste Data