import warnings
from functools import singledispatch
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)
import networkx as nx
import pandas as pd
from visions.types.generic import Generic
from visions.types.type import VisionsBaseType
TypeOrTypeset = TypeVar("TypeOrTypeset", Type[VisionsBaseType], "VisionsTypeset")
pathTypes = TypeVar(
"pathTypes", Type[VisionsBaseType], Dict[str, Type[VisionsBaseType]]
)
pdT = TypeVar("pdT", pd.Series, pd.DataFrame)
T = Type[VisionsBaseType]
def build_graph(nodes: Set[Type[VisionsBaseType]]) -> Tuple[nx.DiGraph, nx.DiGraph]:
"""Constructs a traversable relation graph between visions types
Builds a type relation graph from a collection of :class:`visions.types.type.VisionsBaseType` where
each node corresponds to a type and each edge is a relation defined on the type.
Args:
nodes: An Sequence of :class:`visions.types.type.VisionsBaseType`
Returns:
A directed graph of type relations for the provided nodes.
"""
style_map = {True: "dashed", False: "solid"}
relation_graph = nx.DiGraph()
relation_graph.add_nodes_from(nodes)
noninferential_edges = []
for node in nodes:
for relation in node.relations:
if relation.related_type not in nodes:
warnings.warn(
f"Provided relations included mapping from {relation.related_type} to {relation.type} "
f"but {relation.related_type} was not included in the provided list of nodes"
)
else:
relation_graph.add_edge(
relation.related_type,
relation.type,
relationship=relation,
style=style_map[relation.inferential],
)
if not relation.inferential:
noninferential_edges.append((relation.related_type, relation.type))
check_graph_constraints(relation_graph)
base_graph = relation_graph.edge_subgraph(noninferential_edges)
return relation_graph, base_graph
def check_graph_constraints(relation_graph: nx.DiGraph) -> None:
"""Validates a relation_graph is appropriately constructed
Args:
relation_graph: A directed graph representing the set of relations between type nodes.
"""
check_isolates(relation_graph)
check_cycles(relation_graph)
def check_isolates(graph: nx.DiGraph) -> None:
"""Check for orphaned nodes.
Args:
graph: the graph to check
"""
nodes = set(graph.nodes)
root_node = next(nx.topological_sort(graph))
isolates = list(set(nx.isolates(graph)) - {root_node}) # root can be isolate
graph.remove_nodes_from(isolates)
orphaned_nodes = nodes - set(graph.nodes)
if orphaned_nodes:
message = f"{orphaned_nodes} were isolates in the type relation map and consequently orphaned. "
message += "Please add some mapping to the orphaned nodes."
warnings.warn(message)
def check_cycles(graph: nx.DiGraph) -> None:
"""Check for cycles and warn if one is found
Args:
graph: the graph to check
"""
cycles = list(nx.simple_cycles(graph))
if len(cycles) > 0:
warnings.warn(f"Cyclical relations between types {cycles} detected")
def traverse_graph_with_series(
base_type: T,
series: Sequence,
graph: nx.DiGraph,
path: List[T] = None,
state: Optional[dict] = None,
) -> Tuple[Sequence, List[T], dict]:
"""Depth First Search traversal. There should be at most one successor that contains the series.
Args:
base_type: Entry-point for graph to start traversal
series: the Series to check
graph: the Graph to traverse
path: the path so far
state: traversal state
Returns:
The most uniquely specified node matching the series.
"""
if state is None:
state = dict()
if path is None:
path = []
path.append(base_type)
for vision_type in graph.successors(base_type):
relation = graph[base_type][vision_type]["relationship"]
if relation.is_relation(series, state):
series = relation.transform(series, state)
return traverse_graph_with_series(vision_type, series, graph, path, state)
return series, path, state
def traverse_graph_with_sampled_series(
base_type: T,
series: pd.Series,
graph: nx.DiGraph,
sample_size: int = 10,
state: dict = dict(),
) -> Tuple[Sequence, List[T], dict]:
"""Depth First Search traversal with sampling. There should be at most one successor that contains the series.
Args:
base_type: Entry-point for graph to start traversal
series: the Series to check
graph: the Graph to traverse
sample_size: number of items used in heuristic traversal
state: traversal state
Returns:
The most uniquely specified node matching the series.
"""
if (series.shape[0] < 1000) or (sample_size > series.shape[0]):
return traverse_graph_with_series(base_type, series, graph, state=state)
series_sample = series.sample(sample_size)
_, path, _ = traverse_graph_with_series(
base_type, series_sample, graph, state=state
)
if len(path) == 1:
return series, path, state
# Cast the full series
from_type = path[0]
for i, to_type in enumerate(path[1:]):
relation = graph[from_type][to_type]["relationship"]
if not relation.is_relation(series, state):
break
series = relation.transform(series, state)
from_type = to_type
return series, path[0 : (i + 2)], state
@singledispatch
def traverse_graph(
data: Sequence, root_node: T, graph: nx.DiGraph
) -> Tuple[Sequence, Union[List[T], Dict[str, List[T]]], Dict[str, dict]]:
return traverse_graph_with_series(root_node, data, graph)
@singledispatch
def get_type_from_path(
path_data: Union[Sequence[T], Dict[str, Sequence[T]]]
) -> Union[T, Dict[str, T]]:
raise TypeError(f"Can't get types from path object of type {type(path_data)}")
@get_type_from_path.register(list)
@get_type_from_path.register(tuple)
def _get_type_from_path_builtin(path_list: Sequence[T]) -> T:
return path_list[-1]
@get_type_from_path.register(dict)
def _get_type_from_path_dict(path_dict: Dict[str, Sequence[T]]) -> Dict[str, T]:
return {k: v[-1] for k, v in path_dict.items()}
[docs]class VisionsTypeset:
"""
A collection of :class:`visions.types.type.VisionsBaseType` with associated relationship map between them.
Attributes:
types: The collection of Visions Types derived from :class:`visions.types.type.VisionsBaseType`
base_graph: The graph of relations composed exclusively of :class:`visions.relations.relations.IdentityRelation`
relation_graph: The full relation graph including both :class:`visions.relations.relations.IdentityRelation`
and :class:`visions.relations.relations.InferenceRelation`
"""
[docs] def __init__(self, types: Set[Type[VisionsBaseType]]) -> None:
"""
Args:
types: a set of types
"""
self._root_node: Optional[T] = None
if not isinstance(types, Iterable):
raise ValueError("types should be Sequence")
self.relation_graph, self.base_graph = build_graph(set(types))
if not issubclass(self.root_node, Generic):
raise ValueError("`root_node` should be a subclass of Generic")
self.types = set(self.relation_graph.nodes)
@property
def root_node(self) -> T:
"""Returns a cached copy of the relation_graphs root node
Args:
Returns:
A cached copy of the relation_graphs root node.
"""
if self._root_node is None:
self._root_node = next(nx.topological_sort(self.relation_graph))
return self._root_node
def detect(self, data: Any) -> Tuple[Sequence, Any, dict]:
"""The results found after only considering IdentityRelations.
Notes:
This is an advanced feature, consider using `detect_type` in case the type is what is needed.
Args:
data: a DataFrame or Series to determine types over
Returns:
A tuple of the coerced sequence, visited nodes and state
"""
return traverse_graph(data, self.root_node, self.base_graph)
def detect_type(self, data: Sequence) -> Union[T, Dict[str, T]]:
"""The inferred type found only considering IdentityRelations.
Args:
data: a DataFrame or Series to determine types over
Returns:
A dictionary of {name: type} pairs in the case of DataFrame input or a type
"""
_, paths, _ = self.detect(data)
return get_type_from_path(paths)
def infer(self, data: Sequence) -> Tuple[Sequence, Any, dict]:
"""The results found after considering all relations.
Notes:
This is an advanced feature, consider using `infer_type` in case the type is what is needed.
Args:
data: a DataFrame or Series to determine types over
Returns:
A tuple of the coerced sequence, visited nodes and state
"""
return traverse_graph(data, self.root_node, self.relation_graph)
def infer_type(self, data: Sequence) -> Union[T, Dict[str, T]]:
"""The inferred type found using all type relations.
Args:
data: a DataFrame or Series to determine types over
Returns:
A dictionary of {name: type} pairs in the case of DataFrame input or a type
"""
_, paths, _ = self.infer(data)
return get_type_from_path(paths)
def cast_to_detected(self, data: Sequence) -> Sequence:
"""Transforms input data into a canonical representation using only IdentityRelations
Args:
data: a DataFrame or Series to determine types over
Returns:
new_data: The transformed DataFrame or Series.
"""
data, _, _ = self.detect(data)
return data
def cast_to_inferred(self, data: Sequence) -> Sequence:
"""Transforms input data and returns it's corresponding new type relation using all relations.
Args:
data: a DataFrame or Series to determine types over
Returns:
new_data: The transformed DataFrame or Series.
types: A dictionary of {name: type} pairs in the case of DataFrame input or a type.
"""
data, _, _ = self.infer(data)
return data
def output_graph(
self,
file_name: Union[str, Path],
base_only: bool = False,
dpi: Optional[int] = None,
) -> None:
"""Write the type graph to a file.
Args:
file_name: the file to save the output to
base_only: if True, plot the graph without relation mapping edges
dpi: set the dpi of the output image
"""
from visions.utils.graph import output_graph
if base_only:
graph = self.base_graph.copy()
else:
graph = self.relation_graph.copy()
graph.graph["node"] = {"shape": "box", "color": "red"}
if dpi is not None:
graph.graph["graph"] = {"dpi": dpi}
output_graph(graph, file_name)
def plot_graph(
self,
dpi: int = 800,
base_only: bool = False,
figsize: Optional[Tuple[int, int]] = None,
):
"""
Args:
dpi: dpi of the matplotlib figure.
figsize: figure size
base_only: Only display the typesets base_graph
Returns:
Displays the image
"""
import os
import tempfile
from matplotlib import image as mpimg
from matplotlib import pyplot as plt
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
self.output_graph(temp_file.name, dpi=dpi, base_only=base_only)
img = mpimg.imread(temp_file.name)
plt.figure(dpi=dpi, figsize=figsize)
plt.axis("off")
plt.imshow(img)
os.unlink(temp_file.name)
def _get_other_type(self, other: TypeOrTypeset) -> Set[T]:
"""Converts input into a set of :class:`visions.types.type.VisionsBaseType`
Args:
other: A :class:`visions.types.type.VisionsBaseType` or :class:`visions.typesets.typeset.VisionsTypeset`
Raises:
NotImplementedError:
Returns:
Set[Type[VisionsBaseType]]:
"""
if isinstance(other, VisionsTypeset):
other_types = set(other.types)
elif issubclass(other, VisionsBaseType):
other_types = {other}
else:
raise NotImplementedError(
f"Typeset operation not implemented for type {type(other)}"
)
return other_types
def replace(self, old: T, new: T) -> "VisionsTypeset":
"""Create a new typeset having replace one type with another.
Args:
old: Visions type to replace.
new: Replacement visions type.
Returns
A VisionsTypeset
"""
types = self.types.copy()
types.add(new)
types.remove(old)
return VisionsTypeset(types)
def __add__(self, other: TypeOrTypeset) -> "VisionsTypeset":
"""Adds a type or typeset into the current typeset.
Args:
other: Type or typeset to be added
Returns
A VisionsTypeset
"""
other_types = self._get_other_type(other)
return VisionsTypeset(self.types | other_types)
def __iadd__(self, other: TypeOrTypeset) -> "VisionsTypeset":
"""Adds a type or typeset into the current typeset.
Args:
other: Type or typeset to be added
Returns
A VisionsTypeset
"""
return self.__add__(other)
def __sub__(self, other: TypeOrTypeset) -> "VisionsTypeset":
"""Subtracts a type or typeset from the current typeset.
Args:
other: Type or typeset to be removed
Returns
A VisionsTypeset
"""
other_types = self._get_other_type(other)
return VisionsTypeset(self.types - other_types)
def __isub__(self, other: TypeOrTypeset) -> "VisionsTypeset":
"""Subtracts a type or typeset from the current typeset.
Args:
other: Type or typeset to be removed
Returns
A VisionsTypeset
"""
return self.__sub__(other)
def __repr__(self) -> str:
"""Pretty representation of the typeset.
Returns
A :class:`visions.typesets.typeset.VisionsTypeset`
"""
return self.__class__.__name__