API Reference

This section provides detailed documentation of all public modules and classes.

Configuration

class transformer.config.TransformerConfig(n_layers=12, d_model=1536, n_heads=32, n_kv_heads=None, vocab_size=50000, d_ff=None, norm_design='pre_norm', norm_class='rms_norm', ffn_class='SwiGLU', attn_class='MHA', block_class=None, attn_bias=False, ffn_bias=True, lm_head_bias=False, attn_qk_norm=True, attn_dropout=0.0, tied_weights=False, seq_len=1024, pos_encoding='RoPE', rope_base=10000.0, max_seq_len=4096, **kwargs)[source]

Bases: PreTrainedConfig

Configuration class for Transformer models. Inherits from PretrainedConfig for HuggingFace compatibility.

Parameters:
  • n_layers (int) – Number of Transformer Blocks (layers).

  • d_model (int) – Model Dimension.

  • n_heads (int) – Number of Attention Heads.

  • n_kv_heads (int, optional) – Number of key/value heads for Grouped-Query Attention(GQA). Default: n_heads

  • vocab_size (int) – Vocabulary size of the model. Defines the number of different tokens.

  • d_ff (int, optional) – Dimension of the Feed-Forward Hidden Layer.

  • norm_design (str) – Normalization Design, one of pre-norm, post-norm or both. Default: pre-norm

  • norm_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –

    Normalization class or type. - If str, one of rms_norm or layer_norm. - If Type[nn.Module] then will be instantiated inside the model.

    Should have the same API as a torch Normalization Layer.

    • If List[Union[Type[nn.Module], str]] and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers.

  • ffn_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –

    Feed-Forward Network class or type. - If str, one of SwiGLU, MLP. - If Type[nn.Module] then will be instantiated inside the model.

    Should have the same API as SwiGLU and MLP. Default SwiGLU

    • If List[Union[Type[nn.Module], str]] and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers. Default SwiGLU for every layer.

  • attn_class (Union[List[Union[Type[nn.Module], str]], Type[nn.Module], str]) –

    Attention class or type. - If str, one of MHA, GQA, CrossAttention. For GQA, also specify n_kv_heads. - If Type[nn.Module] then will be instantiated inside the model.

    Should have the same API as transformer.attn.MHA. Default MHA

    • If List[Union[Type[nn.Module], str]] and len(ffn_class) == n_layers then will be instantiated inside the model for the corresponding layers. Default SwiGLU for every layer.

  • block_class (Optional[Type[nn.Module]]) – Transformer Block class for every layer. Default: None - If Type[nn.Module] then will be instantiated for every layer inside the model. - If None then the default transformer.TransformerBlock will be used

  • attn_bias (bool, optional) – Whether to use bias in attention Linear Projections. Default: False

  • ffn_bias (bool, optional) – Whether to use bias in Feed-Forward Linear layers. Default: True

  • lm_head_bias (bool, optional) – Whether to use bias in the Language Modeling Head. Default: False

  • attn_qk_norm (bool, optional) – Whether to apply Normalization to Queries and Keys before the Attention Computation. Default: True

  • attn_dropout (float, optional) – Dropout probability for the Attention Layer. Default: 0.0

  • tied_weights (bool, optional) – If True, tie the input embedding and output projection weights. Default: False

  • seq_len (int) – Sequence Length.

  • pos_encoding (Union[List[str], str]) –

    Positional Encoding for attention. - If List[Union[Type[nn.Module], str]] and len(ffn_class) == n_layers

    then will be instantiated inside the model for the corresponding layers. Default SwiGLU for every layer.

    • If str one of RoPE, AliBI, PartialRoPE. Default: RoPE

    Note: Is recommended to change the default to PartialRoPE which is used in SOTA models like Qwen3-Next-80B-A3B

  • rope_base (float, optional) – Base for the Exponential Frequency Calculation in RoPE. Default: 10000.0

  • max_seq_len (int) – Maximum sequence length for positional embeddings.

  • kwargs (dict, optional) – Additional keyword arguments passed to PretrainedConfig

model_type = 'transformer'
__init__(n_layers=12, d_model=1536, n_heads=32, n_kv_heads=None, vocab_size=50000, d_ff=None, norm_design='pre_norm', norm_class='rms_norm', ffn_class='SwiGLU', attn_class='MHA', block_class=None, attn_bias=False, ffn_bias=True, lm_head_bias=False, attn_qk_norm=True, attn_dropout=0.0, tied_weights=False, seq_len=1024, pos_encoding='RoPE', rope_base=10000.0, max_seq_len=4096, **kwargs)[source]
Parameters:
  • n_layers (int)

  • d_model (int)

  • n_heads (int)

  • n_kv_heads (int | None)

  • vocab_size (int)

  • d_ff (int | None)

  • norm_design (str)

  • norm_class (List[Type[Module] | str] | Type[Module] | str)

  • ffn_class (List[Type[Module] | str] | Type[Module] | str)

  • attn_class (List[Type[Module] | str] | Type[Module] | str)

  • block_class (Type[Module] | None)

  • attn_bias (bool)

  • ffn_bias (bool)

  • lm_head_bias (bool)

  • attn_qk_norm (bool)

  • attn_dropout (float | None)

  • tied_weights (bool)

  • seq_len (int)

  • pos_encoding (str)

  • rope_base (float)

  • max_seq_len (int)

  • kwargs (Dict)

Attention Modules

Multi-Head Attention (MHA)

class transformer.attns.MHA(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]

Bases: Module

Multi-Head Attention MHA module using the optimized implementation of torch.nn.functional.scaled_dot_product_attention() when possible.

