Skip to content

Commit 87be36a

Browse files
committed
prevent the creation of embedded models
1 parent 7dd117f commit 87be36a

File tree

15 files changed

+236
-21
lines changed

15 files changed

+236
-21
lines changed

django_mongodb_backend/fields/embedded_model.py

+11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs):
1818
super().__init__(*args, **kwargs)
1919

2020
def check(self, **kwargs):
21+
from ..models import EmbeddedModel
22+
2123
errors = super().check(**kwargs)
24+
if not issubclass(self.embedded_model, EmbeddedModel):
25+
return [
26+
checks.Error(
27+
"Embedded model must be a subclass of "
28+
"django_mongodb_backend.models.EmbeddedModel.",
29+
obj=self,
30+
id="django_mongodb_backend.embedded_model.E002",
31+
)
32+
]
2233
for field in self.embedded_model._meta.fields:
2334
if field.remote_field:
2435
errors.append(

django_mongodb_backend/managers.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,48 @@
1+
from django.db import NotSupportedError
12
from django.db.models.manager import BaseManager
23

34
from .queryset import MongoQuerySet
45

56

67
class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
78
pass
9+
10+
11+
class EmbeddedModelManager(BaseManager):
12+
"""
13+
Prevent all queryset operations on embedded models since they don't have
14+
their own collection.
15+
"""
16+
17+
def get_queryset(self):
18+
raise NotSupportedError("EmbeddedModels cannot be queried.")
19+
20+
def all(self):
21+
raise NotSupportedError("EmbeddedModels cannot be queried.")
22+
23+
def get(self, *args, **kwargs):
24+
raise NotSupportedError("EmbeddedModels cannot be queried.")
25+
26+
def get_or_create(self, **kwargs):
27+
raise NotSupportedError("EmbeddedModels cannot be queried.")
28+
29+
def filter(self, *args, **kwargs):
30+
raise NotSupportedError("EmbeddedModels cannot be queried.")
31+
32+
def create(self, **kwargs):
33+
raise NotSupportedError("EmbeddedModels cannot be created.")
34+
35+
def bulk_create(self, *args, **kwargs):
36+
raise NotSupportedError("EmbeddedModels cannot be created.")
37+
38+
def update(self, *args, **kwargs):
39+
raise NotSupportedError("EmbeddedModels cannot be updated.")
40+
41+
def bulk_update(self, *args, **kwargs):
42+
raise NotSupportedError("EmbeddedModels cannot be updated.")
43+
44+
def update_or_create(self, **kwargs):
45+
raise NotSupportedError("EmbeddedModels cannot be updated or created.")
46+
47+
def delete(self):
48+
raise NotSupportedError("EmbeddedModels cannot be deleted.")

django_mongodb_backend/models.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from django.db import NotSupportedError, models
2+
3+
from .managers import EmbeddedModelManager
4+
5+
6+
class EmbeddedModel(models.Model):
7+
objects = EmbeddedModelManager()
8+
9+
class Meta:
10+
abstract = True
11+
12+
def delete(self, *args, **kwargs):
13+
raise NotSupportedError("EmbeddedModels cannot be deleted.")
14+
15+
def save(self, *args, **kwargs):
16+
raise NotSupportedError("EmbeddedModels cannot be saved.")

django_mongodb_backend/schema.py

+27
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
from .utils import OperationCollector
1111

1212

13+
def ignore_embedded_models(func):
14+
"""Make a SchemaEditor a no-op if model is an EmbeddedModel."""
15+
16+
def wrapper(self, model, *args, **kwargs):
17+
# If parent_model isn't None, this is a valid recursive operation.
18+
parent_model = kwargs.get("parent_model")
19+
from .models import EmbeddedModel
20+
21+
if parent_model is None and issubclass(model, EmbeddedModel):
22+
return
23+
func(self, model, *args, **kwargs)
24+
25+
return wrapper
26+
27+
1328
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
1429
def get_collection(self, name):
1530
if self.collect_sql:
@@ -22,6 +37,7 @@ def get_database(self):
2237
return self.connection.get_database()
2338

2439
@wrap_database_errors
40+
@ignore_embedded_models
2541
def create_model(self, model):
2642
self.get_database().create_collection(model._meta.db_table)
2743
self._create_model_indexes(model)
@@ -75,13 +91,15 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None):
7591
for index in model._meta.indexes:
7692
self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model)
7793

94+
@ignore_embedded_models
7895
def delete_model(self, model):
7996
# Delete implicit M2m tables.
8097
for field in model._meta.local_many_to_many:
8198
if field.remote_field.through._meta.auto_created:
8299
self.delete_model(field.remote_field.through)
83100
self.get_collection(model._meta.db_table).drop()
84101

