Source code for creyone_layer.wrap

import inspect
from typing import Optional, Union

from torch import nn


def _consume_args(args: tuple, kwargs: dict) -> tuple[list, dict]:
    """Normalize positional args to (c1, c2, k), pulling missing values from kwargs.

    Accepts args in (c1, c2, k) order. Any leading args that are omitted are
    extracted from kwargs using the long-form keys ('in_channels', 'out_channels',
    'kernel_size') or their short aliases ('c1', 'c2', 'k').
    """
    args = list(args)
    if len(args) == 3: return args, kwargs
    k = kwargs.pop('kernel_size', kwargs.pop('k', None))    
    if len(args) == 2: return args + [k], kwargs
    c2 = kwargs.pop('out_channels', kwargs.pop('c2', None))
    if len(args) == 1: return args + [c2, k], kwargs
    c1 = kwargs.pop('in_channels', kwargs.pop('c1', None))
    return [c1, c2, k], kwargs


def _consume_pool_args(args: tuple, kwargs: dict) -> tuple:
    """Extract kernel size from positional args or kwargs for pool layers."""
    if args:
        return args[0], kwargs
    k = kwargs.pop('kernel_size', kwargs.pop('k', None))
    return k, kwargs


def _compute_same_padding(k: int, d: int = 1) -> Union[int, list[int]]:  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if isinstance(k, int): k = [k] 
    if d > 1: k = [d * (x - 1) + 1 for x in k]  # actual kernel-size
    p = [x // 2 for x in k]  # auto-pad
    return p[0] if len(p) == 1 else p


def _wrapfn(**kwargs):
    """Extract and return common convolution/pooling parameters from kwargs.

    Pops 'stride'/'s', 'padding'/'p', 'dilation'/'d', and 'groups'/'g' from
    kwargs, returning them as (s, p, d, g). The remaining kwargs are left for
    the caller to pass through to the underlying layer.
    """
    s = kwargs.pop('stride', kwargs.pop('s', 1))
    p = kwargs.pop('padding', kwargs.pop('p', 0))
    d = kwargs.pop('dilation', kwargs.pop('d', 1))
    g = kwargs.pop('groups', kwargs.pop('g', 1))
    return s, p, d, g


[docs] def wrap_conv(cls: nn.Conv2d, opt: Union[set, str, None] = None): """Wrap a Conv Nd class with flexible argument parsing and optional behaviors. Args: cls: A ConvNd-compatible class to wrap. opt: A '+'-separated string of option flags: - 'grid': use the kernel size arg as the stride (grid-like sampling). - 'ap': auto-pad so the output spatial size matches the input. - 'dw': set groups = in_channels (depthwise convolution). Returns: A factory function that accepts (c1, c2, k, ...) positionally or as kwargs and forwards them to ``cls`` with stride, padding, dilation, and groups set. """ if opt is None: opt = set() elif isinstance(opt, str): opt = set(opt.split('+')) def _fn(*args, **kwargs): args, kwargs = _consume_args(args, kwargs) s, p, d, g = _wrapfn(**kwargs) if 'grid' in opt: s = args[-1] if 'ap' in opt: p = _compute_same_padding(args[-1], d) if 'dw' in opt: g = args[0] return cls(*args, stride=s, padding=p, dilation=d, groups=g, **kwargs) return _fn
[docs] def wrap_pool(cls, opt: Union[set, str, None] = None): """Wrap a pooling class with flexible argument parsing and optional behaviors. Args: cls: A PoolNd-compatible class to wrap. Dilation is forwarded automatically only when the class supports it (MaxPoolNd does, AvgPoolNd does not). opt: A '+'-separated string of option flags: - 'grid': use the kernel size arg as the stride (grid-like sampling). - 'ap': auto-pad so the output spatial size matches the input. - 'ar': wrap output in AutoReshape for (B HW C) tensors. Returns: A factory function that accepts (k, ...) positionally or as kwargs and forwards them to ``cls`` with stride, padding, and dilation set. """ if opt is None: opt = set() elif isinstance(opt, str): opt = set(opt.split('+')) _has_dilation = 'dilation' in inspect.signature(cls).parameters def _fn(*args, **kwargs): k, kwargs = _consume_pool_args(args, kwargs) s, p, d, _ = _wrapfn(**kwargs) if 'grid' in opt: s = k if 'ap' in opt: p = _compute_same_padding(k, d if _has_dilation else 1) extra = {'dilation': d} if _has_dilation else {} return cls(k, stride=s, padding=p, **extra, **kwargs) return _fn