Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.33 KB | None | 0 0
  1. import os
  2. import glob
  3. import pyproj
  4. import shapely
  5. import shapely.geometry
  6. import shapely.ops
  7. import fiona
  8. import rasterio
  9. import rasterio.mask
  10. import rasterio.merge
  11. import numpy
  12. import pickle
  13.  
  14. def project_wsg_shape_to_csr(shape, from_crs, to_crs):
  15. project = lambda x, y: pyproj.transform(
  16. from_crs,
  17. to_crs,
  18. x,
  19. y
  20. )
  21. return shapely.ops.transform(project, shape)
  22.  
  23. train_shapefile = fiona.open("train/train.shp", "r")
  24. train_shape_crs = pyproj.Proj(train_shapefile.crs)
  25.  
  26. test_shapefile = fiona.open("test/test.shp", "r")
  27. test_shape_crs = pyproj.Proj(test_shapefile.crs)
  28. #print(shapefile.crs)
  29.  
  30. # Start by enumerating SAFE products
  31. # TODO: check cloud contamination using s2cloudless
  32. product_groups = {}
  33. train_field_data = {}
  34. train_field_data_r = {}
  35. train_field_data_g = {}
  36. train_field_data_b = {}
  37. test_field_data = {}
  38. test_field_data_r = {}
  39. test_field_data_g = {}
  40. test_field_data_b = {}
  41. for product_fn in glob.glob('*.SAFE'):
  42. #print(product_fn)
  43. """
  44. The compact naming convention is arranged as follows:
  45.  
  46. MMM_MSIL1C_YYYYMMDDHHMMSS_Nxxyy_ROOO_Txxxxx_<Product Discriminator>.SAFE
  47.  
  48. The products contain two dates.
  49.  
  50. The first date (YYYYMMDDHHMMSS) is the datatake sensing time.
  51. The second date is the "<Product Discriminator>" field, which is 15 characters in length, and is used to distinguish between different end user products from the same datatake. Depending on the instance, the time in this field can be earlier or slightly later than the datatake sensing time.
  52.  
  53. The other components of the filename are:
  54.  
  55. MMM: is the mission ID(S2A/S2B)
  56. MSIL1C: denotes the Level-1C product level
  57. YYYYMMDDHHMMSS: the datatake sensing start time
  58. Nxxyy: the Processing Baseline number (e.g. N0204)
  59. ROOO: Relative Orbit number (R001 - R143)
  60. Txxxxx: Tile Number field
  61. SAFE: Product Format (Standard Archive Format for Europe)
  62. """
  63. # Split the product name into parts
  64. product_attrs = product_fn.split('_')
  65. datatake_time = product_attrs[2]
  66. tile_number = product_attrs[5]
  67. # Since the shape files provided cover two tiles, group tiles by datatake_time
  68. if datatake_time in product_groups:
  69. product_groups[datatake_time].append(product_fn)
  70. else:
  71. product_groups[datatake_time] = [product_fn]
  72.  
  73. # sort the dict in the chronological order
  74. product_groups = dict(sorted(product_groups.items()))
  75.  
  76. # Enumerate groups of tiles
  77. for product_group in product_groups:
  78. print('*** Processing {}..'.format(product_group))
  79. b2 = [] # all B4 bands for a group, blue
  80. b3 = [] # all B4 bands for a group, green
  81. b4 = [] # all B4 bands for a group, red
  82. b8 = [] # all B8 bands for a group
  83. for product_fn in product_groups[product_group]:
  84. print(' {}'.format(product_fn))
  85. b2fn = ''
  86. b3fn = ''
  87. b4fn = ''
  88. b8fn = ''
  89. for bandfn in glob.glob('{}/GRANULE/*/IMG_DATA/*.jp2'.format(product_fn)):
  90. # Split the band file name
  91. base = os.path.basename(bandfn)
  92. band_attrs = os.path.splitext(base)[0].split('_')
  93. band_type = band_attrs[2] # B01, B02, etc
  94. if band_type == 'B02':
  95. b2fn = bandfn
  96. if band_type == 'B03':
  97. b3fn = bandfn
  98. if band_type == 'B04':
  99. b4fn = bandfn
  100. if band_type == 'B08':
  101. b8fn = bandfn
  102.  
  103. assert b4fn and b8fn # should have both values
  104. b2.append(rasterio.open(b2fn))
  105. b3.append(rasterio.open(b3fn))
  106. b4.append(rasterio.open(b4fn))
  107. b8.append(rasterio.open(b8fn))
  108.  
  109. print(' Merging bands..')
  110. # For a group of tiles/products, merge bands from different tiles together
  111. blue, _ = rasterio.merge.merge(b2)
  112. green, _ = rasterio.merge.merge(b3)
  113. red, out_trans = rasterio.merge.merge(b4)
  114. nir, _ = rasterio.merge.merge(b8)
  115.  
  116. # Calculate the NDVI, given B4 and B8 band filenames
  117. print(' Calculating the NDVI..')
  118. ndvi = (nir.astype(float) - red.astype(float)) / (nir + red)
  119. # Save the NDVI image for manual analysis later
  120. print(' Saving the NDVI raster to ndvi/{}.tif..'.format(product_group))
  121. meta = b4[0].meta.copy()
  122. meta.update(dtype=rasterio.float64,
  123. compress='lzw',
  124. driver='GTiff',
  125. transform=out_trans,
  126. height=red.shape[1],
  127. width=red.shape[2]
  128. )
  129. with rasterio.open('ndvi/{}.tif'.format(product_group), 'w', **meta) as dst:
  130. dst.write(ndvi)
  131. dst.close()
  132.  
  133. # convert 0..255 range in r,g,b to 0..1
  134. red = red.astype(float) / 65535
  135. green = green.astype(float) / 65535
  136. blue = blue.astype(float) / 65535
  137.  
  138. # Save red, green and blue images as well
  139. print(' Saving the RGB raster to rgb/{}-r/g/b.tif..'.format(product_group))
  140. with rasterio.open('rgb/{}-r.tif'.format(product_group), 'w', **meta) as dst:
  141. dst.write(red)
  142. dst.close()
  143. with rasterio.open('rgb/{}-g.tif'.format(product_group), 'w', **meta) as dst:
  144. dst.write(green)
  145. dst.close()
  146. with rasterio.open('rgb/{}-b.tif'.format(product_group), 'w', **meta) as dst:
  147. dst.write(blue)
  148. dst.close()
  149.  
  150. ndvi_img = rasterio.open('ndvi/{}.tif'.format(product_group))
  151. #print(' NDVI CRS is', ndvi_img.crs.data)
  152. ndvi_crs = pyproj.Proj(ndvi_img.crs)
  153.  
  154. red_img = rasterio.open('rgb/{}-r.tif'.format(product_group))
  155. red_crs = pyproj.Proj(red_img.crs)
  156. green_img = rasterio.open('rgb/{}-g.tif'.format(product_group))
  157. green_crs = pyproj.Proj(green_img.crs)
  158. blue_img = rasterio.open('rgb/{}-b.tif'.format(product_group))
  159. blue_crs = pyproj.Proj(blue_img.crs)
  160.  
  161. # Alright, NDVI is ready for the whole region in question
  162. # Use the shape file to mask out everything, except fields
  163. for field in train_shapefile:
  164. #print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne'])
  165. field_id = field['properties']['Field_Id']
  166. #print(' Cropping NDVI data for train field #{}'.format(field_id))
  167. try:
  168. projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']),
  169. train_shape_crs,
  170. ndvi_crs)
  171. except Exception as e:
  172. print(' ', e, ' exception for field #', field_id)
  173. continue
  174.  
  175. #print(projected_shape)
  176. field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True)
  177. field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True)
  178. field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True)
  179. field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True)
  180. # remove the first dimension
  181. field_img = numpy.squeeze(field_img, axis=0)
  182. field_img_red = numpy.squeeze(field_img_red, axis=0)
  183. field_img_green = numpy.squeeze(field_img_green, axis=0)
  184. field_img_blue = numpy.squeeze(field_img_blue, axis=0)
  185. # add the 3rd dimension
  186. field_img = numpy.expand_dims(field_img, 2)
  187. field_img_red = numpy.expand_dims(field_img_red, 2)
  188. field_img_green = numpy.expand_dims(field_img_green, 2)
  189. field_img_blue = numpy.expand_dims(field_img_blue, 2)
  190.  
  191. if field_id in train_field_data:
  192. train_field_data[field_id] = numpy.concatenate((train_field_data[field_id], field_img), axis=2)
  193. train_field_data_r[field_id] = numpy.concatenate((train_field_data_r[field_id], field_img_red), axis=2)
  194. train_field_data_g[field_id] = numpy.concatenate((train_field_data_g[field_id], field_img_green), axis=2)
  195. train_field_data_b[field_id] = numpy.concatenate((train_field_data_b[field_id], field_img_blue), axis=2)
  196. else:
  197. train_field_data[field_id] = field_img
  198. train_field_data_r[field_id] = field_img_red
  199. train_field_data_g[field_id] = field_img_green
  200. train_field_data_b[field_id] = field_img_blue
  201.  
  202. for field in test_shapefile:
  203. #print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne'])
  204. field_id = field['properties']['Field_Id']
  205. #print(' Cropping NDVI data for test field #{}'.format(field_id))
  206. try:
  207. projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']),
  208. test_shape_crs,
  209. ndvi_crs)
  210. except Exception as e:
  211. print(' ', e, ' exception for field #', field_id)
  212. continue
  213.  
  214. #print(projected_shape)
  215. field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True)
  216. field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True)
  217. field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True)
  218. field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True)
  219. # remove the first dimension
  220. field_img = numpy.squeeze(field_img, axis=0)
  221. field_img_red = numpy.squeeze(field_img_red, axis=0)
  222. field_img_green = numpy.squeeze(field_img_green, axis=0)
  223. field_img_blue = numpy.squeeze(field_img_blue, axis=0)
  224. # add the 3rd dimension
  225. field_img = numpy.expand_dims(field_img, 2)
  226. field_img_red = numpy.expand_dims(field_img_red, 2)
  227. field_img_green = numpy.expand_dims(field_img_green, 2)
  228. field_img_blue = numpy.expand_dims(field_img_blue, 2)
  229.  
  230.  
  231. if field_id in test_field_data:
  232. test_field_data[field_id] = numpy.concatenate((test_field_data[field_id], field_img), axis=2)
  233. test_field_data_r[field_id] = numpy.concatenate((test_field_data_r[field_id], field_img_red), axis=2)
  234. test_field_data_g[field_id] = numpy.concatenate((test_field_data_g[field_id], field_img_green), axis=2)
  235. test_field_data_b[field_id] = numpy.concatenate((test_field_data_b[field_id], field_img_blue), axis=2)
  236. else:
  237. test_field_data[field_id] = field_img
  238. test_field_data_r[field_id] = field_img_red
  239. test_field_data_g[field_id] = field_img_green
  240. test_field_data_b[field_id] = field_img_blue
  241.  
  242.  
  243. # save the fields data to file
  244. pickle.dump(train_field_data, open('train/train.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  245. pickle.dump(train_field_data_r, open('train/train-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  246. pickle.dump(train_field_data_g, open('train/train-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  247. pickle.dump(train_field_data_b, open('train/train-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  248. pickle.dump(test_field_data, open('test/test.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  249. pickle.dump(test_field_data_r, open('test/test-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  250. pickle.dump(test_field_data_g, open('test/test-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
  251. pickle.dump(test_field_data_b, open('test/test-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement