Advertisement
KittehKing

Formulation Problem

Jun 23rd, 2024
197
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.21 KB | None | 0 0
  1. import numpy as np
  2. from z3 import Solver, Int, sat
  3. import matplotlib.pyplot as plt
  4.  
  5. class SolverFunc():
  6.     def __init__(self,filter_type, order):
  7.         self.filter_type=filter_type
  8.         self.half_order = order
  9.  
  10.     def db_to_linear(self,db_arr):
  11.         # Create a mask for NaN values
  12.         nan_mask = np.isnan(db_arr)
  13.  
  14.         # Apply the conversion to non-NaN values (magnitude)
  15.         linear_array = np.zeros_like(db_arr)
  16.         linear_array[~nan_mask] = 10 ** (db_arr[~nan_mask] / 20)
  17.  
  18.         # Preserve NaN values
  19.         linear_array[nan_mask] = np.nan
  20.         return linear_array
  21.    
  22.     def cm_handler(self,m,omega):
  23.         if self.filter_type == 0:
  24.             if m == 0:
  25.                 return 1
  26.             return 2*np.cos(np.pi*omega*m)
  27.        
  28.         #ignore the rest, its for later use if type 1 works
  29.         if self.filter_type == 1:
  30.             return 2*np.cos(omega*np.pi*(m+0.5))
  31.  
  32.         if self.filter_type == 2:
  33.             return 2*np.sin(omega*np.pi*(m-1))
  34.  
  35.         if self.filter_type == 3:
  36.             return 2*np.sin(omega*np.pi*(m+0.5))
  37.  
  38.  
  39. class FIRFilter:
  40.     def __init__(self, filter_type, order_upper, freqx_axis, freq_upper, freq_lower, ignore_lowerbound_lin, app=None):
  41.         self.filter_type = filter_type
  42.         self.order_upper = order_upper
  43.         self.freqx_axis = freqx_axis
  44.         self.freq_upper = freq_upper
  45.         self.freq_lower = freq_lower
  46.         self.ignore_lowerbound_lin = ignore_lowerbound_lin
  47.         self.h_int_res = []
  48.         self.app = app
  49.         self.fig, self.ax = plt.subplots()
  50.         self.freq_upper_lin=0
  51.         self.freq_lower_lin=0
  52.  
  53.     def runsolver(self):
  54.         self.order_current = int(self.order_upper)
  55.        
  56.         print("solver called")
  57.         sf = SolverFunc(self.filter_type, self.order_current)
  58.  
  59.         print("filter order:", self.order_current)
  60.         print("ignore lower than:", self.ignore_lowerbound_lin)
  61.         # linearize the bounds
  62.         self.freq_upper_lin = [sf.db_to_linear(f) for f in self.freq_upper]
  63.         self.freq_lower_lin = [sf.db_to_linear(f) for f in self.freq_lower]
  64.  
  65.         # declaring variables
  66.         h_int = [Int(f'h_int_{i}') for i in range(self.order_current)]
  67.  
  68.         # Create a Z3 solver instance
  69.         solver = Solver()
  70.  
  71.         # Create the sum constraints
  72.         for i in range(len(self.freqx_axis)):
  73.             print("upper freq:", self.freq_upper_lin[i])
  74.             print("lower freq:", self.freq_lower_lin[i])
  75.             print("freq:", self.freqx_axis[i])
  76.             term_sum_exprs = 0
  77.             half_order = self.order_current // 2
  78.             if np.isnan(self.freq_upper_lin[i]) or np.isnan(self.freq_lower_lin[i]):
  79.                 continue
  80.  
  81.             for j in range((self.order_current // 2)):
  82.                 cm_const = sf.cm_handler(j, self.freqx_axis[i])
  83.                 term_sum_exprs += h_int[j] * cm_const
  84.             solver.add(term_sum_exprs <= self.freq_upper_lin[i])
  85.  
  86.             if self.freq_lower_lin[i] < self.ignore_lowerbound_lin:
  87.                 continue
  88.             solver.add(term_sum_exprs >= self.freq_lower_lin[i])
  89.  
  90.         for i in range(self.order_current // 2):
  91.             mirror = (i + 1) * -1
  92.  
  93.             if self.filter_type == 0 or self.filter_type == 1:
  94.                 solver.add(h_int[i] == h_int[mirror])
  95.                 print(f"Added constraint: h_int[{i}] == h_int[{mirror}]")
  96.  
  97.             if self.filter_type == 2 or self.filter_type == 3:
  98.                 solver.add(h_int[i] == -h_int[mirror])
  99.                 print(f"Added constraint: h_int[{i}] == -h_int[{mirror}]")
  100.  
  101.             print(f"stype = {self.filter_type}, {i} is equal with {mirror}")
  102.  
  103.         print("solver running")
  104.  
  105.         if solver.check() == sat:
  106.             print("solver sat")
  107.             model = solver.model()
  108.             for i in range(self.order_current):
  109.                 print(f'h_int_{i} = {model[h_int[i]]}')
  110.                 self.h_int_res.append(model[h_int[i]].as_long())
  111.  
  112.             print(self.h_int_res)
  113.         else:
  114.             print("Unsatisfiable")
  115.  
  116.         print("solver stopped")
  117.  
  118.     def plot_result(self, result_coef):
  119.         print("result plotter called")
  120.         fir_coefficients = np.array(result_coef)
  121.         print("Fir coef in mp", fir_coefficients)
  122.  
  123.         # Compute the FFT of the coefficients
  124.         N = 5120  # Number of points for the FFT
  125.         frequency_response = np.fft.fft(fir_coefficients, N)
  126.         frequencies = np.fft.fftfreq(N, d=1.0)[:N//2]  # Extract positive frequencies up to Nyquist
  127.  
  128.         # Compute the magnitude and phase response for positive frequencies
  129.         magnitude_response = np.abs(frequency_response)[:N//2]
  130.  
  131.         # Convert magnitude response to dB
  132.         magnitude_response_db = 20 * np.log10(np.where(magnitude_response == 0, 1e-10, magnitude_response))
  133.  
  134.         # print("magdb in mp", magnitude_response_db)
  135.  
  136.         # Normalize frequencies to range from 0 to 1
  137.         normalized_frequencies = frequencies / np.max(frequencies)
  138.  
  139.  
  140.         #plot input
  141.         self.ax.scatter(self.freqx_axis, self.freq_upper_lin, color='r', s=20, picker=5)
  142.         self.ax.scatter(self.freqx_axis, self.freq_lower_lin, color='b', s=20, picker=5)
  143.  
  144.         # Plot the updated upper_ydata
  145.         self.ax.set_ylim([-0.5, 4])
  146.         self.ax.plot(normalized_frequencies, magnitude_response, color='y')
  147.  
  148.         if self.app:
  149.             self.app.canvas.draw()
  150.  
  151. # Test inputs
  152. filter_type = 0
  153. order_upper = 10
  154.  
  155.  
  156. # Initialize freq_upper and freq_lower with NaN values
  157. freqx_axis = np.linspace(0, 1, 16*order_upper) #according to Mr. Kumms paper
  158. freq_upper = np.full(16 * order_upper, np.nan)
  159. freq_lower = np.full(16 * order_upper, np.nan)
  160.  
  161. # Manually set specific values for the elements of freq_upper and freq_lower in dB
  162. freq_upper[20:40] = 3
  163. freq_lower[20:40] = 0.4
  164.  
  165. freq_upper[80:100] = -40
  166. freq_lower[80:100] = -1000
  167.  
  168.  
  169.  
  170. #beyond this bound lowerbound will be ignored
  171. ignore_lowerbound_lin = 0.0001
  172.  
  173. # Create FIRFilter instance
  174. fir_filter = FIRFilter(filter_type, order_upper, freqx_axis, freq_upper, freq_lower, ignore_lowerbound_lin)
  175.  
  176. # Run solver and plot result
  177. fir_filter.runsolver()
  178. fir_filter.plot_result(fir_filter.h_int_res)
  179.  
  180. # Show plot
  181. plt.show()
  182.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement