unit-scaling
Contents
1. User guide
2. Developer guide
3. Limitations
4. Blog
5. API reference
unit-scaling
Index
Index
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
L
|
M
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
A
add() (in module unit_scaling.functional)
amean() (in module unit_scaling.constraints)
analyse_module() (in module unit_scaling.utils)
apply_constraint() (in module unit_scaling.constraints)
apply_transform() (in module unit_scaling.transforms.utils)
B
backward() (unit_scaling.utils.ScaleTracker static method)
bias (unit_scaling.LayerNorm attribute)
(unit_scaling.Linear attribute)
bits (unit_scaling.formats.FPFormat property)
boxed_run() (unit_scaling.utils.ScaleTrackingInterpreter method)
C
call_function() (unit_scaling.utils.ScaleTrackingInterpreter method)
call_method() (unit_scaling.utils.ScaleTrackingInterpreter method)
call_module() (unit_scaling.utils.ScaleTrackingInterpreter method)
compile() (in module unit_scaling.transforms)
cross_entropy() (in module unit_scaling.functional)
CrossEntropyLoss (class in unit_scaling)
D
Dropout (class in unit_scaling)
dropout() (in module unit_scaling.functional)
E
Embedding (class in unit_scaling)
embedding() (in module unit_scaling.functional)
example_batch() (in module unit_scaling.analysis)
F
fetch_args_kwargs_from_env() (unit_scaling.utils.ScaleTrackingInterpreter method)
fetch_attr() (unit_scaling.utils.ScaleTrackingInterpreter method)
format_to_tuple() (in module unit_scaling.formats)
FPFormat (class in unit_scaling.formats)
from_pretrained() (unit_scaling.Embedding class method)
G
GELU (class in unit_scaling)
gelu() (in module unit_scaling.functional)
get_attr() (unit_scaling.utils.ScaleTrackingInterpreter method)
gmean() (in module unit_scaling.constraints)
graph_to_dataframe() (in module unit_scaling.analysis)
H
hmean() (in module unit_scaling.constraints)
I
is_traceable (unit_scaling.utils.ScaleTracker attribute)
J
jvp() (unit_scaling.utils.ScaleTracker static method)
L
layer_norm() (in module unit_scaling.functional)
LayerNorm (class in unit_scaling)
Linear (class in unit_scaling)
linear() (in module unit_scaling.functional)
M
map_nodes_to_values() (unit_scaling.utils.ScaleTrackingInterpreter method)
mark_dirty() (unit_scaling.utils.ScaleTracker method)
mark_non_differentiable() (unit_scaling.utils.ScaleTracker method)
matmul() (in module unit_scaling.functional)
max_absolute_value (unit_scaling.formats.FPFormat property)
Metrics (class in unit_scaling.transforms)
Metrics.Data (class in unit_scaling.transforms)
MHSA (class in unit_scaling)
min_absolute_normal (unit_scaling.formats.FPFormat property)
min_absolute_subnormal (unit_scaling.formats.FPFormat property)
MLP (class in unit_scaling)
module
unit_scaling
unit_scaling.analysis
unit_scaling.constraints
unit_scaling.formats
unit_scaling.functional
unit_scaling.scale
unit_scaling.transforms
unit_scaling.transforms.utils
unit_scaling.utils
O
output() (unit_scaling.utils.ScaleTrackingInterpreter method)
P
patch_to_expand_modules() (in module unit_scaling.transforms.utils)
placeholder() (unit_scaling.utils.ScaleTrackingInterpreter method)
plot() (in module unit_scaling.analysis)
prune_non_float_tensors() (in module unit_scaling.transforms)
prune_same_scale_tensors() (in module unit_scaling.transforms)
prune_selected_nodes() (in module unit_scaling.transforms)
Q
quantise() (unit_scaling.formats.FPFormat method)
quantise_bwd() (unit_scaling.formats.FPFormat method)
quantise_fwd() (unit_scaling.formats.FPFormat method)
R
replace_node_with_function() (in module unit_scaling.transforms.utils)
residual_add() (in module unit_scaling.functional)
residual_split() (in module unit_scaling.functional)
run() (unit_scaling.utils.ScaleTrackingInterpreter method)
run_node() (unit_scaling.utils.ScaleTrackingInterpreter method)
S
save_for_backward() (unit_scaling.utils.ScaleTracker method)
save_for_forward() (unit_scaling.utils.ScaleTracker method)
scale_bwd() (in module unit_scaling.scale)
scale_elementwise() (in module unit_scaling.functional)
scale_fwd() (in module unit_scaling.scale)
scaled_dot_product_attention() (in module unit_scaling.functional)
ScalePair (class in unit_scaling.utils)
ScaleTracker (class in unit_scaling.utils)
ScaleTrackingInterpreter (class in unit_scaling.utils)
set_materialize_grads() (unit_scaling.utils.ScaleTracker method)
setup_context() (unit_scaling.utils.ScaleTracker static method)
simulate_format() (in module unit_scaling.transforms)
simulate_fp8() (in module unit_scaling.transforms)
Softmax (class in unit_scaling)
softmax() (in module unit_scaling.functional)
T
to_grad_input_scale() (in module unit_scaling.constraints)
to_left_grad_scale() (in module unit_scaling.constraints)
to_output_scale() (in module unit_scaling.constraints)
to_right_grad_scale() (in module unit_scaling.constraints)
torch_nn_modules_to_user_modules() (in module unit_scaling.transforms.utils)
track_scales() (in module unit_scaling.transforms)
TransformerDecoder (class in unit_scaling)
TransformerLayer (class in unit_scaling)
tuple_to_format() (in module unit_scaling.formats)
U
unit_scale() (in module unit_scaling.transforms)
unit_scaling
module
unit_scaling.analysis
module
unit_scaling.constraints
module
unit_scaling.formats
module
unit_scaling.functional
module
unit_scaling.scale
module
unit_scaling.transforms
module
unit_scaling.transforms.utils
module
unit_scaling.utils
module
V
visualiser() (in module unit_scaling)
(in module unit_scaling.analysis)
vjp() (unit_scaling.utils.ScaleTracker static method)
vmap() (unit_scaling.utils.ScaleTracker static method)
W
weight (unit_scaling.Embedding attribute)
(unit_scaling.LayerNorm attribute)
(unit_scaling.Linear attribute)