Source code for pyrocko.guts_array

# http://pyrocko.org - GPLv3
#
# The Pyrocko Developers, 21st Century
# ---|P------/S----------~Lg----------
from __future__ import absolute_import

import numpy as num
from io import BytesIO
from base64 import b64decode, b64encode
import binascii

from .guts import TBase, Object, ValidationError, literal, newstr


try:
    unicode
except NameError:
    unicode = str


restricted_dtype_map = {
    num.dtype('float64'): '<f8',
    num.dtype('float32'): '<f4',
    num.dtype('int64'): '<i8',
    num.dtype('int32'): '<i4',
    num.dtype('int16'): '<i2',
    num.dtype('int8'): '<i1'}

restricted_dtype_map_rev = dict(
    (v, k) for (k, v) in restricted_dtype_map.items())


def array_equal(a, b):
    return a.dtype == b.dtype \
        and a.shape == b.shape \
        and num.all(a == b)


class Array(Object):

    dummy_for = num.ndarray

    class __T(TBase):
        def __init__(
                self,
                shape=None,
                dtype=None,
                serialize_as='table',
                serialize_dtype=None,
                *args, **kwargs):

            TBase.__init__(self, *args, **kwargs)
            self.shape = shape
            self.dtype = dtype
            assert serialize_as in (
                'table', 'base64', 'list', 'npy',
                'base64+meta', 'base64-compat')
            self.serialize_as = serialize_as
            self.serialize_dtype = serialize_dtype

        def is_default(self, val):
            if self._default is None:
                return val is None
            elif val is None:
                return False
            else:
                return array_equal(self._default, val)

        def regularize_extra(self, val):
            if isinstance(val, (str, newstr)):
                ndim = None
                if self.shape:
                    ndim = len(self.shape)

                if self.serialize_as == 'table':
                    val = num.loadtxt(
                        BytesIO(val.encode('utf-8')),
                        dtype=self.dtype, ndmin=ndim)

                elif self.serialize_as == 'base64':
                    data = b64decode(val)
                    val = num.fromstring(
                        data, dtype=self.serialize_dtype).astype(self.dtype)

                elif self.serialize_as == 'base64-compat':
                    try:
                        data = b64decode(val)
                        val = num.fromstring(
                            data,
                            dtype=self.serialize_dtype).astype(self.dtype)
                    except binascii.Error:
                        val = num.loadtxt(
                            BytesIO(val.encode('utf-8')),
                            dtype=self.dtype, ndmin=ndim)

                elif self.serialize_as == 'npy':
                    data = b64decode(val)
                    try:
                        val = num.load(BytesIO(data), allow_pickle=False)
                    except TypeError:
                        # allow_pickle only available in newer NumPy
                        val = num.load(BytesIO(data))

            elif isinstance(val, dict):
                if self.serialize_as == 'base64+meta':
                    if not sorted(val.keys()) == ['data', 'dtype', 'shape']:
                        raise ValidationError(
                            'array in format "base64+meta" must have keys '
                            '"data", "dtype", and "shape"')

                    shape = val['shape']
                    if not isinstance(shape, list):
                        raise ValidationError('invalid shape definition')

                    for n in shape:
                        if not isinstance(n, int):
                            raise ValidationError('invalid shape definition')

                    serialize_dtype = val['dtype']
                    allowed = list(restricted_dtype_map_rev.keys())
                    if self.serialize_dtype is not None:
                        allowed.append(self.serialize_dtype)

                    if serialize_dtype not in allowed:
                        raise ValidationError(
                            'only the following dtypes are allowed: %s'
                            % ', '.join(sorted(allowed)))

                    data = val['data']
                    if not isinstance(data, (str, newstr)):
                        raise ValidationError(
                            'data must be given as a base64 encoded string')

                    data = b64decode(data)

                    dtype = self.dtype or \
                        restricted_dtype_map_rev[serialize_dtype]

                    val = num.fromstring(
                        data, dtype=serialize_dtype).astype(dtype)

                    if val.size != num.product(shape):
                        raise ValidationError('size/shape mismatch')

                    val = val.reshape(shape)

            else:
                val = num.asarray(val, dtype=self.dtype)

            return val

        def validate_extra(self, val):
            if not isinstance(val, num.ndarray):
                raise ValidationError(
                    'object %s is not of type numpy.ndarray: %s' % (
                        self.xname(), type(val)))
            if self.dtype is not None and self.dtype != val.dtype:
                raise ValidationError(
                    'array %s not of required type: need %s, got %s' % (
                        self.xname(), self.dtype, val.dtype))

            if self.shape is not None:
                la, lb = len(self.shape), len(val.shape)
                if la != lb:
                    raise ValidationError(
                        'array %s dimension mismatch: need %i, got %i' % (
                            self.xname(), la, lb))

                for a, b in zip(self.shape, val.shape):
                    if a is not None:
                        if a != b:
                            raise ValidationError(
                                'array %s shape mismatch: need %s, got: %s' % (
                                    self.xname(), self.shape, val.shape))

        def to_save(self, val):
            if self.serialize_as == 'table':
                out = BytesIO()
                num.savetxt(out, val, fmt='%12.7g')
                return literal(out.getvalue().decode('utf-8'))
            elif self.serialize_as == 'base64' \
                    or self.serialize_as == 'base64-compat':
                data = val.astype(self.serialize_dtype).tostring()
                return literal(b64encode(data).decode('utf-8'))
            elif self.serialize_as == 'list':
                if self.dtype == complex:
                    return [repr(x) for x in val]
                else:
                    return val.tolist()
            elif self.serialize_as == 'npy':
                out = BytesIO()
                try:
                    num.save(out, val, allow_pickle=False)
                except TypeError:
                    # allow_pickle only available in newer NumPy
                    num.save(out, val)

                return literal(b64encode(out.getvalue()).decode('utf-8'))

            elif self.serialize_as == 'base64+meta':
                serialize_dtype = self.serialize_dtype or \
                    restricted_dtype_map[val.dtype]

                data = val.astype(serialize_dtype).tostring()

                return dict(
                    dtype=serialize_dtype,
                    shape=val.shape,
                    data=literal(b64encode(data).decode('utf-8')))


__all__ = ['Array']