Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from big_ol_pile_of_manim_imports import *
- def p(x=0, y=0):
- if type(x) == list or type(x) == tuple or type(x) == np.ndarray:
- return np.array(x)
- return np.array((x, y, 0))
- def near(mobj, edge=ORIGIN):
- return mobj.get_critical_point(edge)
- def comp(x=(0, 0, 0), y=(0, 0, 0)):
- return p(x[0], y[1])
- class MMScene(Scene):
- m1 = np.array([[1, 2], [2, 9]])
- m2 = np.array([[2, 4, 3], [6, 8, 3]])
- res_a, res_b, num_steps = None, None, None
- res_data = None
- x, y, res = None, None, None
- def construct(self):
- if self.m1.shape[1] != self.m2.shape[0]:
- raise ValueError("Invalid matrices")
- self.res_a = self.m1.shape[0]
- self.res_b = self.m2.shape[1]
- self.num_steps = self.m1.shape[1]
- self.create_xy_matrices()
- self.wait(0.5)
- self.shift_y_matrix()
- self.create_res_matrix()
- self.wait(0.5)
- self.perform_multiplication()
- self.wait(1)
- self.clean_up()
- self.wait(3)
- def create_xy_matrices(self):
- self.x = IntegerMatrix(self.m1)
- self.x.move_to(p(-0.1), RIGHT)
- self.y = IntegerMatrix(self.m2)
- self.y.move_to(p(0.1), LEFT)
- self.play(FadeIn(self.x), FadeIn(self.y))
- def shift_y_matrix(self):
- yc = self.y.copy()
- yc.shift(p(y=self.x.get_height() + 0.1))
- self.play(Transform(self.y, yc, lag_ration=0))
- self.remove(self.y)
- self.add(yc)
- self.y = yc
- def create_res_matrix(self):
- self.res_data = np.zeros((self.res_a, self.res_b), dtype=int)
- self.res = IntegerMatrix(self.res_data)
- self.res.move_to(comp(near(self.y), near(self.x)))
- self.play(FadeIn(self.res))
- def perform_multiplication(self):
- for i in range(self.res_a):
- for j in range(self.res_b):
- self.calc_res_elem(i, j)
- def clean_up(self):
- rc = self.res.copy()
- rc.move_to(ORIGIN)
- self.play(FadeOut(self.x), FadeOut(self.y), Transform(self.res, rc))
- self.remove(self.x, self.y, self.res)
- self.add(rc)
- self.res = rc
- def circle_matrix(self, m: Matrix, elem: int, color='red'):
- e = m.get_entries()[elem]
- c = Circle(radius=max(e.get_width(), e.get_height()) / 2 + 0.2, color=color)
- c.move_to(e)
- return c
- def calc_res_elem(self, x, y):
- elem = self.res_data.shape[1] * x + y
- res_tex = TexMobject(str(self.res_data[x, y]))
- tex_pos = comp(y=near(self.x, DOWN) - 0.1)
- res_tex.set_color(GREEN)
- res_tex.move_to(tex_pos, UP)
- c = self.circle_matrix(self.res, elem, 'green')
- self.play(ShowCreation(c, run_time=0.5))
- self.play(Write(res_tex, run_time=0.5))
- xcp, ycp = None, None
- for step in range(self.num_steps):
- x_elem = self.m1.shape[1] * x + step
- y_elem = self.m2.shape[1] * step + y
- a, b = self.m1[x, step], self.m2[step, y]
- dr = a * b
- self.res_data[x, y] += dr
- xcp, ycp = self.calc_step(x_elem, y_elem, xcp, ycp)
- res_tex = self.display_multiplication_tex(res_tex, a, b, dr, self.res_data[x, y])
- self.res, c = self.update_res_matrix(c, elem)
- self.play(*[FadeOut(g) for g in [c, xcp, ycp, res_tex]])
- self.remove(c, xcp, ycp, res_tex)
- def update_res_matrix(self, c, elem):
- rm = IntegerMatrix(self.res_data)
- rm.move_to(self.res)
- cc = self.circle_matrix(rm, elem, 'green')
- self.play(Transform(self.res, rm), Transform(c, cc))
- self.remove(self.res, c)
- self.add(rm, cc)
- return rm, cc
- def display_multiplication_tex(self, res_tex, a, b, dr, res):
- cur_res = res_tex.get_tex_string()
- rtm = TexMobject(f"_{{}}{cur_res}+{a}_{{}}\\times{b}^{{}}", tex_to_color_map={
- f'{a}_{{}}': RED,
- f'{b}^{{}}': YELLOW,
- f'_{{}}{cur_res}': GREEN
- })
- rtm.move_to(res_tex, UP)
- rta = TexMobject(f"{cur_res}^{{}}+{dr}_{{}}", tex_to_color_map={
- f"{cur_res}^{{}}": GREEN,
- f"{dr}_{{}}": ORANGE
- })
- rta.move_to(res_tex, UP)
- rt = TexMobject(str(res))
- rt.set_color(GREEN)
- rt.move_to(res_tex, UP)
- self.chain_transorm([res_tex, rtm, rta, rt], 0.5, run_time=0.7)
- return rt
- def chain_transorm(self, mobjs, wait_time, **kwargs):
- for a, b in zip(mobjs, mobjs[1:]):
- self.play(Transform(a, b, **kwargs))
- self.remove(a)
- self.add(b)
- self.wait(wait_time)
- def calc_step(self, x_elem, y_elem, xcp, ycp):
- if xcp is None or ycp is None:
- xc = self.circle_matrix(self.x, x_elem)
- yc = self.circle_matrix(self.y, y_elem, 'yellow')
- self.play(*[ShowCreation(a, run_time=0.5) for a in (xc, yc)])
- else:
- xc = self.update_circle(xcp, self.x, x_elem)
- yc = self.update_circle(ycp, self.y, y_elem)
- self.play(*[Transform(a, b, run_time=0.7, lag_ratio=0) for a, b in ((xcp, xc), (ycp, yc))])
- self.remove(xcp, ycp)
- self.add(xc, yc)
- return xc, yc
- @staticmethod
- def update_circle(c, m, e):
- cc = c.copy()
- cc.move_to(m.get_entries()[e])
- return cc
Advertisement
Add Comment
Please, Sign In to add comment