Skip to content

Commit 905e2c8

Browse files
authored
feat: basic support for custom fields (#75)
basic support for custom fields
1 parent fd79324 commit 905e2c8

10 files changed

+805
-81
lines changed

netbox_diode_plugin/api/applier.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def apply_changeset(change_set: ChangeSet) -> ChangeSetResult:
21+
def apply_changeset(change_set: ChangeSet, request) -> ChangeSetResult:
2222
"""Apply a change set."""
2323
_validate_change_set(change_set)
2424

@@ -33,7 +33,7 @@ def apply_changeset(change_set: ChangeSet) -> ChangeSetResult:
3333
try:
3434
model_class = get_object_type_model(object_type)
3535
data = _pre_apply(model_class, change, created)
36-
_apply_change(data, model_class, change, created)
36+
_apply_change(data, model_class, change, created, request)
3737
except ValidationError as e:
3838
raise _err_from_validation_error(e, f"changes[{i}]")
3939
except ObjectDoesNotExist:
@@ -45,42 +45,58 @@ def apply_changeset(change_set: ChangeSet) -> ChangeSetResult:
4545
id=change_set.id,
4646
)
4747

48-
def _apply_change(data: dict, model_class: models.Model, change: Change, created: dict):
48+
def _apply_change(data: dict, model_class: models.Model, change: Change, created: dict, request):
4949
serializer_class = get_serializer_for_model(model_class)
5050
change_type = change.change_type
5151
if change_type == ChangeType.CREATE.value:
52-
serializer = serializer_class(data=data)
52+
serializer = serializer_class(data=data, context={"request": request})
5353
serializer.is_valid(raise_exception=True)
5454
instance = serializer.save()
5555
created[change.ref_id] = instance
5656

5757
elif change_type == ChangeType.UPDATE.value:
5858
if object_id := change.object_id:
5959
instance = model_class.objects.get(id=object_id)
60-
serializer = serializer_class(instance, data=data, partial=True)
60+
serializer = serializer_class(instance, data=data, partial=True, context={"request": request})
6161
serializer.is_valid(raise_exception=True)
6262
serializer.save()
6363
# create and update in a same change set
6464
elif change.ref_id and (instance := created[change.ref_id]):
65-
serializer = serializer_class(instance, data=data, partial=True)
65+
serializer = serializer_class(instance, data=data, partial=True, context={"request": request})
6666
serializer.is_valid(raise_exception=True)
6767
serializer.save()
6868

69+
def _set_path(data, path, value):
70+
path = path.split(".")
71+
key = path.pop(0)
72+
while len(path) > 0:
73+
data = data[key]
74+
key = path.pop(0)
75+
data[key] = value
76+
77+
def _get_path(data, path):
78+
path = path.split(".")
79+
v = data
80+
for p in path:
81+
v = v[p]
82+
return v
83+
6984
def _pre_apply(model_class: models.Model, change: Change, created: dict):
7085
data = change.data.copy()
7186

7287
# resolve foreign key references to new objects
7388
for ref_field in change.new_refs:
74-
if isinstance(data[ref_field], (list, tuple)):
89+
v = _get_path(data, ref_field)
90+
if isinstance(v, (list, tuple)):
7591
ref_list = []
76-
for ref in data[ref_field]:
92+
for ref in v:
7793
if isinstance(ref, str):
7894
ref_list.append(created[ref].pk)
7995
elif isinstance(ref, int):
8096
ref_list.append(ref)
81-
data[ref_field] = ref_list
97+
_set_path(data, ref_field, ref_list)
8298
else:
83-
data[ref_field] = created[data[ref_field]].pk
99+
_set_path(data, ref_field, created[v].pk)
84100

85101
# ignore? fields that are not in the data model (error?)
86102
allowed_fields = legal_fields(model_class)

netbox_diode_plugin/api/common.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from django.contrib.contenttypes.models import ContentType
1414
from django.core.exceptions import ValidationError
1515
from django.db import models
16+
from extras.models import CustomField
1617
from rest_framework import status
1718

1819
logger = logging.getLogger("netbox.diode_data")
@@ -114,13 +115,41 @@ def validate(self) -> dict[str, list[str]]:
114115
errors[change.object_type] = rel_errors
115116

116117
try:
118+
custom_fields = change_data.pop('custom_fields', None)
119+
if custom_fields:
120+
self._validate_custom_fields(custom_fields, model)
121+
117122
instance = model(**change_data)
118123
instance.clean_fields(exclude=excluded_relation_fields)
119124
except ValidationError as e:
120-
errors[change.object_type].update(e.error_dict)
125+
errors[change.object_type].update(_error_dict(e))
121126

122127
return errors or None
123128

129+
def _validate_custom_fields(self, data: dict, model: models.Model) -> None:
130+
custom_fields = {
131+
cf.name: cf for cf in CustomField.objects.get_for_model(model)
132+
}
133+
134+
unknown_errors = []
135+
for field_name, value in data.items():
136+
if field_name not in custom_fields:
137+
unknown_errors.append(f"Unknown field name '{field_name}' in custom field data.")
138+
continue
139+
if unknown_errors:
140+
raise ValidationError({
141+
"custom_fields": unknown_errors
142+
})
143+
144+
req_errors = []
145+
for field_name, cf in custom_fields.items():
146+
if cf.required and field_name not in data:
147+
req_errors.append(f"Custom field '{field_name}' is required.")
148+
if req_errors:
149+
raise ValidationError({
150+
"custom_fields": req_errors
151+
})
152+
124153
def _validate_relations(self, change_data: dict, model: models.Model) -> tuple[list[str], dict]:
125154
# check that there is some value for every required
126155
# reference field, but don't validate the actual reference.
@@ -191,3 +220,18 @@ def __str__(self):
191220
if self.errors:
192221
return f"{self.message}: {self.errors}"
193222
return self.message
223+
224+
def _error_dict(e: ValidationError) -> dict:
225+
"""Convert a ValidationError to a dictionary."""
226+
if hasattr(e, "error_dict"):
227+
return e.error_dict
228+
return {
229+
"__all__": e.error_list
230+
}
231+
232+
@dataclass
233+
class AutoSlug:
234+
"""A class that marks an auto-generated slug."""
235+
236+
field_name: str
237+
value: str

netbox_diode_plugin/api/differ.py