102+
@ignore_embedded_models
85103
def add_field(self, model, field):
86104
# Create implicit M2M tables.
87105
if field.many_to_many and field.remote_field.through._meta.auto_created:
@@ -103,6 +121,7 @@ def add_field(self, model, field):
103121
elif self._field_should_have_unique(field):
104122
self._add_field_unique(model, field)
105123

124+
@ignore_embedded_models
106125
def _alter_field(
107126
self,
108127
model,
@@ -149,6 +168,7 @@ def _alter_field(
149168
if not old_field_unique and new_field_unique:
150169
self._add_field_unique(model, new_field)
151170

171+
@ignore_embedded_models
152172
def remove_field(self, model, field):
153173
# Remove implicit M2M tables.
154174
if field.many_to_many and field.remote_field.through._meta.auto_created:
@@ -210,6 +230,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None):
210230
for index in model._meta.indexes:
211231
self.remove_index(parent_model or model, index)
212232

233+
@ignore_embedded_models
213234
def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""):
214235
olds = {tuple(fields) for fields in old_index_together}
215236
news = {tuple(fields) for fields in new_index_together}
@@ -222,6 +243,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co
222243
for field_names in news.difference(olds):
223244
self._add_composed_index(model, field_names, column_prefix=column_prefix)
224245

246+
@ignore_embedded_models
225247
def alter_unique_together(
226248
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
227249
):
@@ -249,6 +271,7 @@ def alter_unique_together(
249271
model, constraint, parent_model=parent_model, column_prefix=column_prefix
250272
)
251273

274+
@ignore_embedded_models
252275
def add_index(
253276
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
254277
):
@@ -302,6 +325,7 @@ def _add_field_index(self, model, field, *, column_prefix=""):
302325
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
303326
self.add_index(model, index, field=field, column_prefix=column_prefix)
304327

328+
@ignore_embedded_models
305329
def remove_index(self, model, index):
306330
if index.contains_expressions:
307331
return
@@ -355,6 +379,7 @@ def _remove_field_index(self, model, field, column_prefix=""):
355379
)
356380
collection.drop_index(index_names[0])
357381

382+
@ignore_embedded_models
358383
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
359384
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
360385
condition=constraint.condition,
@@ -384,6 +409,7 @@ def _add_field_unique(self, model, field, column_prefix=""):
384409
constraint = UniqueConstraint(fields=[field.name], name=name)
385410
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)
386411

412+
@ignore_embedded_models
387413
def remove_constraint(self, model, constraint):
388414
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
389415
condition=constraint.condition,
@@ -417,6 +443,7 @@ def _remove_field_unique(self, model, field, column_prefix=""):
417443
)
418444
self.get_collection(model._meta.db_table).drop_index(constraint_names[0])
419445

446+
@ignore_embedded_models
420447
def alter_db_table(self, model, old_db_table, new_db_table):
421448
if old_db_table == new_db_table:
422449
return

docs/source/embedded-models.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ The basics
1111
Let's consider this example::
1212

1313
from django_mongodb_backend.fields import EmbeddedModelField
14+
from django_mongodb_backend.models import EmbeddedModel
1415

1516
class Customer(models.Model):
1617
name = models.CharField(...)
1718
address = EmbeddedModelField("Address")
1819
...
1920

20-
class Address(models.Model):
21+
class Address(EmbeddedModel):
2122
...
2223
city = models.CharField(...)
2324

docs/source/fields.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ Stores a model of type ``embedded_model``.
224224

225225
Specifies the model class to embed. It can be either a concrete model
226226
class or a :ref:`lazy reference <lazy-relationships>` to a model class.
227+
The target model must be a subclass of
228+
``django_mongodb_backend.models.EmbeddedModel``.
227229

