Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import sympy as sp
- from sympy import Symbol
- import numpy as np
- import matplotlib.pyplot as plt
- %matplotlib inline
- plt.rcParams['figure.figsize'] = 7, 7
- def lag_sym(data_points):
- x = Symbol('x')
- L = np.zeros(data_points.shape) # this stores our basis functions.
- L = []
- for i, xi in enumerate(data_points):
- temp_product = 1.0
- other_data_points = np.delete(data_points, i)
- for xj in other_data_points: # compute product
- temp_product *= (x-xj)/(xi-xj)
- L.append(temp_product)
- # L[i] = temp_product # store the product.
- # finished computing all basis functions.
- return L
- def fun_eval(F, mesh):
- # F should be a list of basis polynomials of sympy.Mul class
- return np.sum(np.array([[F[i].evalf(subs={'x':mesh[k]}) for i in range(len(F))] for k in range(len(mesh)) ]).transpose(), axis=1)
- def herm_sym(L, data_points):
- x = Symbol('x')
- H = []
- for i, xi in enumerate(data_points):
- H.append( (1 - 2*(x - xi)*L[i].diff() ) * L[i]**2 )
- return H
- mesh_resolution = [5]
- for n in mesh_resolution:
- plt.figure()
- # plt.title('data points =')
- data_l = np.linspace(-1, 1, 2*n)
- lag_fun = lag_sym(data_l)
- data_h = np.linspace(-1, 1, n)
- L = lag_sym(data_h)
- # herm_fun = herm_sym(L, data_h)
- plot_mesh = np.linspace(-1,1,100)
- lag_eval = fun_eval(lag_fun, plot_mesh)
- # herm_eval = fun_eval(herm_fun, plot_mesh)
- plt.plot(plot_mesh,lag_eval)
- # plt.plot(plot_mesh,herm_eval)
Add Comment
Please, Sign In to add comment