API Reference¶
This section provides detailed documentation of all public modules and classes.
Configuration¶
- class transformer.config.TransformerConfig(n_layers=12, d_model=1536, n_heads=32, n_kv_heads=None, vocab_size=50000, d_ff=None, norm_design='pre_norm', norm_class='rms_norm', ffn_class='SwiGLU', attn_class='MHA', block_class=None, attn_bias=False, ffn_bias=True, lm_head_bias=False, attn_qk_norm=True, attn_dropout=0.0, tied_weights=False, seq_len=1024, pos_encoding='RoPE', rope_base=10000.0, max_seq_len=4096, **kwargs)[source]¶
Bases:
PreTrainedConfigConfiguration class for Transformer models. Inherits from PretrainedConfig for HuggingFace compatibility.
- Parameters:
n_layers (int) – Number of Transformer Blocks (layers).
d_model (int) – Model Dimension.
n_heads (int) – Number of Attention Heads.
n_kv_heads (int, optional) – Number of key/value heads for Grouped-Query Attention(GQA). Default:
n_headsvocab_size (int) – Vocabulary size of the model. Defines the number of different tokens.
d_ff (int, optional) – Dimension of the Feed-Forward Hidden Layer.
norm_design (str) – Normalization Design, one of
pre-norm,post-normorboth. Default:pre-normnorm_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –
Normalization class or type. - If
str, one ofrms_normorlayer_norm. - IfType[nn.Module]then will be instantiated inside the model.Should have the same API as a torch Normalization Layer.
If
List[Union[Type[nn.Module], str]]and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers.
ffn_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –
Feed-Forward Network class or type. - If
str, one ofSwiGLU,MLP. - IfType[nn.Module]then will be instantiated inside the model.Should have the same API as
SwiGLUandMLP. DefaultSwiGLUIf
List[Union[Type[nn.Module], str]]and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers. DefaultSwiGLUfor every layer.
attn_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –
Attention class or type. - If
str, one ofMHA,GQA,CrossAttention. ForGQA, also specify n_kv_heads. - IfType[nn.Module]then will be instantiated inside the model.Should have the same API as
transformer.attn.MHA. DefaultMHAIf
List[Union[Type[nn.Module], str]]and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers. DefaultSwiGLUfor every layer.
block_class (Optional[Type[nn.Module]]) – Transformer Block class for every layer. Default:
None- IfType[nn.Module]then will be instantiated for every layer inside the model. - IfNonethen the defaulttransformer.TransformerBlockwill be usedattn_bias (bool, optional) – Whether to use bias in attention Linear Projections. Default:
Falseffn_bias (bool, optional) – Whether to use bias in Feed-Forward Linear layers. Default:
Truelm_head_bias (bool, optional) – Whether to use bias in the Language Modeling Head. Default:
Falseattn_qk_norm (bool, optional) – Whether to apply Normalization to Queries and Keys before the Attention Computation. Default:
Trueattn_dropout (float, optional) – Dropout probability for the Attention Layer. Default:
0.0tied_weights (bool, optional) – If True, tie the input embedding and output projection weights. Default:
Falseseq_len (int) – Sequence Length.
pos_encoding (Union[List[str], str]) –
Positional Encoding for attention. - If
List[Union[Type[nn.Module], str]]and len(ffn_class) == n_layersthen will be instantiated inside the model for the corresponding layers. Default
SwiGLUfor every layer.If
strone ofRoPE,AliBI,PartialRoPE. Default:RoPE
Note: Is recommended to change the default to
PartialRoPEwhich is used in SOTA models like Qwen3-Next-80B-A3Brope_base (float, optional) – Base for the Exponential Frequency Calculation in RoPE. Default:
10000.0max_seq_len (int) – Maximum sequence length for positional embeddings.
kwargs (dict, optional) – Additional keyword arguments passed to PretrainedConfig
- model_type = 'transformer'¶
- __init__(n_layers=12, d_model=1536, n_heads=32, n_kv_heads=None, vocab_size=50000, d_ff=None, norm_design='pre_norm', norm_class='rms_norm', ffn_class='SwiGLU', attn_class='MHA', block_class=None, attn_bias=False, ffn_bias=True, lm_head_bias=False, attn_qk_norm=True, attn_dropout=0.0, tied_weights=False, seq_len=1024, pos_encoding='RoPE', rope_base=10000.0, max_seq_len=4096, **kwargs)[source]¶
- Parameters:
n_layers (int)
d_model (int)
n_heads (int)
n_kv_heads (int | None)
vocab_size (int)
d_ff (int | None)
norm_design (str)
norm_class (List[Type[Module] | str] | Type[Module] | str)
ffn_class (List[Type[Module] | str] | Type[Module] | str)
attn_class (List[Type[Module] | str] | Type[Module] | str)
block_class (Type[Module] | None)
attn_bias (bool)
ffn_bias (bool)
lm_head_bias (bool)
attn_qk_norm (bool)
attn_dropout (float | None)
tied_weights (bool)
seq_len (int)
pos_encoding (str)
rope_base (float)
max_seq_len (int)
kwargs (Dict)
Attention Modules¶
Multi-Head Attention (MHA)¶
- class transformer.attns.MHA(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]¶
Bases:
ModuleMulti-Head Attention
MHAmodule using the optimized implementation oftorch.nn.functional.scaled_dot_product_attention()when possible.- Parameters:
d_model (int) – Model dimension.
n_heads (int) – Number of attention heads. Note that
d_modelwill be split acrossn_heads(i.e. each head will have dimensiond_head//n_heads).dropout (float, optional) – Dropout probability on
attn_output_weights. Default:0.0(no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.attn_bias (bool, optional) – Whether to use bias in linear projections. Default:
Falseqk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default:
Truelayer_idx (int, optional) – Index of the layer (used for debugging/logging).
pos_encoding (str, optional) – Positional Encoding to use. Default:
RoPEpos_encoding_kwargs (Dict, optional) – Dictionary of Additional Arguments for Positional Encoding. Example: {“rope_base”: 10000.0, “rot_frac”: 0.5}.
max_seq_len (int) – Maximum sequence length for RoPE.
- __init__(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
d_model (int)
n_heads (int)
dropout (float)
attn_bias (bool | None)
qk_norm (bool | None)
layer_idx (int)
pos_encoding (str)
pos_encoding_kwargs (Dict)
max_seq_len (int)
- forward(x, mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]¶
Forward pass of MHA.
- Parameters:
x (torch.Tensor) – Input tensor of shape
where
is the Sequence Length,
is the batch size, and
is the embedding dimension d_model.mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape
or
, where
is the batch size,
is the number of heads and
is the Sequence Length. A 2D mask will be broadcasted across the batch while a 4D mask allows
for a different mask for each entry in the batch and/or heads dimensions.
Note: Should be a boolean mask where True indicates masked positions.
When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.pos (torch.LongTensor, optional) – Position indices for RoPE, shape
or 
flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default:
False- Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default:torch.nn.attention.SPDBackend.FLASH_ATTENTION- bool: Whether the ordering of the backends is interpreted as their priority order. Default:Falsereturn_states (bool, optional) – If
True, return a dictionary of intermediate tensors. Default:False
- Returns:
Output tensor
if not return_states, else a dict containing
the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input}- Return type:
Union[torch.Tensor, Dict]
Grouped-Query Attention (GQA)¶
- class transformer.attns.GQA(d_model, n_heads, n_kv_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]¶
Bases:
ModuleGrouped Query Attention
GQAmodule using the optimized implementation oftorch.nn.functional.scaled_dot_product_attention()when possible.- Parameters:
d_model (int) – Model dimension.
n_heads (int) – Number of attention heads. Note that
d_modelwill be split acrossn_heads(i.e. each head will have dimensiond_head//n_heads).n_kv_heads (int) – Number of key/value heads (must divide n_heads).
dropout (float, optional) – Dropout probability on
attn_output_weights. Default:0.0(no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.attn_bias (bool, optional) – Whether to use bias in linear projections. Default:
Falseqk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default:
Truelayer_idx (int, optional) – Index of the layer (used for debugging/logging).
pos_encoding (str, optional) – Positional Encoding to use. Default:
RoPEpos_encoding_kwargs (Dict, optional) – Dictionary of Additional Arguments for Positional Encoding. Example: {“rope_base”: 10000.0, “rot_frac”: 0.5}.
max_seq_len (int) – Maximum sequence length for RoPE.
- __init__(d_model, n_heads, n_kv_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
d_model (int)
n_heads (int)
n_kv_heads (int)
dropout (float | None)
attn_bias (bool | None)
qk_norm (bool | None)
layer_idx (int)
pos_encoding (str)
pos_encoding_kwargs (Dict)
max_seq_len (int)
- forward(x, mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]¶
Forward pass of GQA.
- Parameters:
x (torch.Tensor) – Input tensor of shape
where
is the Sequence Length,
is the batch size, and
is the embedding dimension d_model.mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape
or
, where
is the batch size,
is the number of heads and
is the Sequence Length. A 2D mask will be broadcasted across the batch while a 4D mask allows
for a different mask for each entry in the batch and/or heads dimensions.
Note: Should be a boolean mask where True indicates masked positions.
When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.pos (torch.LongTensor, optional) – Position indices for RoPE, shape
or 
flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default:
False- Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default:torch.nn.attention.SPDBackend.FLASH_ATTENTION- bool: Whether the ordering of the backends is interpreted as their priority order. Default:Falsereturn_states (bool, optional) – If
True, return a dictionary of intermediate tensors. Default:False
- Returns:
Output tensor of shape
if not return_states, else a dict containing
the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input}- Return type:
Union[torch.Tensor, Dict]
Cross-Attention¶
- class transformer.attns.CrossAttention(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, rope_base=10000.0, max_seq_len=1024)[source]¶
Bases:
ModuleCrossAttention module using the optimized implementation of
torch.nn.functional.scaled_dot_product_attention()when possible.- Parameters:
d_model (int) – Model dimension.
n_heads (int) – Number of attention heads. Note that
d_modelwill be split acrossn_heads(i.e. each head will have dimensiond_head//n_heads).dropout (float, optional) – Dropout probability on
attn_output_weights. Default:0.0(no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.attn_bias (bool, optional) – Whether to use bias in linear projections. Default:
Falseqk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default:
Truelayer_idx (int, optional) – Index of the layer (used for debugging/logging).
rope_base (float, optional) – Base for the Exponential Frequency Calculation in RoPE. Default:
10000.0max_seq_len (int) – Maximum sequence length for RoPE.
- __init__(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, rope_base=10000.0, max_seq_len=1024)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
d_model (int)
n_heads (int)
dropout (float | None)
attn_bias (bool | None)
qk_norm (bool | None)
layer_idx (int)
rope_base (float)
max_seq_len (int)
- forward(queries, kv, mask=None, pos_q=None, pos_k=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]¶
Forward pass of CrossAttention.
- Parameters:
queries (torch.Tensor) – Input tensor of shape
where
is the Sequence Length for the query sequence,
is the batch size, and
is the embedding dimension d_model.kv (torch.Tensor) – Input tensor of shape
where
is the Sequence Length for the key/value sequence,
is the batch size, and
is the embedding dimension d_model.mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape
or
, where
is the batch size,
is the number of heads,
is the Sequence Length of the query sequence and
is the Sequence Length of the key/value sequence.
A 2D mask will be broadcasted across the batch while a 4D mask allows for a different mask for each entry
in the batch and/or heads dimensions.
Note: Should be a boolean mask where True indicates masked positions.
When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.pos_q (torch.LongTensor, optional) – Position indices for Queries, shape
or 
pos_k (torch.LongTensor, optional) – Position indices for Keys, shape
or 
flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default:
False- Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default:torch.nn.attention.SPDBackend.FLASH_ATTENTION- bool: Whether the ordering of the backends is interpreted as their priority order. Default:Falsereturn_states (bool, optional) – If True, return dictionary of intermediates tensors. Default: False
- Returns:
Output tensor of shape
if not return_states, else a dict containing
the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input} where input is a tuple (queries, kv)- Return type:
Union[torch.Tensor, Dict]
Positional Embeddings¶
RoPE (Rotary Position Embedding)¶
- class transformer.pos.RoPE(max_seq_len, d_head, rope_base=10000.0, persistent=True)[source]¶
Bases:
ModuleRotary Position Embedding (RoPE) module.
- Parameters:
max_seq_len (int) – Maximum sequence length for which to precompute frequencies.
d_head (int) – Dimension per head (must be even).
rope_base (float, optional) – Base for the exponential frequency calculation. Default:
10000.0persistent (bool, optional) – Whether to register the precomputed cos/sin as persistent buffers. Default:
True
- __init__(max_seq_len, d_head, rope_base=10000.0, persistent=True)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
max_seq_len (int)
d_head (int)
rope_base (float)
persistent (bool)
- forward(q, k, pos_q, pos_k)[source]¶
Apply rotary position embeddings to queries and keys.
- Parameters:
q (torch.Tensor) – Query tensor of shape

k (torch.Tensor) – Key tensor of shape

pos_q (torch.LongTensor) – Positions for queries, shape
or 
pos_k (torch.LongTensor) – Positions for keys, shape
or 
- Returns:
Rotated queries and keys.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Feed-Forward Modules¶
SwiGLU¶
- class transformer.ffn.SwiGLU(d_model, d_ff, bias=True)[source]¶
Bases:
ModuleSwiGLU feed-forward module
- Parameters:
d_model (int) – Model dimension.
d_ff (int) – Intermediate dimension (should be even, as it’s split into two halves).
bias (bool, optional) – Whether to use bias in linear layers. Default:
True
- __init__(d_model, d_ff, bias=True)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
d_model (int)
d_ff (int)
bias (bool | None)
- forward(x, return_states=False)[source]¶
Forward pass of SwiGLU.
- Parameters:
x (torch.Tensor) – Input tensor of shape

return_states (bool, optional) – If True, return intermediate activations and input. Default:
False
- Returns:
Output tensor
or dict with intermediates states
containing the keys: “output”, “y1”, “y2” and “input”.- Return type:
Union[torch.Tensor, Dict]
MLP¶
- class transformer.ffn.MLP(d_model, d_ff, bias=True)[source]¶
Bases:
ModuleClassic MLP with GELU activation (as used in the original Transformer).
- Parameters:
d_model (int) – Model dimension.
d_ff (int) – Intermediate dimension.
bias (bool, optional) – Whether to use bias in linear layers. Default:
True
- __init__(d_model, d_ff, bias=True)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
d_model (int)
d_ff (int)
bias (bool | None)
- forward(x, return_states=False)[source]¶
Forward pass of MLP.
- Parameters:
x (torch.Tensor) – Input tensor of shape

return_states (bool, optional) – If True, return intermediate activations. Default:
False
- Returns:
Output tensor
or dict with intermediates states
containing the keys: “output”, “h1”, “h2” and “input”.- Return type:
Union[torch.Tensor, Dict]
Transformer Model¶
TransformerBlock¶
- class transformer.transformer.TransformerBlock(config, attn_kwargs={}, ffn_kwargs={}, norm_kwargs={}, layer_idx=0)[source]¶
Bases:
GradientCheckpointingLayerA Single Transformer Decoder Block with support for Gradient Checkpointing consisting of Multi-Head Attention and Feed-Forward layers, each with Pre-Normalization (RMSNorm) and Standard Residual Connections.
- Parameters:
config (TransformerConfig) – Configuration object.
attn_kwargs (Dict, optional) – Additional Arguments for the attention class passed from
TransformerConfig.attn_class. It is only used ifTransformerConfig.attn_classisType[nn.Module]ffn_kwargs (Dict, optional) – Additional Arguments for the ffn class passed from
TransformerConfig.ffn_class. It is only used ifTransformerConfig.ffn_classisType[nn.Module]norm_kwargs (Dict, optional) – Additional Arguments for the normalization class passed from
TransformerConfig.norm_class. It is always passed.layer_idx (int, optional) – Index of this block (used for debugging/logging).
- __init__(config, attn_kwargs={}, ffn_kwargs={}, norm_kwargs={}, layer_idx=0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
attn_kwargs (Dict | None)
ffn_kwargs (Dict | None)
norm_kwargs (Dict | None)
layer_idx (int)
- forward(x, attn_mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]¶
Forward pass of the transformer block.
- Parameters:
x (torch.Tensor) – Input tensor of shape
.attn_mask (torch.Tensor, optional) – Attention mask for the Attention block.
pos (torch.Tensor, optional) – Position indices for Positional Encoding.
flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention.
return_states (bool, optional) – If True, return a dictionary of intermediate outputs. Default: False
- Returns:
Output tensor (batch_size, seq_len, d_model) if not return_states, else a dict containing the keys: “output”, “attn_output” and “ffn_output”.
- Return type:
Union[torch.Tensor, Dict]
Transformer Class (Main Model)¶
- class transformer.transformer.Transformer(config, attn_kwargs={}, pos_encoding_kwargs={}, ffn_kwargs={}, norm_kwargs={})[source]¶
Bases:
PreTrainedModel,GenerationMixinTransformer language model, compatible with the HuggingFace interface.
- Parameters:
config (TransformerConfig) – Model configuration.
attn_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Attention Module. Default:
{"pos_encoding_kwargs": **pos_encoding_kwargs}pos_encoding_kwargs (Dict, optional) – Additional Arguments for Positional Encoding. Default:
{}Example:{"rope_base": 12000, "persistent": False}ffn_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Feed-Forward Module. Default:
{}norm_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Normalization Layer. Default:
{}
- config_class¶
alias of
TransformerConfig
- base_model_prefix: str = 'transformer'¶
- supports_gradient_checkpointing: bool = True¶
- input_modalities: str | list[str] = 'text'¶
- __init__(config, attn_kwargs={}, pos_encoding_kwargs={}, ffn_kwargs={}, norm_kwargs={})[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
attn_kwargs (Dict)
pos_encoding_kwargs (Dict)
ffn_kwargs (Dict)
norm_kwargs (Dict)
- forward(input_ids, labels=None, is_causal=True, attn_mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False, loss_kwargs={}, **kwargs)[source]¶
Forward pass of the Transformer model.
- Parameters:
input_ids (torch.LongTensor) – Token indices of shape

labels (torch.LongTensor, optional) – Target token indices for loss computation, same shape as input_ids.
is_causal (bool, optional) – If True, create a causal attention mask. Default: True
attn_mask (torch.Tensor, optional) – Custom attention mask. If None and is_causal, a upper triangular causal mask is generated.
pos (torch.Tensor, optional) – Position indices. If None, uses
torch.arange(N).flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention.
return_states (bool, optional) – If True, return hidden states of all layers. Default: False
loss_kwargs (Dict, optional) – Additional keyword arguments passed to F.cross_entropy for loss computation.
kwargs (Dict, optional) – Additional keyword arguments
- Returns:
Contains loss (if labels given else None), logits, and optionally hidden states being a tuple of (input_embs, hidden_states) where hidden_states is a list of dictionaries for the output of each layer.
- Return type:
CausalLMOutput
- get_input_embeddings()[source]¶
Returns the model’s input embeddings.
- Returns:
nn.Module: A torch module mapping vocabulary to hidden states.
- Return type:
Embedding
- set_input_embeddings(embeddings)[source]¶
Fallback setter that handles ~70% of models in the code-base.
Order of attempts: 1. self.<_input_embed_layer> (direct attribute) 2. self.embeddings.<_input_embed_layer> (nested embeddings for vision/audio models) 3. self.model.<_input_embed_layer> (encoder/decoder models) 4. delegate to the base model if one exists 5. otherwise raise NotImplementedError so subclasses still can (and
should) override for exotic layouts.
- Parameters:
embeddings (Embedding)