Source code for spine_io.io_models

######################################################################################################################
# Copyright (C) 2017 - 2019 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Classes for handling models in PySide2's model/view framework.

:author: P. Vennström (VTT)
:date:   1.6.2019
"""
from spinedb_api import ObjectClassMapping, RelationshipClassMapping, ParameterMapping, Mapping
from PySide2.QtCore import QModelIndex, Qt, QAbstractTableModel, QAbstractListModel
from PySide2.QtGui import QColor
from models import MinimalTableModel

[docs]_DISPLAY_TYPE_TO_TYPE = { "Single value": "single value", "List": "1d array", "Time series": "time series", "Time pattern": "time pattern", "Definition": "definition",
}
[docs]_TYPE_TO_DISPLAY_TYPE = {value: key for key, value in _DISPLAY_TYPE_TO_TYPE.items()}
[docs]class MappingPreviewModel(MinimalTableModel): """A model for highlighting columns, rows, and so on, depending on Mapping specification. Used by ImportPreviewWidget. """ def __init__(self, parent=None): super(MappingPreviewModel, self).__init__(parent) self.default_flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable self._mapping = None self._data_changed_signal = None
[docs] def set_mapping(self, mapping): """Set mapping to display colors from Arguments: mapping {MappingSpecModel} -- mapping model """ if self._data_changed_signal is not None and self._mapping: self._mapping.dataChanged.disconnect(self.update_colors) self._data_changed_signal = None self._mapping = mapping if self._mapping: self._data_changed_signal = self._mapping.dataChanged.connect(self.update_colors) self.update_colors()
[docs] def update_colors(self): self.dataChanged.emit(QModelIndex, QModelIndex, [Qt.BackgroundColorRole])
[docs] def data(self, index, role=Qt.DisplayRole): if role == Qt.BackgroundColorRole and self._mapping: return self.data_color(index) return super(MappingPreviewModel, self).data(index, role)
[docs] def data_color(self, index): """returns background color for index depending on mapping Arguments: index {PySide2.QtCore.QModelIndex} -- index Returns: [QColor] -- QColor of index """ mapping = self._mapping._model if mapping.parameters is not None: # parameter colors if mapping.is_pivoted() and mapping.parameters.parameter_type != "definition": # parameter values color last_row = mapping.last_pivot_row() if ( last_row is not None and index.row() > last_row and index.column() not in self.mapping_column_ref_int_list() ): return QColor(1, 133, 113) elif self.index_in_mapping(mapping.parameters.value, index): return QColor(1, 133, 113) if mapping.parameters.extra_dimensions: # parameter extra dimensions color for ed in mapping.parameters.extra_dimensions: if self.index_in_mapping(ed, index): return QColor(128, 205, 193) if self.index_in_mapping(mapping.parameters.name, index): # parameter name colors return QColor(128, 205, 193) if self.index_in_mapping(mapping.name, index): # class name color return QColor(166, 97, 26) objects = [] classes = [] if isinstance(mapping, ObjectClassMapping): objects = [mapping.object] else: if mapping.objects: objects = mapping.objects if mapping.object_classes: classes = mapping.object_classes for o in objects: # object colors if self.index_in_mapping(o, index): return QColor(223, 194, 125) for c in classes: # object colors if self.index_in_mapping(c, index): return QColor(166, 97, 26)
[docs] def index_in_mapping(self, mapping, index): """Checks if index is in mapping Arguments: mapping {Mapping} -- mapping index {QModelIndex} -- index Returns: [bool] -- returns True if mapping is in index """ if not isinstance(mapping, Mapping): return False if mapping.map_type == "column": ref = mapping.value_reference if isinstance(ref, str): # find header reference if ref in self._headers: ref = self._headers.index(ref) if index.column() == ref: if self._mapping._model.is_pivoted(): # only rows below pivoted rows last_row = self._mapping._model.last_pivot_row() if last_row is not None and index.row() > last_row: return True else: return True if mapping.map_type == "row": if index.row() == mapping.value_reference: if index.column() not in self.mapping_column_ref_int_list(): return True return False
[docs] def mapping_column_ref_int_list(self): """Returns a list of column indexes that are not pivoted Returns: [List[int]] -- list of ints """ if not self._mapping: return [] non_pivoted_columns = self._mapping._model.non_pivoted_columns() skip_cols = self._mapping._model.skip_columns if skip_cols is None: skip_cols = [] int_non_piv_cols = [] for pc in set(non_pivoted_columns + skip_cols): if isinstance(pc, str): if pc in self.horizontal_header_labels(): pc = self.horizontal_header_labels().index(pc) else: continue int_non_piv_cols.append(pc) return int_non_piv_cols
[docs]class MappingSpecModel(QAbstractTableModel): """ A model to hold a Mapping specification. """ def __init__(self, model, parent=None): super(MappingSpecModel, self).__init__(parent) self._display_names = [] self._mappings = [] self._model = None if model is not None: self.set_mapping(model) @property
[docs] def map_type(self): if self._model is None: return None return type(self._model)
@property
[docs] def dimension(self): if self._model is None: return 0 if isinstance(self._model, ObjectClassMapping): return 1 return len(self._model.objects)
@property
[docs] def import_objects(self): if self._model is None: return False if isinstance(self._model, RelationshipClassMapping): return self._model.import_objects return True
@property
[docs] def parameter_type(self): if self._model.parameters is None: return "None" return _TYPE_TO_DISPLAY_TYPE[self._model.parameters.parameter_type]
@property
[docs] def is_pivoted(self): if self._model: return self._model.is_pivoted() return False
[docs] def set_import_objects(self, flag): self._model.import_objects = bool(flag) self.dataChanged.emit(QModelIndex, QModelIndex, [])
[docs] def set_mapping(self, mapping): if not isinstance(mapping, (RelationshipClassMapping, ObjectClassMapping)): raise TypeError( f"mapping must be of type: RelationshipClassMapping, ObjectClassMapping instead got {type(mapping)}" ) if isinstance(mapping, type(self._model)): return self.beginResetModel() self._model = mapping if isinstance(self._model, RelationshipClassMapping): if self._model.objects is None: self._model.objects = [None] self._model.object_classes = [None] self.update_display_table() self.dataChanged.emit(QModelIndex, QModelIndex, []) self.endResetModel()
[docs] def set_dimension(self, dim): if self._model is None or isinstance(self._model, ObjectClassMapping): return self.beginResetModel() if len(self._model.objects) >= dim: self._model.objects = self._model.objects[:dim] self._model.object_classes = self._model.object_classes[:dim] else: self._model.objects = self._model.objects + [None] self._model.object_classes = self._model.object_classes + [None] self.update_display_table() self.dataChanged.emit(QModelIndex, QModelIndex, []) self.endResetModel()
[docs] def change_model_class(self, new_class): """ Change model between Relationship and Object class """ self.beginResetModel() if new_class == "Object": new_class = ObjectClassMapping else: new_class = RelationshipClassMapping if self._model is None: self._model = new_class() elif not isinstance(self._model, new_class): parameters = self._model.parameters if new_class == RelationshipClassMapping: # convert object mapping to relationship mapping obj = [self._model.object] object_class = [self._model.name] self._model = RelationshipClassMapping( name=None, object_classes=object_class, objects=obj, parameters=parameters ) else: # convert relationship mapping to object mapping self._model = ObjectClassMapping( name=self._model.object_classes[0], obj=self._model.objects[0], parameters=parameters ) self.update_display_table() self.dataChanged.emit(QModelIndex, QModelIndex, []) self.endResetModel()
[docs] def change_parameter_type(self, new_type): """ Change parameter type """ self.beginResetModel() if new_type == "None": self._model.parameters = None elif new_type in ("Single value", "List", "Definition"): if self._model.parameters is None: self._model.parameters = ParameterMapping() self._model.parameters.extra_dimensions = None if new_type == "Definition": self._model.parameters.value = None self._model.parameters.parameter_type = _DISPLAY_TYPE_TO_TYPE[new_type] elif new_type in ("Time series", "Time pattern"): if self._model.parameters is None: self._model.parameters = ParameterMapping(extra_dimensions=[None]) if self._model.parameters.extra_dimensions is None: self._model.parameters.extra_dimensions = [None] else: self._model.parameters.extra_dimensions = self._model.parameters.extra_dimensions[:1] self._model.parameters.parameter_type = _DISPLAY_TYPE_TO_TYPE[new_type] self.update_display_table() self.dataChanged.emit(QModelIndex, QModelIndex, []) self.endResetModel()
[docs] def update_display_table(self): display_name = [] mappings = [] mappings.append(self._model.name) if isinstance(self._model, RelationshipClassMapping): display_name.append("Relationship class names") if self._model.object_classes: display_name.extend([f"Object class {i+1} names" for i, oc in enumerate(self._model.object_classes)]) mappings.extend([oc for oc in self._model.object_classes]) if self._model.objects: display_name.extend([f"Object {i+1} names" for i, oc in enumerate(self._model.objects)]) mappings.extend([o for o in self._model.objects]) else: display_name.append("Object class names") display_name.append("Object names") mappings.append(self._model.object) if self._model.parameters: display_name.append("Parameter names") mappings.append(self._model.parameters.name) if self._model.parameters.parameter_type != "definition": display_name.append("Parameter values") mappings.append(self._model.parameters.value) if self._model.parameters.parameter_type == "time series": display_name.append("Parameter time index") mappings.append(self._model.parameters.extra_dimensions[0]) if self._model.parameters.parameter_type == "time pattern": display_name.append("Parameter time pattern index") mappings.append(self._model.parameters.extra_dimensions[0]) self._display_names = display_name self._mappings = mappings
[docs] def get_map_type_display(self, mapping, name): if name == "Parameter values" and self._model.is_pivoted(): mapping_type = "Pivoted" elif mapping is None: mapping_type = "None" elif isinstance(mapping, str): mapping_type = "Constant" elif isinstance(mapping, Mapping): if mapping.map_type == "column": mapping_type = "Column" elif mapping.map_type == "column_name": mapping_type = "Header" elif mapping.map_type == "row": mapping_type = "Row" return mapping_type
[docs] def get_map_value_display(self, mapping, name): if name == "Parameter values" and self._model.is_pivoted(): mapping_value = "Pivoted values" elif mapping is None: mapping_value = "" elif isinstance(mapping, str): mapping_value = mapping elif isinstance(mapping, Mapping): if mapping.map_type == "row": if mapping.value_reference == -1: mapping_value = "Headers" else: mapping_value = str(mapping.value_reference) elif mapping.map_type == "column": mapping_value = str(mapping.value_reference) else: mapping_value = str(mapping.value_reference) return mapping_value
# pylint: disable=no-self-use
[docs] def get_map_append_display(self, mapping, name): append_str = "" if isinstance(mapping, Mapping): append_str = mapping.append_str return append_str
# pylint: disable=no-self-use
[docs] def get_map_prepend_display(self, mapping, name): prepend_str = "" if isinstance(mapping, Mapping): prepend_str = mapping.prepend_str return prepend_str
[docs] def data(self, index, role): if role == Qt.DisplayRole: name = self._display_names[index.row()] m = self._mappings[index.row()] func = [ lambda: name, lambda: self.get_map_type_display(m, name), lambda: self.get_map_value_display(m, name), lambda: self.get_map_prepend_display(m, name), lambda: self.get_map_append_display(m, name), ] f = func[index.column()] return f()
[docs] def rowCount(self, index=None): if not self._model: return 0 return len(self._display_names)
[docs] def columnCount(self, index=None): if not self._model: return 0 return 5
[docs] def headerData(self, section, orientation, role): if role == Qt.DisplayRole: if orientation == Qt.Horizontal: return ["Target", "Source type", "Source ref.", "Prepend string", "Append string"][section]
[docs] def flags(self, index): editable = Qt.ItemIsEnabled | Qt.ItemIsSelectable | Qt.ItemIsEditable non_editable = Qt.ItemIsEnabled | Qt.ItemIsSelectable if index.column() == 0: return non_editable mapping = self._mappings[index.row()] if self._model.is_pivoted(): # special case when we have pivoted data, the values should be # columns under pivoted indexes if self._display_names[index.row()] == "Parameter values": return non_editable if mapping is None: if index.column() <= 2: return editable return non_editable if isinstance(mapping, str): if index.column() <= 2: return editable return non_editable if isinstance(mapping, Mapping) and mapping.map_type == "row" and mapping.value_reference == -1: if index.column() == 2: return non_editable return editable return editable
[docs] def setData(self, index, value, role): name = self._display_names[index.row()] if index.column() == 1: return self.set_type(name, value) if index.column() == 2: return self.set_value(name, value) if index.column() == 3: return self.set_prepend_str(name, value) if index.column() == 4: return self.set_append_str(name, value) return False
[docs] def set_type(self, name, value): if value in ("None", "", None): value = None elif value == "Constant": value = "" elif value == "Column": value = Mapping(map_type="column") elif value == "Header": value = Mapping(map_type="column_name") elif value == "Pivoted Headers": value = Mapping(map_type="row", value_reference=-1) elif value == "Row": value = Mapping(map_type="row") else: return False return self.set_mapping_from_name(name, value)
[docs] def set_value(self, name, value): mapping = self.get_mapping_from_name(name) if mapping is None and value.isdigit(): # create new mapping mapping = Mapping(map_type="column", value_reference=int(value)) elif mapping is None: # string mapping if value == "": return False mapping = value else: # update mapping value if isinstance(mapping, str): if value == "": mapping = None else: mapping = value else: if mapping.map_type == "row" and value.lower() == "header": value = -1 if value == "": value = None try: if value is not None: value = int(value) if mapping.map_type == "row": value = max(-1, value) else: value = max(0, value) except ValueError: return False mapping.value_reference = value return self.set_mapping_from_name(name, mapping)
[docs] def set_append_str(self, name, value): mapping = self.get_mapping_from_name(name) if mapping: if isinstance(mapping, Mapping): if value == "": value = None mapping.append_str = value return self.set_mapping_from_name(name, mapping) return False
[docs] def set_prepend_str(self, name, value): mapping = self.get_mapping_from_name(name) if mapping: if isinstance(mapping, Mapping): if value == "": value = None mapping.prepend_str = value return self.set_mapping_from_name(name, mapping) return False
[docs] def get_mapping_from_name(self, name): if not self._model: return None if name in ("Relationship class names", "Object class names"): mapping = self._model.name elif name == "Object names": mapping = self._model.object elif "Object class " in name: index = [int(s) - 1 for s in name.split() if s.isdigit()] if index: mapping = self._model.object_classes[index[0]] elif "Object " in name: index = [int(s) - 1 for s in name.split() if s.isdigit()] if index: mapping = self._model.objects[index[0]] elif name == "Parameter names": mapping = self._model.parameters.name elif name == "Parameter values": mapping = self._model.parameters.value elif name in ("Parameter time index", "Parameter time pattern index"): mapping = self._model.parameters.extra_dimensions[0] else: return None return mapping
[docs] def set_mapping_from_name(self, name, mapping): if name in ("Relationship class names", "Object class names"): self._model.name = mapping elif name == "Object names": self._model.object = mapping elif "Object class " in name: index = [int(s) - 1 for s in name.split() if s.isdigit()] if index: self._model.object_classes[index[0]] = mapping elif "Object " in name: index = [int(s) - 1 for s in name.split() if s.isdigit()] if index: self._model.objects[index[0]] = mapping elif name == "Parameter names": self._model.parameters.name = mapping elif name == "Parameter values": self._model.parameters.value = mapping elif name in ("Parameter time index", "Parameter time pattern index"): self._model.parameters.extra_dimensions = [mapping] else: return False self.update_display_table() if name in self._display_names: self.dataChanged.emit(QModelIndex, QModelIndex, []) return True
[docs] def set_skip_columns(self, columns=None): if columns is None: columns = [] self._model.skip_columns = list(set(columns)) self.dataChanged.emit(0, 0, [])
[docs]class MappingListModel(QAbstractListModel): """ A model to hold a list of Mappings. """ def __init__(self, mapping_list, parent=None): super(MappingListModel, self).__init__(parent) self._qmappings = [] self._names = [] self._counter = 1 self.set_model(mapping_list)
[docs] def set_model(self, model): self.beginResetModel() self._names = [] self._qmappings = [] for m in model: self._names.append("Mapping " + str(self._counter)) self._qmappings.append(MappingSpecModel(m)) self._counter += 1 self.endResetModel()
[docs] def get_mappings(self): return [m._model for m in self._qmappings]
[docs] def rowCount(self, index=None): if not self._qmappings: return 0 return len(self._qmappings)
[docs] def data_mapping(self, index): if self._qmappings and index.row() < len(self._qmappings): return self._qmappings[index.row()]
[docs] def data(self, index, role=Qt.DisplayRole): if not index.isValid(): return if self._qmappings and role == Qt.DisplayRole and index.row() < self.rowCount(): return self._names[index.row()]
[docs] def add_mapping(self): self.beginInsertRows(self.index(self.rowCount(), 0), self.rowCount(), self.rowCount()) m = ObjectClassMapping() self._qmappings.append(MappingSpecModel(m)) self._names.append("Mapping " + str(self._counter)) self._counter += 1 self.endInsertRows()
[docs] def remove_mapping(self, row): if self._qmappings and row < len(self._qmappings): self.beginRemoveRows(self.index(row, 0), row, row) self._qmappings.pop(row) self._names.pop(row) self.endRemoveRows()