|
1 | 1 | import os
|
2 | 2 | from pathlib import Path
|
3 | 3 | from typing import Any, Dict, List, Optional, Set
|
| 4 | +from warnings import warn |
4 | 5 |
|
5 | 6 | import fsspec
|
6 | 7 | import numpy as np
|
7 | 8 |
|
8 | 9 | from ndsl.comm.communicator import Communicator
|
9 |
| -from ndsl.dsl.typing import Float |
| 10 | +from ndsl.dsl.typing import Float, get_precision |
10 | 11 | from ndsl.filesystem import get_fs
|
11 | 12 | from ndsl.logging import ndsl_log
|
12 | 13 | from ndsl.monitor.convert import to_numpy
|
@@ -132,20 +133,10 @@ def __init__(
|
132 | 133 | self._time_chunk_size = time_chunk_size
|
133 | 134 | self.__writer: Optional[_ChunkedNetCDFWriter] = None
|
134 | 135 | self._expected_vars: Optional[Set[str]] = None
|
135 |
| - if precision == Float: |
136 |
| - self._transfer_type = Float |
137 |
| - elif precision == np.float32: |
138 |
| - self._transfer_type = np.float32 |
139 |
| - elif precision == np.float64: |
140 |
| - if np.float32 == Float: |
141 |
| - raise ValueError( |
142 |
| - f"Cannot output float64 with PACE_FLOAT_PRECISION set to {Float}" |
143 |
| - ) |
144 |
| - self._transfer_type = np.float64 |
145 |
| - else: |
146 |
| - raise ValueError( |
147 |
| - "precision must be set to one of 'Float', 'float32', or 'float64" |
148 |
| - f"got {precision}" |
| 136 | + self._transfer_type = precision |
| 137 | + if self._transfer_type == np.float32 and get_precision() > 32: |
| 138 | + warn( |
| 139 | + "NetCDF save: requested 32-bit float but precision of NDSL is {get_precision()}, cast will occur with possible loss of precision" |
149 | 140 | )
|
150 | 141 |
|
151 | 142 | @property
|
|
0 commit comments