Advertisement
vlpap

Gradient descent

Nov 9th, 2020
642
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.21 KB | None | 0 0
  1. import torch
  2. import numpy as np
  3. import sympy as sp
  4. from scipy.optimize import linprog
  5.  
  6. def gradient(f, X):
  7.     X = X.clone().detach().requires_grad_(True)
  8.     f = f(*X)
  9.     f.backward()
  10.     return X.grad.data
  11.  
  12. def to_series(X, e):
  13.     S = [e * i.item() for i in X]
  14.     return torch.cat(S, 0)
  15.  
  16. def lpt(nabla_f, A):
  17.     l = nabla_f.shape[0]
  18.     e = torch.ones(l)
  19.     e[1::2] = -1
  20.     c = to_series(nabla_f, e)
  21.     a = [np.ones(c.shape[0])]
  22.     for i in A:
  23.         tmp = to_series(i, e)
  24.         a.append(tmp.detach().cpu().numpy())
  25.    
  26.     b = np.zeros(A.shape[0]+1)
  27.     b[0] = 1
  28.    
  29.     bounds = [(0, None)] * c.shape[0]
  30.     res = linprog(c, A_ub=a, b_ub=b, bounds=bounds, method='revised simplex')
  31.     res_x = res.x
  32.     if res_x[res_x > 0].size:
  33.         return e * torch.tensor(res_x[res_x > 0])
  34.     else:
  35.         return torch.zeros(l)
  36.  
  37. def isin(A, B):
  38.     return torch.stack([b == A for b in B]).sum(0).bool()        
  39.  
  40. def get_h(X0, Xi):
  41.     r = sp.symbols('r')
  42.     Xir = [x0 + xi*r for x0, xi in zip(X0, Xi)]
  43.     h = sp.simplify(f(*Xir))
  44.     if str(h)[0] == '-':
  45.         return sp.lambdify(r, -h)
  46.     else:
  47.         return sp.lambdify(r, h)
  48.  
  49. def gss(f, a, b, eps=1e-8):
  50.     gr = (np.sqrt(5) - 1) / 2
  51.     lmbda = a + (1 - gr) * (b - a)
  52.     mu = a + gr * (b - a)
  53.     while abs(a - b) > eps:
  54.         if f(lmbda) < f(mu):
  55.             b = mu
  56.         else:
  57.             a = lmbda
  58.         lmbda = a + (1 - gr) * (b - a)
  59.         mu = a + gr * (b - a)
  60.     return (b + a) / 2
  61.  
  62. def calc_r(X0, Xi):
  63.     h = get_h(X0, Xi)
  64.     r = gss(h, -1000, 1000)
  65.     return r
  66.  
  67. def gradient_descent(f, X0, A, B, precision=10e-4):
  68.     i = 0
  69.     while True:
  70.         ax = torch.sum(A * X0, 1)
  71.         if ax[ax < B].shape[0] == ax.shape[0]:
  72.             E = -1 * gradient(f, X0.data)
  73.         elif ax[ax == B].shape[0] > 0:
  74.             mask = isin(ax, B)
  75.             ax_masked = ax[mask]
  76.             A_id = [torch.where(B == i)[0] for i in ax_masked]
  77.             A_id = torch.cat(A_id, 0)
  78.             nabla_f = gradient(f, X0.data)
  79.             E = lpt(nabla_f, A[A_id])
  80.         if torch.norm(E).item() <= precision:
  81.             break
  82.        
  83.         alpha_is = (B - torch.sum(A * X0, 1)) / torch.sum(A * E, 1)
  84.         ae = torch.sum(A * E, 1)
  85.         ae[ae <= 0] = float('inf')
  86.         alpha_is[torch.isinf(ae)] = float('inf')
  87.         alpha_i = alpha_is.min().item()
  88.  
  89.         Ejk = E.detach().clone()
  90.         alpha_js = -1 * X0 / Ejk
  91.         Ejk[Ejk >= 0] = float('inf')
  92.         alpha_js[torch.isinf(Ejk)] = float('inf')
  93.         alpha_j = alpha_js.min().item()
  94.  
  95.         alpha_astrsk = calc_r(X0, E)
  96.         r = np.min([alpha_i, alpha_j, alpha_astrsk]).item()
  97.         X0 = X0 + r*E    
  98.         #print(f' iter {i}: \n E = {E}, X0 = {X0}, alpha_i = {alpha_i}, alpha_j = {alpha_j}, alpha_* = {alpha_astrsk}, r = {r}, Norm = {torch.norm(E).item()}')
  99.         i += 1
  100.     return X0.data, i
  101.  
  102.  
  103. X0 = torch.tensor([1/2, 17/12])
  104.  
  105. A = torch.tensor([[1, 1],
  106.                   [1, 2]], dtype=torch.float64)
  107.  
  108. B = torch.tensor([3,4], dtype=torch.float64)
  109.  
  110. f = lambda x1, x2: (x1-4)**2 + (x2-2)**2
  111.  
  112. X, i = gradient_descent(f, X0, A, B, 10e-4)
  113.  
  114. print(f' Maximum = {X} \n f(Maximum) = {f(*X)} \n number of iterations = {i}')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement