Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- import os
- import safetensors.torch
- import psutil
- import time
- import logging
- import gc
- from typing import Dict, Any
- from threading import Timer
- # Configuração de logging
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s',
- filename='merge_process.log'
- )
- logger = logging.getLogger(__name__)
- class TimeoutError(Exception):
- pass
- def timeout_handler(timeout_duration):
- def decorator(func):
- def wrapper(*args, **kwargs):
- timer = Timer(timeout_duration, lambda: (_ for _ in ()).throw(TimeoutError("Operação excedeu o tempo limite")))
- try:
- timer.start()
- result = func(*args, **kwargs)
- return result
- finally:
- timer.cancel()
- return wrapper
- return decorator
- def print_memory_usage():
- process = psutil.Process(os.getpid())
- memory_mb = process.memory_info().rss / 1024 / 1024
- logger.info(f"Memória em uso: {memory_mb:.2f} MB")
- print(f"Memória em uso: {memory_mb:.2f} MB")
- def verify_file_integrity(file_path: str) -> bool:
- try:
- with safetensors.torch.safe_open(file_path, framework="pt", device="cpu") as f:
- _ = f.keys()
- return True
- except Exception as e:
- logger.error(f"Erro na verificação de integridade do arquivo {file_path}: {e}")
- return False
- def save_progress(processed_weights: Dict[str, Any], filename: str = 'progress.json'):
- try:
- progress_data = {
- 'processed_weights': list(processed_weights.keys()),
- 'timestamp': time.time()
- }
- with open(filename, 'w') as f:
- json.dump(progress_data, f)
- logger.info(f"Progresso salvo em {filename} - {len(progress_data['processed_weights'])} tensores processados")
- except Exception as e:
- logger.error(f"Erro ao salvar progresso: {e}")
- def estimate_required_memory(index_file_path: str) -> int:
- logger.info("Iniciando estimativa de memória necessária")
- try:
- with open(index_file_path, 'r') as f:
- index_data = json.load(f)
- unique_shards = set(index_data['weight_map'].values())
- total_bytes = 0
- print("Tamanhos dos shards:")
- for shard in unique_shards:
- shard_path = os.path.join(os.path.dirname(index_file_path), shard)
- try:
- size = os.path.getsize(shard_path)
- print(f" {shard}: {size / (1024 ** 2):.2f} MB")
- logger.info(f"Shard {shard}: {size / (1024 ** 2):.2f} MB")
- total_bytes += size
- except Exception as e:
- logger.error(f"Erro ao obter tamanho do shard {shard}: {e}")
- print(f"Erro ao obter o tamanho do shard {shard}: {e}")
- print(f"Memória total estimada: {total_bytes / (1024 ** 2):.2f} MB")
- logger.info(f"Memória total estimada: {total_bytes / (1024 ** 2):.2f} MB")
- return total_bytes
- except Exception as e:
- logger.error(f"Erro na estimativa de memória: {e}")
- raise
- def check_available_swap() -> int:
- try:
- swap = psutil.swap_memory()
- print("\nInformações do swap do sistema:")
- print(f" Total: {swap.total / (1024 ** 2):.2f} MB")
- print(f" Disponível: {swap.free / (1024 ** 2):.2f} MB")
- logger.info(f"Swap total: {swap.total / (1024 ** 2):.2f} MB, Disponível: {swap.free / (1024 ** 2):.2f} MB")
- return swap.free
- except Exception as e:
- logger.error(f"Erro ao verificar memória swap: {e}")
- raise
- def merge_safetensors_shards(index_file_path: str, output_file_path: str, batch_size: int = 5):
- logger.info(f"Iniciando mesclagem de shards. Arquivo de índice: {index_file_path}")
- print_memory_usage()
- # Verificar requisitos de memória
- required_memory = estimate_required_memory(index_file_path)
- available_swap = check_available_swap()
- if required_memory > available_swap:
- warning_msg = "AVISO: Memória swap pode ser insuficiente para a operação"
- logger.warning(warning_msg)
- print(f"\n{warning_msg}")
- try:
- with open(index_file_path, 'r') as f:
- index_data = json.load(f)
- except FileNotFoundError as e:
- error_msg = f"Arquivo index não encontrado em {index_file_path}"
- logger.error(error_msg)
- raise FileNotFoundError(error_msg)
- shard_file_directory = os.path.dirname(index_file_path)
- weight_items = list(index_data['weight_map'].items())
- processed_weights = {}
- # Processar em lotes
- for i in range(0, len(weight_items), batch_size):
- batch = weight_items[i:i + batch_size]
- batch_weights = {}
- logger.info(f"Processando batch {i//batch_size + 1} de {len(weight_items)//batch_size + 1}")
- for weight_name, shard_file in batch:
- shard_path = os.path.join(shard_file_directory, shard_file)
- if not verify_file_integrity(shard_path):
- logger.error(f"Arquivo corrompido: {shard_path}")
- continue
- @timeout_handler(300)
- def load_tensor(file_path, name):
- with safetensors.torch.safe_open(file_path, framework="pt", device="cpu") as f:
- return f.get_tensor(name)
- try:
- start_time = time.time()
- tensor = load_tensor(shard_path, weight_name)
- batch_weights[weight_name] = tensor
- elapsed = time.time() - start_time
- logger.info(f"Peso '{weight_name}' carregado em {elapsed:.2f}s, forma: {tensor.shape}")
- print(f"Processado: {weight_name} em {elapsed:.2f}s")
- except TimeoutError:
- logger.error(f"Timeout ao processar {weight_name}")
- continue
- except Exception as e:
- logger.error(f"Erro ao processar {weight_name}: {e}")
- continue
- # Salvar batch atual
- if batch_weights:
- try:
- temp_output = f"{output_file_path}.part{i}"
- safetensors.torch.save_file(batch_weights, temp_output, metadata={"format": "pt"})
- logger.info(f"Batch {i//batch_size + 1} salvo em {temp_output}")
- save_progress({'processed_weights': list(batch_weights.keys())})
- except Exception as e:
- logger.error(f"Erro ao salvar batch {i//batch_size + 1}: {e}")
- continue
- del batch_weights
- gc.collect()
- print_memory_usage()
- # Mesclar arquivos temporários
- try:
- logger.info("Mesclando arquivos temporários")
- metadata = {"format": "pt"}
- all_weights = {}
- # Lista todos os arquivos temporários
- temp_files = []
- for i in range(0, len(weight_items), batch_size):
- temp_file = f"{output_file_path}.part{i}"
- if os.path.exists(temp_file):
- temp_files.append(temp_file)
- print(f"Encontrados {len(temp_files)} arquivos temporários para mesclar")
- logger.info(f"Encontrados {len(temp_files)} arquivos temporários para mesclar")
- # Mescla os arquivos temporários
- for temp_file in temp_files:
- print(f"Processando {temp_file}")
- logger.info(f"Processando {temp_file}")
- try:
- with safetensors.torch.safe_open(temp_file, framework="pt", device="cpu") as f:
- for key in f.keys():
- all_weights[key] = f.get_tensor(key)
- print(f"Arquivo {temp_file} processado")
- logger.info(f"Arquivo {temp_file} processado")
- except Exception as e:
- print(f"Erro ao processar {temp_file}: {e}")
- logger.error(f"Erro ao processar {temp_file}: {e}")
- continue
- # Salva o arquivo final
- print("Salvando arquivo final...")
- logger.info("Salvando arquivo final...")
- safetensors.torch.save_file(all_weights, output_file_path, metadata=metadata)
- # Remove os arquivos temporários
- for temp_file in temp_files:
- try:
- os.remove(temp_file)
- logger.info(f"Arquivo temporário removido: {temp_file}")
- except Exception as e:
- logger.error(f"Erro ao remover arquivo temporário {temp_file}: {e}")
- logger.info(f"Mesclagem concluída com sucesso. Arquivo final: {output_file_path}")
- print(f"\nShards mesclados com sucesso em {output_file_path}")
- except Exception as e:
- logger.error(f"Erro na mesclagem final: {e}")
- print(f"Erro na mesclagem final: {e}")
- raise
- if __name__ == "__main__":
- try:
- print("Pressione Ctrl+C para interromper o processo a qualquer momento")
- index_file_path = r"input_folder_here\diffusion_pytorch_model.safetensors.index.json"
- output_file_path = r"output_folder_here\merged_diffusion_pytorch_model.safetensors"
- print("\nIniciando processamento...")
- print(f"Arquivo de entrada: {index_file_path}")
- print(f"Arquivo de saída: {output_file_path}\n")
- merge_safetensors_shards(index_file_path, output_file_path)
- logger.info("Processo de mesclagem concluído com sucesso")
- print("\nProcesso de mesclagem concluído com sucesso.")
- print("\nPressione Enter para sair...")
- input()
- except KeyboardInterrupt:
- print("\nProcesso interrompido pelo usuário.")
- logger.info("Processo interrompido pelo usuário")
- except Exception as e:
- logger.error(f"Erro fatal no processo de mesclagem: {e}")
- print(f"\nErro fatal no processo de mesclagem: {e}")
- print("\nPressione Enter para sair...")
- input()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement