Skip to content

Commit 588fa17

Browse files
feat: added pagination and deletion guard (#7615)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 91efcff commit 588fa17

2 files changed

Lines changed: 138 additions & 15 deletions

File tree

api/experimentation/views.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import logging
22
from typing import Any
33

4+
from django.db import IntegrityError
45
from django.db.models import Q, QuerySet
5-
from rest_framework import mixins, status
6+
from rest_framework import mixins, serializers, status
67
from rest_framework.decorators import action
78
from rest_framework.permissions import IsAuthenticated
89
from rest_framework.request import Request
910
from rest_framework.response import Response
1011
from rest_framework.serializers import BaseSerializer
1112

13+
from app.pagination import CustomPagination
1214
from environments.views import NestedEnvironmentViewSet
1315
from experimentation.models import (
1416
Experiment,
@@ -101,7 +103,7 @@ class ExperimentViewSet(
101103
mixins.DestroyModelMixin,
102104
):
103105
serializer_class = ExperimentSerializer
104-
pagination_class = None
106+
pagination_class = CustomPagination
105107
permission_classes = [IsAuthenticated, ExperimentPermission]
106108
model_class = Experiment
107109
lookup_field = "id"
@@ -125,6 +127,10 @@ def get_queryset(self) -> "QuerySet[Experiment]":
125127
)
126128
status_filter = self.request.query_params.get("status")
127129
if status_filter:
130+
if status_filter not in ExperimentStatus.values:
131+
raise serializers.ValidationError(
132+
{"status": f"Invalid status '{status_filter}'."}
133+
)
128134
qs = qs.filter(status=status_filter)
129135

130136
q = self.request.query_params.get("q")
@@ -152,7 +158,13 @@ def create(self, request: Request, *args: object, **kwargs: object) -> Response:
152158
status=status.HTTP_409_CONFLICT,
153159
)
154160

155-
self.perform_create(serializer)
161+
try:
162+
self.perform_create(serializer)
163+
except IntegrityError:
164+
return Response(
165+
{"detail": "An active experiment already exists for this feature."},
166+
status=status.HTTP_409_CONFLICT,
167+
)
156168
return Response(serializer.data, status=status.HTTP_201_CREATED)
157169

158170
def perform_create(self, serializer: BaseSerializer[Experiment]) -> None:
@@ -162,12 +174,28 @@ def perform_create(self, serializer: BaseSerializer[Experiment]) -> None:
162174
)
163175

164176
def perform_update(self, serializer: BaseSerializer[Experiment]) -> None:
177+
changed_fields = {
178+
field
179+
for field, value in serializer.validated_data.items()
180+
if getattr(serializer.instance, field, None) != value
181+
}
182+
if not changed_fields:
183+
return
165184
experiment: Experiment = serializer.save()
166185
create_experiment_audit_log(
167186
experiment, self._get_user(self.request), action="updated"
168187
)
169188

170189
def perform_destroy(self, instance: Experiment) -> None:
190+
if instance.status == ExperimentStatus.RUNNING:
191+
raise serializers.ValidationError(
192+
{
193+
"detail": (
194+
"Cannot delete a running experiment. "
195+
"Pause or complete it first."
196+
)
197+
}
198+
)
171199
create_experiment_audit_log(
172200
instance, self._get_user(self.request), action="deleted"
173201
)

api/tests/unit/experimentation/test_experiment_views.py

Lines changed: 107 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from typing import TYPE_CHECKING
44

55
import pytest
6+
from django.db import IntegrityError
67
from django.urls import reverse
8+
from pytest_mock import MockerFixture
79
from rest_framework import status
810
from rest_framework.test import APIClient
911

