5
5
from typing import Optional , Union
6
6
7
7
from loguru import logger
8
+ from result import Err , Ok , Result
8
9
from sqlalchemy .ext .declarative import DeclarativeMeta
9
10
from typing_extensions import TypeGuard
10
11
@@ -18,13 +19,13 @@ def __init__(self, schema_factory: SchemaFactory, /):
18
19
@abstractmethod
19
20
def transform (
20
21
self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
21
- ) -> Schema : ...
22
+ ) -> Result [ Schema , str ] : ...
22
23
23
24
24
25
class JSONSchemaTransformer (AbstractTransformer ):
25
26
def transform (
26
27
self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
27
- ) -> Schema :
28
+ ) -> Result [ Schema , str ] :
28
29
definitions = {}
29
30
30
31
for item in rawtargets :
@@ -33,33 +34,46 @@ def transform(
33
34
elif inspect .ismodule (item ):
34
35
partial_definitions = self .transform_by_module (item , depth )
35
36
else :
36
- TypeError (f"Expected a class or module, got { item } " )
37
+ return Err (f"Expected a class or module, got { item } " )
37
38
38
- definitions .update (partial_definitions )
39
+ if partial_definitions .is_err ():
40
+ return partial_definitions
39
41
40
- return definitions
42
+ definitions . update ( partial_definitions . unwrap ())
41
43
42
- def transform_by_model (self , model : DeclarativeMeta , depth : Optional [int ], / ) -> Schema :
44
+ return Ok (definitions )
45
+
46
+ def transform_by_model (
47
+ self , model : DeclarativeMeta , depth : Optional [int ], /
48
+ ) -> Result [Schema , str ]:
43
49
return self .schema_factory (model , depth = depth )
44
50
45
- def transform_by_module (self , module : ModuleType , depth : Optional [int ], / ) -> Schema :
51
+ def transform_by_module (
52
+ self , module : ModuleType , depth : Optional [int ], /
53
+ ) -> Result [Schema , str ]:
46
54
subdefinitions = {}
47
55
definitions = {}
48
56
for basemodel in collect_models (module ):
49
- schema = self .schema_factory (basemodel , depth = depth )
57
+ schema_result = self .schema_factory (basemodel , depth = depth )
58
+
59
+ if schema_result .is_err ():
60
+ return schema_result
61
+
62
+ schema = schema_result .unwrap ()
63
+
50
64
if "definitions" in schema :
51
65
subdefinitions .update (schema .pop ("definitions" ))
52
66
definitions [schema ["title" ]] = schema
53
67
d = {}
54
68
d .update (subdefinitions )
55
69
d .update (definitions )
56
- return {"definitions" : definitions }
70
+ return Ok ( {"definitions" : definitions })
57
71
58
72
59
73
class OpenAPI2Transformer (AbstractTransformer ):
60
74
def transform (
61
75
self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
62
- ) -> Schema :
76
+ ) -> Result [ Schema , str ] :
63
77
definitions = {}
64
78
65
79
for target in rawtargets :
@@ -68,29 +82,46 @@ def transform(
68
82
elif inspect .ismodule (target ):
69
83
partial_definitions = self .transform_by_module (target , depth )
70
84
else :
71
- raise TypeError (f"Expected a class or module, got { target } " )
85
+ return Err (f"Expected a class or module, got { target } " )
86
+
87
+ if partial_definitions .is_err ():
88
+ return partial_definitions
72
89
73
- definitions .update (partial_definitions )
90
+ definitions .update (partial_definitions . unwrap () )
74
91
75
- return {"definitions" : definitions }
92
+ return Ok ( {"definitions" : definitions })
76
93
77
- def transform_by_model (self , model : DeclarativeMeta , depth : Optional [int ], / ) -> Schema :
94
+ def transform_by_model (
95
+ self , model : DeclarativeMeta , depth : Optional [int ], /
96
+ ) -> Result [Schema , str ]:
78
97
definitions = {}
79
- schema = self .schema_factory (model , depth = depth )
98
+ schema_result = self .schema_factory (model , depth = depth )
99
+
100
+ if schema_result .is_err ():
101
+ return schema_result
102
+
103
+ schema = schema_result .unwrap ()
80
104
81
105
if "definitions" in schema :
82
106
definitions .update (schema .pop ("definitions" ))
83
107
84
108
definitions [schema ["title" ]] = schema
85
109
86
- return definitions
110
+ return Ok ( definitions )
87
111
88
- def transform_by_module (self , module : ModuleType , depth : Optional [int ], / ) -> Schema :
112
+ def transform_by_module (
113
+ self , module : ModuleType , depth : Optional [int ], /
114
+ ) -> Result [Schema , str ]:
89
115
subdefinitions = {}
90
116
definitions = {}
91
117
92
118
for basemodel in collect_models (module ):
93
- schema = self .schema_factory (basemodel , depth = depth )
119
+ schema_result = self .schema_factory (basemodel , depth = depth )
120
+
121
+ if schema_result .is_err ():
122
+ return schema_result
123
+
124
+ schema = schema_result .unwrap ()
94
125
95
126
if "definitions" in schema :
96
127
subdefinitions .update (schema .pop ("definitions" ))
@@ -101,7 +132,7 @@ def transform_by_module(self, module: ModuleType, depth: Optional[int], /) -> Sc
101
132
d .update (subdefinitions )
102
133
d .update (definitions )
103
134
104
- return definitions
135
+ return Ok ( definitions )
105
136
106
137
107
138
class OpenAPI3Transformer (OpenAPI2Transformer ):
@@ -118,8 +149,13 @@ def replace_ref(self, d: Union[dict, list], old_prefix: str, new_prefix: str, /)
118
149
119
150
def transform (
120
151
self , rawtargets : Iterable [Union [ModuleType , DeclarativeMeta ]], depth : Optional [int ], /
121
- ) -> Schema :
122
- definitions = super ().transform (rawtargets , depth )
152
+ ) -> Result [Schema , str ]:
153
+ definitions_result = super ().transform (rawtargets , depth )
154
+
155
+ if definitions_result .is_err ():
156
+ return Err (definitions_result .unwrap_err ())
157
+
158
+ definitions = definitions_result .unwrap ()
123
159
124
160
self .replace_ref (definitions , "#/definitions/" , "#/components/schemas/" )
125
161
@@ -128,7 +164,8 @@ def transform(
128
164
if "schemas" not in definitions ["components" ]:
129
165
definitions ["components" ]["schemas" ] = {}
130
166
definitions ["components" ]["schemas" ] = definitions .pop ("definitions" , {})
131
- return definitions
167
+
168
+ return Ok (definitions )
132
169
133
170
134
171
def collect_models (module : ModuleType , / ) -> Iterator [DeclarativeMeta ]:
0 commit comments