3.8.4. unit_scaling.utils.ScaleTrackingInterpreter

class unit_scaling.utils.ScaleTrackingInterpreter(module: GraphModule)[source]

Wraps an fx.GraphModule such than when executed it records the standard deviation of every intermediate nn.Tensor in the forward and backward pass.

Parameters:

module (fx.GraphModule) – the module to be instrumented.

boxed_run(args_list)

Run module via interpretation and return the result. This uses the “boxed” calling convention, where you pass a list of arguments, which will be cleared by the interpreter. This ensures that input tensors are promptly deallocated.

Note

Backwards-compatibility for this API is guaranteed.

call_function(target: Callable[[...], Any] | str, args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any[source]

Execute a call_function node and return the result.

Parameters:
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the function invocation

Note

Backwards-compatibility for this API is guaranteed.

call_method(target: Callable[[...], Any] | str, args: Tuple[Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, Any]) Any

Execute a call_method node and return the result.

Parameters:
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the method invocation

Note

Backwards-compatibility for this API is guaranteed.

call_module(target: Callable[[...], Any] | str, args: Tuple[Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, Any]) Any

Execute a call_module node and return the result.

Parameters:
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the module invocation

Note

Backwards-compatibility for this API is guaranteed.

fetch_args_kwargs_from_env(n: Node) Tuple[Tuple, Dict]

Fetch the concrete values of args and kwargs of node n from the current execution environment.

Parameters:

n (Node) – The node for which args and kwargs should be fetched.

Returns:

args and kwargs with concrete values for n.

Return type:

Tuple[Tuple, Dict]

Note

Backwards-compatibility for this API is guaranteed.

fetch_attr(target: str)

Fetch an attribute from the Module hierarchy of self.module.

Parameters:

target (str) – The fully-qualified name of the attribute to fetch

Returns:

The value of the attribute.

Return type:

Any

Note

Backwards-compatibility for this API is guaranteed.

get_attr(target: Callable[[...], Any] | str, args: Tuple[Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, Any]) Any

Execute a get_attr node. Will retrieve an attribute value from the Module hierarchy of self.module.

Parameters:
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Returns:

The value of the attribute that was retrieved

Return type:

Any

Note

Backwards-compatibility for this API is guaranteed.

map_nodes_to_values(args: Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, n: Node) Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None

Recursively descend through args and look up the concrete value for each Node in the current execution environment.

Parameters:
  • args (Argument) – Data structure within which to look up concrete values

  • n (Node) – Node to which args belongs. This is only used for error reporting.

Note

Backwards-compatibility for this API is guaranteed.

output(target: Callable[[...], Any] | str, args: Tuple[Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, Any]) Any

Execute an output node. This really just retrieves the value referenced by the output node and returns it.

Parameters:
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Returns:

The return value referenced by the output node

Return type:

Any

Note

Backwards-compatibility for this API is guaranteed.

placeholder(target: Callable[[...], Any] | str, args: Tuple[Tuple[Any, ...] | List[Any] | Dict[str, Any] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, Any]) Any[source]

To handle functions being passed as arguments (for example constraints) the tracer represents them as placeholder nodes. This method extracts the original function from the node, as stored in the target_to_function dict.

run(*args, initial_env: Dict[Node, Any] | None = None, enable_io_processing: bool = True) Any

Run module via interpretation and return the result.

Parameters:
  • *args – The arguments to the Module to run, in positional order

  • initial_env (Optional[Dict[Node, Any]]) – An optional starting environment for execution. This is a dict mapping Node to any value. This can be used, for example, to pre-populate results for certain Nodes so as to do only partial evaluation within the interpreter.

  • enable_io_processing (bool) – If true, we process the inputs and outputs with graph’s process_inputs and process_outputs function first before using them.

Returns:

The value returned from executing the Module

Return type:

Any

Note

Backwards-compatibility for this API is guaranteed.

run_node(n: Node) Any[source]

Run a specific node n and return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending on node.op

Parameters:

n (Node) – The Node to execute

Returns:

The result of executing n

Return type:

Any

Note

Backwards-compatibility for this API is guaranteed.