Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/prist5_10K_m_025.Rqz.ort
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# data_source:
# owner:
# name: Artur Glavic
# affiliation: null
# contact: b''
# affiliation: ''
# contact: ''
# experiment:
# title: Structural evolution of the CO2/Water interface
# instrument: Amor
Expand Down Expand Up @@ -39,7 +39,7 @@
# data_set: 0
# columns:
# - {name: Qz, unit: 1/angstrom, physical_quantity: normal momentum transfer}
# - {name: R, unit: '', physical_quantity: specular reflectivity}
# - {name: R, unit: '1', physical_quantity: specular reflectivity}
# - {error_of: R, error_type: uncertainty, value_is: sigma}
# - {error_of: Qz, error_type: resolution, value_is: sigma}
# # Qz (1/angstrom) R () sR sQz
Expand Down
173 changes: 169 additions & 4 deletions orsopy/fileio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,57 @@ class ORSOSchemaWarning(RuntimeWarning):
pass


@dataclass(frozen=True)
class ORSOValidationResult:
"""
The result of a validation of ORSO header data as reaturned from Header.check_valid.

Includes additional information about the parameters that are missing/additional to a valid format.
"""

valid: bool
header_class: "Header"
missing_attributes: List[str]
invalid_attributes: List[str]
missing_optionals: List[str]
user_parameters: List[str]
attribute_validations: Dict[str, "ORSOValidationResult"]

def __bool__(self):
return self.valid

@staticmethod
def _pprint_list(items, title):
output = ""
if len(items) > 0:
output += f" {title}\n "
for i in items:
output += f"{i}, "
output = output[:-2] + "\n"
return output

def get_report(self):
"""
Returns a summary report of the validation result helping to analyze issues with the data..
"""
if self.valid:
output = f"Dictionary was a valid {self.header_class.__name__} dataset\n"
output += self._pprint_list(self.missing_optionals, "Optional extra parameters that could be provided")
output += self._pprint_list(self.user_parameters, "User parameters not part of specification")
return output
else:
output = f"Dictionary was invalid for {self.header_class.__name__}!\n Reasons for classification:\n"
output += self._pprint_list(self.user_parameters, "User parameters not part of specification")
output += self._pprint_list(self.missing_attributes, "Required parameters not provided")
output += self._pprint_list(self.invalid_attributes, "Parameters with invalid type or attributes")
for ia in self.invalid_attributes:
if ia in self.attribute_validations:
output += f" ORSO parameter {ia} extra information:\n "
output += self.attribute_validations[ia].get_report().replace("\n", "\n ")
output += "\n"
return output


class Header:
"""
The super class for all the items in the orso module.
Expand All @@ -61,6 +112,8 @@ class Header:
# _orso_name_export_priority: List[str] # an optional list of attribute names to put first in the yaml export
_subclass_dict_ = {}

_last_failed_type = None

def __init_subclass__(cls, **kwargs):
"""
For each subclass of Header, collect optional arguments in
Expand Down Expand Up @@ -102,7 +155,7 @@ def from_dict(cls, data_dict):
# convert dictionary to Header derived class if possible
if type(ftype) is type and type(value) is dict and issubclass(ftype, Header):
# the field requires a ORSO Header type
value = construct_fields[field_keys.index(key)].type.from_dict(value)
value = ftype.from_dict(value)
construct_dict[key] = value
else:
user_dict[key] = value
Expand All @@ -111,6 +164,115 @@ def from_dict(cls, data_dict):
setattr(output, key, value)
return output

@classmethod
def check_valid(cls, data_dict, user_is_valid=False) -> ORSOValidationResult:
"""
Analyze input data to see if it is valid for this class and provide
additional information to a user of the API to improve export filters.

By default, user parameters are treated as invalid, as these could be
unintentional like typos.
"""
is_valid = True
construct_fields = list(fields(cls))
field_keys = [fi.name for fi in construct_fields]
missing_attributes = []
invalid_attributes = []
missing_optionals = []
user_keys = []
attribute_validations = {}

