Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 160 additions & 4 deletions openedx/core/djangoapps/enrollments/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from openedx.core.djangoapps.course_groups import cohorts
from openedx.core.djangoapps.embargo.models import Country, CountryAccessRule, RestrictedCourse
from openedx.core.djangoapps.embargo.test_utils import restrict_course
from openedx.core.djangoapps.enrollments import api, data
from openedx.core.djangoapps.enrollments import data
from openedx.core.djangoapps.enrollments.errors import CourseEnrollmentError
from openedx.core.djangoapps.enrollments.views import EnrollmentUserThrottle
from openedx.core.djangoapps.notifications.config.waffle import ENABLE_NOTIFICATIONS
Expand Down Expand Up @@ -711,9 +711,9 @@ def test_get_enrollment_details_bad_course(self):
)
assert resp.status_code == status.HTTP_400_BAD_REQUEST

@patch.object(api, "get_enrollment")
def test_get_enrollment_internal_error(self, mock_get_enrollment):
mock_get_enrollment.side_effect = CourseEnrollmentError("Something bad happened.")
@patch.object(CourseEnrollment.objects, "get")
def test_get_enrollment_internal_error(self, mock_get):
mock_get.side_effect = CourseEnrollmentError("Something bad happened.")
resp = self.client.get(
reverse(
'courseenrollment',
Expand Down Expand Up @@ -2031,3 +2031,159 @@ def test_delete_enrollment_allowed(self, delete_data, expected_result):
self.client.post(self.url, self.data)
response = self.client.delete(self.url, delete_data)
assert response.status_code == expected_result

# --- Response-shape tests (ADR 0025 serializer migration) ---

def test_post_response_shape(self):
"""POST 201 response contains the expected fields from CourseEnrollmentAllowedSerializer."""
response = self.client.post(self.url, self.data)
assert response.status_code == status.HTTP_201_CREATED
body = response.json()
assert body['email'] == self.data['email']
assert body['course_id'] == self.data['course_id']
assert body['auto_enroll'] is False
assert 'created' in body

def test_post_auto_enroll_true_in_response(self):
"""POST with auto_enroll=true is reflected in the 201 response."""
response = self.client.post(self.url, {**self.data, 'auto_enroll': True})
assert response.status_code == status.HTTP_201_CREATED
assert response.json()['auto_enroll'] is True

def test_post_missing_email_returns_field_error(self):
"""POST without email returns a serializer field-level 400 with an 'email' key."""
response = self.client.post(self.url, {'course_id': self.data['course_id']})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert 'email' in response.json()

def test_post_missing_course_id_returns_field_error(self):
"""POST without course_id returns a serializer field-level 400 with a 'course_id' key."""
response = self.client.post(self.url, {'email': self.data['email']})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert 'course_id' in response.json()

def test_post_duplicate_returns_409_with_message(self):
"""A duplicate POST returns 409 with a 'message' key."""
self.client.post(self.url, self.data)
response = self.client.post(self.url, self.data)
assert response.status_code == status.HTTP_409_CONFLICT
assert 'message' in response.json()

def test_get_response_is_list(self):
"""GET response body is a JSON list."""
response = self.client.get(self.url, {'email': self.data['email']})
assert response.status_code == status.HTTP_200_OK
assert isinstance(response.json(), list)

def test_get_empty_response_is_empty_list(self):
"""GET with no matching enrollments returns an empty list, not null."""
response = self.client.get(self.url, {'email': 'nobody@example.com'})
assert response.status_code == status.HTTP_200_OK
assert response.json() == []

def test_get_item_shape(self):
"""Each item in the GET response has the fields from CourseEnrollmentAllowedSerializer."""
self.client.post(self.url, self.data)
response = self.client.get(self.url, {'email': self.data['email']})
assert response.status_code == status.HTTP_200_OK
item = response.json()[0]
assert item['email'] == self.data['email']
assert item['course_id'] == self.data['course_id']
assert 'auto_enroll' in item
assert 'created' in item

def test_get_multiple_entries_returned(self):
"""GET returns all enrollment-allowed records for a given email."""
second_course = 'course-v1:edX+OtherX+Other_Course'
self.client.post(self.url, self.data)
self.client.post(self.url, {'email': self.data['email'], 'course_id': second_course})
response = self.client.get(self.url, {'email': self.data['email']})
assert response.status_code == status.HTTP_200_OK
results = response.json()
assert len(results) == 2
assert all(r['email'] == self.data['email'] for r in results)

def test_delete_missing_email_returns_field_error(self):
"""DELETE without email returns a serializer field-level 400 with an 'email' key."""
self.client.post(self.url, self.data)
response = self.client.delete(self.url, {'course_id': self.data['course_id']})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert 'email' in response.json()


@skip_unless_lms
class EnrollmentViewResponseShapeTest(ModuleStoreTestCase, APITestCase):
"""
Tests that verify EnrollmentView (GET /enrollment/v1/enrollment/{course_id} and
/enrollment/v1/enrollment/{username},{course_id}) response structure is preserved
after migrating to direct serializer usage (ADR 0025).
"""

USERNAME = "Bob"
PASSWORD = "edx"

def setUp(self):
super().setUp()
self.course = CourseFactory.create(emit_signals=True)
self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD)
self.client.login(username=self.USERNAME, password=self.PASSWORD)
CourseModeFactory.create(
course_id=self.course.id,
mode_slug=CourseMode.DEFAULT_MODE_SLUG,
mode_display_name=CourseMode.DEFAULT_MODE_SLUG,
)
CourseEnrollment.enroll(self.user, self.course.id)

def _get_by_course_id(self):
return self.client.get(
reverse('courseenrollment', kwargs={'course_id': str(self.course.id)})
)

def _get_by_username_and_course_id(self):
return self.client.get(
reverse('courseenrollment', kwargs={'username': self.USERNAME, 'course_id': str(self.course.id)})
)

def test_get_by_course_id_returns_200(self):
assert self._get_by_course_id().status_code == status.HTTP_200_OK

def test_get_by_username_course_id_returns_200(self):
assert self._get_by_username_and_course_id().status_code == status.HTTP_200_OK

def test_get_response_top_level_fields(self):
"""Response contains the expected top-level enrollment fields."""
body = self._get_by_course_id().json()
for field in ('created', 'mode', 'is_active', 'user', 'course_details'):
assert field in body, f"Missing top-level field: {field}"

def test_get_response_user_and_mode(self):
"""user and mode values match the enrollment."""
body = self._get_by_course_id().json()
assert body['user'] == self.USERNAME
assert body['mode'] == CourseMode.DEFAULT_MODE_SLUG
assert body['is_active'] is True

def test_get_by_username_course_id_matches_by_course_id(self):
"""Both URL shapes return identical response bodies."""
by_course = self._get_by_course_id().json()
by_username = self._get_by_username_and_course_id().json()
assert by_course == by_username

def test_get_course_details_fields(self):
"""course_details contains the expected nested fields."""
course_details = self._get_by_course_id().json()['course_details']
for field in (
'course_id', 'course_name', 'enrollment_start', 'enrollment_end',
'course_start', 'course_end', 'invite_only', 'course_modes', 'pacing_type',
):
assert field in course_details, f"Missing course_details field: {field}"
assert course_details['course_id'] == str(self.course.id)

def test_get_no_enrollment_returns_null(self):
"""GET for a course the user never enrolled in returns HTTP 200 with a null body."""
unenrolled_course = CourseFactory.create(emit_signals=True)
resp = self.client.get(
reverse('courseenrollment', kwargs={'course_id': str(unenrolled_course.id)})
)
assert resp.status_code == status.HTTP_200_OK
assert resp.json() is None
75 changes: 37 additions & 38 deletions openedx/core/djangoapps/enrollments/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination
from openedx.core.djangoapps.enrollments.serializers import (
CourseEnrollmentAllowedSerializer,
CourseEnrollmentSerializer,
CourseEnrollmentsApiListSerializer,
)
from openedx.core.djangoapps.user_api.accounts.permissions import CanRetireUser
Expand Down Expand Up @@ -187,6 +188,7 @@ class EnrollmentView(APIView, ApiKeyPermissionMixIn):
)
permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,)
throttle_classes = (EnrollmentUserThrottle,)
serializer_class = CourseEnrollmentSerializer

