######################################################################################################################
# Copyright (C) 2017 - 2019 Spine project consortium
# This file is part of Spine Toolbox.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Spine Toolbox grid view
:author: P. Vennström (VTT)
:date: 1.11.2018
"""
import bisect
import operator
from PySide2.QtCore import QAbstractTableModel, Qt, QModelIndex, Signal, QSortFilterProxyModel, QAbstractListModel
from PySide2.QtGui import QColor, QFont
from helpers import tuple_itemgetter
from parameter_value_formatting import format_for_DisplayRole, format_for_EditRole, format_for_ToolTipRole
[docs]class PivotModel:
def __init__(self):
self._index_ind = {}
self._index_type = {}
self.index_entries = {}
self.tuple_index_entries = {}
self._model_is_updating = False # flag if model is being reset/updated
self._data = {} # dictionary of unpivoted data
self._edit_data = {} # dictionary of edited data, values are original data
self._deleted_data = {} # dictionary of deleted data, values are original data
self._data_frozen = {} # data filtered with frozen_value
self._data_frozen_index_values = set() # valid frozen_value values for current pivot_frozen
self._index_types = () # type of the indexes in _data
self.index_names = () # names of the indexes in _data, can not contain duplicates
self.index_real_names = () # real names of indexes, can contain duplicates
self.pivot_rows = () # current selected rows indexes
self.pivot_columns = () # current selected columns indexes
self.pivot_frozen = () # current filtered frozen indexes
self.frozen_value = () # current selected value of index_frozen
self._key_getter = lambda *x: () # operator.itemgetter placeholder used translate pivot to keys in _data
self._row_data_header = [] # header values for row data
self._column_data_header = [] # header valus for column data
self._row_data_header_set = set() # set of _row_data_header
self._column_data_header_set = set() # set of _column_data_header
self._invalid_row = {} # set of rows that have invalid indexes
self._invalid_column = {} # set of columns that have invalid indexes
self._invalid_data = {} # dictionary of invalid data
self._added_index_entries = {} # added index entries
self._added_tuple_index_entries = {} # added tuple index entries
self._deleted_tuple_index_entries = {} # deleted tuple index entries
self._deleted_index_entries = {} # deleted index_entries
self._used_index_values = {}
self._unique_name_2_name = {}
# dict with index name as key and set/range of valid values for that index
# if set/range is empty or index doesn't exist in valid_index_values
# then all values are valid
self._valid_index_values = {}
[docs] def clear_track_data(self):
"""clears data that is tracked"""
self._edit_data = {}
self._deleted_data = {}
self._added_index_entries = {self._unique_name_2_name[n]: set() for n in self.index_names}
self._added_tuple_index_entries = {}
self._deleted_tuple_index_entries = {}
self._deleted_index_entries = {self._unique_name_2_name[n]: set() for n in self.index_names}
[docs] def set_new_data(
self,
data,
index_names,
index_type,
rows=(),
columns=(),
frozen=(),
frozen_value=(),
index_entries=None,
valid_index_values=None,
tuple_index_entries=None,
used_index_values=None,
index_real_names=None,
):
"""set the data of the model, index names and any additional indexes that don't have data, valid index values.
"""
if index_entries is None:
index_entries = dict()
if valid_index_values is None:
valid_index_values = dict()
if tuple_index_entries is None:
tuple_index_entries = dict()
if used_index_values is None:
used_index_values = dict()
if index_real_names is None:
index_real_names = index_names
elif len(index_real_names) != len(index_names):
raise ValueError('index_real_name and index_names must have same length')
if len(index_names) != len(index_type):
raise ValueError('index_names and index_type must have same length')
if data and any(len(d) < len(index_names) + 1 for d in data):
raise ValueError('data inner lists be of len >= len(index_names) + 1')
if not all(t in [str, int] for t in index_type):
raise ValueError('index_type can only contain str or int type')
if len(set(index_real_names)) != len(index_names):
# index_real_names contains duplicates, make sure the type is the same
un_2_n = {unique: name for unique, name in zip(index_names, index_real_names)}
real_type = {n: set() for n in index_real_names}
for name, name_type in zip(index_names, index_type):
real_name = un_2_n[name]
real_type[real_name].add(name_type)
# should only have one type per unique name in index_real_names
if any(len(types) != 1 for types in real_type.values()):
raise ValueError('inconsistent types for "index_real_names" and "index_types"')
if not rows + columns + frozen:
# no pivot given, set default pivot
rows = tuple(index_names)
columns = ()
frozen = ()
frozen_value = ()
else:
# check given pivot
pivot_error = self._is_invalid_pivot(rows, columns, frozen, frozen_value, index_names)
if pivot_error:
raise ValueError(pivot_error)
self._model_is_updating = True
self._unique_name_2_name = {unique: name for unique, name in zip(index_names, index_real_names)}
self._valid_index_values = valid_index_values
self._index_ind = {index: ind for ind, index in enumerate(index_names)}
self.index_names = tuple(index_names)
self.index_real_names = tuple(index_real_names)
self._index_type = {self._unique_name_2_name[index_names[i]]: it for i, it in enumerate(index_type)}
# create data dict with keys as long as index_names
self._data = {tuple(d[: len(index_names)]): d[len(index_names)] for d in data}
# item getter so that you can call _key_getter(row_header + column_header + frozen_value)
# and get a key to use on _data
key = tuple(self.index_names.index(i) for i in index_names)
self._key_getter = tuple_itemgetter(operator.itemgetter(*key), len(key))
self.index_entries = {}
self.tuple_index_entries = {}
self.clear_track_data()
self.pivot_rows = tuple(rows)
self.pivot_columns = tuple(columns)
self.pivot_frozen = tuple(frozen)
self.frozen_value = tuple(frozen_value)
self._used_index_values = used_index_values
# get all index values from data
for i, c in enumerate(self.index_names):
name = self._unique_name_2_name[c]
if name in self.index_entries:
self.index_entries[name].update(set(d[i] for d in self._data.keys()))
else:
self.index_entries[name] = set(d[i] for d in self._data.keys())
self._added_index_entries[name] = set()
self._deleted_index_entries[name] = set()
for k, v in index_entries.items():
# name = self._unique_name_2_name[k]
if k in self.index_entries:
self.index_entries[k].update(set(v))
# add tuple entries
for k, v in tuple_index_entries.items():
keys = tuple(self._index_ind[i] for i in k)
getter = tuple_itemgetter(operator.itemgetter(*keys), len(keys))
self.tuple_index_entries[k] = set(getter(key) for key in self._data.keys())
self.tuple_index_entries[k].update(v)
self.set_pivot(rows, columns, frozen, frozen_value)
self._model_is_updating = False
@staticmethod
[docs] def _is_invalid_pivot(rows, columns, frozen, frozen_value, index_names):
"""checks if given pivot is valid for index_names,
returns str with error message if invalid else None"""
error = None
if not len(set(index_names)) == len(index_names):
error = "'index_names' must contain only unique strings"
if not all(i in index_names for i in frozen):
error = "'frozen' contains strings that doesn't match with current 'index_names'"
if not all(i in index_names for i in rows):
error = "'rows' contains strings that doesn't match with current 'index_names'"
if not all(c in index_names for c in columns):
error = "'columns' contains strings that doesn't match with current 'index_names'"
if len(set(rows + columns + frozen)) != len(index_names):
error = "'rows', 'columns' and 'forzen' must contain all unqiue strings in 'index_names' without duplicates"
if len(frozen) != len(frozen_value):
error = "'frozen_value' must be same length as 'frozen'"
return error
[docs] def _change_index_frozen(self):
"""Filters out data with index values in index_frozen"""
if self.pivot_frozen:
key_getter = self._index_key_getter(self.pivot_frozen)
self._data_frozen_index_values = set(key_getter(k) for k in self._data.keys())
else:
self._data_frozen_index_values = set()
[docs] def _index_key_getter(self, names_of_index):
"""creates a itemgetter that always returns tuples from list of index names"""
keys = tuple(self.index_names.index(i) for i in names_of_index if i in self.index_names)
return tuple_itemgetter(operator.itemgetter(*keys), len(keys))
[docs] def _get_unique_index_values(self, index, filter_index, filter_value):
"""Finds unique index values for index names in index
filtered by index names in filter_index with values in filter_value"""
if index:
index_getter = self._index_key_getter(index)
if filter_index:
frozen_getter = self._index_key_getter(filter_index)
index_header_values = set(
index_getter(k) for k in self._data.keys() if frozen_getter(k) == filter_value
)
else:
index_header_values = set(index_getter(k) for k in self._data.keys())
else:
index_header_values = set()
return index_header_values
[docs] def set_pivot(self, rows, columns, frozen, frozen_value):
"""Sets pivot for current data"""
pivot_error = self._is_invalid_pivot(rows, columns, frozen, frozen_value, self.index_names)
if pivot_error:
raise ValueError(pivot_error)
if not self._model_is_updating:
# check if pivot has changed
if self.pivot_rows == rows and self.pivot_columns == columns and self.pivot_frozen == frozen:
if frozen_value == self.frozen_value:
# nothing has changed
return
self.pivot_rows = tuple(rows)
self.pivot_columns = tuple(columns)
self.pivot_frozen = tuple(frozen)
self.frozen_value = tuple(frozen_value)
# set key_getter so that you will get a correct key for _data when doing
# _key_getter(row_key + col_key + frozen_value)
order = tuple(self.index_names.index(i) for i in self.pivot_rows + self.pivot_columns + self.pivot_frozen)
order = tuple(sorted(range(len(order)), key=order.__getitem__))
self._key_getter = tuple_itemgetter(operator.itemgetter(*order), len(order))
# find unique set of tuples for row and column headers from data with given pivot
# row indexes
self._row_data_header_set = self._get_unique_index_values(self.pivot_rows, self.pivot_frozen, self.frozen_value)
# column indexes
self._column_data_header_set = self._get_unique_index_values(
self.pivot_columns, self.pivot_frozen, self.frozen_value
)
# add tuple index entries to rows and column
# rows
new_row_keys, new_row_none_keys, new_entries = self._index_entries_without_data(
self.pivot_rows, self._row_data_header_set, self.pivot_frozen, self.frozen_value, self.tuple_index_entries
)
for name, value in new_entries.items():
name = self._unique_name_2_name[name]
self.index_entries[name].update(value)
# columns
new_column_keys, new_column_none_keys, new_entries = self._index_entries_without_data(
self.pivot_columns,
self._column_data_header_set,
self.pivot_frozen,
self.frozen_value,
self.tuple_index_entries,
)
for name, value in new_entries.items():
name = self._unique_name_2_name[name]
self.index_entries[name].update(value)
# add values
self._row_data_header_set.update(new_row_keys)
self._column_data_header_set.update(new_column_keys)
self._row_data_header = sorted(self._row_data_header_set)
self._column_data_header = sorted(self._column_data_header_set)
len_valid_rows = len(self._row_data_header)
len_valid_columns = len(self._column_data_header)
# values with None keys
none_rows = sorted(new_row_none_keys, key=lambda x: tuple((i is None, i) for i in x))
none_columns = sorted(new_column_none_keys, key=lambda x: tuple((i is None, i) for i in x))
# add to header data
self._row_data_header.extend(none_rows)
self._column_data_header.extend(none_columns)
# self._change_index_frozen()
# set invalid data to indexes with none in them.
self._invalid_row = set(i + len_valid_rows for i, key in enumerate(none_rows))
self._invalid_column = set(i + len_valid_columns for i, key in enumerate(none_columns))
self._invalid_data = {}
[docs] def set_frozen_value(self, value):
"""Sets the value of the frozen indexes"""
if len(value) != len(self.pivot_frozen):
raise ValueError("'value' must have same lenght as 'self.pivot_frozen'")
if value == self.frozen_value:
# same as previous do nothing
return
# self.frozen_value = tuple(value)
self.set_pivot(self.pivot_rows, self.pivot_columns, self.pivot_frozen, value)
@staticmethod
[docs] def _index_entries_without_data(pivot_index, pivot_set, filter_index, filter_value, tuple_index_entries):
"""find values in tuple_index_entries that are not present in pivot_set for index in pivot index
filtered by filter_index and filter_value"""
# new unique values for pivot_index
new_keys = set()
new_none_keys = set() # can contain None
# keep track of new individual index entries
new_entries = {name: set() for name in pivot_index}
for k in tuple_index_entries.keys():
if set(k).issubset(filter_index + pivot_index) and not set(filter_index).issuperset(k):
# tuple_index_entries names are all in given index, i.e. don't add indexes that are split
position = [i for i, name in enumerate(k) if name in pivot_index]
position_current = [pivot_index.index(name) for name in k if name in pivot_index]
position_current_frozen = [filter_index.index(name) for name in k if name in filter_index]
getter_current = tuple_itemgetter(operator.itemgetter(*tuple(position_current)), len(position_current))
getter = operator.itemgetter(*tuple(position))
v = set(tuple_index_entries[k])
if filter_index and position_current_frozen:
# one or more of the index entries are in the filter column,
# keep only those with same value as filter_value
position_frozen = [i for i, name in enumerate(k) if name in filter_index]
getter_frozen_current = tuple_itemgetter(
operator.itemgetter(*tuple(position_current_frozen)), len(position_current_frozen)
)
getter_frozen = tuple_itemgetter(operator.itemgetter(*tuple(position_frozen)), len(position_frozen))
v = set(getter(i) for i in v if getter_frozen(i) == getter_frozen_current(filter_value))
k = tuple(i for i in k if i not in filter_index)
# find unique values for with subset in tuple_index_entries
current_set = set(getter_current(d) for d in pivot_set)
v = v.difference(current_set)
# create new values that are the same length and order as pivot_index
none_key = [None for _ in pivot_index]
for key in v:
if not isinstance(key, tuple):
key = (key,)
new_key = none_key
for i, ki in enumerate(position_current):
new_key[ki] = key[i]
new_entries[k[i]].add(key[i])
if len(key) == len(pivot_index):
new_keys.add(tuple(new_key))
else:
new_none_keys.add(tuple(new_key))
return new_keys, new_none_keys, new_entries
[docs] def get_pivoted_data(self, row_mask, col_mask):
"""gets data from current pivot with indexes in row_mask and col_mask"""
if not self._row_data_header and not self._column_data_header:
if self.pivot_frozen and len(self.pivot_frozen) == len(self.index_names):
# special case when all indexes are in pivot forzen
return [[self._data.get(self._key_getter(self.frozen_value), None)]]
# no data
return []
if self.pivot_rows and any(r >= len(self._row_data_header) or r < 0 for r in row_mask):
raise ValueError("row_mask contains invalid indexes to current row pivot")
if self.pivot_columns and any(c >= len(self._column_data_header) or c < 0 for c in col_mask):
raise ValueError("col_mask contains invalid indexes to current row pivot")
data = []
for row in row_mask:
data_row = []
invalid_row = row in self._invalid_row
row_key = self.row(row)
for col in col_mask:
if invalid_row or col in self._invalid_column:
# get invalid data
data_row.append(self._invalid_data.get((row, col), None))
else:
# get dict data
col_key = self.column(col)
key = self._key_getter(row_key + col_key + self.frozen_value)
data_row.append(self._data.get(key, None))
data.append(data_row)
return data
[docs] def set_pivoted_data(self, data, row_mask, col_mask):
"""paste list of lists into current pivot, no change of indexes,
row_mask list of indexes where to paste data rows in current pivot
col_mask list of indexes where to paste data columns in current pivot"""
if (not data) or (len(data) == 1 and not data[0]):
return
if len(data) != len(row_mask):
raise ValueError('row_mask must be same length as data')
if not all(len(row) == len(col_mask) for row in data):
raise ValueError('col_mask must be same length as each sublist in data')
# keep only valid indexes
if self.pivot_rows and self.pivot_columns:
data = [
[col for c, col in zip(col_mask, row) if c < len(self._column_data_header)]
for r, row in zip(row_mask, data)
if r < len(self._row_data_header)
]
row_mask = [r for r in row_mask if r < len(self._row_data_header)]
col_mask = [r for r in col_mask if r < len(self._column_data_header)]
elif self.pivot_rows and not self.pivot_columns:
# only row data
data = [
[col for c, col in zip(col_mask, row) if c == 0]
for r, row in zip(row_mask, data)
if r < len(self._row_data_header)
]
row_mask = [r for r in row_mask if r < len(self._row_data_header)]
col_mask = [r for r in col_mask if r == 0]
elif self.pivot_columns and not self.pivot_rows:
# only col data
data = [
[col for c, col in zip(col_mask, row) if c < len(self._column_data_header)]
for r, row in zip(row_mask, data)
if r == 0
]
row_mask = [r for r in row_mask if r == 0]
col_mask = [r for r in col_mask if r < len(self._column_data_header)]
for row, row_value in zip(row_mask, data):
invalid_row = row in self._invalid_row
row_key = self.row(row)
for col, paste_value in zip(col_mask, row_value):
col_key = self.column(col)
if invalid_row or col in self._invalid_column:
# row or col invalid, put data in invald data dict
invalid_index = (row, col)
if not paste_value or paste_value.isspace():
# value is None or whitspace remove any existing data
self._invalid_data.pop(invalid_index, None)
else:
# update invalid data
self._invalid_data[invalid_index] = paste_value
else:
# valid index, insert data into dict
key = self._key_getter(row_key + col_key + self.frozen_value)
if not paste_value or paste_value.isspace():
# value is None or whitspace remove any existing data
self._delete_data(key)
else:
# update data
self._add_data(key, paste_value)
[docs] def _add_index_value(self, value, name):
name = self._unique_name_2_name[name]
if value in self.index_entries[name]:
# value for index already exists, no need to add.
return True
# check if value for index 'name' is already in use.
for k, v in self._used_index_values.items():
if name in k and value in v:
# value is already in use
return False
# check if new value is valid for index.
if not self.is_valid_index(value, name):
return False
# add to existing entries.
self.index_entries[name].add(value)
# add new value to used names
for k, v in self._used_index_values.items():
if name in k:
v.add(value)
# reomve from deleted values
if name in self._deleted_index_entries and value in self._deleted_index_entries[name]:
# value was deleted, now readded, don't add to _added_index_entries
self._deleted_index_entries[name].discard(value)
else:
# add to added values
if name not in self._added_index_entries:
self._added_index_entries[name] = set()
self._added_index_entries[name].add(value)
return True
[docs] def _delete_data(self, key):
# value is None or whitspace remove any existing data
if key in self._edit_data:
# data was edited, track original value
if self._edit_data[key] and key not in self._deleted_data:
# there was data in _edit_data, store original value
self._deleted_data[key] = self._edit_data.pop(key)
else:
# data was not edited, track existing data
if key in self._data and key not in self._deleted_data:
self._deleted_data[key] = self._data.pop(key)
self._edit_data.pop(key, None)
self._data.pop(key, None)
[docs] def _add_data(self, key, value):
old_value = None
if key in self._deleted_data:
# data was deleted before
old_value = self._deleted_data.pop(key)
if not old_value == value:
# new value, set edit data to old value
self._edit_data[key] = old_value
else:
# data new or edit data
if key in self._edit_data:
# data has been edited before
if self._edit_data[key] == value:
# same value as original, delete from edit
self._edit_data.pop(key)
else:
if value != self._data.get(key, None):
# new value is not same as previous
self._edit_data[key] = self._data.get(key, None)
self._data[key] = value
[docs] def _restore_data(self, key):
if key in self._deleted_data:
# data was deleted, add deleted data
self._add_data(key, self._deleted_data[key])
elif key in self._edit_data:
# data was edited
value = self._edit_data[key]
if value is None:
# no previous data, delete
self._delete_data(key)
else:
# readd previous data
self._add_data(key, value)
[docs] def row(self, row):
if self.pivot_rows:
if self._row_data_header:
return self._row_data_header[row]
return tuple(None for _ in self.pivot_rows)
if row == 0:
return ()
raise IndexError('index out of range for current row pivot')
[docs] def column(self, col):
if self.pivot_columns:
if self._column_data_header:
return self._column_data_header[col]
return tuple(None for _ in self.pivot_columns)
if col == 0:
return ()
raise IndexError('index out of range for current column pivot')
@property
[docs] def rows(self):
return self._row_data_header
@property
[docs] def columns(self):
return self._column_data_header
[docs] def restore_pivoted_values(self, indexes):
"""Restores all values for given indexes"""
if not all(i[0] <= len(self.rows) or i[0] < 0 or i[1] <= len(self.columns) or i[1] < 0 for i in indexes):
raise ValueError('indexes must be list of valid index for row pivot')
for i in indexes:
key = self._key_getter(self.row(i[0]) + self.column(i[1]) + self.frozen_value)
self._restore_data(key)
[docs] def delete_pivoted_values(self, indexes):
"""Deletes values for given indexes"""
if not all(i[0] <= len(self.rows) or i[0] < 0 or i[1] <= len(self.columns) or i[1] < 0 for i in indexes):
raise ValueError('indexes must be list of valid index for row pivot')
# delete values
for i in indexes:
if i[0] in self._invalid_row or i[1] in self._invalid_column:
# delete invalid data
self._invalid_data.pop(tuple(i), None)
else:
# delete data if exists
key = self._key_getter(self.row(i[0]) + self.column(i[1]) + self.frozen_value)
self._delete_data(key)
[docs] def delete_tuple_index_values(self, delete_tuples):
"""deletes values from keys with combination of indexes given that match tuple_index_entries"""
# delete from tuple indexes
delete_values = set()
delete_values_row = set()
delete_values_column = set()
for tk in self.tuple_index_entries:
for names, indexes in delete_tuples.items():
if set(names) == set(tk):
# reorder to same index order
pos = [tk.index(n) for n in names]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
indexes = set(getter(i) for i in indexes)
remove_set = set(row for row in self.tuple_index_entries[tk] if row in indexes)
self.tuple_index_entries[tk].difference_update(remove_set)
if tk in self._added_tuple_index_entries:
self._added_tuple_index_entries[tk].difference_update(remove_set)
if tk in self._deleted_tuple_index_entries:
self._deleted_tuple_index_entries[tk].update(remove_set)
else:
self._deleted_tuple_index_entries[tk] = remove_set
# delete values from _data
pos = [tk.index(n) for n in self.index_names if n in tk]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
delete_values.update(set(k for k in self._data if getter(k) in indexes))
# delete values from headers
if all(n in self.pivot_rows + self.pivot_frozen for n in tk):
# tuple exists over rows
pos = [tk.index(n) for n in self.pivot_rows if n in tk]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
if any(n in self.pivot_frozen for n in tk):
# has frozen dimension, filter by frozen value
pos_frozen = [tk.index(n) for n in self.pivot_frozen if n in tk]
getter_frozen = tuple_itemgetter(operator.itemgetter(*pos_frozen), len(pos_frozen))
pos_index_frozen = [self.pivot_frozen.index(n) for n in tk if n in self.pivot_frozen]
getter_index_frozen = tuple_itemgetter(
operator.itemgetter(*pos_index_frozen), len(pos_index_frozen)
)
row_indexes = set(
getter(i) for i in indexes if getter_frozen(i) == getter_index_frozen(self.frozen_value)
)
else:
row_indexes = set(getter(i) for i in indexes)
pos = [self.pivot_rows.index(n) for n in self.pivot_rows if n in tk]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
delete_values_row.update(set(n for n in self._row_data_header if getter(n) in row_indexes))
# delete values from column headers
if all(n in self.pivot_columns + self.pivot_frozen for n in tk):
# tuple exists over columns
pos = [tk.index(n) for n in self.pivot_columns if n in tk]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
if any(n in self.pivot_frozen for n in tk):
# has frozen dimension, filter by frozen value
pos_frozen = [tk.index(n) for n in self.pivot_frozen if n in tk]
getter_frozen = tuple_itemgetter(operator.itemgetter(*pos_frozen), len(pos_frozen))
pos_index_frozen = [self.pivot_frozen.index(n) for n in tk if n in self.pivot_frozen]
getter_index_frozen = tuple_itemgetter(
operator.itemgetter(*pos_index_frozen), len(pos_index_frozen)
)
column_indexes = set(
getter(i) for i in indexes if getter_frozen(i) == getter_index_frozen(self.frozen_value)
)
else:
column_indexes = set(getter(i) for i in indexes)
pos = [self.pivot_columns.index(n) for n in self.pivot_columns if n in tk]
getter = tuple_itemgetter(operator.itemgetter(*pos), len(pos))
delete_values_column.update(
set(n for n in self._column_data_header if getter(n) in column_indexes)
)
if delete_values:
# delete values from data dict
for k in delete_values:
self._delete_data(k)
# delete from index headers
if delete_values_row:
for i, key in reversed(list(enumerate(self._row_data_header))):
if key in delete_values_row:
del_key = self._row_data_header.pop(i)
self._row_data_header_set.discard(del_key)
if delete_values_column:
for i, key in reversed(list(enumerate(self._column_data_header))):
if key in delete_values_column:
del_key = self._column_data_header.pop(i)
self._column_data_header_set.discard(del_key)
[docs] def delete_index_values(self, delete_indexes):
"""delete one ore more index value from data"""
delete_values = {}
delete_values_row = {}
delete_values_column = {}
for k, indexes in delete_indexes.items():
if k not in self.index_real_names or not indexes:
continue
dv = set(indexes)
deleted_entries = dv.intersection(self.index_entries[k])
if not deleted_entries:
# deleted entries not in index, do nothing:
continue
# uppdate existing entries
self.index_entries[k].difference_update(deleted_entries)
k_unique = [u for u, v in self._unique_name_2_name.items() if v == k]
for u in k_unique:
if u in self.pivot_rows:
delete_values_row[self.pivot_rows.index(u)] = deleted_entries
if u in self.pivot_columns:
delete_values_column[self.pivot_columns.index(u)] = deleted_entries
# add existing entries to deleted entries
self._deleted_index_entries[k].update(deleted_entries)
# remove any entries in added indexes
self._added_index_entries[k].difference_update(deleted_entries)
# remove only entries that was deleted from index_entries from used values
for u_name, v in self._used_index_values.items():
if k in u_name and deleted_entries:
v.difference_update(deleted_entries)
for u in k_unique:
delete_values[self.index_names.index(u)] = deleted_entries
# delete from tuple indexes
for tk in self.tuple_index_entries:
# real names
tk_real = [self._unique_name_2_name[t] for t in tk]
for k, indexes in delete_indexes.items():
if k in tk_real:
# all indexes of real name index
pos = [i for i, x in enumerate(tk_real) if x == k]
remove_set = set(row for row in self.tuple_index_entries[tk] if any(row[p] in indexes for p in pos))
self.tuple_index_entries[tk].difference_update(remove_set)
if tk in self._added_tuple_index_entries:
self._added_tuple_index_entries[tk].difference_update(remove_set)
if tk in self._deleted_tuple_index_entries:
self._deleted_tuple_index_entries[tk].update(remove_set)
else:
self._deleted_tuple_index_entries[tk] = remove_set
if delete_values:
# delete values from data dict
delete_keys = []
for key in self._data.keys():
for ind, values in delete_values.items():
if key[ind] in values:
delete_keys.append(key)
for key in delete_keys:
self._delete_data(key)
# delete from index headers
del_i = set()
if delete_values_row:
for i, key in reversed(list(enumerate(self._row_data_header))):
for ind, values in delete_values_row.items():
if key[ind] in values and i not in del_i:
del_key = self._row_data_header.pop(i)
self._row_data_header_set.discard(del_key)
del_i.add(i)
del_i = set()
if delete_values_column:
for i, key in reversed(list(enumerate(self._column_data_header))):
for ind, values in delete_values_column.items():
if key[ind] in values and i not in del_i:
del_key = self._column_data_header.pop(i)
self._column_data_header_set.discard(del_key)
del_i.add(i)
[docs] def paste_data(
self,
row_start=0,
row_header_data=None,
col_start=0,
col_header_data=None,
data=None,
row_mask=None,
col_mask=None,
):
"""Paste a list of list into current view of AbstractTable"""
if row_mask is None:
row_mask = list()
if col_mask is None:
col_mask = list()
if row_header_data is not None and row_header_data:
edit_rows, add_rows = self._data_to_header(
row_header_data, row_start, self._row_data_header, self.pivot_rows, row_mask, "row"
)
self.edit_index(edit_rows + add_rows, row_mask, "row")
if col_header_data is not None and col_header_data and col_header_data[0]:
edit_columns, add_columns = self._data_to_header(
col_header_data, col_start, self._column_data_header, self.pivot_columns, col_mask, "column"
)
self.edit_index(edit_columns + add_columns, col_mask, "column")
# paste data
if data is not None and data:
self.set_pivoted_data(data, row_mask, col_mask)
[docs] def edit_index(self, new_index, index_mask, direction):
"""Edits the index of either row or column"""
if len(new_index) != len(index_mask):
raise ValueError('index_mask must be same length as new_index')
if direction == "row":
index_name = self.pivot_rows
other_index_name = self.pivot_columns
edit_index = self._row_data_header
edit_index_set = self._row_data_header_set
invalid_set = self._invalid_row
other_index = self._column_data_header
other_invalid_set = self._invalid_column
order_getter = operator.itemgetter(*(0, 1))
key_getter = self._key_getter
elif direction == "column":
index_name = self.pivot_columns
other_index_name = self.pivot_rows
edit_index = self._column_data_header
edit_index_set = self._column_data_header_set
invalid_set = self._invalid_column
other_index = self._row_data_header
other_invalid_set = self._invalid_row
order_getter = operator.itemgetter(*(1, 0))
order = tuple(self.index_names.index(i) for i in self.pivot_columns + self.pivot_rows + self.pivot_frozen)
order = tuple(sorted(range(len(order)), key=order.__getitem__))
key_getter = operator.itemgetter(*order)
else:
raise ValueError('parameter direction must be "row" or "column"')
if not other_index_name:
other_index = [()]
# insert new index entites
new_indexes = {}
for i, name in enumerate(index_name):
for r in new_index:
self._add_index_value(r[i], name)
# update tuple entities
for k in self.tuple_index_entries:
if set(k).issubset(index_name + self.pivot_frozen) and not set(self.pivot_frozen).issuperset(k):
names = [n for n in index_name + self.pivot_frozen]
valid = [(i, names.index(kn)) for i, kn in enumerate(k) if kn in names]
keys = tuple(v[1] for v in valid)
names = tuple(k[v[0]] for v in valid)
getter = tuple_itemgetter(operator.itemgetter(*keys), len(keys))
for line in new_index:
new_tuple = getter(tuple(line) + self.frozen_value)
if all(i in self.index_entries[self._unique_name_2_name[n]] for i, n in zip(new_tuple, names)):
# all indexes are valid
if new_tuple not in self.tuple_index_entries[k]:
# new tuple, save
self.tuple_index_entries[k].add(new_tuple)
if k in self._added_tuple_index_entries:
self._added_tuple_index_entries[k].add(new_tuple)
else:
self._added_tuple_index_entries[k] = set([new_tuple])
# change data values since their index is changed
for i, new_key in zip(index_mask, new_index):
if i < len(edit_index) and edit_index[i] == new_key:
# same as old key, do nothing
continue
if self.is_valid_key(new_key, edit_index_set, index_name):
# key is valid
edit_index_set.add(new_key)
if i >= len(edit_index):
# outside old data do nothing
continue
if i in invalid_set:
# previous key was invalid move data from invalid to valid
invalid_set.discard(i)
for c, other_key in enumerate(other_index):
row_col_index = order_getter((i, c))
if c not in other_invalid_set and row_col_index in self._invalid_data:
key = key_getter(new_key + other_key + self.frozen_value)
value = self._invalid_data.pop(row_col_index)
self._add_data(key, value)
else:
# previous key vas valid, move data to new key
old_index_key = edit_index[i]
edit_index_set.remove(old_index_key)
for c, other_key in enumerate(other_index):
old_key = key_getter(old_index_key + other_key + self.frozen_value)
if c not in other_invalid_set and old_key in self._data:
key = key_getter(new_key + other_key + self.frozen_value)
old_val = self._data[old_key]
self._delete_data(old_key)
self._add_data(key, old_val)
else:
# key is invalid
if i < len(edit_index):
old_index_key = edit_index[i]
if old_index_key in edit_index_set and i not in invalid_set:
# previous key was valid, remove from set
edit_index_set.remove(old_index_key)
if i not in invalid_set:
# move data to invalid data
old_index_key = edit_index[i]
for c, other_key in enumerate(other_index):
old_key = key_getter(old_index_key + other_key + self.frozen_value)
if c not in other_invalid_set and old_key in self._data:
row_col_index = order_getter((i, c))
self._invalid_data[row_col_index] = self._data[old_key]
self._delete_data(old_key)
invalid_set.add(i)
# add new values
for i, new_key in zip(index_mask, new_index):
if i < len(edit_index):
edit_index[i] = new_key
else:
edit_index.append(new_key)
# update header arrays
if direction == "row":
self._row_data_header = edit_index
self._row_data_header_set = edit_index_set
self._invalid_row = invalid_set
elif direction == "column":
self._column_data_header = edit_index
self._column_data_header_set = edit_index_set
self._invalid_column = invalid_set
return new_indexes
[docs] def is_valid_index(self, index, index_name):
"""checks if if given index value is a valid value for given index"""
if not index:
# index value cannot be empty/None
return False
if not isinstance(index, self._index_type[index_name]):
# index is not correct type
return False
if index_name in self._valid_index_values and self._valid_index_values[index_name]:
# check if there is any valid values for index
if index not in self._valid_index_values[index_name]:
# index is not in valid values
return False
return True
[docs] def is_valid_key(self, key, existing_keys, key_names):
"""Checks if given key (combination of indexes) is valid"""
real_names = [self._unique_name_2_name[name] for name in key_names]
if not all(index in self.index_entries[index_name] for index, index_name in zip(key, real_names)):
return False
if key in existing_keys:
# key cannot be a duplicate of existing keys in index.
return False
return True
[docs]class PivotTableModel(QAbstractTableModel):
[docs] index_entries_changed = Signal(dict, dict)
def __init__(self, parent=None):
super(PivotTableModel, self).__init__(parent)
self.model = PivotModel()
self._data_header = [[]]
self._num_headers_row = 0
self._num_headers_column = 0
self._plot_x_column = None
[docs] def set_data(
self,
data,
index_names,
index_type,
rows=(),
columns=(),
frozen=(),
frozen_value=(),
index_entries=None,
valid_index_values=None,
tuple_index_entries=None,
used_index_values=None,
index_real_names=None,
):
if index_entries is None:
index_entries = dict()
if valid_index_values is None:
valid_index_values = dict()
if tuple_index_entries is None:
tuple_index_entries = dict()
if used_index_values is None:
used_index_values = dict()
if index_real_names is None:
index_real_names = list()
self.beginResetModel()
self.model.set_new_data(
data,
index_names,
index_type,
rows,
columns,
frozen,
frozen_value,
index_entries,
valid_index_values,
tuple_index_entries,
used_index_values,
index_real_names,
)
self._plot_x_column = None
self._update_header_data()
self.endResetModel()
[docs] def set_pivot(self, rows, columns, frozen, frozen_value):
self.beginResetModel()
self.model.set_pivot(rows, columns, frozen, frozen_value)
self._update_header_data()
self.endResetModel()
[docs] def set_frozen_value(self, frozen_value):
self.beginResetModel()
self.model.set_frozen_value(frozen_value)
self._update_header_data()
self.endResetModel()
[docs] def delete_values(self, indexes):
# transform to PivotModel index
indexes = self._indexes_to_pivot_index(indexes)
self.beginResetModel()
self.model.delete_pivoted_values(indexes)
self.endResetModel()
[docs] def delete_index_values(self, keys_dict):
add_index = {k: len(v) for k, v in self.model._added_index_entries.items()}
del_index = {k: len(v) for k, v in self.model._deleted_index_entries.items()}
self.beginResetModel()
self.model.delete_index_values(keys_dict)
self.endResetModel()
new_indexes = {}
deleted_indexes = {}
for k, v in self.model._added_index_entries.items():
if k in add_index and not len(v) == add_index[k]:
new_indexes[k] = set(v)
for k, v in self.model._deleted_index_entries.items():
if k in add_index and not len(v) == del_index[k]:
deleted_indexes[k] = set(v)
if new_indexes or deleted_indexes:
self.index_entries_changed.emit(new_indexes, deleted_indexes)
[docs] def delete_tuple_index_values(self, tuple_key_dict):
self.beginResetModel()
self.model.delete_tuple_index_values(tuple_key_dict)
self.endResetModel()
[docs] def restore_values(self, indexes):
indexes = self._indexes_to_pivot_index(indexes)
self.beginResetModel()
self.model.restore_pivoted_values(indexes)
self.endResetModel()
[docs] def get_key(self, index):
row = self.model.row(max(0, index.row() - self._num_headers_row))
col = self.model.column(max(0, index.column() - self._num_headers_column))
return self.model._key_getter(row + col + self.model.frozen_value)
[docs] def get_col_key(self, column):
return self.model.column(max(0, column - self._num_headers_column))
[docs] def paste_data(self, index, data, row_mask, col_mask):
"""paste data into pivot model"""
row_header_data = []
col_header_data = [[]]
skip_cols = max(0, self._num_headers_column - index.column())
skip_rows = max(0, self._num_headers_row - index.row())
if self.model.pivot_columns and index.row() < self._num_headers_row:
# extract data for column headers
if not self.model.pivot_rows or not index.row() == self._num_headers_row - 1:
col_header_data = [line[skip_cols:] for line in data[:skip_rows]]
if self.model.pivot_rows and index.column() < self._num_headers_column:
# extract data for row headers
row_header_data = [data[r][:skip_cols] for r in range(skip_rows, len(data))]
# extract data for pasting in values
value_data = [line[skip_cols:] for line in data[skip_rows:]]
if not value_data:
value_data = [[]]
# translate mask into pivot index
row_mask = [r - self._num_headers_row for r in row_mask if r >= self._num_headers_row]
col_mask = [c - self._num_headers_column for c in col_mask if c >= self._num_headers_column]
new_rows = max(len(value_data), len(row_header_data)) - len(row_mask)
new_cols = max(len(value_data[0]), len(col_header_data[0])) - len(col_mask)
# extend mask if new values are given
if new_rows > 0:
row_mask.extend(list(range(len(self.model.rows), len(self.model.rows) + new_rows)))
if new_cols > 0:
col_mask.extend(list(range(len(self.model.columns), len(self.model.columns) + new_cols)))
add_index = {k: len(v) for k, v in self.model._added_index_entries.items()}
del_index = {k: len(v) for k, v in self.model._deleted_index_entries.items()}
self.beginResetModel()
self.model.paste_data(
index.column(), row_header_data, index.row(), col_header_data, value_data, row_mask, col_mask
)
self.endResetModel()
new_indexes = {}
deleted_indexes = {}
for k, v in self.model._added_index_entries.items():
if k in add_index and not len(v) == add_index[k]:
new_indexes[k] = set(v)
for k, v in self.model._deleted_index_entries.items():
if k in add_index and not len(v) == del_index[k]:
deleted_indexes[k] = set(v)
if new_indexes or deleted_indexes:
self.index_entries_changed.emit(new_indexes, deleted_indexes)
[docs] def _indexes_to_pivot_index(self, indexes):
max_row = len(self.model.rows)
max_col = len(self.model.columns)
if not self.model.pivot_rows:
max_row = 1
if not self.model.pivot_columns:
max_col = 1
indexes = [
(i.row() - self._num_headers_row, i.column() - self._num_headers_column)
for i in indexes
if (i.row() >= self._num_headers_row and i.row() - self._num_headers_row < max_row)
and (i.column() >= self._num_headers_column and i.column() - self._num_headers_column < max_col)
]
return indexes
[docs] def first_data_row(self):
"""Returns the row index to the first data row."""
# Last row is an empty row, exclude it.
return self.rowCount() - self.dataRowCount() - 1
[docs] def dataRowCount(self):
"""number of rows that contains actual data"""
return len(self.model.rows)
[docs] def dataColumnCount(self):
"""number of columns that contains actual data"""
return len(self.model.columns)
[docs] def rowCount(self, parent=QModelIndex()):
"""Number of rows in table, number of header rows + datarows + 1 empty row"""
return self._num_headers_row + self.dataRowCount() + 1
[docs] def columnCount(self, parent=QModelIndex()):
"""Number of columns in table, number of header columns + datacolumns + 1 empty columns"""
return self._num_headers_column + self.dataColumnCount() + 1
[docs] def flags(self, index):
"""Roles for data"""
if index.row() < self._num_headers_row and index.column() < self._num_headers_column:
return super(PivotTableModel, self).flags(index)
if (
self.model.pivot_rows
and self.model.pivot_columns
and index.row() == self._num_headers_row - 1
and index.column() >= self._num_headers_column
):
# empty line between column headers and data
return Qt.ItemIsSelectable | Qt.ItemIsEnabled
return Qt.ItemIsEditable | Qt.ItemIsEnabled | Qt.ItemIsSelectable
[docs] def index_in_top_left(self, index):
"""check if index is in top left corner, where pivot names are displayed"""
return index.row() < self._num_headers_row and index.column() < self._num_headers_column
[docs] def index_in_data(self, index):
"""check if index is in data area"""
if (
self.dataRowCount() == 0
and self.model.pivot_rows
or self.dataColumnCount() == 0
and self.model.pivot_columns
):
# no data
return False
return (
index.row() >= self._num_headers_row
and index.column() >= self._num_headers_column
and index.row() < self._num_headers_row + max(1, self.dataRowCount())
and index.column() < self._num_headers_column + max(1, self.dataColumnCount())
)
[docs] def index_in_column_headers(self, index):
"""check if index is in column headers (horizontal) area"""
return (
index.row() < self._num_headers_row
and index.column() >= self._num_headers_column
and index.column() < self.columnCount() - 1
)
)
[docs] def set_plot_x_column(self, column, is_x):
"""Sets or clears the Y flag on a column"""
if is_x:
self._plot_x_column = column
elif column == self._plot_x_column:
self._plot_x_column = None
self.headerDataChanged.emit(Qt.Horizontal, column, column)
@property
[docs] def plot_x_column(self):
"""Returns the index of the column designated as Y values for plotting or None."""
return self._plot_x_column
[docs] def set_index_key(self, index, value, direction):
"""edits/sets a index value in a index in row/column"""
# TODO: change this to insertRow/Column instead when creating new rows
self.beginResetModel()
if not value or value.isspace():
# empty do nothing
return False
if direction == "column":
header_ind = index.row()
index_ind = index.column() - self._num_headers_column
index_name = self.model.pivot_columns[header_ind]
if len(self.model.columns) <= index_ind:
# edited index outside, add new column
old_key = [None for _ in range(len(self.model.pivot_columns))]
else:
old_key = self.model.column(index_ind)
elif direction == "row":
header_ind = index.column()
index_ind = index.row() - self._num_headers_row
index_name = self.model.pivot_rows[header_ind]
if len(self.model.rows) <= index_ind:
# edited index outside, add new column
old_key = [None for _ in range(len(self.model.pivot_rows))]
else:
old_key = self.model.row(index_ind)
else:
raise ValueError('parameter direction must be "row" or "column"')
# check if value should be int
if index_name in self.model._index_type and self.model._index_type[index_name] == int and value.isdigit():
value = int(value)
# update value
new_key = list(old_key)
new_key[header_ind] = value
new_key = tuple(new_key)
# change index values
add_index = {k: len(v) for k, v in self.model._added_index_entries.items()}
del_index = {k: len(v) for k, v in self.model._deleted_index_entries.items()}
self.model.edit_index([new_key], [index_ind], direction)
self.endResetModel()
self.dataChanged.emit(index, index)
# self.update_index_entries(new_key_entries)
# check if any index has been updated
new_indexes = {}
deleted_indexes = {}
for k, v in self.model._added_index_entries.items():
if k in add_index and not len(v) == add_index[k]:
new_indexes[k] = set(v)
for k, v in self.model._deleted_index_entries.items():
if k in add_index and not len(v) == del_index[k]:
deleted_indexes[k] = set(v)
if new_indexes or deleted_indexes:
self.index_entries_changed.emit(new_indexes, deleted_indexes)
return True
[docs] def setData(self, index, value, role=Qt.EditRole):
if role == Qt.EditRole:
if self.index_in_data(index):
# edit existing data
self.model.set_pivoted_data(
[[value]], [index.row() - self._num_headers_row], [index.column() - self._num_headers_column]
)
return True
if index.row() == self.rowCount() - 1 and index.column() < self._num_headers_column:
# add new row if there are any indexes on the row
if self.model.pivot_rows:
return self.set_index_key(index, value, "row")
elif index.column() == self.columnCount() - 1 and index.row() < self._num_headers_row:
# add new column if there are any columns on the pivot
if self.model.pivot_columns:
return self.set_index_key(index, value, "column")
elif (
index.row() < self._num_headers_row - min(1, self.dataRowCount())
and index.column() >= self._num_headers_column
and index.column() < self.columnCount() - 1
):
# edit column key
return self.set_index_key(index, value, "column")
elif self.index_in_row_headers(index):
# edit row key
return self.set_index_key(index, value, "row")
return False
[docs] def data(self, index, role=Qt.DisplayRole):
if role in (Qt.DisplayRole, Qt.EditRole):
if self.index_in_data(index):
# get values
data = self.model.get_pivoted_data(
[index.row() - self._num_headers_row], [index.column() - self._num_headers_column]
)
if not data or data[0][0] is None:
return ''
data = data[0][0]
if role == Qt.EditRole:
return format_for_EditRole(data)
return format_for_DisplayRole(data)
if self.index_in_column_headers(index):
# draw column header values
if not self.model.pivot_rows:
# when special case when no pivot_index, no empty line padding
return self.model._column_data_header[index.column() - self._num_headers_column][index.row()]
if index.row() < self._num_headers_row - 1:
return self.model._column_data_header[index.column() - self._num_headers_column][index.row()]
elif self.index_in_row_headers(index):
# draw index values
return self.model._row_data_header[index.row() - self._num_headers_row][index.column()]
elif index.row() < self._num_headers_row and index.column() < self._num_headers_column:
# draw header values
return self._data_header[index.row()][index.column()]
else:
return None
elif role == Qt.FontRole:
if self.index_in_top_left(index):
font = QFont()
font.setBold(True)
return font
elif role == Qt.BackgroundColorRole:
return self.data_color(index)
elif role == Qt.ToolTipRole:
if self.index_in_data(index):
data = self.model.get_pivoted_data(
[index.row() - self._num_headers_row], [index.column() - self._num_headers_column]
)
if not data or data[0][0] is None:
return None
data = data[0][0]
return format_for_ToolTipRole(data)
else:
return None
[docs] def data_color(self, index):
if self.index_in_data(index):
# color edited values
r = index.row() - self._num_headers_row
c = index.column() - self._num_headers_column
if r in self.model._invalid_row or c in self.model._invalid_column:
# invalid data, color grey
return QColor(Qt.lightGray)
row = self.model.row(index.row() - self._num_headers_row)
col = self.model.column(index.column() - self._num_headers_column)
key = self.model._key_getter(row + col + self.model.frozen_value)
if key in self.model._deleted_data:
# deleted data, color red
return QColor(Qt.red)
if key in self.model._edit_data:
if self.model._edit_data[key] is None:
# new data color green
return QColor(Qt.green)
# edited data color yellow
return QColor(Qt.yellow)
elif self.index_in_column_headers(index):
# color new indexes or invalid indexes "columns"
if index.row() >= len(self.model.pivot_columns):
return
index_name = self.model._unique_name_2_name[self.model.pivot_columns[index.row()]]
key = self.model.column(index.column() - self._num_headers_column)
index_entry = key[index.row()]
if index.column() - self._num_headers_column in self.model._invalid_column and (
not index_entry in self.model.index_entries[index_name] or key in self.model._column_data_header_set
):
# color invalid columns
return QColor(Qt.red)
if index_entry in self.model._added_index_entries[index_name]:
# color added indexes
return QColor(Qt.green)
elif self.index_in_row_headers(index):
# color new indexes or invalid indexes "rows"
index_name = self.model._unique_name_2_name[self.model.pivot_rows[index.column()]]
key = self.model.row(index.row() - self._num_headers_row)
index_entry = key[index.column()]
if index.row() - self._num_headers_row in self.model._invalid_row and (
not index_entry in self.model.index_entries[index_name] or key in self.model._row_data_header_set
):
# invalid index or duplicate key
return QColor(Qt.red)
if index_entry in self.model._added_index_entries[index_name]:
# color added indexes
return QColor(Qt.green)
[docs]class PivotTableSortFilterProxy(QSortFilterProxyModel):
def __init__(self, parent=None):
"""Initialize class."""
super().__init__(parent)
self.setDynamicSortFilter(False) # Important so we can edit parameters in the view
self.index_filters = {}
[docs] def set_filter(self, index_name, filter_value):
self.index_filters[index_name] = filter_value
self.invalidateFilter() # trigger filter update
[docs] def clear_filter(self):
self.index_filters = {}
self.invalidateFilter() # trigger filter update
[docs] def accept_index(self, index, index_names):
accept = True
for i, n in zip(index, index_names):
if self.index_filters.get(n) and i not in self.index_filters[n]:
accept = False
break
return accept
[docs] def delete_values(self, delete_indexes):
delete_indexes = [self.mapToSource(index) for index in delete_indexes]
self.sourceModel().delete_values(delete_indexes)
[docs] def restore_values(self, indexes):
indexes = [self.mapToSource(index) for index in indexes]
self.sourceModel().restore_values(indexes)
[docs] def paste_data(self, index, data):
model_index = self.mapToSource(index)
row_mask = []
# get indexes of filtered rows
# TODO: this might be cached somewhere?
for r in range(model_index.row(), self.sourceModel().dataRowCount() + self.sourceModel()._num_headers_row):
if self.filterAcceptsRow(r, None):
row_mask.append(r)
if len(row_mask) == len(data):
break
col_mask = []
for c in range(
model_index.column(), self.sourceModel().dataColumnCount() + self.sourceModel()._num_headers_column
):
if self.filterAcceptsColumn(c, None):
col_mask.append(c)
if len(col_mask) == len(data[0]):
break
self.sourceModel().paste_data(model_index, data, row_mask, col_mask)
[docs] def filterAcceptsRow(self, source_row, source_parent):
"""Returns true if the item in the row indicated by the given source_row
and source_parent should be included in the model; otherwise returns false.
All the rules and subrules need to pass.
"""
if source_row < self.sourceModel()._num_headers_row or source_row == self.sourceModel().rowCount() - 1:
# always display headers
return True
if source_row in self.sourceModel().model._invalid_row:
return True
if self.sourceModel().model.pivot_rows:
index = self.sourceModel().model._row_data_header[source_row - self.sourceModel()._num_headers_row]
return self.accept_index(index, self.sourceModel().model.pivot_rows)
return True
[docs] def filterAcceptsColumn(self, source_column, source_parent):
"""Returns true if the item in the column indicated by the given source_column
and source_parent should be included in the model; otherwise returns false.
"""
if (
source_column < self.sourceModel()._num_headers_column
or source_column == self.sourceModel().columnCount() - 1
):
# always display headers
return True
if source_column in self.sourceModel().model._invalid_column:
return True
if self.sourceModel().model.pivot_columns:
index = self.sourceModel().model._column_data_header[source_column - self.sourceModel()._num_headers_column]
return self.accept_index(index, self.sourceModel().model.pivot_columns)
return True
[docs]class FilterCheckboxListModel(QAbstractListModel):
def __init__(self, parent=None, show_empty=True):
"""Initialize class."""
super().__init__(parent)
self._data = []
self._data_set = set()
self._all_selected = True
self._empty_selected = True
self._selected = set()
self._selected_filtered = set()
self._list_filter = None
self._index_offset = 2
self._is_filtered = False
self._filter_index = []
self._select_all_str = '(Select All)'
self._show_empty = show_empty
self._empty_str = '(Empty)'
self._add_to_selection_str = 'Add current selection to filter'
self._add_to_selection = False
if self._show_empty:
self._index_offset = 2
else:
self._index_offset = 1
[docs] def reset_selection(self):
self._selected = set(self._data_set)
self._all_selected = True
self._empty_selected = True
[docs] def _select_all_clicked(self):
if self._all_selected:
if self._is_filtered:
self._selected_filtered = set()
else:
self._selected = set()
self._empty_selected = False
else:
if self._is_filtered:
self._selected_filtered = set(self._data[i] for i in self._filter_index)
else:
self._selected = set(self._data_set)
self._empty_selected = True
self._all_selected = not self._all_selected
self.dataChanged.emit(self.index(0, 0), self.index(self.rowCount(), 0), [Qt.CheckStateRole])
[docs] def _is_all_selected(self):
if self._is_filtered:
return len(self._selected_filtered) == len(self._filter_index)
return len(self._selected) == len(self._data_set) and self._empty_selected
[docs] def rowCount(self, parent=QModelIndex()):
if self._is_filtered:
if self._filter_index:
return len(self._filter_index) + self._index_offset
# no filtered values
return 0
return len(self._data) + self._index_offset
[docs] def data(self, index, role=Qt.DisplayRole):
if not index.isValid():
return
if self._is_filtered:
i = 0
if index.row() > 1:
i = self._filter_index[index.row() - self._index_offset]
action_rows = [self._select_all_str, self._add_to_selection_str]
action_state = [self._all_selected, self._add_to_selection]
selected = self._selected_filtered
else:
i = index.row() - self._index_offset
action_rows = [self._select_all_str]
action_state = [self._all_selected]
if self._show_empty:
action_rows.append(self._empty_str)
action_state.append(self._empty_selected)
selected = self._selected
if role == Qt.DisplayRole:
if index.row() >= len(action_rows):
return self._data[i]
return action_rows[index.row()]
if role == Qt.CheckStateRole:
if index.row() < 2:
return action_state[index.row()]
return self._data[i] in selected
[docs] def click_index(self, index):
if index.row() == 0:
self._select_all_clicked()
else:
if index.row() == 1 and self._is_filtered:
self._add_to_selection = not self._add_to_selection
elif index.row() == 1 and self._show_empty:
self._empty_selected = not self._empty_selected
else:
if self._is_filtered:
f_i = self._filter_index[index.row() - self._index_offset]
item = self._data[f_i]
if item in self._selected_filtered:
self._selected_filtered.discard(item)
self._all_selected = False
else:
self._selected_filtered.add(item)
else:
item = self._data[index.row() - self._index_offset]
if item in self._selected:
self._selected.discard(item)
self._all_selected = False
else:
self._selected.add(item)
self._all_selected = self._is_all_selected()
self.dataChanged.emit(index, index, [Qt.CheckStateRole])
self.dataChanged.emit(0, 0, [Qt.CheckStateRole])
[docs] def set_list(self, data, all_selected=True):
self.beginResetModel()
self._data_set = set(data)
self._data = sorted(data)
if all_selected:
self._selected = set(self._data_set)
self._all_selected = True
self._empty_selected = True
else:
self._selected = set()
self._all_selected = False
self._empty_selected = False
self.remove_filter()
self.endResetModel()
[docs] def add_item(self, items, selected=True):
for item in items:
if item not in self._data_set:
pos = bisect.bisect_left(self._data, item)
self.beginInsertRows(self.index(0, 0), pos, pos)
if self._is_filtered and pos is not None:
start_pos = bisect.bisect_left(self._filter_index, pos)
for i in range(start_pos, len(self._filter_index)):
self._filter_index[i] = self._filter_index[i] + 1
if self._list_filter in item:
self._filter_index.insert(start_pos, pos)
self._data.insert(pos, item)
self._data_set.add(item)
if selected:
self._selected.add(item)
if self._is_filtered:
self._selected_filtered.add(item)
self._all_selected = self._is_all_selected()
self.endInsertRows()
[docs] def set_selected(self, selected, select_empty=None):
self.beginResetModel()
self._selected = self._data_set.intersection(selected)
if select_empty is not None:
self._empty_selected = select_empty
self._all_selected = self._is_all_selected()
self.endResetModel()
[docs] def get_selected(self):
return set(self._selected)
[docs] def get_not_selected(self):
if self._all_selected:
return set()
return self._data_set.difference(self._selected)
[docs] def set_filter(self, search_for):
if search_for and (isinstance(search_for, str) and not search_for.isspace()):
self._select_all_str = '(Select all filtered)'
self._list_filter = search_for
self._filter_index = [i for i in range(len(self._data)) if search_for in str(self._data[i])]
self._selected_filtered = set(self._data[i] for i in self._filter_index)
self._add_to_selection = False
self.beginResetModel()
self._is_filtered = True
self._all_selected = True
self.endResetModel()
else:
self.remove_filter()
[docs] def apply_filter(self):
if not self._is_filtered:
return
if self._add_to_selection:
self._remove_and_add_filtered()
else:
self._remove_and_replace_filtered()
[docs] def _remove_and_add_filtered(self):
if not self._selected:
# no previous selected, just replace
self._selected = set(self._selected_filtered)
else:
# add selected
self._selected.update(self._selected_filtered)
# remove unselected
self._selected.difference_update(
set(self._data[i] for i in self._filter_index if self._data[i] not in self._selected_filtered)
)
self.remove_filter()
[docs] def _remove_and_replace_filtered(self):
self._selected = set(self._selected_filtered)
self._empty_selected = False
self.remove_filter()
[docs] def remove_filter(self):
if not self._is_filtered:
return
self.beginResetModel()
self._select_all_str = '(Select all)'
self._list_filter = None
self._is_filtered = False
self._filter_index = []
self._selected_filtered = set()
self._all_selected = self._is_all_selected()
self.endResetModel()
[docs] def remove_items(self, items):
if self._is_filtered:
self._selected_filtered.difference_update(items)
remove_index = []
subtract_index = 0
for i, row in enumerate(self._filter_index):
if self._data[row] in items:
# indexes to remove
remove_index.append(i)
subtract_index = subtract_index + 1
else:
# update row index
self._filter_index[i] = self._filter_index[i] - subtract_index
for i in reversed(remove_index):
self._filter_index.pop(i)
self._data_set.difference_update(items)
self._data = [d for d in self._data if d not in items]
self._selected.difference_update(items)
self._all_selected = self._is_all_selected()