for key, value in data_dict.items():
# is the supplied value a valid attribute for this class
if key in field_keys:
ftype = construct_fields[field_keys.index(key)].type
hbase = get_origin(ftype)
type_value = type(value)
if value is None:
# value is supplied but interpreted as empty, should only happen for optionals
if key in cls._orso_optionals:
missing_optionals.append(key)
else:
missing_attributes.append(key)
elif type(ftype) is type and type_value is dict and issubclass(ftype, Header):
# the field requires a ORSO Header type
result = ftype.check_valid(value, user_is_valid=user_is_valid)
if not result:
invalid_attributes.append(key)
is_valid = False
attribute_validations[key] = result
elif hbase in [Union, Optional] and type_value is dict:
# Case of combined type hints.
# Look for given value type first,
# otherwise try to convert to each type and return the first that fits.
subtypes = get_args(ftype)
if type(value) not in subtypes:
# item is not exactly in the subtypes
failed_results = []
for subt in subtypes:
if type(subt) is type and issubclass(subt, Header):
result = subt.check_valid(value, user_is_valid=user_is_valid)
if result:
failed_results = []
break
else:
failed_results.append(result)
if len(failed_results) > 0:
# select best fitting of failed results, in case there are multiple Header in a Union
best_match = failed_results[0]
for failed_result in failed_results:
if (len(failed_result.missing_attributes) + len(failed_result.invalid_attributes)) < (
len(best_match.missing_attributes) + len(best_match.invalid_attributes)
):
best_match = failed_result
is_valid = False
attribute_validations[key] = best_match
invalid_attributes.append(key)
else:
# try to resolve the type of this value
with warnings.catch_warnings(record=True) as w:
try:
updt = cls._resolve_type(ftype, value)
except Exception:
invalid_attributes.append(key)
if len(w) > 0:
# tried to resolve a type but failed with warning
if type_value is dict:
result = Header._last_failed_type.check_valid(value, user_is_valid=user_is_valid)
if not result:
invalid_attributes.append(key)
is_valid = False
attribute_validations[key] = result
else:
invalid_attributes.append(key)
if updt is None:
invalid_attributes.append(key)
elif type_value is not dict and not isinstance(updt, type_value):
invalid_attributes.append(key)
construct_fields.pop(field_keys.index(key))
field_keys.pop(field_keys.index(key))
else:
user_keys.append(key)
# collect missing key names
for key in field_keys:
if key in cls._orso_optionals:
missing_optionals.append(key)
else:
missing_attributes.append(key)
is_valid &= len(missing_attributes) == 0
is_valid &= len(invalid_attributes) == 0
is_valid &= user_is_valid or (len(user_keys) == 0)
missing_optionals.remove("comment")
return ORSOValidationResult(
is_valid,
cls,
missing_attributes,
invalid_attributes,
missing_optionals,
user_keys,
attribute_validations,
)

def __post_init__(self):
"""Make sure Header types are correct."""
for fld in fields(self):
Expand Down Expand Up @@ -280,13 +442,16 @@ def _resolve_type(hint: type, item: Any) -> Any:
if res is not None:
# This type conversion worked, return the result.
if len(w) > 0:
potential_res.append((w, res))
potential_res.append((w, res, subt))
else:
return res
if len(potential_res) > 0:
# a potential type was found, but it raised a warning
w, res = potential_res[0]
w, res, subt = potential_res[0]
# make sure the warning is displayed
if w[-1].category is ORSOSchemaWarning:
# for validation, store the failed last failed type
Header._last_failed_type = subt
warnings.warn(w[-1].message, w[-1].category, w[-1].lineno)
return res
elif hbase is Literal:
Expand Down Expand Up @@ -870,7 +1035,7 @@ def __post_init__(self):
Assigns a timestamp for file creation if not defined.
"""
Header.__post_init__(self)
if self.timestamp is None:
if self.timestamp is None and self.file is not None:
fname = pathlib.Path(self.file)
if fname.exists():
self.timestamp = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
Expand Down
4 changes: 2 additions & 2 deletions orsopy/fileio/orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ def save_orso(
np.savetxt(f, dsi.data, header=hi, fmt="%-22.16e")


def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]:
def load_orso(fname: Union[TextIO, str], validate=False) -> List[OrsoDataset]:
"""
:param fname: The Orso file to load.

