Skip to content

Commit 3cc1cb5

Browse files
committed
Add support of CORS
1 parent dbb6360 commit 3cc1cb5

File tree

3 files changed

+398
-0
lines changed

3 files changed

+398
-0
lines changed

src/cowboy_req.erl

+126
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
-export([has_resp_header/2]).
6868
-export([has_resp_body/1]).
6969
-export([delete_resp_header/2]).
70+
-export([set_cors_headers/2]).
71+
-export([set_cors_preflight_headers/2]).
7072
-export([reply/2]).
7173
-export([reply/3]).
7274
-export([reply/4]).
@@ -305,6 +307,8 @@ parse_header_fun(<<"accept">>) -> fun cow_http_hd:parse_accept/1;
305307
parse_header_fun(<<"accept-charset">>) -> fun cow_http_hd:parse_accept_charset/1;
306308
parse_header_fun(<<"accept-encoding">>) -> fun cow_http_hd:parse_accept_encoding/1;
307309
parse_header_fun(<<"accept-language">>) -> fun cow_http_hd:parse_accept_language/1;
310+
parse_header_fun(<<"access-control-request-headers">>) -> fun cow_http_hd:parse_access_control_request_headers/1;
311+
parse_header_fun(<<"access-control-request-method">>) -> fun cow_http_hd:parse_access_control_request_method/1;
308312
parse_header_fun(<<"authorization">>) -> fun cow_http_hd:parse_authorization/1;
309313
parse_header_fun(<<"connection">>) -> fun cow_http_hd:parse_connection/1;
310314
parse_header_fun(<<"content-length">>) -> fun cow_http_hd:parse_content_length/1;
@@ -315,6 +319,7 @@ parse_header_fun(<<"if-match">>) -> fun cow_http_hd:parse_if_match/1;
315319
parse_header_fun(<<"if-modified-since">>) -> fun cow_http_hd:parse_if_modified_since/1;
316320
parse_header_fun(<<"if-none-match">>) -> fun cow_http_hd:parse_if_none_match/1;
317321
parse_header_fun(<<"if-unmodified-since">>) -> fun cow_http_hd:parse_if_unmodified_since/1;
322+
parse_header_fun(<<"origin">>) -> fun cow_http_hd:parse_origin/1;
318323
parse_header_fun(<<"range">>) -> fun cow_http_hd:parse_range/1;
319324
parse_header_fun(<<"sec-websocket-extensions">>) -> fun cow_http_hd:parse_sec_websocket_extensions/1;
320325
parse_header_fun(<<"sec-websocket-protocol">>) -> fun cow_http_hd:parse_sec_websocket_protocol_req/1;
@@ -666,6 +671,126 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) ->
666671
RespHeaders2 = lists:keydelete(Name, 1, RespHeaders),
667672
Req#http_req{resp_headers=RespHeaders2}.
668673

674+
-spec set_cors_headers(map(), Req) -> Req when Req :: req().
675+
set_cors_headers(M, Req) ->
676+
try
677+
AllowedOrigins = maps:get(origins, M, []),
678+
Origin =
679+
match_cors_origin(
680+
%% Validating each origin in the list, picking up the first.
681+
case parse_header(<<"origin">>, Req) of
682+
undefined -> throw({bad_origin, undefined, AllowedOrigins});
683+
[H|T] -> _ = [match_cors_origin(Val, AllowedOrigins) || Val <- T], H;
684+
L -> throw({bad_origin, L, AllowedOrigins})
685+
end,
686+
AllowedOrigins),
687+
688+
Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req),
689+
set_cors_exposed_headers(maps:get(exposed_headers, M, []), Req2)
690+
catch throw:_Reason ->
691+
Req
692+
end.
693+
694+
-spec set_cors_preflight_headers(map(), Req) -> Req when Req :: req().
695+
set_cors_preflight_headers(M, Req) ->
696+
try
697+
AllowedOrigins = maps:get(origins, M, []),
698+
Origin =
699+
match_cors_origin(
700+
%% The Origin header can only contain a single origin as the user agent will not follow redirects.
701+
case parse_header(<<"origin">>, Req) of
702+
undefined -> throw({bad_origin, undefined, AllowedOrigins});
703+
[H] -> H;
704+
L -> throw({bad_origin, L, AllowedOrigins})
705+
end,
706+
AllowedOrigins),
707+
Method =
708+
match_cors_method(
709+
parse_header(<<"access-control-request-method">>, Req),
710+
maps:get(methods, M, [])),
711+
Headers =
712+
match_cors_headers(
713+
parse_header(<<"access-control-request-headers">>, Req, []),
714+
maps:get(headers, M, [])),
715+
716+
Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req),
717+
Req3 = set_cors_max_age(maps:get(max_age, M, undefined), Req2),
718+
Req4 = set_cors_allowed_methods([Method], Req3),
719+
set_cors_allowed_headers(Headers, Req4)
720+
catch throw:_Reason ->
721+
Req
722+
end.
723+
724+
-spec set_cors_allow_credentials(boolean(), {binary(), binary(), 0..65535} | reference(), Req) -> Req when Req :: req().
725+
set_cors_allow_credentials(Credentials, Origin, Req) ->
726+
case match_cors_credentials(Credentials, Origin) of
727+
true ->
728+
Req2 = set_resp_header(<<"access-control-allow-origin">>, cow_http_hd:access_control_allow_origin(Origin), Req),
729+
set_resp_header(<<"access-control-allow-credentials">>, cow_http_hd:access_control_allow_credentials(), Req2);
730+
_ ->
731+
set_resp_header(<<"access-control-allow-origin">>, cow_http_hd:access_control_allow_origin(Origin), Req)
732+
end.
733+
734+
-spec set_cors_max_age(non_neg_integer() | undefined, Req) -> Req when Req :: req().
735+
set_cors_max_age(undefined, Req) ->
736+
Req;
737+
set_cors_max_age(Val, Req) ->
738+
set_resp_header(<<"access-control-max-age">>, cow_http_hd:access_control_max_age(Val), Req).
739+
740+
-spec set_cors_allowed_methods([binary()], Req) -> Req when Req :: req().
741+
set_cors_allowed_methods(L, Req) ->
742+
set_resp_header(<<"access-control-allow-methods">>, cow_http_hd:access_control_allow_methods(L), Req).
743+
744+
-spec set_cors_allowed_headers([binary()], Req) -> Req when Req :: req().
745+
set_cors_allowed_headers([], Req) ->
746+
Req;
747+
set_cors_allowed_headers(L, Req) ->
748+
set_resp_header(<<"access-control-allow-headers">>, cow_http_hd:access_control_allow_headers(L), Req).
749+
750+
-spec set_cors_exposed_headers([binary()], Req) -> Req when Req :: req().
751+
set_cors_exposed_headers([], Req) ->
752+
Req;
753+
set_cors_exposed_headers(L, Req) ->
754+
set_resp_header(<<"access-control-expose-headers">>, cow_http_hd:access_control_expose_headers(L), Req).
755+
756+
-spec match_cors_origin(Origin | reference(), [Origin] | Origin | '*')
757+
-> Origin | '*' when Origin :: {binary(), binary(), 0..65535}.
758+
match_cors_origin(Val, '*') when is_reference(Val) ->
759+
'*';
760+
match_cors_origin(Val, '*') ->
761+
Val;
762+
match_cors_origin(Val, Val) ->
763+
Val;
764+
match_cors_origin(Val, AllowedOrigins) when is_list(AllowedOrigins) ->
765+
case lists:member(Val, AllowedOrigins) of
766+
true -> Val;
767+
_ -> throw({nomatch_origin, Val, AllowedOrigins})
768+
end;
769+
match_cors_origin(Val, AllowedOrigins) ->
770+
throw({nomatch_origin, Val, AllowedOrigins}).
771+
772+
-spec match_cors_method(binary() | undefined, [binary()]) -> binary().
773+
match_cors_method(undefined, Methods) ->
774+
throw({bad_method, undefined, Methods});
775+
match_cors_method(Val, AllowedMethods) ->
776+
case lists:member(Val, AllowedMethods) of
777+
true -> Val;
778+
_ -> throw({nomatch_method, Val, AllowedMethods})
779+
end.
780+
781+
-spec match_cors_headers([binary()], [binary()]) -> [binary()].
782+
match_cors_headers(L, AllowedHeaders) ->
783+
[case lists:member(Header, AllowedHeaders) of
784+
false -> throw({nomatch_header, Header, AllowedHeaders});
785+
_ -> Header
786+
end || Header <- L].
787+
788+
-spec match_cors_credentials(boolean(), {binary(), binary(), 0..65535} | reference() | '*') -> boolean().
789+
match_cors_credentials(true, '*') ->
790+
throw({bad_credentials, true, '*'});
791+
match_cors_credentials(Val, _) ->
792+
Val.
793+
669794
-spec reply(cowboy:http_status(), Req) -> Req when Req::req().
670795
reply(Status, Req=#http_req{resp_body=Body}) ->
671796
reply(Status, [], Body, Req).
@@ -1298,4 +1423,5 @@ merge_headers_test_() ->
12981423
{<<"server">>,<<"Cowboy">>}]}
12991424
],
13001425
[fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests].
1426+
13011427
-endif.

0 commit comments

Comments
 (0)