Source code for spinetoolbox.plotting

######################################################################################################################
# Copyright (C) 2017-2021 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Functions for plotting on PlotWidget.

Currently plotting from the table views found in the SpineDBEditor are supported.

The main entrance points to plotting are:
- plot_selection() which plots selected cells on a table view returning a PlotWidget object
- plot_pivot_column() which is a specialized method for plotting entire columns of a pivot table
- add_time_series_plot() which adds a time series plot to an existing PlotWidget
- add_map_plot() which adds a map plot to an existing PlotWidget

:author: A. Soininen(VTT)
:date:   9.7.2019
"""

import functools
from numbers import Number
from matplotlib.ticker import MaxNLocator
import numpy as np
from PySide2.QtCore import QModelIndex
from spinedb_api import (
    Array,
    convert_leaf_maps_to_specialized_containers,
    IndexedValue,
    Map,
    ParameterValueFormatError,
    TimeSeries,
)
from .helpers import first_non_null
from .mvcmodels.shared import PARSED_ROLE
from .widgets.plot_widget import PlotWidget


[docs]class PlottingError(Exception): """An exception signalling failure in plotting.""" def __init__(self, message): """ Args: message (str): an error message """ super().__init__() self._message = message @property
[docs] def message(self): """str: the error message.""" return self._message
[docs]def plot_pivot_column(proxy_model, column, hints, plot_widget=None): """ Returns a plot widget with a plot of an entire column in PivotTableModel. Args: proxy_model (PivotTableSortFilterProxy): a pivot table filter column (int): a column index to the model hints (PlottingHints): a helper needed for e.g. plot labels plot_widget (PlotWidget): an existing plot widget to draw into or None to create a new widget Returns: PlotWidget: a plot widget """ if plot_widget is None: plot_widget = PlotWidget() needs_redraw = False else: needs_redraw = True first_data_row = proxy_model.sourceModel().headerRowCount() values, labels = _collect_column_values(proxy_model, column, range(first_data_row, proxy_model.rowCount()), hints) if values: if plot_widget.plot_type is None: plot_widget.infer_plot_type(values) else: _raise_if_value_types_clash(values, plot_widget) _add_plot_to_widget(values, labels, plot_widget) if len(plot_widget.canvas.axes.get_lines()) > 1: plot_widget.canvas.axes.legend(loc="best", fontsize="small") plot_widget.canvas.axes.set_xlabel(hints.x_label(proxy_model)) plot_lines = plot_widget.canvas.axes.get_lines() if plot_lines: plot_widget.canvas.axes.set_title(plot_lines[0].get_label()) if needs_redraw: plot_widget.canvas.draw() return plot_widget
[docs]def plot_selection(model, indexes, hints, plot_widget=None): """ Returns a plot widget with plots of the selected indexes. Args: model (QAbstractTableModel): a model indexes (Iterable): a list of QModelIndex objects for plotting hints (PlottingHints): a helper needed for e.g. plot labels plot_widget (PlotWidget): an existing plot widget to draw into or None to create a new widget Returns: a PlotWidget object """ if plot_widget is None: plot_widget = PlotWidget() needs_redraw = False else: needs_redraw = True selections = hints.filter_columns(_organize_selection_to_columns(indexes), model) all_labels = list() for column, rows in selections.items(): values, labels = _collect_column_values(model, column, rows, hints) all_labels += labels if values: if plot_widget.plot_type is None: plot_widget.infer_plot_type(values) else: _raise_if_value_types_clash(values, plot_widget) _add_plot_to_widget(values, labels, plot_widget) plot_widget.canvas.axes.set_xlabel(hints.x_label(model)) if len(all_labels) > 1: plot_widget.canvas.axes.legend(loc="best", fontsize="small") elif len(all_labels) == 1: plot_widget.canvas.axes.set_title(all_labels[0]) if needs_redraw: plot_widget.canvas.draw() return plot_widget
[docs]def add_array_plot(plot_widget, value, label=None): """ Adds an array plot to a plot widget. Args: plot_widget (PlotWidget): a plot widget to modify value (Array): the array to plot label (str): a label for the array """ plot_widget.canvas.axes.plot(value.indexes, value.values, label=label)
[docs]def add_map_plot(plot_widget, map_value, label=None): """ Adds a map plot to a plot widget. Args: plot_widget (PlotWidget): a plot widget to modify map_value (Map): the map to plot label (str): a label for the map """ if not map_value.indexes: return if map_value.is_nested(): raise PlottingError("Plotting of nested maps is not supported.") if not all(isinstance(value, float) for value in map_value.values): raise PlottingError("Cannot plot non-numerical values in map.") if not isinstance(map_value.indexes[0], str): indexes_as_strings = list(map(str, map_value.indexes)) else: indexes_as_strings = map_value.indexes plot_widget.canvas.axes.plot(indexes_as_strings, map_value.values, label=label, linestyle="", marker="o") plot_widget.canvas.axes.xaxis.set_major_locator(MaxNLocator(10))
[docs]def add_time_series_plot(plot_widget, value, label=None): """ Adds a time series step plot to a plot widget. Args: plot_widget (PlotWidget): a plot widget to modify value (TimeSeries): the time series to plot label (str): a label for the time series """ plot_widget.canvas.axes.step(value.indexes, value.values, label=label, where='post') # matplotlib cannot have time stamps before 0001-01-01T00:00 on the x axis left, _ = plot_widget.canvas.axes.get_xlim() if left < 1.0: # 1.0 corresponds to 0001-01-01T00:00 plot_widget.canvas.axes.set_xlim(left=1.0) plot_widget.canvas.figure.autofmt_xdate()
[docs]class PlottingHints: """A base class for plotting hints. The functionality in this class allows the plotting functions to work without explicit knowledge of the underlying table model or widget. """
[docs] def cell_label(self, model, index): """Returns a label for the cell given by index in a table.""" raise NotImplementedError()
[docs] def column_label(self, model, column): """Returns a label for a column.""" raise NotImplementedError()
[docs] def filter_columns(self, selections, model): """Filters columns and returns the filtered selections.""" raise NotImplementedError()
[docs] def is_index_in_data(self, model, index): """Returns true if the cell given by index is actually plottable data.""" raise NotImplementedError()
@staticmethod
[docs] def normalize_row(row, model): """Returns a 'human understandable' row number""" return row + 1
[docs] def special_x_values(self, model, column, rows): """Returns X values if available, otherwise returns None.""" raise NotImplementedError()
[docs] def x_label(self, model): """Returns a label for the x axis.""" raise NotImplementedError()
[docs]class ParameterTablePlottingHints(PlottingHints): """Support for plotting data in Parameter table views."""
[docs] def cell_label(self, model, index): """Returns a label build from the columns on the left from the data column.""" return model.index_name(index)
[docs] def column_label(self, model, column): """Returns the column header.""" return model.headerData(column)
[docs] def filter_columns(self, selections, model): """Returns the 'value' or 'default_value' column only.""" columns = selections.keys() filtered = dict() for column in columns: header = model.headerData(column) if header in ("value", "default_value"): filtered[column] = selections[column] return filtered
[docs] def is_index_in_data(self, model, index): """Always returns True.""" return True
[docs] def special_x_values(self, model, column, rows): """Always returns None.""" return None
[docs] def x_label(self, model): """Returns an empty string for the x axis label.""" return ""
[docs]class PivotTablePlottingHints(PlottingHints): """Support for plotting data in Tabular view."""
[docs] def cell_label(self, model, index): """Returns a label for the table cell given by index.""" source_index = model.mapToSource(index) return model.sourceModel().index_name(source_index)
[docs] def column_label(self, model, column): """Returns a label for a table column.""" return model.sourceModel().column_name(column)
[docs] def filter_columns(self, selections, model): """Filters the X column from selections.""" x_column = model.sourceModel().plot_x_column if x_column is None or not model.filterAcceptsColumn(x_column, QModelIndex()): return selections proxy_x_column = self._map_column_from_source(model, x_column) return {column: rows for column, rows in selections.items() if column != proxy_x_column}
[docs] def is_index_in_data(self, model, index): """Returns True if index is in the data portion of the table.""" source_index = model.mapToSource(index) source_model = model.sourceModel() return source_model.index_in_data(source_index) or source_model.column_is_index_column(source_index.column())
@staticmethod
[docs] def normalize_row(row, model): """See base class.""" source_row = model.mapToSource(model.index(row, 0)).row() return source_row + 1 - model.sourceModel().headerRowCount()
[docs] def special_x_values(self, model, column, rows): """Returns the values from the X column if one is designated otherwise returns None.""" x_column = model.sourceModel().plot_x_column if x_column is not None and model.filterAcceptsColumn(x_column, QModelIndex()): proxy_x_column = self._map_column_from_source(model, x_column) if column != proxy_x_column: collect = ( _collect_x_column_values if not model.sourceModel().column_is_index_column(proxy_x_column) else _collect_index_column_values ) x_values = collect(model, proxy_x_column, rows, self) return x_values return None
[docs] def x_label(self, model): """Returns the label of the X column, if available.""" x_column = model.sourceModel().plot_x_column if x_column is None or not model.filterAcceptsColumn(x_column, QModelIndex()): return "" if model.sourceModel().column_is_index_column(x_column): return "Index" return self.column_label(model, self._map_column_from_source(model, x_column))
@staticmethod
[docs] def _map_column_to_source(proxy_model, proxy_column): """Maps a proxy model column to source model.""" return proxy_model.mapToSource(proxy_model.index(0, proxy_column)).column()
@staticmethod
[docs] def _map_column_from_source(proxy_model, source_column): """Maps a source model column to proxy model.""" source_index = proxy_model.sourceModel().index(0, source_column) return proxy_model.mapFromSource(source_index).column()
[docs]def _add_plot_to_widget(values, labels, plot_widget): """Adds a new plot to plot_widget.""" if not values: return if isinstance(values[0], TimeSeries): for value, label in zip(values, labels): add_time_series_plot(plot_widget, value, label) elif isinstance(values[0], Map): for value, label in zip(values, labels): add_map_plot(plot_widget, value, label) elif isinstance(values[0], Array): for value, label in zip(values, labels): add_array_plot(plot_widget, value, label) elif isinstance(values[1][0], Number): plot_widget.canvas.axes.plot(values[0], values[1], label=labels[0]) if isinstance(values[0][0], str): # matplotlib tries to plot every single x tick label if they are strings. # This can become very slow if the labels are numerous. plot_widget.canvas.axes.xaxis.set_major_locator(MaxNLocator(10)) else: raise PlottingError(f"Cannot plot: Don't know how to plot '{type(values[1][0]).__name__}' values.")
[docs]def _raise_if_not_all_indexed_values(values): """Raises an exception if not all values are TimeSeries or Maps.""" if not values: return values first_value_type = type(values[0]) if issubclass(first_value_type, TimeSeries): # Clump fixed and variable step time series together. We can plot both at the same time. first_value_type = TimeSeries if not all(isinstance(value, first_value_type) for value in values[1:]): raise PlottingError("Cannot plot a mixture of indexed and other data")
[docs]def _filter_name_columns(selections): """Returns a dict with all but the entry with the greatest key removed.""" # In case of Tree and Graph views the user may have selected non-data columns for plotting. # This function removes those from the selected columns. last_column = max(selections.keys()) return {last_column: selections[last_column]}
[docs]def _organize_selection_to_columns(indexes): """Organizes a list of model indexes into a dictionary of {column: (rows)} entries.""" selections = dict() for index in indexes: selections.setdefault(index.column(), set()).add(index.row()) for column, rows in selections.items(): selections[column] = list(sorted(rows)) return selections
[docs]def _collect_single_column_values(model, column, rows, hints): """ Collects selected parameter values from a single column. The return value of this function depends on what type of data the given column contains. In case of plain numbers, a list of scalars and a single label string are returned. In case of indexed parameters (time series, maps), a list of parameter_value objects is returned, accompanied by a list of labels, each label corresponding to one of the indexed parameters. Args: model (QAbstractTableModel): a table model column (int): a column index to the model rows (Sequence): row indexes to plot hints (PlottingHints): a plot support object Returns: tuple: values and label(s) """ values = list() labels = list() for row in sorted(rows): data_index = model.index(row, column) if not hints.is_index_in_data(model, data_index): continue value = model.data(data_index, role=PARSED_ROLE) if isinstance(value, Exception): raise PlottingError(f"Failed to plot row {row}: {value}") if isinstance(value, (Array, Map, TimeSeries)): labels.append(hints.cell_label(model, data_index)) elif value is not None and not isinstance(value, Number): raise PlottingError(f"Cannot plot row {row}: don't know how to plot a '{type(value).__name__}'.") values.append(value) if not values: return values, labels if isinstance(first_non_null(values), float): labels.append(hints.column_label(model, column)) return values, labels
[docs]def _collect_x_column_values(model, column, rows, hints): """ Collects selected parameter values from an x column. Args: model (QAbstractTableModel): a table model column (int): a column index to the model rows (Sequence): row indexes to plot hints (PlottingHints): a plot support object Returns: a tuple of values and label(s) """ values = list() for row in sorted(rows): data_index = model.index(row, column) if not hints.is_index_in_data(model, data_index): continue value = model.data(data_index, role=PARSED_ROLE) if isinstance(value, Exception): raise PlottingError(f"Failed to plot '{value}'") if not isinstance(value, Number): raise PlottingError(f"Cannot plot X column value of type {type(value).__name__}.") values.append(value) if not values: return values return values
[docs]def _collect_index_column_values(model, column, rows, hints): """ Collects selected values from an index column. Args: model (QAbstractTableModel): a table model column (int): a column index to the model rows (Sequence): row indexes to plot hints (PlottingHints): a plot support object Returns: list: column's values """ values = list() for row in sorted(rows): data_index = model.index(row, column) if not hints.is_index_in_data(model, data_index): continue data_index = model.index(row, column) data = model.data(data_index, role=PARSED_ROLE) values.append(data) if not values: return values return values
[docs]def _collect_column_values(model, column, rows, hints): """ Collects selected parameter values from a single column for plotting. The return value of this function depends on what type of data the given column contains. In case of plain numbers, a single tuple of two lists of x and y values and a single label string are returned. In case of time series, a list of TimeSeries objects is returned, accompanied by a list of labels, each label corresponding to one of the time series. Args: model (QAbstractTableModel): a table model column (int): a column index to the model rows (Sequence): row indexes to plot hints (PlottingHints): a support object Returns: tuple: a tuple of values and label(s) """ values, labels = _collect_single_column_values(model, column, rows, hints) if not values: return values, labels if isinstance(first_non_null(values), Map): values, labels = _expand_maps(values, labels) if isinstance(first_non_null(values), (Array, Map, TimeSeries)): values = [x for x in values if x is not None] _raise_if_not_all_indexed_values(values) _raise_if_indexed_values_not_plottable(values) return values, labels # Collect the y values as well x_values = hints.special_x_values(model, column, rows) if x_values is None: x_values = _x_values_from_rows(model, rows, hints) usable_x, usable_y = _filter_and_check(x_values, values) if not usable_x: return [], [] return (usable_x, usable_y), labels
[docs]def _expand_maps(maps, labels): """ Gathers the leaf elements from ``maps`` and expands ``labels`` accordingly. Args: maps (list of Map): maps to expand labels (list of str): map labels Returns: tuple: expanded maps and labels """ expanded_values = list() expanded_labels = list() for map_, label in zip(maps, labels): if map_ is None: continue map_ = convert_leaf_maps_to_specialized_containers(map_) if isinstance(map_, (Array, TimeSeries)): expanded_values.append(map_) expanded_labels.append(label) continue nested_values, value_labels = _label_nested_maps(map_, label) expanded_values += nested_values expanded_labels += value_labels return expanded_values, expanded_labels
[docs]def _label_nested_maps(map_, label): """ Collects leaf values from given Maps and labels them. Args: map_ (Map): a map label (str): map's label Returns: tuple: list of values and list of corresponding labels """ if map_ and not map_.is_nested(): if isinstance(map_.values[0], (Array, TimeSeries)): labels = [label + " - " + str(index) for index in map_.indexes] values = list(map_.values) return values, labels return [map_], [label] values = list() labels = list() for index, value in zip(map_.indexes, map_.values): prefix_label = label + str(index) nested_values, nested_labels = _label_nested_maps(value, prefix_label) values += nested_values labels += nested_labels return values, labels
[docs]def _filter_and_check(xs, ys): """Filters Nones and empty values from x and y and checks that data types match.""" x_type = type(first_non_null(xs)) y_type = type(first_non_null(ys)) filtered_xs = list() filtered_ys = list() for x, y in zip(xs, ys): if x is not None and y is not None: try: filtered_xs.append(x_type(x)) filtered_ys.append(y_type(y)) except (ParameterValueFormatError, TypeError, ValueError): raise PlottingError("Cannot plot a mixture of different types of data") return filtered_xs, filtered_ys
[docs]def _raise_if_indexed_values_not_plottable(values): """Raises an exception if the indexed values in values contain elements that cannot be plotted.""" for value in values: if isinstance(value.values, np.ndarray): if value.values.dtype.kind not in ("f", "M", "m", "i", "u"): raise PlottingError(f"Cannot plot values of type {value.values.dtype.name}.") continue if any(not isinstance(x, Number) for x in value.values): raise PlottingError(f"Cannot plot values of type {type(value.values[0]).__name__}.")
[docs]def _raise_if_value_types_clash(values, plot_widget): """Raises a PlottingError if values type is incompatible with plot_widget.""" if isinstance(values[0], IndexedValue): if isinstance(values[0], TimeSeries) and not plot_widget.plot_type == TimeSeries: raise PlottingError("Cannot plot a mixture of time series and other value types.") if isinstance(values[0], Map) and not plot_widget.plot_type == Map: raise PlottingError("Cannot plot a mixture of maps and other value types.") elif not isinstance(values[1][0], plot_widget.plot_type): raise PlottingError("Cannot plot a mixture of indexed values and scalars.")
[docs]def _x_values_from_rows(model, rows, hints): """Returns x value array constructed from model rows.""" normalize = functools.partial(hints.normalize_row, model=model) def row_to_index(row): return float(normalize(row)) x_values = np.asarray(list(map(row_to_index, rows))) return x_values