API Reference

This page documents all public classes, functions, and configuration objects provided by arch_eval.

Core Classes

Trainer

Constructor

Trainer(model: nn.Module, config: TrainingConfig)
  • model: PyTorch nn.Module to train.

  • config: TrainingConfig object with all training parameters.

Methods

  • train() -> Dict[str, List[float]] – Runs the training loop and returns a history dictionary mapping metric names to lists of values (one per epoch).

  • load_checkpoint(path: str, load_optimizer: bool = True, load_scheduler: bool = True, weights_only: bool = False) – Loads a saved checkpoint.


Benchmark

Constructor

Benchmark(models: List[Dict[str, Any]], config: BenchmarkConfig)
  • models: List of dictionaries, each with keys "name" (optional) and "model" (a nn.Module).

  • config: BenchmarkConfig object.

Methods

  • run() -> pd.DataFrame – Executes the benchmark and returns a pandas DataFrame with results.


HyperparameterOptimizer

Constructor

HyperparameterOptimizer(
    model_fn: Callable,
    base_config: TrainingConfig,
    param_grid: Dict[str, List[Any]],
    search_type: str = "grid",
    n_trials: Optional[int] = None,
    metric: str = "val_loss",
    mode: str = "min"
)
  • model_fn: A zero‑argument callable that returns a new model instance.

  • base_config: Base TrainingConfig that will be copied and updated per trial.

  • param_grid: Dictionary mapping parameter names to lists of values to try.

  • search_type: "grid" or "random".

  • n_trials: Number of random trials (ignored for grid search).

  • metric: Metric to optimize (must appear in the training history).

  • mode: "min" or "max".

Methods

  • run() -> pd.DataFrame – Runs the search and returns a DataFrame with all trials and the target metric.


PluginManager

Methods

  • discover_plugins(plugin_paths: Optional[List[str]] = None) – Scans for plugins (modules whose name starts with arch_eval_plugin_ or ends with _plugin).

  • register_local_hook(hook_name: str, func: Callable) – Registers a hook function for the current trainer instance.

  • execute_hook(hook_name: str, *args, **kwargs) -> List[Any] – Executes all global and local hooks for the given hook.

  • get_plugins() -> Dict[str, Any] – Returns information about discovered plugins.


Configuration Classes

TrainingConfig

Inherits from BaseConfig. All fields are configurable via the constructor.

Key fields (in addition to BaseConfig fields):

  • optimizers: List[Dict[str, Any]] – List of optimizer specifications (e.g., {"type": "adam", "lr": 0.001}). Defaults to a single Adam optimizer.

  • schedulers: List[Dict[str, Any]] – List of scheduler specs, each with a "type" (e.g., "step", "cosine", "reduce_on_plateau") and optional "optimizer" index.

  • training_args: Dict[str, Any] – Holds batch_size, learning_rate, num_epochs, etc.

  • task: Union[str, TaskType] – One of "classification", "regression", "next-token-prediction" or a custom object with a loss_function.

  • mixed_precision: bool – Enable automatic mixed precision (AMP).

  • mixed_precision_dtype: MixedPrecisionDtype"float16", "bfloat16", or "fp8".

  • distributed_backend: DistributedBackend"none", "dp" (DataParallel), "ddp" (DistributedDataParallel), or "fsdp".

  • callbacks: List[Callback] – List of callback instances.

  • checkpoint_dir: Optional[str] – Directory to save checkpoints.

  • early_stopping_patience: Optional[int] – Patience for early stopping.

  • gradient_clip: Optional[float] – Max norm for gradient clipping.

  • profiler: Optional[Dict] – Profiler configuration (e.g., {"enabled": True, "activities": ["cpu", "cuda"]}).


BenchmarkConfig

Inherits from BaseConfig. Adds:

  • training_args: Dict[str, Any] – Training parameters (batch size, learning rate, epochs) used for all models.

  • task: Union[str, TaskType] – Task type.

  • parallel: bool – Whether to run models in parallel.

  • compare_metrics: List[str] – Metrics to extract from each model’s history.

  • max_workers: Optional[int] – Maximum parallel workers.

  • use_processes: bool – Use process‑based (rather than thread‑based) parallelism.

  • timeout_seconds: int – Timeout per model.

  • retry_failed: bool – Whether to retry failed models.


BaseConfig

Common fields:

  • Data:

    • dataset: Any – A dataset object, string identifier, or dictionary.

    • dataset_params: Dict – Parameters for synthetic/torchvision datasets.

    • transform, target_transform, collate_fn – Optional callables.

    • dataset_streaming: bool – Use streaming (IterableDataset) for Hugging Face datasets.

  • Computation:

    • dtype: torch.dtype – Default torch.float32.

    • device: Optional[str] – Auto‑selected if None.

  • Logging & Viz:

    • viz_interval, log_interval, eval_interval – Step intervals.

    • realtime: bool – Enable live plotting window.

    • save_video: List[str] – List of metric names to record as video.

    • save_plot: List[str] – List of metric names to save as final plots.

    • log_to_wandb, wandb_project, etc.

  • Reproducibility:

    • seed: Optional[int]

    • deterministic: bool


