3.2.2. unit_scaling.analysis.graph_to_dataframe

unit_scaling.analysis.graph_to_dataframe(g: Graph) DataFrame[source]

Converts a torch.fx.Graph with annotated unit_scaling.transforms.Metrics into a pandas.DataFrame.

This graph is indended to have been generated by applying unit_scaling.transforms.track_scales() to an arbitrary torch.nn.Module, running a forward (and possibly backward) pass, then calling the module.scales_graph() function.

The resulting dataframe contains all the metrics information for the module, and is used internally by the unit_scaling.analysis.plot() function.

Parameters:

g (Graph) – the input graph.

Returns:

the metrics dataframe.

Return type:

pd.DataFrame