Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def load_weight_hook(self, module: nn.Module, incompatible_keys):
- missing_keys = incompatible_keys.missing_keys
- # если нет dora_scale создаём через норму весов
- if "dora_scale" in missing_keys:
- missing_keys.remove("dora_scale")
- org_weight = self.org_module[0].weight.detach().clone().float()
- self.dora_norm_dims = org_weight.dim() - 1
- if self.wd_on_out:
- self.dora_scale = nn.Parameter(
- torch.norm(
- org_weight.reshape(org_weight.shape[0], -1),
- dim=1,
- keepdim=True,
- ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
- )
- else:
- self.dora_scale = nn.Parameter(
- torch.norm(
- org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
- dim=1,
- keepdim=True,
- )
- .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
- .transpose(1, 0)
- )
- self.wd = True
- # обработка scalar
- for key in list(missing_keys):
- if "scalar" in key:
- missing_keys.remove(key)
- if isinstance(self.scalar, nn.Parameter):
- self.scalar.data.copy_(torch.ones_like(self.scalar))
- elif getattr(self, "scalar", None) is not None:
- self.scalar.copy_(torch.ones_like(self.scalar))
- else:
- self.register_buffer(
- "scalar", torch.ones_like(self.scalar), persistent=False
- )
Advertisement
Add Comment
Please, Sign In to add comment