Transformer Engine
2.4.0
Version select:
  • Home

Getting Started

  • Installation
    • Prerequisites
    • Transformer Engine in NGC Containers
    • pip - from PyPI
    • pip - from GitHub
      • Additional Prerequisites
      • Installation (stable release)
      • Installation (development build)
      • Installation (from source)
      • Troubleshooting
  • Getting Started
    • Overview
    • Let’s build a Transformer layer!
    • Meet Transformer Engine
    • Fused TE Modules
    • Enabling FP8
  • Frequently Asked Questions (FAQ)
    • FP8 checkpoint compatibility

Python API documentation

  • Common API
    • Format
    • DelayedScaling
    • MXFP8BlockScaling
    • Float8CurrentScaling
    • Float8BlockScaling
  • Framework-specific API
    • pyTorch
      • Linear
        • Linear.forward()
        • Linear.set_tensor_parallel_group()
      • GroupedLinear
        • GroupedLinear.forward()
        • GroupedLinear.set_tensor_parallel_group()
      • LayerNorm
      • RMSNorm
      • LayerNormLinear
        • LayerNormLinear.forward()
        • LayerNormLinear.set_tensor_parallel_group()
      • LayerNormMLP
        • LayerNormMLP.forward()
        • LayerNormMLP.set_tensor_parallel_group()
      • DotProductAttention
        • DotProductAttention.forward()
        • DotProductAttention.set_context_parallel_group()
      • MultiheadAttention
        • MultiheadAttention.forward()
        • MultiheadAttention.set_context_parallel_group()
        • MultiheadAttention.set_tensor_parallel_group()
      • TransformerLayer
        • TransformerLayer.forward()
        • TransformerLayer.set_context_parallel_group()
        • TransformerLayer.set_tensor_parallel_group()
      • CudaRNGStatesTracker
        • CudaRNGStatesTracker.add()
        • CudaRNGStatesTracker.fork()
        • CudaRNGStatesTracker.get_states()
        • CudaRNGStatesTracker.reset()
        • CudaRNGStatesTracker.set_states()
      • fp8_autocast()
      • fp8_model_init()
      • checkpoint()
      • make_graphed_callables()
      • get_cpu_offload_context()
      • moe_permute()
      • moe_permute_with_probs()
      • moe_unpermute()
      • moe_sort_chunks_by_index()
      • moe_sort_chunks_by_index_with_probs()
      • initialize_ub()
      • destroy_ub()
    • Jax
      • Pre-defined Variable of Logical Axes
      • Modules
        • TransformerLayerType
        • MeshResource
        • fp8_autocast()
        • update_collections()
        • LayerNorm
        • DenseGeneral
        • LayerNormDenseGeneral
        • LayerNormMLP
        • RelativePositionBiases
        • DotProductAttention
        • MultiHeadAttention
        • TransformerLayer
        • extend_logical_axis_rules()

Examples and Tutorials

  • Using FP8 with Transformer Engine
    • Introduction to FP8
      • Structure
      • Mixed precision training - a quick introduction
      • Mixed precision training with FP8
    • MXFP8 and block scaling
      • MXFP8 vs FP8
      • Handling transposes
    • Using FP8 with Transformer Engine
      • FP8 recipe
      • FP8 autocasting
      • Handling backward pass
      • Precision
  • Performance Optimizations
    • Multi-GPU training
    • Gradient accumulation fusion
    • FP8 weight caching
  • Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine
    • Dependencies for this tutorial
    • Table of contents
    • From “Transformer” to “Llama”
    • Hugging Face’s LlamaModel
      • Hugging Face’s LlamaDecoderLayer
        • Self_Attn Layer
        • MLP Layer
    • [Baseline] Running HF LlamaModel (Precision: BF16)
    • [Improvement 1] Replace HF’s LlamaDecoderLayer with TE’s TransformerLayer (Precision: BF16)
      • Transformer Engine’s TransformerLayer
      • TransformerLayer options explained
      • Mapping weights from HF’s LlamaDecoderLayer to TE’s TransformerLayer
    • [Improvement 2] Replace HF’s LlamaDecoderLayer with TE’s TransformerLayer (Precision: FP8)
      • How to run the model in FP8 precision
      • Llama 3 performance results
    • Conclusion

Advanced

  • C/C++ API
    • transformer_engine.h
      • NVTETensor
      • NVTEQuantizationConfig
      • NVTEDType
        • NVTEDType::kNVTEByte
        • NVTEDType::kNVTEInt16
        • NVTEDType::kNVTEInt32
        • NVTEDType::kNVTEInt64
        • NVTEDType::kNVTEFloat32
        • NVTEDType::kNVTEFloat16
        • NVTEDType::kNVTEBFloat16
        • NVTEDType::kNVTEFloat8E4M3
        • NVTEDType::kNVTEFloat8E5M2
        • NVTEDType::kNVTEFloat8E8M0
        • NVTEDType::kNVTENumTypes
      • NVTETensorParam
        • NVTETensorParam::kNVTERowwiseData
        • NVTETensorParam::kNVTEColumnwiseData
        • NVTETensorParam::kNVTEScale
        • NVTETensorParam::kNVTEAmax
        • NVTETensorParam::kNVTERowwiseScaleInv
        • NVTETensorParam::kNVTEColumnwiseScaleInv
        • NVTETensorParam::kNVTENumTensorParams
      • NVTEScalingMode
        • NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING
        • NVTEScalingMode::NVTE_MXFP8_1D_SCALING
        • NVTEScalingMode::NVTE_BLOCK_SCALING_1D
        • NVTEScalingMode::NVTE_BLOCK_SCALING_2D
        • NVTEScalingMode::NVTE_INVALID_SCALING
      • NVTEQuantizationConfigAttribute
        • NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigForcePow2Scales
        • NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigAmaxEpsilon
        • NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigNoopTensor
        • NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigNumAttributes
      • nvte_create_tensor()
      • nvte_destroy_tensor()
      • nvte_tensor_data()
      • nvte_tensor_columnwise_data()
      • nvte_make_shape()
      • nvte_tensor_shape()
      • nvte_tensor_columnwise_shape()
      • nvte_tensor_ndims()
      • nvte_tensor_size()
      • nvte_tensor_numel()
      • nvte_tensor_element_size()
      • nvte_tensor_type()
      • nvte_tensor_amax()
      • nvte_tensor_scale()
      • nvte_tensor_scale_inv()
      • nvte_tensor_scale_inv_shape()
      • nvte_zero_tensor()
      • nvte_set_tensor_param()
      • nvte_get_tensor_param()
      • nvte_tensor_scaling_mode()
      • nvte_tensor_pack_create()
      • nvte_tensor_pack_destroy()
      • nvte_create_quantization_config()
      • nvte_get_quantization_config_attribute()
      • nvte_set_quantization_config_attribute()
      • nvte_destroy_quantization_config()
      • nvte_is_non_tn_fp8_gemm_supported()
      • nvte_memset()
      • NVTEShape
        • NVTEShape::data
        • NVTEShape::ndim
      • NVTEBasicTensor
        • NVTEBasicTensor::data_ptr
        • NVTEBasicTensor::dtype
        • NVTEBasicTensor::shape
      • NVTETensorPack
        • NVTETensorPack::tensors
        • NVTETensorPack::size
        • NVTETensorPack::MAX_SIZE
      • transformer_engine
        • transformer_engine::DType
        • transformer_engine::is_fp8_dtype()
        • transformer_engine::QuantizationConfigWrapper
        • transformer_engine::TensorWrapper
    • activation.h
      • NVTE_Activation_Type
        • NVTE_Activation_Type::GELU
        • NVTE_Activation_Type::GEGLU
        • NVTE_Activation_Type::SILU
        • NVTE_Activation_Type::SWIGLU
        • NVTE_Activation_Type::RELU
        • NVTE_Activation_Type::REGLU
        • NVTE_Activation_Type::QGELU
        • NVTE_Activation_Type::QGEGLU
        • NVTE_Activation_Type::SRELU
        • NVTE_Activation_Type::SREGLU
      • nvte_gelu()
      • nvte_silu()
      • nvte_relu()
      • nvte_qgelu()
      • nvte_srelu()
      • nvte_dgelu()
      • nvte_dsilu()
      • nvte_drelu()
      • nvte_dqgelu()
      • nvte_dsrelu()
      • nvte_geglu()
      • nvte_swiglu()
      • nvte_reglu()
      • nvte_qgeglu()
      • nvte_sreglu()
      • nvte_dgeglu()
      • nvte_dswiglu()
      • nvte_dreglu()
      • nvte_dqgeglu()
      • nvte_dsreglu()
    • cast_transpose_noop.h
      • nvte_transpose_with_noop()
      • nvte_cast_transpose_with_noop()
    • cast.h
      • nvte_quantize()
      • nvte_quantize_noop()
      • nvte_quantize_v2()
      • nvte_quantize_dbias()
      • nvte_quantize_dbias_dgelu()
      • nvte_quantize_dbias_dsilu()
      • nvte_quantize_dbias_drelu()
      • nvte_quantize_dbias_dqgelu()
      • nvte_quantize_dbias_dsrelu()
      • nvte_dequantize()
    • cudnn.h
      • transformer_engine
        • transformer_engine::nvte_cudnn_handle_init()
    • fused_attn.h
      • NVTE_QKV_Layout
        • NVTE_QKV_Layout::NVTE_SB3HD
        • NVTE_QKV_Layout::NVTE_SBH3D
        • NVTE_QKV_Layout::NVTE_SBHD_SB2HD
        • NVTE_QKV_Layout::NVTE_SBHD_SBH2D
        • NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD
        • NVTE_QKV_Layout::NVTE_BS3HD
        • NVTE_QKV_Layout::NVTE_BSH3D
        • NVTE_QKV_Layout::NVTE_BSHD_BS2HD
        • NVTE_QKV_Layout::NVTE_BSHD_BSH2D
        • NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_T3HD
        • NVTE_QKV_Layout::NVTE_TH3D
        • NVTE_QKV_Layout::NVTE_THD_T2HD
        • NVTE_QKV_Layout::NVTE_THD_TH2D
        • NVTE_QKV_Layout::NVTE_THD_THD_THD
        • NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD
        • NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD
        • NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD
      • NVTE_QKV_Layout_Group
        • NVTE_QKV_Layout_Group::NVTE_3HD
        • NVTE_QKV_Layout_Group::NVTE_H3D
        • NVTE_QKV_Layout_Group::NVTE_HD_2HD
        • NVTE_QKV_Layout_Group::NVTE_HD_H2D
        • NVTE_QKV_Layout_Group::NVTE_HD_HD_HD
        • NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD
      • NVTE_QKV_Format
        • NVTE_QKV_Format::NVTE_SBHD
        • NVTE_QKV_Format::NVTE_BSHD
        • NVTE_QKV_Format::NVTE_THD
        • NVTE_QKV_Format::NVTE_BSHD_2SBHD
        • NVTE_QKV_Format::NVTE_SBHD_2BSHD
        • NVTE_QKV_Format::NVTE_THD_2BSHD
        • NVTE_QKV_Format::NVTE_THD_2SBHD
      • NVTE_Bias_Type
        • NVTE_Bias_Type::NVTE_NO_BIAS
        • NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS
        • NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
        • NVTE_Bias_Type::NVTE_ALIBI
      • NVTE_Mask_Type
        • NVTE_Mask_Type::NVTE_NO_MASK
        • NVTE_Mask_Type::NVTE_PADDING_MASK
        • NVTE_Mask_Type::NVTE_CAUSAL_MASK
        • NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
        • NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK
        • NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
      • NVTE_Fused_Attn_Backend
        • NVTE_Fused_Attn_Backend::NVTE_No_Backend
        • NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen
        • NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen
        • NVTE_Fused_Attn_Backend::NVTE_FP8
      • nvte_get_qkv_layout_group()
      • nvte_get_qkv_format()
      • nvte_get_q_format()
      • nvte_get_kv_format()
      • nvte_get_fused_attn_backend()
      • nvte_fused_attn_fwd_qkvpacked()
      • nvte_fused_attn_bwd_qkvpacked()
      • nvte_fused_attn_fwd_kvpacked()
      • nvte_fused_attn_bwd_kvpacked()
      • nvte_fused_attn_fwd()
      • nvte_fused_attn_bwd()
      • nvte_populate_rng_state_async()
      • nvte_get_runtime_num_segments()
      • nvte_extract_seed_and_offset()
      • nvte_copy_to_kv_cache()
      • nvte_cp_thd_read_half_tensor()
      • nvte_cp_thd_second_half_lse_correction()
      • nvte_cp_thd_read_second_half_lse()
      • nvte_cp_thd_out_correction()
      • nvte_cp_thd_grad_correction()
      • nvte_cp_thd_get_partitioned_indices()
      • nvte_convert_thd_to_bshd()
      • nvte_convert_bshd_to_thd()
      • nvte_prepare_flash_attn_fwd()
      • nvte_prepare_flash_attn_bwd()
    • fused_rope.h
      • nvte_fused_rope_forward()
      • nvte_fused_rope_backward()
    • gemm.h
      • nvte_cublas_gemm()
      • nvte_cublas_atomic_gemm()
      • nvte_multi_stream_cublas_gemm()
      • transformer_engine
        • transformer_engine::nvte_cublas_handle_init()
        • transformer_engine::num_streams
    • multi_tensor.h
      • nvte_multi_tensor_l2norm_cuda()
      • nvte_multi_tensor_unscale_l2norm_cuda()
      • nvte_multi_tensor_adam_cuda()
      • nvte_multi_tensor_adam_param_remainder_cuda()
      • nvte_multi_tensor_adam_fp8_cuda()
      • nvte_multi_tensor_adam_capturable_cuda()
      • nvte_multi_tensor_adam_capturable_master_cuda()
      • nvte_multi_tensor_sgd_cuda()
      • nvte_multi_tensor_scale_cuda()
      • nvte_multi_tensor_compute_scale_and_scale_inv_cuda()
    • normalization.h
      • NVTE_Norm_Type
        • NVTE_Norm_Type::LayerNorm
        • NVTE_Norm_Type::RMSNorm
      • nvte_layernorm_fwd()
      • nvte_layernorm_bwd()
      • nvte_rmsnorm_fwd()
      • nvte_rmsnorm_bwd()
      • nvte_enable_cudnn_norm_fwd()
      • nvte_enable_cudnn_norm_bwd()
      • nvte_enable_zero_centered_gamma_in_weight_dtype()
    • padding.h
      • nvte_multi_padding()
    • permutation.h
      • nvte_permute()
      • nvte_unpermute()
      • nvte_device_radix_sort_pairs()
    • recipe.h
      • nvte_delayed_scaling_recipe_amax_and_scale_update()
      • nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction()
      • nvte_compute_amax()
      • nvte_compute_scale_from_amax()
      • nvte_fp8_block_scaling_compute_partial_amax()
      • nvte_fp8_block_scaling_partial_cast()
    • softmax.h
      • nvte_scaled_softmax_forward()
      • nvte_scaled_softmax_backward()
      • nvte_scaled_masked_softmax_forward()
      • nvte_scaled_masked_softmax_backward()
      • nvte_scaled_upper_triang_masked_softmax_forward()
      • nvte_scaled_upper_triang_masked_softmax_backward()
      • nvte_scaled_aligned_causal_masked_softmax_forward()
      • nvte_scaled_aligned_causal_masked_softmax_backward()
    • swizzle.h
      • nvte_swizzle_scaling_factors()
    • transpose.h
      • nvte_cast_transpose()
      • nvte_transpose()
      • nvte_cast_transpose_dbias()
      • nvte_fp8_transpose_dbias()
      • nvte_multi_cast_transpose()
      • nvte_cast_transpose_dbias_dgelu()
      • nvte_cast_transpose_dbias_dsilu()
      • nvte_cast_transpose_dbias_drelu()
      • nvte_cast_transpose_dbias_dqgelu()
      • nvte_cast_transpose_dbias_dsrelu()
      • nvte_dgeglu_cast_transpose()
      • nvte_dswiglu_cast_transpose()
      • nvte_dreglu_cast_transpose()
      • nvte_dqgeglu_cast_transpose()
      • nvte_dsreglu_cast_transpose()
  • Precision debug tools
    • Getting started
      • Example training script
      • Config file
      • Adjusting Python file
      • Inspecting the logs
      • Logging using TensorBoard
    • Config File Structure
      • General Format
      • Layer Specification
      • Names in Transformer Layers
      • Structured Configuration for GEMMs and Tensors
      • Enabling or Disabling Sections and Features
    • API
      • Setup
        • initialize()
        • set_tensor_reduction_group()
        • set_weight_tensor_tp_group_reduce()
      • Debug features
        • LogTensorStats
        • LogFp8TensorStats
        • DisableFP8GEMM
        • DisableFP8Layer
        • PerTensorScaling
        • FakeQuant
      • Calls to Nvidia-DL-Framework-Inspect
        • modify_tensor()
        • inspect_tensor()
        • inspect_tensor_postquantize()
        • modify_tensor_enabled()
        • fp8_gemm_enabled()
        • inspect_tensor_enabled()
        • inspect_tensor_postquantize_enabled()
    • Distributed training
      • Behavior of the features
      • Reduction groups
      • Microbatching
      • Logging to files and TensorBoard
  • Attention Is All You Need!
    • 1. Attention Backends
      • 1.1 Flash vs. Non-Flash
      • 1.2 flash-attention
      • 1.3 cuDNN Attention
    • 2. Backend Selection
      • 2.1 Debug Information
      • 2.2 User Control
      • 2.3 Example Tests
    • 3. Backend Support
      • 3.1 QKV Layout
      • 3.2 Attention Mask
      • 3.3 Attention Bias
      • 3.4 FP8 Attention
Transformer Engine
  • Search


Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact

© Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved..