Source code for schematics.types.union

import inspect

from ..common import *
from ..datastructures import OrderedDict
from ..exceptions import ConversionError
from ..transforms import get_import_context, get_export_context
from .base import BaseType


def _valid_init_args(type_):
    args = set()
    for cls in type_.__mro__:
        args.update(inspect.getargspec(cls.__init__).args[1:])
        if cls is BaseType:
            break
    return args

def _filter_kwargs(valid_args, kwargs):
    return dict((k, v) for k, v in kwargs.items() if k in valid_args)


class UnionType(BaseType):

    types = None

    MESSAGES = {
        'convert': "Couldn't interpret value '{0}' as any of {1}.",
    }

    _baseclass_args = _valid_init_args(BaseType)

    def __init__(self, types=None, resolver=None, **kwargs):

        self._types = OrderedDict()
        types = types or self.types
        if resolver:
            self.resolve = resolver

        for type_ in types:
            if isinstance(type_, type) and issubclass(type_, BaseType):
                type_ = type_(**_filter_kwargs(_valid_init_args(type_), kwargs))
            elif not isinstance(type_, BaseType):
                raise TypeError("Got '%s' instance instead of a Schematics type" % type_.__class__.__name__)
            self._types[type_.__class__] = type_
            self.typenames = tuple((cls.__name__ for cls in self._types))

        super(UnionType, self).__init__(**_filter_kwargs(self._baseclass_args, kwargs))

    def resolve(self, value, context):
        for field in self._types.values():
            try:
                value = field.convert(value, context)
            except ConversionError:
                pass
            else:
                return field, value
        return None

    def _resolve(self, value, context):
        response = self.resolve(value, context)
        if isinstance(response, type):
            field = self._types[response]
            try:
                response = field, field.convert(value, context)
            except ConversionError:
                pass
        if isinstance(response, tuple):
            return response
        raise ConversionError(self.messages['convert'].format(value, self.typenames))

    def convert(self, value, context=None):
        context = context or get_import_context()
        field, native_value = self._resolve(value, context)
        return native_value

    def validate(self, value, context=None):
        field, _ = self._resolve(value, context)
        return field.validate(value, context)

    def _export(self, value, format, context=None):
        field, _ = self._resolve(value, context)
        return field._export(value, format, context)

    def to_native(self, value, context=None):
        field, _ = self._resolve(value, context)
        return field.to_native(value, context)

    def to_primitive(self, value, context=None):
        field, _ = self._resolve(value, context)
        return field.to_primitive(value, context)


__all__ = module_exports(__name__)