Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- >>> import numpy as np
- >>> import automatic_differentiation as ad
- >>> x_val = np.random.randn(2, 3)
- >>> w_val = np.random.randn(3, 5)
- >>> x_val
- array([[ 0.66588232, -1.24427652, 0.33879172],
- [-0.26769112, 0.52176526, 0.32342972]])
- >>> w_val
- array([[ 0.45380673, 0.21560083, -0.58899006, 0.59063819, 0.11651024],
- [-1.00214694, 2.86367365, 1.4583315 , 0.08809848, 0.45746665],
- [ 0.20844987, 0.91073141, 0.07208344, -0.93793171, -0.72982744]])
- >>> x = ad.Variable(x_val, name="x")
- >>> w = ad.Variable(w_val, name="w")
- >>> x
- <automatic_differentiation.src.core.computational_graph.Variable object at 0x7fa30695f908>
- >>> w
- <automatic_differentiation.src.core.computational_graph.Variable object at 0x7fa30728a898>
- >>> x()
- array([[ 0.66588232, -1.24427652, 0.33879172],
- [-0.26769112, 0.52176526, 0.32342972]])
- >>> w()
- array([[ 0.45380673, 0.21560083, -0.58899006, 0.59063819, 0.11651024],
- [-1.00214694, 2.86367365, 1.4583315 , 0.08809848, 0.45746665],
- [ 0.20844987, 0.91073141, 0.07208344, -0.93793171, -0.72982744]])
- >>> y = x @ w
- >>> y()
- array([[ 1.61975087, -3.11108885, -2.18234443, -0.03408684, -0.7388924 ],
- [-0.5769466 , 1.73100859, 0.94188804, -0.41549686, -0.02854644]])
- >>> x_val @ w_val
- array([[ 1.61975087, -3.11108885, -2.18234443, -0.03408684, -0.7388924 ],
- [-0.5769466 , 1.73100859, 0.94188804, -0.41549686, -0.02854644]])
- >>> w_grad = ad.grad(y, [w])[0]
- >>> w_grad
- <automatic_differentiation.src.core.ops.Add object at 0x7fa30689f0f0>
- >>> w_grad()
- array([[ 0.3981912 , 0.3981912 , 0.3981912 , 0.3981912 , 0.3981912 ],
- [-0.72251126, -0.72251126, -0.72251126, -0.72251126, -0.72251126],
- [ 0.66222144, 0.66222144, 0.66222144, 0.66222144, 0.66222144]])
- >>> x_val.T @ np.ones_like(x_val @ w_val)
- array([[ 0.3981912 , 0.3981912 , 0.3981912 , 0.3981912 , 0.3981912 ],
- [-0.72251126, -0.72251126, -0.72251126, -0.72251126, -0.72251126],
- [ 0.66222144, 0.66222144, 0.66222144, 0.66222144, 0.66222144]])
- >>>
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement