# http://pyrocko.org - GPLv3 # # The Pyrocko Developers, 21st Century # ---|P------/S----------~Lg---------- from __future__ import absolute_import from builtins import zip from builtins import str as newstr
import numpy as num from io import BytesIO from base64 import b64decode, b64encode import binascii
from .guts import TBase, Object, ValidationError, literal
try: unicode except NameError: unicode = str
restricted_dtype_map = { num.dtype('float64'): '<f8', num.dtype('float32'): '<f4', num.dtype('int64'): '<i8', num.dtype('int32'): '<i4', num.dtype('int16'): '<i2', num.dtype('int8'): '<i1'}
restricted_dtype_map_rev = dict( (v, k) for (k, v) in restricted_dtype_map.items())
def array_equal(a, b): return a.dtype == b.dtype \ and a.shape == b.shape \ and num.all(a == b)
class Array(Object):
dummy_for = num.ndarray
class __T(TBase): def __init__( self, shape=None, dtype=None, serialize_as='table', serialize_dtype=None, *args, **kwargs):
TBase.__init__(self, *args, **kwargs) self.shape = shape self.dtype = dtype assert serialize_as in ( 'table', 'base64', 'list', 'npy', 'base64+meta', 'base64-compat') self.serialize_as = serialize_as self.serialize_dtype = serialize_dtype
def is_default(self, val): if self._default_cmp is None: return val is None elif val is None: return False else: return array_equal(self._default_cmp, val)
def regularize_extra(self, val): if isinstance(val, (str, newstr)): ndim = None if self.shape: ndim = len(self.shape)
if self.serialize_as == 'table': val = num.loadtxt( BytesIO(val.encode('utf-8')), dtype=self.dtype, ndmin=ndim)
elif self.serialize_as == 'base64': data = b64decode(val) val = num.fromstring( data, dtype=self.serialize_dtype).astype(self.dtype)
elif self.serialize_as == 'base64-compat': try: data = b64decode(val) val = num.fromstring( data, dtype=self.serialize_dtype).astype(self.dtype) except binascii.Error: val = num.loadtxt( BytesIO(val.encode('utf-8')), dtype=self.dtype, ndmin=ndim)
elif self.serialize_as == 'npy': data = b64decode(val) try: val = num.load(BytesIO(data), allow_pickle=False) except TypeError: # allow_pickle only available in newer NumPy val = num.load(BytesIO(data))
elif isinstance(val, dict): if self.serialize_as == 'base64+meta': if not sorted(val.keys()) == ['data', 'dtype', 'shape']: raise ValidationError( 'array in format "base64+meta" must have keys ' '"data", "dtype", and "shape"')
shape = val['shape'] if not isinstance(shape, list): raise ValidationError('invalid shape definition')
for n in shape: if not isinstance(n, int): raise ValidationError('invalid shape definition')
serialize_dtype = val['dtype'] allowed = list(restricted_dtype_map_rev.keys()) if self.serialize_dtype is not None: allowed.append(self.serialize_dtype)
if serialize_dtype not in allowed: raise ValidationError( 'only the following dtypes are allowed: %s' % ', '.join(sorted(allowed)))
data = val['data'] if not isinstance(data, (str, newstr)): raise ValidationError( 'data must be given as a base64 encoded string')
data = b64decode(data)
dtype = self.dtype or \ restricted_dtype_map_rev[serialize_dtype]
val = num.fromstring( data, dtype=serialize_dtype).astype(dtype)
if val.size != num.product(shape): raise ValidationError('size/shape mismatch')
val = val.reshape(shape)
else: val = num.asarray(val, dtype=self.dtype)
return val
def validate_extra(self, val): if not isinstance(val, num.ndarray): raise ValidationError( 'object is not of type numpy.ndarray: %s' % type(val)) if self.dtype is not None and self.dtype != val.dtype: raise ValidationError( 'array not of required type: need %s, got %s' % ( self.dtype, val.dtype))
if self.shape is not None: la, lb = len(self.shape), len(val.shape) if la != lb: raise ValidationError( 'array dimension mismatch: need %i, got %i' % ( la, lb))
for a, b in zip(self.shape, val.shape): if a is not None: if a != b: raise ValidationError( 'array shape mismatch: need %s, got: %s' % ( self.shape, val.shape))
def to_save(self, val): if self.serialize_as == 'table': out = BytesIO() num.savetxt(out, val, fmt='%12.7g') return literal(out.getvalue().decode('utf-8')) elif self.serialize_as == 'base64' \ or self.serialize_as == 'base64-compat': data = val.astype(self.serialize_dtype).tostring() return literal(b64encode(data).decode('utf-8')) elif self.serialize_as == 'list': if self.dtype == num.complex: return [repr(x) for x in val] else: return val.tolist() elif self.serialize_as == 'npy': out = BytesIO() try: num.save(out, val, allow_pickle=False) except TypeError: # allow_pickle only available in newer NumPy num.save(out, val)
return literal(b64encode(out.getvalue()).decode('utf-8'))
elif self.serialize_as == 'base64+meta': serialize_dtype = self.serialize_dtype or \ restricted_dtype_map[val.dtype]
data = val.astype(serialize_dtype).tostring()
return dict( dtype=serialize_dtype, shape=val.shape, data=literal(b64encode(data).decode('utf-8')))
__all__ = ['Array'] |