Skip to content

Adding SetField #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# datastructures
"BaseDict",
"BaseList",
"BaseSet",
"EmbeddedDocumentList",
"LazyReference",
# document
Expand Down
51 changes: 51 additions & 0 deletions mongoengine/base/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"BaseDict",
"StrictDict",
"BaseList",
"BaseSet",
"EmbeddedDocumentList",
"LazyReference",
)
Expand Down Expand Up @@ -187,6 +188,56 @@ def _mark_as_changed(self, key=None):
self._instance._mark_as_changed(self._name)


class BaseSet(set):
"""A special set so we can watch any changes."""

_dereferenced = False
_instance = None
_name = None

def __init__(self, set_items, instance, name):
BaseDocument = _import_class("BaseDocument")

if isinstance(instance, BaseDocument):
self._instance = weakref.proxy(instance)
self._name = name
super().__init__(set_items)

def __getstate__(self):
self.instance = None
self._dereferenced = False
return self

def __setstate__(self, state):
self = state
return self

update = mark_as_changed_wrapper(set.update)
intersection_update = mark_as_changed_wrapper(set.intersection_update)
difference_update = mark_as_changed_wrapper(set.difference_update)
symmetric_difference_update = mark_as_changed_wrapper(
set.symmetric_difference_update
)
add = mark_as_changed_wrapper(set.add)
remove = mark_as_changed_wrapper(set.remove)
discard = mark_as_changed_wrapper(set.discard)
pop = mark_as_changed_wrapper(set.pop)
clear = mark_as_changed_wrapper(set.clear)
__ior__ = mark_as_changed_wrapper(set.__ior__)
__iand__ = mark_as_changed_wrapper(set.__iand__)
__isub__ = mark_as_changed_wrapper(set.__isub__)
__ixor__ = mark_as_changed_wrapper(set.__ixor__)

def _mark_as_changed(self, key=None):
if hasattr(self._instance, "_mark_as_changed"):
if key:
self._instance._mark_as_changed(
"{}.{}".format(self._name, key % len(self))
)
else:
self._instance._mark_as_changed(self._name)


class EmbeddedDocumentList(BaseList):
def __init__(self, list_items, instance, name):
super().__init__(list_items, instance, name)
Expand Down
5 changes: 4 additions & 1 deletion mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
BaseSet,
EmbeddedDocumentList,
LazyReference,
StrictDict,
Expand Down Expand Up @@ -477,7 +478,7 @@ def from_json(cls, json_data, created=False):

def __expand_dynamic_values(self, name, value):
"""Expand any dynamic values to their correct types / values."""
if not isinstance(value, (dict, list, tuple)):
if not isinstance(value, (dict, list, tuple, set)):
return value

# If the value is a dict with '_cls' in it, turn it into a document
Expand All @@ -498,6 +499,8 @@ def __expand_dynamic_values(self, name, value):
value = EmbeddedDocumentList(value, self, name)
else:
value = BaseList(value, self, name)
elif isinstance(value, set) and not isinstance(value, BaseSet):
value = BaseSet(value, self, name)
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, self, name)

Expand Down
12 changes: 10 additions & 2 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pymongo

from mongoengine.base.common import UPDATE_OPERATORS
from mongoengine.base.datastructures import BaseDict, BaseList, EmbeddedDocumentList
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
BaseSet,
EmbeddedDocumentList,
)
from mongoengine.common import _import_class
from mongoengine.errors import DeprecatedError, ValidationError

Expand Down Expand Up @@ -316,14 +321,17 @@ def __get__(self, instance, owner):
elif not isinstance(value, BaseList):
value = BaseList(value, instance, self.name)
instance._data[self.name] = value
elif isinstance(value, set) and not isinstance(value, BaseSet):
value = BaseSet(value, instance, self.name)
instance._data[self.name] = value
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance, self.name)
instance._data[self.name] = value

