Source code for lfd.analysis.plotting.utils

"""A collection of various miscelaneous functionality that helps visualize the
profiles and results of convolutions and their numerical values.

"""
import itertools
import os.path as ospath
import contextlib
import warnings

import matplotlib as mpl
import matplotlib.pyplot as plt

from lfd.analysis.profiles.samplers import generic_sampler
import lfd.analysis.utils as utils


__all__ = ["plot_profiles", "paperstyle", "set_ax_props", "get_ls",
           "get_style_path", "get_data_dir", "get_data_file",
           "create_data_file_name", "get_or_create_data"]


"""List of linestyles used in plots."""
LINESTYLES = ["solid", "dashed", "dashdot", "dotted", "-", "--", "-.", ":"]

"""Linestyle style counter."""
CUR_LS = 0


[docs]def get_data_dir(): """Returns the path to data directory of utils module. """ datadir = ospath.dirname(__file__) return ospath.join(datadir, "data")
[docs]def create_data_file_name(name): """Creates a filepath to a file in the data directory of utils module. The file might already exist. To retrieve a filepath to an already existing file use `get_data_file`. Parameters ---------- name : str Name of the desired data file. Returns ------- fpath : `str` Path to a file in the data directory. """ return ospath.join(get_data_dir(), name)
[docs]def get_data_file(name): """Returns a path to an existing file in the data directory of utils module. If the file doesn't exist an OSError is raised. Parameters ---------- name : str Name of the desired data file. Returns ------- fname : `str` Returns a full data dir path of an existing file. Raises ------ error : OSError File of given name doesn't exist. """ fullname = create_data_file_name(name) if ospath.exists(fullname): return fullname else: raise OSError(f"File {fullname} does not exist.")
[docs]def get_style_path(): """Returns the path to the file defining the Matplotlib figure style in the paper. """ return get_data_file("paperstyle.mplstyle")
[docs]def get_ls(): """Function returns the next linestyle from linestyles to help out with graphing in for loops """ global CUR_LS CUR_LS+=1 if CUR_LS >= len(LINESTYLES): CUR_LS = 0 return LINESTYLES[CUR_LS]
[docs]@contextlib.contextmanager def paperstyle(after_reset=True): """Returns matplotlib style context that uses the same style as was in the paper. """ stylepath = get_style_path() with plt.style.context(stylepath, after_reset) as c: yield
def plot_profile(ax, profile, normed=True, **kwargs): """Normalizes given profile and then plots them on a given axis. Set normed to False if normalization is not desired. Lables are determined from the name attribute of the profile. `*args` and `**kwargs` are forwarded to the matplotlib plot function. """ if isinstance(normed, bool): # I don't want to invoke profile.norm on an object in case height needs # to be preserved. obj = profile.obj/profile.obj.max() else: obj = profile.obj/normed ls = kwargs.pop("linestyle", '-') lbl = kwargs.pop("label", profile.name) ax.plot(profile.scale, obj, label=lbl, linestyle=ls, **kwargs) return ax
[docs]def plot_profiles(ax, profiles, normed=True, **kwargs): """Normalizes all given profiles and then plots them on a given axis. Set normed to False if normalization is not desired. Lables are determined from the name attribute of the profile. `*args` and `**kwargs` are forwarded to the matplotlib plot function. """ norms = (normed,)*len(profiles) lss = kwargs.pop("linestyles", (False,)*len(profiles)) cls = kwargs.pop("colors", (False,)*len(profiles)) lbls = kwargs.pop("labels", (False,)*len(profiles)) plotprops = [] for profile, label, color, linestyle in zip(profiles, lbls, cls, lss): kwarg = {} if label: kwarg["label"] = label if color: kwarg["color"] = color if linestyle: kwarg["linestyle"] = linestyle plotprops.append(kwarg) for profile, prop in zip(profiles, plotprops): plot_profile(ax, profile, normed, **prop, **kwargs) return ax
[docs]def set_ax_props(axes, xlims=(), xticks=(), xlabels=(), ylims =((-0.01, 1.1),), ylabels=(),): """Sets the labels, ticks and limits on all pairs of axes, ticks and labels provided. Parameters ---------- axes : `matplotlib.pyplot.Axes` Axes on which ticks, limits and labels will be set xlims : `list` or `tuple` Nested list or tuple of x-axis limits of the plots per axis. Default: (,). xticks : `list` or `tuple` Nested list or tuple of locations at which ticks and tick marks will be placed. Default: (,) xlabels : `list`, `tuple`, `string` or generator expression List, tuple or an iterable that provide a label for each of the axes ylims : `list` or `tuple` Nested list or tuple of y-axis limits of the plots per axis. Default: (-0.01, 1.1). ylabels : `list`, `tuple`, `string` or generator expression List, tuple or an iterable that provide a label for each of the axes Note ---- Ticks and lims that are shorter than the number of axes will begin repeating themselves untill they match the length of axes. Given ticks and lims *NEED* to be nested list or tuples (i.e. list of lists, list of tuples, tuple of lists...) in order to work properly. Otherwise only a singular limit can be extracted from them and the limits can not be set properly. Ticks and lims can be any iterable that supports Python's multiplication trick (i.e. [1, 2]*2 = [1, 2, 1, 2]). Given ticks and lims that have length zero will not be set. The labels are expected to always be given for each axis. When only a string is given an attempt will be made to inspect and ascertain which axes are shared and which ones are on the edge of the figure and only those axes will be labeled. This procedure, however, is susceptible to errors - often in situations where additional axes holding colorbars are given. In that situation best course of action is to provide the labels as expected, one for each axis, even when they are all the same. """ def pad_prop(prop, lentresh): """Given a prop and some numerical treshold returns (False, range(lentresh)) when the length of prop is zero and (True, prop) when the length of prop is different than zero. If a prop is shorter than the numerical treshold it will be padded to be at least as long (usually longer).""" # catch generators as limit expressions try: len(prop) except TypeError: prop = list(prop) setprop = False if len(prop) != 0: setprop = True if len(prop) < lentresh: prop *= lentresh else: # a fake prop return prevents the zip from truncating valid props prop = (prop,)*lentresh return setprop, prop # padding ticks and lims could be memory expensive # so we try not to expand if not needed. if isinstance(axes, plt.Axes): axlen = 1 axes = (axes,) else: axlen = len(axes) setxticks, xticks = pad_prop(xticks, axlen) setxlims, xlims = pad_prop(xlims, axlen) setylims, ylims = pad_prop(ylims, axlen) for (ax, ticks, xlim, ylim) in zip(axes, xticks, xlims, ylims): if setxticks: ax.set_xticks(ticks) if setxlims: ax.set_xlim(xlim) if setylims: ax.set_ylim(ylim) # labels are complicated since not only do we need to know if we want to # set them, but for which axes we want to set them too. xlbls = (xlabels,) if isinstance(xlabels, str) else xlabels ylbls = (ylabels,) if isinstance(ylabels, str) else ylabels setxlabels, xlbls = pad_prop(xlbls, axlen) setylabels, ylbls = pad_prop(ylbls, axlen) # only clear-cut case happens when labels are given as lists matching the # axes list. Any other case is at least partially ambiguous. Still, attempt # is made to resolve the left/right-most axes only and label only those # when labels are pure strings. If this happens, setflags are left True # only for those axes that are not shared. This will fail if hidden axes, # f.e. such as colorbars, are added to the plot. It should, however, always # be safe to label border axes of the plot. if isinstance(xlabels, str): xgrouper = [ax for ax in axes[0].get_shared_x_axes()] setxlabels = False if len(xgrouper) != 0 else True axes[-1].set_xlabel(xlabels) if isinstance(ylabels, str): ygrouper = [ax for ax in axes[0].get_shared_y_axes()] setylabels = False if len(ygrouper) != 0 else True axes[0].set_ylabel(ylabels) for (ax, ticks, xlim, ylim, xlbl, ylbl) in zip(axes, xticks, xlims, ylims, xlbls, ylbls): if setxlabels: ax.set_xlabel(xlbl) if setylabels: ax.set_ylabel(ylbl) if setxticks: ax.set_xticks(ticks) if setxlims: ax.set_xlim(xlim) if setylims: ax.set_ylim(ylim) return axes
[docs]def get_or_create_data(filenames, samplerKwargs=None, samplers=None, cache=True, **kwargs): """Retrieves data stored in filename(s) or creates the data and stores it at given location(s). If the datafiles are just filenames the files are read and written to lfd's cache. Parameters ---------- filenames : `list`, `tuple` or `str` Singular or a list of filepath(s) to existing files, or filename(s) of of already cached files. samplerKwargs : `list`, `tuple` or `dict`, optional Single or multiple dictionaries to pass to the sampler. samplers : `list`, `function`, optional Sampler(s) to invoke. By default `generic_sampler` is used.` cache : `bool` If True, data will be cached if it doesn't exist already. **kwargs : `dict` Optional. Any keywords are forwarded to the sampler. Returns ------- data : `list` A list of numpy arrays containing the read, or created, data. Notes ----- There is no functional difference between providing a single ``samplerArgs`` dictionary or givin the arguments as "plain old" kwargs. When ``samplerArgs`` are a list, however, it is iterated over, and each element is passed as a kwarg to the sampler, while kwargs are passed to the sampler as-is on every iteration. """ filenames = (filenames, ) if isinstance(filenames, str) else filenames if samplers is None: samplers = itertools.cycle((generic_sampler, )) kwargs["returnType"] = "grid" useGeneric = True if samplerKwargs is None: # neccessary to fool for loop that refuses to iterate over None's smplrKw = itertools.cycle(("dummyVal", )) elif isinstance(samplerKwargs, dict): smplrKw = itertools.cycle((samplerKwargs, )) else: smplrKw = samplerKwargs data = [] for fname, samplerKwarg, sampler in zip (filenames, smplrKw, samplers): try: data.append(utils.get_data(fname)) except FileNotFoundError: warnings.warn(f"Creating data file: '{fname}' - this might take a while.") if samplerKwargs is None: dat = sampler(**kwargs) else: dat = sampler(**samplerKwarg, **kwargs) if useGeneric: data.append(dat) else: data.append(dat) if cache: utils.cache_data(data[-1], fname) return data