Getting Started
Python API documentation
Format
DelayedScaling
MXFP8BlockScaling
Float8CurrentScaling
Float8BlockScaling
Linear
Linear.forward()
Linear.set_tensor_parallel_group()
GroupedLinear
GroupedLinear.forward()
GroupedLinear.set_tensor_parallel_group()
LayerNorm
RMSNorm
LayerNormLinear
LayerNormLinear.forward()
LayerNormLinear.set_tensor_parallel_group()
LayerNormMLP
LayerNormMLP.forward()
LayerNormMLP.set_tensor_parallel_group()
DotProductAttention
DotProductAttention.forward()
DotProductAttention.set_context_parallel_group()
MultiheadAttention
MultiheadAttention.forward()
MultiheadAttention.set_context_parallel_group()
MultiheadAttention.set_tensor_parallel_group()
TransformerLayer
TransformerLayer.forward()
TransformerLayer.set_context_parallel_group()
TransformerLayer.set_tensor_parallel_group()
CudaRNGStatesTracker
CudaRNGStatesTracker.add()
CudaRNGStatesTracker.fork()
CudaRNGStatesTracker.get_states()
CudaRNGStatesTracker.reset()
CudaRNGStatesTracker.set_states()
fp8_autocast()
fp8_model_init()
checkpoint()
make_graphed_callables()
get_cpu_offload_context()
moe_permute()
moe_permute_with_probs()
moe_unpermute()
moe_sort_chunks_by_index()
moe_sort_chunks_by_index_with_probs()
initialize_ub()
destroy_ub()
TransformerLayerType
MeshResource
update_collections()
DenseGeneral
LayerNormDenseGeneral
RelativePositionBiases
MultiHeadAttention
extend_logical_axis_rules()
Examples and Tutorials
LlamaModel
LlamaDecoderLayer
BF16
FP8
Advanced
NVTETensor
NVTEQuantizationConfig
NVTEDType
NVTEDType::kNVTEByte
NVTEDType::kNVTEInt16
NVTEDType::kNVTEInt32
NVTEDType::kNVTEInt64
NVTEDType::kNVTEFloat32
NVTEDType::kNVTEFloat16
NVTEDType::kNVTEBFloat16
NVTEDType::kNVTEFloat8E4M3
NVTEDType::kNVTEFloat8E5M2
NVTEDType::kNVTEFloat8E8M0
NVTEDType::kNVTENumTypes
NVTETensorParam
NVTETensorParam::kNVTERowwiseData
NVTETensorParam::kNVTEColumnwiseData
NVTETensorParam::kNVTEScale
NVTETensorParam::kNVTEAmax
NVTETensorParam::kNVTERowwiseScaleInv
NVTETensorParam::kNVTEColumnwiseScaleInv
NVTETensorParam::kNVTENumTensorParams
NVTEScalingMode
NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING
NVTEScalingMode::NVTE_MXFP8_1D_SCALING
NVTEScalingMode::NVTE_BLOCK_SCALING_1D
NVTEScalingMode::NVTE_BLOCK_SCALING_2D
NVTEScalingMode::NVTE_INVALID_SCALING
NVTEQuantizationConfigAttribute
NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigForcePow2Scales
NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigAmaxEpsilon
NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigNoopTensor
NVTEQuantizationConfigAttribute::kNVTEQuantizationConfigNumAttributes
nvte_create_tensor()
nvte_destroy_tensor()
nvte_tensor_data()
nvte_tensor_columnwise_data()
nvte_make_shape()
nvte_tensor_shape()
nvte_tensor_columnwise_shape()
nvte_tensor_ndims()
nvte_tensor_size()
nvte_tensor_numel()
nvte_tensor_element_size()
nvte_tensor_type()
nvte_tensor_amax()
nvte_tensor_scale()
nvte_tensor_scale_inv()
nvte_tensor_scale_inv_shape()
nvte_zero_tensor()
nvte_set_tensor_param()
nvte_get_tensor_param()
nvte_tensor_scaling_mode()
nvte_tensor_pack_create()
nvte_tensor_pack_destroy()
nvte_create_quantization_config()
nvte_get_quantization_config_attribute()
nvte_set_quantization_config_attribute()
nvte_destroy_quantization_config()
nvte_is_non_tn_fp8_gemm_supported()
nvte_memset()
NVTEShape
NVTEShape::data
NVTEShape::ndim
NVTEBasicTensor
NVTEBasicTensor::data_ptr
NVTEBasicTensor::dtype
NVTEBasicTensor::shape
NVTETensorPack
NVTETensorPack::tensors
NVTETensorPack::size
NVTETensorPack::MAX_SIZE
transformer_engine
transformer_engine::DType
transformer_engine::is_fp8_dtype()
transformer_engine::QuantizationConfigWrapper
transformer_engine::TensorWrapper
NVTE_Activation_Type
NVTE_Activation_Type::GELU
NVTE_Activation_Type::GEGLU
NVTE_Activation_Type::SILU
NVTE_Activation_Type::SWIGLU
NVTE_Activation_Type::RELU
NVTE_Activation_Type::REGLU
NVTE_Activation_Type::QGELU
NVTE_Activation_Type::QGEGLU
NVTE_Activation_Type::SRELU
NVTE_Activation_Type::SREGLU
nvte_gelu()
nvte_silu()
nvte_relu()
nvte_qgelu()
nvte_srelu()
nvte_dgelu()
nvte_dsilu()
nvte_drelu()
nvte_dqgelu()
nvte_dsrelu()
nvte_geglu()
nvte_swiglu()
nvte_reglu()
nvte_qgeglu()
nvte_sreglu()
nvte_dgeglu()
nvte_dswiglu()
nvte_dreglu()
nvte_dqgeglu()
nvte_dsreglu()
nvte_transpose_with_noop()
nvte_cast_transpose_with_noop()
nvte_quantize()
nvte_quantize_noop()
nvte_quantize_v2()
nvte_quantize_dbias()
nvte_quantize_dbias_dgelu()
nvte_quantize_dbias_dsilu()
nvte_quantize_dbias_drelu()
nvte_quantize_dbias_dqgelu()
nvte_quantize_dbias_dsrelu()
nvte_dequantize()
transformer_engine::nvte_cudnn_handle_init()
NVTE_QKV_Layout
NVTE_QKV_Layout::NVTE_SB3HD
NVTE_QKV_Layout::NVTE_SBH3D
NVTE_QKV_Layout::NVTE_SBHD_SB2HD
NVTE_QKV_Layout::NVTE_SBHD_SBH2D
NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD
NVTE_QKV_Layout::NVTE_BS3HD
NVTE_QKV_Layout::NVTE_BSH3D
NVTE_QKV_Layout::NVTE_BSHD_BS2HD
NVTE_QKV_Layout::NVTE_BSHD_BSH2D
NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_T3HD
NVTE_QKV_Layout::NVTE_TH3D
NVTE_QKV_Layout::NVTE_THD_T2HD
NVTE_QKV_Layout::NVTE_THD_TH2D
NVTE_QKV_Layout::NVTE_THD_THD_THD
NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD
NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD
NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD
NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD
NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD
NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD
NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group::NVTE_3HD
NVTE_QKV_Layout_Group::NVTE_H3D
NVTE_QKV_Layout_Group::NVTE_HD_2HD
NVTE_QKV_Layout_Group::NVTE_HD_H2D
NVTE_QKV_Layout_Group::NVTE_HD_HD_HD
NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD
NVTE_QKV_Format
NVTE_QKV_Format::NVTE_SBHD
NVTE_QKV_Format::NVTE_BSHD
NVTE_QKV_Format::NVTE_THD
NVTE_QKV_Format::NVTE_BSHD_2SBHD
NVTE_QKV_Format::NVTE_SBHD_2BSHD
NVTE_QKV_Format::NVTE_THD_2BSHD
NVTE_QKV_Format::NVTE_THD_2SBHD
NVTE_Bias_Type
NVTE_Bias_Type::NVTE_NO_BIAS
NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS
NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
NVTE_Bias_Type::NVTE_ALIBI
NVTE_Mask_Type
NVTE_Mask_Type::NVTE_NO_MASK
NVTE_Mask_Type::NVTE_PADDING_MASK
NVTE_Mask_Type::NVTE_CAUSAL_MASK
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
NVTE_Fused_Attn_Backend
NVTE_Fused_Attn_Backend::NVTE_No_Backend
NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen
NVTE_Fused_Attn_Backend::NVTE_FP8
nvte_get_qkv_layout_group()
nvte_get_qkv_format()
nvte_get_q_format()
nvte_get_kv_format()
nvte_get_fused_attn_backend()
nvte_fused_attn_fwd_qkvpacked()
nvte_fused_attn_bwd_qkvpacked()
nvte_fused_attn_fwd_kvpacked()
nvte_fused_attn_bwd_kvpacked()
nvte_fused_attn_fwd()
nvte_fused_attn_bwd()
nvte_populate_rng_state_async()
nvte_get_runtime_num_segments()
nvte_extract_seed_and_offset()
nvte_copy_to_kv_cache()
nvte_cp_thd_read_half_tensor()
nvte_cp_thd_second_half_lse_correction()
nvte_cp_thd_read_second_half_lse()
nvte_cp_thd_out_correction()
nvte_cp_thd_grad_correction()
nvte_cp_thd_get_partitioned_indices()
nvte_convert_thd_to_bshd()
nvte_convert_bshd_to_thd()
nvte_prepare_flash_attn_fwd()
nvte_prepare_flash_attn_bwd()
nvte_fused_rope_forward()
nvte_fused_rope_backward()
nvte_cublas_gemm()
nvte_cublas_atomic_gemm()
nvte_multi_stream_cublas_gemm()
transformer_engine::nvte_cublas_handle_init()
transformer_engine::num_streams
nvte_multi_tensor_l2norm_cuda()
nvte_multi_tensor_unscale_l2norm_cuda()
nvte_multi_tensor_adam_cuda()
nvte_multi_tensor_adam_param_remainder_cuda()
nvte_multi_tensor_adam_fp8_cuda()
nvte_multi_tensor_adam_capturable_cuda()
nvte_multi_tensor_adam_capturable_master_cuda()
nvte_multi_tensor_sgd_cuda()
nvte_multi_tensor_scale_cuda()
nvte_multi_tensor_compute_scale_and_scale_inv_cuda()
NVTE_Norm_Type
NVTE_Norm_Type::LayerNorm
NVTE_Norm_Type::RMSNorm
nvte_layernorm_fwd()
nvte_layernorm_bwd()
nvte_rmsnorm_fwd()
nvte_rmsnorm_bwd()
nvte_enable_cudnn_norm_fwd()
nvte_enable_cudnn_norm_bwd()
nvte_enable_zero_centered_gamma_in_weight_dtype()
nvte_multi_padding()
nvte_permute()
nvte_unpermute()
nvte_device_radix_sort_pairs()
nvte_delayed_scaling_recipe_amax_and_scale_update()
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction()
nvte_compute_amax()
nvte_compute_scale_from_amax()
nvte_fp8_block_scaling_compute_partial_amax()
nvte_fp8_block_scaling_partial_cast()
nvte_scaled_softmax_forward()
nvte_scaled_softmax_backward()
nvte_scaled_masked_softmax_forward()
nvte_scaled_masked_softmax_backward()
nvte_scaled_upper_triang_masked_softmax_forward()
nvte_scaled_upper_triang_masked_softmax_backward()
nvte_scaled_aligned_causal_masked_softmax_forward()
nvte_scaled_aligned_causal_masked_softmax_backward()
nvte_swizzle_scaling_factors()
nvte_cast_transpose()
nvte_transpose()
nvte_cast_transpose_dbias()
nvte_fp8_transpose_dbias()
nvte_multi_cast_transpose()
nvte_cast_transpose_dbias_dgelu()
nvte_cast_transpose_dbias_dsilu()
nvte_cast_transpose_dbias_drelu()
nvte_cast_transpose_dbias_dqgelu()
nvte_cast_transpose_dbias_dsrelu()
nvte_dgeglu_cast_transpose()
nvte_dswiglu_cast_transpose()
nvte_dreglu_cast_transpose()
nvte_dqgeglu_cast_transpose()
nvte_dsreglu_cast_transpose()
LogTensorStats
LogFp8TensorStats
DisableFP8GEMM
DisableFP8Layer
PerTensorScaling
FakeQuant
modify_tensor()
inspect_tensor()
inspect_tensor_postquantize()
modify_tensor_enabled()
fp8_gemm_enabled()
inspect_tensor_enabled()
inspect_tensor_postquantize_enabled()