Source code for creyone_layer.utils.registry
import sys, warnings
from collections import defaultdict
from typing import Any, Callable, Optional
_module_to_layers: dict[str, set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_layer_to_module: dict[str, dict[str, str]] = defaultdict(dict) # mapping of model names to module names
_layer_entrypoints: dict[str, dict[str, Callable[..., Any]]] = defaultdict(dict) # mapping of model names to architecture entrypoint fns
def register_layer(layer_family: str = 'any'):
def _register_layer(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module
mod = sys.modules[fn.__module__]
module_name = fn.__module__.split('.')[-1]
# add model to __all__ in module
layer_name = fn.__name__
if not hasattr(mod, '__all__'): mod.__all__ = []
mod.__all__.append(layer_name)
# add entries to registry dict/sets
if layer_name in _layer_entrypoints[layer_family]:
warnings.warn(
f'Overwriting {layer_name} in registry with {fn.__module__}.{layer_name}. This is because the name being '
'registered conflicts with an existing name. Please check if this is not expected.',
stacklevel=2,
)
_layer_entrypoints[layer_family][layer_name] = fn
_layer_to_module[layer_family][layer_name] = module_name
_module_to_layers[module_name].add(layer_name)
return fn
return _register_layer
[docs]
def layer_entrypoint(layer_name: str,
layer_family: Optional[str] = None,
module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name
"""
if module_filter and layer_name not in _module_to_layers.get(module_filter, set()):
raise RuntimeError(f'Model ({layer_name}) not found in module {module_filter}.')
if layer_family is not None and layer_name in _layer_entrypoints[layer_family]:
return _layer_entrypoints[layer_family][layer_name]
if layer_name not in _layer_entrypoints['any']:
raise RuntimeError(f'Layer [{layer_name}] not found in {layer_family} and `any`.')
return _layer_entrypoints['any'][layer_name]