# Since the course about page on the marketing site uses this API to auto-enroll users,
# we need to support cross-domain CSRF.
Expand Down Expand Up @@ -221,7 +223,17 @@ def get(self, request, course_id=None, username=None):
return Response(status=status.HTTP_404_NOT_FOUND)

try:
return Response(api.get_enrollment(username, course_id))
course_key = CourseKey.from_string(course_id)
except InvalidKeyError:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={"message": f"No course '{course_id}' found for enrollment"},
)

try:
enrollment = CourseEnrollment.objects.get(user__username=username, course_id=course_key)
except CourseEnrollment.DoesNotExist:
return Response(None)
except CourseEnrollmentError:
return Response(
status=status.HTTP_400_BAD_REQUEST,
Expand All @@ -233,6 +245,9 @@ def get(self, request, course_id=None, username=None):
},
)

serializer = CourseEnrollmentSerializer(enrollment)
return Response(serializer.data)


class EnrollmentUserRolesView(APIView):
"""
Expand Down Expand Up @@ -1087,12 +1102,9 @@ def get(self, request):
if not user_email:
user_email = request.user.email

enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) or []
serialized_enrollments_allowed = [
CourseEnrollmentAllowedSerializer(enrollment).data for enrollment in enrollments_allowed
]

return Response(status=status.HTTP_200_OK, data=serialized_enrollments_allowed)
enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email)
serializer = CourseEnrollmentAllowedSerializer(enrollments_allowed, many=True)
return Response(status=status.HTTP_200_OK, data=serializer.data)

def post(self, request):
"""
Expand Down Expand Up @@ -1126,23 +1138,24 @@ def post(self, request):
- 403: Forbidden, you need to be staff.
- 409: Conflict, enrollment allowed already exists.
"""
is_bad_request_response, email, course_id = self.check_required_data(request)
auto_enroll = request.data.get("auto_enroll", False)
if is_bad_request_response:
return is_bad_request_response
serializer = CourseEnrollmentAllowedSerializer(data=request.data)
if not serializer.is_valid():
return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors)

try:
enrollment_allowed = CourseEnrollmentAllowed.objects.create(
email=email, course_id=course_id, auto_enroll=auto_enroll
)
enrollment_allowed = serializer.save()
except IntegrityError:
return Response(
status=status.HTTP_409_CONFLICT,
data={"message": f"An enrollment allowed with email {email} and course {course_id} already exists."},
data={
"message": (
f"An enrollment allowed with email {serializer.validated_data.get('email')} "
f"and course {serializer.validated_data.get('course_id')} already exists."
)
},
)

serializer = CourseEnrollmentAllowedSerializer(enrollment_allowed)
return Response(status=status.HTTP_201_CREATED, data=serializer.data)
return Response(status=status.HTTP_201_CREATED, data=CourseEnrollmentAllowedSerializer(enrollment_allowed).data)

def delete(self, request):
"""
Expand Down Expand Up @@ -1174,32 +1187,18 @@ def delete(self, request):
- 403: Forbidden, you need to be staff.
- 404: Not found, the course enrollment allowed doesn't exists.
"""
is_bad_request_response, email, course_id = self.check_required_data(request)
if is_bad_request_response:
return is_bad_request_response
serializer = CourseEnrollmentAllowedSerializer(data=request.data)
if not serializer.is_valid():
return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors)

email = serializer.validated_data.get("email")
course_id = serializer.validated_data.get("course_id")

try:
CourseEnrollmentAllowed.objects.get(email=email, course_id=course_id).delete()
return Response(
status=status.HTTP_204_NO_CONTENT,
)
return Response(status=status.HTTP_204_NO_CONTENT)
except ObjectDoesNotExist:
return Response(
status=status.HTTP_404_NOT_FOUND,
data={"message": f"An enrollment allowed with email {email} and course {course_id} doesn't exists."},
)

def check_required_data(self, request):
"""
Check if the request has email and course_id.
"""
email = request.data.get("email")
course_id = request.data.get("course_id")
if not email or not course_id:
is_bad_request = Response(
status=status.HTTP_400_BAD_REQUEST,
data={"message": "Please provide a value for 'email' and 'course_id' in the request data."},
)
else:
is_bad_request = None
return (is_bad_request, email, course_id)
Loading