gt22

Untitled

Jun 12th, 2019
465
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.40 KB | None | 0 0
  1. from big_ol_pile_of_manim_imports import *
  2.  
  3.  
  4. def p(x=0, y=0):
  5.     if type(x) == list or type(x) == tuple or type(x) == np.ndarray:
  6.         return np.array(x)
  7.     return np.array((x, y, 0))
  8.  
  9.  
  10. def near(mobj, edge=ORIGIN):
  11.     return mobj.get_critical_point(edge)
  12.  
  13.  
  14. def comp(x=(0, 0, 0), y=(0, 0, 0)):
  15.     return p(x[0], y[1])
  16.  
  17.  
  18. class MMScene(Scene):
  19.  
  20.     m1 = np.array([[1, 2], [2, 9]])
  21.     m2 = np.array([[2, 4, 3], [6, 8, 3]])
  22.     res_a, res_b, num_steps = None, None, None
  23.     res_data = None
  24.     x, y, res = None, None, None
  25.  
  26.     def construct(self):
  27.         if self.m1.shape[1] != self.m2.shape[0]:
  28.             raise ValueError("Invalid matrices")
  29.  
  30.         self.res_a = self.m1.shape[0]
  31.         self.res_b = self.m2.shape[1]
  32.         self.num_steps = self.m1.shape[1]
  33.  
  34.         self.create_xy_matrices()
  35.         self.wait(0.5)
  36.         self.shift_y_matrix()
  37.         self.create_res_matrix()
  38.         self.wait(0.5)
  39.         self.perform_multiplication()
  40.         self.wait(1)
  41.         self.clean_up()
  42.         self.wait(3)
  43.  
  44.     def create_xy_matrices(self):
  45.         self.x = IntegerMatrix(self.m1)
  46.         self.x.move_to(p(-0.1), RIGHT)
  47.         self.y = IntegerMatrix(self.m2)
  48.         self.y.move_to(p(0.1), LEFT)
  49.         self.play(FadeIn(self.x), FadeIn(self.y))
  50.  
  51.     def shift_y_matrix(self):
  52.         yc = self.y.copy()
  53.         yc.shift(p(y=self.x.get_height() + 0.1))
  54.         self.play(Transform(self.y, yc, lag_ration=0))
  55.         self.remove(self.y)
  56.         self.add(yc)
  57.         self.y = yc
  58.  
  59.     def create_res_matrix(self):
  60.         self.res_data = np.zeros((self.res_a, self.res_b), dtype=int)
  61.         self.res = IntegerMatrix(self.res_data)
  62.         self.res.move_to(comp(near(self.y), near(self.x)))
  63.         self.play(FadeIn(self.res))
  64.  
  65.     def perform_multiplication(self):
  66.         for i in range(self.res_a):
  67.             for j in range(self.res_b):
  68.                 self.calc_res_elem(i, j)
  69.  
  70.     def clean_up(self):
  71.         rc = self.res.copy()
  72.         rc.move_to(ORIGIN)
  73.         self.play(FadeOut(self.x), FadeOut(self.y), Transform(self.res, rc))
  74.         self.remove(self.x, self.y, self.res)
  75.         self.add(rc)
  76.         self.res = rc
  77.  
  78.     def circle_matrix(self, m: Matrix, elem: int, color='red'):
  79.         e = m.get_entries()[elem]
  80.         c = Circle(radius=max(e.get_width(), e.get_height()) / 2 + 0.2, color=color)
  81.         c.move_to(e)
  82.         return c
  83.  
  84.     def calc_res_elem(self, x, y):
  85.         elem = self.res_data.shape[1] * x + y
  86.         res_tex = TexMobject(str(self.res_data[x, y]))
  87.         tex_pos = comp(y=near(self.x, DOWN) - 0.1)
  88.         res_tex.set_color(GREEN)
  89.         res_tex.move_to(tex_pos, UP)
  90.         c = self.circle_matrix(self.res, elem, 'green')
  91.         self.play(ShowCreation(c, run_time=0.5))
  92.         self.play(Write(res_tex, run_time=0.5))
  93.         xcp, ycp = None, None
  94.         for step in range(self.num_steps):
  95.             x_elem = self.m1.shape[1] * x + step
  96.             y_elem = self.m2.shape[1] * step + y
  97.  
  98.             a, b = self.m1[x, step], self.m2[step, y]
  99.             dr = a * b
  100.             self.res_data[x, y] += dr
  101.  
  102.             xcp, ycp = self.calc_step(x_elem, y_elem, xcp, ycp)
  103.  
  104.             res_tex = self.display_multiplication_tex(res_tex, a, b, dr, self.res_data[x, y])
  105.  
  106.             self.res, c = self.update_res_matrix(c, elem)
  107.  
  108.         self.play(*[FadeOut(g) for g in [c, xcp, ycp, res_tex]])
  109.         self.remove(c, xcp, ycp, res_tex)
  110.  
  111.     def update_res_matrix(self, c, elem):
  112.         rm = IntegerMatrix(self.res_data)
  113.         rm.move_to(self.res)
  114.         cc = self.circle_matrix(rm, elem, 'green')
  115.         self.play(Transform(self.res, rm), Transform(c, cc))
  116.         self.remove(self.res, c)
  117.         self.add(rm, cc)
  118.         return rm, cc
  119.  
  120.     def display_multiplication_tex(self, res_tex, a, b, dr, res):
  121.         cur_res = res_tex.get_tex_string()
  122.         rtm = TexMobject(f"_{{}}{cur_res}+{a}_{{}}\\times{b}^{{}}", tex_to_color_map={
  123.             f'{a}_{{}}': RED,
  124.             f'{b}^{{}}': YELLOW,
  125.             f'_{{}}{cur_res}': GREEN
  126.         })
  127.         rtm.move_to(res_tex, UP)
  128.         rta = TexMobject(f"{cur_res}^{{}}+{dr}_{{}}", tex_to_color_map={
  129.             f"{cur_res}^{{}}": GREEN,
  130.             f"{dr}_{{}}": ORANGE
  131.         })
  132.         rta.move_to(res_tex, UP)
  133.         rt = TexMobject(str(res))
  134.         rt.set_color(GREEN)
  135.         rt.move_to(res_tex, UP)
  136.         self.chain_transorm([res_tex, rtm, rta, rt], 0.5, run_time=0.7)
  137.         return rt
  138.  
  139.     def chain_transorm(self, mobjs, wait_time, **kwargs):
  140.         for a, b in zip(mobjs, mobjs[1:]):
  141.             self.play(Transform(a, b, **kwargs))
  142.             self.remove(a)
  143.             self.add(b)
  144.             self.wait(wait_time)
  145.  
  146.     def calc_step(self, x_elem, y_elem, xcp, ycp):
  147.  
  148.         if xcp is None or ycp is None:
  149.             xc = self.circle_matrix(self.x, x_elem)
  150.             yc = self.circle_matrix(self.y, y_elem, 'yellow')
  151.             self.play(*[ShowCreation(a, run_time=0.5) for a in (xc, yc)])
  152.         else:
  153.             xc = self.update_circle(xcp, self.x, x_elem)
  154.             yc = self.update_circle(ycp, self.y, y_elem)
  155.  
  156.             self.play(*[Transform(a, b, run_time=0.7, lag_ratio=0) for a, b in ((xcp, xc), (ycp, yc))])
  157.             self.remove(xcp, ycp)
  158.             self.add(xc, yc)
  159.         return xc, yc
  160.  
  161.     @staticmethod
  162.     def update_circle(c, m, e):
  163.         cc = c.copy()
  164.         cc.move_to(m.get_entries()[e])
  165.         return cc
Advertisement
Add Comment
Please, Sign In to add comment