Enums

  • TaskType: REGRESSION, CLASSIFICATION, NEXT_TOKEN_PREDICTION

  • DistributedBackend: NONE, DATAPARALLEL, DISTRIBUTED, FSDP

  • MixedPrecisionDtype: FLOAT16, BFLOAT16, FP8


Callbacks

All callbacks inherit from Callback and can be registered via the callbacks list in TrainingConfig.

Callback

All methods are no‑ops by default; override the ones you need.

Built‑in Callbacks

  • EarlyStopping – Stops training when a monitored metric stops improving.

  • ModelCheckpoint – Saves model checkpoints.

  • LRSchedulerLogger – Logs learning rates after each epoch.

  • TensorBoardLogger – Logs metrics to TensorBoard.


Data Handling

DatasetHandler

Internal class used by Trainer to prepare data loaders. Usually you do not need to instantiate it directly.

Synthetic Dataset Classes

SyntheticDataset

A simple torch.utils.data.Dataset wrapping synthetic data tensors for traditional ML tasks (classification, regression, etc.).

TextDataset

Dataset for language modeling with token sequences. Returns (input_ids, labels) pairs suitable for transformer training.

Example:

from arch_eval.data import TextDataset
import torch

input_ids = torch.randint(0, 1000, (1000, 128))  # 1000 samples, seq length 128
labels = input_ids.clone()  # For causal LM, labels are same as input
dataset = TextDataset(input_ids, labels)

ImageDataset

Dataset for vision tasks with images. Returns (images, labels) pairs with optional transforms.

Example:

from arch_eval.data import ImageDataset
import torch

images = torch.rand(1000, 3, 32, 32)  # 1000 RGB images, 32x32
labels = torch.randint(0, 10, (1000,))  # 10 classes
dataset = ImageDataset(images, labels)

VisionLanguageDataset

Dataset for vision-language multi-modal tasks. Returns dictionaries with pixel_values, input_ids, attention_mask, and labels.

Example:

from arch_eval.data import VisionLanguageDataset
import torch

images = torch.rand(500, 3, 32, 32)  # 500 images
input_ids = torch.randint(0, 1000, (500, 64))  # Token sequences
attention_mask = torch.ones(500, 64)
labels = torch.randint(0, 10, (500,))
dataset = VisionLanguageDataset(images, input_ids, attention_mask, labels)

Synthetic Dataset Generation Functions

create_synthetic_dataset

create_synthetic_dataset(dataset_type: str, params: Dict[str, Any]) -> SyntheticDataset

Creates a synthetic dataset for traditional ML tasks. Supported types:

  • "classification", "regression", "blobs", "circles", "moons"

  • "friedman1", "friedman2", "friedman3", "sparse_uncorrelated", "multilabel"

Example:

from arch_eval.data import create_synthetic_dataset

# Binary classification
dataset = create_synthetic_dataset("classification", {
    "n_samples": 1000,
    "n_features": 20,
    "n_classes": 2,
    "n_informative": 10
})

create_synthetic_text_dataset

create_synthetic_text_dataset(params: Dict[str, Any]) -> TextDataset

Creates synthetic text data for language modeling. Generates token sequences with configurable structure.

Parameters:

  • vocab_size: Size of vocabulary (default: 1000)

  • seq_length: Sequence length (default: 128)

  • n_samples: Number of samples (default: 1000)

  • entropy: Randomness level 0-1, lower = more pattern (default: 0.8)

  • random_state: Random seed (default: 42)

Example:

from arch_eval.data import create_synthetic_text_dataset

# Generate synthetic text for transformer pretraining
dataset = create_synthetic_text_dataset({
    "vocab_size": 32000,
    "seq_length": 512,
    "n_samples": 10000,
    "entropy": 0.7  # Some structure, not purely random
})

config = TrainingConfig(
    dataset="synthetic text",
    dataset_params={
        "vocab_size": 32000,
        "seq_length": 512,
        "n_samples": 10000
    },
    task="next-token-prediction",
    training_args={"batch_size": 32, "num_epochs": 10}
)

create_synthetic_image_dataset

create_synthetic_image_dataset(params: Dict[str, Any]) -> ImageDataset

Creates synthetic image data for vision tasks. Supports random noise, gradients, and geometric shapes.

Parameters:

  • img_size: Image size as int or tuple (H, W) (default: 32)

  • channels: Number of channels (default: 3)

  • n_samples: Number of samples (default: 1000)

  • n_classes: Number of classes (default: 10)

  • pattern: Type of pattern - 'random', 'gradient', 'shapes' (default: 'random')

  • random_state: Random seed (default: 42)

Example:

