Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/guts_array.py: 93%
107 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 15:01 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 15:01 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7NumPy support for :py:mod:`pyrocko.guts`.
8'''
10import numpy as num
11from io import BytesIO
12from base64 import b64decode, b64encode
13import binascii
15from .guts import TBase, Object, ValidationError, literal
18try:
19 unicode
20except NameError:
21 unicode = str
24restricted_dtype_map = {
25 num.dtype('float64'): '<f8',
26 num.dtype('float32'): '<f4',
27 num.dtype('int64'): '<i8',
28 num.dtype('int32'): '<i4',
29 num.dtype('int16'): '<i2',
30 num.dtype('int8'): '<i1'}
32restricted_dtype_map_rev = dict(
33 (v, k) for (k, v) in restricted_dtype_map.items())
36def array_equal(a, b):
37 return a.dtype == b.dtype \
38 and a.shape == b.shape \
39 and num.all(a == b)
42class Array(Object):
43 '''
44 Placeholder for :py:class:`numpy.ndarray`
46 Normally, no objects of this class should be instatiated. It is needed
47 by Guts' type system.
48 '''
50 dummy_for = num.ndarray
51 dummy_for_description = 'numpy.ndarray'
53 class __T(TBase):
54 def __init__(
55 self,
56 shape=None,
57 dtype=None,
58 serialize_as='table',
59 serialize_dtype=None,
60 *args, **kwargs):
62 TBase.__init__(self, *args, **kwargs)
63 self.shape = shape
64 self.dtype = dtype
65 assert serialize_as in (
66 'table', 'base64', 'list', 'npy',
67 'base64+meta', 'base64-compat')
68 self.serialize_as = serialize_as
69 self.serialize_dtype = serialize_dtype
71 def is_default(self, val):
72 if self._default is None:
73 return val is None
74 elif val is None:
75 return False
76 else:
77 return array_equal(self._default, val)
79 def regularize_extra(self, val):
80 if isinstance(val, str):
81 ndim = None
82 if self.shape:
83 ndim = len(self.shape)
85 if self.serialize_as == 'table':
86 val = num.loadtxt(
87 BytesIO(val.encode('utf-8')),
88 dtype=self.dtype, ndmin=ndim)
90 elif self.serialize_as == 'base64':
91 data = b64decode(val)
92 val = num.frombuffer(
93 data, dtype=self.serialize_dtype).astype(self.dtype)
95 elif self.serialize_as == 'base64-compat':
96 try:
97 data = b64decode(val)
98 val = num.frombuffer(
99 data,
100 dtype=self.serialize_dtype).astype(self.dtype)
101 except binascii.Error:
102 val = num.loadtxt(
103 BytesIO(val.encode('utf-8')),
104 dtype=self.dtype, ndmin=ndim)
106 elif self.serialize_as == 'npy':
107 data = b64decode(val)
108 try:
109 val = num.load(BytesIO(data), allow_pickle=False)
110 except TypeError:
111 # allow_pickle only available in newer NumPy
112 val = num.load(BytesIO(data))
114 elif isinstance(val, dict):
115 if self.serialize_as == 'base64+meta':
116 if not sorted(val.keys()) == ['data', 'dtype', 'shape']:
117 raise ValidationError(
118 'array in format "base64+meta" must have keys '
119 '"data", "dtype", and "shape"')
121 shape = val['shape']
122 if not isinstance(shape, list):
123 raise ValidationError('invalid shape definition')
125 for n in shape:
126 if not isinstance(n, int):
127 raise ValidationError('invalid shape definition')
129 serialize_dtype = val['dtype']
130 allowed = list(restricted_dtype_map_rev.keys())
131 if self.serialize_dtype is not None:
132 allowed.append(self.serialize_dtype)
134 if serialize_dtype not in allowed:
135 raise ValidationError(
136 'only the following dtypes are allowed: %s'
137 % ', '.join(sorted(allowed)))
139 data = val['data']
140 if not isinstance(data, str):
141 raise ValidationError(
142 'data must be given as a base64 encoded string')
144 data = b64decode(data)
146 dtype = self.dtype or \
147 restricted_dtype_map_rev[serialize_dtype]
149 val = num.frombuffer(
150 data, dtype=serialize_dtype).astype(dtype)
152 if val.size != num.prod(shape):
153 raise ValidationError('size/shape mismatch')
155 val = val.reshape(shape)
157 else:
158 val = num.asarray(val, dtype=self.dtype)
160 return val
162 def validate_extra(self, val):
163 if not isinstance(val, num.ndarray):
164 raise ValidationError(
165 'object %s is not of type numpy.ndarray: %s' % (
166 self.xname(), type(val)))
167 if self.dtype is not None and self.dtype != val.dtype:
168 raise ValidationError(
169 'array %s not of required type: need %s, got %s' % (
170 self.xname(), self.dtype, val.dtype))
172 if self.shape is not None:
173 la, lb = len(self.shape), len(val.shape)
174 if la != lb:
175 raise ValidationError(
176 'array %s dimension mismatch: need %i, got %i' % (
177 self.xname(), la, lb))
179 for a, b in zip(self.shape, val.shape):
180 if a is not None:
181 if a != b:
182 raise ValidationError(
183 'array %s shape mismatch: need %s, got: %s' % (
184 self.xname(), self.shape, val.shape))
186 def to_save(self, val):
187 if self.serialize_as == 'table':
188 out = BytesIO()
189 num.savetxt(out, val, fmt='%12.7g')
190 return literal(out.getvalue().decode('utf-8'))
191 elif self.serialize_as == 'base64' \
192 or self.serialize_as == 'base64-compat':
193 data = val.astype(self.serialize_dtype).tobytes()
194 return literal(b64encode(data).decode('utf-8'))
195 elif self.serialize_as == 'list':
196 if self.dtype == complex:
197 return [repr(x) for x in val]
198 else:
199 return val.tolist()
200 elif self.serialize_as == 'npy':
201 out = BytesIO()
202 try:
203 num.save(out, val, allow_pickle=False)
204 except TypeError:
205 # allow_pickle only available in newer NumPy
206 num.save(out, val)
208 return literal(b64encode(out.getvalue()).decode('utf-8'))
210 elif self.serialize_as == 'base64+meta':
211 serialize_dtype = self.serialize_dtype or \
212 restricted_dtype_map[val.dtype]
214 data = val.astype(serialize_dtype).tobytes()
216 return dict(
217 dtype=serialize_dtype,
218 shape=val.shape,
219 data=literal(b64encode(data).decode('utf-8')))
222__all__ = ['Array']