:return: :py:class:`OrsoDataset` objects for each dataset contained
within the ORT file.
"""
dct_list, datas, version = _read_header_data(fname)
dct_list, datas, version = _read_header_data(fname, validate=validate)
ods = []

for dct, data in zip(dct_list, datas):
Expand Down
92 changes: 92 additions & 0 deletions orsopy/fileio/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for fileio.base module
"""

# pylint: disable=R0201

import datetime as datetime_module
Expand Down Expand Up @@ -74,6 +75,97 @@ class TestDatetime(base.Header):
finally:
datetime_module.datetime = datetime

def test_simple_validate(self):
@dataclass
class TestClass(base.Header):
test: str
test2: int
test3: float

# check valid entry is returned correctly
assert TestClass.check_valid(dict(test="abc", test2=1, test3=2.0))
# missing value
res = TestClass.check_valid(dict(test="abc", test3=2.0))
self.assertFalse(res.valid)
self.assertEqual(res.missing_attributes, ["test2"])
# empty required value
res = TestClass.check_valid(dict(test="abc", test2=None, test3=2.0))
self.assertFalse(res.valid)
self.assertEqual(res.missing_attributes, ["test2"])
# invalid value
res = TestClass.check_valid(dict(test="abc", test2=1.0, test3="test"))
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test2", "test3"])
# user value
assert TestClass.check_valid(dict(test="abc", test2=1, test3=1.0, peter=123), user_is_valid=True)
res = TestClass.check_valid(dict(test="abc", test2=1, test3=1.0, peter=123))
self.assertFalse(res.valid)
self.assertEqual(res.user_parameters, ["peter"])

def test_deep_validate(self):
@dataclass
class TestClass1(base.Header):
test: str
test2: int
test3: float

@dataclass
class TestClass1a(base.Header):
beer: str

@dataclass
class TestClass2(base.Header):
test: TestClass1
test2: float = 2.5
test3: Optional[TestClass1] = None
test4: Optional[Union[TestClass1, TestClass1a, str]] = None

# check valid entry is returned correctly
assert TestClass2.check_valid(dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23))
assert TestClass2.check_valid(
dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23, test4=dict(test="abc", test2=1, test3=2.0))
)
res = TestClass2.check_valid(dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23, test4="peter"))
assert res
res.get_report() # check no error condition
# missing value
res = TestClass2.check_valid(dict(test=dict(test="abc", test2=1, test3=2.0), test3=None))
self.assertFalse(res.valid)
self.assertEqual(res.missing_attributes, ["test2"])
self.assertEqual(res.missing_optionals, ["test3", "test4"])
res.get_report() # check no error condition
# missing sub-value
res = TestClass2.check_valid(dict(test=dict(test="abc", test2=1), test2=1.23))
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test"])
self.assertEqual(res.attribute_validations["test"].missing_attributes, ["test3"])
res.get_report() # check no error condition
# invalid value
res = TestClass2.check_valid(dict(test=dict(test="abc", test2=1, test3=2.0), test2="1.23"))
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test2"])
res.get_report() # check no error condition
# invalid sub-value
res = TestClass2.check_valid(dict(test=dict(test="abc", test2="1", test3=2.0), test2=1.23))
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test"])
self.assertEqual(res.attribute_validations["test"].invalid_attributes, ["test2"])
res = TestClass2.check_valid(
dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23, test4=dict(test=13.3, test2=1, test3=2.0))
)
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test4"])
self.assertEqual(res.attribute_validations["test4"].invalid_attributes, ["test"])
res.get_report() # check no error condition
res = TestClass2.check_valid(
dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23, test4=dict(beer="beer"))
)
assert res
res = TestClass2.check_valid(dict(test=dict(test="abc", test2=1, test3=2.0), test2=1.23, test4=dict(beer=3.5)))
self.assertFalse(res.valid)
self.assertEqual(res.invalid_attributes, ["test4"])
self.assertEqual(res.attribute_validations["test4"].invalid_attributes, ["beer"])

def test_resolve_dictof(self):
if sys.version_info < (3, 8):
# dict type annotation changed in 3.8
Expand Down
Loading