from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Sequence, Type, Union, cast
import attr
from multimethod import multimethod
from visions.relations import TypeRelation
_DEFAULT = object()
class RelationsIterManager:
"""Class to enable to treat relations as dict"""
def __init__(self, relations: Sequence[TypeRelation]):
self._keys: Dict["Type[VisionsBaseType]", int] = {
item.related_type: i for i, item in enumerate(relations)
}
self.values = tuple(relations)
def __getitem__(self, index: Union["Type[VisionsBaseType]", int]) -> TypeRelation:
idx = index if isinstance(index, int) else self._keys[index]
return self.values[idx]
def get(
self, index: Union["Type[VisionsBaseType]", int], default: Any = _DEFAULT
) -> Union[TypeRelation, Any]:
try:
return self[index]
except (IndexError, KeyError) as err:
if default is _DEFAULT:
raise err
else:
return default
def __iter__(self):
yield from self.values
class VisionsBaseTypeMeta(ABCMeta):
_relations: Optional[RelationsIterManager] = None
def __contains__(cls, sequence: Sequence) -> bool:
return cls.contains_op(sequence, dict())
@staticmethod
def get_relations() -> Sequence[TypeRelation]:
raise NotImplementedError
@staticmethod
def contains_op(item: Any, state: dict) -> bool:
raise NotImplementedError
@property
def relations(cls) -> RelationsIterManager:
from visions.relations.relations import IdentityRelation
if cls._relations is None:
cls._relations = RelationsIterManager(
[
attr.evolve(
r,
type=cls,
relationship=cls.contains_op
if r.relationship is None
else r.relationship,
)
if isinstance(r, IdentityRelation)
else attr.evolve(
r,
type=cls,
relationship=multimethod(r.relationship)
if r.relationship is not None
else None,
transformer=multimethod(r.transformer),
)
for r in cls.get_relations()
]
)
return cls._relations
def __add__(cls, other):
from visions.types import Generic
from visions.typesets import VisionsTypeset
if not any(issubclass(x, Generic) for x in [cls, other]):
return VisionsTypeset({Generic, cls, other})
return VisionsTypeset({cls, other})
def __str__(cls) -> str:
return str(cls.__name__)
def __repr__(cls) -> str:
return str(cls)
[docs]class VisionsBaseType(metaclass=VisionsBaseTypeMeta):
"""Abstract implementation of a vision type.
Provides a common API for building custom visions data types.
"""
[docs] def __init__(self):
pass
@staticmethod
@abstractmethod
def get_relations() -> Sequence[TypeRelation]:
raise NotImplementedError
@classmethod
def register_transformer(
cls, relation: "Type[VisionsBaseType]", dispatch_type: Any
):
relation_transformer = cls.relations[relation].transformer
return cast(Any, relation_transformer).register(dispatch_type, dict)
@classmethod
def register_relationship(
cls, relation: "Type[VisionsBaseType]", dispatch_type: Any
):
relation_relationship = cls.relations[relation].relationship
return cast(Any, relation_relationship).register(dispatch_type, dict)
@staticmethod
@multimethod
@abstractmethod
def contains_op(sequence: Any, state: Any) -> bool:
raise NotImplementedError