Source code for spinetoolbox.import_editor.mvcmodels.source_data_table_model

######################################################################################################################
# Copyright (C) 2017-2020 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Contains the source data table model.

:author: P. Vennström (VTT)
:date:   1.6.2019
"""
from PySide2.QtCore import Qt, Signal, Slot
from spinedb_api import (
    EntityClassMapping,
    ObjectClassMapping,
    RelationshipClassMapping,
    ObjectGroupMapping,
    AlternativeMapping,
    ScenarioMapping,
    ScenarioAlternativeMapping,
    ParameterDefinitionMapping,
    ParameterValueMapping,
    ParameterArrayMapping,
    MappingBase,
    ColumnHeaderMapping,
    ColumnMapping,
    RowMapping,
    ParameterValueFormatError,
    mapping_non_pivoted_columns,
)
from spinetoolbox.mvcmodels.minimal_table_model import MinimalTableModel
from spinetoolbox.spine_io.type_conversion import ConvertSpec
from .mapping_specification_model import MappingSpecificationModel
from ..mapping_colors import ERROR_COLOR, MAPPING_COLORS


[docs]class SourceDataTableModel(MinimalTableModel): """A model for import mapping specification. Highlights columns, rows, and so on, depending on Mapping specification. """
[docs] column_types_updated = Signal()
[docs] row_types_updated = Signal()
[docs] mapping_changed = Signal()
[docs] about_to_undo = Signal(str)
"""Emitted when an undo/redo command is going to be executed.""" def __init__(self, parent=None): super().__init__(parent) self.default_flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable self._mapping_specification = None self._column_types = {} self._row_types = {} self._column_type_errors = {} self._row_type_errors = {}
[docs] def mapping_specification(self): return self._mapping_specification
[docs] def clear(self): self._column_type_errors = {} self._row_type_errors = {} self._column_types = {} self._row_types = {} super().clear()
[docs] def reset_model(self, main_data=None): self._column_type_errors = {} self._row_type_errors = {} self._column_types = {} self._row_types = {} super().reset_model(main_data)
[docs] def set_mapping(self, mapping): """Set mapping to display colors from Args: mapping (MappingSpecificationModel): mapping model """ if not mapping: return if not isinstance(mapping, MappingSpecificationModel): raise TypeError( f"mapping must be instance of 'MappingSpecificationModel', instead got: '{type(mapping).__name__}'" ) if self._mapping_specification is not None: self._mapping_specification.dataChanged.disconnect(self._mapping_data_changed) self._mapping_specification.mapping_read_start_row_changed.disconnect(self._mapping_data_changed) self._mapping_specification.row_or_column_type_recommendation_changed.disconnect(self.set_type) self._mapping_specification.multi_column_type_recommendation_changed.disconnect(self.set_all_column_types) self._mapping_specification = mapping self._mapping_specification.dataChanged.connect(self._mapping_data_changed) self._mapping_specification.mapping_read_start_row_changed.connect(self._mapping_data_changed) self._mapping_specification.row_or_column_type_recommendation_changed.connect(self.set_type) self._mapping_specification.multi_column_type_recommendation_changed.connect(self.set_all_column_types) self._mapping_data_changed()
[docs] def validate(self, section, orientation=Qt.Horizontal): type_class = self.get_type(section, orientation) if type_class is None: return if orientation == Qt.Horizontal: other_orientation_count = self.rowCount() correct_index_order = lambda x: (x[1], x[0]) error_dict = self._column_type_errors else: other_orientation_count = self.columnCount() correct_index_order = lambda x: (x[0], x[1]) error_dict = self._row_type_errors converter = type_class.convert_function() for other_index in range(other_orientation_count): index_tuple = correct_index_order((section, other_index)) index = self.index(*index_tuple) error_dict.pop(index_tuple, None) data = self.data(index) try: if isinstance(data, str) and not data: data = None if data is not None: converter(data) except (ValueError, ParameterValueFormatError) as e: error_dict[index_tuple] = e data_changed_start = correct_index_order((section, 0)) data_changed_end = correct_index_order((section, other_orientation_count)) self.dataChanged.emit(self.index(*data_changed_start), self.index(*data_changed_end))
[docs] def get_type(self, section, orientation=Qt.Horizontal): if orientation == Qt.Horizontal: return self._column_types.get(section, None) return self._row_types.get(section, None)
[docs] def get_types(self, orientation=Qt.Horizontal): if orientation == Qt.Horizontal: return self._column_types return self._row_types
@Slot(int, object, object)
[docs] def set_type(self, section, section_type, orientation=Qt.Horizontal): if orientation == Qt.Horizontal: count = self.columnCount() emit_signal = self.column_types_updated type_dict = self._column_types else: count = self.rowCount() emit_signal = self.row_types_updated type_dict = self._row_types if not isinstance(section_type, ConvertSpec): raise TypeError( f"section_type must be a instance of ConvertSpec, instead got {type(section_type).__name__}" ) if section < 0 or section >= count: return type_dict[section] = section_type emit_signal.emit() self.validate(section, orientation)
[docs] def set_types(self, sections, section_type, orientation): type_dict = self._column_types if orientation == Qt.Horizontal else self._row_types for section in sections: type_dict[section] = section_type self.validate(section, orientation) if orientation == Qt.Horizontal: self.column_types_updated.emit() else: self.row_types_updated.emit()
@Slot(object, object)
[docs] def set_all_column_types(self, excluded_columns, column_type): for column in range(self.columnCount()): if column not in excluded_columns: self._column_types[column] = column_type self.column_types_updated.emit()
@Slot()
[docs] def _mapping_data_changed(self): self.update_colors() self.mapping_changed.emit()
[docs] def update_colors(self): top_left = self.index(0, 0) bottom_right = self.index(self.rowCount() - 1, self.columnCount() - 1) self.dataChanged.emit(top_left, bottom_right, [Qt.BackgroundColorRole])
[docs] def data_error(self, index, role=Qt.DisplayRole, orientation=Qt.Horizontal): if role == Qt.DisplayRole: return "Error" if role == Qt.ToolTipRole: type_name = self.get_type(index.column(), orientation) return f'Could not parse value: "{self._main_data[index.row()][index.column()]}" as a {type_name}' if role == Qt.BackgroundColorRole: return ERROR_COLOR
[docs] def data(self, index, role=Qt.DisplayRole): if self._mapping_specification: last_pivoted_row = self._mapping_specification.last_pivot_row read_from_row = self._mapping_specification.read_start_row else: last_pivoted_row = -1 read_from_row = 0 if index.row() > max(last_pivoted_row, read_from_row - 1): if (index.row(), index.column()) in self._column_type_errors: return self.data_error(index, role) if index.row() <= last_pivoted_row: if ( index.column() not in mapping_non_pivoted_columns(self._mapping_specification.mapping, self.columnCount(), self.header) and index.column() not in self._mapping_specification.skip_columns ): if (index.row(), index.column()) in self._row_type_errors: return self.data_error(index, role, orientation=Qt.Vertical) if role == Qt.BackgroundColorRole and self._mapping_specification: return self.data_color(index) return super().data(index, role)
[docs] def data_color(self, index): """ Returns background color for index depending on mapping. Arguments: index (PySide2.QtCore.QModelIndex): index Returns: QColor: color of index """ mapping = self._mapping_specification.mapping if isinstance(mapping, EntityClassMapping): if isinstance(mapping.parameters, ParameterValueMapping): # parameter values color if mapping.is_pivoted(): last_row = max(mapping.last_pivot_row(), mapping.read_start_row - 1) if ( last_row is not None and index.row() > last_row and index.column() not in self.mapping_column_ref_int_list() ): return MAPPING_COLORS["parameter_value"] elif self.index_in_mapping(mapping.parameters.value, index): return MAPPING_COLORS["parameter_value"] elif self.index_in_mapping(mapping.parameters.alternative_name, index): return MAPPING_COLORS["alternative"] if isinstance(mapping.parameters, ParameterArrayMapping) and mapping.parameters.extra_dimensions: # parameter extra dimensions color for ed in mapping.parameters.extra_dimensions: if self.index_in_mapping(ed, index): return MAPPING_COLORS["parameter_extra_dimension"] if isinstance(mapping.parameters, ParameterDefinitionMapping) and self.index_in_mapping( mapping.parameters.name, index ): # parameter name colors return MAPPING_COLORS["parameter_name"] if not isinstance( mapping, (AlternativeMapping, ScenarioMapping, ScenarioAlternativeMapping) ) and self.index_in_mapping(mapping.name, index): return MAPPING_COLORS["entity_class"] classes = [] objects = [] if isinstance(mapping, ObjectClassMapping): objects = [mapping.objects] elif isinstance(mapping, ObjectGroupMapping): if self.index_in_mapping(mapping.groups, index): return MAPPING_COLORS["group"] objects = [mapping.members] elif isinstance(mapping, RelationshipClassMapping): objects = mapping.objects classes = mapping.object_classes elif isinstance(mapping, AlternativeMapping): if self.index_in_mapping(mapping.name, index): return MAPPING_COLORS["alternative"] elif isinstance(mapping, ScenarioMapping): if self.index_in_mapping(mapping.name, index): return MAPPING_COLORS["scenario"] if self.index_in_mapping(mapping.active, index): return MAPPING_COLORS["active"] elif isinstance(mapping, ScenarioAlternativeMapping): if self.index_in_mapping(mapping.scenario_name, index): return MAPPING_COLORS["scenario"] if self.index_in_mapping(mapping.alternative_name, index): return MAPPING_COLORS["alternative"] if self.index_in_mapping(mapping.before_alternative_name, index): return MAPPING_COLORS["before_alternative"] for o in objects: # object colors if self.index_in_mapping(o, index): return MAPPING_COLORS["entity"] for c in classes: # object colors if self.index_in_mapping(c, index): return MAPPING_COLORS["entity_class"]
[docs] def index_in_mapping(self, mapping, index): """ Checks if index is in mapping Args: mapping (MappingBase): mapping index (QModelIndex): index Returns: bool: True if mapping is in index """ if not isinstance(mapping, MappingBase): return False if isinstance(mapping, ColumnHeaderMapping): # column header can't be in data return False if isinstance(mapping, ColumnMapping): ref = mapping.reference if isinstance(ref, str): # find header reference if ref in self.header: ref = self.header.index(ref) if index.column() == ref: if self._mapping_specification.mapping.is_pivoted(): # only rows below pivoted rows last_row = max( self._mapping_specification.mapping.last_pivot_row(), self._mapping_specification.read_start_row - 1, ) if last_row is not None and index.row() > last_row: return True elif index.row() >= self._mapping_specification.read_start_row: return True if isinstance(mapping, RowMapping): if index.row() == mapping.reference: if index.column() not in self.mapping_column_ref_int_list(): return True return False
[docs] def mapping_column_ref_int_list(self): """Returns a list of column indexes that are not pivoted Returns: [List[int]] -- list of ints """ if not self._mapping_specification: return [] non_pivoted_columns = self._mapping_specification.mapping.non_pivoted_columns() skip_cols = self._mapping_specification.mapping.skip_columns if skip_cols is None: skip_cols = [] int_non_piv_cols = [] for pc in set(non_pivoted_columns + skip_cols): if isinstance(pc, str): try: pc = self.horizontal_header_labels().index(pc) except ValueError: continue int_non_piv_cols.append(pc) return int_non_piv_cols