Advertisement
CaptainNaoe

[python] plotting misc

Jan 16th, 2021 (edited)
1,009
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.65 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import seaborn as sns
  3.  
  4.  
  5. SMALL_SIZE = 14
  6. MEDIUM_SIZE = 18
  7. BIGGER_SIZE = 24
  8. BIGGERER_SIZE = 28
  9.  
  10. plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
  11. plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
  12. plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
  13. plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
  14. plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
  15. plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
  16. plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
  17. plt.rc('axes', titlesize=BIGGER_SIZE)    # fontsize of the figure title
  18.  
  19. plt.rcParams["figure.figsize"] = [10, 6]  # figsize
  20. plt.rcParams["figure.dpi"] = 100  # default dpi
  21. plt.rcParams["figure.autolayout"] = True
  22. plt.rcParams["legend.loc"] = "upper left"
  23. plt.rcParams["savefig.bbox"] = "tight"
  24. plt.rcParams["figure.autolayout"] = True
  25.  
  26. plt.rcParams["font.weight"] = "bold"
  27. plt.rcParams["axes.labelweight"] = "bold"
  28.  
  29.  
  30. linestyle_str = [
  31.      ('solid', 'solid'),      # Same as (0, ()) or '-'
  32.      ('dotted', 'dotted'),    # Same as (0, (1, 1)) or ':'
  33.      ('dashed', 'dashed'),    # Same as '--'
  34.      ('dashdot', 'dashdot')]  # Same as '-.'
  35.  
  36. linestyle_tuple = [
  37.      ('dotted',                (0, (1, 1))),
  38.      ('densely dotted',        (0, (1, 1))),
  39.  
  40.      ('dashed',                (0, (5, 5))),
  41.      ('densely dashed',        (0, (5, 1))),
  42.  
  43.      ('dashdotted',            (0, (3, 5, 1, 5))),
  44.      ('densely dashdotted',    (0, (3, 1, 1, 1))),
  45.  
  46.      ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
  47.      ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
  48.  
  49.  
  50.  
  51. def export_legend(legend, filename="results/legend.pdf", transparent=False):
  52.     fig = legend.figure
  53.     fig.canvas.draw()
  54.     bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
  55.     fig.savefig(filename, dpi=200, bbox_inches=bbox, transparent=transparent)
  56.  
  57. ax.tick_params(axis='both', which='major', labelsize=16)
  58.  
  59. ax.text(0.0, 0.7, f"$r={r}", ha="center", va="center")
  60.  
  61. fig.text(0.5, 0.04, 'common X', ha='center')
  62. fig.text(0.04, 0.5, 'common Y', va='center', rotation='vertical')
  63.  
  64.  
  65. def hue_regplot(data, x, y, hue, palette=None, **kwargs):
  66.     from matplotlib.cm import get_cmap
  67.  
  68.     regplots = []
  69.  
  70.     levels = data[hue].unique()
  71.  
  72.     if palette is None:
  73.         default_colors = get_cmap("tab10")
  74.         palette = {k: default_colors(i) for i, k in enumerate(levels)}
  75.  
  76.     for key in levels:
  77.         regplots.append(sns.regplot(x=x, y=y, data=data[data[hue] == key], label=key, color=palette[key], **kwargs))
  78.  
  79.     return regplots
  80.  
  81.  
  82. def subplots_centered(nrows, ncols, nfigs, figsize=None, dpi=None):
  83.     """
  84.    Modification of matplotlib plt.subplots(),
  85.    useful when some subplots are empty.
  86.  
  87.    It returns a grid where the plots
  88.    in the **last** row are centered.
  89.  
  90.    Inputs
  91.    ------
  92.        nrows, ncols, figsize: same as plt.subplots()
  93.        nfigs: real number of figures
  94.    """
  95.     assert nfigs < nrows * ncols, "No empty subplots, use normal plt.subplots() instead"
  96.  
  97.     fig = plt.figure(figsize=figsize, dpi=dpi)
  98.     axs = []
  99.  
  100.     m = nfigs % ncols
  101.     m = range(1, ncols + 1)[-m]  # subdivision of columns
  102.     gs = gridspec.GridSpec(nrows, m * ncols)
  103.  
  104.     for i in range(0, nfigs):
  105.         row = i // ncols
  106.         col = i % ncols
  107.  
  108.         if row == nrows - 1:  # center only last row
  109.             off = int(m * (ncols - nfigs % ncols) / 2)
  110.         else:
  111.             off = 0
  112.  
  113.         ax = plt.subplot(gs[row, m * col + off : m * (col + 1) + off])
  114.         axs.append(ax)
  115.  
  116.     return fig, np.array(axs)
  117.  
  118.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement