Advertisement
Guest User

Untitled

a guest
Apr 17th, 2025
29
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.69 KB | None | 0 0
  1. import math
  2.  
  3. import torch
  4. import torch.nn as nn
  5.  
  6. from .base import LycorisBaseModule
  7. from ..functional.loha import diff_weight as loha_diff_weight
  8.  
  9.  
  10. class LohaModule(LycorisBaseModule):
  11. name = "loha"
  12. support_module = {
  13. "linear",
  14. "conv1d",
  15. "conv2d",
  16. "conv3d",
  17. }
  18. weight_list = [
  19. "hada_w1_a",
  20. "hada_w1_b",
  21. "hada_w2_a",
  22. "hada_w2_b",
  23. "hada_t1",
  24. "hada_t2",
  25. "alpha",
  26. "dora_scale",
  27. ]
  28. weight_list_det = ["hada_w1_a"]
  29.  
  30. def __init__(
  31. self,
  32. lora_name,
  33. org_module: nn.Module,
  34. multiplier=1.0,
  35. lora_dim=4,
  36. alpha=1,
  37. dropout=0.0,
  38. rank_dropout=0.0,
  39. module_dropout=0.0,
  40. use_tucker=False,
  41. use_scalar=False,
  42. rank_dropout_scale=False,
  43. weight_decompose=False,
  44. wd_on_out=True,
  45. bypass_mode=None,
  46. rs_lora=False,
  47. use_ggpo=False, # новый аргумент для включения ggpo
  48. sigma=0.0, # новый аргумент для sigma
  49. beta=0.0, # новый аргумент для beta
  50. **kwargs,
  51. ):
  52. super().__init__(
  53. lora_name,
  54. org_module,
  55. multiplier,
  56. dropout,
  57. rank_dropout,
  58. module_dropout,
  59. rank_dropout_scale,
  60. bypass_mode,
  61. )
  62. if self.module_type not in self.support_module:
  63. raise ValueError(f"{self.module_type} is not supported in LoHa algo.")
  64. self.lora_name = lora_name
  65. self.lora_dim = lora_dim
  66. self.tucker = False
  67. self.rs_lora = rs_lora
  68.  
  69. # Store GGPO parameters
  70. self.use_ggpo = use_ggpo
  71. self.sigma = sigma
  72. self.beta = beta
  73.  
  74. w_shape = self.shape
  75. if self.module_type.startswith("conv"):
  76. in_dim = org_module.in_channels
  77. k_size = org_module.kernel_size
  78. out_dim = org_module.out_channels
  79. self.shape = (out_dim, in_dim, *k_size)
  80. self.tucker = use_tucker and any(i != 1 for i in k_size)
  81. if self.tucker:
  82. w_shape = (out_dim, in_dim, *k_size)
  83. else:
  84. w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item())
  85.  
  86. if self.tucker:
  87. self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
  88. self.hada_w1_a = nn.Parameter(
  89. torch.empty(lora_dim, w_shape[0])
  90. ) # out_dim, 1-mode
  91. self.hada_w1_b = nn.Parameter(
  92. torch.empty(lora_dim, w_shape[1])
  93. ) # in_dim , 2-mode
  94.  
  95. self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
  96. self.hada_w2_a = nn.Parameter(
  97. torch.empty(lora_dim, w_shape[0])
  98. ) # out_dim, 1-mode
  99. self.hada_w2_b = nn.Parameter(
  100. torch.empty(lora_dim, w_shape[1])
  101. ) # in_dim , 2-mode
  102. else:
  103. self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
  104. self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
  105.  
  106. self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
  107. self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
  108.  
  109. self.wd = weight_decompose
  110. self.wd_on_out = wd_on_out
  111. if self.wd:
  112. org_weight = org_module.weight.cpu().clone().float()
  113. self.dora_norm_dims = org_weight.dim() - 1
  114. if self.wd_on_out:
  115. self.dora_scale = nn.Parameter(
  116. torch.norm(
  117. org_weight.reshape(org_weight.shape[0], -1),
  118. dim=1,
  119. keepdim=True,
  120. ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
  121. ).float()
  122. else:
  123. self.dora_scale = nn.Parameter(
  124. torch.norm(
  125. org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
  126. dim=1,
  127. keepdim=True,
  128. )
  129. .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
  130. .transpose(1, 0)
  131. ).float()
  132.  
  133. if self.dropout:
  134. print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
  135.  
  136. if type(alpha) == torch.Tensor:
  137. alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
  138. alpha = lora_dim if alpha is None or alpha == 0 else alpha
  139.  
  140. r_factor = lora_dim
  141. if self.rs_lora:
  142. r_factor = math.sqrt(r_factor)
  143.  
  144. self.scale = alpha / r_factor
  145.  
  146. self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
  147.  
  148. if use_scalar:
  149. self.scalar = nn.Parameter(torch.tensor(0.0))
  150. else:
  151. self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
  152. # Need more experiments on init method
  153. if self.tucker:
  154. torch.nn.init.normal_(self.hada_t1, std=0.1)
  155. torch.nn.init.normal_(self.hada_t2, std=0.1)
  156. torch.nn.init.normal_(self.hada_w1_b, std=1)
  157. torch.nn.init.normal_(self.hada_w1_a, std=0.1)
  158. torch.nn.init.normal_(self.hada_w2_b, std=1)
  159. if use_scalar:
  160. torch.nn.init.normal_(self.hada_w2_a, std=0.1)
  161. else:
  162. torch.nn.init.constant_(self.hada_w2_a, 0)
  163. # Добавляем хуки к градиентам, если включен режим GGPO
  164. if self.use_ggpo:
  165. self.apply_ggpo_grad_hooks()
  166.  
  167. def apply_ggpo_grad_hooks(self):
  168. # Регистрируем хук для всех параметров, добавляющий шум к градиенту
  169. for name, param in self.named_parameters():
  170. param.register_hook(lambda grad, sigma=self.sigma, beta=self.beta: grad + torch.randn_like(grad) * sigma + beta)
  171.  
  172. @classmethod
  173. def make_module_from_state_dict(
  174. cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale
  175. ):
  176. module = cls(
  177. lora_name,
  178. orig_module,
  179. 1,
  180. w1b.size(0),
  181. float(alpha),
  182. use_tucker=t1 is not None,
  183. weight_decompose=dora_scale is not None,
  184. )
  185. module.hada_w1_a.copy_(w1a)
  186. module.hada_w1_b.copy_(w1b)
  187. module.hada_w2_a.copy_(w2a)
  188. module.hada_w2_b.copy_(w2b)
  189. if t1 is not None:
  190. module.hada_t1.copy_(t1)
  191. module.hada_t2.copy_(t2)
  192. if dora_scale is not None:
  193. module.dora_scale.copy_(dora_scale)
  194. return module
  195.  
  196. def load_weight_hook(self, module: nn.Module, incompatible_keys):
  197. missing_keys = incompatible_keys.missing_keys
  198. for key in missing_keys:
  199. if "scalar" in key:
  200. del missing_keys[missing_keys.index(key)]
  201. if isinstance(self.scalar, nn.Parameter):
  202. self.scalar.data.copy_(torch.ones_like(self.scalar))
  203. elif getattr(self, "scalar", None) is not None:
  204. self.scalar.copy_(torch.ones_like(self.scalar))
  205. else:
  206. self.register_buffer(
  207. "scalar", torch.ones_like(self.scalar), persistent=False
  208. )
  209.  
  210. def get_weight(self, shape):
  211. scale = torch.tensor(
  212. self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device
  213. )
  214. if self.tucker:
  215. weight = loha_diff_weight(
  216. self.hada_w1_b,
  217. self.hada_w1_a,
  218. self.hada_w2_b,
  219. self.hada_w2_a,
  220. self.hada_t1,
  221. self.hada_t2,
  222. gamma=scale,
  223. )
  224. else:
  225. weight = loha_diff_weight(
  226. self.hada_w1_b,
  227. self.hada_w1_a,
  228. self.hada_w2_b,
  229. self.hada_w2_a,
  230. None,
  231. None,
  232. gamma=scale,
  233. )
  234. if shape is not None:
  235. weight = weight.reshape(shape)
  236. if self.training and self.rank_dropout:
  237. drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype)
  238. drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
  239. if self.rank_dropout_scale:
  240. drop /= drop.mean()
  241. weight *= drop
  242. # Применяем принципы GGPO, если включено
  243. if self.use_ggpo:
  244. weight = weight + torch.randn_like(weight) * self.sigma + self.beta
  245. return weight
  246.  
  247. def get_diff_weight(self, multiplier=1, shape=None, device=None):
  248. scale = self.scale * multiplier
  249. diff = self.get_weight(shape) * scale
  250. if device is not None:
  251. diff = diff.to(device)
  252. return diff, None
  253.  
  254. def get_merged_weight(self, multiplier=1, shape=None, device=None):
  255. diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
  256. weight = self.org_weight
  257. if self.wd:
  258. merged = self.apply_weight_decompose(weight + diff, multiplier)
  259. else:
  260. merged = weight + diff * multiplier
  261. return merged, None
  262.  
  263. def apply_weight_decompose(self, weight, multiplier=1):
  264. weight = weight.to(self.dora_scale.dtype)
  265. if self.wd_on_out:
  266. weight_norm = (
  267. weight.reshape(weight.shape[0], -1)
  268. .norm(dim=1)
  269. .reshape(weight.shape[0], *[1] * self.dora_norm_dims)
  270. ) + torch.finfo(weight.dtype).eps
  271. else:
  272. weight_norm = (
  273. weight.transpose(0, 1)
  274. .reshape(weight.shape[1], -1)
  275. .norm(dim=1, keepdim=True)
  276. .reshape(weight.shape[1], *[1] * self.dora_norm_dims)
  277. .transpose(0, 1)
  278. ) + torch.finfo(weight.dtype).eps
  279.  
  280. scale = self.dora_scale.to(weight.device) / weight_norm
  281. if multiplier != 1:
  282. scale = multiplier * (scale - 1) + 1
  283.  
  284. return weight * scale
  285.  
  286. def custom_state_dict(self):
  287. destination = {}
  288. destination["alpha"] = self.alpha
  289. if self.wd:
  290. destination["dora_scale"] = self.dora_scale
  291. destination["hada_w1_a"] = self.hada_w1_a * self.scalar
  292. destination["hada_w1_b"] = self.hada_w1_b
  293. destination["hada_w2_a"] = self.hada_w2_a
  294. destination["hada_w2_b"] = self.hada_w2_b
  295. if self.tucker:
  296. destination["hada_t1"] = self.hada_t1
  297. destination["hada_t2"] = self.hada_t2
  298. return destination
  299.  
  300. @torch.no_grad()
  301. def apply_max_norm(self, max_norm, device=None):
  302. orig_norm = (self.get_weight(self.shape) * self.scalar).norm()
  303. norm = torch.clamp(orig_norm, max_norm / 2)
  304. desired = torch.clamp(norm, max=max_norm)
  305. ratio = desired.cpu() / norm.cpu()
  306.  
  307. scaled = norm != desired
  308. if scaled:
  309. self.scalar *= ratio
  310.  
  311. return scaled, orig_norm * ratio
  312.  
  313. def bypass_forward_diff(self, x, scale=1):
  314. diff_weight = self.get_weight(self.shape) * self.scalar * scale
  315. return self.drop(self.op(x, diff_weight, **self.kw_dict))
  316.  
  317. def bypass_forward(self, x, scale=1):
  318. return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
  319.  
  320. def forward(self, x: torch.Tensor, *args, **kwargs):
  321. if self.module_dropout and self.training:
  322. if torch.rand(1) < self.module_dropout:
  323. return self.op(
  324. x,
  325. self.org_module[0].weight.data,
  326. (
  327. None
  328. if self.org_module[0].bias is None
  329. else self.org_module[0].bias.data
  330. ),
  331. )
  332. if self.bypass_mode:
  333. return self.bypass_forward(x, scale=self.multiplier)
  334. else:
  335. diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar
  336. weight = self.org_module[0].weight.data.to(self.dtype)
  337. if self.wd:
  338. weight = self.apply_weight_decompose(
  339. weight + diff_weight, self.multiplier
  340. )
  341. else:
  342. weight = weight + diff_weight * self.multiplier
  343. bias = (
  344. None
  345. if self.org_module[0].bias is None
  346. else self.org_module[0].bias.data
  347. )
  348. return self.op(x, weight, bias, **self.kw_dict)
  349.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement