1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5from __future__ import absolute_import 

6 

7import numpy as num 

8from io import BytesIO 

9from base64 import b64decode, b64encode 

10import binascii 

11 

12from .guts import TBase, Object, ValidationError, literal, newstr 

13 

14 

15try: 

16 unicode 

17except NameError: 

18 unicode = str 

19 

20 

21restricted_dtype_map = { 

22 num.dtype('float64'): '<f8', 

23 num.dtype('float32'): '<f4', 

24 num.dtype('int64'): '<i8', 

25 num.dtype('int32'): '<i4', 

26 num.dtype('int16'): '<i2', 

27 num.dtype('int8'): '<i1'} 

28 

29restricted_dtype_map_rev = dict( 

30 (v, k) for (k, v) in restricted_dtype_map.items()) 

31 

32 

33def array_equal(a, b): 

34 return a.dtype == b.dtype \ 

35 and a.shape == b.shape \ 

36 and num.all(a == b) 

37 

38 

39class Array(Object): 

40 

41 dummy_for = num.ndarray 

42 

43 class __T(TBase): 

44 def __init__( 

45 self, 

46 shape=None, 

47 dtype=None, 

48 serialize_as='table', 

49 serialize_dtype=None, 

50 *args, **kwargs): 

51 

52 TBase.__init__(self, *args, **kwargs) 

53 self.shape = shape 

54 self.dtype = dtype 

55 assert serialize_as in ( 

56 'table', 'base64', 'list', 'npy', 

57 'base64+meta', 'base64-compat') 

58 self.serialize_as = serialize_as 

59 self.serialize_dtype = serialize_dtype 

60 

61 def is_default(self, val): 

62 if self._default is None: 

63 return val is None 

64 elif val is None: 

65 return False 

66 else: 

67 return array_equal(self._default, val) 

68 

69 def regularize_extra(self, val): 

70 if isinstance(val, (str, newstr)): 

71 ndim = None 

72 if self.shape: 

73 ndim = len(self.shape) 

74 

75 if self.serialize_as == 'table': 

76 val = num.loadtxt( 

77 BytesIO(val.encode('utf-8')), 

78 dtype=self.dtype, ndmin=ndim) 

79 

80 elif self.serialize_as == 'base64': 

81 data = b64decode(val) 

82 val = num.fromstring( 

83 data, dtype=self.serialize_dtype).astype(self.dtype) 

84 

85 elif self.serialize_as == 'base64-compat': 

86 try: 

87 data = b64decode(val) 

88 val = num.fromstring( 

89 data, 

90 dtype=self.serialize_dtype).astype(self.dtype) 

91 except binascii.Error: 

92 val = num.loadtxt( 

93 BytesIO(val.encode('utf-8')), 

94 dtype=self.dtype, ndmin=ndim) 

95 

96 elif self.serialize_as == 'npy': 

97 data = b64decode(val) 

98 try: 

99 val = num.load(BytesIO(data), allow_pickle=False) 

100 except TypeError: 

101 # allow_pickle only available in newer NumPy 

102 val = num.load(BytesIO(data)) 

103 

104 elif isinstance(val, dict): 

105 if self.serialize_as == 'base64+meta': 

106 if not sorted(val.keys()) == ['data', 'dtype', 'shape']: 

107 raise ValidationError( 

108 'array in format "base64+meta" must have keys ' 

109 '"data", "dtype", and "shape"') 

110 

111 shape = val['shape'] 

112 if not isinstance(shape, list): 

113 raise ValidationError('invalid shape definition') 

114 

115 for n in shape: 

116 if not isinstance(n, int): 

117 raise ValidationError('invalid shape definition') 

118 

119 serialize_dtype = val['dtype'] 

120 allowed = list(restricted_dtype_map_rev.keys()) 

121 if self.serialize_dtype is not None: 

122 allowed.append(self.serialize_dtype) 

123 

124 if serialize_dtype not in allowed: 

125 raise ValidationError( 

126 'only the following dtypes are allowed: %s' 

127 % ', '.join(sorted(allowed))) 

128 

129 data = val['data'] 

130 if not isinstance(data, (str, newstr)): 

131 raise ValidationError( 

132 'data must be given as a base64 encoded string') 

133 

134 data = b64decode(data) 

135 

136 dtype = self.dtype or \ 

137 restricted_dtype_map_rev[serialize_dtype] 

138 

139 val = num.fromstring( 

140 data, dtype=serialize_dtype).astype(dtype) 

141 

142 if val.size != num.product(shape): 

143 raise ValidationError('size/shape mismatch') 

144 

145 val = val.reshape(shape) 

146 

147 else: 

148 val = num.asarray(val, dtype=self.dtype) 

149 

150 return val 

151 

152 def validate_extra(self, val): 

153 if not isinstance(val, num.ndarray): 

154 raise ValidationError( 

155 'object is not of type numpy.ndarray: %s' % type(val)) 

156 if self.dtype is not None and self.dtype != val.dtype: 

157 raise ValidationError( 

158 'array not of required type: need %s, got %s' % ( 

159 self.dtype, val.dtype)) 

160 

161 if self.shape is not None: 

162 la, lb = len(self.shape), len(val.shape) 

163 if la != lb: 

164 raise ValidationError( 

165 'array dimension mismatch: need %i, got %i' % ( 

166 la, lb)) 

167 

168 for a, b in zip(self.shape, val.shape): 

169 if a is not None: 

170 if a != b: 

171 raise ValidationError( 

172 'array shape mismatch: need %s, got: %s' % ( 

173 self.shape, val.shape)) 

174 

175 def to_save(self, val): 

176 if self.serialize_as == 'table': 

177 out = BytesIO() 

178 num.savetxt(out, val, fmt='%12.7g') 

179 return literal(out.getvalue().decode('utf-8')) 

180 elif self.serialize_as == 'base64' \ 

181 or self.serialize_as == 'base64-compat': 

182 data = val.astype(self.serialize_dtype).tostring() 

183 return literal(b64encode(data).decode('utf-8')) 

184 elif self.serialize_as == 'list': 

185 if self.dtype == complex: 

186 return [repr(x) for x in val] 

187 else: 

188 return val.tolist() 

189 elif self.serialize_as == 'npy': 

190 out = BytesIO() 

191 try: 

192 num.save(out, val, allow_pickle=False) 

193 except TypeError: 

194 # allow_pickle only available in newer NumPy 

195 num.save(out, val) 

196 

197 return literal(b64encode(out.getvalue()).decode('utf-8')) 

198 

199 elif self.serialize_as == 'base64+meta': 

200 serialize_dtype = self.serialize_dtype or \ 

201 restricted_dtype_map[val.dtype] 

202 

203 data = val.astype(serialize_dtype).tostring() 

204 

205 return dict( 

206 dtype=serialize_dtype, 

207 shape=val.shape, 

208 data=literal(b64encode(data).decode('utf-8'))) 

209 

210 

211__all__ = ['Array']