1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
5from __future__ import absolute_import
7import numpy as num
8from io import BytesIO
9from base64 import b64decode, b64encode
10import binascii
12from .guts import TBase, Object, ValidationError, literal, newstr
15try:
16 unicode
17except NameError:
18 unicode = str
21restricted_dtype_map = {
22 num.dtype('float64'): '<f8',
23 num.dtype('float32'): '<f4',
24 num.dtype('int64'): '<i8',
25 num.dtype('int32'): '<i4',
26 num.dtype('int16'): '<i2',
27 num.dtype('int8'): '<i1'}
29restricted_dtype_map_rev = dict(
30 (v, k) for (k, v) in restricted_dtype_map.items())
33def array_equal(a, b):
34 return a.dtype == b.dtype \
35 and a.shape == b.shape \
36 and num.all(a == b)
39class Array(Object):
41 dummy_for = num.ndarray
43 class __T(TBase):
44 def __init__(
45 self,
46 shape=None,
47 dtype=None,
48 serialize_as='table',
49 serialize_dtype=None,
50 *args, **kwargs):
52 TBase.__init__(self, *args, **kwargs)
53 self.shape = shape
54 self.dtype = dtype
55 assert serialize_as in (
56 'table', 'base64', 'list', 'npy',
57 'base64+meta', 'base64-compat')
58 self.serialize_as = serialize_as
59 self.serialize_dtype = serialize_dtype
61 def is_default(self, val):
62 if self._default is None:
63 return val is None
64 elif val is None:
65 return False
66 else:
67 return array_equal(self._default, val)
69 def regularize_extra(self, val):
70 if isinstance(val, (str, newstr)):
71 ndim = None
72 if self.shape:
73 ndim = len(self.shape)
75 if self.serialize_as == 'table':
76 val = num.loadtxt(
77 BytesIO(val.encode('utf-8')),
78 dtype=self.dtype, ndmin=ndim)
80 elif self.serialize_as == 'base64':
81 data = b64decode(val)
82 val = num.fromstring(
83 data, dtype=self.serialize_dtype).astype(self.dtype)
85 elif self.serialize_as == 'base64-compat':
86 try:
87 data = b64decode(val)
88 val = num.fromstring(
89 data,
90 dtype=self.serialize_dtype).astype(self.dtype)
91 except binascii.Error:
92 val = num.loadtxt(
93 BytesIO(val.encode('utf-8')),
94 dtype=self.dtype, ndmin=ndim)
96 elif self.serialize_as == 'npy':
97 data = b64decode(val)
98 try:
99 val = num.load(BytesIO(data), allow_pickle=False)
100 except TypeError:
101 # allow_pickle only available in newer NumPy
102 val = num.load(BytesIO(data))
104 elif isinstance(val, dict):
105 if self.serialize_as == 'base64+meta':
106 if not sorted(val.keys()) == ['data', 'dtype', 'shape']:
107 raise ValidationError(
108 'array in format "base64+meta" must have keys '
109 '"data", "dtype", and "shape"')
111 shape = val['shape']
112 if not isinstance(shape, list):
113 raise ValidationError('invalid shape definition')
115 for n in shape:
116 if not isinstance(n, int):
117 raise ValidationError('invalid shape definition')
119 serialize_dtype = val['dtype']
120 allowed = list(restricted_dtype_map_rev.keys())
121 if self.serialize_dtype is not None:
122 allowed.append(self.serialize_dtype)
124 if serialize_dtype not in allowed:
125 raise ValidationError(
126 'only the following dtypes are allowed: %s'
127 % ', '.join(sorted(allowed)))
129 data = val['data']
130 if not isinstance(data, (str, newstr)):
131 raise ValidationError(
132 'data must be given as a base64 encoded string')
134 data = b64decode(data)
136 dtype = self.dtype or \
137 restricted_dtype_map_rev[serialize_dtype]
139 val = num.fromstring(
140 data, dtype=serialize_dtype).astype(dtype)
142 if val.size != num.product(shape):
143 raise ValidationError('size/shape mismatch')
145 val = val.reshape(shape)
147 else:
148 val = num.asarray(val, dtype=self.dtype)
150 return val
152 def validate_extra(self, val):
153 if not isinstance(val, num.ndarray):
154 raise ValidationError(
155 'object is not of type numpy.ndarray: %s' % type(val))
156 if self.dtype is not None and self.dtype != val.dtype:
157 raise ValidationError(
158 'array not of required type: need %s, got %s' % (
159 self.dtype, val.dtype))
161 if self.shape is not None:
162 la, lb = len(self.shape), len(val.shape)
163 if la != lb:
164 raise ValidationError(
165 'array dimension mismatch: need %i, got %i' % (
166 la, lb))
168 for a, b in zip(self.shape, val.shape):
169 if a is not None:
170 if a != b:
171 raise ValidationError(
172 'array shape mismatch: need %s, got: %s' % (
173 self.shape, val.shape))
175 def to_save(self, val):
176 if self.serialize_as == 'table':
177 out = BytesIO()
178 num.savetxt(out, val, fmt='%12.7g')
179 return literal(out.getvalue().decode('utf-8'))
180 elif self.serialize_as == 'base64' \
181 or self.serialize_as == 'base64-compat':
182 data = val.astype(self.serialize_dtype).tostring()
183 return literal(b64encode(data).decode('utf-8'))
184 elif self.serialize_as == 'list':
185 if self.dtype == complex:
186 return [repr(x) for x in val]
187 else:
188 return val.tolist()
189 elif self.serialize_as == 'npy':
190 out = BytesIO()
191 try:
192 num.save(out, val, allow_pickle=False)
193 except TypeError:
194 # allow_pickle only available in newer NumPy
195 num.save(out, val)
197 return literal(b64encode(out.getvalue()).decode('utf-8'))
199 elif self.serialize_as == 'base64+meta':
200 serialize_dtype = self.serialize_dtype or \
201 restricted_dtype_map[val.dtype]
203 data = val.astype(serialize_dtype).tostring()
205 return dict(
206 dtype=serialize_dtype,
207 shape=val.shape,
208 data=literal(b64encode(data).decode('utf-8')))
211__all__ = ['Array']