Coverage for C:\src\imod-python\imod\prepare\layerregrid.py: 95%

73 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1import numba 

2import numpy as np 

3import xarray as xr 

4 

5from imod.prepare import common 

6 

7# LayerRegrid does not support conductance method, nearest, or linear 

8METHODS = common.METHODS.copy() 

9METHODS.pop("conductance") 

10METHODS.pop("nearest") 

11METHODS.pop("multilinear") 

12 

13 

14@numba.njit(cache=True) 

15def _regrid_layers(src, dst, src_top, dst_top, src_bot, dst_bot, method): 

16 """ 

17 Maps one set of layers unto the other. 

18 """ 

19 nlayer_src, nrow, ncol = src.shape 

20 nlayer_dst = dst.shape[0] 

21 values = np.zeros(nlayer_src) 

22 weights = np.zeros(nlayer_src) 

23 

24 for i in range(nrow): 

25 for j in range(ncol): 

26 src_t = src_top[:, i, j] 

27 dst_t = dst_top[:, i, j] 

28 src_b = src_bot[:, i, j] 

29 dst_b = dst_bot[:, i, j] 

30 

31 # ii is index of dst 

32 for ii in range(nlayer_dst): 

33 dt = dst_t[ii] 

34 db = dst_b[ii] 

35 if np.isnan(dt) or np.isnan(db): 

36 continue 

37 

38 count = 0 

39 has_value = False 

40 # jj is index of src 

41 for jj in range(nlayer_src): 

42 st = src_t[jj] 

43 sb = src_b[jj] 

44 

45 if np.isnan(st) or np.isnan(sb): 

46 continue 

47 

48 overlap = common._overlap((db, dt), (sb, st)) 

49 if overlap == 0: 

50 continue 

51 

52 has_value = True 

53 values[count] = src[jj, i, j] 

54 weights[count] = overlap 

55 count += 1 

56 else: 

57 if has_value: 

58 dst[ii, i, j] = method(values, weights) 

59 # Reset 

60 values[:count] = 0 

61 weights[:count] = 0 

62 

63 return dst 

64 

65 

66class LayerRegridder: 

67 """ 

68 Object to repeatedly regrid layers of similar objects. Compiles 

69 once on first call, can then be repeatedly called without 

70 JIT compilation overhead. 

71 

72 Attributes 

73 ---------- 

74 method : str, function 

75 The method to use for regridding. Default available methods are: 

76 ``{"mean", "harmonic_mean", "geometric_mean", "sum", "minimum", 

77 "maximum", "mode", "median", "max_overlap"}`` 

78 """ 

79 

80 def __init__(self, method): 

81 _method = common._get_method(method, METHODS) 

82 self.method = _method 

83 self._first_call = True 

84 

85 def _make_regrid(self): 

86 """ 

87 Use closure to avoid numba overhead 

88 """ 

89 jit_method = numba.njit(self.method) 

90 

91 @numba.njit 

92 def regrid(src, dst, src_top, dst_top, src_bot, dst_bot): 

93 return _regrid_layers( 

94 src, dst, src_top, dst_top, src_bot, dst_bot, jit_method 

95 ) 

96 

97 self._regrid = regrid 

98 

99 def regrid( 

100 self, source, source_top, source_bottom, destination_top, destination_bottom 

101 ): 

102 """ 

103 Parameters 

104 ---------- 

105 source : xr.DataArray 

106 The values of the layered model. 

107 source_top : xr.DataArray 

108 The vertical location of the layer tops. 

109 destination_top : xr.DataArray 

110 The vertical location of the layer tops. 

111 source_bottom : xr.DataArray 

112 The vertical location of the layer bottoms. 

113 destination_bottom : xr.DataArray 

114 The vertical location of the layer bottoms. 

115 

116 Returns 

117 ------- 

118 regridded : xr.DataArray 

119 """ 

120 

121 def dim_format(dims): 

122 return ", ".join(dim for dim in dims) 

123 

124 # Checks on inputs 

125 for da in [ 

126 source_top, 

127 source_bottom, 

128 source, 

129 destination_bottom, 

130 destination_top, 

131 ]: 

132 if not isinstance(da, xr.DataArray): 

133 raise TypeError("All arguments must be DataArrays.") 

134 if not da.dims == ("layer", "y", "x"): 

135 raise ValueError( 

136 "Dimensions for top, bottom, and source have to be exactly" 

137 f' ("layer", "y", "x"). Got instead {dim_format(da.dims)}.' 

138 ) 

139 for da in [source_bottom, source]: 

140 for dim in ["layer", "y", "x"]: 

141 if not source_top[dim].equals(da[dim]): 

142 raise ValueError(f"Input coordinates do not match along {dim}") 

143 

144 if self._first_call: 

145 self._make_regrid() 

146 self._first_call = False 

147 

148 dst = xr.full_like(destination_top, np.nan, dtype=source.dtype) 

149 dst.values = self._regrid( 

150 source.values, 

151 dst.values, 

152 source_top.values, 

153 destination_top.values, 

154 source_bottom.values, 

155 destination_bottom.values, 

156 ) 

157 return dst