Advertisement
Guest User

Untitled

a guest
Jun 20th, 2024
172
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.91 KB | None | 0 0
  1. ##########################################################################################
  2. #
  3. # Code to create a TFRecord from videos (droplet videos). The frames are inputted into a
  4. # VAE, and the latent variable is stored into the TFRecord (instead of the full frame).
  5. # Thus, for each video the TFRecord will store the latent variable that represents each
  6. # frame.
  7. # The videos must be binasired, and 400x400, in the format expected by the VAE.
  8. #
  9. # Author: Juan Manuel Parrilla Gutierrez ([email protected])
  10. #
  11. ##########################################################################################
  12.  
  13. import glob, cv2, tqdm, sys, os, re, random, shutil
  14. import tensorflow as tf
  15. import numpy as np
  16.  
  17. from tensorflow.keras import utils
  18.  
  19. from src.utils.modelutils import load_vae_model
  20.  
  21.  
  22. class TFRecordWriter(object):
  23.     """
  24.    Class handling writing the video data (through a VAE) to a TFRecord.
  25.    Please check TFRecordReader to see how to read it back
  26.    """
  27.  
  28.     def __init__(self, videopaths, modelpath, dillation,
  29.                  step_size=1, batch_size=32, train_split=0.8):
  30.         """Creates an object that will be used to save the data into a TFRecord
  31.  
  32.        Args:
  33.            videopaths: Path to the folder where all the videos are. The videos must be
  34.            binarised, created with the script VideoUtils.py
  35.            modelpath: Path to the folder with the trained VAE (created with train_vae.py)
  36.            dillation: "normal", "reverse" or "no". Check FramesLoader for more info
  37.            step_size: How many frames to take. For example step_3 will take 1 out 3 frames
  38.            batch_size: Batch size of the data when we run it against the VAE encoder
  39.            train_split: % of data in train_set (the rest goes to test_set)
  40.        """
  41.         self.videopaths = videopaths
  42.         self.modelpath = modelpath
  43.         self.dillation = dillation
  44.         self.step_size = step_size
  45.         self.batch_size = batch_size
  46.         self.train_split = train_split
  47.  
  48.         self.vae, self.input_dim = load_vae_model(self.modelpath)
  49.  
  50.         # Because the videos have different sizes, we need to break the longer ones
  51.         # into sequences of max 416, so that all the batches have the same number of elements
  52.         # therefore all the batches will be of size 416, latent_size
  53.         # this number needs to be adjusted depending on self.batch_size
  54.         self.seq_length = 416 * 2
  55.  
  56.  
  57.     def serialise_to_tfrecords(self):
  58.         """ Serialises all the data into one file with TFRecords."""
  59.  
  60.         # get all the videos
  61.         run_name = (self.videopaths+'/*.avi')
  62.         videos = glob.glob(run_name)
  63.         random.shuffle(videos)
  64.         db_split = int(len(videos)*self.train_split)
  65.         train_set = videos[:db_split]
  66.         test_set = videos[db_split:]
  67.  
  68.         for ti, tset in enumerate([train_set, test_set]):
  69.  
  70.             if ti==0:
  71.                 out_path = self.videopaths + "/train.tfrecord"
  72.             else:
  73.                 out_path = self.videopaths + "/test.tfrecord"
  74.             writer = tf.io.TFRecordWriter(out_path)
  75.  
  76.             for video in tqdm.tqdm(tset, unit='F'):
  77.                 frames_latent_vectors = self.video2latentvectors(video)
  78.                 vector_length = frames_latent_vectors.shape[0]
  79.                 recipe = self.get_recipe(video)
  80.  
  81.                 # break down the latent_vectors into sequences of length seq_lentgh
  82.                 for seq in range(0, vector_length, self.seq_length):
  83.                     start = seq
  84.                     end = seq + self.seq_length
  85.  
  86.                     if end > vector_length:
  87.                         break
  88.  
  89.                     chunk = frames_latent_vectors[start:end]
  90.                     # self.visualise_latent_reconstructions(frames_latent_vectors)
  91.                     serialised_video = self.prepare_TFRecord(chunk, recipe)
  92.                     writer.write(serialised_video)
  93.  
  94.                 # move file to either train or test folder
  95.                 # get the folder where the file is
  96.                 current_folder = os.path.dirname(video)
  97.                 # get the name of the file
  98.                 file_name = os.path.basename(video)
  99.  
  100.                 if ti == 0:
  101.                     dest_folder = current_folder + "/train/"
  102.                 else:
  103.                     dest_folder = current_folder + "/test/"
  104.  
  105.                 destination = dest_folder + file_name
  106.                 shutil.move(video, destination)
  107.                
  108.             writer.close()
  109.  
  110.  
  111.     def get_recipe(self, videopath):
  112.         """ from the filename we can get the recipe. For example:
  113.        octanoic_0_pentanol_0_octanol_9_dep_90_raw_1_bin.mp4
  114.        means 0% octanoic, 0% pentaol, 9% octanol, 90% DEP"""
  115.  
  116.         # get the file name only, not the full path with folders
  117.         file_name = os.path.basename(videopath)
  118.         # remove the extension
  119.         file_name = os.path.splitext(file_name)[0]
  120.         # find the first four numbers, which are the recipe
  121.         recipe = re.findall(r'\d+', file_name)[:4]
  122.         # transform to int and to numpy
  123.         recipenp = np.array( [int(x) for x in recipe] )
  124.         # normalise and return
  125.         return recipenp / np.sum(recipenp)
  126.  
  127.  
  128.     def video2latentvectors(self, videopath):
  129.         """
  130.        Given a video, it will use a trained vae to return the latent vectors for each frame
  131.        """
  132.  
  133.         # create dataset and load vae
  134.         ds = self.video2dataset(videopath)
  135.         # where to store the data as it is generated by the vae
  136.         vectors = []
  137.  
  138.         # using vae get latent vectors
  139.         for batch in ds:
  140.             _, _, latent = self.vae.encoder(batch)
  141.             vectors.append(latent)
  142.  
  143.         # this will go from (n_video, batch, latent_v) to (n_video*batch, latent_v)
  144.         return np.concatenate( np.array(vectors) )
  145.  
  146.  
  147.     def video2dataset(self, videopath):
  148.         """
  149.        Given a video, it will return a TF dataset with its frames
  150.        """
  151.  
  152.         AUTOTUNE = tf.data.AUTOTUNE
  153.        
  154.         # Get a numpy array with all the frames
  155.         frames = self.frames_from_video_file(videopath)
  156.         # convert the numpy array into a tf dataset
  157.         dataset = tf.data.Dataset.from_tensor_slices(frames)
  158.  
  159.         # batch it
  160.         dataset = dataset.batch(self.batch_size, drop_remainder=True)
  161.         # preprocess it
  162.         dataset = self.preprocess_dataset(dataset)
  163.         # configure for performance
  164.         dataset = dataset.prefetch(buffer_size=AUTOTUNE)
  165.  
  166.         return dataset
  167.    
  168.  
  169.     def preprocess_dataset(self, dataset):
  170.         # perform some pre-processing as we did to train the vae
  171.  
  172.         normalization_layer = tf.keras.layers.Rescaling(1./255)
  173.         dillation_layer = tf.keras.layers.MaxPool2D(pool_size=5, strides=1, padding='same')
  174.  
  175.         dataset = dataset.map(lambda x: tf.image.resize(
  176.             x, (self.input_dim[0], self.input_dim[1]) ))
  177.        
  178.         if self.dillation == "normal":
  179.             dataset = dataset.map(lambda x: dillation_layer(x))
  180.         elif self.dillation == "reverse":
  181.             dataset = dataset.map(lambda x: 1-dillation_layer(1-x))
  182.        
  183.         normalized_ds = dataset.map(lambda x: normalization_layer(x))
  184.         return normalized_ds
  185.    
  186.  
  187.     def frames_from_video_file(self, videopath):
  188.         """
  189.        Given a video, it will return the frames in a numpy array
  190.        """
  191.  
  192.         frames = []
  193.         video_capture = cv2.VideoCapture(videopath)
  194.  
  195.         while True:
  196.  
  197.             # Take one frame every step_size
  198.             for _ in range(self.step_size):
  199.                 ret, frame = video_capture.read()
  200.                 if not ret:
  201.                     break
  202.  
  203.             if not ret:
  204.                 break
  205.  
  206.             # the following line would convert it from 0..255 to 0..1
  207.             # but we do a normalization layer later on, so I will comment this out
  208.             # frame = tf.image.convert_image_dtype(frame, tf.float32)
  209.             frames.append(frame)
  210.  
  211.         # last bit changes from bgr to rgb
  212.         return np.array(frames)[..., [2, 1, 0]]
  213.  
  214.  
  215.     def prepare_TFRecord(self, frames, recipe):
  216.         # Tensorflow nomenclature to serialised data to create the TFRecords
  217.  
  218.         frames_feature = tf.train.Feature(
  219.             bytes_list=tf.train.BytesList(value=[
  220.                 tf.io.serialize_tensor(frames).numpy(),
  221.             ])
  222.         )
  223.  
  224.         recipe_feature = tf.train.Feature(
  225.             float_list=tf.train.FloatList(value=recipe),
  226.         )
  227.  
  228.         features = tf.train.Features(feature={
  229.             'frames': frames_feature,
  230.             'recipe': recipe_feature
  231.         })
  232.        
  233.         example = tf.train.Example(features=features)
  234.         return example.SerializeToString()
  235.    
  236.  
  237.     def visualise_latent_reconstructions(self, latent_vectors):
  238.         """
  239.        creates images of the latent vectors generated, to see if the previous encoding
  240.        is correct
  241.        """
  242.  
  243.         batch_size = 32
  244.         ds = tf.data.Dataset.from_tensor_slices(latent_vectors)
  245.         ds = ds.batch(batch_size)
  246.  
  247.         for entry in ds.take(1):
  248.             generated_images = self.vae.decoder(entry)
  249.        
  250.         for i in range(batch_size):
  251.             img = utils.array_to_img(generated_images[i])
  252.             img.save("writer_img_%03d.png" % (i))
  253.        
  254.  
  255. class TFRecordReader(object):
  256.     """
  257.    Class handling reading the TFRecord into a dataset to use for training
  258.    Please check TFRecordWriter to see how it was saved to disk.
  259.    The TFRecord to be read must have been created with TFRecordWriter
  260.    """
  261.  
  262.     def __init__(self, tfrecordfile, batch_size = 64):
  263.         self.BATCH_SIZE = batch_size
  264.         self.tfrecordfile = tfrecordfile
  265.         self.AUTOTUNE = tf.data.AUTOTUNE
  266.         self.dataset = self.get_dataset()  # this will set self.dataset
  267.         self.dataset_iter = iter(self.dataset)
  268.  
  269.    
  270.     def decode_frames(self, frames):
  271.         parsed_data = tf.io.parse_tensor(frames, tf.float32)
  272.         parsed_data = tf.reshape(parsed_data, [832, 200]) # explicit size needed for TPU
  273.         return parsed_data
  274.  
  275.  
  276.     def read_tfrecord(self, example):
  277.         TFREC_FORMAT = {
  278.             "frames": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
  279.             "recipe": tf.io.FixedLenFeature([4], tf.float32)
  280.         }
  281.         example = tf.io.parse_single_example(example, TFREC_FORMAT)
  282.         video_latent_vectors = self.decode_frames(example['frames'])
  283.         return video_latent_vectors, example['recipe']
  284.  
  285.    
  286.     def load_dataset(self):
  287.         """ Loads a TFRecord and uses map to parse it, and stores it into self.dataset
  288.        Check https://keras.io/examples/keras_recipes/tfrecord/ "define load methods"
  289.        because this is basically a copy paste of that code with small modifications
  290.  
  291.        Args:
  292.            properties (list, optional): Check parse_fn above
  293.  
  294.        Returns:
  295.            dataset: Loadad TFRecord
  296.        """
  297.         ignore_order = tf.data.Options()
  298.         ignore_order.experimental_deterministic = False  # disable order, increase speed
  299.         dataset = tf.data.TFRecordDataset(
  300.             self.tfrecordfile
  301.         )  # automatically interleaves reads from multiple files
  302.         dataset = dataset.with_options(
  303.             ignore_order
  304.         )  # uses data as soon as it streams in, rather than in its original order
  305.         dataset = dataset.map(
  306.             self.read_tfrecord,
  307.             num_parallel_calls=self.AUTOTUNE
  308.         )
  309.         # returns the dataset as loaded
  310.         return dataset
  311.    
  312.  
  313.     def get_dataset(self):
  314.         """Loads the TFRecord from the paths (filenames), and then shuffles the data and
  315.        divides it into batches.
  316.        """
  317.         dataset = self.load_dataset()
  318.         dataset = dataset.shuffle(2048)
  319.         dataset = dataset.prefetch(buffer_size=self.AUTOTUNE)
  320.         dataset = dataset.batch(self.BATCH_SIZE, drop_remainder=True)
  321.         return dataset  # .repeat()
  322.    
  323.  
  324.     def visualise_latent_reconstructions_and_recipes(self, vaepath):
  325.  
  326.         batch_size = 32
  327.         data = next(self.dataset_iter)[0] # returns 576,200 (or whatever latent size)
  328.         ds = tf.data.Dataset.from_tensor_slices(data.numpy()[0])
  329.         ds = ds.batch(batch_size) # returns for example 9,32,200
  330.  
  331.         vae, _ = load_vae_model(vaepath)
  332.  
  333.         for entry in ds.take(1):
  334.             generated_images = vae.decoder(entry)
  335.        
  336.         for i in range(batch_size):
  337.             img = utils.array_to_img(generated_images[i])
  338.             img.save("reader_img_%03d.png" % (i))
  339.  
  340.  
  341. if __name__ == "__main__":
  342.  
  343.     videopaths = sys.argv[1]
  344.     # tfrecord_file = sys.argv[1]
  345.     modelpath = sys.argv[2]
  346.  
  347.     tfrecord_writer = TFRecordWriter(videopaths, modelpath, dillation="reverse")
  348.     tfrecord_writer.serialise_to_tfrecords()
  349.  
  350.     # tfrecord_reader = TFRecordReader(tfrecord_file, batch_size=32)
  351.     # tfrecord_reader.visualise_latent_reconstructions_and_recipes(modelpath)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement