# http://pyrocko.org - GPLv3
#
# The Pyrocko Developers, 21st Century
# ---|P------/S----------~Lg----------
from __future__ import absolute_import
from builtins import zip
from builtins import str as newstr
import numpy as num
from io import BytesIO
from base64 import b64decode, b64encode
import binascii
from .guts import TBase, Object, ValidationError, literal
try:
unicode
except NameError:
unicode = str
restricted_dtype_map = {
num.dtype('float64'): '<f8',
num.dtype('float32'): '<f4',
num.dtype('int64'): '<i8',
num.dtype('int32'): '<i4',
num.dtype('int16'): '<i2',
num.dtype('int8'): '<i1'}
restricted_dtype_map_rev = dict(
(v, k) for (k, v) in restricted_dtype_map.items())
def array_equal(a, b):
return a.dtype == b.dtype \
and a.shape == b.shape \
and num.all(a == b)
class Array(Object):
dummy_for = num.ndarray
class __T(TBase):
def __init__(
self,
shape=None,
dtype=None,
serialize_as='table',
serialize_dtype=None,
*args, **kwargs):
TBase.__init__(self, *args, **kwargs)
self.shape = shape
self.dtype = dtype
assert serialize_as in (
'table', 'base64', 'list', 'npy',
'base64+meta', 'base64-compat')
self.serialize_as = serialize_as
self.serialize_dtype = serialize_dtype
def is_default(self, val):
if self._default_cmp is None:
return val is None
elif val is None:
return False
else:
return array_equal(self._default_cmp, val)
def regularize_extra(self, val):
if isinstance(val, (str, newstr)):
ndim = None
if self.shape:
ndim = len(self.shape)
if self.serialize_as == 'table':
val = num.loadtxt(
BytesIO(val.encode('utf-8')),
dtype=self.dtype, ndmin=ndim)
elif self.serialize_as == 'base64':
data = b64decode(val)
val = num.fromstring(
data, dtype=self.serialize_dtype).astype(self.dtype)
elif self.serialize_as == 'base64-compat':
try:
data = b64decode(val)
val = num.fromstring(
data,
dtype=self.serialize_dtype).astype(self.dtype)
except binascii.Error:
val = num.loadtxt(
BytesIO(val.encode('utf-8')),
dtype=self.dtype, ndmin=ndim)
elif self.serialize_as == 'npy':
data = b64decode(val)
try:
val = num.load(BytesIO(data), allow_pickle=False)
except TypeError:
# allow_pickle only available in newer NumPy
val = num.load(BytesIO(data))
elif isinstance(val, dict):
if self.serialize_as == 'base64+meta':
if not sorted(val.keys()) == ['data', 'dtype', 'shape']:
raise ValidationError(
'array in format "base64+meta" must have keys '
'"data", "dtype", and "shape"')
shape = val['shape']
if not isinstance(shape, list):
raise ValidationError('invalid shape definition')
for n in shape:
if not isinstance(n, int):
raise ValidationError('invalid shape definition')
serialize_dtype = val['dtype']
allowed = list(restricted_dtype_map_rev.keys())
if self.serialize_dtype is not None:
allowed.append(self.serialize_dtype)
if serialize_dtype not in allowed:
raise ValidationError(
'only the following dtypes are allowed: %s'
% ', '.join(sorted(allowed)))
data = val['data']
if not isinstance(data, (str, newstr)):
raise ValidationError(
'data must be given as a base64 encoded string')
data = b64decode(data)
dtype = self.dtype or \
restricted_dtype_map_rev[serialize_dtype]
val = num.fromstring(
data, dtype=serialize_dtype).astype(dtype)
if val.size != num.product(shape):
raise ValidationError('size/shape mismatch')
val = val.reshape(shape)
else:
val = num.asarray(val, dtype=self.dtype)
return val
def validate_extra(self, val):
if not isinstance(val, num.ndarray):
raise ValidationError(
'object is not of type numpy.ndarray: %s' % type(val))
if self.dtype is not None and self.dtype != val.dtype:
raise ValidationError(
'array not of required type: need %s, got %s' % (
self.dtype, val.dtype))
if self.shape is not None:
la, lb = len(self.shape), len(val.shape)
if la != lb:
raise ValidationError(
'array dimension mismatch: need %i, got %i' % (
la, lb))
for a, b in zip(self.shape, val.shape):
if a is not None:
if a != b:
raise ValidationError(
'array shape mismatch: need %s, got: %s' % (
self.shape, val.shape))
def to_save(self, val):
if self.serialize_as == 'table':
out = BytesIO()
num.savetxt(out, val, fmt='%12.7g')
return literal(out.getvalue().decode('utf-8'))
elif self.serialize_as == 'base64' \
or self.serialize_as == 'base64-compat':
data = val.astype(self.serialize_dtype).tostring()
return literal(b64encode(data).decode('utf-8'))
elif self.serialize_as == 'list':
if self.dtype == num.complex:
return [repr(x) for x in val]
else:
return val.tolist()
elif self.serialize_as == 'npy':
out = BytesIO()
try:
num.save(out, val, allow_pickle=False)
except TypeError:
# allow_pickle only available in newer NumPy
num.save(out, val)
return literal(b64encode(out.getvalue()).decode('utf-8'))
elif self.serialize_as == 'base64+meta':
serialize_dtype = self.serialize_dtype or \
restricted_dtype_map[val.dtype]
data = val.astype(serialize_dtype).tostring()
return dict(
dtype=serialize_dtype,
shape=val.shape,
data=literal(b64encode(data).decode('utf-8')))
__all__ = ['Array']
|