diff --git a/README.md b/README.md index 3bfde01..0616df9 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ Finally include the `rest_framework_docs` urls in your `urls.py`: You can find detailed information about the package's settings at [the docs](http://drfdocs.com/settings/). REST_FRAMEWORK_DOCS = { - 'HIDE_DOCS': True # Default: False + 'HIDE_DOCS': True, # Default: False + 'LOGIN_REQUIRED': True, # Default: True } diff --git a/demo/project/accounts/serializers.py b/demo/project/accounts/serializers.py index e4b4cb8..067cf64 100644 --- a/demo/project/accounts/serializers.py +++ b/demo/project/accounts/serializers.py @@ -1,5 +1,6 @@ from rest_framework import serializers from project.accounts.models import User +from rest_framework.authtoken.serializers import AuthTokenSerializer class UserRegistrationSerializer(serializers.ModelSerializer): @@ -30,3 +31,12 @@ class Meta: model = User fields = ('id', 'token', 'password',) extra_kwargs = {'password': {'write_only': True}} + + +class NestedSerializer(serializers.Serializer): + nb_test = serializers.IntegerField(default=0, required=False) + liste_codes = serializers.ListField(child=serializers.CharField()) + + +class CustomAuthTokenSerializer(AuthTokenSerializer): + nested = NestedSerializer(many=True) diff --git a/demo/project/accounts/urls.py b/demo/project/accounts/urls.py index 1486675..003fa41 100644 --- a/demo/project/accounts/urls.py +++ b/demo/project/accounts/urls.py @@ -1,15 +1,23 @@ +import django from django.conf.urls import url from project.accounts import views +from rest_framework.routers import SimpleRouter -urlpatterns = [ - url(r'^test/$', views.TestView.as_view(), name="test-view"), +account_router = SimpleRouter() +account_router.register('user-model-viewsets', views.UserModelViewset, base_name='account') +account_urlpatterns = [ + url(r'^test/$', views.TestView.as_view(), name="test-view"), url(r'^login/$', views.LoginView.as_view(), name="login"), url(r'^register/$', views.UserRegistrationView.as_view(), name="register"), url(r'^reset-password/$', view=views.PasswordResetView.as_view(), name="reset-password"), url(r'^reset-password/confirm/$', views.PasswordResetConfirmView.as_view(), name="reset-password-confirm"), - url(r'^user/profile/$', views.UserProfileView.as_view(), name="profile"), +] + account_router.urls -] +# Django 1.9 Support for the app_name argument is deprecated +# https://docs.djangoproject.com/en/1.9/ref/urls/#include +django_version = django.VERSION +if django.VERSION[:2] >= (1, 9, ): + account_urlpatterns = (account_urlpatterns, 'accounts', ) diff --git a/demo/project/accounts/views.py b/demo/project/accounts/views.py index e1bd9c0..0c0ad0f 100644 --- a/demo/project/accounts/views.py +++ b/demo/project/accounts/views.py @@ -2,13 +2,13 @@ from django.views.generic.base import TemplateView from rest_framework import parsers, renderers, generics, status from rest_framework.authtoken.models import Token -from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet from project.accounts.models import User from project.accounts.serializers import ( - UserRegistrationSerializer, UserProfileSerializer, ResetPasswordSerializer + UserRegistrationSerializer, UserProfileSerializer, ResetPasswordSerializer, CustomAuthTokenSerializer ) @@ -28,7 +28,7 @@ class LoginView(APIView): permission_classes = () parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) - serializer_class = AuthTokenSerializer + serializer_class = CustomAuthTokenSerializer def post(self, request): serializer = self.serializer_class(data=request.data) @@ -81,3 +81,8 @@ def post(self, request, *args, **kwargs): if not serializer.is_valid(): return Response({'errors': serializer.errors}, status=status.HTTP_400_BAD_REQUEST) return Response({"msg": "Password updated successfully."}, status=status.HTTP_200_OK) + + +class UserModelViewset(ModelViewSet): + queryset = User.objects.all() + serializer_class = UserProfileSerializer diff --git a/demo/project/organisations/urls.py b/demo/project/organisations/urls.py index 4d0311e..bc75484 100644 --- a/demo/project/organisations/urls.py +++ b/demo/project/organisations/urls.py @@ -1,12 +1,27 @@ +import django from django.conf.urls import url from project.organisations import views +from rest_framework.routers import SimpleRouter +from .views import OrganisationModelViewset -urlpatterns = [ +organisation_router = SimpleRouter() +organisation_router.register('organisation-model-viewsets', OrganisationModelViewset, base_name='organisation') +organisations_urlpatterns = [ url(r'^create/$', view=views.CreateOrganisationView.as_view(), name="create"), url(r'^(?P[\w-]+)/$', view=views.RetrieveOrganisationView.as_view(), name="organisation"), url(r'^(?P[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"), url(r'^(?P[\w-]+)/leave/$', view=views.LeaveOrganisationView.as_view(), name="leave") +] + organisation_router.urls +members_urlpatterns = [ + url(r'^(?P[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"), ] + +# Django 1.9 Support for the app_name argument is deprecated +# https://docs.djangoproject.com/en/1.9/ref/urls/#include +django_version = django.VERSION +if django.VERSION[:2] >= (1, 9, ): + organisations_urlpatterns = (organisations_urlpatterns, 'organisations', ) + members_urlpatterns = (members_urlpatterns, 'organisations', ) diff --git a/demo/project/organisations/views.py b/demo/project/organisations/views.py index 1e2d5fb..a3c8901 100644 --- a/demo/project/organisations/views.py +++ b/demo/project/organisations/views.py @@ -1,5 +1,6 @@ from rest_framework import generics, status from rest_framework.response import Response +from rest_framework.viewsets import ModelViewSet from project.organisations.models import Organisation, Membership from project.organisations.serializers import ( CreateOrganisationSerializer, OrganisationMembersSerializer, RetrieveOrganisationSerializer @@ -34,3 +35,8 @@ def delete(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return Response(status=status.HTTP_204_NO_CONTENT) + + +class OrganisationModelViewset(ModelViewSet): + queryset = Organisation.objects.all() + serializer_class = OrganisationMembersSerializer diff --git a/demo/project/settings.py b/demo/project/settings.py index 5e06207..0c33d3d 100644 --- a/demo/project/settings.py +++ b/demo/project/settings.py @@ -43,7 +43,6 @@ 'project.accounts', 'project.organisations', - ) MIDDLEWARE_CLASSES = ( diff --git a/demo/project/urls.py b/demo/project/urls.py index d8e049f..bc5873c 100644 --- a/demo/project/urls.py +++ b/demo/project/urls.py @@ -13,14 +13,36 @@ 1. Add an import: from blog import urls as blog_urls 2. Add a URL to urlpatterns: url(r'^blog/', include(blog_urls)) """ +import django from django.conf.urls import include, url from django.contrib import admin +from rest_framework_docs.views import DRFDocsView +from .accounts.urls import account_urlpatterns, account_router +from .organisations.urls import organisations_urlpatterns, members_urlpatterns, organisation_router urlpatterns = [ url(r'^admin/', include(admin.site.urls)), - url(r'^docs/', include('rest_framework_docs.urls')), - - # API - url(r'^accounts/', view=include('project.accounts.urls', namespace='accounts')), - url(r'^organisations/', view=include('project.organisations.urls', namespace='organisations')), ] + +# Django 1.9 Support for the app_name argument is deprecated +# https://docs.djangoproject.com/en/1.9/ref/urls/#include +django_version = django.VERSION +if django.VERSION[:2] >= (1, 9, ): + urlpatterns.extend([ + url(r'^accounts/', view=include(account_urlpatterns, namespace='accounts')), + url(r'^organisations/', view=include(organisations_urlpatterns, namespace='organisations')), + url(r'^members/', view=include(members_urlpatterns, namespace='members')), + ]) +else: + urlpatterns.extend([ + url(r'^accounts/', view=include(account_urlpatterns, namespace='accounts', app_name='account_app')), + url(r'^organisations/', view=include(organisations_urlpatterns, namespace='organisations', app_name='organisations_app')), + url(r'^members/', view=include(members_urlpatterns, namespace='members', app_name='organisations_app')), + ]) + + +routers = [account_router, organisation_router] +urlpatterns.extend([ + url(r'^docs/(?P[\w-]+)/$', DRFDocsView.as_view(drf_router=routers), name='drfdocs-filter'), + url(r'^docs/$', DRFDocsView.as_view(drf_router=routers), name='drfdocs'), +]) diff --git a/requirements.txt b/requirements.txt index f449bbb..b799525 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -Django==1.8.7 -djangorestframework==3.3.2 -coverage==4.0.3 -flake8==2.5.1 +Django==2.1.0 +djangorestframework==3.7.7 +coverage==4.2 +flake8<3.0.0 mkdocs==0.15.3 diff --git a/rest_framework_docs/__init__.py b/rest_framework_docs/__init__.py index ad3cf1d..0f4c4de 100644 --- a/rest_framework_docs/__init__.py +++ b/rest_framework_docs/__init__.py @@ -1 +1,3 @@ __version__ = '0.0.11' + +SERIALIZER_FIELDS = {} diff --git a/rest_framework_docs/api_docs.py b/rest_framework_docs/api_docs.py index d22dd4c..2f2d88d 100644 --- a/rest_framework_docs/api_docs.py +++ b/rest_framework_docs/api_docs.py @@ -1,14 +1,20 @@ +from operator import attrgetter from importlib import import_module from django.conf import settings -from django.core.urlresolvers import RegexURLResolver, RegexURLPattern +from django.urls import URLResolver, URLPattern from django.utils.module_loading import import_string from rest_framework.views import APIView +from rest_framework_docs import SERIALIZER_FIELDS from rest_framework_docs.api_endpoint import ApiEndpoint class ApiDocumentation(object): - def __init__(self, drf_router=None): + def __init__(self, drf_router=None, filter_param=None): + """ + :param filter_param: namespace or app_name + """ + SERIALIZER_FIELDS.clear() self.endpoints = [] self.drf_router = drf_router try: @@ -17,18 +23,19 @@ def __init__(self, drf_router=None): # Handle a case when there's no dot in ROOT_URLCONF root_urlconf = import_module(settings.ROOT_URLCONF) if hasattr(root_urlconf, 'urls'): - self.get_all_view_names(root_urlconf.urls.urlpatterns) + self.get_all_view_names(root_urlconf.urls.urlpatterns, filter_param=filter_param) else: - self.get_all_view_names(root_urlconf.urlpatterns) + self.get_all_view_names(root_urlconf.urlpatterns, filter_param=filter_param) - def get_all_view_names(self, urlpatterns, parent_pattern=None): + def get_all_view_names(self, urlpatterns, parent_pattern=None, filter_param=None): for pattern in urlpatterns: - if isinstance(pattern, RegexURLResolver): - parent_pattern = None if pattern._regex == "^" else pattern - self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=parent_pattern) - elif isinstance(pattern, RegexURLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern): - api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router) - self.endpoints.append(api_endpoint) + if isinstance(pattern, URLResolver) and (not filter_param or filter_param in [pattern.app_name, pattern.namespace]): + # parent_pattern = None if pattern._regex == "^" else pattern + self.get_all_view_names(urlpatterns=pattern.url_patterns, parent_pattern=None if pattern.pattern.regex.pattern == "^" else pattern, filter_param=filter_param) + elif isinstance(pattern, URLPattern) and self._is_drf_view(pattern) and not self._is_format_endpoint(pattern): + if not filter_param or parent_pattern: + api_endpoint = ApiEndpoint(pattern, parent_pattern, self.drf_router) + self.endpoints.append(api_endpoint) def _is_drf_view(self, pattern): """ @@ -40,7 +47,7 @@ def _is_format_endpoint(self, pattern): """ Exclude endpoints with a "format" parameter """ - return '?P' in pattern._regex + return '?P' in pattern.pattern.regex.pattern def get_endpoints(self): - return self.endpoints + return sorted(self.endpoints, key=attrgetter('name', 'path')) diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index 89a33f8..3f15a0a 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -2,50 +2,62 @@ import inspect from django.contrib.admindocs.views import simplify_regex from django.utils.encoding import force_str -from rest_framework.serializers import BaseSerializer +from rest_framework import serializers +from rest_framework.viewsets import ModelViewSet +from rest_framework_docs import SERIALIZER_FIELDS class ApiEndpoint(object): def __init__(self, pattern, parent_pattern=None, drf_router=None): - self.drf_router = drf_router + self.drf_router = drf_router or [] + if not isinstance(self.drf_router, (list, tuple)): + self.drf_router = [self.drf_router] self.pattern = pattern self.callback = pattern.callback - # self.name = pattern.name self.docstring = self.__get_docstring__() - self.name_parent = simplify_regex(parent_pattern.regex.pattern).strip('/') if parent_pattern else None + if parent_pattern: + self.name_parent = parent_pattern.namespace or parent_pattern.app_name or \ + simplify_regex(parent_pattern.pattern.regex.pattern).replace('/', '-') + self.name = self.name_parent + if hasattr(pattern.callback, 'cls') and issubclass(pattern.callback.cls, ModelViewSet): + self.name = '%s (RESTful)' % self.name_parent + else: + self.name_parent = '' + self.name = '' + # self.labels = (self.name_parent, self.name, slugify(self.name)) + self.labels = dict(parent=self.name_parent, name=self.name) self.path = self.__get_path__(parent_pattern) self.allowed_methods = self.__get_allowed_methods__() - # self.view_name = pattern.callback.__name__ self.errors = None self.serializer_class = self.__get_serializer_class__() if self.serializer_class: self.serializer = self.__get_serializer__() - self.fields = self.__get_serializer_fields__(self.serializer) + self.fields = self.__get_serializer_fields__() self.fields_json = self.__get_serializer_fields_json__() - self.permissions = self.__get_permissions_class__() def __get_path__(self, parent_pattern): if parent_pattern: - return "/{0}{1}".format(self.name_parent, simplify_regex(self.pattern.regex.pattern)) - return simplify_regex(self.pattern.regex.pattern) + parent_regex = parent_pattern.pattern.regex.pattern + return simplify_regex(parent_regex + self.pattern.pattern.regex.pattern) + return simplify_regex(self.pattern.pattern.regex.pattern) def __get_allowed_methods__(self): viewset_methods = [] - if self.drf_router: - for prefix, viewset, basename in self.drf_router.registry: + for router in self.drf_router: + for prefix, viewset, basename in router.registry: if self.callback.cls != viewset: continue - lookup = self.drf_router.get_lookup_regex(viewset) - routes = self.drf_router.get_routes(viewset) + lookup = router.get_lookup_regex(viewset) + routes = router.get_routes(viewset) for route in routes: # Only actions which actually exist on the viewset will be bound - mapping = self.drf_router.get_method_map(viewset, route.mapping) + mapping = router.get_method_map(viewset, route.mapping) if not mapping: continue @@ -53,9 +65,9 @@ def __get_allowed_methods__(self): regex = route.url.format( prefix=prefix, lookup=lookup, - trailing_slash=self.drf_router.trailing_slash + trailing_slash=router.trailing_slash ) - if self.pattern.regex.pattern == regex: + if self.pattern.pattern.regex.pattern == regex: funcs, viewset_methods = zip( *[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping] ) @@ -63,7 +75,8 @@ def __get_allowed_methods__(self): if len(set(funcs)) == 1: self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0])) - view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)] + view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if + hasattr(self.callback.cls, m) or m in getattr(self.callback, 'actions', {})] return viewset_methods + view_methods def __get_docstring__(self): @@ -86,31 +99,51 @@ def __get_serializer_class__(self): if hasattr(self.callback.cls, 'get_serializer_class'): return self.callback.cls.get_serializer_class(self.pattern.callback.cls()) - def __get_serializer_fields__(self, serializer): + def __get_serializer_fields__(self): fields = [] + serializer = None + + if hasattr(self.callback.cls, 'serializer_class'): + serializer = self.callback.cls.serializer_class + + elif hasattr(self.callback.cls, 'get_serializer_class'): + serializer = self.callback.cls.get_serializer_class(self.pattern.callback.cls()) if hasattr(serializer, 'get_fields'): - for key, field in serializer.get_fields().items(): - to_many_relation = True if hasattr(field, 'many') else False - sub_fields = [] - - if to_many_relation: - sub_fields = self.__get_serializer_fields__(field.child) if isinstance(field, BaseSerializer) else None - else: - sub_fields = self.__get_serializer_fields__(field) if isinstance(field, BaseSerializer) else None - - fields.append({ - "name": key, - "type": str(field.__class__.__name__), - "sub_fields": sub_fields, - "required": field.required, - "to_many_relation": to_many_relation - }) - # FIXME: - # Show more attibutes of `field`? + try: + fields = self.__get_fields__(serializer) + except KeyError as e: + self.errors = e + fields = [] return fields + def __get_fields__(self, serializer): + if serializer in SERIALIZER_FIELDS: + return SERIALIZER_FIELDS.get(serializer) + + fields = [] + for key, field in serializer().get_fields().items(): + item = dict( + name=key, + type=str(field.__class__.__name__), + required=field.required + ) + + # Nested/List serializer + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + sub_type = field.child.__class__ + item['sub_type'] = str(sub_type.__name__) + if isinstance(sub_type(), serializers.Serializer): + item['fields'] = self.__get_fields__(sub_type) + elif isinstance(field, serializers.Serializer): + item['fields'] = self.__get_fields__(field.__class__) + fields.append(item) + + # Keep a copy of serializer fields for optimization purposes + SERIALIZER_FIELDS[serializer] = fields + return fields + def __get_serializer_fields_json__(self): # FIXME: # Return JSON or not? diff --git a/rest_framework_docs/settings.py b/rest_framework_docs/settings.py index 2853a7b..c9804c0 100644 --- a/rest_framework_docs/settings.py +++ b/rest_framework_docs/settings.py @@ -5,7 +5,8 @@ class DRFSettings(object): def __init__(self): self.drf_settings = { - "HIDE_DOCS": self.get_setting("HIDE_DOCS") or False + "HIDE_DOCS": self.get_setting("HIDE_DOCS") or False, + "LOGIN_REQUIRED": self.get_setting("LOGIN_REQUIRED") or False, } def get_setting(self, name): diff --git a/rest_framework_docs/templates/rest_framework_docs/base.html b/rest_framework_docs/templates/rest_framework_docs/base.html index 5a2e6fc..61ae64b 100644 --- a/rest_framework_docs/templates/rest_framework_docs/base.html +++ b/rest_framework_docs/templates/rest_framework_docs/base.html @@ -55,10 +55,12 @@ {% block jumbotron %} + {% endblock %} {% block content %}{% endblock %} diff --git a/rest_framework_docs/templates/rest_framework_docs/fields.html b/rest_framework_docs/templates/rest_framework_docs/fields.html new file mode 100644 index 0000000..e99251e --- /dev/null +++ b/rest_framework_docs/templates/rest_framework_docs/fields.html @@ -0,0 +1,12 @@ +
    + {% for field in item.fields %} +
  • + {{ field.name }}: {{ field.type }} + {% if field.required %}R{% endif %} + {% if field.sub_type %} ({{ field.sub_type }}) {% endif %} + {% if field.fields %} + {% include 'rest_framework_docs/fields.html' with item=field only %} + {% endif %} +
  • + {% endfor %} +
diff --git a/rest_framework_docs/templates/rest_framework_docs/home.html b/rest_framework_docs/templates/rest_framework_docs/home.html index 235a6ee..546a95a 100644 --- a/rest_framework_docs/templates/rest_framework_docs/home.html +++ b/rest_framework_docs/templates/rest_framework_docs/home.html @@ -1,100 +1,106 @@ {% extends "rest_framework_docs/docs.html" %} {% block apps_menu %} -{% regroup endpoints by name_parent as endpoints_grouped %} +{% regroup endpoints by labels as endpoints_grouped %} +{% if endpoints_grouped|length > 1 %} +{% endif %} {% endblock %} {% block content %} +{% regroup endpoints by labels as endpoints_grouped %} +{% if endpoints_grouped %} +{% for group in endpoints_grouped %} +

+ {% if group.grouper.parent %} + {{ group.grouper.name }} + {% endif %} +

- {% regroup endpoints by name_parent as endpoints_grouped %} - - {% if endpoints_grouped %} - {% for group in endpoints_grouped %} - -

{{group.grouper}}

- -
+
{% for endpoint in group.list %} -
+
+
{% endfor %} -
+
- {% endfor %} - {% elif not query %} -

There are currently no api endpoints to document.

- {% else %} -

No endpoints found for {{ query }}.

- {% endif %} +{% endfor %} +{% elif not query %} +

There are currently no api endpoints to document.

+{% else %} +

No endpoints found for {{ query }}.

+{% endif %} - - {% endblock %} diff --git a/rest_framework_docs/urls.py b/rest_framework_docs/urls.py index beb1588..512aa04 100644 --- a/rest_framework_docs/urls.py +++ b/rest_framework_docs/urls.py @@ -1,7 +1,19 @@ from django.conf.urls import url + +from rest_framework_docs.settings import DRFSettings from rest_framework_docs.views import DRFDocsView + +settings = DRFSettings().settings +if settings["LOGIN_REQUIRED"]: + from django.contrib.auth.decorators import login_required + docs_view = login_required(DRFDocsView.as_view()) +else: + docs_view = DRFDocsView.as_view() + urlpatterns = [ # Url to view the API Docs - url(r'^$', DRFDocsView.as_view(), name='drfdocs'), + url(r'^$', docs_view, name='drfdocs'), + # Url to view the API Docs with a specific namespace or app_name + url(r'^(?P[\w-]+)/$', docs_view, name='drfdocs-filter'), ] diff --git a/rest_framework_docs/views.py b/rest_framework_docs/views.py index 50400d4..862eaa4 100644 --- a/rest_framework_docs/views.py +++ b/rest_framework_docs/views.py @@ -1,5 +1,6 @@ from django.http import Http404 from django.views.generic.base import TemplateView + from rest_framework_docs.api_docs import ApiDocumentation from rest_framework_docs.settings import DRFSettings @@ -9,13 +10,13 @@ class DRFDocsView(TemplateView): template_name = "rest_framework_docs/home.html" drf_router = None - def get_context_data(self, **kwargs): + def get_context_data(self, filter_param=None, **kwargs): settings = DRFSettings().settings if settings["HIDE_DOCS"]: raise Http404("Django Rest Framework Docs are hidden. Check your settings.") context = super(DRFDocsView, self).get_context_data(**kwargs) - docs = ApiDocumentation(drf_router=self.drf_router) + docs = ApiDocumentation(drf_router=self.drf_router, filter_param=filter_param) endpoints = docs.get_endpoints() query = self.request.GET.get("search", "") diff --git a/runtests.py b/runtests.py index c388477..4b06497 100644 --- a/runtests.py +++ b/runtests.py @@ -53,6 +53,7 @@ def run_tests_coverage(): cov.report() cov.html_report(directory='covhtml') + exit_on_failure(flake8_main(FLAKE8_ARGS)) exit_on_failure(run_tests_eslint()) exit_on_failure(run_tests_coverage()) diff --git a/tests/settings.py b/tests/settings.py index d80e9f2..7e59296 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -39,3 +39,11 @@ # https://docs.djangoproject.com/en/1.8/howto/static-files/ STATIC_URL = '/static/' + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [], + 'APP_DIRS': True, + }, +] diff --git a/tests/tests.py b/tests/tests.py index 998faee..c0f95a6 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -6,7 +6,8 @@ class DRFDocsViewTests(TestCase): SETTINGS_HIDE_DOCS = { - 'HIDE_DOCS': True # Default: False + 'HIDE_DOCS': True, # Default: False + 'LOGIN_REQUIRED': False, # Default: False } def setUp(self): @@ -27,27 +28,21 @@ def test_index_view_with_endpoints(self): response = self.client.get(reverse('drfdocs')) self.assertEqual(response.status_code, 200) - self.assertEqual(len(response.context["endpoints"]), 15) + self.assertEqual(len(response.context["endpoints"]), 16) # Test the login view - self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") - self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/") - self.assertEqual(response.context["endpoints"][0].docstring, "A view that allows users to login providing their username and password.") - self.assertEqual(len(response.context["endpoints"][0].fields), 2) - self.assertEqual(response.context["endpoints"][0].fields[0]["type"], "CharField") - self.assertTrue(response.context["endpoints"][0].fields[0]["required"]) - - self.assertEqual(response.context["endpoints"][1].name_parent, "accounts") - self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/") - self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class") - self.assertEqual(len(response.context["endpoints"][1].fields), 2) - self.assertEqual(response.context["endpoints"][1].fields[0]["type"], "CharField") - self.assertTrue(response.context["endpoints"][1].fields[0]["required"]) + endpoint = response.context["endpoints"][4] + self.assertEqual(endpoint.name_parent, "accounts") + self.assertEqual(endpoint.allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(endpoint.path, "/accounts/login/") + self.assertEqual(endpoint.docstring, "A view that allows users to login providing their username and password.") + self.assertEqual(len(endpoint.fields), 2) + self.assertEqual(endpoint.fields[0]["type"], "CharField") + self.assertTrue(endpoint.fields[0]["required"]) # The view "OrganisationErroredView" (organisations/(?P[\w-]+)/errored/) should contain an error. - self.assertEqual(str(response.context["endpoints"][9].errors), "'test_value'") + endpoint = response.context["endpoints"][12] + self.assertEqual(str(endpoint.errors), "'test_value'") def test_index_search_with_endpoints(self): response = self.client.get("%s?search=reset-password" % reverse("drfdocs")) @@ -68,16 +63,77 @@ def test_index_view_docs_hidden(self): self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase.upper(), "NOT FOUND") - def test_model_viewset(self): - response = self.client.get(reverse('drfdocs')) + def test_index_view_with_existent_namespace(self): + """ + Should load the drf docs view with all the endpoints contained in the specified namespace. + NOTE: Views that do **not** inherit from DRF's "APIView" are not included. + """ + # Test 'accounts' namespace + response = self.client.get(reverse('drfdocs-filter', args=['accounts'])) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 6) + + # Test the login view + self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") + self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/") + + # Test 'organisations' namespace + response = self.client.get(reverse('drfdocs-filter', args=['organisations'])) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 5) + + # The view "OrganisationErroredView" (organisations/(?P[\w-]+)/errored/) should contain an error. + self.assertEqual(str(response.context["endpoints"][1].errors), "'test_value'") + # Test 'members' namespace + response = self.client.get(reverse('drfdocs-filter', args=['members'])) self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 1) - self.assertEqual(response.context["endpoints"][10].path, '/organisations//') - self.assertEqual(response.context['endpoints'][6].fields[2]['to_many_relation'], True) - self.assertEqual(response.context["endpoints"][11].path, '/organisation-model-viewsets/') - self.assertEqual(response.context["endpoints"][12].path, '/organisation-model-viewsets//') - self.assertEqual(response.context["endpoints"][11].allowed_methods, ['GET', 'POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][12].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][13].allowed_methods, ['POST', 'OPTIONS']) - self.assertEqual(response.context["endpoints"][13].docstring, 'This is a test.') + def test_index_search_with_existent_namespace(self): + response = self.client.get("%s?search=reset-password" % reverse('drfdocs-filter', args=['accounts'])) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 2) + self.assertEqual(response.context["endpoints"][1].path, "/accounts/reset-password/confirm/") + self.assertEqual(len(response.context["endpoints"][1].fields), 3) + + def test_index_view_with_existent_app_name(self): + """ + Should load the drf docs view with all the endpoints contained in the specified app_name. + NOTE: Views that do **not** inherit from DRF's "APIView" are not included. + """ + # Test 'organisations_app' app_name + response = self.client.get(reverse('drfdocs-filter', args=['organisations_app'])) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 6) + parents_name = [e.name_parent for e in response.context["endpoints"]] + self.assertEquals(parents_name.count('organisations'), 5) + self.assertEquals(parents_name.count('members'), 1) + + def test_index_search_with_existent_app_name(self): + response = self.client.get("%s?search=create" % reverse('drfdocs-filter', args=['organisations_app'])) + + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 1) + self.assertEqual(response.context["endpoints"][0].path, "/organisations/create/") + self.assertEqual(len(response.context["endpoints"][0].fields), 3) + + def test_index_view_with_non_existent_namespace_or_app_name(self): + """ + Should load the drf docs view with no endpoint. + """ + response = self.client.get(reverse('drfdocs-filter', args=['non-existent-ns-or-app-name'])) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.context["endpoints"]), 0) + + def test_model_viewset(self): + response = self.client.get(reverse('drfdocs')) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.context["endpoints"][1].path, '/organisation-model-viewsets/') + self.assertEqual(response.context["endpoints"][2].path, '/organisation-model-viewsets//') + self.assertEqual(response.context["endpoints"][1].allowed_methods, ['GET', 'POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][2].allowed_methods, ['GET', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][4].allowed_methods, ['POST', 'OPTIONS']) + self.assertEqual(response.context["endpoints"][3].docstring, 'This is a test.') diff --git a/tests/urls.py b/tests/urls.py index abdf71b..a4cb5b7 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, division, print_function +import django from django.conf.urls import include, url from django.contrib import admin from rest_framework.routers import SimpleRouter from rest_framework_docs.views import DRFDocsView from tests import views + accounts_urls = [ url(r'^login/$', views.LoginView.as_view(), name="login"), url(r'^login2/$', views.LoginWithSerilaizerClassView.as_view(), name="login2"), @@ -29,15 +31,40 @@ router = SimpleRouter() router.register('organisation-model-viewsets', views.TestModelViewSet, base_name='organisation') +members_urls = [ + url(r'^(?P[\w-]+)/members/$', view=views.OrganisationMembersView.as_view(), name="members"), +] + urlpatterns = [ url(r'^admin/', include(admin.site.urls)), - url(r'^docs/', DRFDocsView.as_view(drf_router=router), name='drfdocs'), + + # url(r'^docs/', include('rest_framework_docs.urls')), + url(r'^docs/(?P[\w-]+)/$', DRFDocsView.as_view(drf_router=router), name='drfdocs-filter'), + url(r'^docs/$', DRFDocsView.as_view(drf_router=router), name='drfdocs'), # API - url(r'^accounts/', view=include(accounts_urls, namespace='accounts')), - url(r'^organisations/', view=include(organisations_urls, namespace='organisations')), + # url(r'^accounts/', view=include(accounts_urls, namespace="accounts")), + # url(r'^organisations/', view=include(organisations_urls, namespace='organisations')), url(r'^', include(router.urls)), # Endpoints without parents/namespaces url(r'^another-login/$', views.LoginView.as_view(), name="login"), ] + +# Django 1.9 Support for the app_name argument is deprecated +# https://docs.djangoproject.com/en/1.9/ref/urls/#include +django_version = django.VERSION +if django.VERSION[:2] >= (1, 9, ): + organisations_urls = (organisations_urls, 'organisations_app', ) + members_urls = (members_urls, 'organisations_app', ) + urlpatterns.extend([ + url(r'^accounts/', view=include(accounts_urls, namespace="accounts")), + url(r'^organisations/', view=include(organisations_urls, namespace='organisations')), + url(r'^members/', view=include(members_urls, namespace='members')), + ]) +else: + urlpatterns.extend([ + url(r'^accounts/', view=include(accounts_urls, namespace="accounts", app_name='accounts_app')), + url(r'^organisations/', view=include(organisations_urls, namespace='organisations', app_name='organisations_app')), + url(r'^members/', view=include(members_urls, namespace='members', app_name='organisations_app')), + ])