from arch_eval.data import create_synthetic_image_dataset

# Generate synthetic images with geometric shapes
dataset = create_synthetic_image_dataset({
    "img_size": 64,
    "channels": 3,
    "n_samples": 5000,
    "n_classes": 10,
    "pattern": "shapes"  # Squares, circles, triangles, lines
})

config = TrainingConfig(
    dataset="synthetic image",
    dataset_params={
        "img_size": 64,
        "channels": 3,
        "pattern": "shapes"
    },
    task="classification",
    training_args={"batch_size": 64, "num_epochs": 20}
)

create_synthetic_vision_language_dataset

create_synthetic_vision_language_dataset(params: Dict[str, Any]) -> VisionLanguageDataset

Creates synthetic vision-language paired data for multi-modal tasks.

Parameters:

  • img_size: Image size (default: 32)

  • channels: Number of image channels (default: 3)

  • vocab_size: Vocabulary size for text (default: 1000)

  • seq_length: Text sequence length (default: 64)

  • n_samples: Number of samples (default: 500)

  • correlation: How much text correlates with image class 0-1 (default: 0.7)

  • random_state: Random seed (default: 42)

Example:

from arch_eval.data import create_synthetic_vision_language_dataset

# Generate synthetic image-caption pairs
dataset = create_synthetic_vision_language_dataset({
    "img_size": 32,
    "channels": 3,
    "vocab_size": 1000,
    "seq_length": 64,
    "n_samples": 2000,
    "correlation": 0.8  # High correlation between image class and text
})

config = TrainingConfig(
    dataset="synthetic vision_language",
    dataset_params={
        "img_size": 32,
        "vocab_size": 1000,
        "seq_length": 64,
        "correlation": 0.8
    },
    task="next-token-prediction",  # Or custom multi-modal task
    training_args={"batch_size": 32, "num_epochs": 15}
)

Distributed Utilities

init_distributed

init_distributed(backend="nccl", world_size=1, rank=0, master_addr="127.0.0.1", master_port="29500")

Initializes the distributed process group.

cleanup_distributed

cleanup_distributed()

Destroys the process group.

get_wrapped_model

get_wrapped_model(model, config)

Returns the model wrapped with the appropriate distributed wrapper (DataParallel, DistributedDataParallel, or FSDP).


Logging

setup_logging

setup_logging(level="INFO", log_file=None, fmt=None)

Configures root logger with console output and optional file rotation. Returns the root logger.

LoggerAdapter

Provides consistent logging with a "arch_eval." prefix.


Metrics

MetricCalculator

Computes task‑specific metrics (accuracy, precision, recall, F1, AUC, R², MSE, perplexity, etc.) and accumulates data for confusion matrices.

Supported Model Output Formats:

The MetricCalculator automatically handles various output formats from different model architectures:

  • Tensor: Standard PyTorch tensor output (batch_size, num_classes) or (batch_size, seq_len, vocab_size)

  • Tuple: Models returning (logits, loss), (loss, logits), or similar tuple patterns

  • Dict: Transformer models returning {"logits": ..., "loss": ...} or similar dictionaries

  • Hugging Face Style: Objects with .logits and .loss attributes (e.g., CausalLMOutput, SequenceClassifierOutput)

Example with Hugging Face style output:

from transformers.modeling_outputs import CausalLMOutput

class MyTransformer(nn.Module):
    def forward(self, input_ids, labels=None):
        logits = self.model(input_ids)
        loss = compute_loss(logits, labels) if labels is not None else None
        return CausalLMOutput(loss=loss, logits=logits)

# Trainer and MetricCalculator will automatically extract loss and logits
trainer = Trainer(model, config)
history = trainer.train()

Plugins

hook decorator

from arch_eval.plugins import hook

@hook("before_training")
def my_hook(trainer, config):
    ...

Marks a function as a plugin hook. The function will be called at the appropriate time if the plugin is loaded.


Profiling

profiler_context

from arch_eval.profiler import profiler_context

with profiler_context(config):
    # training code

Context manager that conditionally enables PyTorch’s torch.profiler based on the profiler field in the config.


Utilities

Device utilities

get_optimal_device() -> str
get_device_info() -> Dict[str, Any]
memory_summary() -> str

auto_device decorator

from arch_eval.utils import auto_device

@auto_device
def my_function(tensor):
    ...

Automatically moves input tensors to the device of the first argument (or inferred from a tensor) and optionally returns CPU tensors.


Visualization

RealtimeWindow

Displays live metric plots and system resource usage.

VideoRecorder

Records frames of metric plots and assembles them into a video using ffmpeg.

PlotSaver

Saves final metric plots (PNG) to disk.


Exceptions

All custom exceptions inherit from ArchEvalError:

  • DatasetFormatError

  • ConfigurationError

  • ModelError

  • PluginError

  • VisualizationError

  • StopTraining – can be raised by plugins to gracefully stop training.

  • DistributedError