1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5from __future__ import absolute_import 

6 

7import numpy as num 

8from io import BytesIO 

9from base64 import b64decode, b64encode 

10import binascii 

11 

12from .guts import TBase, Object, ValidationError, literal, newstr 

13 

14 

15try: 

16 unicode 

17except NameError: 

18 unicode = str 

19 

20 

21restricted_dtype_map = { 

22 num.dtype('float64'): '<f8', 

23 num.dtype('float32'): '<f4', 

24 num.dtype('int64'): '<i8', 

25 num.dtype('int32'): '<i4', 

26 num.dtype('int16'): '<i2', 

27 num.dtype('int8'): '<i1'} 

28 

29restricted_dtype_map_rev = dict( 

30 (v, k) for (k, v) in restricted_dtype_map.items()) 

31 

32 

33def array_equal(a, b): 

34 return a.dtype == b.dtype \ 

35 and a.shape == b.shape \ 

36 and num.all(a == b) 

37 

38 

39class Array(Object): 

40 

41 dummy_for = num.ndarray 

42 

43 class __T(TBase): 

44 def __init__( 

45 self, 

46 shape=None, 

47 dtype=None, 

48 serialize_as='table', 

49 serialize_dtype=None, 

50 *args, **kwargs): 

51 

52 TBase.__init__(self, *args, **kwargs) 

53 self.shape = shape 

54 self.dtype = dtype 

55 assert serialize_as in ( 

56 'table', 'base64', 'list', 'npy', 

57 'base64+meta', 'base64-compat') 

58 self.serialize_as = serialize_as 

59 self.serialize_dtype = serialize_dtype 

60 

61 def is_default(self, val): 

62 if self._default is None: 

63 return val is None 

64 elif val is None: 

65 return False 

66 else: 

67 return array_equal(self._default, val) 

68 

69 def regularize_extra(self, val): 

70 if isinstance(val, (str, newstr)): 

71 ndim = None 

72 if self.shape: 

73 ndim = len(self.shape) 

74 

75 if self.serialize_as == 'table': 

76 val = num.loadtxt( 

77 BytesIO(val.encode('utf-8')), 

78 dtype=self.dtype, ndmin=ndim) 

79 

80 elif self.serialize_as == 'base64': 

81 data = b64decode(val) 

82 val = num.frombuffer( 

83 data, dtype=self.serialize_dtype).astype(self.dtype) 

84 

85 elif self.serialize_as == 'base64-compat': 

86 try: 

87 data = b64decode(val) 

88 val = num.frombuffer( 

89 data, 

90 dtype=self.serialize_dtype).astype(self.dtype) 

91 except binascii.Error: 

92 val = num.loadtxt( 

93 BytesIO(val.encode('utf-8')), 

94 dtype=self.dtype, ndmin=ndim) 

95 

96 elif self.serialize_as == 'npy': 

97 data = b64decode(val) 

98 try: 

99 val = num.load(BytesIO(data), allow_pickle=False) 

100 except TypeError: 

101 # allow_pickle only available in newer NumPy 

102 val = num.load(BytesIO(data)) 

103 

104 elif isinstance(val, dict): 

105 if self.serialize_as == 'base64+meta': 

106 if not sorted(val.keys()) == ['data', 'dtype', 'shape']: 

107 raise ValidationError( 

108 'array in format "base64+meta" must have keys ' 

109 '"data", "dtype", and "shape"') 

110 

111 shape = val['shape'] 

112 if not isinstance(shape, list): 

113 raise ValidationError('invalid shape definition') 

114 

115 for n in shape: 

116 if not isinstance(n, int): 

117 raise ValidationError('invalid shape definition') 

118 

119 serialize_dtype = val['dtype'] 

120 allowed = list(restricted_dtype_map_rev.keys()) 

121 if self.serialize_dtype is not None: 

122 allowed.append(self.serialize_dtype) 

123 

124 if serialize_dtype not in allowed: 

125 raise ValidationError( 

126 'only the following dtypes are allowed: %s' 

127 % ', '.join(sorted(allowed))) 

128 

129 data = val['data'] 

130 if not isinstance(data, (str, newstr)): 

131 raise ValidationError( 

132 'data must be given as a base64 encoded string') 

133 

134 data = b64decode(data) 

135 

136 dtype = self.dtype or \ 

137 restricted_dtype_map_rev[serialize_dtype] 

138 

139 val = num.frombuffer( 

140 data, dtype=serialize_dtype).astype(dtype) 

141 

142 if val.size != num.product(shape): 

143 raise ValidationError('size/shape mismatch') 

144 

145 val = val.reshape(shape) 

146 

147 else: 

148 val = num.asarray(val, dtype=self.dtype) 

149 

150 return val 

151 

152 def validate_extra(self, val): 

153 if not isinstance(val, num.ndarray): 

154 raise ValidationError( 

155 'object %s is not of type numpy.ndarray: %s' % ( 

156 self.xname(), type(val))) 

157 if self.dtype is not None and self.dtype != val.dtype: 

158 raise ValidationError( 

159 'array %s not of required type: need %s, got %s' % ( 

160 self.xname(), self.dtype, val.dtype)) 

161 

162 if self.shape is not None: 

163 la, lb = len(self.shape), len(val.shape) 

164 if la != lb: 

165 raise ValidationError( 

166 'array %s dimension mismatch: need %i, got %i' % ( 

167 self.xname(), la, lb)) 

168 

169 for a, b in zip(self.shape, val.shape): 

170 if a is not None: 

171 if a != b: 

172 raise ValidationError( 

173 'array %s shape mismatch: need %s, got: %s' % ( 

174 self.xname(), self.shape, val.shape)) 

175 

176 def to_save(self, val): 

177 if self.serialize_as == 'table': 

178 out = BytesIO() 

179 num.savetxt(out, val, fmt='%12.7g') 

180 return literal(out.getvalue().decode('utf-8')) 

181 elif self.serialize_as == 'base64' \ 

182 or self.serialize_as == 'base64-compat': 

183 data = val.astype(self.serialize_dtype).tobytes() 

184 return literal(b64encode(data).decode('utf-8')) 

185 elif self.serialize_as == 'list': 

186 if self.dtype == complex: 

187 return [repr(x) for x in val] 

188 else: 

189 return val.tolist() 

190 elif self.serialize_as == 'npy': 

191 out = BytesIO() 

192 try: 

193 num.save(out, val, allow_pickle=False) 

194 except TypeError: 

195 # allow_pickle only available in newer NumPy 

196 num.save(out, val) 

197 

198 return literal(b64encode(out.getvalue()).decode('utf-8')) 

199 

200 elif self.serialize_as == 'base64+meta': 

201 serialize_dtype = self.serialize_dtype or \ 

202 restricted_dtype_map[val.dtype] 

203 

204 data = val.astype(serialize_dtype).tobytes() 

205 

206 return dict( 

207 dtype=serialize_dtype, 

208 shape=val.shape, 

209 data=literal(b64encode(data).decode('utf-8'))) 

210 

211 

212__all__ = ['Array']