9
9
10
10
import dataclasses
11
11
import datetime
12
+ import functools
12
13
import logging
13
14
import os
14
15
import shutil
46
47
logger = logging .getLogger (__name__ )
47
48
48
49
49
- @dataclasses .dataclass (order = True )
50
+ @dataclasses .dataclass (eq = True )
51
+ @functools .total_ordering
50
52
class VulnerabilitySeverity :
51
53
# FIXME: this should be named scoring_system, like in the model
52
54
system : ScoringSystem
@@ -55,15 +57,26 @@ class VulnerabilitySeverity:
55
57
published_at : Optional [datetime .datetime ] = None
56
58
57
59
def to_dict (self ):
58
- published_at_dict = (
59
- {"published_at" : self .published_at .isoformat ()} if self .published_at else {}
60
- )
61
- return {
60
+ data = {
62
61
"system" : self .system .identifier ,
63
62
"value" : self .value ,
64
63
"scoring_elements" : self .scoring_elements ,
65
- ** published_at_dict ,
66
64
}
65
+ if self .published_at :
66
+ if isinstance (self .published_at , datetime .datetime ):
67
+ data ["published_at" ] = self .published_at .isoformat ()
68
+ else :
69
+ data ["published_at" ] = self .published_at
70
+ return data
71
+
72
+ def __lt__ (self , other ):
73
+ if not isinstance (other , VulnerabilitySeverity ):
74
+ return NotImplemented
75
+ return self ._cmp_key () < other ._cmp_key ()
76
+
77
+ # TODO: Add cache
78
+ def _cmp_key (self ):
79
+ return (self .system .identifier , self .value , self .scoring_elements , self .published_at )
67
80
68
81
@classmethod
69
82
def from_dict (cls , severity : dict ):
@@ -79,7 +92,8 @@ def from_dict(cls, severity: dict):
79
92
)
80
93
81
94
82
- @dataclasses .dataclass (order = True )
95
+ @dataclasses .dataclass (eq = True )
96
+ @functools .total_ordering
83
97
class Reference :
84
98
reference_id : str = ""
85
99
reference_type : str = ""
@@ -90,27 +104,28 @@ def __post_init__(self):
90
104
if not self .url :
91
105
raise TypeError ("Reference must have a url" )
92
106
93
- def normalized (self ):
94
- severities = sorted ( self . severities )
95
- return Reference (
96
- reference_id = self . reference_id ,
97
- url = self . url ,
98
- severities = severities ,
99
- reference_type = self . reference_type ,
100
- )
107
+ def __lt__ (self , other ):
108
+ if not isinstance ( other , Reference ):
109
+ return NotImplemented
110
+ return self . _cmp_key () < other . _cmp_key ()
111
+
112
+ # TODO: Add cache
113
+ def _cmp_key ( self ):
114
+ return ( self . reference_id , self . reference_type , self . url , tuple ( self . severities ) )
101
115
102
116
def to_dict (self ):
117
+ """Return a normalized dictionary representation"""
103
118
return {
104
119
"reference_id" : self .reference_id ,
105
120
"reference_type" : self .reference_type ,
106
121
"url" : self .url ,
107
- "severities" : [severity .to_dict () for severity in self .severities ],
122
+ "severities" : [severity .to_dict () for severity in sorted ( self .severities ) ],
108
123
}
109
124
110
125
@classmethod
111
126
def from_dict (cls , ref : dict ):
112
127
return cls (
113
- reference_id = ref ["reference_id" ],
128
+ reference_id = str ( ref ["reference_id" ]) ,
114
129
reference_type = ref .get ("reference_type" ) or "" ,
115
130
url = ref ["url" ],
116
131
severities = [
@@ -140,7 +155,8 @@ class NoAffectedPackages(Exception):
140
155
"""
141
156
142
157
143
- @dataclasses .dataclass (order = True , frozen = True )
158
+ @functools .total_ordering
159
+ @dataclasses .dataclass (eq = True )
144
160
class AffectedPackage :
145
161
"""
146
162
Relate a Package URL with a range of affected versions and a fixed version.
@@ -170,6 +186,19 @@ def get_fixed_purl(self):
170
186
raise ValueError (f"Affected Package { self .package !r} does not have a fixed version" )
171
187
return update_purl_version (purl = self .package , version = str (self .fixed_version ))
172
188
189
+ def __lt__ (self , other ):
190
+ if not isinstance (other , AffectedPackage ):
191
+ return NotImplemented
192
+ return self ._cmp_key () < other ._cmp_key ()
193
+
194
+ # TODO: Add cache
195
+ def _cmp_key (self ):
196
+ return (
197
+ str (self .package ),
198
+ str (self .affected_version_range or "" ),
199
+ str (self .fixed_version or "" ),
200
+ )
201
+
173
202
@classmethod
174
203
def merge (
175
204
cls , affected_packages : Iterable
0 commit comments