Advertisement
Guest User

Untitled

a guest
Apr 25th, 2017
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.86 KB | None | 0 0
  1. # BATCH-NORMALIZED FEATURE MATRICES
  2. dataset = reader.read.format("com.databricks.spark.avro").load("data/collisions_feature_matrices_batch_normalized.avro")
  3. # Read the collisions dataset to obtain meta-information about the tracks.
  4. collisions = reader.read.format("com.databricks.spark.avro").load("data/collisions.avro")
  5.  
  6. def extract_track_types(iterator):
  7.     for row in iterator:
  8.         tracks = row['tracks']
  9.         for t in tracks:
  10.             yield t['track_type']
  11.  
  12. # Obtain the files from which we extracted the collisions.
  13. files = collisions.mapPartitions(extract_track_types).distinct().collect()
  14.  
  15. def construct_output_vector(row):
  16.     collision_id = row['id']
  17.     tracks = row['tracks']
  18.     files = []
  19.     for t in tracks:
  20.         file = t['track_type']
  21.         if file not in files:
  22.             files.append(file)
  23.     # Construct the output vector.
  24.     y = np.zeros(num_types)
  25.     for f in files:
  26.         y[mapping[f]] = 1.0
  27.        
  28.     return Row(**{'id': collision_id, 'y': y.tolist()})
  29.  
  30. # From this, construct a feature vector which represents the track types for every collision-id.
  31. output_vectors = collisions.map(construct_output_vector).toDF()
  32.  
  33. def flatten(row):
  34.     # Obtain the collision-id.
  35.     collision_id = row['collision_id']
  36.     # Obtain the feature matrices, and flatten them.
  37.     m_f = np.asarray(row['front']).flatten()
  38.     m_s = np.asarray(row['side']).flatten()
  39.    
  40.     return Row(**{'collision_id': collision_id, 'front': m_f.tolist(), 'side': m_s.tolist()})
  41.  
  42. training_set = dataset.map(flatten).toDF()
  43. training_set = training_set.join(output_vectors, training_set.collision_id == output_vectors.id)
  44. training_set = training_set.select("collision_id", "front", "side", "y")
  45. training_set.persist(StorageLevel.MEMORY_AND_DISK)
  46.  
  47. training_set.printSchema()
  48.  
  49. print("Number of training samples: " + str(training_set.count()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement