Source code for serdio.serde

"""The Serdio serialization and deserialization implementation.

The ``serdio.serde`` spec is an extension of MessagePack that can handle some extra
Python types, while also allowing users to supply their own hooks for seamless
encoding/decoding of user-defined types.

**Usage** ::

    xyb_bytes = serdio.serialize(2, 3.0, b=2.0)

    x, y = serdio.deserialize(xyb_bytes)
    print(f"{x}, {y}")
    # 2, 3.0

    args, kwargs = serdio.deserialize(xyb_bytes, as_signature=True)
    print(f"{args[0]}, {args[1]}, b={kwargs["b"]}")
    # 2, 3.0, b=2.0
"""
import dataclasses
import enum
from typing import Any
from typing import Callable
from typing import Dict
from typing import Tuple

import msgpack

ARGS_MARKER = "_serdio_args_"
KWARGS_MARKER = "_serdio_kwargs_"


class _MsgpackExtType(enum.IntEnum):
    """Messagepack custom type ids."""

    native_complex = 1
    native_tuple = 2
    native_set = 3
    native_frozenset = 4


def _default_encoder(x, custom_encoder=None):
    """
    An extension of the default MessagePack encoder.

    Supports Python types not usually handled by MessagePack (`complex`, `tuple`, `set`,
    `frozenset`), as well as optional user-supplied types.

    Args:
        x: input value
        custom_encoder: optional callable that implements an encoder for user-defined
            types that might be encountered inside collection types.

    Returns:
        The extended MessagePack encoder.
    """
    if custom_encoder is None:
        encoder = _default_encoder  # noqa: E731
    else:

        def encoder(x):
            uncollected = _default_encoder(x, custom_encoder=custom_encoder)
            return custom_encoder(uncollected)

    if isinstance(x, complex):
        return msgpack.ExtType(
            _MsgpackExtType.native_complex,
            msgpack.packb((x.real, x.imag), default=encoder, strict_types=True),
        )
    elif isinstance(x, tuple):
        return msgpack.ExtType(
            _MsgpackExtType.native_tuple,
            msgpack.packb(list(x), default=encoder, strict_types=True),
        )
    elif isinstance(x, set):
        return msgpack.ExtType(
            _MsgpackExtType.native_set,
            msgpack.packb(list(x), default=encoder, strict_types=True),
        )
    elif isinstance(x, frozenset):
        return msgpack.ExtType(
            _MsgpackExtType.native_frozenset,
            msgpack.packb(list(x), default=encoder, strict_types=True),
        )
    return x


def _msgpack_ext_unpack(code, data, custom_decoder=None):
    """An extension of the default MessagePack decoder.

    This is the inverse of ``_default_encoder``.

    Args:
        code: Data type encoded as 1 (complex), 2 (tuple), 3 (set), or 4 (frozen set)
        data: Byte array to unpack
        custom_decoder: Optional callable that implements a decoder for user-defined
            types that might be encountered inside collection types.

    Returns:
        The extended MessagePack decoder.
    """
    if custom_decoder is None:
        custom_decoder = lambda x: x  # noqa: E731
        ext_hook = _msgpack_ext_unpack
    else:
        ext_hook = lambda c, d: _msgpack_ext_unpack(  # noqa: E731
            c, d, custom_decoder=custom_decoder
        )
    if code == _MsgpackExtType.native_complex:
        complex_tuple = msgpack.unpackb(
            data, ext_hook=ext_hook, object_hook=custom_decoder
        )
        return complex(complex_tuple[0], complex_tuple[1])
    elif code == _MsgpackExtType.native_tuple:
        tuple_list = msgpack.unpackb(
            data, ext_hook=ext_hook, object_hook=custom_decoder
        )
        return tuple(tuple_list)
    elif code == _MsgpackExtType.native_set:
        set_list = msgpack.unpackb(data, ext_hook=ext_hook, object_hook=custom_decoder)
        return set(set_list)
    elif code == _MsgpackExtType.native_frozenset:
        frozenset_list = msgpack.unpackb(
            data, ext_hook=ext_hook, object_hook=custom_decoder
        )
        return frozenset(frozenset_list)
    return msgpack.ExtType(code, data)


