######################################################################################################################
# Copyright (C) 2017-2020 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Provides pivot table models for the Tabular View.
:author: P. Vennström (VTT)
:date: 1.11.2018
"""
from PySide2.QtCore import Slot, QAbstractTableModel, Qt, QModelIndex, QSortFilterProxyModel
from PySide2.QtGui import QColor, QFont
from .pivot_model import PivotModel
from ..config import PIVOT_TABLE_HEADER_COLOR
[docs]class PivotTableModel(QAbstractTableModel):
def __init__(self, parent):
"""
Args:
parent (TabularViewForm)
"""
super().__init__()
self._parent = parent
self.db_mngr = parent.db_mngr
self.db_map = parent.db_map
self.model = PivotModel()
self._plot_x_column = None
self._data_row_count = 0
self._data_column_count = 0
self.modelReset.connect(self.reset_data_count)
@Slot()
[docs] def reset_data_count(self):
self.layoutAboutToBeChanged.emit()
self._data_row_count = 0
self._data_column_count = 0
self.layoutChanged.emit()
[docs] def canFetchMore(self, parent):
return self._data_row_count < len(self.model.rows) or self._data_column_count < len(self.model.columns)
[docs] def fetchMore(self, parent):
self.fetch_more_rows(parent)
self.fetch_more_columns(parent)
[docs] def fetch_more_rows(self, parent):
count = min(self._ITEMS_TO_FETCH, len(self.model.rows) - self._data_row_count)
first = self.headerRowCount() + self.dataRowCount()
self.beginInsertRows(parent, first, first + count - 1)
self._data_row_count += count
self.endInsertRows()
[docs] def fetch_more_columns(self, parent):
count = min(self._ITEMS_TO_FETCH, len(self.model.columns) - self._data_column_count)
first = self.headerColumnCount() + self.dataColumnCount()
self.beginInsertColumns(parent, first, first + count - 1)
self._data_column_count += count
self.endInsertColumns()
[docs] def reset_model(self, data, index_ids, rows=(), columns=(), frozen=(), frozen_value=()):
self.beginResetModel()
self.model.reset_model(data, index_ids, rows, columns, frozen, frozen_value)
self.endResetModel()
self._plot_x_column = None
[docs] def clear_model(self):
self.beginResetModel()
self.model.clear_model()
self.endResetModel()
self._plot_x_column = None
[docs] def update_model(self, data):
if not data:
return
self.model.update_model(data)
[docs] def add_to_model(self, data):
if not data:
return
row_count, column_count = self.model.add_to_model(data)
if row_count > 0:
first = self.headerRowCount() + self.dataRowCount()
self.beginInsertRows(QModelIndex(), first, first + row_count - 1)
self._data_row_count += row_count
self.endInsertRows()
if column_count > 0:
first = self.headerColumnCount() + self.dataColumnCount()
self.beginInsertColumns(QModelIndex(), first, first + column_count - 1)
self._data_column_count += column_count
self.endInsertColumns()
[docs] def remove_from_model(self, data):
if not data:
return
row_count, column_count = self.model.remove_from_model(data)
if row_count > 0:
first = self.headerRowCount()
self.beginRemoveRows(QModelIndex(), first, first + row_count - 1)
self._data_row_count -= row_count
self.endRemoveRows()
if column_count > 0:
first = self.headerColumnCount()
self.beginRemoveColumns(QModelIndex(), first, first + column_count - 1)
self._data_column_count -= column_count
self.endRemoveColumns()
[docs] def set_pivot(self, rows, columns, frozen, frozen_value):
self.beginResetModel()
self.model.set_pivot(rows, columns, frozen, frozen_value)
self.endResetModel()
[docs] def set_frozen_value(self, frozen_value):
self.beginResetModel()
self.model.set_frozen_value(frozen_value)
self.endResetModel()
[docs] def set_plot_x_column(self, column, is_x):
"""Sets or clears the Y flag on a column"""
if is_x:
self._plot_x_column = column
elif column == self._plot_x_column:
self._plot_x_column = None
self.headerDataChanged.emit(Qt.Horizontal, column, column)
@property
[docs] def plot_x_column(self):
"""Returns the index of the column designated as Y values for plotting or None."""
return self._plot_x_column
[docs] def first_data_row(self):
"""Returns the row index to the first data row."""
# Last row is an empty row, exclude it.
return self.headerRowCount()
[docs] def headerColumnCount(self):
"""Returns number of columns occupied by header."""
return max(bool(self.model.pivot_columns), len(self.model.pivot_rows))
[docs] def dataRowCount(self):
"""Returns number of rows that contain actual data."""
if self.model.pivot_columns and not self.model.pivot_rows:
return 1
return self._data_row_count
[docs] def dataColumnCount(self):
"""Returns number of columns that contain actual data."""
if self.model.pivot_rows and not self.model.pivot_columns:
return 1
return self._data_column_count
[docs] def emptyRowCount(self):
return 1 if self.model.pivot_rows else 0
[docs] def emptyColumnCount(self):
return 1 if self.model.pivot_columns else 0
[docs] def rowCount(self, parent=QModelIndex()):
"""Number of rows in table, number of header rows + datarows + 1 empty row"""
return self.headerRowCount() + self.dataRowCount() + self.emptyRowCount()
[docs] def columnCount(self, parent=QModelIndex()):
"""Number of columns in table, number of header columns + datacolumns + 1 empty columns"""
return self.headerColumnCount() + self.dataColumnCount() + self.emptyColumnCount()
[docs] def flags(self, index):
"""Roles for data"""
if index.row() < self.headerRowCount() and index.column() < self.headerColumnCount():
return ~Qt.ItemIsEnabled
if self.model.pivot_rows and index.row() == len(self.model.pivot_columns):
# empty line between column headers and data
return Qt.ItemIsSelectable | Qt.ItemIsEnabled
return Qt.ItemIsEditable | Qt.ItemIsEnabled | Qt.ItemIsSelectable
[docs] def top_left_indexes(self):
"""Returns indexes in the top left area.
Returns
list(QModelIndex): top indexes (horizontal headers, associated to rows)
list(QModelIndex): left indexes (vertical headers, associated to columns)
"""
pivot_column_count = len(self.model.pivot_columns)
pivot_row_count = len(self.model.pivot_rows)
top_indexes = []
left_indexes = []
for column in range(pivot_row_count):
index = self.index(pivot_column_count, column)
top_indexes.append(index)
column = max(pivot_row_count - 1, 0)
for row in range(pivot_column_count):
index = self.index(row, column)
left_indexes.append(index)
return top_indexes, left_indexes
[docs] def index_in_top(self, index):
return index.row() == len(self.model.pivot_columns) and index.column() < len(self.model.pivot_rows)
[docs] def index_in_left(self, index):
return index.column() == self.headerColumnCount() - 1 and index.row() < len(self.model.pivot_columns)
[docs] def index_in_top_left(self, index):
"""Returns whether or not the given index is in top left corner, where pivot names are displayed"""
return self.index_in_top(index) or self.index_in_left(index)
[docs] def index_in_column_headers(self, index):
"""Returns whether or not the given index is in column headers (horizontal) area"""
return (
index.row() < len(self.model.pivot_columns)
and self.headerColumnCount() <= index.column() < self.headerColumnCount() + self.dataColumnCount()
)
)
[docs] def index_in_empty_column_headers(self, index):
"""Returns whether or not the given index is in empty column headers (vertical) area"""
return index.row() < len(self.model.pivot_columns) and index.column() == self.columnCount() - 1
[docs] def index_in_data(self, index):
"""Returns whether or not the given index is in data area"""
return (
self.headerRowCount() <= index.row() < self.rowCount() - self.emptyRowCount()
and self.headerColumnCount() <= index.column() < self.columnCount() - self.emptyColumnCount()
)
[docs] def map_to_pivot(self, index):
"""Returns a tuple of row and column in the pivot model that corresponds to the given model index.
Args:
index (QModelIndex)
Returns:
int: row
int: column
"""
return index.row() - self.headerRowCount(), index.column() - self.headerColumnCount()
[docs] def _top_left_id(self, index):
"""Returns the id of the top left header corresponding to the given header index.
Args:
index (QModelIndex)
Returns:
int, NoneType
"""
if self.index_in_row_headers(index):
return self.model.pivot_rows[index.column()]
if self.index_in_column_headers(index):
return self.model.pivot_columns[index.row()]
return None
[docs] def value_name(self, index):
"""Returns a string that concatenates the header names corresponding to the given data index.
Args:
index (QModelIndex)
Returns:
str
"""
if not self.index_in_data(index):
return ""
object_names, parameter_name = self.header_names(index)
return self.db_mngr._GROUP_SEP.join(object_names) + " - " + parameter_name
[docs] def column_name(self, column):
"""Returns a string that concatenates the header names corresponding to the given column.
Args:
column (int)
Returns:
str
"""
header_names = []
column -= self.headerColumnCount()
for row, top_left_id in enumerate(self.model.pivot_columns):
header_id = self.model._column_data_header[column][row]
header_names.append(self._header_name(top_left_id, header_id))
return self.db_mngr._GROUP_SEP.join(header_names)
[docs] def _color_data(self, index):
if index.row() < self.headerRowCount() and index.column() < self.headerColumnCount():
return QColor(PIVOT_TABLE_HEADER_COLOR)
[docs] def data(self, index, role=Qt.DisplayRole):
if role in (Qt.DisplayRole, Qt.EditRole, Qt.ToolTipRole):
if self.index_in_top(index):
return self.model.pivot_rows[index.column()]
if self.index_in_left(index):
return self.model.pivot_columns[index.row()]
if self.index_in_headers(index):
return self.header_data(index, role)
if self.index_in_data(index):
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
if not data:
return None
if self._parent.is_value_input_type():
if data[0][0] is None:
return None
return self.db_mngr.get_value(self.db_map, "parameter value", data[0][0], "value", role)
return bool(data[0][0])
return None
if role == Qt.FontRole and self.index_in_top_left(index):
font = QFont()
font.setBold(True)
return font
if role == Qt.BackgroundColorRole:
return self._color_data(index)
if (
role == Qt.TextAlignmentRole
and self.index_in_data(index)
and not self._parent.is_value_input_type()
# or self.index_in_column_headers(index)
):
return Qt.AlignHCenter
return None
[docs] def setData(self, index, value, role=Qt.EditRole):
if role != Qt.EditRole:
return False
return self.batch_set_data([index], [value])
[docs] def batch_set_data(self, indexes, values):
inner_data = []
header_data = []
empty_row_header_data = []
empty_column_header_data = []
for index, value in zip(indexes, values):
if self.index_in_data(index):
inner_data.append((index, value))
elif self.index_in_headers(index):
header_data.append((index, value))
elif self.index_in_empty_row_headers(index):
empty_row_header_data.append((index, value))
elif self.index_in_empty_column_headers(index):
empty_column_header_data.append((index, value))
result = self._batch_set_inner_data(inner_data)
result |= self._batch_set_header_data(header_data)
result |= self._batch_set_empty_header_data(empty_row_header_data, lambda i: self.model.pivot_rows[i.column()])
result |= self._batch_set_empty_header_data(
empty_column_header_data, lambda i: self.model.pivot_columns[i.row()]
)
return result
[docs] def _batch_set_inner_data(self, inner_data):
row_map = set()
column_map = set()
values = {}
for index, value in inner_data:
row, column = self.map_to_pivot(index)
row_map.add(row)
column_map.add(column)
values[row, column] = value
row_map = list(row_map)
column_map = list(column_map)
data = self.model.get_pivoted_data(row_map, column_map)
if not data:
return False
if self._parent.is_value_input_type():
return self._batch_set_parameter_value_data(row_map, column_map, data, values)
return self._batch_set_relationship_data(row_map, column_map, data, values)
[docs] def _batch_set_parameter_value_data(self, row_map, column_map, data, values):
""""""
def object_parameter_value_to_add(header_ids, value, _):
return dict(
entity_class_id=self._parent.current_class_id,
entity_id=header_ids[0],
parameter_definition_id=header_ids[-1],
value=value,
)
def relationship_parameter_value_to_add(header_ids, value, relationship_ids):
object_id_list = ",".join([str(id_) for id_ in header_ids[:-1]])
relationship_id = relationship_ids[object_id_list]
return dict(
entity_class_id=self._parent.current_class_id,
entity_id=relationship_id,
parameter_definition_id=header_ids[-1],
value=value,
)
to_add = []
to_update = []
if self._parent.current_class_type == "object class":
relationship_ids = {}
parameter_value_to_add = object_parameter_value_to_add
elif self._parent.current_class_type == "relationship class":
relationships = self.db_mngr.get_items_by_field(
self.db_map, "relationship", "class_id", self._parent.current_class_id
)
relationship_ids = {x["object_id_list"]: x["id"] for x in relationships}
parameter_value_to_add = relationship_parameter_value_to_add
for i, row in enumerate(row_map):
for j, column in enumerate(column_map):
if (row, column) not in values:
continue
header_ids = self._header_ids(row, column)
if data[i][j] is None:
item = parameter_value_to_add(header_ids, values[row, column], relationship_ids)
to_add.append(item)
else:
item = dict(id=data[i][j], value=values[row, column], parameter_definition_id=header_ids[-1])
to_update.append(item)
if not to_add and not to_update:
return False
if to_add:
self._add_parameter_values(to_add)
if to_update:
self._update_parameter_values(to_update)
return True
[docs] def _checked_parameter_values(self, items):
value_lists = {}
par_def_ids = {item["parameter_definition_id"] for item in items}
for par_def_id in par_def_ids:
param_val_list_id = self.db_mngr.get_item(self.db_map, "parameter definition", par_def_id).get(
"parameter_value_list_id"
)
if not param_val_list_id:
continue
param_val_list = self.db_mngr.get_item(self.db_map, "parameter value list", param_val_list_id)
value_list = param_val_list.get("value_list")
if not value_list:
continue
value_lists[par_def_id] = value_list.split(",")
checked_items = []
for item in items:
par_def_id = item["parameter_definition_id"]
value_list = value_lists.get(par_def_id)
if value_list and item["value"] not in value_list:
continue
checked_items.append(item)
return checked_items
[docs] def _add_parameter_values(self, items):
items = self._checked_parameter_values(items)
self.db_mngr.add_checked_parameter_values({self.db_map: items})
[docs] def _update_parameter_values(self, items):
items = self._checked_parameter_values(items)
self.db_mngr.update_checked_parameter_values({self.db_map: items})
[docs] def _batch_set_relationship_data(self, row_map, column_map, data, values):
def relationship_to_add(header_ids):
rel_cls_name = self.db_mngr.get_item(self.db_map, "relationship class", self._parent.current_class_id)[
"name"
]
object_names = [self.db_mngr.get_item(self.db_map, "object", id_)["name"] for id_ in header_ids]
name = rel_cls_name + "_" + "__".join(object_names)
return dict(object_id_list=list(header_ids), class_id=self._parent.current_class_id, name=name)
to_add = []
to_remove = []
for i, row in enumerate(row_map):
for j, column in enumerate(column_map):
header_ids = self._header_ids(row, column)
if data[i][j] is None and values[row, column]:
item = relationship_to_add(header_ids)
to_add.append(item)
elif data[i][j] is not None and not values[row, column]:
item = self.db_mngr.get_item(self.db_map, "relationship", data[i][j])
to_remove.append(item)
if not to_add and not to_remove:
return False
if to_add:
self.db_mngr.add_relationships({self.db_map: to_add})
if to_remove:
self.db_mngr.remove_items({self.db_map: {"relationship": to_remove}})
return True
[docs]class PivotTableSortFilterProxy(QSortFilterProxyModel):
def __init__(self, parent=None):
"""Initialize class."""
super().__init__(parent)
self.setDynamicSortFilter(False) # Important so we can edit parameters in the view
self.index_filters = {}
[docs] def set_filter(self, identifier, filter_value):
"""Sets filter for a given index (object class) name.
Args:
identifier (int): index identifier
filter_value (set, None): A set of accepted values, or None if no filter (all pass)
"""
self.index_filters[identifier] = filter_value
self.invalidateFilter() # trigger filter update
[docs] def clear_filter(self):
self.index_filters = {}
self.invalidateFilter() # trigger filter update
[docs] def accept_index(self, index, index_ids):
for i, identifier in zip(index, index_ids):
valid = self.index_filters.get(identifier)
if valid is not None and i not in valid:
return False
return True
[docs] def filterAcceptsRow(self, source_row, source_parent):
"""Returns true if the item in the row indicated by the given source_row
and source_parent should be included in the model; otherwise returns false.
"""
if source_row < self.sourceModel().headerRowCount() or source_row == self.sourceModel().rowCount() - 1:
return True
if self.sourceModel().model.pivot_rows:
index = self.sourceModel().model._row_data_header[source_row - self.sourceModel().headerRowCount()]
return self.accept_index(index, self.sourceModel().model.pivot_rows)
return True
[docs] def filterAcceptsColumn(self, source_column, source_parent):
"""Returns true if the item in the column indicated by the given source_column
and source_parent should be included in the model; otherwise returns false.
"""
if (
source_column < self.sourceModel().headerColumnCount()
or source_column == self.sourceModel().columnCount() - 1
):
return True
if self.sourceModel().model.pivot_columns:
index = self.sourceModel().model._column_data_header[source_column - self.sourceModel().headerColumnCount()]
return self.accept_index(index, self.sourceModel().model.pivot_columns)
return True
[docs] def batch_set_data(self, indexes, values):
indexes = [self.mapToSource(index) for index in indexes]
return self.sourceModel().batch_set_data(indexes, values)