if (
auto_dereference
and instance._initialised
and isinstance(value, (BaseList, BaseDict))
and isinstance(value, (BaseList, BaseSet, BaseDict))
and not value._dereferenced
):
value = _dereference(value, max_depth=1, instance=instance, name=self.name)
Expand Down
36 changes: 27 additions & 9 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mongoengine.base import (
BaseDict,
BaseList,
BaseSet,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
Expand Down Expand Up @@ -215,12 +216,14 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
:class:`~mongoengine.base.ComplexBaseField`
"""
if not items:
if isinstance(items, (BaseDict, BaseList)):
if isinstance(items, (BaseDict, BaseList, BaseSet)):
return items

if instance:
if isinstance(items, dict):
return BaseDict(items, instance, name)
elif isinstance(items, set):
return BaseSet(items, instance, name)
else:
return BaseList(items, instance, name)

Expand All @@ -238,27 +241,37 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
doc._data["_cls"] = _cls
return doc

if not hasattr(items, "items"):
is_list = True
SET = "set"
LIST = "list"
DICT = "dict"

if isinstance(items, set):
iterable_type = SET
iterator = enumerate(items)
data = set()
elif not hasattr(items, "items"):
iterable_type = LIST
list_type = BaseList
if isinstance(items, EmbeddedDocumentList):
list_type = EmbeddedDocumentList
as_tuple = isinstance(items, tuple)
iterator = enumerate(items)
data = []
else:
is_list = False
iterable_type = DICT
iterator = items.items()
data = {}

depth += 1
for k, v in iterator:
if is_list:
if iterable_type == SET:
data.add(v)
elif iterable_type == LIST:
data.append(v)
else:
data[k] = v

if k in self.object_map and not is_list:
if k in self.object_map and iterable_type == DICT:
data[k] = self.object_map[k]
elif isinstance(v, (Document, EmbeddedDocument)):
for field_name in v._fields:
Expand All @@ -271,12 +284,15 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
data[k]._data[field_name] = self.object_map.get(
(v["_ref"].collection, v["_ref"].id), v
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
elif (
isinstance(v, (dict, list, tuple, set))
and depth <= self.max_depth
):
item_name = "{}.{}.{}".format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(
v, depth, instance=instance, name=item_name
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
elif isinstance(v, (dict, list, tuple, set)) and depth <= self.max_depth:
item_name = "{}.{}".format(name, k) if name else name
data[k] = self._attach_objects(
v, depth - 1, instance=instance, name=item_name
Expand All @@ -285,7 +301,9 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
data[k] = self.object_map.get((v.collection, v.id), v)

if instance and name:
if is_list:
if iterable_type == SET:
return BaseSet(data, instance, name)
if iterable_type == LIST:
return tuple(data) if as_tuple else list_type(data, instance, name)
return BaseDict(data, instance, name)
depth += 1
Expand Down
4 changes: 4 additions & 0 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseDict,
BaseDocument,
BaseList,
BaseSet,
DocumentMetaclass,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
Expand Down Expand Up @@ -780,6 +781,9 @@ def _reload(self, key, value):
elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value]
value = BaseList(value, self, key)
elif isinstance(value, BaseSet):
value = {self._reload(key, v) for v in value}
value = BaseSet(value, self, key)
elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)):
value._instance = None
value._changed_fields = []
Expand Down
39 changes: 39 additions & 0 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"ListField",
"SortedListField",
"EmbeddedDocumentListField",
"SetField",
"DictField",
"MapField",
"ReferenceField",
Expand Down Expand Up @@ -1020,6 +1021,44 @@ def to_mongo(self, value, use_db_field=True, fields=None):
return sorted(value, reverse=self._order_reverse)


class SetField(ListField):
""" A set field that inherits from list field but creates a set interface
on the python side of things.

The underlying MongoDB list is a sorted list of the set.
"""

def __init__(self, field=None, max_length=None, **kwargs):
kwargs.setdefault("default", lambda: set())
super().__init__(field=field, max_length=max_length, **kwargs)

def __set__(self, instance, value):
if isinstance(value, (list, tuple)):
value = set(value)
return super().__set__(instance, value)

def to_mongo(self, value, use_db_field=True, fields=None):
if isinstance(value, set):
value = sorted(list(value))
return super().to_mongo(value, use_db_field, fields)

def to_python(self, value):
value = super().to_python(value)
if isinstance(value, list):
value = set(value)
return value

def validate(self, value):
"""Make sure that a set of valid fields is being used."""
if not isinstance(value, (list, tuple, set, BaseQuerySet)):
self.error("Only lists, tuples and sets may be used in a set field")

if isinstance(value, set):
value = list(value)

super().validate(value)


def key_not_string(d):
"""Helper function to recursively determine if any key in a
dictionary is not a string.
Expand Down
Loading