+59-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Diode NetBox Plugin - API - Differ."""
44

55
import copy
6+
import datetime
67
import logging
78

89
from django.contrib.contenttypes.models import ContentType
@@ -12,7 +13,7 @@
1213
from .common import Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType
1314
from .plugin_utils import get_primary_value, legal_fields
1415
from .supported_models import extract_supported_models
15-
from .transformer import cleanup_unresolved_references, transform_proto_json
16+
from .transformer import cleanup_unresolved_references, set_custom_field_defaults, transform_proto_json
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -68,9 +69,31 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901
6869
else:
6970
prechange_data[field_name] = value
7071

72+
if hasattr(instance, "get_custom_fields"):
73+
custom_field_values = instance.get_custom_fields()
74+
cfmap = {}
75+
for cf, value in custom_field_values.items():
76+
if isinstance(value, (datetime.datetime, datetime.date)):
77+
cfmap[cf.name] = value
78+
else:
79+
cfmap[cf.name] = cf.serialize(value)
80+
prechange_data["custom_fields"] = cfmap
81+
7182
return prechange_data
7283

7384

85+
def _harmonize_formats(prechange_data: dict, postchange_data: dict):
86+
for k, v in prechange_data.items():
87+
if isinstance(v, datetime.datetime):
88+
prechange_data[k] = v.strftime("%Y-%m-%dT%H:%M:%SZ")
89+
elif isinstance(v, datetime.date):
90+
prechange_data[k] = v.strftime("%Y-%m-%d")
91+
elif isinstance(v, int) and k in postchange_data:
92+
postchange_data[k] = int(postchange_data[k])
93+
elif isinstance(v, dict):
94+
_harmonize_formats(v, postchange_data.get(k, {}))
95+
96+
7497
def clean_diff_data(data: dict, exclude_empty_values: bool = True) -> dict:
7598
"""Clean diff data by removing null values."""
7699
result = {}
@@ -80,8 +103,10 @@ def clean_diff_data(data: dict, exclude_empty_values: bool = True) -> dict:
80103
continue
81104
if isinstance(v, list) and len(v) == 0:
82105
continue
83-
if isinstance(v, dict) and len(v) == 0:
84-
continue
106+
if isinstance(v, dict):
107+
if len(v) == 0:
108+
continue
109+
v = clean_diff_data(v, exclude_empty_values)
85110
if isinstance(v, str) and v == "":
86111
continue
87112
result[k] = v
@@ -100,7 +125,7 @@ def diff_to_change(
100125
if change_type == ChangeType.UPDATE and not len(changed_attrs) > 0:
101126
change_type = ChangeType.NOOP
102127

103-
primary_value = get_primary_value(prechange_data | postchange_data, object_type)
128+
primary_value = str(get_primary_value(prechange_data | postchange_data, object_type))
104129
if primary_value is None:
105130
primary_value = "(unnamed)"
106131

@@ -111,6 +136,8 @@ def diff_to_change(
111136

112137
change = Change(
113138
change_type=change_type,
139+
before=_tidy(prechange_data),
140+
data={},
114141
object_type=object_type,
115142
object_id=prior_id if isinstance(prior_id, int) else None,
116143
ref_id=ref_id,
@@ -119,17 +146,13 @@ def diff_to_change(
119146
)
120147

121148
if change_type != ChangeType.NOOP:
122-
postchange_data_clean = clean_diff_data(postchange_data)
123-
change.data = sort_dict_recursively(postchange_data_clean)
124-
else:
125-
change.data = {}
126-
127-
if change_type == ChangeType.UPDATE or change_type == ChangeType.NOOP:
128-
prechange_data_clean = clean_diff_data(prechange_data)
129-
change.before = sort_dict_recursively(prechange_data_clean)
149+
change.data = _tidy(postchange_data)
130150

131151
return change
132152

153+
def _tidy(data: dict) -> dict:
154+
return sort_dict_recursively(clean_diff_data(data))
155+
133156
def sort_dict_recursively(d):
134157
"""Recursively sorts a dictionary by keys."""
135158
if isinstance(d, dict):
@@ -161,7 +184,11 @@ def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult:
161184
# prior state is a model instance
162185
else:
163186
prechange_data = prechange_data_from_instance(instance)
164-
187+
# merge the prior state that we don't want to overwrite with the new state
188+
# this is also important for custom fields because they do not appear to
189+
# respsect paritial update serialization.
190+
entity = _partially_merge(prechange_data, entity, instance)
191+
_harmonize_formats(prechange_data, entity)
165192
changed_data = shallow_compare_dict(
166193
prechange_data, entity,
167194
)
@@ -187,7 +214,25 @@ def generate_changeset(entity: dict, object_type: str) -> ChangeSetResult:
187214
if errors := change_set.validate():
188215
raise ChangeSetException("Invalid change set", errors)
189216

190-
return ChangeSetResult(
217+
218+
cs = ChangeSetResult(
191219
id=change_set.id,
192220
change_set=change_set,
193221
)
222+
return cs
223+
224+
def _partially_merge(prechange_data: dict, postchange_data: dict, instance) -> dict:
225+
"""Merge lists and custom_fields rather than replacing the full value..."""
226+
result = {}
227+
for key, value in postchange_data.items():
228+
# TODO: partially merge lists like tags? all lists?
229+
result[key] = value
230+
231+
# these are fully merged in from the prechange state because
232+
# they don't respect partial update serialization.
233+
if "custom_fields" in postchange_data:
234+
for key, value in prechange_data.get("custom_fields", {}).items():
235+
if value is not None and key not in postchange_data["custom_fields"]:
236+
result["custom_fields"][key] = value
237+
set_custom_field_defaults(result, instance)
238+
return result

0 commit comments

Comments
 (0)