API Reference¶
This page documents all public classes, functions, and configuration objects provided by arch_eval.
Core Classes¶
Trainer¶
Constructor
Trainer(model: nn.Module, config: TrainingConfig)
model: PyTorch
nn.Moduleto train.config:
TrainingConfigobject with all training parameters.
Methods
train() -> Dict[str, List[float]]– Runs the training loop and returns a history dictionary mapping metric names to lists of values (one per epoch).load_checkpoint(path: str, load_optimizer: bool = True, load_scheduler: bool = True, weights_only: bool = False)– Loads a saved checkpoint.
Benchmark¶
Constructor
Benchmark(models: List[Dict[str, Any]], config: BenchmarkConfig)
models: List of dictionaries, each with keys
"name"(optional) and"model"(ann.Module).config:
BenchmarkConfigobject.
Methods
run() -> pd.DataFrame– Executes the benchmark and returns a pandas DataFrame with results.
HyperparameterOptimizer¶
Constructor
HyperparameterOptimizer(
model_fn: Callable,
base_config: TrainingConfig,
param_grid: Dict[str, List[Any]],
search_type: str = "grid",
n_trials: Optional[int] = None,
metric: str = "val_loss",
mode: str = "min"
)
model_fn: A zero‑argument callable that returns a new model instance.
base_config: Base
TrainingConfigthat will be copied and updated per trial.param_grid: Dictionary mapping parameter names to lists of values to try.
search_type:
"grid"or"random".n_trials: Number of random trials (ignored for grid search).
metric: Metric to optimize (must appear in the training history).
mode:
"min"or"max".
Methods
run() -> pd.DataFrame– Runs the search and returns a DataFrame with all trials and the target metric.
PluginManager¶
Methods
discover_plugins(plugin_paths: Optional[List[str]] = None)– Scans for plugins (modules whose name starts witharch_eval_plugin_or ends with_plugin).register_local_hook(hook_name: str, func: Callable)– Registers a hook function for the current trainer instance.execute_hook(hook_name: str, *args, **kwargs) -> List[Any]– Executes all global and local hooks for the given hook.get_plugins() -> Dict[str, Any]– Returns information about discovered plugins.
Configuration Classes¶
TrainingConfig¶
Inherits from BaseConfig. All fields are configurable via the constructor.
Key fields (in addition to BaseConfig fields):
optimizers: List[Dict[str, Any]]– List of optimizer specifications (e.g.,{"type": "adam", "lr": 0.001}). Defaults to a single Adam optimizer.schedulers: List[Dict[str, Any]]– List of scheduler specs, each with a"type"(e.g.,"step","cosine","reduce_on_plateau") and optional"optimizer"index.training_args: Dict[str, Any]– Holdsbatch_size,learning_rate,num_epochs, etc.task: Union[str, TaskType]– One of"classification","regression","next-token-prediction"or a custom object with aloss_function.mixed_precision: bool– Enable automatic mixed precision (AMP).mixed_precision_dtype: MixedPrecisionDtype–"float16","bfloat16", or"fp8".distributed_backend: DistributedBackend–"none","dp"(DataParallel),"ddp"(DistributedDataParallel), or"fsdp".callbacks: List[Callback]– List of callback instances.checkpoint_dir: Optional[str]– Directory to save checkpoints.early_stopping_patience: Optional[int]– Patience for early stopping.gradient_clip: Optional[float]– Max norm for gradient clipping.profiler: Optional[Dict]– Profiler configuration (e.g.,{"enabled": True, "activities": ["cpu", "cuda"]}).
BenchmarkConfig¶
Inherits from BaseConfig. Adds:
training_args: Dict[str, Any]– Training parameters (batch size, learning rate, epochs) used for all models.task: Union[str, TaskType]– Task type.parallel: bool– Whether to run models in parallel.compare_metrics: List[str]– Metrics to extract from each model’s history.max_workers: Optional[int]– Maximum parallel workers.use_processes: bool– Use process‑based (rather than thread‑based) parallelism.timeout_seconds: int– Timeout per model.retry_failed: bool– Whether to retry failed models.
BaseConfig¶
Common fields:
Data:
dataset: Any– A dataset object, string identifier, or dictionary.dataset_params: Dict– Parameters for synthetic/torchvision datasets.transform,target_transform,collate_fn– Optional callables.dataset_streaming: bool– Use streaming (IterableDataset) for Hugging Face datasets.
Computation:
dtype: torch.dtype– Defaulttorch.float32.device: Optional[str]– Auto‑selected ifNone.
Logging & Viz:
viz_interval,log_interval,eval_interval– Step intervals.realtime: bool– Enable live plotting window.save_video: List[str]– List of metric names to record as video.save_plot: List[str]– List of metric names to save as final plots.log_to_wandb,wandb_project, etc.
Reproducibility:
seed: Optional[int]deterministic: bool
Enums¶
TaskType:REGRESSION,CLASSIFICATION,NEXT_TOKEN_PREDICTIONDistributedBackend:NONE,DATAPARALLEL,DISTRIBUTED,FSDPMixedPrecisionDtype:FLOAT16,BFLOAT16,FP8
Callbacks¶
All callbacks inherit from Callback and can be registered via the callbacks list in TrainingConfig.
Callback¶
All methods are no‑ops by default; override the ones you need.
Built‑in Callbacks¶
EarlyStopping– Stops training when a monitored metric stops improving.ModelCheckpoint– Saves model checkpoints.LRSchedulerLogger– Logs learning rates after each epoch.TensorBoardLogger– Logs metrics to TensorBoard.
Data Handling¶
DatasetHandler¶
Internal class used by Trainer to prepare data loaders. Usually you do not need to instantiate it directly.
Synthetic Dataset Classes¶
SyntheticDataset¶
A simple torch.utils.data.Dataset wrapping synthetic data tensors for traditional ML tasks (classification, regression, etc.).
TextDataset¶
Dataset for language modeling with token sequences. Returns (input_ids, labels) pairs suitable for transformer training.
Example:
from arch_eval.data import TextDataset
import torch
input_ids = torch.randint(0, 1000, (1000, 128)) # 1000 samples, seq length 128
labels = input_ids.clone() # For causal LM, labels are same as input
dataset = TextDataset(input_ids, labels)
ImageDataset¶
Dataset for vision tasks with images. Returns (images, labels) pairs with optional transforms.
Example:
from arch_eval.data import ImageDataset
import torch
images = torch.rand(1000, 3, 32, 32) # 1000 RGB images, 32x32
labels = torch.randint(0, 10, (1000,)) # 10 classes
dataset = ImageDataset(images, labels)
VisionLanguageDataset¶
Dataset for vision-language multi-modal tasks. Returns dictionaries with pixel_values, input_ids, attention_mask, and labels.
Example:
from arch_eval.data import VisionLanguageDataset
import torch
images = torch.rand(500, 3, 32, 32) # 500 images
input_ids = torch.randint(0, 1000, (500, 64)) # Token sequences
attention_mask = torch.ones(500, 64)
labels = torch.randint(0, 10, (500,))
dataset = VisionLanguageDataset(images, input_ids, attention_mask, labels)
Synthetic Dataset Generation Functions¶
create_synthetic_dataset¶
create_synthetic_dataset(dataset_type: str, params: Dict[str, Any]) -> SyntheticDataset
Creates a synthetic dataset for traditional ML tasks. Supported types:
"classification","regression","blobs","circles","moons""friedman1","friedman2","friedman3","sparse_uncorrelated","multilabel"
Example:
from arch_eval.data import create_synthetic_dataset
# Binary classification
dataset = create_synthetic_dataset("classification", {
"n_samples": 1000,
"n_features": 20,
"n_classes": 2,
"n_informative": 10
})
create_synthetic_text_dataset¶
create_synthetic_text_dataset(params: Dict[str, Any]) -> TextDataset
Creates synthetic text data for language modeling. Generates token sequences with configurable structure.
Parameters:
vocab_size: Size of vocabulary (default: 1000)seq_length: Sequence length (default: 128)n_samples: Number of samples (default: 1000)entropy: Randomness level 0-1, lower = more pattern (default: 0.8)random_state: Random seed (default: 42)
Example:
from arch_eval.data import create_synthetic_text_dataset
# Generate synthetic text for transformer pretraining
dataset = create_synthetic_text_dataset({
"vocab_size": 32000,
"seq_length": 512,
"n_samples": 10000,
"entropy": 0.7 # Some structure, not purely random
})
config = TrainingConfig(
dataset="synthetic text",
dataset_params={
"vocab_size": 32000,
"seq_length": 512,
"n_samples": 10000
},
task="next-token-prediction",
training_args={"batch_size": 32, "num_epochs": 10}
)
create_synthetic_image_dataset¶
create_synthetic_image_dataset(params: Dict[str, Any]) -> ImageDataset
Creates synthetic image data for vision tasks. Supports random noise, gradients, and geometric shapes.
Parameters:
img_size: Image size as int or tuple (H, W) (default: 32)channels: Number of channels (default: 3)n_samples: Number of samples (default: 1000)n_classes: Number of classes (default: 10)pattern: Type of pattern -'random','gradient','shapes'(default:'random')random_state: Random seed (default: 42)
Example:
from arch_eval.data import create_synthetic_image_dataset
# Generate synthetic images with geometric shapes
dataset = create_synthetic_image_dataset({
"img_size": 64,
"channels": 3,
"n_samples": 5000,
"n_classes": 10,
"pattern": "shapes" # Squares, circles, triangles, lines
})
config = TrainingConfig(
dataset="synthetic image",
dataset_params={
"img_size": 64,
"channels": 3,
"pattern": "shapes"
},
task="classification",
training_args={"batch_size": 64, "num_epochs": 20}
)
create_synthetic_vision_language_dataset¶
create_synthetic_vision_language_dataset(params: Dict[str, Any]) -> VisionLanguageDataset
Creates synthetic vision-language paired data for multi-modal tasks.
Parameters:
img_size: Image size (default: 32)channels: Number of image channels (default: 3)vocab_size: Vocabulary size for text (default: 1000)seq_length: Text sequence length (default: 64)n_samples: Number of samples (default: 500)correlation: How much text correlates with image class 0-1 (default: 0.7)random_state: Random seed (default: 42)
Example:
from arch_eval.data import create_synthetic_vision_language_dataset
# Generate synthetic image-caption pairs
dataset = create_synthetic_vision_language_dataset({
"img_size": 32,
"channels": 3,
"vocab_size": 1000,
"seq_length": 64,
"n_samples": 2000,
"correlation": 0.8 # High correlation between image class and text
})
config = TrainingConfig(
dataset="synthetic vision_language",
dataset_params={
"img_size": 32,
"vocab_size": 1000,
"seq_length": 64,
"correlation": 0.8
},
task="next-token-prediction", # Or custom multi-modal task
training_args={"batch_size": 32, "num_epochs": 15}
)
Distributed Utilities¶
init_distributed¶
init_distributed(backend="nccl", world_size=1, rank=0, master_addr="127.0.0.1", master_port="29500")
Initializes the distributed process group.
cleanup_distributed¶
cleanup_distributed()
Destroys the process group.
get_wrapped_model¶
get_wrapped_model(model, config)
Returns the model wrapped with the appropriate distributed wrapper (DataParallel, DistributedDataParallel, or FSDP).
Logging¶
setup_logging¶
setup_logging(level="INFO", log_file=None, fmt=None)
Configures root logger with console output and optional file rotation. Returns the root logger.
LoggerAdapter¶
Provides consistent logging with a "arch_eval." prefix.
Metrics¶
MetricCalculator¶
Computes task‑specific metrics (accuracy, precision, recall, F1, AUC, R², MSE, perplexity, etc.) and accumulates data for confusion matrices.
Supported Model Output Formats:
The MetricCalculator automatically handles various output formats from different model architectures:
Tensor: Standard PyTorch tensor output
(batch_size, num_classes)or(batch_size, seq_len, vocab_size)Tuple: Models returning
(logits, loss),(loss, logits), or similar tuple patternsDict: Transformer models returning
{"logits": ..., "loss": ...}or similar dictionariesHugging Face Style: Objects with
.logitsand.lossattributes (e.g.,CausalLMOutput,SequenceClassifierOutput)
Example with Hugging Face style output:
from transformers.modeling_outputs import CausalLMOutput
class MyTransformer(nn.Module):
def forward(self, input_ids, labels=None):
logits = self.model(input_ids)
loss = compute_loss(logits, labels) if labels is not None else None
return CausalLMOutput(loss=loss, logits=logits)
# Trainer and MetricCalculator will automatically extract loss and logits
trainer = Trainer(model, config)
history = trainer.train()
Plugins¶
hook decorator¶
from arch_eval.plugins import hook
@hook("before_training")
def my_hook(trainer, config):
...
Marks a function as a plugin hook. The function will be called at the appropriate time if the plugin is loaded.
Profiling¶
profiler_context¶
from arch_eval.profiler import profiler_context
with profiler_context(config):
# training code
Context manager that conditionally enables PyTorch’s torch.profiler based on the profiler field in the config.
Utilities¶
Device utilities¶
get_optimal_device() -> str
get_device_info() -> Dict[str, Any]
memory_summary() -> str
auto_device decorator¶
from arch_eval.utils import auto_device
@auto_device
def my_function(tensor):
...
Automatically moves input tensors to the device of the first argument (or inferred from a tensor) and optionally returns CPU tensors.
Visualization¶
RealtimeWindow¶
Displays live metric plots and system resource usage.
VideoRecorder¶
Records frames of metric plots and assembles them into a video using ffmpeg.
PlotSaver¶
Saves final metric plots (PNG) to disk.
Exceptions¶
All custom exceptions inherit from ArchEvalError:
DatasetFormatErrorConfigurationErrorModelErrorPluginErrorVisualizationErrorStopTraining– can be raised by plugins to gracefully stop training.DistributedError