Source code for time_domain_astronomy_sandbox.plotting

"""Plotting methods."""
import numpy as np
import scipy.signal
from matplotlib import pyplot as plt
from matplotlib import rc
rc('font', size=16)
rc('axes', titlesize=18)
rc('axes', labelsize=18)

from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib.offsetbox import AnchoredText

from .snr import SNR

def add_at(ax, t, loc=2):
    # fp = dict(size=13)
    _at = AnchoredText(t, loc=loc)#, prop=fp)
    ax.add_artist(_at)
    return _at

def set_fig_dims(direction, data_arr, spectrum=False):
    if direction == 'horizontal':
        ncols = len(data_arr)*2 if spectrum else len(data_arr)
        nrows = 1
    elif direction == 'vertical':
        ncols = 1
        nrows = len(data_arr)*2 if spectrum else len(data_arr)

    return ncols, nrows

[docs]def set_multi_axes(ax, direction, xticks, xtick_labels, yticks, ytick_labels, spectrum=False, dual=False): """Set axes ticks and tick labels Parameters ---------- ax : matplotlib.axes.Axes Array of axes direction : str General direction onto which append subplots xticks : list List of ticks for x axis xticks_labels : list List of tick labels for x axis yticks : list List of ticks for y axis yticks_labels : list List of tick labels for y axis """ if dual: for i, axi in enumerate(ax): if len(xticks) > 0 and len(xtick_labels) > 0: if (direction == 'vertical' and i == len(ax)-1) or direction == 'horizontal': axi.set_xlabel('Time (s)') axi.set_xticks(xticks) axi.set_xticklabels(xtick_labels) else: plt.setp(axi.get_xticklabels(), visible=False) else: axi.set_xlabel('X Index') if len(yticks) > 0 and len(ytick_labels) > 0: if (direction == 'horizontal' and i == 0) or direction == 'vertical': axi.set_ylabel('SNR' if (spectrum and (i % 2)==1) else 'Freq. (MHz)') if spectrum is False or (i % 2)==0: axi.set_yticks(yticks) axi.set_yticklabels(ytick_labels) else: plt.setp(axi.get_yticklabels(), visible=False) else: axi.set_ylabel('SNR') else: for i, axi in enumerate(ax): if len(xticks) > 0 and len(xtick_labels) > 0: if (direction == 'vertical' and i == len(ax)-1) or direction == 'horizontal': axi.set_xlabel('Time (s)') axi.set_xticks(xticks) axi.set_xticklabels(xtick_labels) else: plt.setp(axi.get_xticklabels(), visible=False) else: axi.set_xlabel('X Index') if len(yticks) > 0 and len(ytick_labels) > 0: if (direction == 'horizontal' and i == 0) or direction == 'vertical': axi.set_ylabel('S/N' if (spectrum) else 'Freq. (MHz)') if spectrum is False: axi.set_yticks(yticks) axi.set_yticklabels(ytick_labels) else: plt.setp(axi.get_yticklabels(), visible=False) else: axi.set_ylabel('S/N')
[docs]def plot_spectrum(data, ncols=1, nrows=1): """Plot spectrum. Parameters ---------- data : Numpy.Array ncols : int Number of column for matplotlib.pyplot.subplots nrows : int Number of rows for matplotlib.pyplot.subplots """ fig, ax = plt.subplots(figsize=(10, 5), ncols=ncols, nrows=nrows) ax.plot(data, extent='lower') ax.set_xlabel('channel') ax.set_ylabel('intensity')
[docs]def plot_image(data, xticks=[], xtick_labels=[], yticks=[], ytick_labels=[], ncols=1, nrows=1, xfig_size=10, yfig_size=5 ): """Plot spectrum. Parameters ---------- data : Numpy.Array xticks : list List of ticks for x axis xticks_labels : list List of tick labels for x axis yticks : list List of ticks for y axis yticks_labels : list List of tick labels for y axis ncols : int Number of column for matplotlib.pyplot.subplots nrows : int Number of rows for matplotlib.pyplot.subplots xfig_size : int Figure size in x yfig_size : int Figure size in y """ fig, ax = plt.subplots( figsize=(xfig_size, yfig_size), ncols=ncols, nrows=nrows ) ax.imshow(data, origin='lower') ax.set_xlabel('time (s)') ax.set_ylabel('frequency (MHz)') if len(xticks) > 0 and len(xtick_labels) > 0: ax.set_xticks(xticks) ax.set_xticklabels(xtick_labels) if len(yticks) > 0 and len(ytick_labels) > 0: ax.set_yticks(yticks) ax.set_yticklabels(ytick_labels)
[docs]def plot_multi_1D(data_arr, labels=[], xticks=[], xtick_labels=[], yticks=[], ytick_labels=[], direction='horizontal', xfig_size=10, yfig_size=5, loc=4, detection_threshold=None, savefig=False, fig_name='multi-1D', ext='png', dpi=150 ): """Plot multiple spectrum. Parameters ---------- data : list(Numpy.Array) list of data arrays xticks : list List of ticks for x axis xticks_labels : list List of tick labels for x axis yticks : list List of ticks for y axis yticks_labels : list List of tick labels for y axis direction : str General direction onto which append subplots (default: 'horizontal') xfig_size : int Figure size in x (default: 10) yfig_size : int Figure size in y (default: 5) savefig : bool Save figure (default: False) fig_name : str Figure name (default: 'multi-images') ext : str File extension (default 'png') """ ncols, nrows = set_fig_dims(direction, data_arr) fig, ax = plt.subplots( figsize=(xfig_size, yfig_size), ncols=ncols, nrows=nrows, gridspec_kw = {'hspace':0, 'wspace':0}, sharex=True ) for i, axi in enumerate(ax): axi.plot(SNR().simple_snr(data_arr[i], axis=0)) if detection_threshold is not None: axi.plot( [i for i in range(data_arr[i].shape[1])], [detection_threshold for i in range(data_arr[i].shape[1])] ) if len(labels) > 0: pos_x = data_arr[i].shape[0]-0.3*data_arr[i].shape[0] pos_y = data_arr[i].shape[1]-0.3*data_arr[i].shape[1] if len(labels[i]) > 0: add_at(axi, labels[i], loc=loc) set_multi_axes(ax, direction=direction, xticks=xticks, xtick_labels=xtick_labels, yticks=yticks, ytick_labels=ytick_labels, spectrum=True) plt.tight_layout() if savefig: plt.savefig("%s.%s" % (fig_name, ext), dpi=dpi)
[docs]def plot_multi_images(data_arr, labels=[], xticks=[], xtick_labels=[], yticks=[], ytick_labels=[], direction='horizontal', xfig_size=10, yfig_size=5, loc=4, spectrum=False, detection_threshold=None, colorbar=False, savefig=False, fig_name='multi-images', ext='png', dpi=150 ): """Plot images. Parameters ---------- data : list(Numpy.Array) list of data arrays xticks : list List of ticks for x axis xticks_labels : list List of tick labels for x axis yticks : list List of ticks for y axis yticks_labels : list List of tick labels for y axis direction : str General direction onto which append subplots (default: 'horizontal') xfig_size : int Figure size in x (default: 10) yfig_size : int Figure size in y (default: 5) savefig : bool Save figure (default: False) fig_name : str Figure name (default: 'multi-images') ext : str File extension (default 'png') """ ncols, nrows = set_fig_dims(direction, data_arr, spectrum) fig, ax = plt.subplots( figsize=(xfig_size, yfig_size), ncols=ncols, nrows=nrows, # sharex=True if spectrum else False, gridspec_kw=dict( height_ratios=[2 if (i % 2) == 0 else 1 for i in range(len(data_arr) * 2)], hspace=0.1, wspace=0. ) if spectrum else dict( hspace=0.1, wspace=0. ) ) ax_i = 0 spec_max_snr = -999 for i, data in enumerate(data_arr): im = ax[ax_i].imshow(data_arr[i], origin='lower') if len(labels) > 0: pos_x = data_arr[i].shape[0]-0.3*data_arr[i].shape[0] pos_y = data_arr[i].shape[1]-0.3*data_arr[i].shape[1] if len(labels[i]) > 0: add_at(ax[ax_i], labels[i], loc=loc) if colorbar: cbar = plt.colorbar(im, ax=ax[ax_i]) # cbar.set_label('Arbitrary unit', size=15) if spectrum: ax_i += 1 if detection_threshold is not None: ax[ax_i].plot( [detection_threshold for i in range(data_arr[i].shape[1])], '--', color='black', alpha=0.5 ) ax[ax_i].set_xlim(0, data_arr[i].shape[1]-1) snr = SNR().simple_snr(data_arr[i], axis=0) ax[ax_i].plot(snr) # ax[ax_i].axis('off') if spec_max_snr < np.nanmax(snr): spec_max_snr = np.nanmax(snr) ax_i += 1 set_multi_axes(ax, direction, xticks, xtick_labels, yticks, ytick_labels, spectrum, dual=True) if spectrum: for i, axi in enumerate(ax): if (i % 2) == 1: axi.set_ylim(0, spec_max_snr+1) plt.tight_layout() if savefig: plt.savefig("%s.%s" % (fig_name, ext), dpi=dpi)