Coverage for src/rechunk_data/_rechunk.py: 99%

84 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-30 09:58 +0000

1"""Rechunking module.""" 

2 

3import os 

4import logging 

5from pathlib import Path 

6from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple 

7from typing_extensions import Literal 

8 

9from dask.utils import format_bytes 

10from dask.array.core import Array 

11import xarray as xr 

12 

13 

14logging.basicConfig( 

15 format="%(name)s - %(levelname)s - %(message)s", level=logging.ERROR 

16) 

17logger = logging.getLogger("rechunk-data") 

18 

19ENCODINGS = dict( 

20 h5netcdf={ 

21 "_FillValue", 

22 "complevel", 

23 "chunksizes", 

24 "dtype", 

25 "zlib", 

26 "compression_opts", 

27 "shuffle", 

28 "fletcher32", 

29 "compression", 

30 "contiguous", 

31 }, 

32 netcdf4={ 

33 "contiguous", 

34 "complevel", 

35 "zlib", 

36 "shuffle", 

37 "_FillValue", 

38 "least_significant_digit", 

39 "chunksizes", 

40 "fletcher32", 

41 "dtype", 

42 }, 

43) 

44 

45 

46def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: 

47 suffixes = [".nc", "nc4"] 

48 input_path = input_path.expanduser().absolute() 

49 if input_path.is_dir() and input_path.exists(): 

50 nc_iter = input_path.rglob("*.*") 

51 elif input_path.is_file() and input_path.exists(): 

52 nc_iter = cast(Generator[Path, None, None], iter([input_path])) 

53 else: 

54 # This could be a path with a glob pattern, let's try to construct it 

55 nc_iter = input_path.parent.rglob(input_path.name) 

56 for ncfile in nc_iter: 

57 if ncfile.suffix in suffixes: 

58 yield ncfile 

59 

60 

61def _save_dataset( 

62 dset: xr.Dataset, 

63 file_name: Path, 

64 encoding: Dict[Hashable, Dict[str, Any]], 

65 engine: Literal["netcdf4", "h5netcdf"], 

66 override: bool = False, 

67) -> None: 

68 if not encoding and not override: 

69 logger.debug("Chunk size already optimized for %s", file_name.name) 

70 return 

71 logger.debug("Saving file to %s using %s engine", str(file_name), engine) 

72 try: 

73 dset.to_netcdf( 

74 file_name, 

75 engine=engine, 

76 encoding=encoding, 

77 ) 

78 except Exception as error: 

79 logger.error("Saving to file failed: %s", str(error)) 

80 

81 

82def _rechunk_dataset( 

83 dset: xr.Dataset, 

84 engine: Literal["h5netcdf", "netcdf4"], 

85) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]: 

86 encoding: Dict[Hashable, Dict[str, Any]] = {} 

87 try: 

88 _keywords = ENCODINGS[engine] 

89 except KeyError as error: 

90 raise ValueError( 

91 "Only the following engines are supported: ', '.join(ENCODINGS.keys())" 

92 ) from error 

93 for data_var in dset.data_vars: 

94 var = str(data_var) 

95 if ( 

96 not isinstance(dset[var].data, Array) 

97 or "bnds" in var 

98 or "rotated_pole" in var 

99 ): 

100 logger.debug("Skipping rechunking variable %s", var) 

101 continue 

102 logger.debug("Rechunking variable %s", var) 

103 chunks: Dict[int, Optional[str]] = {} 

104 for i, dim in enumerate(map(str, dset[var].dims)): 

105 if "lon" in dim.lower() or "lat" in dim.lower() or "bnds" in dim.lower(): 

106 chunks[i] = None 

107 else: 

108 chunks[i] = "auto" 

109 old_chunks = dset[var].encoding.get("chunksizes") 

110 new_chunks = dset[var].data.rechunk(chunks).chunksize 

111 if new_chunks == old_chunks: 

112 logger.debug("%s: chunk sizes already optimized, skipping", var) 

113 continue 

114 dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks))) 

115 logger.debug( 

116 "%s: old chunk size: %s, new chunk size: %s", 

117 var, 

118 old_chunks, 

119 new_chunks, 

120 ) 

121 logger.debug("Settings encoding of variable %s", var) 

122 encoding[data_var] = { 

123 str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords 

124 } 

125 if engine != "netcdf4" or encoding[data_var].get("contiguous", False) is False: 125 ↛ 93line 125 didn't jump to line 93, because the condition on line 125 was never false

126 encoding[data_var]["chunksizes"] = new_chunks 

127 return dset, encoding 

128 

129 

130def rechunk_dataset( 

131 dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "netcdf4" 

132) -> xr.Dataset: 

133 """Rechunk a xarray dataset. 

134 

135 Parameters 

136 ---------- 

137 dset: xarray.Dataset 

138 Input dataset that is going to be rechunked 

139 engine: str, default: netcdf4 

140 The netcdf engine used to create the new netcdf file. 

141 

142 Returns 

143 ------- 

144 xarray.Dataset: rechunked dataset 

145 """ 

146 data, _ = _rechunk_dataset(dset.chunk(), engine) 

147 return data 

148 

149 

150def rechunk_netcdf_file( 

151 input_path: os.PathLike, 

152 output_path: Optional[os.PathLike] = None, 

153 decode_cf: bool = True, 

154 engine: Literal["h5netcdf", "netcdf4"] = "netcdf4", 

155) -> None: 

156 """Rechunk netcdf files. 

157 

158 Parameters 

159 ---------- 

160 input_path: os.PathLike 

161 Input file/directory. If a directory is given all ``.nc`` in all sub 

162 directories will be processed 

163 output_path: os.PathLike 

164 Output file/directory of the chunked netcdf file(s). Note: If ``input`` 

165 is a directory output should be a directory. If None given (default) 

166 the ``input`` is overidden. 

167 decode_cf: bool, default: True 

168 Whether to decode these variables, assuming they were saved according 

169 to CF conventions. 

170 engine: str, default: netcdf4 

171 The netcdf engine used to create the new netcdf file. 

172 """ 

173 input_path = Path(input_path).expanduser().absolute() 

174 for input_file in _search_for_nc_files(input_path): 

175 logger.info("Working on file: %s", str(input_file)) 

176 if output_path is None: 

177 output_file = input_file 

178 elif Path(output_path).expanduser().absolute().is_dir(): 

179 output_file = Path(output_path).expanduser().absolute() 

180 output_file /= input_file.relative_to(input_path) 

181 else: 

182 output_file = Path(output_path) 

183 output_file.parent.mkdir(exist_ok=True, parents=True) 

184 try: 

185 with xr.open_mfdataset( 

186 str(input_file), 

187 parallel=True, 

188 decode_cf=decode_cf, 

189 ) as nc_data: 

190 new_data, encoding = _rechunk_dataset(nc_data, engine) 

191 if encoding: 

192 logger.debug( 

193 "Loading data into memory (%s).", 

194 format_bytes(new_data.nbytes), 

195 ) 

196 new_data = new_data.load() 

197 except Exception as error: 

198 logger.error( 

199 "Error while processing file %s: %s", 

200 str(input_file), 

201 str(error), 

202 ) 

203 continue 

204 _save_dataset( 

205 new_data, 

206 output_file.with_suffix(input_file.suffix), 

207 encoding, 

208 engine, 

209 override=output_file != input_file, 

210 )