Parameters:
  • d_model (int) – Model dimension.

  • n_heads (int) – Number of attention heads. Note that d_model will be split across n_heads (i.e. each head will have dimension d_head//n_heads).

  • dropout (float, optional) – Dropout probability on attn_output_weights. Default: 0.0 (no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.

  • attn_bias (bool, optional) – Whether to use bias in linear projections. Default: False

  • qk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default: True

  • layer_idx (int, optional) – Index of the layer (used for debugging/logging).

  • pos_encoding (str, optional) – Positional Encoding to use. Default: RoPE

  • pos_encoding_kwargs (Dict, optional) – Dictionary of Additional Arguments for Positional Encoding. Example: {“rope_base”: 10000.0, “rot_frac”: 0.5}.

  • max_seq_len (int) – Maximum sequence length for RoPE.

__init__(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • d_model (int)

  • n_heads (int)

  • dropout (float)

  • attn_bias (bool | None)

  • qk_norm (bool | None)

  • layer_idx (int)

  • pos_encoding (str)

  • pos_encoding_kwargs (Dict)

  • max_seq_len (int)

forward(x, mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]

Forward pass of MHA.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, N, D) where N is the Sequence Length, B is the batch size, and D is the embedding dimension d_model.

  • mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape (N, N) or (B, H, N, N), where B is the batch size, H is the number of heads and N is the Sequence Length. A 2D mask will be broadcasted across the batch while a 4D mask allows for a different mask for each entry in the batch and/or heads dimensions. Note: Should be a boolean mask where True indicates masked positions. When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.

  • pos (torch.LongTensor, optional) – Position indices for RoPE, shape (N) or (B, N)

  • flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default: False - Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default: torch.nn.attention.SPDBackend.FLASH_ATTENTION - bool: Whether the ordering of the backends is interpreted as their priority order. Default: False

  • return_states (bool, optional) – If True, return a dictionary of intermediate tensors. Default: False

Returns:

Output tensor (B, N, D) if not return_states, else a dict containing the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input}

Return type:

Union[torch.Tensor, Dict]

Grouped-Query Attention (GQA)

class transformer.attns.GQA(d_model, n_heads, n_kv_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]

Bases: Module

Grouped Query Attention GQA module using the optimized implementation of torch.nn.functional.scaled_dot_product_attention() when possible.

Parameters:
  • d_model (int) – Model dimension.

  • n_heads (int) – Number of attention heads. Note that d_model will be split across n_heads (i.e. each head will have dimension d_head//n_heads).

  • n_kv_heads (int) – Number of key/value heads (must divide n_heads).

  • dropout (float, optional) – Dropout probability on attn_output_weights. Default: 0.0 (no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.

  • attn_bias (bool, optional) – Whether to use bias in linear projections. Default: False

  • qk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default: True

  • layer_idx (int, optional) – Index of the layer (used for debugging/logging).

  • pos_encoding (str, optional) – Positional Encoding to use. Default: RoPE

  • pos_encoding_kwargs (Dict, optional) – Dictionary of Additional Arguments for Positional Encoding. Example: {“rope_base”: 10000.0, “rot_frac”: 0.5}.

  • max_seq_len (int) – Maximum sequence length for RoPE.

__init__(d_model, n_heads, n_kv_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, pos_encoding='RoPE', pos_encoding_kwargs={}, max_seq_len=1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • d_model (int)

  • n_heads (int)

  • n_kv_heads (int)

  • dropout (float | None)

  • attn_bias (bool | None)

  • qk_norm (bool | None)

  • layer_idx (int)

  • pos_encoding (str)

  • pos_encoding_kwargs (Dict)

  • max_seq_len (int)

forward(x, mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]

Forward pass of GQA.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, N, D) where N is the Sequence Length, B is the batch size, and D is the embedding dimension d_model.

  • mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape (N, N) or (B, H, N, N), where B is the batch size, H is the number of heads and N is the Sequence Length. A 2D mask will be broadcasted across the batch while a 4D mask allows for a different mask for each entry in the batch and/or heads dimensions. Note: Should be a boolean mask where True indicates masked positions. When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.

  • pos (torch.LongTensor, optional) – Position indices for RoPE, shape (N) or (B, N)

  • flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default: False - Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default: torch.nn.attention.SPDBackend.FLASH_ATTENTION - bool: Whether the ordering of the backends is interpreted as their priority order. Default: False

  • return_states (bool, optional) – If True, return a dictionary of intermediate tensors. Default: False

Returns:

Output tensor of shape (B, N, D) if not return_states, else a dict containing the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input}

Return type:

Union[torch.Tensor, Dict]

Cross-Attention

class transformer.attns.CrossAttention(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, rope_base=10000.0, max_seq_len=1024)[source]

Bases: Module

CrossAttention module using the optimized implementation of torch.nn.functional.scaled_dot_product_attention() when possible.

Parameters:
  • d_model (int) – Model dimension.

  • n_heads (int) – Number of attention heads. Note that d_model will be split across n_heads (i.e. each head will have dimension d_head//n_heads).

  • dropout (float, optional) – Dropout probability on attn_output_weights. Default: 0.0 (no dropout). Note: Latest SOTA Architectures do not use Dropout at all and for Research Purposes it is recommended to never use it.

  • attn_bias (bool, optional) – Whether to use bias in linear projections. Default: False

  • qk_norm (bool, optional) – Whether to apply RMSNorm to queries and keys. Default: True

  • layer_idx (int, optional) – Index of the layer (used for debugging/logging).

  • rope_base (float, optional) – Base for the Exponential Frequency Calculation in RoPE. Default: 10000.0

  • max_seq_len (int) – Maximum sequence length for RoPE.

__init__(d_model, n_heads, dropout=0.0, attn_bias=False, qk_norm=True, layer_idx=0, rope_base=10000.0, max_seq_len=1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • d_model (int)

  • n_heads (int)

  • dropout (float | None)

  • attn_bias (bool | None)

  • qk_norm (bool | None)

  • layer_idx (int)

  • rope_base (float)

  • max_seq_len (int)

forward(queries, kv, mask=None, pos_q=None, pos_k=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]

Forward pass of CrossAttention.

Parameters:
  • queries (torch.Tensor) – Input tensor of shape (B, Lq, D) where Lq is the Sequence Length for the query sequence, B is the batch size, and D is the embedding dimension d_model.

  • kv (torch.Tensor) – Input tensor of shape (B, Lq, D) where Lk is the Sequence Length for the key/value sequence, B is the batch size, and D is the embedding dimension d_model.

  • mask (torch.BoolTensor, optional) – If specified, a 2D or 4D mask preventing attention to certain positions. Must be of shape (Lq, Lk) or (B, H, Lq, Lk), where B is the batch size, H is the number of heads, Lq is the Sequence Length of the query sequence and Lk is the Sequence Length of the key/value sequence. A 2D mask will be broadcasted across the batch while a 4D mask allows for a different mask for each entry in the batch and/or heads dimensions. Note: Should be a boolean mask where True indicates masked positions. When Flash Attention is enabled it is inverted because PyTorch expects True for allowed positions.

  • pos_q (torch.LongTensor, optional) – Position indices for Queries, shape (Lq) or (B, Lq)

  • pos_k (torch.LongTensor, optional) – Position indices for Keys, shape (Lk) or (B, Lk)

  • flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention and the Context manager to select which backend to use for scaled dot product attention. - bool: Whether to use or not Flash Attention. Default: False - Union[List[SDPBackend], SDPBackend]: A backend or list of backends for scaled dot product attention. Default: torch.nn.attention.SPDBackend.FLASH_ATTENTION - bool: Whether the ordering of the backends is interpreted as their priority order. Default: False

  • return_states (bool, optional) – If True, return dictionary of intermediates tensors. Default: False

Returns:

Output tensor of shape (B, N, D) if not return_states, else a dict containing the keys: {output, queries, keys, values, attn_weights, attn_scores, output_before_proj and input} where input is a tuple (queries, kv)

Return type:

Union[torch.Tensor, Dict]

Positional Embeddings

RoPE (Rotary Position Embedding)

class transformer.pos.RoPE(max_seq_len, d_head, rope_base=10000.0, persistent=True)[source]

Bases: Module

Rotary Position Embedding (RoPE) module.

Parameters:
  • max_seq_len (int) – Maximum sequence length for which to precompute frequencies.

  • d_head (int) – Dimension per head (must be even).

  • rope_base (float, optional) – Base for the exponential frequency calculation. Default: 10000.0

  • persistent (bool, optional) – Whether to register the precomputed cos/sin as persistent buffers. Default: True

__init__(max_seq_len, d_head, rope_base=10000.0, persistent=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • max_seq_len (int)

  • d_head (int)

  • rope_base (float)

  • persistent (bool)

forward(q, k, pos_q, pos_k)[source]

Apply rotary position embeddings to queries and keys.

Parameters:
  • q (torch.Tensor) – Query tensor of shape (B, H, N, d)

  • k (torch.Tensor) – Key tensor of shape (B, H, N, d)

  • pos_q (torch.LongTensor) – Positions for queries, shape (N,) or (B, N)

  • pos_k (torch.LongTensor) – Positions for keys, shape (N,) or (B, N)

Returns:

Rotated queries and keys.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Feed-Forward Modules

SwiGLU

class transformer.ffn.SwiGLU(d_model, d_ff, bias=True)[source]

Bases: Module

SwiGLU feed-forward module

Parameters:
  • d_model (int) – Model dimension.

  • d_ff (int) – Intermediate dimension (should be even, as it’s split into two halves).

  • bias (bool, optional) – Whether to use bias in linear layers. Default: True

__init__(d_model, d_ff, bias=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • d_model (int)

  • d_ff (int)

  • bias (bool | None)

forward(x, return_states=False)[source]

Forward pass of SwiGLU.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (..., D)

  • return_states (bool, optional) – If True, return intermediate activations and input. Default: False

Returns:

Output tensor (..., D) or dict with intermediates states containing the keys: “output”, “y1”, “y2” and “input”.

Return type:

Union[torch.Tensor, Dict]

MLP

class transformer.ffn.MLP(d_model, d_ff, bias=True)[source]

Bases: Module

Classic MLP with GELU activation (as used in the original Transformer).

Parameters:
  • d_model (int) – Model dimension.

  • d_ff (int) – Intermediate dimension.

  • bias (bool, optional) – Whether to use bias in linear layers. Default: True

__init__(d_model, d_ff, bias=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • d_model (int)

  • d_ff (int)

  • bias (bool | None)

forward(x, return_states=False)[source]

Forward pass of MLP.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (..., D)

  • return_states (bool, optional) – If True, return intermediate activations. Default: False

Returns:

Output tensor (..., D) or dict with intermediates states containing the keys: “output”, “h1”, “h2” and “input”.

Return type:

Union[torch.Tensor, Dict]

Transformer Model

TransformerBlock

class transformer.transformer.TransformerBlock(config, attn_kwargs={}, ffn_kwargs={}, norm_kwargs={}, layer_idx=0)[source]

Bases: GradientCheckpointingLayer

A Single Transformer Decoder Block with support for Gradient Checkpointing consisting of Multi-Head Attention and Feed-Forward layers, each with Pre-Normalization (RMSNorm) and Standard Residual Connections.

Parameters:
  • config (TransformerConfig) – Configuration object.

  • attn_kwargs (Dict, optional) – Additional Arguments for the attention class passed from TransformerConfig.attn_class. It is only used if TransformerConfig.attn_class is Type[nn.Module]

  • ffn_kwargs (Dict, optional) – Additional Arguments for the ffn class passed from TransformerConfig.ffn_class. It is only used if TransformerConfig.ffn_class is Type[nn.Module]

  • norm_kwargs (Dict, optional) – Additional Arguments for the normalization class passed from TransformerConfig.norm_class. It is always passed.

  • layer_idx (int, optional) – Index of this block (used for debugging/logging).

__init__(config, attn_kwargs={}, ffn_kwargs={}, norm_kwargs={}, layer_idx=0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • attn_kwargs (Dict | None)

  • ffn_kwargs (Dict | None)

  • norm_kwargs (Dict | None)

  • layer_idx (int)

forward(x, attn_mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False)[source]

Forward pass of the transformer block.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, N, D).

  • attn_mask (torch.Tensor, optional) – Attention mask for the Attention block.

  • pos (torch.Tensor, optional) – Position indices for Positional Encoding.

  • flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention.

  • return_states (bool, optional) – If True, return a dictionary of intermediate outputs. Default: False

Returns:

Output tensor (batch_size, seq_len, d_model) if not return_states, else a dict containing the keys: “output”, “attn_output” and “ffn_output”.

Return type:

Union[torch.Tensor, Dict]

Transformer Class (Main Model)

class transformer.transformer.Transformer(config, attn_kwargs={}, pos_encoding_kwargs={}, ffn_kwargs={}, norm_kwargs={})[source]

Bases: PreTrainedModel, GenerationMixin

Transformer language model, compatible with the HuggingFace interface.

Parameters:
  • config (TransformerConfig) – Model configuration.

  • attn_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Attention Module. Default: {"pos_encoding_kwargs": **pos_encoding_kwargs}

  • pos_encoding_kwargs (Dict, optional) – Additional Arguments for Positional Encoding. Default: {} Example: {"rope_base": 12000, "persistent": False}

  • ffn_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Feed-Forward Module. Default: {}

  • norm_kwargs (Dict, optional) – Additional Keyword Arguments passed to the Normalization Layer. Default: {}

config_class

alias of TransformerConfig

base_model_prefix: str = 'transformer'
supports_gradient_checkpointing: bool = True
input_modalities: str | list[str] = 'text'
__init__(config, attn_kwargs={}, pos_encoding_kwargs={}, ffn_kwargs={}, norm_kwargs={})[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • attn_kwargs (Dict)

  • pos_encoding_kwargs (Dict)

  • ffn_kwargs (Dict)

  • norm_kwargs (Dict)

forward(input_ids, labels=None, is_causal=True, attn_mask=None, pos=None, flash_attn=(False, <SDPBackend.FLASH_ATTENTION: 1>, False), return_states=False, loss_kwargs={}, **kwargs)[source]

Forward pass of the Transformer model.

Parameters:
  • input_ids (torch.LongTensor) – Token indices of shape (B, N)

  • labels (torch.LongTensor, optional) – Target token indices for loss computation, same shape as input_ids.

  • is_causal (bool, optional) – If True, create a causal attention mask. Default: True

  • attn_mask (torch.Tensor, optional) – Custom attention mask. If None and is_causal, a upper triangular causal mask is generated.

  • pos (torch.Tensor, optional) – Position indices. If None, uses torch.arange(N).

  • flash_attn (Tuple[bool, Union[list[torch.nn.attention.SDPBackend], torch.nn.attention.SDPBackend], bool], optional) – Tuple of Arguments for Flash Attention.

  • return_states (bool, optional) – If True, return hidden states of all layers. Default: False

  • loss_kwargs (Dict, optional) – Additional keyword arguments passed to F.cross_entropy for loss computation.

  • kwargs (Dict, optional) – Additional keyword arguments

Returns:

Contains loss (if labels given else None), logits, and optionally hidden states being a tuple of (input_embs, hidden_states) where hidden_states is a list of dictionaries for the output of each layer.

Return type:

CausalLMOutput

get_input_embeddings()[source]

Returns the model’s input embeddings.

Returns:

nn.Module: A torch module mapping vocabulary to hidden states.

Return type:

Embedding

set_input_embeddings(embeddings)[source]

Fallback setter that handles ~70% of models in the code-base.

Order of attempts: 1. self.<_input_embed_layer> (direct attribute) 2. self.embeddings.<_input_embed_layer> (nested embeddings for vision/audio models) 3. self.model.<_input_embed_layer> (encoder/decoder models) 4. delegate to the base model if one exists 5. otherwise raise NotImplementedError so subclasses still can (and

should) override for exotic layouts.

Parameters:

embeddings (Embedding)

get_num_params()[source]

Return the number of trainable parameters.

Return type:

int