Coverage for src/rechunk_data/_rechunk.py: 99%
84 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-30 09:58 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-30 09:58 +0000
1"""Rechunking module."""
3import os
4import logging
5from pathlib import Path
6from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple
7from typing_extensions import Literal
9from dask.utils import format_bytes
10from dask.array.core import Array
11import xarray as xr
14logging.basicConfig(
15 format="%(name)s - %(levelname)s - %(message)s", level=logging.ERROR
16)
17logger = logging.getLogger("rechunk-data")
19ENCODINGS = dict(
20 h5netcdf={
21 "_FillValue",
22 "complevel",
23 "chunksizes",
24 "dtype",
25 "zlib",
26 "compression_opts",
27 "shuffle",
28 "fletcher32",
29 "compression",
30 "contiguous",
31 },
32 netcdf4={
33 "contiguous",
34 "complevel",
35 "zlib",
36 "shuffle",
37 "_FillValue",
38 "least_significant_digit",
39 "chunksizes",
40 "fletcher32",
41 "dtype",
42 },
43)
46def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]:
47 suffixes = [".nc", "nc4"]
48 input_path = input_path.expanduser().absolute()
49 if input_path.is_dir() and input_path.exists():
50 nc_iter = input_path.rglob("*.*")
51 elif input_path.is_file() and input_path.exists():
52 nc_iter = cast(Generator[Path, None, None], iter([input_path]))
53 else:
54 # This could be a path with a glob pattern, let's try to construct it
55 nc_iter = input_path.parent.rglob(input_path.name)
56 for ncfile in nc_iter:
57 if ncfile.suffix in suffixes:
58 yield ncfile
61def _save_dataset(
62 dset: xr.Dataset,
63 file_name: Path,
64 encoding: Dict[Hashable, Dict[str, Any]],
65 engine: Literal["netcdf4", "h5netcdf"],
66 override: bool = False,
67) -> None:
68 if not encoding and not override:
69 logger.debug("Chunk size already optimized for %s", file_name.name)
70 return
71 logger.debug("Saving file to %s using %s engine", str(file_name), engine)
72 try:
73 dset.to_netcdf(
74 file_name,
75 engine=engine,
76 encoding=encoding,
77 )
78 except Exception as error:
79 logger.error("Saving to file failed: %s", str(error))
82def _rechunk_dataset(
83 dset: xr.Dataset,
84 engine: Literal["h5netcdf", "netcdf4"],
85) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]:
86 encoding: Dict[Hashable, Dict[str, Any]] = {}
87 try:
88 _keywords = ENCODINGS[engine]
89 except KeyError as error:
90 raise ValueError(
91 "Only the following engines are supported: ', '.join(ENCODINGS.keys())"
92 ) from error
93 for data_var in dset.data_vars:
94 var = str(data_var)
95 if (
96 not isinstance(dset[var].data, Array)
97 or "bnds" in var
98 or "rotated_pole" in var
99 ):
100 logger.debug("Skipping rechunking variable %s", var)
101 continue
102 logger.debug("Rechunking variable %s", var)
103 chunks: Dict[int, Optional[str]] = {}
104 for i, dim in enumerate(map(str, dset[var].dims)):
105 if "lon" in dim.lower() or "lat" in dim.lower() or "bnds" in dim.lower():
106 chunks[i] = None
107 else:
108 chunks[i] = "auto"
109 old_chunks = dset[var].encoding.get("chunksizes")
110 new_chunks = dset[var].data.rechunk(chunks).chunksize
111 if new_chunks == old_chunks:
112 logger.debug("%s: chunk sizes already optimized, skipping", var)
113 continue
114 dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks)))
115 logger.debug(
116 "%s: old chunk size: %s, new chunk size: %s",
117 var,
118 old_chunks,
119 new_chunks,
120 )
121 logger.debug("Settings encoding of variable %s", var)
122 encoding[data_var] = {
123 str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords
124 }
125 if engine != "netcdf4" or encoding[data_var].get("contiguous", False) is False: 125 ↛ 93line 125 didn't jump to line 93, because the condition on line 125 was never false
126 encoding[data_var]["chunksizes"] = new_chunks
127 return dset, encoding
130def rechunk_dataset(
131 dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "netcdf4"
132) -> xr.Dataset:
133 """Rechunk a xarray dataset.
135 Parameters
136 ----------
137 dset: xarray.Dataset
138 Input dataset that is going to be rechunked
139 engine: str, default: netcdf4
140 The netcdf engine used to create the new netcdf file.
142 Returns
143 -------
144 xarray.Dataset: rechunked dataset
145 """
146 data, _ = _rechunk_dataset(dset.chunk(), engine)
147 return data
150def rechunk_netcdf_file(
151 input_path: os.PathLike,
152 output_path: Optional[os.PathLike] = None,
153 decode_cf: bool = True,
154 engine: Literal["h5netcdf", "netcdf4"] = "netcdf4",
155) -> None:
156 """Rechunk netcdf files.
158 Parameters
159 ----------
160 input_path: os.PathLike
161 Input file/directory. If a directory is given all ``.nc`` in all sub
162 directories will be processed
163 output_path: os.PathLike
164 Output file/directory of the chunked netcdf file(s). Note: If ``input``
165 is a directory output should be a directory. If None given (default)
166 the ``input`` is overidden.
167 decode_cf: bool, default: True
168 Whether to decode these variables, assuming they were saved according
169 to CF conventions.
170 engine: str, default: netcdf4
171 The netcdf engine used to create the new netcdf file.
172 """
173 input_path = Path(input_path).expanduser().absolute()
174 for input_file in _search_for_nc_files(input_path):
175 logger.info("Working on file: %s", str(input_file))
176 if output_path is None:
177 output_file = input_file
178 elif Path(output_path).expanduser().absolute().is_dir():
179 output_file = Path(output_path).expanduser().absolute()
180 output_file /= input_file.relative_to(input_path)
181 else:
182 output_file = Path(output_path)
183 output_file.parent.mkdir(exist_ok=True, parents=True)
184 try:
185 with xr.open_mfdataset(
186 str(input_file),
187 parallel=True,
188 decode_cf=decode_cf,
189 ) as nc_data:
190 new_data, encoding = _rechunk_dataset(nc_data, engine)
191 if encoding:
192 logger.debug(
193 "Loading data into memory (%s).",
194 format_bytes(new_data.nbytes),
195 )
196 new_data = new_data.load()
197 except Exception as error:
198 logger.error(
199 "Error while processing file %s: %s",
200 str(input_file),
201 str(error),
202 )
203 continue
204 _save_dataset(
205 new_data,
206 output_file.with_suffix(input_file.suffix),
207 encoding,
208 engine,
209 override=output_file != input_file,
210 )