1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6import numpy as num 

7from io import BytesIO 

8from base64 import b64decode, b64encode 

9import binascii 

10 

11from .guts import TBase, Object, ValidationError, literal 

12 

13 

14try: 

15 unicode 

16except NameError: 

17 unicode = str 

18 

19 

20restricted_dtype_map = { 

21 num.dtype('float64'): '<f8', 

22 num.dtype('float32'): '<f4', 

23 num.dtype('int64'): '<i8', 

24 num.dtype('int32'): '<i4', 

25 num.dtype('int16'): '<i2', 

26 num.dtype('int8'): '<i1'} 

27 

28restricted_dtype_map_rev = dict( 

29 (v, k) for (k, v) in restricted_dtype_map.items()) 

30 

31 

32def array_equal(a, b): 

33 return a.dtype == b.dtype \ 

34 and a.shape == b.shape \ 

35 and num.all(a == b) 

36 

37 

38class Array(Object): 

39 

40 dummy_for = num.ndarray 

41 

42 class __T(TBase): 

43 def __init__( 

44 self, 

45 shape=None, 

46 dtype=None, 

47 serialize_as='table', 

48 serialize_dtype=None, 

49 *args, **kwargs): 

50 

51 TBase.__init__(self, *args, **kwargs) 

52 self.shape = shape 

53 self.dtype = dtype 

54 assert serialize_as in ( 

55 'table', 'base64', 'list', 'npy', 

56 'base64+meta', 'base64-compat') 

57 self.serialize_as = serialize_as 

58 self.serialize_dtype = serialize_dtype 

59 

60 def is_default(self, val): 

61 if self._default is None: 

62 return val is None 

63 elif val is None: 

64 return False 

65 else: 

66 return array_equal(self._default, val) 

67 

68 def regularize_extra(self, val): 

69 if isinstance(val, str): 

70 ndim = None 

71 if self.shape: 

72 ndim = len(self.shape) 

73 

74 if self.serialize_as == 'table': 

75 val = num.loadtxt( 

76 BytesIO(val.encode('utf-8')), 

77 dtype=self.dtype, ndmin=ndim) 

78 

79 elif self.serialize_as == 'base64': 

80 data = b64decode(val) 

81 val = num.frombuffer( 

82 data, dtype=self.serialize_dtype).astype(self.dtype) 

83 

84 elif self.serialize_as == 'base64-compat': 

85 try: 

86 data = b64decode(val) 

87 val = num.frombuffer( 

88 data, 

89 dtype=self.serialize_dtype).astype(self.dtype) 

90 except binascii.Error: 

91 val = num.loadtxt( 

92 BytesIO(val.encode('utf-8')), 

93 dtype=self.dtype, ndmin=ndim) 

94 

95 elif self.serialize_as == 'npy': 

96 data = b64decode(val) 

97 try: 

98 val = num.load(BytesIO(data), allow_pickle=False) 

99 except TypeError: 

100 # allow_pickle only available in newer NumPy 

101 val = num.load(BytesIO(data)) 

102 

103 elif isinstance(val, dict): 

104 if self.serialize_as == 'base64+meta': 

105 if not sorted(val.keys()) == ['data', 'dtype', 'shape']: 

106 raise ValidationError( 

107 'array in format "base64+meta" must have keys ' 

108 '"data", "dtype", and "shape"') 

109 

110 shape = val['shape'] 

111 if not isinstance(shape, list): 

112 raise ValidationError('invalid shape definition') 

113 

114 for n in shape: 

115 if not isinstance(n, int): 

116 raise ValidationError('invalid shape definition') 

117 

118 serialize_dtype = val['dtype'] 

119 allowed = list(restricted_dtype_map_rev.keys()) 

120 if self.serialize_dtype is not None: 

121 allowed.append(self.serialize_dtype) 

122 

123 if serialize_dtype not in allowed: 

124 raise ValidationError( 

125 'only the following dtypes are allowed: %s' 

126 % ', '.join(sorted(allowed))) 

127 

128 data = val['data'] 

129 if not isinstance(data, str): 

130 raise ValidationError( 

131 'data must be given as a base64 encoded string') 

132 

133 data = b64decode(data) 

134 

135 dtype = self.dtype or \ 

136 restricted_dtype_map_rev[serialize_dtype] 

137 

138 val = num.frombuffer( 

139 data, dtype=serialize_dtype).astype(dtype) 

140 

141 if val.size != num.product(shape): 

142 raise ValidationError('size/shape mismatch') 

143 

144 val = val.reshape(shape) 

145 

146 else: 

147 val = num.asarray(val, dtype=self.dtype) 

148 

149 return val 

150 

151 def validate_extra(self, val): 

152 if not isinstance(val, num.ndarray): 

153 raise ValidationError( 

154 'object %s is not of type numpy.ndarray: %s' % ( 

155 self.xname(), type(val))) 

156 if self.dtype is not None and self.dtype != val.dtype: 

157 raise ValidationError( 

158 'array %s not of required type: need %s, got %s' % ( 

159 self.xname(), self.dtype, val.dtype)) 

160 

161 if self.shape is not None: 

162 la, lb = len(self.shape), len(val.shape) 

163 if la != lb: 

164 raise ValidationError( 

165 'array %s dimension mismatch: need %i, got %i' % ( 

166 self.xname(), la, lb)) 

167 

168 for a, b in zip(self.shape, val.shape): 

169 if a is not None: 

170 if a != b: 

171 raise ValidationError( 

172 'array %s shape mismatch: need %s, got: %s' % ( 

173 self.xname(), self.shape, val.shape)) 

174 

175 def to_save(self, val): 

176 if self.serialize_as == 'table': 

177 out = BytesIO() 

178 num.savetxt(out, val, fmt='%12.7g') 

179 return literal(out.getvalue().decode('utf-8')) 

180 elif self.serialize_as == 'base64' \ 

181 or self.serialize_as == 'base64-compat': 

182 data = val.astype(self.serialize_dtype).tobytes() 

183 return literal(b64encode(data).decode('utf-8')) 

184 elif self.serialize_as == 'list': 

185 if self.dtype == complex: 

186 return [repr(x) for x in val] 

187 else: 

188 return val.tolist() 

189 elif self.serialize_as == 'npy': 

190 out = BytesIO() 

191 try: 

192 num.save(out, val, allow_pickle=False) 

193 except TypeError: 

194 # allow_pickle only available in newer NumPy 

195 num.save(out, val) 

196 

197 return literal(b64encode(out.getvalue()).decode('utf-8')) 

198 

199 elif self.serialize_as == 'base64+meta': 

200 serialize_dtype = self.serialize_dtype or \ 

201 restricted_dtype_map[val.dtype] 

202 

203 data = val.astype(serialize_dtype).tobytes() 

204 

205 return dict( 

206 dtype=serialize_dtype, 

207 shape=val.shape, 

208 data=literal(b64encode(data).decode('utf-8'))) 

209 

210 

211__all__ = ['Array']