Advertisement
Guest User

Untitled

a guest
Feb 16th, 2025
209
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.06 KB | None | 0 0
  1. import json
  2. import os
  3. import safetensors.torch
  4. import psutil
  5. import time
  6. import logging
  7. import gc
  8. from typing import Dict, Any
  9. from threading import Timer
  10.  
  11. # Configuração de logging
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format='%(asctime)s - %(levelname)s - %(message)s',
  15. filename='merge_process.log'
  16. )
  17. logger = logging.getLogger(__name__)
  18.  
  19. class TimeoutError(Exception):
  20. pass
  21.  
  22. def timeout_handler(timeout_duration):
  23. def decorator(func):
  24. def wrapper(*args, **kwargs):
  25. timer = Timer(timeout_duration, lambda: (_ for _ in ()).throw(TimeoutError("Operação excedeu o tempo limite")))
  26. try:
  27. timer.start()
  28. result = func(*args, **kwargs)
  29. return result
  30. finally:
  31. timer.cancel()
  32. return wrapper
  33. return decorator
  34.  
  35. def print_memory_usage():
  36. process = psutil.Process(os.getpid())
  37. memory_mb = process.memory_info().rss / 1024 / 1024
  38. logger.info(f"Memória em uso: {memory_mb:.2f} MB")
  39. print(f"Memória em uso: {memory_mb:.2f} MB")
  40.  
  41. def verify_file_integrity(file_path: str) -> bool:
  42. try:
  43. with safetensors.torch.safe_open(file_path, framework="pt", device="cpu") as f:
  44. _ = f.keys()
  45. return True
  46. except Exception as e:
  47. logger.error(f"Erro na verificação de integridade do arquivo {file_path}: {e}")
  48. return False
  49.  
  50. def save_progress(processed_weights: Dict[str, Any], filename: str = 'progress.json'):
  51. try:
  52. progress_data = {
  53. 'processed_weights': list(processed_weights.keys()),
  54. 'timestamp': time.time()
  55. }
  56. with open(filename, 'w') as f:
  57. json.dump(progress_data, f)
  58. logger.info(f"Progresso salvo em {filename} - {len(progress_data['processed_weights'])} tensores processados")
  59. except Exception as e:
  60. logger.error(f"Erro ao salvar progresso: {e}")
  61.  
  62. def estimate_required_memory(index_file_path: str) -> int:
  63. logger.info("Iniciando estimativa de memória necessária")
  64. try:
  65. with open(index_file_path, 'r') as f:
  66. index_data = json.load(f)
  67.  
  68. unique_shards = set(index_data['weight_map'].values())
  69. total_bytes = 0
  70. print("Tamanhos dos shards:")
  71. for shard in unique_shards:
  72. shard_path = os.path.join(os.path.dirname(index_file_path), shard)
  73. try:
  74. size = os.path.getsize(shard_path)
  75. print(f" {shard}: {size / (1024 ** 2):.2f} MB")
  76. logger.info(f"Shard {shard}: {size / (1024 ** 2):.2f} MB")
  77. total_bytes += size
  78. except Exception as e:
  79. logger.error(f"Erro ao obter tamanho do shard {shard}: {e}")
  80. print(f"Erro ao obter o tamanho do shard {shard}: {e}")
  81.  
  82. print(f"Memória total estimada: {total_bytes / (1024 ** 2):.2f} MB")
  83. logger.info(f"Memória total estimada: {total_bytes / (1024 ** 2):.2f} MB")
  84. return total_bytes
  85. except Exception as e:
  86. logger.error(f"Erro na estimativa de memória: {e}")
  87. raise
  88.  
  89. def check_available_swap() -> int:
  90. try:
  91. swap = psutil.swap_memory()
  92. print("\nInformações do swap do sistema:")
  93. print(f" Total: {swap.total / (1024 ** 2):.2f} MB")
  94. print(f" Disponível: {swap.free / (1024 ** 2):.2f} MB")
  95. logger.info(f"Swap total: {swap.total / (1024 ** 2):.2f} MB, Disponível: {swap.free / (1024 ** 2):.2f} MB")
  96. return swap.free
  97. except Exception as e:
  98. logger.error(f"Erro ao verificar memória swap: {e}")
  99. raise
  100.  
  101. def merge_safetensors_shards(index_file_path: str, output_file_path: str, batch_size: int = 5):
  102. logger.info(f"Iniciando mesclagem de shards. Arquivo de índice: {index_file_path}")
  103. print_memory_usage()
  104.  
  105. # Verificar requisitos de memória
  106. required_memory = estimate_required_memory(index_file_path)
  107. available_swap = check_available_swap()
  108.  
  109. if required_memory > available_swap:
  110. warning_msg = "AVISO: Memória swap pode ser insuficiente para a operação"
  111. logger.warning(warning_msg)
  112. print(f"\n{warning_msg}")
  113.  
  114. try:
  115. with open(index_file_path, 'r') as f:
  116. index_data = json.load(f)
  117. except FileNotFoundError as e:
  118. error_msg = f"Arquivo index não encontrado em {index_file_path}"
  119. logger.error(error_msg)
  120. raise FileNotFoundError(error_msg)
  121.  
  122. shard_file_directory = os.path.dirname(index_file_path)
  123. weight_items = list(index_data['weight_map'].items())
  124. processed_weights = {}
  125.  
  126. # Processar em lotes
  127. for i in range(0, len(weight_items), batch_size):
  128. batch = weight_items[i:i + batch_size]
  129. batch_weights = {}
  130. logger.info(f"Processando batch {i//batch_size + 1} de {len(weight_items)//batch_size + 1}")
  131.  
  132. for weight_name, shard_file in batch:
  133. shard_path = os.path.join(shard_file_directory, shard_file)
  134.  
  135. if not verify_file_integrity(shard_path):
  136. logger.error(f"Arquivo corrompido: {shard_path}")
  137. continue
  138.  
  139. @timeout_handler(300)
  140. def load_tensor(file_path, name):
  141. with safetensors.torch.safe_open(file_path, framework="pt", device="cpu") as f:
  142. return f.get_tensor(name)
  143.  
  144. try:
  145. start_time = time.time()
  146. tensor = load_tensor(shard_path, weight_name)
  147. batch_weights[weight_name] = tensor
  148. elapsed = time.time() - start_time
  149. logger.info(f"Peso '{weight_name}' carregado em {elapsed:.2f}s, forma: {tensor.shape}")
  150. print(f"Processado: {weight_name} em {elapsed:.2f}s")
  151. except TimeoutError:
  152. logger.error(f"Timeout ao processar {weight_name}")
  153. continue
  154. except Exception as e:
  155. logger.error(f"Erro ao processar {weight_name}: {e}")
  156. continue
  157.  
  158. # Salvar batch atual
  159. if batch_weights:
  160. try:
  161. temp_output = f"{output_file_path}.part{i}"
  162. safetensors.torch.save_file(batch_weights, temp_output, metadata={"format": "pt"})
  163. logger.info(f"Batch {i//batch_size + 1} salvo em {temp_output}")
  164. save_progress({'processed_weights': list(batch_weights.keys())})
  165. except Exception as e:
  166. logger.error(f"Erro ao salvar batch {i//batch_size + 1}: {e}")
  167. continue
  168.  
  169. del batch_weights
  170. gc.collect()
  171. print_memory_usage()
  172.  
  173. # Mesclar arquivos temporários
  174. try:
  175. logger.info("Mesclando arquivos temporários")
  176. metadata = {"format": "pt"}
  177. all_weights = {}
  178.  
  179. # Lista todos os arquivos temporários
  180. temp_files = []
  181. for i in range(0, len(weight_items), batch_size):
  182. temp_file = f"{output_file_path}.part{i}"
  183. if os.path.exists(temp_file):
  184. temp_files.append(temp_file)
  185.  
  186. print(f"Encontrados {len(temp_files)} arquivos temporários para mesclar")
  187. logger.info(f"Encontrados {len(temp_files)} arquivos temporários para mesclar")
  188.  
  189. # Mescla os arquivos temporários
  190. for temp_file in temp_files:
  191. print(f"Processando {temp_file}")
  192. logger.info(f"Processando {temp_file}")
  193. try:
  194. with safetensors.torch.safe_open(temp_file, framework="pt", device="cpu") as f:
  195. for key in f.keys():
  196. all_weights[key] = f.get_tensor(key)
  197. print(f"Arquivo {temp_file} processado")
  198. logger.info(f"Arquivo {temp_file} processado")
  199. except Exception as e:
  200. print(f"Erro ao processar {temp_file}: {e}")
  201. logger.error(f"Erro ao processar {temp_file}: {e}")
  202. continue
  203.  
  204. # Salva o arquivo final
  205. print("Salvando arquivo final...")
  206. logger.info("Salvando arquivo final...")
  207. safetensors.torch.save_file(all_weights, output_file_path, metadata=metadata)
  208.  
  209. # Remove os arquivos temporários
  210. for temp_file in temp_files:
  211. try:
  212. os.remove(temp_file)
  213. logger.info(f"Arquivo temporário removido: {temp_file}")
  214. except Exception as e:
  215. logger.error(f"Erro ao remover arquivo temporário {temp_file}: {e}")
  216.  
  217. logger.info(f"Mesclagem concluída com sucesso. Arquivo final: {output_file_path}")
  218. print(f"\nShards mesclados com sucesso em {output_file_path}")
  219.  
  220. except Exception as e:
  221. logger.error(f"Erro na mesclagem final: {e}")
  222. print(f"Erro na mesclagem final: {e}")
  223. raise
  224.  
  225. if __name__ == "__main__":
  226. try:
  227. print("Pressione Ctrl+C para interromper o processo a qualquer momento")
  228.  
  229. index_file_path = r"input_folder_here\diffusion_pytorch_model.safetensors.index.json"
  230. output_file_path = r"output_folder_here\merged_diffusion_pytorch_model.safetensors"
  231.  
  232. print("\nIniciando processamento...")
  233. print(f"Arquivo de entrada: {index_file_path}")
  234. print(f"Arquivo de saída: {output_file_path}\n")
  235.  
  236. merge_safetensors_shards(index_file_path, output_file_path)
  237. logger.info("Processo de mesclagem concluído com sucesso")
  238. print("\nProcesso de mesclagem concluído com sucesso.")
  239.  
  240. print("\nPressione Enter para sair...")
  241. input()
  242. except KeyboardInterrupt:
  243. print("\nProcesso interrompido pelo usuário.")
  244. logger.info("Processo interrompido pelo usuário")
  245. except Exception as e:
  246. logger.error(f"Erro fatal no processo de mesclagem: {e}")
  247. print(f"\nErro fatal no processo de mesclagem: {e}")
  248. print("\nPressione Enter para sair...")
  249. input()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement