From 9b3ffa8bd7bebb6d328a132eb9ad18cb1e8e1fe9 Mon Sep 17 00:00:00 2001 From: Attack825 <2707138687@qq.com> Date: Tue, 29 Jul 2025 18:36:22 +0800 Subject: [PATCH] feat: adapt enforce/batch_enforce to new Casdoor API parameter format --- src/casdoor/main.py | 123 ++++++++++++++++++++++++---------------- src/casdoor/user.py | 5 +- src/tests/test_oauth.py | 96 +++++++++++++------------------ 3 files changed, 116 insertions(+), 108 deletions(-) diff --git a/src/casdoor/main.py b/src/casdoor/main.py index d91ec66..62ca55f 100644 --- a/src/casdoor/main.py +++ b/src/casdoor/main.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import Dict, List, Optional import jwt @@ -257,86 +258,112 @@ def parse_jwt_token(self, token: str, **kwargs) -> Dict: certificate = x509.load_pem_x509_certificate(self.certification, default_backend()) return_json = jwt.decode( - token, certificate.public_key(), algorithms=self.algorithms, audience=self.client_id, **kwargs + token.encode("utf-8"), + certificate.public_key(), + algorithms=self.algorithms, + audience=self.client_id, + **kwargs, ) return return_json def enforce( self, - permission_model_name: str, - sub: str, - obj: str, - act: str, - v3: Optional[str] = None, - v4: Optional[str] = None, - v5: Optional[str] = None, + permission_id: str, + model_id: str, + resource_id: str, + enforce_id: str, + owner: str, + casbin_request: Optional[List[str]] = None, ) -> bool: """ Send data to Casdoor enforce API - :param permission_model_name: Name permission model - :param sub: sub from Casbin - :param obj: obj from Casbin - :param act: act from Casbin - :param v3: v3 from Casbin - :param v4: v4 from Casbin - :param v5: v5 from Casbin + :param permission_id: the permission id (i.e. organization name/permission name) + :param model_id: the model id + :param resource_id: the resource id + :param enforce_id: the enforce id + :param owner: the owner of the permission + :param casbin_request: a list containing the request data (i.e. sub, obj, act) + :return: a boolean value indicating whether the request is allowed """ url = self.endpoint + "/api/enforce" - query_params = {"clientId": self.client_id, "clientSecret": self.client_secret} params = { - "id": permission_model_name, - "v0": sub, - "v1": obj, - "v2": act, - "v3": v3, - "v4": v4, - "v5": v5, + "permissionId": permission_id, + "modelId": model_id, + "resourceId": resource_id, + "enforceId": enforce_id, + "owner": owner, } - r = requests.post(url, json=params, params=query_params) + r = requests.post( + url, + params=params, + data=json.dumps(casbin_request), + auth=(self.client_id, self.client_secret), + ) if r.status_code != 200 or "json" not in r.headers["content-type"]: error_str = "Casdoor response error:\n" + str(r.text) raise ValueError(error_str) - has_permission = r.json() - + response = r.json() + if isinstance(response, dict): + data = response.get("data") + if isinstance(data, list) and len(data) > 0: + has_permission = data[0] + else: + has_permission = response + else: + has_permission = response if not isinstance(has_permission, bool): error_str = "Casdoor response error:\n" + r.text raise ValueError(error_str) return has_permission - def batch_enforce(self, permission_model_name: str, permission_rules: List[List[str]]) -> List[bool]: + def batch_enforce( + self, + permission_id: str, + model_id: str, + enforce_id: str, + owner: str, + casbin_request: Optional[List[List[str]]] = None, + ) -> List[bool]: """ Send data to Casdoor enforce API - :param permission_model_name: Name permission model - :param permission_rules: permission rules to enforce - [][0] -> sub: sub from Casbin - [][1] -> obj: obj from Casbin - [][2] -> act: act from Casbin - [][3] -> v3: v3 from Casbin (optional) - [][4] -> v4: v4 from Casbin (optional) - [][5] -> v5: v5 from Casbin (optional) + :param permission_id: the permission id (i.e. organization name/permission name) + :param model_id: the model id + :param enforce_id: the enforce id + :param owner: the owner of the permission + :param casbin_request: a list of lists containing the request data + :return: a list of boolean values indicating whether each request is allowed """ url = self.endpoint + "/api/batch-enforce" - query_params = {"clientId": self.client_id, "clientSecret": self.client_secret} - - def map_rule(rule: List[str], idx) -> Dict: - if len(rule) < 3: - raise ValueError("Invalid permission rule[{0}]: {1}".format(idx, rule)) - result = {"id": permission_model_name} - for i in range(0, len(rule)): - result.update({"v{0}".format(i): rule[i]}) - return result - - params = [map_rule(permission_rules[i], i) for i in range(0, len(permission_rules))] - r = requests.post(url, json=params, params=query_params) + params = { + "permissionId": permission_id, + "modelId": model_id, + "enforceId": enforce_id, + "owner": owner, + } + r = requests.post( + url, + params=params, + data=json.dumps(casbin_request), + auth=(self.client_id, self.client_secret), + ) + if r.status_code != 200 or "json" not in r.headers["content-type"]: error_str = "Casdoor response error:\n" + str(r.text) raise ValueError(error_str) - enforce_results = r.json() + response = r.json() + data = response.get("data") + if data is None: + error_str = "Casdoor response error:\n" + r.text + raise ValueError(error_str) + if not isinstance(data, list): + error_str = f"Casdoor 'data' is not a list (got {type(data)}):\n" + r.text + raise ValueError(error_str) + enforce_results = data[0] if ( not isinstance(enforce_results, list) diff --git a/src/casdoor/user.py b/src/casdoor/user.py index 493227d..4e61449 100644 --- a/src/casdoor/user.py +++ b/src/casdoor/user.py @@ -102,7 +102,7 @@ def get_users(self) -> List[Dict]: users.append(User.from_dict(user)) return users - def get_user(self, user_id: str) -> Dict: + def get_user(self, user_id: str) -> User: """ Get the user from Casdoor providing the user_id. @@ -141,7 +141,8 @@ def get_user_count(self, is_online: bool = None) -> int: params["isOnline"] = "1" if is_online else "0" r = requests.get(url, params) - count = r.json() + response = r.json() + count = response.get("data") return count def modify_user(self, method: str, user: User) -> Dict: diff --git a/src/tests/test_oauth.py b/src/tests/test_oauth.py index 6080eb6..1c54fe3 100644 --- a/src/tests/test_oauth.py +++ b/src/tests/test_oauth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import TestCase, mock +from unittest import TestCase from requests import Response @@ -59,7 +59,7 @@ def test__oauth_token_request(self): "code": self.code, } response = sdk._oauth_token_request(payload=data) - self.assertIsInstance(response, dict) + self.assertIsInstance(response, Response) def test__get_payload_for_authorization_code(self): sdk = self.get_sdk() @@ -137,71 +137,51 @@ def test_parse_jwt_token(self): def test_enforce(self): sdk = self.get_sdk() - status = sdk.enforce("built-in/permission-built-in", "admin", "a", "ac") + status = sdk.enforce( + permission_id="built-in/permission-built-in", + model_id="", + resource_id="", + enforce_id="", + owner="", + casbin_request=["alice", "data1", "read"], + ) self.assertIsInstance(status, bool) - def mocked_enforce_requests_post(*args, **kwargs): - class MockResponse: - def __init__(self, json_data, status_code=200, headers=None): - if headers is None: - headers = {"content-type": "json"} - self.json_data = json_data - self.status_code = status_code - self.headers = headers - - def json(self): - return self.json_data - - result = True - for i in range(0, 5): - if kwargs.get("json").get(f"v{i}") != f"v{i}": - result = False - - return MockResponse(result) - - @mock.patch("requests.post", side_effect=mocked_enforce_requests_post) - def test_enforce_parmas(self, mock_post): + def test_enforce_parmas(self): sdk = self.get_sdk() - status = sdk.enforce("built-in/permission-built-in", "v0", "v1", "v2", v3="v3", v4="v4", v5="v5") - self.assertEqual(status, True) - - def mocked_batch_enforce_requests_post(*args, **kwargs): - class MockResponse: - def __init__(self, json_data, status_code=200, headers=None): - if headers is None: - headers = {"content-type": "json"} - self.json_data = json_data - self.status_code = status_code - self.headers = headers - - def json(self): - return self.json_data - - json = kwargs.get("json") - result = [True for i in range(0, len(json))] - for k in range(0, len(json)): - for i in range(0, len(json[k]) - 1): - if json[k].get(f"v{i}") != f"v{i}": - result[k] = False - - return MockResponse(result) + status = sdk.enforce( + permission_id="built-in/permission-built-in", + model_id="", + resource_id="", + enforce_id="", + owner="", + casbin_request=["alice", "data1", "read"], + ) + self.assertIsInstance(status, bool) - @mock.patch("requests.post", side_effect=mocked_batch_enforce_requests_post) - def test_batch_enforce(self, mock_post): + def test_batch_enforce(self): sdk = self.get_sdk() status = sdk.batch_enforce( - "built-in/permission-built-in", [["v0", "v1", "v2", "v3", "v4", "v5"], ["v0", "v1", "v2", "v3", "v4", "v1"]] + permission_id="built-in/permission-built-in", + model_id="", + enforce_id="", + owner="", + casbin_request=[["alice", "data1", "read"], ["bob", "data2", "write"]], ) self.assertEqual(len(status), 2) - self.assertEqual(status[0], True) - self.assertEqual(status[1], False) + self.assertIsInstance(status[0], bool) + self.assertIsInstance(status[1], bool) - @mock.patch("requests.post", side_effect=mocked_batch_enforce_requests_post) - def test_batch_enforce_raise(self, mock_post): + def test_batch_enforce_raise(self): sdk = self.get_sdk() - with self.assertRaises(ValueError) as context: - sdk.batch_enforce("built-in/permission-built-in", [["v0", "v1"]]) - self.assertEqual("Invalid permission rule[0]: ['v0', 'v1']", str(context.exception)) + with self.assertRaises(ValueError): + sdk.batch_enforce( + permission_id="built-in/permission-built-in", + model_id="", + enforce_id="", + owner="", + casbin_request=[], + ) def test_get_users(self): sdk = self.get_sdk() @@ -221,7 +201,7 @@ def test_get_user_count(self): def test_get_user(self): sdk = self.get_sdk() user = sdk.get_user("admin") - self.assertIsInstance(user, dict) + self.assertIsInstance(user, User) def test_modify_user(self): sdk = self.get_sdk()