1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6import numpy as num
7from io import BytesIO
8from base64 import b64decode, b64encode
9import binascii
11from .guts import TBase, Object, ValidationError, literal
14try:
15 unicode
16except NameError:
17 unicode = str
20restricted_dtype_map = {
21 num.dtype('float64'): '<f8',
22 num.dtype('float32'): '<f4',
23 num.dtype('int64'): '<i8',
24 num.dtype('int32'): '<i4',
25 num.dtype('int16'): '<i2',
26 num.dtype('int8'): '<i1'}
28restricted_dtype_map_rev = dict(
29 (v, k) for (k, v) in restricted_dtype_map.items())
32def array_equal(a, b):
33 return a.dtype == b.dtype \
34 and a.shape == b.shape \
35 and num.all(a == b)
38class Array(Object):
40 dummy_for = num.ndarray
42 class __T(TBase):
43 def __init__(
44 self,
45 shape=None,
46 dtype=None,
47 serialize_as='table',
48 serialize_dtype=None,
49 *args, **kwargs):
51 TBase.__init__(self, *args, **kwargs)
52 self.shape = shape
53 self.dtype = dtype
54 assert serialize_as in (
55 'table', 'base64', 'list', 'npy',
56 'base64+meta', 'base64-compat')
57 self.serialize_as = serialize_as
58 self.serialize_dtype = serialize_dtype
60 def is_default(self, val):
61 if self._default is None:
62 return val is None
63 elif val is None:
64 return False
65 else:
66 return array_equal(self._default, val)
68 def regularize_extra(self, val):
69 if isinstance(val, str):
70 ndim = None
71 if self.shape:
72 ndim = len(self.shape)
74 if self.serialize_as == 'table':
75 val = num.loadtxt(
76 BytesIO(val.encode('utf-8')),
77 dtype=self.dtype, ndmin=ndim)
79 elif self.serialize_as == 'base64':
80 data = b64decode(val)
81 val = num.frombuffer(
82 data, dtype=self.serialize_dtype).astype(self.dtype)
84 elif self.serialize_as == 'base64-compat':
85 try:
86 data = b64decode(val)
87 val = num.frombuffer(
88 data,
89 dtype=self.serialize_dtype).astype(self.dtype)
90 except binascii.Error:
91 val = num.loadtxt(
92 BytesIO(val.encode('utf-8')),
93 dtype=self.dtype, ndmin=ndim)
95 elif self.serialize_as == 'npy':
96 data = b64decode(val)
97 try:
98 val = num.load(BytesIO(data), allow_pickle=False)
99 except TypeError:
100 # allow_pickle only available in newer NumPy
101 val = num.load(BytesIO(data))
103 elif isinstance(val, dict):
104 if self.serialize_as == 'base64+meta':
105 if not sorted(val.keys()) == ['data', 'dtype', 'shape']:
106 raise ValidationError(
107 'array in format "base64+meta" must have keys '
108 '"data", "dtype", and "shape"')
110 shape = val['shape']
111 if not isinstance(shape, list):
112 raise ValidationError('invalid shape definition')
114 for n in shape:
115 if not isinstance(n, int):
116 raise ValidationError('invalid shape definition')
118 serialize_dtype = val['dtype']
119 allowed = list(restricted_dtype_map_rev.keys())
120 if self.serialize_dtype is not None:
121 allowed.append(self.serialize_dtype)
123 if serialize_dtype not in allowed:
124 raise ValidationError(
125 'only the following dtypes are allowed: %s'
126 % ', '.join(sorted(allowed)))
128 data = val['data']
129 if not isinstance(data, str):
130 raise ValidationError(
131 'data must be given as a base64 encoded string')
133 data = b64decode(data)
135 dtype = self.dtype or \
136 restricted_dtype_map_rev[serialize_dtype]
138 val = num.frombuffer(
139 data, dtype=serialize_dtype).astype(dtype)
141 if val.size != num.product(shape):
142 raise ValidationError('size/shape mismatch')
144 val = val.reshape(shape)
146 else:
147 val = num.asarray(val, dtype=self.dtype)
149 return val
151 def validate_extra(self, val):
152 if not isinstance(val, num.ndarray):
153 raise ValidationError(
154 'object %s is not of type numpy.ndarray: %s' % (
155 self.xname(), type(val)))
156 if self.dtype is not None and self.dtype != val.dtype:
157 raise ValidationError(
158 'array %s not of required type: need %s, got %s' % (
159 self.xname(), self.dtype, val.dtype))
161 if self.shape is not None:
162 la, lb = len(self.shape), len(val.shape)
163 if la != lb:
164 raise ValidationError(
165 'array %s dimension mismatch: need %i, got %i' % (
166 self.xname(), la, lb))
168 for a, b in zip(self.shape, val.shape):
169 if a is not None:
170 if a != b:
171 raise ValidationError(
172 'array %s shape mismatch: need %s, got: %s' % (
173 self.xname(), self.shape, val.shape))
175 def to_save(self, val):
176 if self.serialize_as == 'table':
177 out = BytesIO()
178 num.savetxt(out, val, fmt='%12.7g')
179 return literal(out.getvalue().decode('utf-8'))
180 elif self.serialize_as == 'base64' \
181 or self.serialize_as == 'base64-compat':
182 data = val.astype(self.serialize_dtype).tobytes()
183 return literal(b64encode(data).decode('utf-8'))
184 elif self.serialize_as == 'list':
185 if self.dtype == complex:
186 return [repr(x) for x in val]
187 else:
188 return val.tolist()
189 elif self.serialize_as == 'npy':
190 out = BytesIO()
191 try:
192 num.save(out, val, allow_pickle=False)
193 except TypeError:
194 # allow_pickle only available in newer NumPy
195 num.save(out, val)
197 return literal(b64encode(out.getvalue()).decode('utf-8'))
199 elif self.serialize_as == 'base64+meta':
200 serialize_dtype = self.serialize_dtype or \
201 restricted_dtype_map[val.dtype]
203 data = val.astype(serialize_dtype).tobytes()
205 return dict(
206 dtype=serialize_dtype,
207 shape=val.shape,
208 data=literal(b64encode(data).decode('utf-8')))
211__all__ = ['Array']