Guest User

batchinference

a guest
Jul 18th, 2024
7
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.55 KB | None | 0 0
  1. #@title experimental batch inference WITH PROGRESS BAR
  2. import os
  3. import cv2
  4. from IPython.display import Image, display
  5. from concurrent.futures import ThreadPoolExecutor
  6.  
  7. # Paths
  8. source_video_path = '/content/LivePortrait/37secsclip.mp4' #@param {type:"string"}
  9. driving_video_path = '/content/LivePortrait/drivingcut35secs.mp4' #@param {type:"string"}
  10. output_dir = 'output_frames'
  11. animations_dir = '/content/LivePortrait/animations'
  12. last_frames_dir = 'last_frames'
  13. source_frames_dir = 'source_frames'
  14. driving_frames_dir = 'driving_frames'
  15. os.makedirs(output_dir, exist_ok=True)
  16. os.makedirs(animations_dir, exist_ok=True)
  17. os.makedirs(last_frames_dir, exist_ok=True)
  18. os.makedirs(source_frames_dir, exist_ok=True)
  19. os.makedirs(driving_frames_dir, exist_ok=True)
  20.  
  21. # Get the FPS of both videos
  22. source_fps = cv2.VideoCapture(source_video_path).get(cv2.CAP_PROP_FPS)
  23. driving_fps = cv2.VideoCapture(driving_video_path).get(cv2.CAP_PROP_FPS)
  24.  
  25. # Settings
  26. fps_option = 'custom' #@param ["source", "driving", "higher", "lower", "custom"] {allow-input: true}
  27. custom_fps = 15 #@param {type:"number"} # Only used if fps_option is 'custom'
  28. num_workers = 4 #@param {type:"number"}
  29.  
  30. # Determine the frame rate to use
  31. if fps_option == 'source':
  32. fps = source_fps
  33. elif fps_option == 'driving':
  34. fps = driving_fps
  35. elif fps_option == 'higher':
  36. fps = max(source_fps, driving_fps)
  37. elif fps_option == 'lower':
  38. fps = min(source_fps, driving_fps)
  39. elif fps_option == 'custom':
  40. fps = custom_fps
  41. else:
  42. fps = source_fps # Default to source video fps if option is unknown
  43.  
  44. print(f"Using frame rate: {fps} FPS")
  45.  
  46. # Clear directories
  47. for folder in [source_frames_dir, driving_frames_dir, last_frames_dir, output_dir]:
  48. for file in os.listdir(folder):
  49. file_path = os.path.join(folder, file)
  50. if os.path.isfile(file_path) or os.path.islink(file_path):
  51. os.unlink(file_path)
  52. elif os.path.isdir(file_path):
  53. os.rmdir(file_path)
  54.  
  55. # Extract frames from source video using ffmpeg
  56. os.system(f'ffmpeg -i {source_video_path} -vf "fps={fps}" {source_frames_dir}/frame_%04d.png')
  57.  
  58. # Extract frames from driving video using ffmpeg
  59. os.system(f'ffmpeg -i {driving_video_path} -vf "fps={fps}" {driving_frames_dir}/frame_%04d.png')
  60.  
  61. # List extracted frames
  62. source_frames = sorted(os.listdir(source_frames_dir))
  63. driving_frames = sorted(os.listdir(driving_frames_dir))
  64.  
  65. # Determine the minimum length
  66. min_length = min(len(source_frames), len(driving_frames))
  67. print(f"Using {min_length} frames for processing.")
  68.  
  69. # Create 2-frame videos
  70. for i in range(min_length):
  71. frame1_path = os.path.join(driving_frames_dir, driving_frames[0]) # first frame
  72. frame2_path = os.path.join(driving_frames_dir, driving_frames[i])
  73.  
  74. # Create 2-frame video using ffmpeg
  75. two_frame_video_path = os.path.join(output_dir, f'two_frame_video_{i}.mp4')
  76. os.system(f'ffmpeg -y -loop 1 -t 0.5 -i {frame1_path} -loop 1 -t 0.5 -i {frame2_path} '
  77. f'-filter_complex "[0:v][1:v]concat=n=2:v=1:a=0[outv]" -map "[outv]" {two_frame_video_path}')
  78. # print(f"Created 2-frame video: {two_frame_video_path}")
  79.  
  80. # Verify the first frame of the first 2-frame video
  81. two_frame_video_path = os.path.join(output_dir, 'two_frame_video_0.mp4')
  82. first_frame_path = os.path.join(output_dir, 'verify_frame.png')
  83. os.system(f'ffmpeg -i {two_frame_video_path} -vf "select=eq(n\,0)" -q:v 3 {first_frame_path}')
  84.  
  85. # # Display the extracted frame from the first 2-frame video
  86. # if os.path.exists(first_frame_path):
  87. # display(Image(filename=first_frame_path))
  88. # else:
  89. # print("Failed to extract the first frame from the 2-frame video.")
  90.  
  91. # Function to run inference
  92. import os
  93. from concurrent.futures import ThreadPoolExecutor
  94. from tqdm import tqdm
  95.  
  96. # Function to run inference on a batch of frames
  97. def run_batch_inference(batch):
  98. batch_commands = []
  99. for i in batch:
  100. input_image_path = os.path.join(source_frames_dir, source_frames[i])
  101. driving_video_path = os.path.join(output_dir, f'two_frame_video_{i}.mp4')
  102. command = f'python inference.py -s {input_image_path} -d {driving_video_path}'
  103. batch_commands.append(command)
  104.  
  105. # Run the batch inference
  106. os.system(' && '.join(batch_commands))
  107.  
  108. # Define batch size
  109. batch_size = 32 # Adjust this based on your system's capabilities
  110.  
  111. # Create batches
  112. batches = [range(i, min(i + batch_size, min_length)) for i in range(0, min_length, batch_size)]
  113.  
  114. # Calculate total number of batches
  115. total_batches = len(batches)
  116.  
  117. # Run batch inference in parallel with progress bar
  118. with ThreadPoolExecutor(max_workers=num_workers) as executor:
  119. list(tqdm(executor.map(run_batch_inference, batches),
  120. total=total_batches,
  121. desc="Processing batches",
  122. unit="batch"))
  123.  
  124. print("Re-combining Video")
  125. # Save the last frame of each inference video
  126. for i in range(min_length):
  127. input_image_path = os.path.join('source_frames', source_frames[i])
  128. driving_video_path = os.path.join(output_dir, f'two_frame_video_{i}.mp4')
  129.  
  130. # Determine output file path
  131. image_filename = os.path.splitext(os.path.basename(input_image_path))[0]
  132. video_filename = os.path.splitext(os.path.basename(driving_video_path))[0]
  133. output_filename = f"{image_filename}--{video_filename}.mp4"
  134. output_path = f"{animations_dir}/{output_filename}"
  135.  
  136. if not os.path.exists(output_path):
  137. print(f"Output video not found: {output_path}")
  138. continue
  139.  
  140. # Extract the last frame of the output video
  141. last_frame_path = os.path.join(last_frames_dir, f'last_frame_{i}.png')
  142. os.system(f'ffmpeg -sseof -3 -i {output_path} -update 1 -q:v 1 {last_frame_path}')
  143. # print(f"Saved last frame of {output_path} to {last_frame_path}")
  144. # print(f"Saved frame {i}")
  145.  
  146. # Delete the final output video if it already exists
  147. final_video_path = os.path.join(output_dir, 'final_output2.mp4')
  148. if os.path.exists(final_video_path):
  149. os.remove(final_video_path)
  150.  
  151. # Combine the last frames into a final video using ffmpeg
  152. os.system(f'ffmpeg -framerate {fps} -i {last_frames_dir}/last_frame_%d.png -c:v libx264 -pix_fmt yuv420p {final_video_path}')
  153.  
  154. print(f"Final video saved at: {final_video_path}")
  155.  
  156.  
  157. from IPython.display import HTML
  158. from base64 import b64encode
  159.  
  160. # Read the output video file
  161. mp4 = open("output_frames/final_output2.mp4", 'rb').read()
  162. data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  163.  
  164. # Display the video in HTML
  165. HTML(f"""
  166. <video width=400 controls>
  167. <source src="{data_url}" type="video/mp4">
  168. </video>
  169. """)
  170.  
Add Comment
Please, Sign In to add comment