Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import safetensors.torch
- from safetensors.torch import save_file, load_file
- from bitsandbytes.nn.modules import Params4bit, QuantState
- def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
- if state is None:
- return None
- device = device or state.absmax.device
- state2 = (
- QuantState(
- absmax=state.state2.absmax.to(device),
- shape=state.state2.shape,
- code=state.state2.code.to(device),
- blocksize=state.state2.blocksize,
- quant_type=state.state2.quant_type,
- dtype=state.state2.dtype,
- )
- if state.nested
- else None
- )
- return QuantState(
- absmax=state.absmax.to(device),
- shape=state.shape,
- code=state.code.to(device),
- blocksize=state.blocksize,
- quant_type=state.quant_type,
- dtype=state.dtype,
- offset=state.offset.to(device) if state.nested else None,
- state2=state2,
- )
- class ForgeParams4bit(Params4bit):
- def to(self, *args, **kwargs):
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
- if device is not None and device.type == "cuda" and not self.bnb_quantized:
- return self._quantize(device)
- else:
- n = ForgeParams4bit(
- torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
- requires_grad=self.requires_grad,
- quant_state=copy_quant_state(self.quant_state, device),
- blocksize=self.blocksize,
- compress_statistics=self.compress_statistics,
- quant_type=self.quant_type,
- quant_storage=self.quant_storage,
- bnb_quantized=self.bnb_quantized,
- module=self.module
- )
- self.module.quant_state = n.quant_state
- return n
- # Function to load the model weights from .safetensors
- def load_model_weights(safetensors_path):
- return load_file(safetensors_path)
- # Function to quantize the model weights
- # Function to quantize the model weights
- def quantize_model_weights(state_dict, device=torch.device("cuda")):
- quantized_state_dict = {}
- for key, value in state_dict.items():
- if 'weight' in key: # Assuming 'weight' in the key indicates a layer to quantize
- # Convert to ForgeParams4bit and quantize
- quantized_param = ForgeParams4bit(
- value.to(device),
- requires_grad=False,
- compress_statistics=True,
- quant_type='nf4', # Assuming quant_type is 'bnb'
- quant_storage=torch.uint8,
- module=None # No module reference in this context
- )
- # Quantizing the parameter explicitly
- quantized_param = quantized_param._quantize(device)
- quantized_state_dict[key] = quantized_param.data
- else:
- quantized_state_dict[key] = value.to(device)
- return quantized_state_dict
- # Function to save the quantized state dict back to .safetensors
- def save_quantized_model_weights(quantized_state_dict, save_path):
- save_file(quantized_state_dict, save_path)
- # Example usage
- safetensors_path = "flux1-dev-schnell-merge.sft" # Replace with your actual path
- quantized_safetensors_path = "quantized_model.safetensors" # Replace with your actual path
- # Step 1: Load model weights from .safetensors
- state_dict = load_model_weights(safetensors_path)
- # Step 2: Quantize the model weights
- quantized_state_dict = quantize_model_weights(state_dict)
- # Step 3: Save the quantized model weights to .safetensors
- save_quantized_model_weights(quantized_state_dict, quantized_safetensors_path)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement