diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index dca0c4bb7..25dbf465a 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -18,6 +18,7 @@ # datastructures "BaseDict", "BaseList", + "BaseSet", "EmbeddedDocumentList", "LazyReference", # document diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index d08d4930d..4388aa997 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -9,6 +9,7 @@ "BaseDict", "StrictDict", "BaseList", + "BaseSet", "EmbeddedDocumentList", "LazyReference", ) @@ -187,6 +188,56 @@ def _mark_as_changed(self, key=None): self._instance._mark_as_changed(self._name) +class BaseSet(set): + """A special set so we can watch any changes.""" + + _dereferenced = False + _instance = None + _name = None + + def __init__(self, set_items, instance, name): + BaseDocument = _import_class("BaseDocument") + + if isinstance(instance, BaseDocument): + self._instance = weakref.proxy(instance) + self._name = name + super().__init__(set_items) + + def __getstate__(self): + self.instance = None + self._dereferenced = False + return self + + def __setstate__(self, state): + self = state + return self + + update = mark_as_changed_wrapper(set.update) + intersection_update = mark_as_changed_wrapper(set.intersection_update) + difference_update = mark_as_changed_wrapper(set.difference_update) + symmetric_difference_update = mark_as_changed_wrapper( + set.symmetric_difference_update + ) + add = mark_as_changed_wrapper(set.add) + remove = mark_as_changed_wrapper(set.remove) + discard = mark_as_changed_wrapper(set.discard) + pop = mark_as_changed_wrapper(set.pop) + clear = mark_as_changed_wrapper(set.clear) + __ior__ = mark_as_changed_wrapper(set.__ior__) + __iand__ = mark_as_changed_wrapper(set.__iand__) + __isub__ = mark_as_changed_wrapper(set.__isub__) + __ixor__ = mark_as_changed_wrapper(set.__ixor__) + + def _mark_as_changed(self, key=None): + if hasattr(self._instance, "_mark_as_changed"): + if key: + self._instance._mark_as_changed( + "{}.{}".format(self._name, key % len(self)) + ) + else: + self._instance._mark_as_changed(self._name) + + class EmbeddedDocumentList(BaseList): def __init__(self, list_items, instance, name): super().__init__(list_items, instance, name) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index e697fe403..4694f957b 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -11,6 +11,7 @@ from mongoengine.base.datastructures import ( BaseDict, BaseList, + BaseSet, EmbeddedDocumentList, LazyReference, StrictDict, @@ -477,7 +478,7 @@ def from_json(cls, json_data, created=False): def __expand_dynamic_values(self, name, value): """Expand any dynamic values to their correct types / values.""" - if not isinstance(value, (dict, list, tuple)): + if not isinstance(value, (dict, list, tuple, set)): return value # If the value is a dict with '_cls' in it, turn it into a document @@ -498,6 +499,8 @@ def __expand_dynamic_values(self, name, value): value = EmbeddedDocumentList(value, self, name) else: value = BaseList(value, self, name) + elif isinstance(value, set) and not isinstance(value, BaseSet): + value = BaseSet(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 7bab813c0..9486c0120 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -6,7 +6,12 @@ import pymongo from mongoengine.base.common import UPDATE_OPERATORS -from mongoengine.base.datastructures import BaseDict, BaseList, EmbeddedDocumentList +from mongoengine.base.datastructures import ( + BaseDict, + BaseList, + BaseSet, + EmbeddedDocumentList, +) from mongoengine.common import _import_class from mongoengine.errors import DeprecatedError, ValidationError @@ -316,6 +321,9 @@ def __get__(self, instance, owner): elif not isinstance(value, BaseList): value = BaseList(value, instance, self.name) instance._data[self.name] = value + elif isinstance(value, set) and not isinstance(value, BaseSet): + value = BaseSet(value, instance, self.name) + instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, instance, self.name) instance._data[self.name] = value @@ -323,7 +331,7 @@ def __get__(self, instance, owner): if ( auto_dereference and instance._initialised - and isinstance(value, (BaseList, BaseDict)) + and isinstance(value, (BaseList, BaseSet, BaseDict)) and not value._dereferenced ): value = _dereference(value, max_depth=1, instance=instance, name=self.name) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index ff608a3b3..d8938d7fb 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -3,6 +3,7 @@ from mongoengine.base import ( BaseDict, BaseList, + BaseSet, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document, @@ -215,12 +216,14 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): :class:`~mongoengine.base.ComplexBaseField` """ if not items: - if isinstance(items, (BaseDict, BaseList)): + if isinstance(items, (BaseDict, BaseList, BaseSet)): return items if instance: if isinstance(items, dict): return BaseDict(items, instance, name) + elif isinstance(items, set): + return BaseSet(items, instance, name) else: return BaseList(items, instance, name) @@ -238,8 +241,16 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): doc._data["_cls"] = _cls return doc - if not hasattr(items, "items"): - is_list = True + SET = "set" + LIST = "list" + DICT = "dict" + + if isinstance(items, set): + iterable_type = SET + iterator = enumerate(items) + data = set() + elif not hasattr(items, "items"): + iterable_type = LIST list_type = BaseList if isinstance(items, EmbeddedDocumentList): list_type = EmbeddedDocumentList @@ -247,18 +258,20 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): iterator = enumerate(items) data = [] else: - is_list = False + iterable_type = DICT iterator = items.items() data = {} depth += 1 for k, v in iterator: - if is_list: + if iterable_type == SET: + data.add(v) + elif iterable_type == LIST: data.append(v) else: data[k] = v - if k in self.object_map and not is_list: + if k in self.object_map and iterable_type == DICT: data[k] = self.object_map[k] elif isinstance(v, (Document, EmbeddedDocument)): for field_name in v._fields: @@ -271,12 +284,15 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): data[k]._data[field_name] = self.object_map.get( (v["_ref"].collection, v["_ref"].id), v ) - elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: + elif ( + isinstance(v, (dict, list, tuple, set)) + and depth <= self.max_depth + ): item_name = "{}.{}.{}".format(name, k, field_name) data[k]._data[field_name] = self._attach_objects( v, depth, instance=instance, name=item_name ) - elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: + elif isinstance(v, (dict, list, tuple, set)) and depth <= self.max_depth: item_name = "{}.{}".format(name, k) if name else name data[k] = self._attach_objects( v, depth - 1, instance=instance, name=item_name @@ -285,7 +301,9 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): data[k] = self.object_map.get((v.collection, v.id), v) if instance and name: - if is_list: + if iterable_type == SET: + return BaseSet(data, instance, name) + if iterable_type == LIST: return tuple(data) if as_tuple else list_type(data, instance, name) return BaseDict(data, instance, name) depth += 1 diff --git a/mongoengine/document.py b/mongoengine/document.py index db64054a1..7fd6ae6fc 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -10,6 +10,7 @@ BaseDict, BaseDocument, BaseList, + BaseSet, DocumentMetaclass, EmbeddedDocumentList, TopLevelDocumentMetaclass, @@ -780,6 +781,9 @@ def _reload(self, key, value): elif isinstance(value, BaseList): value = [self._reload(key, v) for v in value] value = BaseList(value, self, key) + elif isinstance(value, BaseSet): + value = {self._reload(key, v) for v in value} + value = BaseSet(value, self, key) elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): value._instance = None value._changed_fields = [] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b05e726af..d1a01b621 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -67,6 +67,7 @@ "ListField", "SortedListField", "EmbeddedDocumentListField", + "SetField", "DictField", "MapField", "ReferenceField", @@ -1020,6 +1021,44 @@ def to_mongo(self, value, use_db_field=True, fields=None): return sorted(value, reverse=self._order_reverse) +class SetField(ListField): + """ A set field that inherits from list field but creates a set interface + on the python side of things. + + The underlying MongoDB list is a sorted list of the set. + """ + + def __init__(self, field=None, max_length=None, **kwargs): + kwargs.setdefault("default", lambda: set()) + super().__init__(field=field, max_length=max_length, **kwargs) + + def __set__(self, instance, value): + if isinstance(value, (list, tuple)): + value = set(value) + return super().__set__(instance, value) + + def to_mongo(self, value, use_db_field=True, fields=None): + if isinstance(value, set): + value = sorted(list(value)) + return super().to_mongo(value, use_db_field, fields) + + def to_python(self, value): + value = super().to_python(value) + if isinstance(value, list): + value = set(value) + return value + + def validate(self, value): + """Make sure that a set of valid fields is being used.""" + if not isinstance(value, (list, tuple, set, BaseQuerySet)): + self.error("Only lists, tuples and sets may be used in a set field") + + if isinstance(value, set): + value = list(value) + + super().validate(value) + + def key_not_string(d): """Helper function to recursively determine if any key in a dictionary is not a string. diff --git a/tests/fields/test_set_field.py b/tests/fields/test_set_field.py new file mode 100644 index 000000000..007766c2d --- /dev/null +++ b/tests/fields/test_set_field.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +from bson import InvalidDocument +import pytest + +from mongoengine import * +from mongoengine.base import BaseSet +from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version + +from tests.utils import MongoDBTestCase, get_as_pymongo + + +def get_from_db(doc): + """Fetch the Document from the database""" + return doc.__class__.objects.get(id=doc.id) + + +class TestSetField(MongoDBTestCase): + def test_storage(self): + class BlogPost(Document): + info = SetField() + + BlogPost.drop_collection() + info = {"testvalue1", "testvalue2"} + post = BlogPost(info=info) + assert isinstance(post.info, BaseSet) + + post.save() + assert isinstance(post.info, BaseSet) + + post = get_from_db(post) + assert isinstance(post.info, BaseSet) + + assert get_as_pymongo(post) == {"_id": post.id, "info": sorted(list(info))} + + def test_validate_invalid_type(self): + class BlogPost(Document): + info = SetField() + + BlogPost.drop_collection() + + invalid_infos = ["my post", {1: "test"}] + for invalid_info in invalid_infos: + with pytest.raises(ValidationError): + BlogPost(info=invalid_info).validate() + + def test_general_things(self): + """Ensure that set types work as expected.""" + + class BlogPost(Document): + info = SetField() + + BlogPost.drop_collection() + + post = BlogPost(info=["test1", "test2"]) + post.save() + + post = BlogPost() + post.info = {"test3"} + post.save() + + post = BlogPost() + post.info = ["test3", "test3", "test4"] + post.save() + + post = BlogPost() + post.info = {"test2", "test3", "test4"} + post.save() + + assert BlogPost.objects.count() == 4 + assert BlogPost.objects.filter(info="test3").count() == 3 + assert BlogPost.objects.filter(info__0="test1").count() == 1 + assert BlogPost.objects.filter(info__0="test2").count() == 1 + assert BlogPost.objects.filter(info__in=["test2", "test4"]).count() == 3 + + post = BlogPost.objects.create(info={"test5", "test6"}) + post.info.update({"updated"}) + post.save() + post.reload() + assert "updated" in post.info + + def test_list_and_tuples(self): + """Ensure that sets can be created from lists and tuples.""" + + class BlogPost(Document): + info = SetField() + + BlogPost.drop_collection() + + post = BlogPost(info=[1, 2, 2]) + assert post.info == {1, 2} + post.save() + assert post.info == {1, 2} + post.reload() + assert post.info == {1, 2} + post = get_from_db(post) + assert post.info == {1, 2} + + post = BlogPost() + post.info = [1, 2, 2] + assert post.info == {1, 2} + post.save() + assert post.info == {1, 2} + post.reload() + assert post.info == {1, 2} + post = get_from_db(post) + assert post.info == {1, 2} + + post = BlogPost(info=(1, 2, 2)) + assert post.info == {1, 2} + post.save() + assert post.info == {1, 2} + post.reload() + assert post.info == {1, 2} + post = get_from_db(post) + assert post.info == {1, 2} + + post = BlogPost() + post.info = (1, 2, 2) + assert post.info == {1, 2} + post.save() + assert post.info == {1, 2} + post.reload() + assert post.info == {1, 2} + post = get_from_db(post) + assert post.info == {1, 2} + + def test_set_field_field(self): + """Ensure subfields are validated.""" + + class BlogPost(Document): + info = SetField(BooleanField()) + + BlogPost.drop_collection() + + with pytest.raises(ValidationError): + post = BlogPost() + post.info = {"a", "b"} + post.validate() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 6d432e328..8904e4676 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -3,7 +3,7 @@ import pytest from mongoengine import Document -from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict +from mongoengine.base.datastructures import BaseDict, BaseList, BaseSet, StrictDict class DocumentStub(object): @@ -14,7 +14,7 @@ def _mark_as_changed(self, key): self._changed_fields.append(key) -class TestBaseDict: +class TestBaseDict(unittest.TestCase): @staticmethod def _get_basedict(dict_items): """Get a BaseList bound to a fake document instance""" @@ -150,7 +150,7 @@ def test___delattr____tracked_by_changes(self): assert base_dict._instance._changed_fields == ["my_name.a_new_attr"] -class TestBaseList: +class TestBaseList(unittest.TestCase): @staticmethod def _get_baselist(list_items): """Get a BaseList bound to a fake document instance""" @@ -358,6 +358,115 @@ def test_sort_calls_with_key(self): assert base_list == [1, 11, 2] +class TestBaseSet(unittest.TestCase): + @staticmethod + def _get_baseset(set_items): + """Get a BaseSet bound to a fake document instance""" + fake_doc = DocumentStub() + base_set = BaseSet(set_items, instance=None, name="my_name") + base_set._instance = ( + fake_doc # hack to inject the mock, it does not work in the constructor + ) + return base_set + + def test___init___(self): + class MyDoc(Document): + pass + + set_items = {True} + doc = MyDoc() + base_set = BaseSet(set_items, instance=doc, name="my_name") + assert isinstance(base_set._instance, Document) + assert base_set._name == "my_name" + assert base_set == set_items + + def test___iter__(self): + values = {0, 1, 2, 2} + base_set = BaseSet(values, instance=None, name="my_name") + assert values == set(base_set) + + def test_add_calls_mark_as_changed(self): + base_set = self._get_baseset({}) + assert not base_set._instance._changed_fields + base_set.add(True) + assert base_set._instance._changed_fields == ["my_name"] + + def test_subclass_add(self): + # Due to the way mark_as_changed_wrapper is implemented + # it is good to test subclasses + class SubBaseSet(BaseSet): + pass + + base_set = SubBaseSet({}, instance=None, name="my_name") + base_set.add(True) + + def test_update_calls_mark_as_changed(self): + base_set = self._get_baseset({}) + base_set.update({True}) + assert base_set._instance._changed_fields == ["my_name"] + + def test_intersection_update_calls_mark_as_changed(self): + base_set = self._get_baseset({True, False}) + base_set.intersection_update({True}) + assert base_set._instance._changed_fields == ["my_name"] + + def test_difference_update_calls_mark_as_changed(self): + base_set = self._get_baseset({True, False}) + base_set.difference_update({True}) + assert base_set._instance._changed_fields == ["my_name"] + + def test_symmetric_difference_update_calls_mark_as_changed(self): + base_set = self._get_baseset({0, 1}) + base_set.symmetric_difference_update({1, 2}) + assert base_set._instance._changed_fields == ["my_name"] + + def test_remove_calls_mark_as_changed(self): + base_set = self._get_baseset({True}) + base_set.remove(True) + assert base_set._instance._changed_fields == ["my_name"] + + def test_remove_not_mark_as_changed_when_it_fails(self): + base_set = self._get_baseset({True}) + with pytest.raises(KeyError): + base_set.remove(False) + assert not base_set._instance._changed_fields + + def test_discard_calls_mark_as_changed(self): + base_set = self._get_baseset({True}) + base_set.discard(True) + assert base_set._instance._changed_fields == ["my_name"] + + def test_pop_calls_mark_as_changed(self): + base_set = self._get_baseset({True}) + base_set.pop() + assert base_set._instance._changed_fields == ["my_name"] + + def test_clear_calls_mark_as_changed(self): + base_set = self._get_baseset({True}) + base_set.clear() + assert base_set._instance._changed_fields == ["my_name"] + + def test___ior___calls_mark_as_changed(self): + base_set = self._get_baseset({}) + base_set |= {True} + assert base_set._instance._changed_fields == ["my_name"] + + def test___iand___calls_mark_as_changed(self): + base_set = self._get_baseset({True, False}) + base_set &= {True} + assert base_set._instance._changed_fields == ["my_name"] + + def test___isub___calls_mark_as_changed(self): + base_set = self._get_baseset({True, False}) + base_set -= {True} + assert base_set._instance._changed_fields == ["my_name"] + + def test___ixor___calls_mark_as_changed(self): + base_set = self._get_baseset({0, 1}) + base_set ^= {1, 2} + assert base_set._instance._changed_fields == ["my_name"] + + class TestStrictDict(unittest.TestCase): def setUp(self): self.dtype = self.strict_dict_class(("a", "b", "c"))