######################################################################################################################
# Copyright (C) 2017-2021 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Contains the GraphViewMixin class.
:author: M. Marin (KTH)
:date: 26.11.2018
"""
import itertools
from time import monotonic
from PySide2.QtCore import Slot, QTimer, QThreadPool
from PySide2.QtWidgets import QHBoxLayout
from spinedb_api import from_database
from ...widgets.custom_qgraphicsscene import CustomGraphicsScene
from ...helpers import get_save_file_name_in_last_dir
from ..graphics_items import (
EntityItem,
ObjectItem,
RelationshipItem,
ArcItem,
CrossHairsItem,
CrossHairsRelationshipItem,
CrossHairsArcItem,
)
from .graph_layout_generator import GraphLayoutGenerator
from .add_items_dialogs import AddObjectsDialog, AddReadyRelationshipsDialog
[docs]class GraphViewMixin:
"""Provides the graph view for the DS form."""
[docs] _ARC_WIDTH = 0.15 * VERTEX_EXTENT
[docs] _ARC_LENGTH_HINT = 1.5 * VERTEX_EXTENT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
QHBoxLayout(self.ui.graphicsView)
self._persistent = False
self._owes_graph = False
self.scene = None
self.object_items = list()
self.relationship_items = list()
self.arc_items = list()
self.selected_tree_inds = {}
self.object_ids = list()
self.relationship_ids = list()
self.src_inds = list()
self.dst_inds = list()
self._relationships_being_added = False
self._adding_objects_at_pos = None
self.added_relationship_ids = set()
self._thread_pool = QThreadPool()
self.layout_gens = dict()
self._layout_gen_id = None
self.ui.graphicsView.connect_spine_db_editor(self)
[docs] def init_models(self):
super().init_models()
self.scene = CustomGraphicsScene(self)
self.ui.graphicsView.setScene(self.scene)
[docs] def connect_signals(self):
"""Connects signals."""
super().connect_signals()
self.ui.treeView_object.tree_selection_changed.connect(self.rebuild_graph)
self.ui.treeView_relationship.tree_selection_changed.connect(self.rebuild_graph)
self.ui.dockWidget_entity_graph.visibilityChanged.connect(self._handle_entity_graph_visibility_changed)
[docs] def receive_objects_added(self, db_map_data):
"""Runs when objects are added to the db.
Adds the new objects to the graph if needed.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
super().receive_objects_added(db_map_data)
added_ids = {(db_map, x["id"]) for db_map, objects in db_map_data.items() for x in objects}
restored_ids = self.restore_removed_entities(added_ids)
added_ids -= restored_ids
if added_ids and self._adding_objects_at_pos is not None:
spread = self.VERTEX_EXTENT * self.ui.graphicsView.zoom_factor
gen = GraphLayoutGenerator(None, len(added_ids), spread=spread)
gen.run()
x = self._adding_objects_at_pos.x()
y = self._adding_objects_at_pos.y()
for dx, dy, object_id in zip(gen.x, gen.y, added_ids):
object_item = ObjectItem(self, x + dx, y + dy, self.VERTEX_EXTENT, object_id)
self.scene.addItem(object_item)
object_item.apply_zoom(self.ui.graphicsView.zoom_factor)
self._adding_objects_at_pos = None
[docs] def receive_relationships_added(self, db_map_data):
"""Runs when relationships are added to the db.
Adds the new relationships to the graph if needed.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
super().receive_relationships_added(db_map_data)
added_ids = {(db_map, x["id"]) for db_map, relationships in db_map_data.items() for x in relationships}
restored_ids = self.restore_removed_entities(added_ids)
added_ids -= restored_ids
if added_ids and self._relationships_being_added:
self.added_relationship_ids.update(added_ids)
self.build_graph(persistent=True)
self._end_add_relationships()
[docs] def receive_object_classes_updated(self, db_map_data):
super().receive_object_classes_updated(db_map_data)
self.refresh_icons(db_map_data)
[docs] def receive_relationship_classes_updated(self, db_map_data):
super().receive_relationship_classes_updated(db_map_data)
self.refresh_icons(db_map_data)
[docs] def receive_objects_updated(self, db_map_data):
"""Runs when objects are updated in the db. Refreshes names of objects in graph.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
super().receive_objects_updated(db_map_data)
updated_ids = {(db_map, x["id"]): x["name"] for db_map, objects in db_map_data.items() for x in objects}
for item in self.ui.graphicsView.items():
if isinstance(item, ObjectItem) and item.db_map_entity_id in updated_ids:
name = updated_ids[item.db_map_entity_id]
item.update_name(name)
[docs] def receive_objects_removed(self, db_map_data):
"""Runs when objects are removed from the db. Rebuilds graph if needed.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
super().receive_objects_removed(db_map_data)
self.hide_removed_entities(db_map_data)
[docs] def receive_relationships_removed(self, db_map_data):
"""Runs when relationships are removed from the db. Rebuilds graph if needed.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
super().receive_relationships_removed(db_map_data)
self.hide_removed_entities(db_map_data)
[docs] def restore_removed_entities(self, added_ids):
"""Restores any entities that have been previously removed and returns their ids.
This happens in the context of undo/redo.
Args:
added_ids (set(int)): Set of newly added ids.
Returns:
set(int)
"""
restored_items = [item for item in self.ui.graphicsView.removed_items if item.db_map_entity_id in added_ids]
for item in restored_items:
self.ui.graphicsView.removed_items.remove(item)
item.set_all_visible(True)
return {item.db_map_entity_id for item in restored_items}
[docs] def hide_removed_entities(self, db_map_data):
"""Hides removed entities while saving them into a list attribute.
This allows entities to be restored in case the user undoes the operation."""
removed_ids = {(db_map, x["id"]) for db_map, items in db_map_data.items() for x in items}
self.added_relationship_ids -= removed_ids
removed_items = [
item
for item in self.ui.graphicsView.items()
if isinstance(item, EntityItem) and item.db_map_entity_id in removed_ids
]
if not removed_items:
return
self.ui.graphicsView.removed_items.extend(removed_items)
scene = self.scene
self.scene = None
for item in removed_items:
item.set_all_visible(False)
self.scene = scene
[docs] def refresh_icons(self, db_map_data):
"""Runs when entity classes are updated in the db. Refreshes icons of entities in graph.
Args:
db_map_data (dict): list of dictionary-items keyed by DiffDatabaseMapping instance.
"""
updated_ids = {(db_map, x["id"]) for db_map, items in db_map_data.items() for x in items}
for item in self.ui.graphicsView.items():
if isinstance(item, EntityItem) and (item.db_map, item.entity_class_id) in updated_ids:
item.refresh_icon()
@Slot(bool)
[docs] def _handle_entity_graph_visibility_changed(self, visible):
if not visible:
self._stop_layout_generators()
return
if self._owes_graph:
QTimer.singleShot(100, self.build_graph)
@Slot(dict)
[docs] def rebuild_graph(self, selected):
"""Stores the given selection of entity tree indexes and builds graph."""
self.selected_tree_inds = selected
self.added_relationship_ids.clear()
self.build_graph()
[docs] def build_graph(self, persistent=False):
"""Builds the graph.
Args:
persistent (bool, optional): If True, elements in the current graph (if any) retain their position
in the new one.
"""
if not self.ui.dockWidget_entity_graph.isVisible():
self._owes_graph = True
return
self._owes_graph = False
self.ui.graphicsView.clear_cross_hairs_items() # Needed
self._persistent = persistent
self._stop_layout_generators()
self._update_graph_data()
self._layout_gen_id = monotonic()
self.layout_gens[self._layout_gen_id] = layout_gen = self._make_layout_generator()
layout_gen.show_progress_widget(self.ui.graphicsView)
layout_gen.layout_available.connect(self._complete_graph)
layout_gen.finished.connect(lambda id_: self.layout_gens.pop(id_))
self._thread_pool.start(layout_gen)
[docs] def _stop_layout_generators(self):
for layout_gen in self.layout_gens.values():
layout_gen.stop()
[docs] def _complete_graph(self, layout_gen_id, x, y):
"""
Args:
layout_gen_id (object)
x (list): Horizontal coordinates
y (list): Vertical coordinates
"""
# Ignore layouts from obsolete generators
if layout_gen_id != self._layout_gen_id:
return
self.ui.graphicsView.removed_items.clear()
self.ui.graphicsView.selected_items.clear()
self.ui.graphicsView.hidden_items.clear()
self.ui.graphicsView.heat_map_items.clear()
self.scene.clear()
if self._make_new_items(x, y):
self._add_new_items() # pylint: disable=no-value-for-parameter
if not self._persistent:
self.ui.graphicsView.reset_zoom()
else:
self.ui.graphicsView.apply_zoom()
[docs] def _get_selected_entity_ids(self):
"""Returns a set of ids corresponding to selected entities in the trees.
Returns:
set: selected object ids
set: selected relationship ids
"""
if "root" in self.selected_tree_inds:
return (
set((db_map, x["id"]) for db_map in self.db_maps for x in self.db_mngr.get_items(db_map, "object")),
set(),
)
selected_object_ids = set()
selected_relationship_ids = set()
for index in self.selected_tree_inds.get("object", {}):
item = index.model().item_from_index(index)
for db_map_id in item.db_map_ids.items():
selected_object_ids.add(db_map_id)
for index in self.selected_tree_inds.get("relationship", {}):
item = index.model().item_from_index(index)
for db_map_id in item.db_map_ids.items():
selected_relationship_ids.add(db_map_id)
for index in self.selected_tree_inds.get("object_class", {}):
item = index.model().item_from_index(index)
for db_map in item.db_maps:
object_ids = set((db_map, id_) for id_ in item._get_children_ids(db_map))
selected_object_ids.update(object_ids)
for index in self.selected_tree_inds.get("relationship_class", {}):
item = index.model().item_from_index(index)
for db_map in item.db_maps:
relationship_ids = set((db_map, id_) for id_ in item._get_children_ids(db_map))
selected_relationship_ids.update(relationship_ids)
return selected_object_ids, selected_relationship_ids
[docs] def _get_all_relationships_for_graph(self, object_ids, relationship_ids):
cond = any if self.ui.graphicsView.auto_expand_objects else all
return [
(db_map, x)
for db_map in self.db_maps
for x in self.db_mngr.get_items(db_map, "relationship")
if cond([(db_map, int(id_)) in object_ids for id_ in x["object_id_list"].split(",")])
] + [(db_map, self.db_mngr.get_item(db_map, "relationship", id_)) for db_map, id_ in relationship_ids]
[docs] def _update_graph_data(self):
"""Updates data for graph according to selection in trees."""
object_ids, relationship_ids = self._get_selected_entity_ids()
relationship_ids.update(self.added_relationship_ids)
prunned_entity_ids = {id_ for ids in self.ui.graphicsView.prunned_entity_ids.values() for id_ in ids}
object_ids -= prunned_entity_ids
relationship_ids -= prunned_entity_ids
relationships = self._get_all_relationships_for_graph(object_ids, relationship_ids)
object_id_lists = dict()
for db_map, relationship in relationships:
if (db_map, relationship["id"]) in prunned_entity_ids:
continue
object_id_list = [
(db_map, id_)
for id_ in (int(x) for x in relationship["object_id_list"].split(","))
if (db_map, id_) not in prunned_entity_ids
]
if len(object_id_list) < 2:
continue
object_ids.update(object_id_list)
object_id_lists[db_map, relationship["id"]] = object_id_list
self.object_ids = list(object_ids)
self.relationship_ids = list(object_id_lists)
self._update_src_dst_inds(object_id_lists)
[docs] def _update_src_dst_inds(self, object_id_lists):
self.src_inds = list()
self.dst_inds = list()
object_ind_lookup = {id_: k for k, id_ in enumerate(self.object_ids)}
relationship_ind_lookup = {id_: len(self.object_ids) + k for k, id_ in enumerate(self.relationship_ids)}
for relationship_id, object_id_list in object_id_lists.items():
object_inds = [object_ind_lookup[object_id] for object_id in object_id_list]
relationship_ind = relationship_ind_lookup[relationship_id]
for object_ind in object_inds:
self.src_inds.append(relationship_ind)
self.dst_inds.append(object_ind)
[docs] def _get_parameter_positions(self, parameter_name):
if not parameter_name:
yield from []
for db_map in self.db_maps:
for p in self.db_mngr.get_items_by_field(db_map, "parameter_value", "parameter_name", parameter_name):
pos = from_database(p["value"])
if isinstance(pos, float):
yield (db_map, p["entity_id"]), pos
[docs] def _make_layout_generator(self):
"""Returns a layout generator for the current graph.
Returns:
GraphLayoutGenerator
"""
fixed_positions = {}
if self._persistent:
for item in self.ui.graphicsView.items():
if isinstance(item, EntityItem):
fixed_positions[item.db_map_entity_id] = {"x": item.pos().x(), "y": item.pos().y()}
param_pos_x = dict(self._get_parameter_positions(self.ui.graphicsView.pos_x_parameter))
param_pos_y = dict(self._get_parameter_positions(self.ui.graphicsView.pos_y_parameter))
for db_map_entity_id in param_pos_x.keys() & param_pos_y.keys():
fixed_positions[db_map_entity_id] = {"x": param_pos_x[db_map_entity_id], "y": param_pos_y[db_map_entity_id]}
entity_ids = self.object_ids + self.relationship_ids
heavy_positions = {ind: fixed_positions[id_] for ind, id_ in enumerate(entity_ids) if id_ in fixed_positions}
return GraphLayoutGenerator(
self._layout_gen_id,
len(entity_ids),
self.src_inds,
self.dst_inds,
self._ARC_LENGTH_HINT,
heavy_positions=heavy_positions,
)
[docs] def _make_new_items(self, x, y):
"""Returns new items for the graph.
Args:
x (list)
y (list)
"""
self.object_items = list()
self.relationship_items = list()
self.arc_items = list()
for i, object_id in enumerate(self.object_ids):
object_item = ObjectItem(self, x[i], y[i], self.VERTEX_EXTENT, object_id)
self.object_items.append(object_item)
offset = len(self.object_items)
for i, relationship_id in enumerate(self.relationship_ids):
relationship_item = RelationshipItem(
self, x[offset + i], y[offset + i], 0.5 * self.VERTEX_EXTENT, relationship_id
)
self.relationship_items.append(relationship_item)
for rel_ind, obj_ind in zip(self.src_inds, self.dst_inds):
arc_item = ArcItem(self.relationship_items[rel_ind - offset], self.object_items[obj_ind], self._ARC_WIDTH)
self.arc_items.append(arc_item)
return any(self.object_items)
[docs] def _add_new_items(self):
for item in self.object_items + self.relationship_items + self.arc_items:
self.scene.addItem(item)
[docs] def start_relationship(self, relationship_class, obj_item):
"""Starts a relationship from the given object item.
Args:
relationship_class (dict)
obj_item (..graphics_items.ObjectItem)
"""
db_map = obj_item.db_map
object_class_ids_to_go = relationship_class["object_class_id_list"].copy()
object_class_ids_to_go.remove(obj_item.entity_class_id)
relationship_class["object_class_ids_to_go"] = object_class_ids_to_go
ch_item = CrossHairsItem(
self, obj_item.pos().x(), obj_item.pos().y(), 0.8 * self.VERTEX_EXTENT, db_map_entity_id=(db_map, None)
)
ch_rel_item = CrossHairsRelationshipItem(
self, obj_item.pos().x(), obj_item.pos().y(), 0.5 * self.VERTEX_EXTENT, db_map_entity_id=(db_map, None)
)
ch_arc_item1 = CrossHairsArcItem(ch_rel_item, obj_item, self._ARC_WIDTH)
ch_arc_item2 = CrossHairsArcItem(ch_rel_item, ch_item, self._ARC_WIDTH)
ch_rel_item.refresh_icon()
self.ui.graphicsView.set_cross_hairs_items(
relationship_class, [ch_item, ch_rel_item, ch_arc_item1, ch_arc_item2]
)
[docs] def finalize_relationship(self, relationship_class, *object_items):
"""Tries to add relationships between the given object items.
Args:
relationship_class (dict)
object_items (..graphics_items.ObjectItem)
"""
db_map = object_items[0].db_map
relationships = set()
object_class_id_list = relationship_class["object_class_id_list"]
for item_permutation in itertools.permutations(object_items):
if [item.entity_class_id for item in item_permutation] == object_class_id_list:
relationship = tuple(item.entity_name for item in item_permutation)
relationships.add(relationship)
dialog = AddReadyRelationshipsDialog(self, relationship_class, list(relationships), self.db_mngr, db_map)
dialog.accepted.connect(self._begin_add_relationships)
dialog.show()
[docs] def _begin_add_relationships(self):
self._relationships_being_added = True
[docs] def _end_add_relationships(self):
self._relationships_being_added = False
[docs] def add_objects_at_position(self, pos):
self._adding_objects_at_pos = pos
dialog = AddObjectsDialog(self, self.db_mngr, *self.db_maps)
dialog.show()
[docs] def get_pdf_file_path(self):
self.qsettings.beginGroup(self.settings_group)
file_path, _ = get_save_file_name_in_last_dir(
self.qsettings, "exportGraphAsPDF", self, "Export as PDF...", self._get_base_dir(), "PDF files (*.pdf)"
)
self.qsettings.endGroup()
return file_path
[docs] def closeEvent(self, event):
"""Handle close window.
Args:
event (QCloseEvent): Closing event
"""
super().closeEvent(event)
self.scene = None