Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/guts_array.py: 93%

107 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-01-03 09:20 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7NumPy support for :py:mod:`pyrocko.guts`. 

8''' 

9 

10import numpy as num 

11from io import BytesIO 

12from base64 import b64decode, b64encode 

13import binascii 

14 

15from .guts import TBase, Object, ValidationError, literal 

16 

17 

18try: 

19 unicode 

20except NameError: 

21 unicode = str 

22 

23 

24restricted_dtype_map = { 

25 num.dtype('float64'): '<f8', 

26 num.dtype('float32'): '<f4', 

27 num.dtype('int64'): '<i8', 

28 num.dtype('int32'): '<i4', 

29 num.dtype('int16'): '<i2', 

30 num.dtype('int8'): '<i1'} 

31 

32restricted_dtype_map_rev = dict( 

33 (v, k) for (k, v) in restricted_dtype_map.items()) 

34 

35 

36def array_equal(a, b): 

37 return a.dtype == b.dtype \ 

38 and a.shape == b.shape \ 

39 and num.all(a == b) 

40 

41 

42class Array(Object): 

43 ''' 

44 Placeholder for :py:class:`numpy.ndarray` 

45 

46 Normally, no objects of this class should be instatiated. It is needed 

47 by Guts' type system. 

48 ''' 

49 

50 dummy_for = num.ndarray 

51 dummy_for_description = 'numpy.ndarray' 

52 

53 class __T(TBase): 

54 def __init__( 

55 self, 

56 shape=None, 

57 dtype=None, 

58 serialize_as='table', 

59 serialize_dtype=None, 

60 *args, **kwargs): 

61 

62 TBase.__init__(self, *args, **kwargs) 

63 self.shape = shape 

64 self.dtype = dtype 

65 assert serialize_as in ( 

66 'table', 'base64', 'list', 'npy', 

67 'base64+meta', 'base64-compat') 

68 self.serialize_as = serialize_as 

69 self.serialize_dtype = serialize_dtype 

70 

71 def is_default(self, val): 

72 if self._default is None: 

73 return val is None 

74 elif val is None: 

75 return False 

76 else: 

77 return array_equal(self._default, val) 

78 

79 def regularize_extra(self, val): 

80 if isinstance(val, str): 

81 ndim = None 

82 if self.shape: 

83 ndim = len(self.shape) 

84 

85 if self.serialize_as == 'table': 

86 val = num.loadtxt( 

87 BytesIO(val.encode('utf-8')), 

88 dtype=self.dtype, ndmin=ndim) 

89 

90 elif self.serialize_as == 'base64': 

91 data = b64decode(val) 

92 val = num.frombuffer( 

93 data, dtype=self.serialize_dtype).astype(self.dtype) 

94 

95 elif self.serialize_as == 'base64-compat': 

96 try: 

97 data = b64decode(val) 

98 val = num.frombuffer( 

99 data, 

100 dtype=self.serialize_dtype).astype(self.dtype) 

101 except binascii.Error: 

102 val = num.loadtxt( 

103 BytesIO(val.encode('utf-8')), 

104 dtype=self.dtype, ndmin=ndim) 

105 

106 elif self.serialize_as == 'npy': 

107 data = b64decode(val) 

108 try: 

109 val = num.load(BytesIO(data), allow_pickle=False) 

110 except TypeError: 

111 # allow_pickle only available in newer NumPy 

112 val = num.load(BytesIO(data)) 

113 

114 elif isinstance(val, dict): 

115 if self.serialize_as == 'base64+meta': 

116 if not sorted(val.keys()) == ['data', 'dtype', 'shape']: 

117 raise ValidationError( 

118 'array in format "base64+meta" must have keys ' 

119 '"data", "dtype", and "shape"') 

120 

121 shape = val['shape'] 

122 if not isinstance(shape, list): 

123 raise ValidationError('invalid shape definition') 

124 

125 for n in shape: 

126 if not isinstance(n, int): 

127 raise ValidationError('invalid shape definition') 

128 

129 serialize_dtype = val['dtype'] 

130 allowed = list(restricted_dtype_map_rev.keys()) 

131 if self.serialize_dtype is not None: 

132 allowed.append(self.serialize_dtype) 

133 

134 if serialize_dtype not in allowed: 

135 raise ValidationError( 

136 'only the following dtypes are allowed: %s' 

137 % ', '.join(sorted(allowed))) 

138 

139 data = val['data'] 

140 if not isinstance(data, str): 

141 raise ValidationError( 

142 'data must be given as a base64 encoded string') 

143 

144 data = b64decode(data) 

145 

146 dtype = self.dtype or \ 

147 restricted_dtype_map_rev[serialize_dtype] 

148 

149 val = num.frombuffer( 

150 data, dtype=serialize_dtype).astype(dtype) 

151 

152 if val.size != num.prod(shape): 

153 raise ValidationError('size/shape mismatch') 

154 

155 val = val.reshape(shape) 

156 

157 else: 

158 val = num.asarray(val, dtype=self.dtype) 

159 

160 return val 

161 

162 def validate_extra(self, val): 

163 if not isinstance(val, num.ndarray): 

164 raise ValidationError( 

165 'object %s is not of type numpy.ndarray: %s' % ( 

166 self.xname(), type(val))) 

167 if self.dtype is not None and self.dtype != val.dtype: 

168 raise ValidationError( 

169 'array %s not of required type: need %s, got %s' % ( 

170 self.xname(), self.dtype, val.dtype)) 

171 

172 if self.shape is not None: 

173 la, lb = len(self.shape), len(val.shape) 

174 if la != lb: 

175 raise ValidationError( 

176 'array %s dimension mismatch: need %i, got %i' % ( 

177 self.xname(), la, lb)) 

178 

179 for a, b in zip(self.shape, val.shape): 

180 if a is not None: 

181 if a != b: 

182 raise ValidationError( 

183 'array %s shape mismatch: need %s, got: %s' % ( 

184 self.xname(), self.shape, val.shape)) 

185 

186 def to_save(self, val): 

187 if self.serialize_as == 'table': 

188 out = BytesIO() 

189 num.savetxt(out, val, fmt='%12.7g') 

190 return literal(out.getvalue().decode('utf-8')) 

191 elif self.serialize_as == 'base64' \ 

192 or self.serialize_as == 'base64-compat': 

193 data = val.astype(self.serialize_dtype).tobytes() 

194 return literal(b64encode(data).decode('utf-8')) 

195 elif self.serialize_as == 'list': 

196 if self.dtype == complex: 

197 return [repr(x) for x in val] 

198 else: 

199 return val.tolist() 

200 elif self.serialize_as == 'npy': 

201 out = BytesIO() 

202 try: 

203 num.save(out, val, allow_pickle=False) 

204 except TypeError: 

205 # allow_pickle only available in newer NumPy 

206 num.save(out, val) 

207 

208 return literal(b64encode(out.getvalue()).decode('utf-8')) 

209 

210 elif self.serialize_as == 'base64+meta': 

211 serialize_dtype = self.serialize_dtype or \ 

212 restricted_dtype_map[val.dtype] 

213 

214 data = val.astype(serialize_dtype).tobytes() 

215 

216 return dict( 

217 dtype=serialize_dtype, 

218 shape=val.shape, 

219 data=literal(b64encode(data).decode('utf-8'))) 

220 

221 

222__all__ = ['Array']