######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# Copyright Spine Toolbox contributors
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
""" Classes for custom QDialogs to add items to databases. """
from contextlib import suppress
from itertools import product
from PySide6.QtWidgets import (
QHBoxLayout,
QVBoxLayout,
QGridLayout,
QWidget,
QLabel,
QLineEdit,
QDialog,
QComboBox,
QSpinBox,
QToolButton,
QTableWidget,
QTableWidgetItem,
QAbstractItemView,
QTreeWidget,
QTreeWidgetItem,
QSplitter,
QDialogButtonBox,
)
from PySide6.QtCore import Slot, Qt, QSize, QModelIndex
from PySide6.QtGui import QIcon
from spinedb_api.helpers import name_from_elements, name_from_dimensions
from ..helpers import string_to_bool, string_to_display_icon
from ...mvcmodels.empty_row_model import EmptyRowModel
from ...mvcmodels.compound_table_model import CompoundTableModel
from ...mvcmodels.minimal_table_model import MinimalTableModel
from ...helpers import DB_ITEM_SEPARATOR
from .custom_delegates import ManageEntityClassesDelegate, ManageEntitiesDelegate
from .manage_items_dialogs import (
ShowIconColorEditorMixin,
GetEntityClassesMixin,
GetEntitiesMixin,
ManageItemsDialog,
DialogWithTableAndButtons,
)
[docs]class AddReadyEntitiesDialog(DialogWithTableAndButtons):
"""A dialog to let the user add new 'ready' multidimensional entities."""
def __init__(self, parent, entity_class, entities, db_mngr, *db_maps, commit_data=True):
"""
Args:
parent (SpineDBEditor)
entity_class (dict)
entities (list(list(str))
db_mngr (SpineDBManager)
*db_maps: DatabaseMapping instances
"""
super().__init__(parent, db_mngr)
self._commit_data = commit_data
self.entity_class = entity_class
self.entities = entities
self.db_maps = db_maps
self.table_view.horizontalHeader().setMinimumSectionSize(0)
self.setWindowTitle("Add '{0}' entities".format(self.entity_class["name"]))
self.populate_table_view()
self.connect_signals()
[docs] def _populate_layout(self):
label = QLabel("<p>Please check the entities you want to add and press <b>Ok</b>.</p>")
label.setWordWrap(True)
self.layout().addWidget(label)
super()._populate_layout()
[docs] def make_table_view(self):
return QTableWidget(self)
[docs] def populate_table_view(self):
dimension_name_list = self.entity_class["dimension_name_list"]
self.table_view.setRowCount(len(self.entities))
self.table_view.setColumnCount(len(dimension_name_list) + 1)
labels = ("",) + dimension_name_list
self.table_view.setHorizontalHeaderLabels(labels)
self.table_view.verticalHeader().hide()
for row, entity in enumerate(self.entities):
item = QTableWidgetItem()
item.setFlags(Qt.ItemIsEnabled)
item.setCheckState(Qt.CheckState.Checked)
self.table_view.setItem(row, 0, item)
for column, element_byname in enumerate(entity):
item = QTableWidgetItem(DB_ITEM_SEPARATOR.join(element_byname))
item.setFlags(Qt.ItemIsEnabled)
self.table_view.setItem(row, column + 1, item)
self.table_view.resizeColumnsToContents()
self.resize_window_to_columns()
[docs] def connect_signals(self):
super().connect_signals()
self.table_view.cellClicked.connect(self._handle_table_view_cell_clicked)
self.table_view.selectionModel().currentChanged.connect(self._handle_table_view_current_changed)
[docs] def _handle_table_view_cell_clicked(self, row, column):
item = self.table_view.item(row, 0)
check_state = Qt.CheckState.Unchecked if item.checkState() == Qt.CheckState.Checked else Qt.CheckState.Checked
item.setCheckState(check_state)
[docs] def _handle_table_view_current_changed(self, current, _previous):
if current.isValid():
self.table_view.selectionModel().clearCurrentIndex()
@Slot()
[docs] def accept(self):
"""Collect info from dialog and try to add items."""
if not self._commit_data:
super().accept()
return
db_map_data = self.get_db_map_data()
if not db_map_data:
self.parent().msg_error.emit("Nothing to add")
return
self.db_mngr.add_entities(db_map_data)
super().accept()
[docs] def get_db_map_data(self):
data = []
for row in range(self.table_view.rowCount()):
if self.table_view.item(row, 0).checkState() != Qt.CheckState.Checked:
continue
element_byname_list = tuple(self.entities[row])
byname = tuple(x for byname in element_byname_list for x in byname)
data.append({"entity_class_name": self.entity_class["name"], "entity_byname": byname})
return {db_map: data for db_map in self.db_maps}
[docs]class AddItemsDialog(ManageItemsDialog):
"""A dialog to query user's preferences for new db items."""
def __init__(self, parent, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor)
db_mngr (SpineDBManager)
*db_maps: DiffDatabaseMapping instances
"""
super().__init__(parent, db_mngr)
self.db_maps = db_maps
self.keyed_db_maps = {x.codename: x for x in db_maps}
self.remove_rows_button = QToolButton(self)
self.remove_rows_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
self.remove_rows_button.setText("Remove selected rows")
[docs] def _populate_layout(self):
self.layout().addWidget(self.remove_rows_button, 1, 0)
self.layout().addWidget(self.button_box, 2, 0, -1, -1)
[docs] def connect_signals(self):
super().connect_signals()
self.remove_rows_button.clicked.connect(self.remove_selected_rows)
@Slot(bool)
[docs] def remove_selected_rows(self, checked=True):
indexes = self.table_view.selectedIndexes()
rows = list(set(ind.row() for ind in indexes))
for row in sorted(rows, reverse=True):
self.model.removeRows(row, 1)
[docs] def all_databases(self, row):
"""Returns a list of db names available for a given row.
Used by delegates.
"""
return [x.codename for x in self.db_maps]
[docs]class AddEntityClassesDialog(ShowIconColorEditorMixin, GetEntityClassesMixin, AddItemsDialog):
"""A dialog to query user's preferences for new entity classes."""
def __init__(self, parent, item, db_mngr, *db_maps, force_default=False):
"""
Args:
parent (SpineDBEditor)
item (MultiDBTreeItem)
db_mngr (SpineDBManager)
*db_maps: DatabaseMapping instances
force_default (bool): if True, defaults are non-editable
"""
super().__init__(parent, db_mngr, *db_maps)
self.setWindowTitle("Add entity classes")
self.table_view.set_column_converter_for_pasting("display icon", string_to_display_icon)
self.table_view.set_column_converter_for_pasting("active by default", string_to_bool)
self.model = EmptyRowModel(self)
self.model.force_default = force_default
self.table_view.setModel(self.model)
self.dimension_count_widget = QWidget(self)
layout = QHBoxLayout(self.dimension_count_widget)
layout.addWidget(QLabel("Number of dimensions"))
self.spin_box = QSpinBox(self)
self.spin_box.setMinimum(0)
layout.addWidget(self.spin_box)
layout.addStretch()
self.remove_rows_button.setIcon(QIcon(":/icons/menu_icons/cube_minus.svg"))
self.table_view.setItemDelegate(ManageEntityClassesDelegate(self))
dimension_one_name = item.name if item.item_type != "root" else None
self.number_of_dimensions = 1 if dimension_one_name is not None else 0
self.spin_box.setValue(self.number_of_dimensions)
self.connect_signals()
labels = ["dimension name (1)"] if dimension_one_name is not None else []
labels += ["entity class name", "description", "display icon", "active by default", "databases"]
self.model.set_horizontal_header_labels(labels)
db_names = ",".join(x.codename for x in item.db_maps)
self.default_display_icon = None
self.model.set_default_row(
**{
"dimension name (1)": dimension_one_name,
"display icon": self.default_display_icon,
"active by default": self.number_of_dimensions != 0,
"databases": db_names,
}
)
self.model.clear()
[docs] def _populate_layout(self):
self.layout().addWidget(self.dimension_count_widget, 0, 0, 1, -1)
self.layout().addWidget(self.table_view, 1, 0)
self.layout().addWidget(self.remove_rows_button, 2, 0)
self.layout().addWidget(self.button_box, 3, 0, -1, -1)
[docs] def connect_signals(self):
"""Connect signals to slots."""
super().connect_signals()
self.spin_box.valueChanged.connect(self._handle_spin_box_value_changed)
# pylint: disable=unnecessary-lambda
self.table_view.itemDelegate().icon_color_editor_requested.connect(
lambda index: self.show_icon_color_editor(index)
)
@Slot(int)
[docs] def _handle_spin_box_value_changed(self, i):
self.spin_box.setEnabled(False)
if i > self.number_of_dimensions:
self.insert_column()
elif i < self.number_of_dimensions:
self.remove_column()
default_value = self.number_of_dimensions != 0
self.model.default_row["active by default"] = default_value
column = self.model.horizontal_header_labels().index("active by default")
last_row = self.model.rowCount() - 1
index = self.model.index(last_row, column)
if index.data() != default_value:
self.model.setData(index, default_value)
self.spin_box.setEnabled(True)
self.resize_window_to_columns()
[docs] def insert_column(self):
column = self.number_of_dimensions
self.number_of_dimensions += 1
column_name = "dimension name ({0})".format(self.number_of_dimensions)
self.model.insertColumns(column, 1)
self.model.insert_horizontal_header_labels(column, [column_name])
self.table_view.resizeColumnToContents(column)
[docs] def remove_column(self):
self.number_of_dimensions -= 1
column = self.number_of_dimensions
self.model.header.pop(column)
self.model.removeColumns(column, 1)
@Slot(QModelIndex, QModelIndex, list)
[docs] def _handle_model_data_changed(self, top_left, bottom_right, roles):
if Qt.ItemDataRole.EditRole not in roles:
return
left = top_left.column()
if left >= self.number_of_dimensions:
return
right = bottom_right.column()
if right >= self.number_of_dimensions:
return
top = top_left.row()
bottom = bottom_right.row()
for row in range(top, bottom + 1):
obj_cls_names = [
name for j in range(self.number_of_dimensions) if (name := self.model.index(row, j).data())
]
relationship_class_name = name_from_dimensions(obj_cls_names)
self.model.setData(self.model.index(row, self.number_of_dimensions), relationship_class_name)
@Slot()
[docs] def accept(self):
"""Collect info from dialog and try to add items."""
db_map_data = dict()
header_labels = self.model.horizontal_header_labels()
name_column = header_labels.index("entity class name")
description_column = header_labels.index("description")
display_icon_column = header_labels.index("display icon")
active_by_default_column = header_labels.index("active by default")
db_column = header_labels.index("databases")
for i in range(self.model.rowCount() - 1): # last row will always be empty
row_data = self.model.row_data(i)
entity_class_name = row_data[name_column]
if not entity_class_name:
self.parent().msg_error.emit("Entity class missing at row {}".format(i + 1))
return
description = row_data[description_column]
display_icon = row_data[display_icon_column]
if not display_icon:
display_icon = self.default_display_icon
active_by_default = row_data[active_by_default_column]
pre_item = {
"name": entity_class_name,
"description": description,
"display_icon": display_icon,
"active_by_default": active_by_default,
}
db_names = row_data[db_column]
if db_names is None:
db_names = ""
for db_name in db_names.split(","):
if db_name not in self.keyed_db_maps:
self.parent().msg_error.emit("Invalid database {0} at row {1}".format(db_name, i + 1))
return
db_map = self.keyed_db_maps[db_name]
entity_classes = self.db_map_ent_cls_lookup_by_name[db_map]
dimension_id_list = list()
for column in range(name_column): # Leave 'name' column outside
dimension_name = row_data[column]
if dimension_name not in entity_classes:
self.parent().msg_error.emit(
"Invalid dimension '{}' for db '{}' at row {}".format(dimension_name, db_name, i + 1)
)
return
dimension_id = entity_classes[dimension_name]["id"]
dimension_id_list.append(dimension_id)
item = pre_item.copy()
item["dimension_id_list"] = dimension_id_list
db_map_data.setdefault(db_map, []).append(item)
if not db_map_data:
self.parent().msg_error.emit("Nothing to add")
return
self.db_mngr.add_entity_classes(db_map_data)
super().accept()
[docs]class AddEntitiesOrManageElementsDialog(GetEntityClassesMixin, GetEntitiesMixin, AddItemsDialog):
"""A dialog to query user's preferences for new entities."""
def __init__(self, parent, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor)
db_mngr (SpineDBManager)
*db_maps: DatabaseMapping instances
"""
super().__init__(parent, db_mngr, *db_maps)
self.model = self.make_model()
self.table_view.setModel(self.model)
self.header_widget = QWidget(self)
layout = QHBoxLayout(self.header_widget)
layout.addWidget(QLabel("Entity class"))
self.ent_cls_combo_box = QComboBox(self)
layout.addWidget(self.ent_cls_combo_box)
layout.addStretch()
self.remove_rows_button.setIcon(QIcon(":/icons/menu_icons/cube_minus.svg"))
self.entity_class_keys = None
[docs] def _class_key_to_str(self, key, *db_maps):
raise NotImplementedError()
[docs] def _accepts_class(self, ent_cls):
"""Returns whether the widget should handle the given entity class.
Args:
ent_cls (MappedItem)
Returns:
bool
"""
raise NotImplementedError()
[docs] def make_model(self):
raise NotImplementedError()
[docs] def _do_reset_model(self):
raise NotImplementedError()
@Slot(int)
[docs] def reset_model(self, index):
"""Setup model according to current entity class selected in combobox."""
self.class_key = self.entity_class_keys[index]
self._do_reset_model()
[docs] def _populate_layout(self):
self.layout().addWidget(self.header_widget, 0, 0, 1, -1)
self.layout().addWidget(self.table_view, 1, 0)
self.layout().addWidget(self.remove_rows_button, 2, 0)
self.layout().addWidget(self.button_box, 3, 0, -1, -1)
[docs] def connect_signals(self):
"""Connect signals to slots."""
super().connect_signals()
self.ent_cls_combo_box.currentIndexChanged.connect(self.reset_model)
@Slot(QModelIndex, QModelIndex, list)
[docs] def _handle_model_data_changed(self, top_left, bottom_right, roles):
if roles and Qt.ItemDataRole.EditRole not in roles or self.entity_class is None:
return
dimension_count = len(self.dimension_name_list)
if dimension_count == 0:
return
header = self.model.horizontal_header_labels()
left = top_left.column()
right = bottom_right.column()
if left <= header.index("entity name") <= right:
return
if "databases" in header and left == right == header.index("databases"):
return
top = top_left.row()
bottom = bottom_right.row()
for row in range(top, bottom + 1):
el_names = [n for n in (self.model.index(row, j).data() for j in range(dimension_count)) if n]
entity_name = name_from_elements(el_names)
self.model.setData(self.model.index(row, dimension_count), entity_name)
[docs]class AddEntitiesDialog(AddEntitiesOrManageElementsDialog):
"""A dialog to query user's preferences for new entities."""
def __init__(self, parent, item, db_mngr, *db_maps, force_default=False, commit_data=True):
"""
Args:
parent (SpineDBEditor)
item (MultiDBTreeItem)
db_mngr (SpineDBManager)
*db_maps: DatabaseMapping instances
force_default (bool): if True, defaults are non-editable
"""
super().__init__(parent, db_mngr, *db_maps)
self._commit_data = commit_data
self.entity_names_by_class_name = {}
if item.item_type == "entity":
entity_name = item.name
entity_class_name = item.parent_item.name
self.entity_names_by_class_name = {entity_class_name: entity_name}
if item.parent_item.item_type == "entity_class":
self.class_key = item.parent_item.display_id
elif item.item_type == "entity_class":
self.class_key = item.display_id
self.model.force_default = force_default
self.setWindowTitle("Add entities")
self.table_view.setItemDelegate(ManageEntitiesDelegate(self))
self.ent_cls_combo_box.setEnabled(not force_default)
db_maps_by_keys = {}
for db_map, entity_classes in self.db_map_ent_cls_lookup.items():
for key, ent_cls in entity_classes.items():
if self._accepts_class(ent_cls):
db_maps_by_keys.setdefault(key, set()).add(db_map)
self.entity_class_keys = sorted(db_maps_by_keys)
self.ent_cls_combo_box.addItems(
[self._class_key_to_str(key, *db_maps_by_keys[key]) for key in self.entity_class_keys]
)
try:
current_index = self.entity_class_keys.index(self.class_key)
self.reset_model(current_index)
self._handle_model_reset()
except ValueError:
current_index = -1
self.ent_cls_combo_box.setCurrentIndex(current_index)
self.connect_signals()
[docs] def _class_key_to_str(self, key, *db_maps):
class_name = self.db_map_ent_cls_lookup[db_maps[0]][key]["name"]
if len(db_maps) == len(self.db_maps):
return class_name
return class_name + "@(" + ", ".join(db_map.codename for db_map in db_maps) + ")"
[docs] def _accepts_class(self, ent_cls):
if self.entity_class is None:
return True
if self.entity_class["dimension_name_list"]:
# Base class is multidimensional, check if given ent_cls containts the base in its entirety
return set(ent_cls["dimension_name_list"]) >= set(self.entity_class["dimension_name_list"])
# Base class is zero-dimensional, check if given ent_cls containts it
return self.entity_class["name"] in set(ent_cls["dimension_name_list"]) | {ent_cls["name"]}
[docs] def make_model(self):
return EmptyRowModel(self)
[docs] def _do_reset_model(self):
header = self.dimension_name_list + ("entity name", "databases")
self.model.set_horizontal_header_labels(header)
default_db_maps = [db_map for db_map, keys in self.db_map_ent_cls_lookup.items() if self.class_key in keys]
db_names = ",".join([db_name for db_name, db_map in self.keyed_db_maps.items() if db_map in default_db_maps])
defaults = {"databases": db_names}
defaults.update(self.entity_names_by_class_name)
self.model.set_default_row(**defaults)
self.model.clear()
[docs] def get_db_map_data(self):
db_map_data = {}
name_column = self.model.horizontal_header_labels().index("entity name")
db_column = self.model.horizontal_header_labels().index("databases")
for i in range(self.model.rowCount() - 1): # last row will always be empty
row_data = self.model.row_data(i)
element_byname_list = [row_data[column] for column in range(name_column)]
entity_name = row_data[name_column]
if not entity_name:
self.parent().msg_error.emit(f"Entity missing at row {i + 1}")
return
pre_item = {"name": entity_name}
db_names = row_data[db_column]
if db_names is None:
db_names = ""
for db_name in db_names.split(","):
if db_name not in self.keyed_db_maps:
self.parent().msg_error.emit(f"Invalid database {db_name} at row {i + 1}")
return
db_map = self.keyed_db_maps[db_name]
entity_classes = self.db_map_ent_cls_lookup[db_map]
ent_cls = entity_classes[self.class_key]
class_id = ent_cls["id"]
dimension_id_list = ent_cls["dimension_id_list"]
entities = self.db_map_ent_lookup[db_map]
element_id_list = []
for entity_class_id, element_byname in zip(dimension_id_list, element_byname_list):
if (entity_class_id, element_byname) not in entities:
self.parent().msg_error.emit(
f"Invalid element '{element_byname}' for db '{db_name}' at row {i + 1}"
)
return
element_id = entities[entity_class_id, element_byname]["id"]
element_id_list.append(element_id)
item = pre_item.copy()
item.update({"element_id_list": element_id_list, "class_id": class_id})
db_map_data.setdefault(db_map, []).append(item)
return db_map_data
@Slot()
[docs] def accept(self):
"""Collect info from dialog and try to add items."""
if not self._commit_data:
super().accept()
return
db_map_data = self.get_db_map_data()
if db_map_data is None:
return
if not db_map_data:
self.parent().msg_error.emit("Nothing to add")
return
self.db_mngr.add_entities(db_map_data)
super().accept()
[docs]class ManageElementsDialog(AddEntitiesOrManageElementsDialog):
"""A dialog to query user's preferences for managing entity dimensions."""
def __init__(self, parent, item, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor): data store widget
item (MultiDBTreeItem)
db_mngr (SpineDBManager): the manager to do the removal
*db_maps: DatabaseMapping instances
"""
super().__init__(parent, db_mngr, *db_maps)
if item.item_type == "entity_class":
self.class_key = item.display_id
self.setWindowTitle("Manage elements")
self.table_view.setSelectionBehavior(QAbstractItemView.SelectRows)
self.remove_rows_button.setToolButtonStyle(Qt.ToolButtonIconOnly)
self.remove_rows_button.setToolTip("<p>Remove selected entities.</p>")
self.remove_rows_button.setIconSize(QSize(24, 24))
self.db_map = db_maps[0]
self.entity_ids = {}
layout = self.header_widget.layout()
self.db_combo_box = QComboBox(self)
layout.addSpacing(32)
layout.addWidget(QLabel("Database", self))
layout.addWidget(self.db_combo_box)
self.splitter = QSplitter(self)
self.splitter.setChildrenCollapsible(False)
self.add_button = QToolButton(self)
self.add_button.setToolTip("<p>Add entities by combining selected available elements.</p>")
self.add_button.setIcon(QIcon(":/icons/menu_icons/cubes_plus.svg"))
self.add_button.setIconSize(QSize(24, 24))
self.add_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
self.add_button.setText(">>")
self._label_available = QLabel("Available elements", self)
self._label_existing = QLabel("Existing entities", self)
self.hidable_widgets = [
self.add_button,
self._label_available,
self._label_existing,
self.table_view,
self.remove_rows_button,
]
for widget in self.hidable_widgets:
widget.hide()
self.existing_items_model = MinimalTableModel(self, lazy=False)
self.new_items_model = MinimalTableModel(self, lazy=False)
self.model.sub_models = [self.new_items_model, self.existing_items_model]
self.db_combo_box.addItems([db_map.codename for db_map in db_maps])
self.reset_entity_class_combo_box(db_maps[0].codename)
self.connect_signals()
[docs] def _populate_layout(self):
self.layout().addWidget(self.header_widget, 0, 0, 1, 4, Qt.AlignHCenter)
self.layout().addWidget(self._label_available, 1, 0)
self.layout().addWidget(self._label_existing, 1, 2)
self.layout().addWidget(self.splitter, 2, 0)
self.layout().addWidget(self.add_button, 2, 1)
self.layout().addWidget(self.table_view, 2, 2)
self.layout().addWidget(self.remove_rows_button, 2, 3)
self.layout().addWidget(self.button_box, 3, 0, -1, -1)
[docs] def make_model(self):
return CompoundTableModel(self)
[docs] def connect_signals(self):
"""Connect signals to slots."""
super().connect_signals()
self.db_combo_box.currentTextChanged.connect(self.reset_entity_class_combo_box)
self.add_button.clicked.connect(self.add_entities)
[docs] def _accepts_class(self, ent_cls):
return bool(ent_cls["dimension_id_list"])
[docs] def _class_key_to_str(self, key, *_db_maps):
return self.db_map_ent_cls_lookup[self.db_map][key]["name"]
@Slot(str)
[docs] def reset_entity_class_combo_box(self, database):
self.db_map = self.keyed_db_maps[database]
self.entity_class_keys = [
key for key, ent_cls in self.db_map_ent_cls_lookup[self.db_map].items() if self._accepts_class(ent_cls)
]
self.ent_cls_combo_box.addItems([self._class_key_to_str(key) for key in self.entity_class_keys])
try:
current_index = self.entity_class_keys.index(self.class_key)
self.reset_model(current_index)
self._handle_model_reset()
except ValueError:
current_index = -1
self.ent_cls_combo_box.setCurrentIndex(current_index)
@Slot(bool)
[docs] def add_entities(self, checked=True):
element_names = [[item.text(0) for item in wg.selectedItems()] for wg in self.splitter_widgets()]
candidate = set(product(*element_names))
existing = {
tuple(elements[:-1]) for elements in self.new_items_model._main_data + self.existing_items_model._main_data
}
to_add = candidate - existing
count = len(to_add)
if not count:
return
self.new_items_model.insertRows(0, count)
self.new_items_model._main_data[0:count] = [list(row) + [""] for row in to_add]
self.model.refresh()
self.model.dataChanged.emit(self.model.index(0, 0), self.model.index(count - 1, self.model.columnCount() - 2))
[docs] def _do_reset_model(self):
horizontal_header_labels = self.dimension_name_list + ("entity name",)
self.model.set_horizontal_header_labels(horizontal_header_labels)
self.existing_items_model.set_horizontal_header_labels(horizontal_header_labels)
self.new_items_model.set_horizontal_header_labels(horizontal_header_labels)
self.entity_ids.clear()
for db_map in self.db_maps:
entity_classes = self.db_map_ent_cls_lookup[db_map]
ent_cls = entity_classes.get(self.class_key, None)
if ent_cls is None:
continue
for entity in self.db_mngr.get_items_by_field(db_map, "entity", "class_id", ent_cls["id"]):
key = tuple(DB_ITEM_SEPARATOR.join(byname) for byname in entity["element_byname_list"]) + (
entity["name"],
)
self.entity_ids[key] = entity["id"]
existing_items = sorted(self.entity_ids)
self.existing_items_model.reset_model(existing_items)
self.model.refresh()
self.model.modelReset.emit()
for wg in self.splitter_widgets():
wg.deleteLater()
for name in self.dimension_name_list:
tree_widget = QTreeWidget(self)
tree_widget.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection)
tree_widget.setColumnCount(1)
tree_widget.setIndentation(0)
header_item = QTreeWidgetItem([name])
header_item.setTextAlignment(0, Qt.AlignHCenter)
tree_widget.setHeaderItem(header_item)
elements = [
x
for k in ("entity_class_name", "superclass_name")
for x in self.db_mngr.get_items_by_field(self.db_map, "entity", k, name)
]
items = [QTreeWidgetItem([DB_ITEM_SEPARATOR.join(el["entity_byname"])]) for el in elements]
tree_widget.addTopLevelItems(items)
tree_widget.resizeColumnToContents(0)
tree_widget.setMinimumWidth(tree_widget.columnWidth(0) + 10)
self.splitter.addWidget(tree_widget)
sizes = [wg.columnWidth(0) for wg in self.splitter_widgets()]
self.splitter.setSizes(sizes)
for widget in self.hidable_widgets:
widget.show()
[docs] def resize_window_to_columns(self, height=None):
table_view_width = (
self.table_view.frameWidth() * 2
+ self.table_view.verticalHeader().width()
+ self.table_view.horizontalHeader().length()
)
self.table_view.setMinimumWidth(table_view_width)
self.table_view.setMinimumHeight(self.table_view.verticalHeader().defaultSectionSize() * 16)
margins = self.layout().contentsMargins()
if height is None:
height = self.sizeHint().height()
self.resize(
margins.left() + margins.right() + table_view_width + self.add_button.width() + self.splitter.width(),
height,
)
@Slot()
[docs] def accept(self):
"""Collect info from dialog and try to add items."""
to_remove = set()
to_update = []
new_name_by_element_byname_list = {tuple(row[:-1]): row[-1] for row in self.existing_items_model._main_data}
for key, id_ in self.entity_ids.items():
element_byname_list, old_name = key[:-1], key[-1]
new_name = new_name_by_element_byname_list.get(element_byname_list)
if new_name is None:
to_remove.add(id_)
elif old_name != new_name:
to_update.append({"id": id_, "name": new_name})
to_add = [
{
"entity_class_name": self.class_name,
"entity_byname": tuple(x for byname in row[:-1] for x in byname.split(DB_ITEM_SEPARATOR)),
"name": row[-1],
}
for row in self.new_items_model._main_data
]
identifier = self.db_mngr.get_command_identifier()
self.db_mngr.remove_items({self.db_map: {"entity": to_remove}}, identifier=identifier)
self.db_mngr.update_items("entity", {self.db_map: to_update}, identifier=identifier)
self.db_mngr.add_items("entity", {self.db_map: to_add}, identifier=identifier)
super().accept()
[docs]class EntityGroupDialogBase(QDialog):
def __init__(self, parent, entity_class_item, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor): data store widget
entity_class_item (EntityClassItem)
db_mngr (SpineDBManager)
*db_maps: database mappings
"""
super().__init__(parent)
self.entity_class_item = entity_class_item
self.db_mngr = db_mngr
self.db_maps = db_maps
self.db_map = db_maps[0]
self.db_maps_by_codename = {db_map.codename: db_map for db_map in db_maps}
self.db_combo_box = QComboBox(self)
self.header_widget = QWidget(self)
self.group_name_line_edit = QLineEdit(self)
header_layout = QHBoxLayout(self.header_widget)
header_layout.addWidget(QLabel("Group name:"))
header_layout.addWidget(self.group_name_line_edit)
header_layout.addSpacing(32)
header_layout.addWidget(QLabel("Database"))
header_layout.addWidget(self.db_combo_box)
self.non_members_tree = QTreeWidget(self)
self.non_members_tree.setHeaderLabel("Non members")
self.non_members_tree.setSelectionMode(QTreeWidget.SelectionMode.ExtendedSelection)
self.non_members_tree.setColumnCount(1)
self.non_members_tree.setIndentation(0)
self.members_tree = QTreeWidget(self)
self.members_tree.setHeaderLabel("Members")
self.members_tree.setSelectionMode(QTreeWidget.SelectionMode.ExtendedSelection)
self.members_tree.setColumnCount(1)
self.members_tree.setIndentation(0)
self.add_button = QToolButton()
self.add_button.setToolTip("<p>Add selected non-members.</p>")
self.add_button.setIcon(QIcon(":/icons/menu_icons/cube_plus.svg"))
self.add_button.setIconSize(QSize(24, 24))
self.add_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
self.add_button.setText(">>")
self.remove_button = QToolButton()
self.remove_button.setToolTip("<p>Remove selected members.</p>")
self.remove_button.setIcon(QIcon(":/icons/menu_icons/cube_minus.svg"))
self.remove_button.setIconSize(QSize(24, 24))
self.remove_button.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
self.remove_button.setText("<<")
self.vertical_button_widget = QWidget()
vertical_button_layout = QVBoxLayout(self.vertical_button_widget)
vertical_button_layout.addStretch()
vertical_button_layout.addWidget(self.add_button)
vertical_button_layout.addWidget(self.remove_button)
vertical_button_layout.addStretch()
self.button_box = QDialogButtonBox(self)
self.button_box.setStandardButtons(QDialogButtonBox.StandardButton.Cancel | QDialogButtonBox.StandardButton.Ok)
layout = QGridLayout(self)
layout.addWidget(self.header_widget, 0, 0, 1, 3, Qt.AlignHCenter)
layout.addWidget(self.non_members_tree, 1, 0)
layout.addWidget(self.vertical_button_widget, 1, 1)
layout.addWidget(self.members_tree, 1, 2)
layout.addWidget(self.button_box, 2, 0, 1, 3)
self.setAttribute(Qt.WA_DeleteOnClose)
self.db_combo_box.addItems(list(self.db_maps_by_codename))
self.db_map_entity_ids = {
db_map: {
x["name"]: x["id"]
for x in self.db_mngr.get_items_by_field(
self.db_map, "entity", "class_id", self.entity_class_item.db_map_id(db_map)
)
}
for db_map in db_maps
}
[docs] def connect_signals(self):
"""Connect signals to slots."""
self.button_box.accepted.connect(self.accept)
self.button_box.rejected.connect(self.reject)
self.db_combo_box.currentTextChanged.connect(self.reset_list_widgets)
self.add_button.clicked.connect(self.add_members)
self.remove_button.clicked.connect(self.remove_members)
[docs] def initial_member_ids(self):
raise NotImplementedError()
[docs] def initial_entity_id(self):
raise NotImplementedError()
@Slot(bool)
[docs] def add_members(self, checked=False):
indexes = sorted(
[self.non_members_tree.indexOfTopLevelItem(x) for x in self.non_members_tree.selectedItems()], reverse=True
)
items = [self.non_members_tree.takeTopLevelItem(ind) for ind in indexes]
self.members_tree.addTopLevelItems(items)
@Slot(bool)
[docs] def remove_members(self, checked=False):
indexes = sorted(
[self.members_tree.indexOfTopLevelItem(x) for x in self.members_tree.selectedItems()], reverse=True
)
items = [self.members_tree.takeTopLevelItem(ind) for ind in indexes]
self.non_members_tree.addTopLevelItems(items)
[docs] def _check_validity(self):
if not self.members_tree.topLevelItemCount():
self.parent().msg_error.emit("Please select at least one member entity.")
return False
return True
[docs]class AddEntityGroupDialog(EntityGroupDialogBase):
def __init__(self, parent, entity_class_item, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor): data store widget
entity_class_item (EntityClassItem)
db_mngr (SpineDBManager)
*db_maps: database mappings
"""
super().__init__(parent, entity_class_item, db_mngr, *db_maps)
self.setWindowTitle("Add entity group")
self.group_name_line_edit.setFocus()
self.group_name_line_edit.setPlaceholderText("Type group name here")
self.reset_list_widgets(db_maps[0].codename)
self.connect_signals()
[docs] def initial_member_ids(self):
return set()
[docs] def initial_entity_id(self):
return None
[docs] def _check_validity(self):
if not super()._check_validity():
return False
group_name = self.group_name_line_edit.text()
if not group_name:
self.parent().msg_error.emit("Please enter a name for the group.")
return False
if group_name in self.db_map_entity_ids[self.db_map]:
self.parent().msg_error.emit(
f"An entity called {group_name} already exists in this class. Please select a different group name."
)
return False
return True
@Slot()
[docs] def accept(self):
if not self._check_validity():
return
class_name = self.entity_class_item.name
group_name = self.group_name_line_edit.text()
member_names = {item.text(0) for item in self.members_tree.findItems("*", Qt.MatchWildcard)}
db_map_data = {
self.db_map: {
"entities": [(class_name, group_name)],
"entity_groups": [
(self.entity_class_item.name, group_name, member_name) for member_name in member_names
],
}
}
self.db_mngr.import_data(db_map_data, command_text="Add entity group")
super().accept()
[docs]class ManageMembersDialog(EntityGroupDialogBase):
def __init__(self, parent, entity_item, db_mngr, *db_maps):
"""
Args:
parent (SpineDBEditor): data store widget
entity_item (entity_tree_item.EntityItem)
db_mngr (SpineDBManager)
*db_maps: database mappings
"""
super().__init__(parent, entity_item.parent_item, db_mngr, *db_maps)
self.setWindowTitle("Manage members")
self.group_name_line_edit.setReadOnly(True)
self.group_name_line_edit.setText(entity_item.name)
self.entity_item = entity_item
self.reset_list_widgets(db_maps[0].codename)
self.connect_signals()
[docs] def _entity_groups(self):
return self.db_mngr.get_items_by_field(self.db_map, "entity_group", "group_id", self.initial_entity_id())
[docs] def initial_member_ids(self):
return {x["member_id"] for x in self._entity_groups()}
[docs] def initial_entity_id(self):
return self.entity_item.db_map_id(self.db_map)
@Slot()
[docs] def accept(self):
if not self._check_validity():
return
ent = self.entity_item.db_map_data(self.db_map)
current_member_ids = {
self.db_map_entity_ids[self.db_map][item.text(0)]
for item in self.members_tree.findItems("*", Qt.MatchWildcard)
}
added = current_member_ids - self.initial_member_ids()
removed = self.initial_member_ids() - current_member_ids
items_to_add = [
{"entity_id": ent["id"], "entity_class_id": ent["class_id"], "member_id": member_id} for member_id in added
]
ids_to_remove = [x["id"] for x in self._entity_groups() if x["member_id"] in removed]
identifier = self.db_mngr.get_command_identifier()
if items_to_add:
self.db_mngr.add_items("entity_group", {self.db_map: items_to_add}, identifier=identifier)
if ids_to_remove:
self.db_mngr.remove_items({self.db_map: {"entity_group": ids_to_remove}}, identifier=identifier)
super().accept()