Guest User

with keys

a guest
Aug 12th, 2024
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.12 KB | None | 0 0
  1. import torch
  2. import safetensors.torch
  3. from safetensors.torch import save_file, load_file
  4. from bitsandbytes.nn.modules import Params4bit, QuantState
  5.  
  6.  
  7. def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
  8. if state is None:
  9. return None
  10.  
  11. device = device or state.absmax.device
  12.  
  13. state2 = (
  14. QuantState(
  15. absmax=state.state2.absmax.to(device),
  16. shape=state.state2.shape,
  17. code=state.state2.code.to(device),
  18. blocksize=state.state2.blocksize,
  19. quant_type=state.state2.quant_type,
  20. dtype=state.state2.dtype,
  21. )
  22. if state.nested
  23. else None
  24. )
  25.  
  26. return QuantState(
  27. absmax=state.absmax.to(device),
  28. shape=state.shape,
  29. code=state.code.to(device),
  30. blocksize=state.blocksize,
  31. quant_type=state.quant_type,
  32. dtype=state.dtype,
  33. offset=state.offset.to(device) if state.nested else None,
  34. state2=state2,
  35. )
  36.  
  37.  
  38. class ForgeParams4bit(Params4bit):
  39. def to(self, *args, **kwargs):
  40. device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
  41. if device is not None and device.type == "cuda" and not self.bnb_quantized:
  42. return self._quantize(device)
  43. else:
  44. n = ForgeParams4bit(
  45. torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
  46. requires_grad=self.requires_grad,
  47. quant_state=copy_quant_state(self.quant_state, device),
  48. blocksize=self.blocksize,
  49. compress_statistics=self.compress_statistics,
  50. quant_type=self.quant_type,
  51. quant_storage=self.quant_storage,
  52. bnb_quantized=self.bnb_quantized,
  53. module=self.module
  54. )
  55. self.module.quant_state = n.quant_state
  56. return n
  57.  
  58. # Function to load the model weights from .safetensors
  59. def load_model_weights(safetensors_path):
  60. return load_file(safetensors_path)
  61.  
  62. # Function to quantize the model weights
  63. def quantize_model_weights(state_dict, device=torch.device("cuda")):
  64. quantized_state_dict = {}
  65.  
  66. for key, value in state_dict.items():
  67. if 'weight' in key: # Assuming 'weight' in the key indicates a layer to quantize
  68. # Convert to ForgeParams4bit and quantize
  69. quantized_param = ForgeParams4bit(
  70. value.to(device),
  71. requires_grad=False,
  72. compress_statistics=True,
  73. quant_type='fp4', # Assuming quant_type is 'fp4'
  74. quant_storage=torch.uint8,
  75. module=None # No module reference in this context
  76. )
  77.  
  78. # Quantizing the parameter explicitly
  79. quantized_param = quantized_param._quantize(device)
  80.  
  81. # Store the quantized weight
  82. quantized_state_dict[key] = quantized_param.data
  83.  
  84. # Ensure quant_state is included in the state_dict
  85. quant_state = getattr(quantized_param, "quant_state", None)
  86. if quant_state is not None:
  87. for k, v in quant_state.as_dict(packed=True).items():
  88. quantized_state_dict[key + "." + k] = v
  89.  
  90. else:
  91. quantized_state_dict[key] = value.to(device)
  92.  
  93. return quantized_state_dict
  94.  
  95. # Function to save the quantized state dict back to .safetensors
  96. def save_quantized_model_weights(quantized_state_dict, save_path):
  97. save_file(quantized_state_dict, save_path)
  98.  
  99. # Example usage
  100. safetensors_path = "flux1-dev-schnell-merge.sft" # Replace with your actual path
  101. quantized_safetensors_path = "quantized_model.safetensors" # Replace with your actual path
  102.  
  103. # Step 1: Load model weights from .safetensors
  104. state_dict = load_model_weights(safetensors_path)
  105.  
  106. # Step 2: Quantize the model weights
  107. quantized_state_dict = quantize_model_weights(state_dict)
  108.  
  109. # Step 3: Save the quantized model weights to .safetensors
  110. save_quantized_model_weights(quantized_state_dict, quantized_safetensors_path)
Add Comment
Please, Sign In to add comment