228230
The embedded model cannot have relational fields
229231
(:class:`~django.db.models.ForeignKey`,
@@ -234,11 +236,12 @@ Stores a model of type ``embedded_model``.
234236

235237
from django.db import models
236238
from django_mongodb_backend.fields import EmbeddedModelField
239+
from django_mongodb_backend.models import EmbeddedModel
237240

238-
class Address(models.Model):
241+
class Address(EmbeddedModel):
239242
...
240243

241-
class Author(models.Model):
244+
class Author(EmbeddedModel):
242245
address = EmbeddedModelField(Address)
243246

244247
class Book(models.Model):

docs/source/models.rst

Whitespace-only changes.

tests/model_fields_/models.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.db import models
44

55
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
6+
from django_mongodb_backend.models import EmbeddedModel
67

78

89
# ObjectIdField
@@ -98,19 +99,19 @@ class Holder(models.Model):
9899
data = EmbeddedModelField("Data", null=True, blank=True)
99100

100101

101-
class Data(models.Model):
102+
class Data(EmbeddedModel):
102103
integer = models.IntegerField(db_column="custom_column")
103104
auto_now = models.DateTimeField(auto_now=True)
104105
auto_now_add = models.DateTimeField(auto_now_add=True)
105106

106107

107-
class Address(models.Model):
108+
class Address(EmbeddedModel):
108109
city = models.CharField(max_length=20)
109110
state = models.CharField(max_length=2)
110111
zip_code = models.IntegerField(db_index=True)
111112

112113

113-
class Author(models.Model):
114+
class Author(EmbeddedModel):
114115
name = models.CharField(max_length=10)
115116
age = models.IntegerField()
116117
address = EmbeddedModelField(Address)

tests/model_fields_/test_embedded_model.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from django.test.utils import isolate_apps
55

66
from django_mongodb_backend.fields import EmbeddedModelField
7+
from django_mongodb_backend.models import EmbeddedModel
78

89
from .models import (
910
Address,
@@ -108,7 +109,7 @@ def test_nested(self):
108109
@isolate_apps("model_fields_")
109110
class CheckTests(SimpleTestCase):
110111
def test_no_relational_fields(self):
111-
class Target(models.Model):
112+
class Target(EmbeddedModel):
112113
key = models.ForeignKey("MyModel", models.CASCADE)
113114

114115
class MyModel(models.Model):
@@ -121,3 +122,18 @@ class MyModel(models.Model):
121122
self.assertEqual(
122123
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
123124
)
125+
126+
def test_embedded_model_subclass(self):
127+
class Target(models.Model):
128+
pass
129+
130+
class MyModel(models.Model):
131+
field = EmbeddedModelField(Target)
132+
133+
errors = MyModel().check()
134+
self.assertEqual(len(errors), 1)
135+
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002")
136+
msg = errors[0].msg
137+
self.assertEqual(
138+
msg, "Embedded model must be a subclass of django_mongodb_backend.models.EmbeddedModel."
139+
)

tests/model_forms_/models.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from django.db import models
22

33
from django_mongodb_backend.fields import EmbeddedModelField
4+
from django_mongodb_backend.models import EmbeddedModel
45

56

6-
class Address(models.Model):
7+
class Address(EmbeddedModel):
78
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
89
city = models.CharField(max_length=20)
910
state = models.CharField(max_length=2)
@@ -15,8 +16,3 @@ class Author(models.Model):
1516
age = models.IntegerField()
1617
address = EmbeddedModelField(Address)
1718
billing_address = EmbeddedModelField(Address, blank=True, null=True)
18-
19-
20-
class Book(models.Model):
21-
name = models.CharField(max_length=100)
22-
author = EmbeddedModelField(Author)

tests/models_/__init__.py

Whitespace-only changes.

tests/models_/models.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from django_mongodb_backend.models import EmbeddedModel
2+
3+
4+
class Embed(EmbeddedModel):
5+
pass

tests/models_/test_embedded_model.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from django.db import NotSupportedError
2+
from django.test import SimpleTestCase
3+
4+
from .models import Embed
5+
6+
7+
class TestMethods(SimpleTestCase):
8+
def test_save(self):
9+
e = Embed()
10+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be saved."):
11+
e.save()
12+
13+
def test_delete(self):
14+
e = Embed()
15+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."):
16+
e.delete()
17+
18+
19+
class TestManagerMethods(SimpleTestCase):
20+
def test_all(self):
21+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
22+
Embed.objects.all()
23+
24+
def test_get(self):
25+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
26+
Embed.objects.get()
27+
28+
def test_get_or_create(self):
29+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
30+
Embed.objects.get_or_create()
31+
32+
def test_filter(self):
33+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."):
34+
Embed.objects.filter(foo="bar")
35+
36+
def test_create(self):
37+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."):
38+
Embed.objects.create()
39+
40+
def test_bulk_create(self):
41+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."):
42+
Embed.objects.bulk_create()
43+
44+
def test_update(self):
45+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."):
46+
Embed.objects.update(foo="bar")
47+
48+
def test_bulk_update(self):
49+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."):
50+
Embed.objects.bulk_update()
51+
52+
def test_update_or_create(self):
53+
msg = "EmbeddedModels cannot be updated or created."
54+
with self.assertRaisesMessage(NotSupportedError, msg):
55+
Embed.objects.update_or_create()
56+
57+
def test_delete(self):
58+
with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."):
59+
Embed.objects.delete()

0 commit comments

Comments
 (0)