csp256

SymPy quaternion composition verification and code generation

Jul 17th, 2025
13
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.67 KB | Source Code | 0 0
  1. from sympy import *
  2. init_printing()
  3.  
  4. def make_quaternion_expression(x: tuple, i: Symbol, j: Symbol, k: Symbol):
  5.     return x[0] + x[1]*i + x[2]*j + x[3]*k
  6.    
  7. def quaternion_simplification(expr: Expr, i: Symbol, j: Symbol, k: Symbol):
  8.     # Only does one round: if you know nothing about expr you must call this function until it doesn't change `expr`
  9.     # (or do something more clever!)
  10.     expr = expr.subs(i**2, -1)
  11.     expr = expr.subs(j**2, -1)
  12.     expr = expr.subs(k**2, -1)
  13.    
  14.     expr = expr.subs(i*j, k)
  15.     expr = expr.subs(i*k, -j)
  16.  
  17.     expr = expr.subs(j*i, -k)
  18.     expr = expr.subs(j*k, i)
  19.  
  20.     expr = expr.subs(k*i, j)
  21.     expr = expr.subs(k*j, -i)
  22.     return expr
  23.  
  24. def quaternion_expression_to_coefficient_tuple(expr, i, j, k):
  25.     # Strip off each imaginary term one by one
  26.     i_coeff = expr.coeff(i, 1)
  27.     expr = simplify(expr - i_coeff * i)
  28.  
  29.     j_coeff = expr.coeff(j, 1)
  30.     expr = simplify(expr - j_coeff * j)
  31.    
  32.     k_coeff = expr.coeff(k, 1)
  33.     expr = simplify(expr - k_coeff * k)
  34.  
  35.     # Only the real part should be left... let's double check that
  36.     real_part = expr
  37.     assert real_part == real_part.coeff(i, 0).coeff(j, 0).coeff(k, 0)
  38.  
  39.     return (real_part, i_coeff, j_coeff, k_coeff)
  40.  
  41. def quaternion_composition(a: tuple, b: tuple):
  42.     i, j, k = symbols("i, j, k", commutative=False)
  43.  
  44.     # These are our two quaternions
  45.     # 0th index holds the real part
  46.     A = make_quaternion_expression(a, i, j, k)
  47.     B = make_quaternion_expression(b, i, j, k)
  48.  
  49.     # FOIL the product of two quaternions
  50.     expr = expand( B * A )
  51.  
  52.     # Use quaternion algebra to simplify expression
  53.     # we happen to know that no more than two imaginary terms are multiplied together,
  54.     # so we can just do each substitution once
  55.     expr = quaternion_simplification(expr, i, j, k)
  56.  
  57.     return quaternion_expression_to_coefficient_tuple(expr, i, j, k)
  58.  
  59. def quaternion_coefficient_conjugate(a: tuple):
  60.     return (a[0], -a[1], -a[2], -a[3])
  61.  
  62. def quaternion_rotation(quat_coeff: tuple, vec3: ImmutableMatrix):
  63.     i, j, k = symbols('i, j, k', commutative=False)
  64.     q = make_quaternion(quat_coeff, i, j, k)
  65.  
  66.     conj_coeff = quaternion_coefficient_conjugate( quat_coeff )
  67.     q_conj = make_quaternion(conj_coeff, i, j, k)
  68.  
  69.     expr = expand( q * vec3 * q_conj )
  70.     expr = quaternion_simplification(expr, i, j, k)
  71.     return expr
  72.  
  73. # Let us verify that rotating a vector by quaternions `a` then `b` is equal to
  74. # composing `a` and `b` into one quaternion `ab` then rotating a vector by `ab`
  75. vec3 = ImmutableMatrix( symbols("v[:3]") )
  76. a = symbols("a[:4]")
  77. b = symbols("b[:4]")
  78.  
  79. # Method 1
  80. vec3_a = quaternion_rotation(a, vec3)
  81. vec3_a_b = quaternion_rotation(b, vec3_a)
  82.  
  83. # Method 2
  84. ab = quaternion_composition(a, b)
  85. vec3_ab = quaternion_rotation(ab, vec3)
  86.  
  87. # We just verified that composition does indeed work the way we think it does
  88. assert vec3_a_b == vec3_ab
  89.  
  90. # Let's print these functions out as C++ code
  91. from sympy.printing.cxx import *
  92. printer = CXX17CodePrinter()
  93.  
  94. print("template <typename T>")
  95. print("void")
  96. print("quaternion_composition(")
  97. print("\t\tT const * a,")
  98. print("\t\tT const * b,")
  99. print("\t\tT * a_then_b)")
  100. print("{")
  101. for i in range(len(a_then_b)):
  102.     print(f"\ta_then_b[{i}] = " + str(a_then_b[i]) + ";")
  103. print("}")
  104. print("")
  105. # While we're at it, lets emit some code for rotating a vector by a quaternion
  106. print("template <typename T>")
  107. print("void")
  108. print("quaternion_rotation(")
  109. print("\t\tT const * quat,")
  110. print("\t\tT const * vec3,")
  111. print("\t\tT * out)")
  112. print("{")
  113. rotated_vector = quaternion_rotation(
  114.         symbols("quat[:4]"),
  115.          ImmutableMatrix( symbols("vec3[:3]") ))
  116. for i in range(len(rotated_vector)):
  117.     print(f"\tout[{i}] = " + printer.doprint(rotated_vector[i]) + ";")
  118. print("}")
  119.  
  120. # template <typename T>
  121. # void
  122. # quaternion_composition(
  123. #       T const * a,
  124. #       T const * b,
  125. #       T * a_then_b)
  126. # {
  127. #   a_then_b[0] = a[0]*b[0] - a[1]*b[1] - a[2]*b[2] - a[3]*b[3];
  128. #   a_then_b[1] = a[0]*b[1] + a[1]*b[0] - a[2]*b[3] + a[3]*b[2];
  129. #   a_then_b[2] = a[0]*b[2] + a[1]*b[3] + a[2]*b[0] - a[3]*b[1];
  130. #   a_then_b[3] = a[0]*b[3] - a[1]*b[2] + a[2]*b[1] + a[3]*b[0];
  131. # }
  132.  
  133. # template <typename T>
  134. # void
  135. # quaternion_rotation(
  136. #       T const * quat,
  137. #       T const * vec3,
  138. #       T * out)
  139. # {
  140. #   out[0] = std::pow(quat[0], 2)*vec3[0] + std::pow(quat[1], 2)*vec3[0] + std::pow(quat[2], 2)*vec3[0] + std::pow(quat[3], 2)*vec3[0];
  141. #   out[1] = std::pow(quat[0], 2)*vec3[1] + std::pow(quat[1], 2)*vec3[1] + std::pow(quat[2], 2)*vec3[1] + std::pow(quat[3], 2)*vec3[1];
  142. #   out[2] = std::pow(quat[0], 2)*vec3[2] + std::pow(quat[1], 2)*vec3[2] + std::pow(quat[2], 2)*vec3[2] + std::pow(quat[3], 2)*vec3[2];
  143. # }
Advertisement
Add Comment
Please, Sign In to add comment