Usage Examples¶
This section provides complete and diverse usage examples of all modules and classes
Basic Usage¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformer import Transformer, TransformerConfig
# Define configuration
config = TransformerConfig(
n_layers=6,
n_heads=8,
d_model=384,
vocab_size=65,
seq_len=256,
max_seq_len=1024,
tied_weights=False
)
# Create model
model = Transformer(config)
# Prepare input
batch_size, seq_len = 2, 128
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# Forward pass
outputs = model(input_ids)
logits = outputs.logits # shape: [B, N, V]
print(logits.shape)
Visualization¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformer import Transformer, TransformerConfig
import matplotlib.pyplot as plt
# Define configuration
config = TransformerConfig(
n_layers=6,
n_heads=8,
d_model=384,
vocab_size=65,
seq_len=256,
max_seq_len=1024,
tied_weights=True
)
# Create model
model = Transformer(config)
# Prepare input
batch_size, seq_len = 1, 128
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# Forward pass
with torch.no_grad():
output = model(input_ids, return_states=True)
logits = output.logits # shape: [batch_size, seq_len, vocab_size]
hidden_states = output.hidden_states # Tuple: (input_embs, hidden_states)
hidden_states = hidden_states[1] # The hidden_states
layer, batch, head = (0, 0, 0)
# Visualization of Attention Scores
# Note: Use .detach() always to avoid RuntimeError
attn_matrix = hidden_states[layer]["attn_output"]["attn_scores"][batch, head].detach().cpu() # Shape [N, N]
plt.imshow(attn_matrix) # No need to convert to numpy this is handled automatically
plt.colorbar()
plt.show()
# Visualization of Attention Weights
attn_matrix = hidden_states[layer]["attn_output"]["attn_weights"][batch, head].detach().cpu() # Shape [N, N]
plt.imshow(attn_matrix)
plt.colorbar()
plt.show()
# Visualization of the weights of the first linear layer of SwiGLU as a HeatMap
weights = model.blocks[layer].ffn.W1.weight.mT.detach().cpu() # Shape [d_ff, d_model]
plt.imshow(weights)
plt.colorbar()
plt.show()
# Visualization of the weights of the first linear layer of SwiGLU as lines
weights = weights.mT # Shape [d_model, d_ff]
plt.plot(weights)
plt.show()
Training a Simple Model¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformer import Transformer, TransformerConfig
# Model configuration
config = TransformerConfig(
n_layers=4,
n_heads=4,
d_model=256,
vocab_size=1000,
seq_len=128,
max_seq_len=512
)
model = Transformer(config)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
# Dummy data
batch_size = 8
seq_len = 64
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# Training step
model.train()
optimizer.zero_grad()
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
Text Generation with HuggingFace GenerationMixin¶
The model inherits from GenerationMixin, so you can use generate().
# Assume model is trained or loaded
model.eval()
# Prompt
prompt = torch.tensor([[1, 2, 3, 4]]) # (B, N)
# Generate
with torch.no_grad():
generated = model.generate(
input_ids=prompt,
max_new_tokens=50,
do_sample=True,
temperature=0.8,
top_k=40
)
print(generated.shape)
Using Flash Attention¶
Flash Attention can be enabled via the flash_attn tuple passed in the forward call. The tuple contains:
use_flash(bool): whether to use flash attention.backends: a backend or list of backends (e.g.,torch.nn.attention.SDPBackend.FLASH_ATTENTION).set_priority(bool): whether the list order is priority.
from torch.nn.attention import SDPBackend
# Enable flash attention with default backend
flash_attn = (True, SDPBackend.FLASH_ATTENTION, False)
outputs = model(input_ids, flash_attn=flash_attn)
# Use a list of backends with priority
flash_attn = (True, [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], True)
outputs = model(input_ids, flash_attn=flash_attn)
Note: When flash attention is used, attn_weights and attn_scores are not returned (they are None in the state dict).
Custom Attention Mask¶
You can provide any boolean mask to control which positions attend to which.
# Causal mask (upper triangular)
seq_len = 16
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# Padding mask (batch-specific)
pad_mask = torch.randint(0, 2, (2, seq_len)).bool() # (B, N)
# Combine masks (broadcasted)
# For 4D mask: (B, H, N, N)
combined_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1,1,N,N)
combined_mask = combined_mask | pad_mask.unsqueeze(1).unsqueeze(2) # (B,1,N,N)
outputs = model(input_ids, attn_mask=combined_mask)