cudnn.h

Helper for cuDNN initialization.

namespace transformer_engine

Namespace containing C++ API of Transformer Engine.

Functions

void nvte_cudnn_handle_init()

TE/JAX cudaGraph requires the cuDNN initialization to happen outside of the capturing region. This function is a helper to call cudnnCreate() which allocate memory for the handle. The function will be called in the initialize() phase of the related XLA custom calls.