[docs]def serialize(*args: Any, encoder: Callable = None, **kwargs: Any) -> bytes: """Serializes a set of ``args` and ``kwargs`` into bytes with MessagePack. Args: *args: Positional arguments to include in the serialized bytes encoder: Optional callable specifying MessagePack encoder for user-defined types. See :class:`.SerdeHookBundle` for details. kwargs: Keyword arguments to include in the serialized bytes Returns: Dictionary of ``args`` and ``kwargs``, serialized with MessagePack and optional custom ``encoder``. Raises: TypeError: if ``encoder`` is not callable. Other errors can be raised by MessagePack during packing. """ x = {ARGS_MARKER: args} if len(kwargs) > 0: x[KWARGS_MARKER] = kwargs encode_hook = _default_encoder if encoder is not None: if not callable(encoder): raise TypeError( f"`encoder` arg needs to be callable, found type {type(encoder)}" ) encode_hook = lambda x: _default_encoder( # noqa: E731 x, custom_encoder=encoder ) return msgpack.packb(x, default=encode_hook, strict_types=True)
[docs]def deserialize( serdio_bytes: bytes, decoder: Callable = None, as_signature: bool = False ) -> Any: """Unpacks serdio-serialized bytes to an object Args: serdio_bytes: Byte array to deserialize. decoder: Optional callable specifying Messagepack decoder for user-defined types. See :class:`.SerdeHookBundle` for details. as_signature: Optional boolean determining return format. If True, unpack the serialized byte array into an ``args`` tuple and a ``kwargs`` dictionary. This argument is most useful when the user is trying to serialize the inputs to a function of unknown arity. Returns: The deserialized object. If ``as_signature=True``, assumes the resulting object is a dictionary with an ``args`` tuple and ``kwargs`` dict for values, and returns these two instead of the full dictionary. """ ext_hook = _msgpack_ext_unpack if decoder is not None: if not callable(decoder): raise TypeError( f"`decoder` needs to be a callable, found type {type(decoder)}" ) ext_hook = lambda c, d: _msgpack_ext_unpack( # noqa: E731 c, d, custom_decoder=decoder ) unpacked = msgpack.unpackb(serdio_bytes, ext_hook=ext_hook, object_hook=decoder) unpacked_args = unpacked.get(ARGS_MARKER) unpacked_kwargs = unpacked.get(KWARGS_MARKER, {}) if as_signature: return unpacked_args, unpacked_kwargs return_vals = unpacked_args if len(return_vals) == 1: return return_vals[0] return return_vals
[docs]@dataclasses.dataclass class SerdeHookBundle: """An encoder-decoder hook pair for user-defined types. The ``encoder_hook`` and ``decoder_hook`` specify how to convert from a user-defined type into an equivalent collection of Python-native values and back. Thus for any object ``X`` of user-defined type ``T``, the following relationship should hold: :: hook_bundle = SerdioHookBundle(f, g) native_X = hook_bundle.encoder_hook(X) # f(X) Y = hook_bundle.decoder_hook(native_X) # g(native_X) assert X == Y Note that ``native_X`` above needs to be some collection of native Python values, e.g. a simple dataclass can be represented as a dictionary of attributes mapping to values. Args: encoder_hook: An encoder function specifying how :func:`.serdio.serde.serialize` should break down any custom types into Python native types. decoder_hook: The inverse of ``encoder_hook``, specifying how :func:`.serdio.serde.deserialize` should re-assemble the ``encoder_hook`` output into user-defined types. """ encoder_hook: Callable decoder_hook: Callable
[docs] def to_dict(self) -> Dict: """Return the encoder-decoder hook pair as a dictionary.""" return dataclasses.asdict(self)
[docs] def unbundle(self) -> Tuple: """Return the encoder-decoder hook pair as a tuple.""" return dataclasses.astuple(self)
[docs]def bundle_serde_hooks(hook_bundle): """Helper to lift an encoder-decoder hook pair into a :class:`.SerdeHookBundle`. Args: hook_bundle: A tuple, list, dict or :class:`.SerdeHookBundle` containing an encoder-decoder hook pair. If a tuple or list, the encoder_hook must come first. If a dictionary, must have exactly two keys ``"encoder_hook"`` and ``"decoder_hook"``. Returns: A :class:`.SerdeHookBundle` encapsulating the encoder-decoder hook pair. Raises: ValueError: if the ``hook_bundle`` dictionary is malformed. """ if isinstance(hook_bundle, (tuple, list)): hook_bundle = SerdeHookBundle(*hook_bundle) elif isinstance(hook_bundle, dict): _check_dict_hook_bundle(hook_bundle) hook_bundle = SerdeHookBundle(**hook_bundle) return hook_bundle
def _check_dict_hook_bundle(hook_bundle): correct_size = len(hook_bundle) == 2 correct_keys = "encoder_hook" in hook_bundle and "decoder_hook" in hook_bundle if not correct_size or not correct_keys: raise ValueError( "`hook_bundle` dict must have exactly two key-value pairs: 'encoder_hook'" f"and 'decoder_hook'. Found dict with keys: {list(hook_bundle.keys())}." )