# PlotEigs

May 3rd, 2021
454
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. class EigsVisualizer:
2.     def __init__(self, eigs):
3.         self.eigs = eigs
4.
5.     def show(self):
6.         self.plot_eigs(narrow_view=True, show_axes=True)
7.
8.     def _enforce_ratio(self, goal_ratio, supx, infx, supy, infy):
9.             """
10.            Computes the right value of `supx,infx,supy,infy` to obtain the desired
11.            ratio in :func:`plot_eigs`. Ratio is defined as
12.            ::
13.                dx = supx - infx
14.                dy = supy - infy
15.                max(dx,dy) / min(dx,dy)
16.            :param float goal_ratio: the desired ratio.
17.            :param float supx: the old value of `supx`, to be adjusted.
18.            :param float infx: the old value of `infx`, to be adjusted.
19.            :param float supy: the old value of `supy`, to be adjusted.
20.            :param float infy: the old value of `infy`, to be adjusted.
21.            :return tuple: a tuple which contains the updated values of
22.                `supx,infx,supy,infy` in this order.
23.            """
24.
25.             dx = supx - infx
26.             if dx == 0:
27.                 dx = 1.e-16
28.             dy = supy - infy
29.             if dy == 0:
30.                 dy = 1.e-16
31.             ratio = max(dx, dy) / min(dx, dy)
32.
33.             if ratio >= goal_ratio:
34.                 if dx < dy:
35.                     goal_size = dy / goal_ratio
36.
37.                     supx += (goal_size - dx) / 2
38.                     infx -= (goal_size - dx) / 2
39.                 elif dy < dx:
40.                     goal_size = dx / goal_ratio
41.
42.                     supy += (goal_size - dy) / 2
43.                     infy -= (goal_size - dy) / 2
44.
45.             return (supx,infx,supy,infy)
46.
47.     def _plot_limits(self, narrow_view):
48.             if narrow_view:
49.                 supx = max(self.eigs.real) + 0.05
50.                 infx = min(self.eigs.real) - 0.05
51.
52.                 supy = max(self.eigs.imag) + 0.05
53.                 infy = min(self.eigs.imag) - 0.05
54.
55.                 return self._enforce_ratio(8, supx, infx, supy,
56.                     infy)
57.             else:
58.                 return np.max(np.ceil(np.absolute(self.eigs)))
59.
60.     def plot_eigs(self,
61.                       show_axes=True,
62.                       show_unit_circle=True,
63.                       figsize=(8, 8),
64.                       title='',
65.                       narrow_view=False,
66.                       dpi=None,
67.                       filename=None):
68.             """
69.            Plot the eigenvalues.
70.            :param bool show_axes: if True, the axes will be showed in the plot.
71.                Default is True.
72.            :param bool show_unit_circle: if True, the circle with unitary radius
73.                and center in the origin will be showed. Default is True.
74.            :param tuple(int,int) figsize: tuple in inches defining the figure
75.                size. Default is (8, 8).
76.            :param str title: title of the plot.
77.            :param narrow_view bool: if True, the plot will show only the smallest
78.                rectangular area which contains all the eigenvalues, with a padding
79.                of 0.05. Not compatible with `show_axes=True`. Default is False.
80.            :param dpi int: If not None, the given value is passed to ``plt.figure``.
81.            :param str filename: if specified, the plot is saved at `filename`.
82.            """
83.             if self.eigs is None:
84.                 raise ValueError('The eigenvalues have not been computed.'
85.                                  'You have to perform the fit method.')
86.
87.             if dpi is not None:
88.                 plt.figure(figsize=figsize, dpi=dpi)
89.             else:
90.                 plt.figure(figsize=figsize)
91.
92.             plt.title(title)
93.             plt.gcf()
94.             ax = plt.gca()
95.
96.             points, = ax.plot(self.eigs.real,
97.                               self.eigs.imag,
98.                               'bo',
99.                               label='Eigenvalues')
100.
101.             if narrow_view:
102.                 supx, infx, supy, infy = self._plot_limits(narrow_view)
103.
104.                 # set limits for axis
105.                 ax.set_xlim((infx, supx))
106.                 ax.set_ylim((infy, supy))
107.
108.                 # x and y axes
109.                 if show_axes:
110.                     endx = np.min([supx, 1.])
111.                     ax.annotate('',
112.                                 xy=(endx, 0.),
113.                                 xytext=(np.max([infx, -1.]), 0.),
114.                                 arrowprops=dict(arrowstyle=("->" if endx == 1. else '-')))
115.
116.                     endy = np.min([supy, 1.])
117.                     ax.annotate('',
118.                                 xy=(0., endy),
119.                                 xytext=(0., np.max([infy, -1.])),
120.                                 arrowprops=dict(arrowstyle=("->" if endy == 1. else '-')))
121.             else:
122.                 # set limits for axis
123.                 limit = self._plot_limits(narrow_view)
124.
125.                 ax.set_xlim((-limit, limit))
126.                 ax.set_ylim((-limit, limit))
127.
128.                 # x and y axes
129.                 if show_axes:
130.                     ax.annotate('',
131.                                 xy=(np.max([limit * 0.8, 1.]), 0.),
132.                                 xytext=(np.min([-limit * 0.8, -1.]), 0.),
133.                                 arrowprops=dict(arrowstyle="->"))
134.                     ax.annotate('',
135.                                 xy=(0., np.max([limit * 0.8, 1.])),
136.                                 xytext=(0., np.min([-limit * 0.8, -1.])),
137.                                 arrowprops=dict(arrowstyle="->"))
138.
139.             plt.ylabel('Imaginary part')
140.             plt.xlabel('Real part')
141.
142.             if show_unit_circle:
143.                 unit_circle = plt.Circle((0., 0.),
144.                                          1.,
145.                                          color='green',
146.                                          fill=False,
147.                                          label='Unit circle',
148.                                          linestyle='--')
150.
151.             # Dashed grid
152.             gridlines = ax.get_xgridlines() + ax.get_ygridlines()
153.             for line in gridlines:
154.                 line.set_linestyle('-.')
155.             ax.grid(True)
156.
157.             # legend
158.             if show_unit_circle:
160.                     plt.legend([points, unit_circle],
161.                                ['Eigenvalues', 'Unit circle'],
162.                                loc='best'))
163.             else:
165.
166.             ax.set_aspect('equal')
167.
168.             if filename:
169.                 plt.savefig(filename)
170.             else:
171.                 plt.show()
RAW Paste Data