Advertisement
Guest User

ConvertUnetOnly

a guest
Aug 11th, 2024
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.81 KB | None | 0 0
  1. import torch
  2. import safetensors.torch
  3. from safetensors.torch import save_file, load_file
  4. from bitsandbytes.nn.modules import Params4bit, QuantState
  5.  
  6.  
  7. def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
  8. if state is None:
  9. return None
  10.  
  11. device = device or state.absmax.device
  12.  
  13. state2 = (
  14. QuantState(
  15. absmax=state.state2.absmax.to(device),
  16. shape=state.state2.shape,
  17. code=state.state2.code.to(device),
  18. blocksize=state.state2.blocksize,
  19. quant_type=state.state2.quant_type,
  20. dtype=state.state2.dtype,
  21. )
  22. if state.nested
  23. else None
  24. )
  25.  
  26. return QuantState(
  27. absmax=state.absmax.to(device),
  28. shape=state.shape,
  29. code=state.code.to(device),
  30. blocksize=state.blocksize,
  31. quant_type=state.quant_type,
  32. dtype=state.dtype,
  33. offset=state.offset.to(device) if state.nested else None,
  34. state2=state2,
  35. )
  36.  
  37.  
  38. class ForgeParams4bit(Params4bit):
  39. def to(self, *args, **kwargs):
  40. device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
  41. if device is not None and device.type == "cuda" and not self.bnb_quantized:
  42. return self._quantize(device)
  43. else:
  44. n = ForgeParams4bit(
  45. torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
  46. requires_grad=self.requires_grad,
  47. quant_state=copy_quant_state(self.quant_state, device),
  48. blocksize=self.blocksize,
  49. compress_statistics=self.compress_statistics,
  50. quant_type=self.quant_type,
  51. quant_storage=self.quant_storage,
  52. bnb_quantized=self.bnb_quantized,
  53. module=self.module
  54. )
  55. self.module.quant_state = n.quant_state
  56. return n
  57.  
  58. # Function to load the model weights from .safetensors
  59. def load_model_weights(safetensors_path):
  60. return load_file(safetensors_path)
  61.  
  62. # Function to quantize the model weights
  63. # Function to quantize the model weights
  64. def quantize_model_weights(state_dict, device=torch.device("cuda")):
  65. quantized_state_dict = {}
  66.  
  67. for key, value in state_dict.items():
  68. if 'weight' in key: # Assuming 'weight' in the key indicates a layer to quantize
  69. # Convert to ForgeParams4bit and quantize
  70. quantized_param = ForgeParams4bit(
  71. value.to(device),
  72. requires_grad=False,
  73. compress_statistics=True,
  74. quant_type='nf4', # Assuming quant_type is 'bnb'
  75. quant_storage=torch.uint8,
  76. module=None # No module reference in this context
  77. )
  78.  
  79. # Quantizing the parameter explicitly
  80. quantized_param = quantized_param._quantize(device)
  81.  
  82. quantized_state_dict[key] = quantized_param.data
  83. else:
  84. quantized_state_dict[key] = value.to(device)
  85.  
  86. return quantized_state_dict
  87.  
  88. # Function to save the quantized state dict back to .safetensors
  89. def save_quantized_model_weights(quantized_state_dict, save_path):
  90. save_file(quantized_state_dict, save_path)
  91.  
  92. # Example usage
  93. safetensors_path = "flux1-dev-schnell-merge.sft" # Replace with your actual path
  94. quantized_safetensors_path = "quantized_model.safetensors" # Replace with your actual path
  95.  
  96. # Step 1: Load model weights from .safetensors
  97. state_dict = load_model_weights(safetensors_path)
  98.  
  99. # Step 2: Quantize the model weights
  100. quantized_state_dict = quantize_model_weights(state_dict)
  101.  
  102. # Step 3: Save the quantized model weights to .safetensors
  103. save_quantized_model_weights(quantized_state_dict, quantized_safetensors_path)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement