Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from numbers import Number 

2import operator 

3import os 

4import threading 

5import contextlib 

6 

7import numpy as np 

8# good_size is exposed (and used) from this import 

9from .pypocketfft import good_size 

10 

11_config = threading.local() 

12_cpu_count = os.cpu_count() 

13 

14 

15def _iterable_of_int(x, name=None): 

16 """Convert ``x`` to an iterable sequence of int 

17 

18 Parameters 

19 ---------- 

20 x : value, or sequence of values, convertible to int 

21 name : str, optional 

22 Name of the argument being converted, only used in the error message 

23 

24 Returns 

25 ------- 

26 y : ``List[int]`` 

27 """ 

28 if isinstance(x, Number): 

29 x = (x,) 

30 

31 try: 

32 x = [operator.index(a) for a in x] 

33 except TypeError as e: 

34 name = name or "value" 

35 raise ValueError("{} must be a scalar or iterable of integers" 

36 .format(name)) from e 

37 

38 return x 

39 

40 

41def _init_nd_shape_and_axes(x, shape, axes): 

42 """Handles shape and axes arguments for nd transforms""" 

43 noshape = shape is None 

44 noaxes = axes is None 

45 

46 if not noaxes: 

47 axes = _iterable_of_int(axes, 'axes') 

48 axes = [a + x.ndim if a < 0 else a for a in axes] 

49 

50 if any(a >= x.ndim or a < 0 for a in axes): 

51 raise ValueError("axes exceeds dimensionality of input") 

52 if len(set(axes)) != len(axes): 

53 raise ValueError("all axes must be unique") 

54 

55 if not noshape: 

56 shape = _iterable_of_int(shape, 'shape') 

57 

58 if axes and len(axes) != len(shape): 

59 raise ValueError("when given, axes and shape arguments" 

60 " have to be of the same length") 

61 if noaxes: 

62 if len(shape) > x.ndim: 

63 raise ValueError("shape requires more axes than are present") 

64 axes = range(x.ndim - len(shape), x.ndim) 

65 

66 shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] 

67 elif noaxes: 

68 shape = list(x.shape) 

69 axes = range(x.ndim) 

70 else: 

71 shape = [x.shape[a] for a in axes] 

72 

73 if any(s < 1 for s in shape): 

74 raise ValueError( 

75 "invalid number of data points ({0}) specified".format(shape)) 

76 

77 return shape, axes 

78 

79 

80def _asfarray(x): 

81 """ 

82 Convert to array with floating or complex dtype. 

83 

84 float16 values are also promoted to float32. 

85 """ 

86 if not hasattr(x, "dtype"): 

87 x = np.asarray(x) 

88 

89 if x.dtype == np.float16: 

90 return np.asarray(x, np.float32) 

91 elif x.dtype.kind not in 'fc': 

92 return np.asarray(x, np.float64) 

93 

94 # Require native byte order 

95 dtype = x.dtype.newbyteorder('=') 

96 # Always align input 

97 copy = not x.flags['ALIGNED'] 

98 return np.array(x, dtype=dtype, copy=copy) 

99 

100def _datacopied(arr, original): 

101 """ 

102 Strict check for `arr` not sharing any data with `original`, 

103 under the assumption that arr = asarray(original) 

104 """ 

105 if arr is original: 

106 return False 

107 if not isinstance(original, np.ndarray) and hasattr(original, '__array__'): 

108 return False 

109 return arr.base is None 

110 

111 

112def _fix_shape(x, shape, axes): 

113 """Internal auxiliary function for _raw_fft, _raw_fftnd.""" 

114 must_copy = False 

115 

116 # Build an nd slice with the dimensions to be read from x 

117 index = [slice(None)]*x.ndim 

118 for n, ax in zip(shape, axes): 

119 if x.shape[ax] >= n: 

120 index[ax] = slice(0, n) 

121 else: 

122 index[ax] = slice(0, x.shape[ax]) 

123 must_copy = True 

124 

125 index = tuple(index) 

126 

127 if not must_copy: 

128 return x[index], False 

129 

130 s = list(x.shape) 

131 for n, axis in zip(shape, axes): 

132 s[axis] = n 

133 

134 z = np.zeros(s, x.dtype) 

135 z[index] = x[index] 

136 return z, True 

137 

138 

139def _fix_shape_1d(x, n, axis): 

140 if n < 1: 

141 raise ValueError( 

142 "invalid number of data points ({0}) specified".format(n)) 

143 

144 return _fix_shape(x, (n,), (axis,)) 

145 

146 

147def _normalization(norm, forward): 

148 """Returns the pypocketfft normalization mode from the norm argument""" 

149 

150 if norm is None: 

151 return 0 if forward else 2 

152 

153 if norm == 'ortho': 

154 return 1 

155 

156 raise ValueError( 

157 "Invalid norm value {}, should be None or \"ortho\".".format(norm)) 

158 

159 

160def _workers(workers): 

161 if workers is None: 

162 return getattr(_config, 'default_workers', 1) 

163 

164 if workers < 0: 

165 if workers >= -_cpu_count: 

166 workers += 1 + _cpu_count 

167 else: 

168 raise ValueError("workers value out of range; got {}, must not be" 

169 " less than {}".format(workers, -_cpu_count)) 

170 elif workers == 0: 

171 raise ValueError("workers must not be zero") 

172 

173 return workers 

174 

175 

176@contextlib.contextmanager 

177def set_workers(workers): 

178 """Context manager for the default number of workers used in `scipy.fft` 

179 

180 Parameters 

181 ---------- 

182 workers : int 

183 The default number of workers to use 

184 

185 Examples 

186 -------- 

187 >>> from scipy import fft, signal 

188 >>> x = np.random.randn(128, 64) 

189 >>> with fft.set_workers(4): 

190 ... y = signal.fftconvolve(x, x) 

191 

192 """ 

193 old_workers = get_workers() 

194 _config.default_workers = _workers(operator.index(workers)) 

195 try: 

196 yield 

197 finally: 

198 _config.default_workers = old_workers 

199 

200 

201def get_workers(): 

202 """Returns the default number of workers within the current context 

203 

204 Examples 

205 -------- 

206 >>> from scipy import fft 

207 >>> fft.get_workers() 

208 1 

209 >>> with fft.set_workers(4): 

210 ... fft.get_workers() 

211 4 

212 """ 

213 return getattr(_config, 'default_workers', 1)