diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 38da2e873..c79048e67 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,3 +1,5 @@ +from functools import partial + from bson import SON, DBRef from mongoengine.base import ( @@ -174,7 +176,10 @@ def _fetch_objects(self, doc_type=None): refs = [ dbref for dbref in dbrefs if (col_name, dbref) not in object_map ] - references = collection.objects.in_bulk(refs) + if isinstance(collection.objects, partial): + references = collection.objects().in_bulk(refs) + else: + references = collection.objects.in_bulk(refs) for key, doc in references.items(): object_map[(col_name, key)] = doc else: # Generic reference: use the refs data to convert to document diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 8386249f2..724f6471f 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -4016,6 +4016,28 @@ def objects(klass, queryset): assert 0 == Foo.objects.count() assert 1 == Bar.objects.count() + def test_select_related_when_referenced_has_custom_queryset_manager(self): + class Foo(Document): + @queryset_manager + def objects(klass, queryset, arg1=None, arg2=None, **kwargs): + return queryset(**kwargs) + + class Bar(Document): + foo = ReferenceField(Foo) + + Foo.drop_collection() + Bar.drop_collection() + + foo = Foo() + foo.save() + + Bar(foo=foo).save() + + Bar.objects().select_related() + + assert 1 == Foo.objects().count() + assert 1 == Bar.objects().count() + def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary."""