Source code for ciri.core

import logging
from abc import ABCMeta

from ciri.abstract import (AbstractField, AbstractSchema, SchemaFieldDefault,
                           SchemaFieldMissing, UseSchemaOption)
from ciri.compat import add_metaclass
from ciri.encoder import JSONEncoder
from ciri.exception import SchemaException, SerializationError, ValidationError, FieldValidationError
from ciri.fields import FieldError, Schema as SchemaField
from ciri.registry import schema_registry


logger = logging.getLogger('ciri')


class ErrorHandler(object):
    """
    Default `Schema` Error Handler.
    """

    def __init__(self):
        #: Holds formatted Errors
        self.errors = {}

    def reset(self):
        """Clears the current error context"""
        self.errors = {}

    def add(self, key, field_error):
        """Takes a `FieldError`

        :param key: error key 
        :type key: str

        """
        key = str(key)
        self.errors[key] = {'msg': field_error.message}
        if field_error.errors:
            handler = self.__class__()
            for k, v in field_error.errors.items():
                handler.add(k, v)
            self.errors[key]['errors'] = handler.errors


[docs]class SchemaOptions(object): """ Holds the schema behavior configuration :param allow_none: Allow :class:`None` values :param raise_errors: Whether or not to raise exceptions :param error_handler: Schema error handling :param encoder: Schema encoding handler :param registry: Schema registry :type allow_none: bool :type raise_errors: bool :type error_handler: :class:`~ciri.core.ErrorHandler` :type encoder: :class:`~ciri.encoder.SchemaEncoder` :type registry: :class:`~ciri.registry.SchemaRegistry` """ def __init__(self, *args, **kwargs): defaults = { 'allow_none': False, 'raise_errors': True, 'error_handler': ErrorHandler, 'encoder': JSONEncoder(), 'registry': schema_registry } options = dict((k, v) if k in defaults else ('_unknown', 1) for (k,v) in kwargs.items()) options.pop('_unknown', None) defaults.update(options) for k, v in defaults.items():
setattr(self, k, v) DEFAULT_SCHEMA_OPTIONS = SchemaOptions() class SchemaCallableObject(object): def __init__(self, *args, **kwargs): self.callables = ['pre_validate', 'pre_serialize', 'pre_deserialize', 'post_validate', 'post_serialize', 'post_deserialize'] for c in self.callables: setattr(self, c, kwargs.get(c, {})) def find(self, schema): for key, field in schema._fields.items(): if isinstance(field, Schema): continue # schemas will handle their own for c in self.callables: field_callable = getattr(field, c, None) if field_callable: updated_callables = [] for item in field_callable: if callable(item): updated_callables.append(item) else: method = getattr(schema, item, None) if callable(method): updated_callables.append(method.__get__(schema, None)) getattr(self, c)[key] = updated_callables class AbstractPolySchema(AbstractSchema): pass class ABCSchema(ABCMeta): """ Schema Metaclass Looks for :class:`ciri.fields.Field` attributes and handles the schema magic methods. """ def __new__(cls, name, bases, attrs): cls, name, bases, attrs = cls.prepare_class(cls, name, bases, attrs) klass = ABCMeta.__new__(cls, name, bases, dict(attrs)) klass._fields = {} klass._elements = {} klass._subfields = {} klass._pending_schemas = {} klass._callables = SchemaCallableObject() klass._config = DEFAULT_SCHEMA_OPTIONS klass.handle_bases(bases) klass.handle_poly(cls, name, bases, attrs) klass.handle_config() klass.find_fields() klass.process_fields() klass._callables.find(klass) return klass @staticmethod def prepare_class(cls, name, bases, attrs): """ Prepares the class instance for different Schema types. Currently this only handles the :class:`PolySchema` type and it's subclasses.""" clear_poly = False meta = False # Meta : compose attributes if 'Meta' in attrs and getattr(attrs['Meta'], 'compose', None): compose = getattr(attrs['Meta'], 'compose') base_includes = getattr(attrs, '__schema_include__', []) attrs['__schema_include__'] = [s for s in compose] + base_includes # Meta : poly_id if 'Meta' in attrs and getattr(attrs['Meta'], 'poly_id', None): attrs['__poly_id__'] = getattr(attrs['Meta'], 'poly_id') # Meta : options if 'Meta' in attrs and getattr(attrs['Meta'], 'options', None): attrs['__schema_options__'] = getattr(attrs['Meta'], 'options') if '__poly_id__' in attrs: clear_poly = True if clear_poly: updated_bases = [] for base in bases: if issubclass(base, AbstractPolySchema): props = [] for x in base.__poly_inherit__: if x: props.append(x) newattrs = dict((x, getattr(base, x, None)) for x in props) for k, v in newattrs.items(): if callable(v): newattrs[k] = v.__get__(cls, None) newattrs.update(attrs) attrs = newattrs attrs['__poly_parent__'] = base attrs.update(base._fields) continue updated_bases.append(base) updated_bases.append(Schema) bases = tuple(updated_bases) return cls, name, bases, attrs def handle_bases(self, bases): """Handles the Schema inheritance, specifically bringing in the inherited field attributes""" for base in bases: if hasattr(base, '_fields'): self._fields.update(base._fields) def handle_poly(self, cls, name, bases, attrs): """Handles magic methods (e.g. `__poly_on__`) of :class:`~ciri.core.PolySchema` definitions.""" for base in bases: if issubclass(base, AbstractPolySchema): if '__poly_on__' in attrs: base.__poly_mapping__ = {} base.__poly_inherit__ = [x if not x.startswith('__poly') else None for x in attrs] if '__poly_id__' in attrs: base.__poly_parent__ = base def handle_config(self): """Handles the schema options magic method""" if hasattr(self, '__schema_options__'): self._config = getattr(self, '__schema_options__') def find_fields(self): """Find the :class:`~ciri.fields.Field` attributes and store them in the schemas `_fields` attribute.""" items = dict((k,v) for k,v in vars(self).items()) includes = getattr(self, '__schema_include__', None) inc = {} if includes: for inc_item in includes: if isinstance(inc_item, ABCSchema): inc.update(inc_item._fields) else: inc.update(inc_item) inc.update(items) ignore_fields = ['__poly_on__'] for k, v in inc.items(): if k not in ignore_fields and (isinstance(v, AbstractField) or isinstance(v, AbstractSchema)): if not v.name: v.name = k self._fields[k] = v if k in items: delattr(self, k) elif k in self._fields: self._fields.pop(k) # a subclass has overriden the field def process_fields(self): """Performs field processing. Handles: * Tracking required fields or fields that should always be checked * Tracking nested fields (aka, sub schemas) * Tracking deferred schema fields * Converting :class:`ciri.fields.Schema` fields to Schemas """ for k, v in self._fields.items(): if isinstance(v, AbstractField): if isinstance(v, SchemaField): try: self._fields[k] = v._get_schema() except AttributeError: self._pending_schemas[k] = v self._subfields[k] = v self._elements[k] = True else: if v.required or v.allow_none or (v.default is not SchemaFieldDefault): self._elements[k] = True elif isinstance(v, AbstractSchema): self._subfields[k] = v self._elements[k] = True self._e = [x for x in self._elements] def __init__(self, *args, **kwargs): """Metaclass Init - Maps the current Schema to the :class:`ciri.core.PolySchema` parent if the Schema has one""" poly_id = getattr(self, '__poly_id__', None) if poly_id: self.__poly_parent__.__poly_mapping__[poly_id] = self
[docs]@add_metaclass(ABCSchema) class Schema(AbstractSchema): def __init__(self, *args, **kwargs): for k, v in kwargs.items(): if self._fields.get(k): setattr(self, k, v) self._validation_opts = {} self._serialization_opts = {} for k in self._fields: self._fields[k]._schema = self self.config({}) def config(self, cfg): if cfg.get('options') is not None: self._config = cfg['options'] self._error_handler = self._config.error_handler() self._registry = self._config.registry self._encoder = self._config.encoder @property def errors(self): return self._error_handler.errors def pre_process(self, data): pass def _iterate(self, fields, elements, data, validation_opts, parent=None, do_serialize=False, do_deserialize=False, do_validate=False): errors = {} valid = {} halt_on_error = validation_opts.get('halt_on_error') allow_none = self._config.allow_none pre_validate = parent._callables.pre_validate pre_serialize = parent._callables.pre_serialize pre_deserialize = parent._callables.pre_deserialize post_validate = parent._callables.post_validate post_serialize = parent._callables.post_serialize post_deserialize = parent._callables.post_deserialize if do_validate: parent._raw_errors = {} parent._error_handler.reset() if hasattr(data, '__dict__'): data = vars(data) for key in elements: str_key = str(key) # field value klass_value = data.get(key, SchemaFieldMissing) missing = (klass_value is SchemaFieldMissing) invalid = False field = fields[key] # if we encounter a schema field, cache it if key in parent._pending_schemas: field = self._fields[key] = field._get_schema() parent._pending_schemas.pop(key) if key in parent._subfields: data_keys = [] for k in data: if field._fields.get(k): data_keys.append(k) key_cache = set(field._e + data_keys) suberrors, valid[key] = self._iterate(field._fields, key_cache, klass_value, validation_opts, parent=field, do_serialize=do_serialize, do_validate=do_validate, do_deserialize=do_deserialize) if suberrors: errors[key] = FieldError(parent._subfields[key], 'invalid', errors=suberrors) continue # if the field is missing, set the default value if missing and (fields[key].default is not SchemaFieldDefault): if callable(fields[key].default): klass_value = fields[key].default(parent, field) else: klass_value = fields[key].default missing = False if do_validate: # run pre validation functions if pre_validate: for func in pre_validate.get(key, []): try: valid[key] = func(parent, field, klass_value) missing = (valid[key] is SchemaFieldMissing) except FieldValidationError as field_exc: errors[str_key] = field_exc.error invalid = True break if errors and halt_on_error: break # if the field is missing, but it's required, set an error. # if a value of None is allowed and we do not have a field, skip validation # otherwise, validate the value if missing and field.required: errors[str_key] = FieldError(fields[key], 'required') invalid = True elif missing and field.allow_none is True: pass elif (missing or klass_value is None) and field.allow_none is False: errors[key] = FieldError(field, 'invalid') elif (missing or klass_value is None) and allow_none is False and field.allow_none is not True: errors[key] = FieldError(field, 'invalid') elif allow_none and field.allow_none is UseSchemaOption and (klass_value is None or klass_value is SchemaFieldMissing): pass elif field.allow_none is True and klass_value is None: pass elif not missing: try: valid[key] = field.validate(klass_value) except FieldValidationError as field_exc: errors[str_key] = field_exc.error invalid = True if post_validate: for validator in post_validate.get(key, []): try: valid[key] = validator(parent, field, valid[key]) except FieldValidationError as field_exc: errors[str_key] = field_exc.error invalid = True break if errors and halt_on_error: break if not invalid and do_serialize: # run pre serialization functions if pre_serialize: for func in pre_serialize.get(key, []): valid[key] = func(parent, field, klass_value) missing = (valid[key] is SchemaFieldMissing) # determine the field result name (serialized name) name = field.name or key # if it's allowed, and the field is missing, set the value to None if missing and allow_none and field.allow_none is UseSchemaOption: valid[name] = None elif missing and field.allow_none: valid[name] = None elif klass_value is None and field.allow_none: valid[name] = None else: valid[name] = field.serialize(valid.get(key, klass_value)) # run post serialization functions if post_serialize: for func in post_serialize.get(key, []): valid[name] = func(parent, field, klass_value) # remove old keys if the serializer renames the field if name != key: del valid[key] if not invalid and do_deserialize: # run pre deserialization functions if pre_deserialize: for func in pre_deserialize.get(key, []): valid[key] = func(parent, field, valid.get(key, klass_value)) missing = (valid[key] is SchemaFieldMissing) # if it's allowed, and the field is missing, set the value to None if missing and allow_none and field.allow_none is UseSchemaOption: valid[key] = None elif missing and field.allow_none: valid[key] = None elif klass_value is None and field.allow_none: valid[key] = None else: valid[key] = field.deserialize(valid.get(key, klass_value)) # run post deserialization functions if post_deserialize: for func in post_deserialize.get(key, []): valid[key] = func(parent, field, klass_value) for e, err in errors.items(): parent._raw_errors[e] = err parent._error_handler.add(e, err) return (errors, valid) def validate(self, data=None, halt_on_error=False, key_cache=None): data = data or self if hasattr(data, '__dict__'): data = vars(data) self._validation_opts = { 'halt_on_error': halt_on_error } if not key_cache: data_keys = [] for k in data: if self._fields.get(k): data_keys.append(k) key_cache = set(self._e + data_keys) errors, valid = self._iterate(self._fields, key_cache, data, self._validation_opts, parent=self, do_validate=True) if self._config.raise_errors and errors: raise ValidationError(self) return valid def serialize(self, data=None, skip_validation=False): data = data or self if hasattr(data, '__dict__'): data = vars(data) data_keys = [] append = data_keys.append fields = self._fields for k in data: if fields.get(k): append(k) elements = set(self._e + data_keys) errors, output = self._iterate(self._fields, elements, data, self._validation_opts, parent=self, do_serialize=True, do_validate=(not skip_validation)) if self._config.raise_errors and errors: raise ValidationError(self) return output def deserialize(self, data=None, skip_validation=False): data = data or self if hasattr(data, '__dict__'): data = vars(data) data_keys = [] append = data_keys.append fields = self._fields for k in data: if fields.get(k): append(k) elements = set(self._e + data_keys) errors, output = self._iterate(self._fields, elements, data, self._validation_opts, parent=self, do_deserialize=True, do_validate=(not skip_validation)) if self._config.raise_errors and errors: raise ValidationError(self) return self.__class__(**output) def encode(self, data=None, skip_validation=False, skip_serialization=False): self._encode_stream = [] data = data or self if hasattr(data, '__dict__'): data = vars(data) data_keys = [] append = data_keys.append fields = self._fields for k in data: if fields.get(k): append(k) elements = set(self._e + data_keys) errors, output = self._iterate(self._fields, elements, data, self._validation_opts, parent=self, do_serialize=(not skip_serialization), do_validate=(not skip_validation)) if self._config.raise_errors and errors: raise ValidationError(self) return self._encoder.encode(output, self) def __eq__(self, other): if isinstance(other, AbstractSchema): if self.serialize() == other.serialize(): return True return False
return NotImplemented class PolySchema(AbstractPolySchema, Schema): def __init__(self, *args, **kwargs): self.__poly_args__ = args self.__poly_kwargs__ = kwargs super(PolySchema, self).__init__(*args, **kwargs) def deserialize(self, data=None, *args, **kwargs): ident_key = self.__poly_on__.name data = data or self.__poly_kwargs__ or self if hasattr(data, '__dict__'): data = vars(data) id_ = data.get(ident_key) if not id_: raise SerializationError schema = self.__poly_mapping__.get(id_)(*self.__poly_args__, **self.__poly_kwargs__) return schema.deserialize(data, *args, **kwargs) def serialize(self, data=None, *args, **kwargs): ident_key = self.__poly_on__.name data = data or self.__poly_kwargs__ or self if hasattr(data, '__dict__'): data = vars(data) id_ = data.get(ident_key) if not id_: raise SerializationError schema = self.__poly_mapping__.get(id_)(*self.__poly_args__, **self.__poly_kwargs__) return schema.serialize(data, *args, **kwargs) def encode(self, data=None, *args, **kwargs): ident_key = self.__poly_on__.name data = data or self.__poly_kwargs__ or self if hasattr(data, '__dict__'): data = vars(data) id_ = data.get(ident_key) if not id_: raise SerializationError schema = self.__poly_mapping__.get(id_)(*self.__poly_args__, **self.__poly_kwargs__) return schema.encode(data, *args, **kwargs) @classmethod def polymorph(cls, *args, **kwargs): ident_key = cls.__poly_on__.name id_ = kwargs.get(ident_key) schema = cls.__poly_mapping__.get(id_)(*args, **kwargs) return schema