"""
Functions for generating plots from model outputs
This module implements Atomica's plotting library, which is used to
generate various plots from model outputs.
"""
import itertools
import os
import errno
from collections import defaultdict
from pandas import isna
import numpy as np
import scipy.interpolate
import scipy.integrate
import matplotlib.cm as cmx
import matplotlib.colors as matplotlib_colors
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.legend import Legend
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle, Patch
from matplotlib.ticker import FuncFormatter
import atomica
import sciris as sc
from .model import Compartment, Characteristic, Parameter, Link, SourceCompartment, JunctionCompartment, SinkCompartment
from .results import Result
from .system import logger, NotFoundError
from .function_parser import parse_function
from .system import FrameworkSettings as FS
from .utils import format_duration, nested_loop
__all__ = ["save_figs", "PlotData", "Series", "plot_bars", "plot_series", "plot_legend", "reorder_legend", "relabel_legend"]
settings = dict()
settings["legend_mode"] = "together" # Possible options are ['together','separate','none']
settings["bar_width"] = 1.0 # Width of bars in plot_bars()
settings["line_width"] = 3.0 # Width of lines in plot_series()
settings["marker_edge_width"] = 3.0
settings["dpi"] = 150 # average quality
settings["transparent"] = False
[docs]
def save_figs(figs, path=".", prefix="", fnames=None, file_format="png") -> None:
"""
Save figures to disk as PNG or other graphics format files
Functions like `plot_series` and `plot_bars` can generate multiple figures, depending on
the data and legend options. This function facilitates saving those figures together.
The name for the file can be automatically selected when saving figures generated
by `plot_series` and `plot_bars`. This function also deals with cases where the figure
list may or may not contain a separate legend (so saving figures with this function means
the legend mode can be changed freely without having to change the figure saving code).
:param figs: A figure or list of figures
:param path: Optionally append a path to the figure file name
:param prefix: Optionally prepend a prefix to the file name
:param fnames: Optionally an array of file names. By default, each figure is named
using its 'label' property. If a figure has an empty 'label' string it is assumed to be
a legend and will be named based on the name of the figure immediately before it.
If you provide an empty string in the `fnames` argument this same operation will be carried
out. If the last figure name is omitted, an empty string will automatically be added.
:param file_format: the file format to save as, default png, allowed formats {png, ps, pdf, svg}
"""
try:
os.makedirs(path)
except OSError as err:
if err.errno != errno.EEXIST:
raise
# Sanitize fig array input
if not isinstance(figs, list):
figs = [figs]
# Sanitize and populate default fnames values
if fnames is None:
fnames = [fig.get_label() for fig in figs]
elif not isinstance(fnames, list):
fnames = [fnames]
# Add legend figure to the end
if len(fnames) < len(figs):
fnames.append("")
assert len(fnames) == len(figs), "Number of figures must match number of specified filenames, or the last figure must be a legend with no label"
assert fnames[0], "The first figure name cannot be empty"
assert file_format in ["png", "ps", "pdf", "svg"], f'File format {file_format} invalid. Format must be one of "png", "ps", "pdf", or "svg"'
for i, fig in enumerate(figs):
if not fnames[i]: # assert above means that i>0
fnames[i] = fnames[i - 1] + "_legend"
legend = fig.findobj(Legend)[0]
renderer = fig.canvas.get_renderer()
fig.draw(renderer=renderer)
bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
else:
bbox = "tight"
fname = prefix + fnames[i] + "." + file_format
fname = sc.sanitizefilename(fname) # parameters may have inappropriate characters
fig.savefig(os.path.join(path, fname), bbox_inches=bbox, dpi=settings["dpi"], transparent=settings["transparent"])
logger.info('Saved figure "%s"', fname)
[docs]
class PlotData:
"""
Process model outputs into plottable quantities
This is what gets passed into a plotting function, which displays a View of the data
Conceptually, we are applying visuals to the data.
But we are performing an extraction step rather than doing it directly because things like
labels, colours, groupings etc. only apply to plots, not to results, and there could be several
different views of the same data.
Operators for ``-`` and ``/`` are defined to faciliate looking at differences and relative
differences of derived quantities (quantities computed using ``PlotData`` operations) across
individual results. To keep the implementation tractable, they don't generalize further than that,
and operators ``+`` and ``*`` are not implemented because these operations rarely make sense
for the data being operated on.
:param results: Specify which results to plot. Can be
- a Result,
- a list of Results,
- a dict/odict of results (the name of the result is taken from the Result, not the dict)
:param outputs: The name of an output compartment, characteristic, or
parameter, or list of names. Inside a list, a dict can be given to
specify an aggregation e.g. ``outputs=['sus',{'total':['sus','vac']}]``
where the key is the new name. Or, a formula can be given which will
be evaluated by looking up labels within the model object. Links will
automatically be summed over
:param pops: The name of an output population, or list of names. Like
outputs, can specify a dict with a list of pops to aggregate over them
:param output_aggregation: If an output aggregation is requested, combine the outputs listed using one of
- 'sum' - just add values together
- 'average' - unweighted average of quantities
- 'weighted' - weighted average where the weight is the
compartment size, characteristic value, or link source
compartment size (summed over duplicate links). 'weighted'
method cannot be used with non-transition parameters and a
KeyError will result in that case
:param pop_aggregation: Same as output_aggregation, except that 'weighted'
uses population sizes. Note that output aggregation is performed
before population aggregation. This also means that population
aggregation can be used to combine already aggregated outputs (e.g.
can first sum 'sus'+'vac' within populations, and then take weighted
average across populations)
:param project: Optionally provide a :class:`Project` object, which will be used to convert names to labels in the outputs for plotting.
:param time_aggregation: Optionally specify time aggregation method. Supported methods are 'integrate' and 'average' (no weighting). When aggregating
times, *non-annualized* flow rates will be used.
:param t_bins: Optionally specify time bins, which will enable time aggregation. Supported inputs are
- A vector of bin edges. Time points are included if the time
is >= the lower bin value and < upper bin value.
- A scalar bin size (e.g. 5) which will be expanded to a vector spanning the data
- The string 'all' will maps to bin edges ``[-inf, inf]`` aggregating over all time
:param accumulate: Optionally accumulate outputs over time. Can be 'sum' or 'integrate' to either sum quantities or integrate by multiplying by the timestep. Accumulation happens *after* time aggregation.
The logic is extremely simple - the quantities in the Series pass through ``cumsum``. If 'integrate' is selected, then the quantities are multiplied
by ``dt`` and the units are multiplied by ``years``
:return: A :class:`PlotData` instance that can be passed to :func:`plot_series` or :func:`plot_bars`
.. automethod:: __getitem__
"""
# TODO: Make sure to chuck a useful error when t_bins is greater than sim duration, rather than just crashing.
def __init__(self, results, outputs=None, pops=None, output_aggregation=None, pop_aggregation=None, project=None, time_aggregation=None, t_bins=None, accumulate=None):
# Validate inputs
if isinstance(results, sc.odict):
results = [result for _, result in results.items()]
elif not isinstance(results, list):
results = [results]
result_names = [x.name for x in results]
if len(set(result_names)) != len(result_names):
raise Exception("Results must have different names (in their result.name property)")
if pops in [None, "all"]:
pops = [pop.name for pop in results[0].model.pops]
elif pops == "total":
pops = [{"Total": [pop.name for pop in results[0].model.pops]}]
pops = sc.promotetolist(pops)
if outputs is None:
outputs = [comp.name for comp in results[0].model.pops[0].comps if not (isinstance(comp, SourceCompartment) or isinstance(comp, JunctionCompartment) or isinstance(comp, SinkCompartment))]
elif not isinstance(outputs, list):
outputs = [outputs]
pops = _expand_dict(pops)
outputs = _expand_dict(outputs)
assert output_aggregation in [None, "sum", "average", "weighted"]
assert pop_aggregation in [None, "sum", "average", "weighted"]
# First, get all of the pops and outputs requested by flattening the lists
pops_required = _extract_labels(pops)
outputs_required = _extract_labels(outputs)
self.series = []
tvecs = dict()
# Because aggregations always occur within a Result object, loop over results
for result in results:
result_label = result.name
tvecs[result_label] = result.model.t
dt = result.model.dt
aggregated_outputs = defaultdict(dict) # Dict with aggregated_outputs[pop_label][aggregated_output_label]
aggregated_units = dict() # Dict with aggregated_units[aggregated_output_label]
aggregated_timescales = dict()
output_units = dict()
output_timescales = dict()
compsize = dict()
popsize = dict()
# Defaultdict won't throw key error when checking outputs.
data_label = defaultdict(str) # Label used to identify which data to plot, maps output label to data label.
# Aggregation over outputs takes place first, so loop over pops
for pop_label in pops_required:
pop = result.model.get_pop(pop_label)
popsize[pop_label] = pop.popsize()
data_dict = dict() # Temporary storage for raw outputs
# First pass, extract the original output quantities, summing links and annualizing as required
for output_label in outputs_required:
try:
vars = pop.get_variable(output_label)
except NotFoundError as e:
in_pops = [x.name for x in result.model.pops if output_label in x]
message = f'Variable "{output_label}" was requested in population "{pop.name}" but it is only defined in these populations: {in_pops}'
raise NotFoundError(message) from e
if vars[0].vals is None:
raise Exception('Requested output "%s" was not recorded because only partial results were saved' % (vars[0].name))
if isinstance(vars[0], Link):
data_dict[output_label] = np.zeros(tvecs[result_label].shape)
compsize[output_label] = np.zeros(tvecs[result_label].shape)
for link in vars:
data_dict[output_label] += link.vals
compsize[output_label] += link.source.vals if not isinstance(link.source, JunctionCompartment) else link.source.outflow
# Annualize the units, and record that they correspond to a flow per year
data_dict[output_label] /= dt
output_units[output_label] = vars[0].units
output_timescales[output_label] = 1.0
data_label[output_label] = vars[0].parameter.name if (vars[0].parameter and vars[0].parameter.units == FS.QUANTITY_TYPE_NUMBER) else None # Only use parameter data points if the units match
elif isinstance(vars[0], Parameter):
data_dict[output_label] = vars[0].vals
output_units[output_label] = vars[0].units
output_timescales[output_label] = vars[0].timescale # The timescale attribute for non-transition parameters will already be set to None
data_label[output_label] = vars[0].name
# If there are links, we can retrieve a compsize for the user to do a weighted average
if vars[0].links:
output_units[output_label] = vars[0].units
compsize[output_label] = np.zeros(tvecs[result_label].shape)
for link in vars[0].links:
compsize[output_label] += link.source.vals if not isinstance(link.source, JunctionCompartment) else link.source.outflow
elif isinstance(vars[0], Compartment) or isinstance(vars[0], Characteristic):
data_dict[output_label] = vars[0].vals
compsize[output_label] = vars[0].vals
output_units[output_label] = vars[0].units
output_timescales[output_label] = None
data_label[output_label] = vars[0].name
else:
raise Exception("Unknown type")
# Second pass, add in any dynamically computed quantities
# Using model. Parameter objects will automatically sum over Links and convert Links
# to annualized rates
for output in outputs:
if not isinstance(output, dict):
continue
output_label, f_stack_str = list(output.items())[0] # _extract_labels has already ensured only one key is present
if not sc.isstring(f_stack_str):
continue
def placeholder_pop():
return None
placeholder_pop.name = "None"
par = Parameter(pop=placeholder_pop, name=output_label)
fcn, dep_labels = parse_function(f_stack_str)
deps = {}
displayed_annualization_warning = False
for dep_label in dep_labels:
vars = pop.get_variable(dep_label)
if t_bins is not None and (isinstance(vars[0], Link) or isinstance(vars[0], Parameter)) and time_aggregation == "sum" and not displayed_annualization_warning:
raise Exception("Function includes Parameter/Link so annualized rates are being used. Aggregation should therefore use 'average' rather than 'sum'.")
deps[dep_label] = vars
par._fcn = fcn
par.deps = deps
par.preallocate(tvecs[result_label], dt)
par.update()
data_dict[output_label] = par.vals
output_units[output_label] = par.units
output_timescales[output_label] = None
# Third pass, aggregate them according to any aggregations present
for output in outputs: # For each final output
if isinstance(output, dict):
output_name = list(output.keys())[0]
labels = output[output_name]
# If this was a function, aggregation over outputs doesn't apply so just put it straight in.
if sc.isstring(labels):
aggregated_outputs[pop_label][output_name] = data_dict[output_name]
aggregated_units[output_name] = "unknown" # Also, we don't know what the units of a function are
aggregated_timescales[output_name] = None # Timescale is lost
continue
units = list(set([output_units[x] for x in labels]))
timescales = list(set([np.nan if isna(output_timescales[x]) else output_timescales[x] for x in labels])) # Ensure that None and nan don't appear as different timescales
# Set default aggregation method depending on the units of the quantity
if output_aggregation is None:
if units[0] in ["", FS.QUANTITY_TYPE_FRACTION, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE]:
output_aggregation = "average"
else:
output_aggregation = "sum"
if len(units) > 1:
logger.warning("Aggregation for output '%s' is mixing units, this is almost certainly not desired.", output_name)
aggregated_units[output_name] = "unknown"
else:
if units[0] in ["", FS.QUANTITY_TYPE_FRACTION, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE] and output_aggregation == "sum" and len(labels) > 1: # Dimensionless, like prevalance
logger.warning("Output '%s' is not in number units, so output aggregation probably should not be 'sum'.", output_name)
aggregated_units[output_name] = output_units[labels[0]]
if len(timescales) > 1:
logger.warning("Aggregation for output '%s' is mixing timescales, this is almost certainly not desired.", output_name)
aggregated_timescales[output_name] = None
else:
aggregated_timescales[output_name] = output_timescales[labels[0]]
if output_aggregation == "sum":
aggregated_outputs[pop_label][output_name] = sum(data_dict[x] for x in labels) # Add together all the outputs
elif output_aggregation == "average":
aggregated_outputs[pop_label][output_name] = sum(data_dict[x] for x in labels) # Add together all the outputs
aggregated_outputs[pop_label][output_name] /= len(labels)
elif output_aggregation == "weighted":
aggregated_outputs[pop_label][output_name] = sum(data_dict[x] * compsize[x] for x in labels) # Add together all the outputs
aggregated_outputs[pop_label][output_name] /= sum([compsize[x] for x in labels])
else:
aggregated_outputs[pop_label][output] = data_dict[output]
aggregated_units[output] = output_units[output]
aggregated_timescales[output] = output_timescales[output]
# Now aggregate over populations
# If we have requested a reduction over populations, this is done for every output present
for pop in pops: # This is looping over the population entries
for output_name in aggregated_outputs[list(aggregated_outputs.keys())[0]].keys():
if isinstance(pop, dict):
pop_name = list(pop.keys())[0]
pop_labels = pop[pop_name]
# Set population aggregation method depending on
if pop_aggregation is None:
if aggregated_units[output_name] in ["", FS.QUANTITY_TYPE_FRACTION, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE]:
pop_aggregation = "average"
else:
pop_aggregation = "sum"
if pop_aggregation == "sum":
if aggregated_units[output_name] in ["", FS.QUANTITY_TYPE_FRACTION, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE] and len(pop_labels) > 1:
logger.warning("Output '%s' is not in number units, so population aggregation probably should not be 'sum'", output_name)
vals = sum(aggregated_outputs[x][output_name] for x in pop_labels) # Add together all the outputs
elif pop_aggregation == "average":
vals = sum(aggregated_outputs[x][output_name] for x in pop_labels) # Add together all the outputs
vals /= len(pop_labels)
elif pop_aggregation == "weighted":
numerator = sum(aggregated_outputs[x][output_name] * popsize[x] for x in pop_labels) # Add together all the outputs
denominator = sum([popsize[x] for x in pop_labels])
vals = np.divide(numerator, denominator, out=np.full(numerator.shape, np.nan, dtype=float), where=numerator != 0)
else:
raise Exception("Unknown pop aggregation method")
self.series.append(Series(tvecs[result_label], vals, result_label, pop_name, output_name, data_label[output_name], units=aggregated_units[output_name], timescale=aggregated_timescales[output_name], data_pop=pop_name))
else:
vals = aggregated_outputs[pop][output_name]
self.series.append(Series(tvecs[result_label], vals, result_label, pop, output_name, data_label[output_name], units=aggregated_units[output_name], timescale=aggregated_timescales[output_name], data_pop=pop))
self.results = sc.odict()
for result in results:
self.results[result.name] = result.name
self.pops = sc.odict()
for pop in pops:
key = list(pop.keys())[0] if isinstance(pop, dict) else pop
self.pops[key] = _get_full_name(key, project) if project is not None else key
self.outputs = sc.odict()
for output in outputs:
key = list(output.keys())[0] if isinstance(output, dict) else output
self.outputs[key] = _get_full_name(key, project) if project is not None else key
# Handle time aggregation
if t_bins is not None:
self.time_aggregate(t_bins, time_aggregation)
if accumulate is not None:
self.accumulate(accumulate)
[docs]
def accumulate(self, accumulation_method) -> None:
"""
Accumulate values over time
Accumulation methods are
:param accumulation_method: Select whether to add or integrate. Supported methods are:
- 'sum' : runs `cumsum` on all quantities - should not be used if units are flow rates (so will check for a timescale).
Summation should be used for compartment-based quantities, such as DALYs
- 'integrate' : integrate using trapezoidal rule, assuming initial value of 0
Note that here there is no concept of 'dt' because we might have non-uniform time aggregation bins
Therefore, we need to use the time vector actually contained in the Series object (via `cumtrapz()`)
"""
# Note, in general we need to be able to explicitly specify the method to use, because we don't
# know how to deal with parameter functions that have unknown units
assert accumulation_method in ["sum", "integrate"]
for s in self.series:
if accumulation_method == "sum":
if not isna(s.timescale):
raise Exception('Quantity "%s" has timescale %g which means it should be accumulated by integration, not summation' % (s.output, s.timescale))
s.vals = np.cumsum(s.vals)
elif accumulation_method == "integrate":
if s.timescale:
s.vals = scipy.integrate.cumulative_trapezoid(s.vals, s.tvec / s.timescale)
else:
s.vals = scipy.integrate.cumulative_trapezoid(s.vals, s.tvec)
s.vals = np.insert(s.vals, 0, 0.0)
# If integrating a quantity with a timescale, then lose the timescale factor
# Otherwise, the units pick up a factor of time
if not isna(s.timescale):
s.timescale = None
else:
if s.units == "Number of people":
s.units = "Number of person-years"
else:
s.units += " years"
else:
raise Exception("Unknown accumulation type")
for k, v in self.outputs.items():
self.outputs[k] = "Cumulative " + v
[docs]
def time_aggregate(self, t_bins, time_aggregation=None, interpolation_method=None):
"""
Aggregate values over time
Note that *accumulation* is a running total, whereas *aggregation* refers to binning. The two can be
both be applied (aggregation should be performed prior to accumulation).
Normally, aggregation is performed when constructing a `PlotData` instance and this method does not need
to be manually called. However, in rare cases, it may be necessary to explicitly set the interpolation method.
Specifically, the interpolation method needs to match the underlying assumption for parameter values. For
parameter scenarios, this may require that the 'previous' method is used (to match the assumption in the parameter overwrite)
rather than relying on the standard assumption that databook quantities can be interpolated directly.
This method modifies the `PlotData` object in-place. However, the modified object is also returned, so that
time aggregation can be chained with other operations, the same as `PlotData.interpolate()`.
:param t_bins: Vector of bin edges OR a scalar bin size, which will be automatically expanded to a vector of bin edges
:param time_aggregation: can be 'integrate' or 'average'. Note that for quantities that have a timescale, flow parameters
in number units will be adjusted accordingly (e.g. a parameter in units of 'people/day'
aggregated over a 1 year period will display as the equivalent number of people that year)
:param interpolation_method: Assumption on how the quantity behaves in between timesteps - in general, 'linear' should be suitable for
most dynamic quantities, while 'previous' should be used for spending and other program-related quantities.
:return: The same modified `PlotData` instance
"""
assert time_aggregation in [None, "integrate", "average"]
assert interpolation_method in [None, "linear", "previous"]
if interpolation_method is None:
interpolation_method = "linear"
if not hasattr(t_bins, "__len__"):
# If a scalar bin is provided, then it is
if t_bins > (self.series[0].tvec[-1] - self.series[0].tvec[0]):
# If bin width is greater than the sim duration, treat it the same as aggregating over all times
t_bins = "all"
else:
if not (self.series[0].tvec[-1] - self.series[0].tvec[0]) % t_bins:
upper = self.series[0].tvec[-1] + t_bins
else:
upper = self.series[0].tvec[-1]
t_bins = np.arange(self.series[0].tvec[0], upper, t_bins)
elif len(t_bins) < 2:
raise Exception("If passing in t_bins as a list of bin edges, at least two values must be provided")
if sc.isstring(t_bins) and t_bins == "all":
t_bins = self.series[0].tvec[[0, -1]].ravel()
t_bins = sc.promotetoarray(t_bins)
lower = t_bins[0:-1]
upper = t_bins[1:]
for s in self.series:
# Decide automatic aggregation method if not specified - this is done on a per-quantity basis
if time_aggregation is None:
if s.units in {FS.QUANTITY_TYPE_DURATION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_FRACTION}:
method = "average"
else:
method = "integrate"
else:
method = time_aggregation
if method == "integrate" and s.units in {FS.QUANTITY_TYPE_DURATION, FS.QUANTITY_TYPE_PROBABILITY, FS.QUANTITY_TYPE_RATE, FS.QUANTITY_TYPE_PROPORTION, FS.QUANTITY_TYPE_FRACTION}:
logger.warning('Units for series "%s" are "%s" so time aggregation should probably be "average", not "integrate"', s, s.units)
if not isna(s.timescale):
scale = s.timescale
else:
scale = 1.0
# We interpolate in time-aggregation because the time bins are independent of the step size. In contrast,
# accumulation preserves the same time bins, so we don't need the interpolation step and instead go straight
# to summation or trapezoidal integration
max_step = 0.5 * min(np.diff(s.tvec)) # Subdivide for trapezoidal integration with at least 2 divisions per timestep. Could be a lot of memory for integrating daily timesteps over a full simulation, but unlikely to be prohibitive
vals = np.full(lower.shape, fill_value=np.nan)
for i, (l, u) in enumerate(zip(lower, upper)):
n = np.ceil((u - l) / max_step) + 1 # Add 1 so that in most cases, we can use the actual timestep values
t2 = np.linspace(l, u, int(n))
if interpolation_method == "linear":
v2 = np.interp(t2, s.tvec, s.vals, left=np.nan, right=np.nan) # Return NaN outside bounds - it should never be valid to use extrapolated output values in time aggregation
vals[i] = np.trapz(y=v2 / scale, x=t2) # Note division by timescale here, which annualizes it
elif interpolation_method == "previous":
v2 = scipy.interpolate.interp1d(s.tvec, s.vals, kind="previous", copy=False, assume_sorted=True, bounds_error=False, fill_value=(np.nan, np.nan))(t2)
vals[i] = sum(v2[:-1] / scale * np.diff(t2))
s.tvec = (lower + upper) / 2.0
if method == "integrate":
s.vals = np.array(vals)
# If integrating the units might change
if not isna(s.timescale):
# Any flow rates get integrated over the bin width, so change the timescale to None
# If the units were 'duration', this doesn't make sense, but integrating a duration doesn't
# make sense either. This would only happen if the user explicitly requests it anyway. For example,
# a parameter might go from 'number of people per month' to 'number of people'
s.timescale = None
else:
# For quantities that don't have a timescale and are being integrated, the scale is 1.0 and
# it picks up 'years' in the units. So for example, 'number of people' becomes 'number of person years'
# This would be the usage 99% of the time (esp. for DALYs that are interested in number of person-years)
if s.units == "Number of people":
s.units = "Number of person-years"
elif not isna(s.units):
s.units += " years"
else:
# If the units are none, decide what to do. It probably makes sense just to do nothing and
# leave the units blank, on the assumption that the user knows what they are doing if they
# are working with dimensionless quantities. More commonly, the quantity wouldn't actually
# be dimensionless, but it might not have had units entered e.g. parameter functions
pass
elif method == "average":
s.vals = np.array(vals) / np.diff(t_bins / scale) # Divide by bin width if averaging within the bins
s.units = "Average %s" % (s.units) # It will look odd to do 'Cumulative Average Number of people' but that's will accurately what the user has requested (combining aggregation and accumulation is permitted, but not likely to be necessary)
else:
raise Exception('Unknown time aggregation type "%s"' % (time_aggregation))
if sc.isstring(t_bins) and t_bins == "all":
s.t_labels = ["All"]
else:
s.t_labels = ["%d-%d" % (low, high) for low, high in zip(lower, upper)]
return self
def __repr__(self):
s = "PlotData\n"
s += "Results: {0}\n".format(self.results.keys())
s += "Pops: {0}\n".format(self.pops.keys())
s += "Outputs: {0}\n".format(self.outputs.keys())
return s
def __sub__(self, other):
"""
Difference between two instances
This function iterates over all Series and takes their difference.
The intended functionality is when wanting to compute the difference
of derived quantities between two results. It only functions clearly when
the only difference between two PlotData instances is the result they were
constructed on. For example, model usage would be
>>> a = PlotData(result1, outputs, pops)
>>> b = PlotData(result2, outputs, pops)
>>> c = a-b
Both PlotData instances must have
- The same pops
- The same outputs
- The same units (i.e. the same aggregation steps)
- The same time points
This method also incorporates singleton expansion for results, which means that one or both
of the PlotData instances can contain a single result instead of multiple results. The single
result will be applied against all of the results in the other PlotData instance, so for example
a single baseline result can be subtracted off a set of scenarios. Note that if both PlotData instances
have more than one result, then an error will be raised (because the result names don't have to match,
it is otherwise impossible to identify which pairs of results to subtract).
Series will be copied either from the PlotData instance that has multiple Results, or from the left :class:`PlotData` instance
if both instances have only one result. Thus, ensure that ordering, formatting, and
labels are set in advance on the appropriate object, if preserving the formatting is important. In practice, it would be usually
be best to operate on the :class:`PlotData` values first, before setting formatting etc.
:param other: A :class:`PlotData` instance to subtract off
:return: A new :class:`PlotData` instance
"""
assert isinstance(other, self.__class__), "PlotData subtraction can only operate on another PlotData instance"
assert set(self.pops) == set(other.pops), "PlotData subtraction requires both instances to have the same populations"
assert set(self.outputs) == set(other.outputs), "PlotData subtraction requires both instances to have the same populations"
assert np.array_equal(self.tvals()[0], other.tvals()[0])
if len(self.results) > 1 and len(other.results) > 1:
raise Exception("When subtracting PlotData instances, both of them cannot have more than one result")
elif len(other.results) > 1:
new = sc.dcp(other)
else:
new = sc.dcp(self)
new.results = sc.odict()
for s1 in new.series:
if len(other.results) > 1:
s2 = self[self.results[0], s1.pop, s1.output]
else:
s2 = other[other.results[0], s1.pop, s1.output]
assert s1.units == s2.units
assert s1.timescale == s2.timescale
if len(other.results) > 1:
# If `b` has more than one result, then `s1` is from `b` and `s2` is from `a`, so the values for `a-b` are `s2-s1`
s1.vals = s2.vals - s1.vals
s1.result = "%s-%s" % (s2.result, s1.result)
else:
s1.vals = s1.vals - s2.vals
s1.result = "%s-%s" % (s1.result, s2.result)
new.results[s1.result] = s1.result
return new
def __truediv__(self, other):
"""
Divide two instances
This function iterates over all Series and divides them. The original intention
is to use this functionality when wanting to compute fractional differences between
insteances. It only functions clearly when the only difference between two PlotData instances is the result they were
constructed on. For example, model usage would be
>>> a = PlotData(result1, outputs, pops)
>>> b = PlotData(result2, outputs, pops)
>>> c = (a-b)/a
Both PlotData instances must have
- The same pops
- The same outputs
- The same units (i.e. the same aggregation steps)
- The same time points
Series will be copied either from the PlotData instance that has multiple Results, or from the left :class:`PlotData` instance
if both instances have only one result. Thus, ensure that ordering, formatting, and
labels are set in advance on the appropriate object, if preserving the formatting is important. In practice, it would be usually
be best to operate on the :class:`PlotData` values first, before setting formatting etc.
:param other: A :class:`PlotData` instance to serve as denominator in division
:return: A new :class:`PlotData` instance
"""
assert isinstance(other, self.__class__), "PlotData subtraction can only operate on another PlotData instance"
assert set(self.pops) == set(other.pops), "PlotData subtraction requires both instances to have the same populations"
assert set(self.outputs) == set(other.outputs), "PlotData subtraction requires both instances to have the same populations"
assert np.array_equal(self.tvals()[0], other.tvals()[0])
if len(self.results) > 1 and len(other.results) > 1:
raise Exception("When subtracting PlotData instances, both of them cannot have more than one result")
elif len(other.results) > 1:
new = sc.dcp(other)
else:
new = sc.dcp(self)
new.results = sc.odict()
for s1 in new.series:
if len(other.results) > 1:
s2 = self[self.results[0], s1.pop, s1.output]
else:
s2 = other[other.results[0], s1.pop, s1.output]
assert s1.units == s2.units
assert s1.timescale == s2.timescale
if len(other.results) > 1:
# If `b` has more than one result, then `s1` is from `b` and `s2` is from `a`, so the values for `a-b` are `s2-s1`
s1.vals = s2.vals / s1.vals
s1.result = "%s/%s" % (s2.result, s1.result)
else:
s1.vals = s1.vals / s2.vals
s1.result = "%s/%s" % (s1.result, s2.result)
s1.units = ""
new.results[s1.result] = s1.result
return new
[docs]
@staticmethod
def programs(results, outputs=None, t_bins=None, quantity="spending", accumulate=None, nan_outside=False):
"""
Constructs a PlotData instance from program values
This alternate constructor can be used to plot program-related quantities such as spending or coverage.
:param results: single Result, or list of Results
:param outputs: specification of which programs to plot spending for. Can be:
- the name of a single program
- a list of program names
- aggregation dict e.g. {'treatment':['tx-1','tx-2']} or list of such dicts. Output aggregation type is automatically 'sum' for
program spending, and aggregation is NOT permitted for coverages (due to modality interactions)
:param t_bins: aggregate over time, using summation for spending and number coverage, and average for fraction/proportion coverage. Notice that
unlike the `PlotData()` constructor, this function does _not_ allow the time aggregation method to be manually set.
:param quantity: can be 'spending', 'coverage_number', 'coverage_eligible', or 'coverage_fraction'. The 'coverage_eligible' is
the sum of compartments reached by a program, such that coverage_fraction = coverage_number/coverage_eligible
:param accumulate: can be 'sum' or 'integrate'
:param nan_outside: If True, then values will be NaN outside the program start/stop year
:return: A new :class:`PlotData` instance
"""
# Sanitize the results input
if isinstance(results, sc.odict):
results = [result for _, result in results.items()]
elif isinstance(results, Result):
results = [results]
result_names = [x.name for x in results]
if len(set(result_names)) != len(result_names):
raise Exception("Results must have different names (in their result.name property)")
for result in results:
if result.model.progset is None:
raise Exception('Tried to plot program outputs for result "%s", but that result did not use programs' % result.name)
if outputs is None:
outputs = results[0].model.progset.programs.keys()
elif not isinstance(outputs, list):
outputs = [outputs]
outputs = _expand_dict(outputs)
assert quantity in ["spending", "equivalent_spending", "coverage_number", "coverage_eligible", "coverage_fraction", "coverage_capacity"]
# Make a new PlotData instance
# We are using __new__ because this method is to be formally considered an alternate constructor and
# thus bears responsibility for ensuring this new instance is initialized correctly
plotdata = PlotData.__new__(PlotData)
plotdata.series = []
# Because aggregations always occur within a Result object, loop over results
for result in results:
if quantity == "spending":
all_vals = result.get_alloc()
units = result.model.progset.currency
timescales = dict.fromkeys(all_vals, 1.0)
elif quantity == "equivalent_spending":
all_vals = result.get_equivalent_alloc()
units = result.model.progset.currency
timescales = dict.fromkeys(all_vals, 1.0)
elif quantity in {"coverage_capacity", "coverage_number"}:
if quantity == "coverage_capacity":
all_vals = result.get_coverage("capacity")
else:
all_vals = result.get_coverage("number")
units = "Number of people"
timescales = dict.fromkeys(all_vals, 1.0)
elif quantity == "coverage_eligible":
all_vals = result.get_coverage("eligible")
units = "Number of people"
timescales = dict.fromkeys(all_vals, None)
elif quantity == "coverage_fraction":
all_vals = result.get_coverage("fraction")
units = "Fraction covered"
timescales = dict.fromkeys(all_vals, None)
else:
raise Exception("Unknown quantity")
for output in outputs: # For each final output
if isinstance(output, dict): # If this is an aggregation over programs
if quantity in ["spending", "equivalent_spending"]:
output_name = list(output.keys())[0] # This is the aggregated name
labels = output[output_name] # These are the quantities being aggregated
# We only support summation for combining program spending, not averaging
vals = sum(all_vals[x] for x in labels)
output_name = output_name
data_label = None # No data present for aggregations
timescale = timescales[labels[0]]
else:
raise Exception("Cannot use program aggregation for anything other than spending yet")
else:
vals = all_vals[output]
output_name = output
data_label = output # Can look up program spending by the program name
timescale = timescales[output]
if nan_outside:
vals[(result.t < result.model.program_instructions.start_year) | (result.t > result.model.program_instructions.stop_year)] = np.nan
plotdata.series.append(Series(result.t, vals, result=result.name, pop=FS.DEFAULT_SYMBOL_INAPPLICABLE, output=output_name, data_label=data_label, units=units, timescale=timescale)) # The program should specify the units for its unit cost
plotdata.results = sc.odict()
for result in results:
plotdata.results[result.name] = result.name
plotdata.pops = sc.odict({FS.DEFAULT_SYMBOL_INAPPLICABLE: FS.DEFAULT_SYMBOL_INAPPLICABLE})
plotdata.outputs = sc.odict()
for output in outputs:
key = list(output.keys())[0] if isinstance(output, dict) else output
plotdata.outputs[key] = results[0].model.progset.programs[key].label if key in results[0].model.progset.programs else key
if t_bins is not None:
# TODO - time aggregation of coverage_number by integration should only be applied to one-off programs
# TODO - confirm time aggregation of spending is correct for the units entered in databook or in overwrites
if quantity in {"spending", "equivalent_spending", "coverage_number"}:
plotdata.time_aggregate(t_bins, "integrate", interpolation_method="previous")
elif quantity in {"coverage_eligible", "coverage_fraction"}:
plotdata.time_aggregate(t_bins, "average", interpolation_method="previous")
else:
raise Exception("Unknown quantity type for aggregation")
if accumulate is not None:
plotdata.accumulate(accumulate)
return plotdata
[docs]
def tvals(self):
"""
Return vector of time values
This method returns a vector of time values for the ``PlotData`` object, if all of the series have the
same time axis (otherwise it will throw an error). All series must have the same number of timepoints.
This will always be the case for a ``PlotData`` object unless the instance has been manually modified after construction.
:return: Tuple with (array of time values, array of time labels)
"""
assert len(set([len(x.tvec) for x in self.series])) == 1, "All series must have the same number of time points."
tvec = self.series[0].tvec
t_labels = self.series[0].t_labels
for i in range(1, len(self.series)):
assert all(np.equal(self.series[i].tvec, tvec)), "All series must have the same time points"
return tvec, t_labels
[docs]
def interpolate(self, new_tvec):
"""
Interpolate all ``Series`` onto new time values
This will modify all of the contained ``Series`` objects in-place.
The modified ``PlotData`` instance is also returned, so that interpolation and
construction can be performed in one line. i.e. both
>>> d = PlotData(result)
... d.interpolate(tvals)
and
>>> vals = PlotData(result).interpolate(tvals)
will work as intended.
:param new_tvec: Vector of new time values
:return: The modified `PlotData` instance
"""
new_tvec = sc.promotetoarray(new_tvec)
for series in self.series:
series.vals = series.interpolate(new_tvec)
series.tvec = np.copy(new_tvec)
series.t_labels = np.copy(new_tvec)
return self
[docs]
def __getitem__(self, key: tuple):
"""
Implement custom indexing
The :class:`Series` objects stored within :class:`PlotData` are each bound to a single
result, population, and output. This operator makes it possible to easily retrieve
a particular :class:`Series` instance. For example,
>>> d = PlotData(results)
... d['default','0-4','sus']
:param key: A tuple of (result,pop,output)
:return: A :class:`Series` instance
"""
for s in self.series:
if s.result == key[0] and s.pop == key[1] and s.output == key[2]:
return s
raise Exception("Series %s-%s-%s not found" % (key[0], key[1], key[2]))
[docs]
def set_colors(self, colors=None, results="all", pops="all", outputs="all", overwrite=False):
"""
Assign colors to quantities
This function facilitates assigned colors to the ``Series`` objects contained in this
``PlotData`` instance.
:param colors: Specify the colours to use. This can be
- A list of colours that applies to the list of all matching items
- A single colour to use for all matching items
- The name of a colormap to use (e.g., 'Blues')
:param results: A list of results to set colors for, or a dict of results where the key names the results (e.g. ``PlotData.results``)
:param pops: A list of pops to set colors for, or a dict of pops where the key names the pops (e.g. ``PlotData.pops``
:param outputs:A list of outputs to set colors for, or a dict of outputs where the key names the outputs (e.g. ``PlotData.outputs``)
:param overwrite: False (default) or True. If True, then any existing manually set colours will be overwritten
:return: The `PlotData` instance (also modified in-place)
Essentially, the lists of results, pops, and outputs are used to filter the ``Series`` resulting in a list of ``Series`` to operate on.
Then, the colors argument is applied to that list.
"""
if isinstance(results, dict):
results = results.keys()
else:
results = sc.promotetolist(results)
if isinstance(pops, dict):
pops = pops.keys()
else:
pops = sc.promotetolist(pops)
if isinstance(outputs, dict):
outputs = outputs.keys()
else:
outputs = sc.promotetolist(outputs)
targets = list(itertools.product(results, pops, outputs))
if colors is None:
colors = sc.gridcolors(len(targets)) # Default colors
elif isinstance(colors, list):
assert len(colors) == len(targets), "Number of colors must either be a string, or a list with as many elements as colors to set"
colors = colors
elif colors.startswith("#") or colors not in [m for m in plt.cm.datad if not m.endswith("_r")]:
colors = [colors for _ in range(len(targets))] # Apply color to all requested outputs
else:
color_norm = matplotlib_colors.Normalize(vmin=-1, vmax=len(targets))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap=colors)
colors = [scalar_map.to_rgba(index) for index in range(len(targets))]
# Now each of these colors gets assigned
for color, target in zip(colors, targets):
series = self.series
series = [x for x in series if (x.result == target[0] or target[0] == "all")]
series = [x for x in series if (x.pop == target[1] or target[1] == "all")]
series = [x for x in series if (x.output == target[2] or target[2] == "all")]
for s in series:
s.color = color if (s.color is None or overwrite) else s.color
return self
[docs]
class Series:
"""
Represent a plottable time series
A Series represents a quantity available for plotting. It is like a `TimeSeries` but contains
additional information only used for plotting, such as color.
:param tvec: array of time values
:param vals: array of values
:param result: name of the result associated with ths data
:param pop: name of the pop associated with the data
:param output: name of the output associated with the data
:param data_label: name of a quantity in project data to plot in conjunction with this `Series`
:param color: the color to render the `Series` with
:param units: the units for the values
:param timescale: For Number, Probability and Duration units, there are timescales associated with them
"""
def __init__(self, tvec, vals, result="default", pop="default", output="default", data_label="", color=None, units="", timescale=None, data_pop=""):
self.tvec = np.copy(tvec) # : array of time values
self.t_labels = np.copy(self.tvec) # : Iterable array of time labels - could be set to strings like [2010-2014]
self.vals = np.copy(vals) # : array of values
self.result = result # : name of the result associated with ths data
self.pop = pop # : name of the pop associated with the data
self.output = output # : name of the output associated with the data
self.color = color # : the color to render the `Series` with
self.data_label = data_label #: Used to identify data for plotting - should match the name of a data TDVE
self.data_pop = data_pop #: Used to identify which population in the TDVE (specified by ``data_label``) to look up
self.units = units #: The units for the quantity to display on the plot
#: If the quantity has a time-like denominator (e.g. number/year, probability/day) then the denominator is stored here (in units of years)
#: This enables quantities to be time-aggregated correctly (e.g. number/day must be converted to number/timestep prior to summation or integration)
#: For links, the timescale is normally just ``dt``. This also enables more rigorous checking for quantities with time denominators than checking
#: for a string like ``'/year'`` because users may not set this specifically.
self.timescale = timescale
if np.any(np.isnan(vals)):
logger.warning("%s contains NaNs", self)
@property
def unit_string(self) -> str:
"""
Return the units for the quantity including timescale
When making plots, it is useful for the axis label to have the units of the quantity. The units should
also include the time scale e.g. "Death rate (probability per day)". However, if the timescale changes
due to aggregation or accumulation, then the value might be different. In that case,
The unit of the quantity is interpreted as a numerator if the Timescale is not None. For example,
Compartments have units of 'number', while Links have units of 'number/timestep' which is stored as
``Series.units='number'`` and ``Series.timescale=0.25`` (if ``dt=0.25``). The `unit_string` attribute
returns a string that is suitable to use for plots e.g. 'number per week'.
:return: A string representation of the units for use in plotting
"""
if not isna(self.timescale):
if self.units == FS.QUANTITY_TYPE_DURATION:
return "%s" % (format_duration(self.timescale, True))
else:
return "%s per %s" % (self.units, format_duration(self.timescale))
else:
return self.units
def __repr__(self):
return "Series(%s,%s,%s)" % (self.result, self.pop, self.output)
[docs]
def interpolate(self, new_tvec):
"""
Return interpolated vector of values
This function returns an `np.array()` with the values of this series interpolated onto the requested
time array new_tvec. To ensure results are not misleading, extrapolation is disabled
and will return `NaN` if `new_tvec` contains values outside the original time range.
Note that unlike `PlotData.interpolate()`, `Series.interpolate()` does not modify the object but instead
returns the interpolated values. This makes the `Series` object more versatile (`PlotData` is generally
used only for plotting, but the `Series` object can be a convenient way to work with values computed using
the sophisticated aggregations within `PlotData`).
:param new_tvec: array of new time values
:return: array with interpolated values (same size as `new_tvec`)
"""
out_of_bounds = (new_tvec < self.tvec[0]) | (new_tvec > self.tvec[-1])
if np.any(out_of_bounds):
logger.warning("Series has values from %.2f to %.2f so requested time points %s are out of bounds", self.tvec[0], self.tvec[-1], new_tvec[out_of_bounds])
return np.interp(sc.promotetoarray(new_tvec), self.tvec, self.vals, left=np.nan, right=np.nan)
[docs]
def plot_bars(plotdata, stack_pops=None, stack_outputs=None, outer=None, legend_mode=None, show_all_labels=False, orientation="vertical") -> list:
"""
Produce a bar plot
:param plotdata: a :class:`PlotData` instance to plot
:param stack_pops: A list of lists with populations to stack. A bar is rendered for each item in the list.
For example, `[['0-4','5-14'],['15-64']]` will render two bars, with two populations stacked
in the first bar, and only one population in the second bar. Items not appearing in this list
will be rendered unstacked.
:param stack_outputs: Same as `stack_pops`, but for outputs.
:param outer: Optionally select whether the outermost/highest level of grouping is by `'times'` or by `'results'`
:param legend_mode: override the default legend mode in settings
:param show_all_labels: If True, then inner/outer labels will be shown even if there is only one label
:param orientation: 'vertical' (default) or 'horizontal'
:return: A list of newly created Figures
"""
global settings
if legend_mode is None:
legend_mode = settings["legend_mode"]
assert outer in [None, "times", "results"], 'Supported outer groups are "times" or "results"'
assert orientation in ["vertical", "horizontal"], 'Supported orientations are "vertical" or "horizontal"'
if outer is None:
if len(plotdata.results) == 1:
# If there is only one Result, then use 'outer=results' so that times can be promoted to axis labels
outer = "results"
else:
outer = "times"
plotdata = sc.dcp(plotdata)
# Note - all of the tvecs must be the same
tvals, t_labels = plotdata.tvals() # We have to iterate over these, with offsets, if there is more than one
# If quantities are stacked, then they need to be coloured differently.
if stack_pops is None:
color_by = "outputs"
plotdata.set_colors(outputs=plotdata.outputs.keys())
elif stack_outputs is None:
color_by = "pops"
plotdata.set_colors(pops=plotdata.pops.keys())
else:
color_by = "both"
plotdata.set_colors(pops=plotdata.pops.keys(), outputs=plotdata.outputs.keys())
def process_input_stacks(input_stacks, available_items):
# Sanitize the input. input stack could be
# - A list of stacks, where a stack is a list of pops or a string with a single pop
# - A dict of stacks, where the key is the name, and the value is a list of pops or a string with a single pop
# - None, in which case all available items are used
# - 'all' in which case all of the items appear in a single stack
#
# The return value `output_stacks` is a list of tuples where
# (a,b,c)
# a - The automatic name
# b - User provided manual name
# c - List of pop labels
# Same for outputs
if input_stacks is None:
return [(x, "", [x]) for x in available_items]
elif input_stacks == "all":
# Put all available items into a single stack
return process_input_stacks([available_items], available_items)
items = set()
output_stacks = []
if isinstance(input_stacks, list):
for x in input_stacks:
if isinstance(x, list):
output_stacks.append(("", "", x) if len(x) > 1 else (x[0], "", x))
items.update(x)
elif sc.isstring(x):
output_stacks.append((x, "", [x]))
items.add(x)
else:
raise Exception("Unsupported input")
elif isinstance(input_stacks, dict):
for k, x in input_stacks.items():
if isinstance(x, list):
output_stacks.append(("", k, x) if len(x) > 1 else (x[0], k, x))
items.update(x)
elif sc.isstring(x):
output_stacks.append((x, k, [x]))
items.add(x)
else:
raise Exception("Unsupported input")
# Add missing items
missing = list(set(available_items) - items)
output_stacks += [(x, "", [x]) for x in missing]
return output_stacks
pop_stacks = process_input_stacks(stack_pops, plotdata.pops.keys())
output_stacks = process_input_stacks(stack_outputs, plotdata.outputs.keys())
# Now work out which pops and outputs appear in each bar (a bar is a pop-output combo)
bar_pops = []
bar_outputs = []
for pop in pop_stacks:
for output in output_stacks:
bar_pops.append(pop)
bar_outputs.append(output)
width = settings["bar_width"]
gaps = [0.1, 0.4, 0.8] # Spacing within blocks, between inner groups, and between outer groups
block_width = len(bar_pops) * (width + gaps[0])
# If there is only one bar group, then increase spacing between bars
if len(tvals) == 1 and len(plotdata.results) == 1:
gaps[0] = 0.3
if outer == "times":
if len(plotdata.results) == 1: # If there is only one inner group
gaps[2] = gaps[1]
gaps[1] = 0
result_offset = block_width + gaps[1]
tval_offset = len(plotdata.results) * (block_width + gaps[1]) + gaps[2]
iterator = nested_loop([range(len(plotdata.results)), range(len(tvals))], [0, 1])
elif outer == "results":
if len(tvals) == 1: # If there is only one inner group
gaps[2] = gaps[1]
gaps[1] = 0
result_offset = len(tvals) * (block_width + gaps[1]) + gaps[2]
tval_offset = block_width + gaps[1]
iterator = nested_loop([range(len(plotdata.results)), range(len(tvals))], [1, 0])
else:
raise Exception('outer option must be either "times" or "results"')
figs = []
fig, ax = plt.subplots()
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
fig.set_label("bars")
figs.append(fig)
rectangles = defaultdict(list) # Accumulate the list of rectangles for each colour
color_legend = sc.odict()
# NOTE
# pops, output - colour separates them. To merge colours, aggregate the data first
# results, time - spacing separates them. Can choose to group by one or the other
# Now, there are three levels of ticks
# There is the within-block level, the inner group, and the outer group
block_labels = [] # Labels for individual bars (tick labels)
inner_labels = [] # Labels for bar groups below axis
block_offset = None
base_offset = None
negative_present = False # If True, it means negative quantities were present
# Iterate over the inner and outer groups, rendering blocks at a time
for r_idx, t_idx in iterator:
base_offset = r_idx * result_offset + t_idx * tval_offset # Offset between outer groups
block_offset = 0.0 # Offset between inner groups
if outer == "results":
inner_labels.append((base_offset + block_width / 2.0, t_labels[t_idx]))
elif outer == "times":
inner_labels.append((base_offset + block_width / 2.0, plotdata.results[r_idx]))
for idx, bar_pop, bar_output in zip(range(len(bar_pops)), bar_pops, bar_outputs):
# pop is something like ['0-4','5-14'] or ['0-4']
# output is something like ['sus','vac'] or ['0-4'] depending on the stack
y0 = [0, 0] # Baselines for positive and negative bars, respectively
# Set the name of the bar
# If the user provided a label, it will always be displayed
# In addition, if there is more than one label of the other (output/pop) type,
# then that label will also be shown, otherwise it will be suppressed
if bar_pop[1] or bar_output[1]:
if bar_pop[1]:
if bar_output[1]:
bar_label = "%s\n%s" % (bar_pop[1], bar_output[1])
elif len(output_stacks) > 1 and len(set([x[0] for x in output_stacks])) > 1 and bar_output[0]:
bar_label = "%s\n%s" % (bar_pop[1], bar_output[0])
else:
bar_label = bar_pop[1]
else:
if len(pop_stacks) > 1 and len(set([x[0] for x in pop_stacks])) > 1 and bar_pop[0]:
bar_label = "%s\n%s" % (bar_pop[0], bar_output[1])
else:
bar_label = bar_output[1]
else:
if color_by == "outputs" and len(pop_stacks) > 1 and len(set([x[0] for x in pop_stacks])) > 1:
bar_label = plotdata.pops[bar_pop[0]]
elif color_by == "pops" and len(output_stacks) > 1 and len(set([x[0] for x in output_stacks])) > 1:
bar_label = plotdata.outputs[bar_output[0]]
else:
bar_label = ""
for pop in bar_pop[2]:
for output in bar_output[2]:
series = plotdata[plotdata.results[r_idx], pop, output]
y = series.vals[t_idx]
if y >= 0:
baseline = y0[0]
y0[0] += y
height = y
else:
baseline = y0[1] + y
y0[1] += y
height = -y
negative_present = True
if orientation == "horizontal":
rectangles[series.color].append(Rectangle((baseline, base_offset + block_offset), height, width))
else:
rectangles[series.color].append(Rectangle((base_offset + block_offset, baseline), width, height))
if series.color in color_legend and (pop, output) not in color_legend[series.color]:
color_legend[series.color].append((pop, output))
elif series.color not in color_legend:
color_legend[series.color] = [(pop, output)]
block_labels.append((base_offset + block_offset + width / 2.0, bar_label))
block_offset += width + gaps[0]
# Add the patches to the figure and assemble the legend patches
legend_patches = []
for color, items in color_legend.items():
pc = PatchCollection(rectangles[color], facecolor=color, edgecolor="none")
ax.add_collection(pc)
pops = set([x[0] for x in items])
outputs = set([x[1] for x in items])
if pops == set(plotdata.pops.keys()) and len(outputs) == 1: # If the same color is used for all pops and always the same output
label = plotdata.outputs[items[0][1]] # Use the output name
elif outputs == set(plotdata.outputs.keys()) and len(pops) == 1: # Same color for all outputs and always same pop
label = plotdata.pops[items[0][0]] # Use the pop name
else:
label = ""
for x in items:
label += "%s-%s,\n" % (plotdata.pops[x[0]], plotdata.outputs[x[1]])
label = label.strip()[:-1] # Replace trailing newline and comma
legend_patches.append(Patch(facecolor=color, label=label))
# Set axes now, because we need block_offset and base_offset after the loop
ax.autoscale()
_turn_off_border(ax)
block_labels = sorted(block_labels, key=lambda x: x[0])
if orientation == "horizontal":
ax.set_ylim(bottom=-2 * gaps[0], top=block_offset + base_offset)
fig.set_figheight(0.75 + 0.75 * (block_offset + base_offset))
if not negative_present:
ax.set_xlim(left=0)
else:
ax.spines["right"].set_color("k")
ax.spines["right"].set_position("zero")
ax.set_yticks([x[0] for x in block_labels])
ax.set_yticklabels([x[1] for x in block_labels])
ax.invert_yaxis()
sc.SIticks(ax=ax, axis="x")
else:
ax.set_xlim(left=-2 * gaps[0], right=block_offset + base_offset)
fig.set_figwidth(1.1 + 1.1 * (block_offset + base_offset))
if not negative_present:
ax.set_ylim(bottom=0)
else:
ax.spines["top"].set_color("k")
ax.spines["top"].set_position("zero")
ax.set_xticks([x[0] for x in block_labels])
ax.set_xticklabels([x[1] for x in block_labels])
sc.SIticks(ax=ax, axis="y")
# Calculate the units. As all bar patches are shown on the same axis, they are all expected to have the
# same units. If they do not, the plot could be misleading
units = list(set([x.unit_string for x in plotdata.series]))
if len(units) == 1 and not isna(units[0]):
if orientation == "horizontal":
ax.set_xlabel(units[0].capitalize())
else:
ax.set_ylabel(units[0].capitalize())
elif len(units) > 1:
logger.warning("Warning - bar plot quantities mix units, double check that output selection is correct")
# Outer group labels are only displayed if there is more than one group
if outer == "times" and (show_all_labels or len(tvals) > 1):
offset = 0.0
for t in t_labels:
# Can't use title() here, there are usually more than one of these labels and they need to be positioned
# at the particular axis value where the block of bars appear. Also, it would be common that the plot still
# needs a title in addition to these (these outer labels are essentially tertiary axis ticks, not a title for the plot)
if orientation == "horizontal":
ax.text(1, offset + (tval_offset - gaps[1] - gaps[2]) / 2, t, transform=ax.get_yaxis_transform(), verticalalignment="center", horizontalalignment="left")
else:
ax.text(offset + (tval_offset - gaps[1] - gaps[2]) / 2, 1, t, transform=ax.get_xaxis_transform(), verticalalignment="bottom", horizontalalignment="center")
offset += tval_offset
elif outer == "results" and (show_all_labels or len(plotdata.results) > 1):
offset = 0.0
for r in plotdata.results:
if orientation == "horizontal":
ax.text(1, offset + (result_offset - gaps[1] - gaps[2]) / 2, plotdata.results[r], transform=ax.get_yaxis_transform(), verticalalignment="center", horizontalalignment="left")
else:
ax.text(offset + (result_offset - gaps[1] - gaps[2]) / 2, 1, plotdata.results[r], transform=ax.get_xaxis_transform(), verticalalignment="bottom", horizontalalignment="center")
offset += result_offset
# If there are no block labels (e.g. due to stacking) and the number of inner labels matches the number of bars, then promote the inner group
# labels and use them as bar labels
if not any([x[1] for x in block_labels]) and len(block_labels) == len(inner_labels):
if orientation == "horizontal":
ax.set_yticks([x[0] for x in inner_labels])
ax.set_yticklabels([x[1] for x in inner_labels])
else:
ax.set_xticks([x[0] for x in inner_labels])
ax.set_xticklabels([x[1] for x in inner_labels])
elif show_all_labels or (len(inner_labels) > 1 and len(set([x for _, x in inner_labels])) > 1):
# Otherwise, if there is only one inner group AND there are bar labels, don't show the inner group labels unless show_all_labels is True
if orientation == "horizontal":
ax2 = ax.twinx() # instantiate a second axes that shares the same y-axis
ax2.set_yticks([x[0] for x in inner_labels])
# TODO - At the moment there is a chance these labels will overlap, need to increase the offset somehow e.g. padding with spaces
# Best to leave this until a specific test case arises
# Simply rotating doesn't work because the vertical labels also overlap with the original axis labels
# So would be necessary to apply some offset as well (perhaps from YAxis.get_text_widths)
ax2.set_yticklabels([str(x[1]) for x in inner_labels])
ax2.yaxis.set_ticks_position("left")
ax2.set_ylim(ax.get_ylim())
else:
ax2 = ax.twiny() # instantiate a second axes that shares the same x-axis
ax2.set_xticks([x[0] for x in inner_labels])
ax2.set_xticklabels(["\n\n" + str(x[1]) for x in inner_labels])
ax2.xaxis.set_ticks_position("bottom")
ax2.set_xlim(ax.get_xlim())
ax2.tick_params(axis="both", which="both", length=0)
ax2.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["left"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
fig.tight_layout() # Do a final resizing
# Do the legend last, so repositioning the axes works properly
if legend_mode == "together":
_render_legend(ax, plot_type="bar", handles=legend_patches)
elif legend_mode == "separate":
figs.append(sc.separatelegend(handles=legend_patches, reverse=True))
return figs
[docs]
def plot_series(plotdata, plot_type="line", axis=None, data=None, legend_mode=None, lw=None, n_cols: int = None) -> list:
"""
Produce a time series plot
:param plotdata: a :class:`PlotData` instance to plot
:param plot_type: 'line', 'stacked', or 'proportion' (stacked, normalized to 1)
:param axis: Specify which quantity to group outputs on plots by - can be 'outputs', 'results', or 'pops'. A line will
be drawn for each of the selected quantity, and any other quantities will appear as separate figures.
:param data: Draw scatter points for data wherever the output label matches a data label. Only draws data if the plot_type is 'line'
:param legend_mode: override the default legend mode in settings
:param lw: override the default line width
:param n_cols: If None (default), separate figures will be created for each axis. If provided, axes will be tiled as subplots in a single figure
window with the requested number of columns
:return: A list of newly created Figures
"""
global settings
if legend_mode is None:
legend_mode = settings["legend_mode"]
if lw is None:
lw = settings["line_width"]
if axis is None:
axis = "outputs"
assert axis in ["outputs", "results", "pops"]
subplots = bool(n_cols) # If True, use subplots
def _prepare_figures(dim1, dim2, n_cols):
n_figs = len(dim1) * len(dim2) + (1 if legend_mode == "separate" else 0)
if subplots:
# Use subplots
n_cols = int(n_cols)
n_rows = int(np.ceil(n_figs / n_cols))
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, squeeze=False, sharex="all")
fig.set_label("series")
size = fig.get_size_inches()
fig.set_size_inches(size[0] * n_cols, size[1] * n_rows)
figs = [fig]
axes = axes.ravel()
for i in range(n_figs, len(axes)):
axes[i].remove()
else:
figs = []
axes = []
for i in range(n_figs):
fig, ax = plt.subplots()
figs.append(fig)
axes.append(ax)
for fig in figs:
fig.patch.set_alpha(0)
for ax in axes:
ax.patch.set_alpha(0)
return figs[:n_figs], axes[:n_figs]
plotdata = sc.dcp(plotdata)
if min([len(s.vals) for s in plotdata.series]) == 1:
logger.warning("At least one Series has only one timepoint. Series must have at least 2 time points to be rendered as a line - `plot_bars` may be more suitable for such data")
if axis == "results":
plotdata.set_colors(results=plotdata.results.keys())
figs, axes = _prepare_figures(plotdata.pops, plotdata.outputs, n_cols)
for i, (pop, output) in enumerate(itertools.product(plotdata.pops.keys(), plotdata.outputs.keys())):
ax = axes[i]
if not subplots:
figs[i].set_label("%s_%s" % (pop, output))
units = list(set([plotdata[result, pop, output].unit_string for result in plotdata.results]))
if len(units) == 1 and not isna(units[0]) and units[0]:
ax.set_ylabel("%s (%s)" % (plotdata.outputs[output], units[0]))
else:
ax.set_ylabel("%s" % (plotdata.outputs[output]))
if plotdata.pops[pop] != FS.DEFAULT_SYMBOL_INAPPLICABLE:
ax.set_title("%s" % (plotdata.pops[pop]))
if plot_type in ["stacked", "proportion"]:
y = np.stack([plotdata[result, pop, output].vals for result in plotdata.results])
y = y / np.sum(y, axis=0) if plot_type == "proportion" else y
ax.stackplot(plotdata[plotdata.results.keys()[0], pop, output].tvec, y, labels=[plotdata.results[x] for x in plotdata.results], colors=[plotdata[result, pop, output].color for result in plotdata.results])
if plot_type == "stacked" and data is not None:
_stack_data(ax, data, [plotdata[result, pop, output] for result in plotdata.results])
else:
for i, result in enumerate(plotdata.results):
ax.plot(plotdata[result, pop, output].tvec, plotdata[result, pop, output].vals, color=plotdata[result, pop, output].color, label=plotdata.results[result], lw=lw)
if data is not None and i == 0:
_render_data(ax, data, plotdata[result, pop, output])
_apply_series_formatting(ax, plot_type)
if legend_mode == "together":
_render_legend(ax, plot_type)
elif axis == "pops":
plotdata.set_colors(pops=plotdata.pops.keys())
figs, axes = _prepare_figures(plotdata.results, plotdata.outputs, n_cols)
for i, (result, output) in enumerate(itertools.product(plotdata.results.keys(), plotdata.outputs.keys())):
ax = axes[i]
if not subplots:
figs[i].set_label("%s_%s" % (result, output))
units = list(set([plotdata[result, pop, output].unit_string for pop in plotdata.pops]))
if len(units) == 1 and not isna(units[0]) and units[0]:
ax.set_ylabel("%s (%s)" % (plotdata.outputs[output], units[0]))
else:
ax.set_ylabel("%s" % (plotdata.outputs[output]))
ax.set_title("%s" % (plotdata.results[result]))
if plot_type in ["stacked", "proportion"]:
y = np.stack([plotdata[result, pop, output].vals for pop in plotdata.pops])
y = y / np.sum(y, axis=0) if plot_type == "proportion" else y
ax.stackplot(plotdata[result, plotdata.pops.keys()[0], output].tvec, y, labels=[plotdata.pops[x] for x in plotdata.pops], colors=[plotdata[result, pop, output].color for pop in plotdata.pops])
if plot_type == "stacked" and data is not None:
_stack_data(ax, data, [plotdata[result, pop, output] for pop in plotdata.pops])
else:
for pop in plotdata.pops:
ax.plot(plotdata[result, pop, output].tvec, plotdata[result, pop, output].vals, color=plotdata[result, pop, output].color, label=plotdata.pops[pop], lw=lw)
if data is not None:
_render_data(ax, data, plotdata[result, pop, output])
_apply_series_formatting(ax, plot_type)
if legend_mode == "together":
_render_legend(ax, plot_type)
elif axis == "outputs":
plotdata.set_colors(outputs=plotdata.outputs.keys())
figs, axes = _prepare_figures(plotdata.results, plotdata.pops, n_cols)
for i, (result, pop) in enumerate(itertools.product(plotdata.results.keys(), plotdata.pops.keys())):
ax = axes[i]
if not subplots:
figs[i].set_label("%s_%s" % (result, pop))
units = list(set([plotdata[result, pop, output].unit_string for output in plotdata.outputs]))
if len(units) == 1 and not isna(units[0]) and units[0]:
ax.set_ylabel(units[0][0].upper() + units[0][1:])
if plotdata.pops[pop] != FS.DEFAULT_SYMBOL_INAPPLICABLE:
ax.set_title("%s-%s" % (plotdata.results[result], plotdata.pops[pop]))
else:
ax.set_title("%s" % (plotdata.results[result]))
if plot_type in ["stacked", "proportion"]:
y = np.stack([plotdata[result, pop, output].vals for output in plotdata.outputs])
y = y / np.sum(y, axis=0) if plot_type == "proportion" else y
ax.stackplot(plotdata[result, pop, plotdata.outputs.keys()[0]].tvec, y, labels=[plotdata.outputs[x] for x in plotdata.outputs], colors=[plotdata[result, pop, output].color for output in plotdata.outputs])
if plot_type == "stacked" and data is not None:
_stack_data(ax, data, [plotdata[result, pop, output] for output in plotdata.outputs])
else:
for output in plotdata.outputs:
ax.plot(plotdata[result, pop, output].tvec, plotdata[result, pop, output].vals, color=plotdata[result, pop, output].color, label=plotdata.outputs[output], lw=lw)
if data is not None:
_render_data(ax, data, plotdata[result, pop, output])
_apply_series_formatting(ax, plot_type)
if legend_mode == "together":
_render_legend(ax, plot_type)
else:
raise Exception('axis option must be one of "results", "pops" or "outputs"')
if legend_mode == "separate":
reverse_legend = True if plot_type in ["stacked", "proportion"] else False
if not subplots:
# Replace the last figure with a legend figure
plt.close(figs[-1]) # TODO - update Sciris to allow passing in an existing figure
figs[-1] = sc.separatelegend(ax, reverse=reverse_legend)
else:
legend_ax = axes[-1]
handles, labels = ax.get_legend_handles_labels()
legend_ax.set_axis_off() # Hide axis lines
if reverse_legend: # pragma: no cover
handles = handles[::-1]
labels = labels[::-1]
legend_ax.legend(handles=handles, labels=labels, loc="center", framealpha=0)
return figs
def _stack_data(ax, data, series) -> None:
"""
Internal function to stack series data
Used by `plot_series` when rendering stacked plots and also showing data.
"""
baselines = np.cumsum(np.stack([s.vals for s in series]), axis=0)
baselines = np.vstack([np.zeros((1, baselines.shape[1])), baselines]) # Insert row of zeros for first data row
for i, s in enumerate(series):
_render_data(ax, data, s, baselines[i, :], True)
def _render_data(ax, data, series, baseline=None, filled=False) -> None:
"""
Renders a scatter plot for a single variable in a single population
:param ax: axis object that data will be rendered in
:param data: a ProjectData instance containing the data to render
:param series: a `Series` object, the 'pop' and 'data_label' attributes are used to extract the TimeSeries from the data
:param baseline: adds an offset to the data e.g. for stacked plots
:param filled: fill the marker with a solid fill e.g. for stacked plots
"""
ts = data.get_ts(series.data_label, series.data_pop)
if ts is None:
return
if not ts.has_time_data:
return
t, y = ts.get_arrays()
if baseline is not None:
y_data = np.interp(sc.promotetoarray(t), series.tvec, baseline, left=np.nan, right=np.nan)
y = y + y_data
if filled:
ax.scatter(t, y, marker="o", s=40, linewidths=1, facecolors=series.color, color="k") # label='Data %s %s' % (name(pop,proj),name(output,proj)))
else:
ax.scatter(t, y, marker="o", s=40, linewidths=settings["marker_edge_width"], facecolors="none", color=series.color) # label='Data %s %s' % (name(pop,proj),name(output,proj)))
def _apply_series_formatting(ax, plot_type) -> None:
# This function applies formatting that is common to all Series plots
# (irrespective of the 'axis' setting)
ax.autoscale(enable=True, axis="x", tight=True)
ax.set_xlabel("Year")
ax.set_ylim(bottom=0)
_turn_off_border(ax)
if plot_type == "proportion":
ax.set_ylim(top=1)
ax.set_ylabel("Proportion " + ax.get_ylabel())
else:
ax.set_ylim(top=ax.get_ylim()[1] * 1.05)
sc.SIticks(ax=ax, axis="y")
def _turn_off_border(ax) -> None:
"""
Turns off top and right borders.
Note that this function will leave the bottom and left borders on.
:param ax: An axis object
:return: None
"""
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")
[docs]
def plot_legend(entries: dict, plot_type=None, fig=None, legendsettings: dict = None):
"""
Render a new legend
:param entries: Dict where key is the label and value is the colour e.g. `{'sus':'blue','vac':'red'}`
:param plot_type: Optionally specify 'patch', 'line', 'circle', or a list the same length as param_entries containing these values
:param fig: Optionally takes in the figure to render the legend in. If not provided, a new figure will be created
:param legendsettings: settings for the layout of the legend. If not provided will default to appropriate values depending on whether the legend is separate or together with a plot
:return: The matplotlib `Figure` object containing the legend
"""
if plot_type is None:
plot_type = "line"
plot_type = sc.promotetolist(plot_type)
if len(plot_type) == 1:
plot_type = plot_type * len(entries)
assert len(plot_type) == len(entries), "If plot_type is a list, it must have the same number of values as there are entries in the legend (%s vs %s)" % (plot_type, entries)
h = []
for (label, color), p_type in zip(entries.items(), plot_type):
if p_type == "patch":
h.append(Patch(color=color, label=label))
elif p_type == "line":
h.append(Line2D([0], [0], linewidth=settings["line_width"], color=color, label=label))
elif p_type == "circle":
h.append(Line2D([0], [0], marker="o", linewidth=0, markeredgewidth=settings["marker_edge_width"], fillstyle="none", color=color, label=label))
else:
raise Exception(f'Unknown plot type "{p_type}"')
if fig is None: # Draw in a new figure
fig = sc.separatelegend(handles=h, legendsettings=legendsettings)
else:
existing_legend = fig.findobj(Legend)
if existing_legend and existing_legend[0].parent is fig: # If existing legend and this is a separate legend fig
existing_legend[0].remove() # Delete the old legend
if legendsettings is None:
legendsettings = {"loc": "center", "bbox_to_anchor": None, "frameon": False} # Settings for separate legend
fig.legend(handles=h, **legendsettings)
else: # Drawing into an existing figure
ax = fig.axes[0]
if legendsettings is None:
legendsettings = {"loc": "center left", "bbox_to_anchor": (1.05, 0.5), "ncol": 1}
if existing_legend:
existing_legend[0].remove() # Delete the old legend
ax.legend(handles=h, **legendsettings)
else:
ax.legend(handles=h, **legendsettings)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
return fig
def _render_legend(ax, plot_type=None, handles=None) -> None:
"""
Internal function to render a legend
:param ax: Axis in which to create the legend
:param plot_type: Used to decide whether to reverse the legend order for stackplots
:param handles: The handles of the objects to enter in the legend. Labels should be stored in the handles
"""
if handles is None:
handles, labels = ax.get_legend_handles_labels()
else:
labels = [h.get_label() for h in handles]
legendsettings = {"loc": "center left", "bbox_to_anchor": (1.05, 0.5), "ncol": 1, "framealpha": 0}
# labels = [textwrap.fill(label, 16) for label in labels]
if plot_type in ["stacked", "proportion", "bar"]:
ax.legend(handles=handles[::-1], labels=labels[::-1], **legendsettings)
else:
ax.legend(handles=handles, labels=labels, **legendsettings)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
[docs]
def reorder_legend(figs, order=None) -> None:
"""
Change the order of an existing legend
:param figs: Figure, or list of figures, containing legends for which the order should be changed
:param order: Specification of the order in which to render the legend entries. This can be
- The string `'reverse'` which will reverse the order of the legend
- A list of indices mapping old position to new position. For example, if the
original label order was ['a,'b','c'], then order=[1,0,2] would result in ['b','a','c'].
If a partial list is provided, then only a subset of the legend entries will appear. This
allows this function to be used to remove legend entries as well.
"""
if isinstance(figs, list):
for fig in figs: # Apply order operation to all figures passed in
reorder_legend(fig, order=order)
return
else:
fig = figs
legend = fig.findobj(Legend)[0]
assert len(legend._legend_handle_box._children) == 1, "Only single-column legends are supported"
vpacker = legend._legend_handle_box._children[0]
if order is None:
return
elif order == "reverse":
try:
# matplotlib<3.8
order = range(len(legend.legendHandles) - 1, -1, -1)
except AttributeError:
# matplotlib>=3.9
order = range(len(legend.legend_handles) - 1, -1, -1)
else:
assert max(order) < len(vpacker._children), "Requested index greater than number of legend entries"
new_children = []
for i in range(0, len(order)):
new_children.append(vpacker._children[order[i]])
vpacker._children = new_children
[docs]
def relabel_legend(figs, labels) -> None:
"""
Change the labels on an existing legend
:param figs: Figure, or list of figures, to change labels in
:param labels: `list` of labels the same length as the number of legend labels OR a `dict` of labels where the key is the index
of the labels to change. The `dict` input option makes it possible to change only a subset of the labels.
"""
if isinstance(figs, list):
for fig in figs:
relabel_legend(fig, labels=labels)
return
else:
fig = figs
legend = fig.findobj(Legend)[0]
assert len(legend._legend_handle_box._children) == 1, "Only single-column legends are supported"
vpacker = legend._legend_handle_box._children[0]
if isinstance(labels, list):
assert len(labels) == len(vpacker._children), "If specifying list of labels, length must match number of legend entries"
labels = {i: l for i, l in enumerate(labels)}
elif isinstance(labels, dict):
idx = labels.keys()
assert max(idx) < len(vpacker._children), "Requested index greater than number of legend entries"
else:
raise Exception("Labels must be a list or a dict")
for idx, label in labels.items():
text = vpacker._children[idx]._children[1]._text
text.set_text(label)
def _get_full_name(code_name: str, proj=None) -> str:
"""
Return the label of an object retrieved by name
If a :class:`Project` has been provided, code names can be converted into
labels for plotting. This function is different to `framework.get_label()` though,
because it supports converting population names to labels as well (this information is
in the project's data, not in the framework), and it also supports converting
link syntax (e.g. `sus:vac`) into full names as well. Note also that this means that the strings
returned by `_get_full_name` can be as specific as necessary for plotting.
:param code_name: The code name for a variable (e.g. `'sus'`, `'pris'`, `'sus:vac'`)
:param proj: Optionally specify a :class:`Project` instance
:return: If a project was provided, returns the full name. Otherwise, just returns the code name
"""
if proj is None:
return code_name
if code_name in proj.data.pops:
return proj.data.pops[code_name]["label"] # Convert population
if ":" in code_name: # We are parsing a link
# Handle Links specified with colon syntax
output_tokens = code_name.split(":")
if len(output_tokens) == 2:
output_tokens.append("")
src, dest, par = output_tokens
# If 'par_name:flow' syntax was used
if dest == "flow":
if src in proj.framework:
return "{0} (flow)".format(proj.framework.get_label(src))
else:
return "{0} (flow)".format(src)
if src and src in proj.framework:
src = proj.framework.get_label(src)
if dest and dest in proj.framework:
dest = proj.framework.get_label(dest)
if par and par in proj.framework:
par = proj.framework.get_label(par)
full = "Flow"
if src:
full += " from {}".format(src)
if dest:
full += " to {}".format(dest)
if par:
full += " ({})".format(par)
return full
else:
if code_name in proj.framework:
return proj.framework.get_label(code_name)
else:
return code_name
def _expand_dict(x: list) -> list:
"""
Expand a dict with multiple keys into a list of single-key dicts
An aggregation is defined as a mapping of multiple outputs into a single
variable with a single label. This is represented by a dict with a single key,
where the key is the label of the new quantity, and the value represents the instructions
for how to compute the quantity. Sometimes outputs and pops are used directly, without
renaming, so in this case, only the string representing the name of the quantity is required.
Therefore, the format used internally by `PlotData` is that outputs/pops are represented
as lists with length equal to the total number of quantities being returned/computed, and
that list can contain dictionaries with single keys whenever an aggregation is required.
For ease of use, it is convenient for users to enter multiple aggregations as a single dict
with multiple keys. This function processes such a dict into the format used internally
by PlotData.
:param x: A list of inputs, containing strings or dicts that might have multiple keys
:return: A list containing strings or dicts where any dicts have only one key
Example usage:
>>> _expand_dict(['a',{'b':1,'c':2}])
['a', {'b': 1}, {'c': 2}]
"""
# If a list contains a dict with multiple keys, expand it into multiple dicts each
# with a single key
y = list()
for v in x:
if isinstance(v, dict):
y += [{a: b} for a, b in v.items()]
elif sc.isstring(v):
y.append(v)
else:
raise Exception("Unknown type")
return y
def _extract_labels(input_arrays) -> set:
"""
Extract all quantities from list of dicts
The inputs supported by `outputs` and `pops` can contain lists of optional
aggregations. The first step in `PlotData` is to extract all of the quantities
in the `Model` object that are required to compute the requested aggregations.
:param input_arrays: Input string, list, or dict specifying aggregations
:return: Set of unique string values that correspond to model quantities
Example usage:
>>> _extract_labels(['vac',{'a':['vac','sus']}])
set(['vac','sus'])
The main workflow is:
['vac',{'a':['vac','sus']}] -> ['vac','vac','sus'] -> set(['vac','sus'])
i.e. first a flat list is constructed by replacing any dicts with their values
and concatenating, then the list is converted into a set
"""
out = []
for x in input_arrays:
if isinstance(x, dict):
k = list(x.keys())
assert len(k) == 1, "Aggregation dict can only have one key"
if sc.isstring(x[k[0]]):
continue
else:
out += x[k[0]]
else:
out.append(x)
return set(out)