# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Utilities for working with transforms."""
import copy
import functools
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    TypeVar,
    no_type_check,
)
from unittest.mock import patch
import torch
import torch._dynamo
from torch import Tensor, nn
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from .. import functional as U
from .._internal_utils import generate__all__
T = TypeVar("T")
Backend = Callable[[GraphModule, List[Tensor]], Callable[..., Any]]
_unit_scaled_functions = [getattr(U, f) for f in U.__all__]
[docs]
def torch_nn_modules_to_user_modules(mod: nn.Module) -> None:
    """
    Convert torch.nn.module classes to `trivial_subclass` versions.
    By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
    :mod:`torch.nn.functional` functions when capturing the FX graph.
    This function makes `torch.nn` modules into user modules.
    To use this with a :class:`torch.nn.Module` the typical use case
    is to call `module = torch_nn_modules_to_user_modules(module)`.
    """
    for n, submod in mod.named_children():
        torch_nn_modules_to_user_modules(submod)
        # Mirroring the check at https://github.com/pytorch/pytorch/blob/34bce27f0d12bf7226b37dfe365660aad456701a/torch/_dynamo/variables/nn_module.py#L307 # noqa: E501
        if submod.__module__.startswith(("torch.nn.", "torch.ao.")):
            # Generate a new name, so e.g. torch.nn.modules.sparse.Embedding
            # becomes trivial_subclass_modules_sparse_Embedding
            modulename = submod.__module__
            modulename = modulename.replace("torch.nn.", "", 1)
            modulename = modulename.replace(".", "_")
            newtypename = "trivial_subclass_" + modulename + "_" + type(submod).__name__
            # Create a new type object deriving from type(submod)
            newmodtype = type(newtypename, (type(submod),), {})
            # Initialize and copy state using pickle
            newsubmod = newmodtype.__new__(newmodtype)  # type: ignore [call-overload]
            state = submod.__getstate__()  # type: ignore [no-untyped-call]
            newsubmod.__setstate__(state)
            # Update module in mod
            setattr(mod, n, newsubmod) 
[docs]
def patch_to_expand_modules(fn: Callable[..., T]) -> Callable[..., T]:
    """By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
    :mod:`torch.nn.functional` functions when capturing the FX graph.
    Any function which is wrapped in
    :func:`torch._dynamo.optimise` (or :func:`torch.compile`) and is then passed to
    this function as `fn` will now automatically recurse into
    :mod:`torch.nn` modules or :mod:`torch.nn.functional` functions.
    In practice, to use this with a :class:`torch.nn.Module` the typical use case
    is to call `module = torch._dynamo.optimize(backend)(module)`, followed by
    `module.forward = patch_to_expand_modules(module.forward)`.
    This should be used in conjunction with :func:`torch_nn_modules_to_user_modules`
    Args:
        fn (Callable[..., T]): the function to be patched.
    Returns:
        Callable[..., T]: the new version of `fn` with patching applied.
    """
    def _patched_call_function(  # type: ignore[no-untyped-def]
        self,
        tx,
        args,
        kwargs,
    ):  # pragma: no cover
        # Removing the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
        return super(
            torch._dynamo.variables.functions.UserMethodVariable, self
        ).call_function(tx, args, kwargs)
    @functools.wraps(fn)
    def new_fn(*args: Any, **kwargs: Any) -> Any:
        with patch(
            "torch._dynamo.variables.functions.UserMethodVariable.call_function",
            new=_patched_call_function,
        ):
            return fn(*args, **kwargs)
    return new_fn 
[docs]
def replace_node_with_function(
    graph: Graph,
    source: Node,
    target_fn: Callable[..., Any],
    args: Optional[Tuple[Any, ...]] = None,
    kwargs: Optional[Dict[Any, Any]] = None,
    keep_type_expr: bool = True,
) -> None:
    """Given a source node and its accompanying graph, remove the node and replace it
    with a new node that represents calling the target function.
    Args:
        graph (Graph): the graph in which the node is present.
        source (Node): the node to be replaced.
        target_fn (Callable[..., Any]): the function to be contained in the new node.
        args (Optional[Tuple[Any, ...]], optional): args of the new node.
            Defaults to None.
        kwargs (Optional[Dict[Any, Any]], optional): kwargs of the new node.
            Defaults to None.
        keep_type_expr (bool, optional): retain the type expression of the removed node.
            Defaults to True.
    """
    if args is None:
        args = source.args
    if kwargs is None:
        kwargs = source.kwargs
    type_expr = getattr(source, "type", None) if keep_type_expr else None
    with graph.inserting_after(source):
        new_node = graph.call_function(target_fn, args, kwargs, type_expr)
        source.replace_all_uses_with(new_node)
        graph.erase_node(source) 
def _compose_backends(backends: Iterable[Backend]) -> Backend:
    def composite_backend(
        gm: GraphModule, example_inputs: List[Tensor]
    ) -> Callable[..., Any]:
        for b in backends:
            new_gm = b(gm, example_inputs)
            new_gm._param_name_to_source = getattr(  # type: ignore
                gm,
                "_param_name_to_source",
                None,
            )
            gm = new_gm  # type: ignore[assignment]
        return gm
    return composite_backend
M = TypeVar("M", bound=nn.Module)
__all__ = generate__all__(__name__)