@@ -265,8 +267,9 @@ def test_get_list__with_experiments__returns_all(
265267

266268
# Then
267269
assert response.status_code == status.HTTP_200_OK
268-
assert len(response.json()) == 1
269-
assert response.json()[0]["id"] == experiment.id
270+
results = response.json()["results"]
271+
assert len(results) == 1
272+
assert results[0]["id"] == experiment.id
270273

271274

272275
def test_get_list__with_experiments__returns_nested_feature(
@@ -284,9 +287,9 @@ def test_get_list__with_experiments__returns_nested_feature(
284287

285288
# Then
286289
assert response.status_code == status.HTTP_200_OK
287-
data = response.json()
288-
assert len(data) == 1
289-
feature_data = data[0]["feature"]
290+
results = response.json()["results"]
291+
assert len(results) == 1
292+
feature_data = results[0]["feature"]
290293
assert isinstance(feature_data, dict)
291294
assert feature_data["id"] == multivariate_feature.id
292295
assert feature_data["name"] == multivariate_feature.name
@@ -329,7 +332,7 @@ def test_get_list__empty__returns_200(
329332

330333
# Then
331334
assert response.status_code == status.HTTP_200_OK
332-
assert response.json() == []
335+
assert response.json()["results"] == []
333336

334337

335338
@pytest.mark.parametrize(
@@ -357,7 +360,7 @@ def test_get_list__filter_by_status__returns_filtered(
357360

358361
# Then
359362
assert response.status_code == status.HTTP_200_OK
360-
assert len(response.json()) == expected_count
363+
assert len(response.json()["results"]) == expected_count
361364

362365

363366
def test_get_list__search_by_experiment_name__returns_matching(
@@ -374,8 +377,9 @@ def test_get_list__search_by_experiment_name__returns_matching(
374377

375378
# Then
376379
assert response.status_code == status.HTTP_200_OK
377-
assert len(response.json()) == 1
378-
assert response.json()[0]["id"] == experiment.id
380+
results = response.json()["results"]
381+
assert len(results) == 1
382+
assert results[0]["id"] == experiment.id
379383

380384

381385
def test_get_list__search_by_feature_name__returns_matching(
@@ -395,8 +399,9 @@ def test_get_list__search_by_feature_name__returns_matching(
395399

396400
# Then
397401
assert response.status_code == status.HTTP_200_OK
398-
assert len(response.json()) == 1
399-
assert response.json()[0]["id"] == experiment.id
402+
results = response.json()["results"]
403+
assert len(results) == 1
404+
assert results[0]["id"] == experiment.id
400405

401406

402407
def test_get_list__search_no_match__returns_empty(
@@ -413,7 +418,7 @@ def test_get_list__search_no_match__returns_empty(
413418

414419
# Then
415420
assert response.status_code == status.HTTP_200_OK
416-
assert len(response.json()) == 0
421+
assert len(response.json()["results"]) == 0
417422

418423

419424
def test_get_detail__exists__returns_200(
@@ -670,3 +675,93 @@ def test_delete__valid_delete__creates_audit_log(
670675
).last()
671676
assert audit is not None
672677
assert "deleted" in audit.log
678+
679+
680+
def test_get_list__invalid_status__returns_400(
681+
admin_client_new: APIClient,
682+
environment: Environment,
683+
enable_features: EnableFeaturesFixture,
684+
) -> None:
685+
# Given
686+
enable_features(EXPERIMENT_FLAG)
687+
688+
# When
689+
response = admin_client_new.get(_list_url(environment), {"status": "garbage"})
690+
691+
# Then
692+
assert response.status_code == status.HTTP_400_BAD_REQUEST
693+
694+
695+
def test_delete__running_experiment__returns_400(
696+
admin_client_new: APIClient,
697+
environment: Environment,
698+
experiment: Experiment,
699+
enable_features: EnableFeaturesFixture,
700+
) -> None:
701+
# Given
702+
enable_features(EXPERIMENT_FLAG)
703+
experiment.status = ExperimentStatus.RUNNING
704+
experiment.save()
705+
706+
# When
707+
response = admin_client_new.delete(_detail_url(environment, experiment))
708+
709+
# Then
710+
assert response.status_code == status.HTTP_400_BAD_REQUEST
711+
assert Experiment.objects.filter(id=experiment.id).exists()
712+
713+
714+
def test_patch__no_change__skips_audit_log(
715+
admin_client_new: APIClient,
716+
environment: Environment,
717+
experiment: Experiment,
718+
enable_features: EnableFeaturesFixture,
719+
) -> None:
720+
# Given
721+
enable_features(EXPERIMENT_FLAG)
722+
audit_count_before = AuditLog.objects.filter(
723+
related_object_type=RelatedObjectType.EXPERIMENT.name
724+
).count()
725+
726+
# When
727+
response = admin_client_new.patch(
728+
_detail_url(environment, experiment),
729+
data={"name": experiment.name},
730+
format="json",
731+
)
732+
733+
# Then
734+
assert response.status_code == status.HTTP_200_OK
735+
audit_count_after = AuditLog.objects.filter(
736+
related_object_type=RelatedObjectType.EXPERIMENT.name
737+
).count()
738+
assert audit_count_after == audit_count_before
739+
740+
741+
def test_post__concurrent_create_race__returns_409(
742+
admin_client_new: APIClient,
743+
environment: Environment,
744+
multivariate_feature: Feature,
745+
enable_features: EnableFeaturesFixture,
746+
mocker: MockerFixture,
747+
) -> None:
748+
# Given
749+
enable_features(EXPERIMENT_FLAG)
750+
mocker.patch(
751+
"experimentation.views.ExperimentViewSet.perform_create",
752+
side_effect=IntegrityError(),
753+
)
754+
755+
# When
756+
response = admin_client_new.post(
757+
_list_url(environment),
758+
data={
759+
"feature": multivariate_feature.id,
760+
"name": "Race",
761+
"hypothesis": "Should 409",
762+
},
763+
format="json",
764+
)
765+
766+
# Then
767+
assert response.status_code == status.HTTP_409_CONFLICT

0 commit comments

Comments
 (0)