######################################################################################################################
# Copyright (C) 2017-2020 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Provides pivot table models for the Tabular View.
:author: P. Vennström (VTT)
:date: 1.11.2018
"""
from PySide2.QtCore import Qt, Slot, QTimer, QAbstractTableModel, QModelIndex, QSortFilterProxyModel
from PySide2.QtGui import QColor, QFont
from .pivot_model import PivotModel
from ...mvcmodels.shared import PARSED_ROLE
from ...config import PIVOT_TABLE_HEADER_COLOR
[docs]class PivotTableModelBase(QAbstractTableModel):
def __init__(self, parent):
"""
Args:
parent (SpineDBEditor)
"""
super().__init__()
self._parent = parent
self.db_mngr = parent.db_mngr
self.db_map = parent.pivot_db_map
self.model = PivotModel()
self.top_left_headers = {}
self._plot_x_column = None
self._data_row_count = 0
self._data_column_count = 0
self.rowsInserted.connect(lambda *args: QTimer.singleShot(self._FETCH_DELAY, self.fetch_more_rows))
self.columnsInserted.connect(lambda *args: QTimer.singleShot(self._FETCH_DELAY, self.fetch_more_columns))
self.modelAboutToBeReset.connect(self.reset_data_count)
self.modelReset.connect(lambda *args: QTimer.singleShot(self._FETCH_DELAY, self.start_fetching))
@property
[docs] def item_type(self):
"""Returns the item type."""
raise NotImplementedError()
@Slot()
[docs] def reset_data_count(self):
self._data_row_count = 0
self._data_column_count = 0
@Slot()
[docs] def start_fetching(self):
self.fetch_more_rows()
self.fetch_more_columns()
@Slot()
[docs] def fetch_more_rows(self):
max_count = max(self._MIN_FETCH_COUNT, len(self.model.rows) // self._FETCH_STEP_COUNT + 1)
count = min(max_count, len(self.model.rows) - self._data_row_count)
if not count:
return
first = self.headerRowCount() + self.dataRowCount()
self.beginInsertRows(QModelIndex(), first, first + count - 1)
self._data_row_count += count
self.endInsertRows()
@Slot()
[docs] def fetch_more_columns(self):
max_count = max(self._MIN_FETCH_COUNT, len(self.model.rows) // self._FETCH_STEP_COUNT + 1)
count = min(max_count, len(self.model.columns) - self._data_column_count)
if not count:
return
first = self.headerColumnCount() + self.dataColumnCount()
self.beginInsertColumns(QModelIndex(), first, first + count - 1)
self._data_column_count += count
self.endInsertColumns()
[docs] def call_reset_model(self, object_class_ids, pivot=None):
"""
Args:
object_class_names (dict): mapping disambiguated class names to ids
pivot (tuple, optional): list of rows, list of columns, list of frozen indexes, frozen value
"""
raise NotImplementedError()
[docs] def reset_model(self, data, index_ids, rows=(), columns=(), frozen=(), frozen_value=()):
self.beginResetModel()
self.model.reset_model(data, index_ids, rows, columns, frozen, frozen_value)
self.endResetModel()
self._plot_x_column = None
[docs] def clear_model(self):
self.beginResetModel()
self.model.clear_model()
self.endResetModel()
self._plot_x_column = None
[docs] def update_model(self, data):
"""Update model with new data, but doesn't grow the model.
Args:
data (dict)
"""
if not data:
return
self.model.update_model(data)
top_left = self.index(self.headerRowCount(), self.headerColumnCount())
bottom_right = self.index(self.rowCount(), self.columnCount())
self.dataChanged.emit(top_left, bottom_right)
[docs] def add_to_model(self, data):
if not data:
return
row_count, column_count = self.model.add_to_model(data)
if row_count > 0:
first = self.headerRowCount() + self.dataRowCount()
self.beginInsertRows(QModelIndex(), first, first + row_count - 1)
self._data_row_count += row_count
self.endInsertRows()
if column_count > 0:
first = self.headerColumnCount() + self.dataColumnCount()
self.beginInsertColumns(QModelIndex(), first, first + column_count - 1)
self._data_column_count += column_count
self.endInsertColumns()
[docs] def remove_from_model(self, data):
if not data:
return
row_count, column_count = self.model.remove_from_model(data)
if row_count > 0:
first = self.headerRowCount()
self.beginRemoveRows(QModelIndex(), first, first + row_count - 1)
self._data_row_count -= row_count
self.endRemoveRows()
if column_count > 0:
first = self.headerColumnCount()
self.beginRemoveColumns(QModelIndex(), first, first + column_count - 1)
self._data_column_count -= column_count
self.endRemoveColumns()
[docs] def set_pivot(self, rows, columns, frozen, frozen_value):
self.beginResetModel()
self.model.set_pivot(rows, columns, frozen, frozen_value)
self.endResetModel()
[docs] def set_frozen_value(self, frozen_value):
self.beginResetModel()
self.model.set_frozen_value(frozen_value)
self.endResetModel()
[docs] def set_plot_x_column(self, column, is_x):
"""Sets or clears the X flag on a column"""
if is_x:
self._plot_x_column = column
elif column == self._plot_x_column:
self._plot_x_column = None
self.headerDataChanged.emit(Qt.Horizontal, column, column)
@property
[docs] def plot_x_column(self):
"""Returns the index of the column designated as Y values for plotting or None."""
return self._plot_x_column
[docs] def headerColumnCount(self):
"""Returns number of columns occupied by header."""
return max(bool(self.model.pivot_columns), len(self.model.pivot_rows))
[docs] def dataRowCount(self):
"""Returns number of rows that contain actual data."""
if self.model.pivot_columns and not self.model.pivot_rows:
return 1
return self._data_row_count
[docs] def dataColumnCount(self):
"""Returns number of columns that contain actual data."""
if self.model.pivot_rows and not self.model.pivot_columns:
return 1
return self._data_column_count
[docs] def emptyRowCount(self):
return 1 if self.model.pivot_rows else 0
[docs] def emptyColumnCount(self):
return 1 if self.model.pivot_columns else 0
[docs] def rowCount(self, parent=QModelIndex()):
"""Number of rows in table, number of header rows + datarows + 1 empty row"""
return self.headerRowCount() + self.dataRowCount() + self.emptyRowCount()
[docs] def columnCount(self, parent=QModelIndex()):
"""Number of columns in table, number of header columns + datacolumns + 1 empty columns"""
return self.headerColumnCount() + self.dataColumnCount() + self.emptyColumnCount()
[docs] def flags(self, index):
"""Roles for data"""
if index.row() < self.headerRowCount() and index.column() < self.headerColumnCount():
return ~Qt.ItemIsEnabled
if self.model.pivot_rows and index.row() == len(self.model.pivot_columns):
# empty line between column headers and data
return Qt.ItemIsSelectable | Qt.ItemIsEnabled
return Qt.ItemIsEditable | Qt.ItemIsEnabled | Qt.ItemIsSelectable
[docs] def top_left_indexes(self):
"""Returns indexes in the top left area.
Returns
list(QModelIndex): top indexes (horizontal headers, associated to rows)
list(QModelIndex): left indexes (vertical headers, associated to columns)
"""
pivot_column_count = len(self.model.pivot_columns)
pivot_row_count = len(self.model.pivot_rows)
top_indexes = []
left_indexes = []
for column in range(pivot_row_count):
index = self.index(pivot_column_count, column)
top_indexes.append(index)
column = max(pivot_row_count - 1, 0)
for row in range(pivot_column_count):
index = self.index(row, column)
left_indexes.append(index)
return top_indexes, left_indexes
[docs] def index_in_top(self, index):
return index.row() == len(self.model.pivot_columns) and index.column() < len(self.model.pivot_rows)
[docs] def index_in_left(self, index):
return index.column() == self.headerColumnCount() - 1 and index.row() < len(self.model.pivot_columns)
[docs] def index_in_top_left(self, index):
"""Returns whether or not the given index is in top left corner, where pivot names are displayed"""
return self.index_in_top(index) or self.index_in_left(index)
[docs] def index_in_column_headers(self, index):
"""Returns whether or not the given index is in column headers (horizontal) area"""
return (
index.row() < len(self.model.pivot_columns)
and self.headerColumnCount() <= index.column() < self.headerColumnCount() + self.dataColumnCount()
)
)
[docs] def index_in_empty_column_headers(self, index):
"""Returns whether or not the given index is in empty column headers (vertical) area"""
return index.row() < len(self.model.pivot_columns) and index.column() == self.columnCount() - 1
[docs] def index_in_data(self, index):
"""Returns whether or not the given index is in data area"""
return (
self.headerRowCount() <= index.row() < self.rowCount() - self.emptyRowCount()
and self.headerColumnCount() <= index.column() < self.columnCount() - self.emptyColumnCount()
)
[docs] def column_is_index_column(self, column): # pylint: disable=no-self-use
"""Returns True if column is the column containing expanded parameter_value indexes."""
return False
[docs] def map_to_pivot(self, index):
"""Returns a tuple of row and column in the pivot model that corresponds to the given model index.
Args:
index (QModelIndex)
Returns:
int: row
int: column
"""
return index.row() - self.headerRowCount(), index.column() - self.headerColumnCount()
[docs] def _top_left_id(self, index):
"""Returns the id of the top left header corresponding to the given header index.
Args:
index (QModelIndex)
Returns:
int, NoneType
"""
if self.index_in_row_headers(index):
return self.model.pivot_rows[index.column()]
if self.index_in_column_headers(index):
return self.model.pivot_columns[index.row()]
return None
[docs] def _color_data(self, index):
if index.row() < self.headerRowCount() and index.column() < self.headerColumnCount():
return QColor(PIVOT_TABLE_HEADER_COLOR)
[docs] def _text_alignment_data(self, index): # pylint: disable=no-self-use
return None
[docs] def _data(self, index, role):
raise NotImplementedError()
[docs] def data(self, index, role=Qt.DisplayRole):
if role in (Qt.DisplayRole, Qt.EditRole, Qt.ToolTipRole, PARSED_ROLE):
if self.index_in_top(index):
return self.model.pivot_rows[index.column()]
if self.index_in_left(index):
return self.model.pivot_columns[index.row()]
if self.index_in_headers(index):
return self._header_data(index, role)
if self.index_in_data(index):
return self._data(index, role)
return None
if role == Qt.FontRole and self.index_in_top_left(index):
font = QFont()
font.setBold(True)
return font
if role == Qt.BackgroundColorRole:
return self._color_data(index)
if role == Qt.TextAlignmentRole:
return self._text_alignment_data(index)
return None
[docs] def setData(self, index, value, role=Qt.EditRole):
if role != Qt.EditRole:
return False
return self.batch_set_data([index], [value])
[docs] def batch_set_data(self, indexes, values):
inner_data = []
header_data = []
empty_row_header_data = []
empty_column_header_data = []
for index, value in zip(indexes, values):
if self.index_in_data(index):
inner_data.append((index, value))
elif self.index_in_headers(index):
header_data.append((index, value))
elif self.index_in_empty_row_headers(index):
empty_row_header_data.append((index, value))
elif self.index_in_empty_column_headers(index):
empty_column_header_data.append((index, value))
result = self._batch_set_inner_data(inner_data)
result |= self._batch_set_header_data(header_data)
result |= self._batch_set_empty_header_data(empty_row_header_data, lambda i: self.model.pivot_rows[i.column()])
result |= self._batch_set_empty_header_data(
empty_column_header_data, lambda i: self.model.pivot_columns[i.row()]
)
return result
[docs] def _batch_set_inner_data(self, inner_data):
row_map = set()
column_map = set()
values = {}
for index, value in inner_data:
row, column = self.map_to_pivot(index)
row_map.add(row)
column_map.add(column)
values[row, column] = value
row_map = list(row_map)
column_map = list(column_map)
data = self.model.get_pivoted_data(row_map, column_map)
if not data:
return False
return self._do_batch_set_inner_data(row_map, column_map, data, values)
[docs] def _do_batch_set_inner_data(self, row_map, column_map, data, values):
raise NotImplementedError()
[docs] def receive_data_added_or_removed(self, data, action):
{"add": self.add_to_model, "remove": self.remove_from_model}[action](data)
[docs]class ParameterValuePivotTableModel(PivotTableModelBase):
"""A model for the pivot table in parameter_value input type."""
def __init__(self, parent):
"""
Args:
parent (SpineDBEditor)
"""
super().__init__(parent)
self._object_class_count = None
@property
[docs] def item_type(self):
return "parameter_value"
[docs] def object_and_parameter_ids(self, index):
"""Returns the object and parameter ids corresponding to the given data index.
Used by PivotTableView.
Args:
index (QModelIndex)
Returns:
list(int): object ids
int: parameter id
"""
row, column = self.map_to_pivot(index)
header_ids = self._header_ids(row, column)
return header_ids[: self._object_class_count], header_ids[-2], header_ids[-1]
[docs] def object_and_parameter_names(self, index):
"""Returns the object and parameter names corresponding to the given data index.
Used by PivotTableView.
Args:
index (QModelIndex)
Returns:
list(str): object names
str: parameter name
"""
objects_ids, parameter_id, alternative_id = self.object_and_parameter_ids(index)
object_names = [self.db_mngr.get_item(self.db_map, "object", id_)["name"] for id_ in objects_ids]
parameter_name = self.db_mngr.get_item(self.db_map, "parameter_definition", parameter_id).get(
"parameter_name", ""
)
alternative_name = self.db_mngr.get_item(self.db_map, "alternative", alternative_id).get("name", "")
return object_names, parameter_name, alternative_name
[docs] def index_name(self, index):
"""Returns a string that concatenates the object and parameter names corresponding to the given data index.
Used by plotting and ParameterValueEditor.
Args:
index (QModelIndex)
Returns:
str
"""
if not self.index_in_data(index):
return ""
object_names, parameter_name, alternative_name = self.object_and_parameter_names(index)
return self.db_mngr._GROUP_SEP.join(object_names) + " - " + parameter_name + " - " + alternative_name
[docs] def column_name(self, column):
"""Returns a string that concatenates the object and parameter names corresponding to the given column.
Used by plotting.
Args:
column (int)
Returns:
str
"""
header_names = []
column -= self.headerColumnCount()
for row, top_left_id in enumerate(self.model.pivot_columns):
header_id = self.model._column_data_header[column][row]
header_names.append(self._header_name(top_left_id, header_id))
return self.db_mngr._GROUP_SEP.join(header_names)
[docs] def call_reset_model(self, object_class_ids, pivot=None):
"""See base class."""
self._object_class_count = len(object_class_ids)
data = self._parent.load_parameter_value_data()
top_left_headers = [TopLeftObjectHeaderItem(self, name, id_) for name, id_ in object_class_ids.items()]
top_left_headers += [TopLeftParameterHeaderItem(self)]
top_left_headers += [TopLeftAlternativeHeaderItem(self)]
self.top_left_headers = {h.name: h for h in top_left_headers}
if pivot is None:
pivot = self._default_pivot()
super().reset_model(data, list(self.top_left_headers), *pivot)
[docs] def _default_pivot(self):
names = list(self.top_left_headers)
rows = names[:-1]
columns = [names[-1]]
frozen = []
frozen_value = ()
return rows, columns, frozen, frozen_value
[docs] def _data(self, index, role):
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
if not data:
return None
if data[0][0] is None:
return None
return self.db_mngr.get_value(self.db_map, "parameter_value", data[0][0], role)
[docs] def _do_batch_set_inner_data(self, row_map, column_map, data, values):
return self._batch_set_parameter_value_data(row_map, column_map, data, values)
[docs] def _object_parameter_value_to_add(self, header_ids, value):
return dict(
entity_class_id=self._parent.current_class_id,
entity_id=header_ids[0],
parameter_definition_id=header_ids[-2],
value=value,
alternative_id=header_ids[-1],
)
[docs] def _relationship_parameter_value_to_add(self, header_ids, value, rel_id_lookup):
object_id_list = ",".join([str(id_) for id_ in header_ids[:-2]])
relationship_id = rel_id_lookup[object_id_list]
return dict(
entity_class_id=self._parent.current_class_id,
entity_id=relationship_id,
parameter_definition_id=header_ids[-2],
value=value,
alternative_id=header_ids[-1],
)
[docs] def _make_parameter_value_to_add(self):
if self._parent.current_class_type == "object_class":
return self._object_parameter_value_to_add
if self._parent.current_class_type == "relationship_class":
relationships = self.db_mngr.get_items_by_field(
self.db_map, "relationship", "class_id", self._parent.current_class_id
)
rel_id_lookup = {x["object_id_list"]: x["id"] for x in relationships}
return lambda header_ids, value, rel_id_lookup=rel_id_lookup: self._relationship_parameter_value_to_add(
header_ids, value, rel_id_lookup
)
@staticmethod
[docs] def _parameter_value_to_update(id_, header_ids, value):
return {"id": id_, "value": value, "parameter_definition_id": header_ids[-2], "alternative_id": header_ids[-1]}
[docs] def _batch_set_parameter_value_data(self, row_map, column_map, data, values):
"""Sets parameter values in batch."""
to_add = []
to_update = []
parameter_value_to_add = self._make_parameter_value_to_add()
for i, row in enumerate(row_map):
for j, column in enumerate(column_map):
if (row, column) not in values:
continue
header_ids = self._header_ids(row, column)
if data[i][j] is None:
item = parameter_value_to_add(header_ids, values[row, column])
to_add.append(item)
else:
item = self._parameter_value_to_update(data[i][j], header_ids, values[row, column])
to_update.append(item)
if not to_add and not to_update:
return False
if to_add:
self._add_parameter_values(to_add)
if to_update:
self._update_parameter_values(to_update)
return True
[docs] def _checked_parameter_values(self, items):
value_lists = {}
par_def_ids = {item["parameter_definition_id"] for item in items}
for par_def_id in par_def_ids:
param_val_list_id = self.db_mngr.get_item(self.db_map, "parameter_definition", par_def_id).get(
"parameter_value_list_id"
)
if not param_val_list_id:
continue
param_val_list = self.db_mngr.get_item(self.db_map, "parameter_value_list", param_val_list_id)
value_list = param_val_list.get("value_list")
if not value_list:
continue
value_lists[par_def_id] = value_list.split(";")
checked_items = []
for item in items:
par_def_id = item["parameter_definition_id"]
value_list = value_lists.get(par_def_id)
if value_list and item["value"] not in value_list:
continue
checked_items.append(item)
return checked_items
[docs] def _add_parameter_values(self, items):
items = self._checked_parameter_values(items)
self.db_mngr.add_checked_parameter_values({self.db_map: items})
[docs] def _update_parameter_values(self, items):
items = self._checked_parameter_values(items)
self.db_mngr.update_checked_parameter_values({self.db_map: items})
[docs] def get_set_data_delayed(self, index):
"""Returns a function that ParameterValueEditor can call to set data for the given index at any later time,
even if the model changes.
Args:
index (QModelIndex)
Returns:
function
"""
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
header_ids = self._header_ids(row, column)
if data[0][0] is None:
parameter_value_to_add = self._make_parameter_value_to_add()
return lambda value, parameter_value_to_add=parameter_value_to_add, header_ids=header_ids: self._add_parameter_values(
[parameter_value_to_add(header_ids, value)]
)
return lambda value, id_=data[0][0], header_ids=header_ids: self._update_parameter_values(
[self._parameter_value_to_update(id_, header_ids, value)]
)
[docs] def receive_objects_added_or_removed(self, items, action):
if self._parent.current_class_type != "object_class":
return False
objects = [x for x in items if x["class_id"] == self._parent.current_class_id]
if not objects:
return False
data = self._parent.load_empty_parameter_value_data(entities=objects)
self.receive_data_added_or_removed(data, action)
return True
[docs] def receive_relationships_added_or_removed(self, relationships, action):
data = self._parent.load_empty_parameter_value_data(entities=relationships)
self.receive_data_added_or_removed(data, action)
return True
[docs] def receive_parameter_definitions_added_or_removed(self, parameters, action):
parameter_ids = {x["id"] for x in parameters}
data = self._parent.load_empty_parameter_value_data(parameter_ids=parameter_ids)
self.receive_data_added_or_removed(data, action)
return True
[docs] def receive_parameter_values_added_or_removed(self, parameter_values, action):
data = self._parent.load_full_parameter_value_data(parameter_values=parameter_values, action=action)
self.update_model(data)
return True
[docs]class IndexExpansionPivotTableModel(ParameterValuePivotTableModel):
"""A model for the pivot table in parameter index expansion input type."""
def __init__(self, parent):
"""
Args:
parent (SpineDBEditor)
"""
super().__init__(parent)
self._index_top_left_header = None
[docs] def call_reset_model(self, object_class_ids, pivot=None):
"""See base class."""
self._object_class_count = len(object_class_ids)
data = self._parent.load_expanded_parameter_value_data()
top_left_headers = [TopLeftObjectHeaderItem(self, name, id_) for name, id_ in object_class_ids.items()]
self._index_top_left_header = TopLeftParameterIndexHeaderItem(self)
top_left_headers += [
self._index_top_left_header,
TopLeftParameterHeaderItem(self),
TopLeftAlternativeHeaderItem(self),
]
self.top_left_headers = {h.name: h for h in top_left_headers}
if pivot is None:
pivot = self._default_pivot()
super().reset_model(data, list(self.top_left_headers), *pivot)
pivot_rows = pivot[0]
try:
x_column = pivot_rows.index(self._index_top_left_header.name)
self.set_plot_x_column(x_column, is_x=True)
except ValueError:
# The parameter index is not a column (it's either a row or frozen)
pass
[docs] def flags(self, index):
"""Roles for data"""
if self.index_in_data(index):
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
if not data or data[0][0] is None:
# Don't add parameter values in index expansion mode
return Qt.ItemIsSelectable | Qt.ItemIsEnabled
if self._top_left_id(index) == self._index_top_left_header.name:
return Qt.ItemIsSelectable | Qt.ItemIsEnabled
return super().flags(index)
[docs] def column_is_index_column(self, column):
"""Returns True if column is the column containing expanded parameter_value indexes."""
try:
index_column = self.model.pivot_rows.index(self._index_top_left_header.name)
return column == index_column
except ValueError:
# The parameter index is not a column (it's either a row or frozen)
return False
[docs] def _data(self, index, role):
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
if not data:
return None
if data[0][0] is None:
return None
parameter_index = self._header_ids(row, column)[-2]
return self.db_mngr.get_value_index(self.db_map, "parameter_value", data[0][0], parameter_index, role)
@staticmethod
[docs] def _parameter_value_to_update(id_, header_ids, value):
return {"id": id_, "value": value, "index": header_ids[-2]}
[docs] def _update_parameter_values(self, items):
self.db_mngr.update_expanded_parameter_values({self.db_map: items})
[docs]class RelationshipPivotTableModel(PivotTableModelBase):
"""A model for the pivot table in relationship input type."""
@property
[docs] def item_type(self):
return "relationship"
[docs] def call_reset_model(self, object_class_ids, pivot=None):
"""See base class."""
data = self._parent.load_relationship_data()
self.top_left_headers = {
name: TopLeftObjectHeaderItem(self, name, id_) for name, id_ in object_class_ids.items()
}
if pivot is None:
pivot = self._default_pivot()
super().reset_model(data, list(self.top_left_headers), *pivot)
[docs] def _default_pivot(self):
rows = list(self.top_left_headers)
columns = []
frozen = []
frozen_value = ()
return rows, columns, frozen, frozen_value
[docs] def _data(self, index, role):
row, column = self.map_to_pivot(index)
data = self.model.get_pivoted_data([row], [column])
if not data:
return None
return bool(data[0][0])
[docs] def _text_alignment_data(self, index):
if self.index_in_data(index):
return Qt.AlignHCenter
return None
[docs] def _do_batch_set_inner_data(self, row_map, column_map, data, values):
return self._batch_set_relationship_data(row_map, column_map, data, values)
[docs] def _batch_set_relationship_data(self, row_map, column_map, data, values):
def relationship_to_add(header_ids):
rel_cls_name = self.db_mngr.get_item(self.db_map, "relationship_class", self._parent.current_class_id)[
"name"
]
object_names = [self.db_mngr.get_item(self.db_map, "object", id_)["name"] for id_ in header_ids]
name = rel_cls_name + "_" + "__".join(object_names)
return dict(object_id_list=list(header_ids), class_id=self._parent.current_class_id, name=name)
to_add = []
to_remove = []
for i, row in enumerate(row_map):
for j, column in enumerate(column_map):
header_ids = self._header_ids(row, column)
if data[i][j] is None and values[row, column]:
item = relationship_to_add(header_ids)
to_add.append(item)
elif data[i][j] is not None and not values[row, column]:
to_remove.append(data[i][j])
if not to_add and not to_remove:
return False
if to_add:
self.db_mngr.add_relationships({self.db_map: to_add})
if to_remove:
self.db_mngr.remove_items({self.db_map: {"relationship": to_remove}})
return True
[docs] def receive_objects_added_or_removed(self, items, action):
objects_per_class = dict()
for item in items:
objects_per_class.setdefault(item["class_id"], []).append(item)
if not set(objects_per_class.keys()).intersection(self._parent.current_object_class_id_list):
return False
data = self._parent.load_empty_relationship_data(objects_per_class=objects_per_class)
self.receive_data_added_or_removed(data, action)
return True
[docs] def receive_relationships_added_or_removed(self, relationships, action):
data = self._parent.load_full_relationship_data(relationships=relationships, action=action)
self.update_model(data)
return True
[docs] def receive_parameter_definitions_added_or_removed(self, db_map_data, action): # pylint: disable=no-self-use
"""Returns False, this model does not hold parameter data."""
return False
[docs] def receive_parameter_values_added_or_removed(self, db_map_data, action): # pylint: disable=no-self-use
"""Returns False, this model does not hold parameter data."""
return False
[docs]class PivotTableSortFilterProxy(QSortFilterProxyModel):
def __init__(self, parent=None):
"""Initialize class."""
super().__init__(parent)
self.setDynamicSortFilter(False) # Important so we can edit parameters in the view
self.index_filters = {}
[docs] def set_filter(self, identifier, filter_value):
"""Sets filter for a given index (object_class) name.
Args:
identifier (int): index identifier
filter_value (set, None): A set of accepted values, or None if no filter (all pass)
"""
self.index_filters[identifier] = filter_value
self.invalidateFilter() # trigger filter update
[docs] def clear_filter(self):
self.index_filters = {}
self.invalidateFilter() # trigger filter update
[docs] def accept_index(self, index, index_ids):
for i, identifier in zip(index, index_ids):
valid = self.index_filters.get(identifier)
if valid is not None and i not in valid:
return False
return True
[docs] def filterAcceptsRow(self, source_row, source_parent):
"""Returns true if the item in the row indicated by the given source_row
and source_parent should be included in the model; otherwise returns false.
"""
if source_row < self.sourceModel().headerRowCount() or source_row == self.sourceModel().rowCount() - 1:
return True
if self.sourceModel().model.pivot_rows:
index = self.sourceModel().model._row_data_header[source_row - self.sourceModel().headerRowCount()]
return self.accept_index(index, self.sourceModel().model.pivot_rows)
return True
[docs] def filterAcceptsColumn(self, source_column, source_parent):
"""Returns true if the item in the column indicated by the given source_column
and source_parent should be included in the model; otherwise returns false.
"""
if (
source_column < self.sourceModel().headerColumnCount()
or source_column == self.sourceModel().columnCount() - 1
):
return True
if self.sourceModel().model.pivot_columns:
index = self.sourceModel().model._column_data_header[source_column - self.sourceModel().headerColumnCount()]
return self.accept_index(index, self.sourceModel().model.pivot_columns)
return True
[docs] def batch_set_data(self, indexes, values):
indexes = [self.mapToSource(index) for index in indexes]
return self.sourceModel().batch_set_data(indexes, values)