Guest User

Untitled

a guest
Sep 6th, 2025
11
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.72 KB | None | 0 0
  1. def load_weight_hook(self, module: nn.Module, incompatible_keys):
  2. missing_keys = incompatible_keys.missing_keys
  3.  
  4. # если нет dora_scale создаём через норму весов
  5. if "dora_scale" in missing_keys:
  6. missing_keys.remove("dora_scale")
  7. org_weight = self.org_module[0].weight.detach().clone().float()
  8. self.dora_norm_dims = org_weight.dim() - 1
  9. if self.wd_on_out:
  10. self.dora_scale = nn.Parameter(
  11. torch.norm(
  12. org_weight.reshape(org_weight.shape[0], -1),
  13. dim=1,
  14. keepdim=True,
  15. ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
  16. )
  17. else:
  18. self.dora_scale = nn.Parameter(
  19. torch.norm(
  20. org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
  21. dim=1,
  22. keepdim=True,
  23. )
  24. .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
  25. .transpose(1, 0)
  26. )
  27. self.wd = True
  28.  
  29. # обработка scalar
  30. for key in list(missing_keys):
  31. if "scalar" in key:
  32. missing_keys.remove(key)
  33. if isinstance(self.scalar, nn.Parameter):
  34. self.scalar.data.copy_(torch.ones_like(self.scalar))
  35. elif getattr(self, "scalar", None) is not None:
  36. self.scalar.copy_(torch.ones_like(self.scalar))
  37. else:
  38. self.register_buffer(
  39. "scalar", torch.ones_like(self.scalar), persistent=False
  40. )
Advertisement
Add Comment
Please, Sign In to add comment