ant is a PyTorch-based Deep Learning framework. It is inspired by nanoGPT, although it does not borrow code from it. It tries to keep being research-friendly, while integrating (way) more features, following modern best practices and being modular. Specifically, among others, ant supports:
- The CIFAR-10, OpenWebText, FineWeb-Edu and ClimbMix (SoTA) datasets
- Any Hugging Face or TokenMonster (SoTA) tokenizer
- ResNet, GPT2, Llama 2, nGPT, and OLMo 2 models
- The PSGD, DistributedShampoo, AdEMAMix, SOAP, Muon and Scion optimizers
- μP for zero-shot hyperparameter transfer
- Downstream evaluations (HellaSwag, ARC etc.) through lm_eval during training
- Model summary (parameters and FLOPS) and computational graph visualization
- Offline logging of gradients and weights in simple .dat files
- Attention heatmaps
- Plotting through PGFPlots
- Pure PyTorch attention, FlashAttention, FlexAttention and cuDNN attention with RoPE, ALiBi and Sliding Window Attention (SWA)
- Distributed Data Parallel, torch.compile and Automatic Mixed Precision
Tip
All scripts have a help menu available via --help. In cases this is not enough, it is recommended that you look at the code (self-documenting) directly. Alternatively, if you are a vibe coder 🤖, you can try feeding the whole codebase to an LLM (e.g. via https://uithub.com/gvlassis/ant), or a coding agent (e.g. Codex CLI).
-
Clone the repo:
git clone https://github.com/gvlassis/ant.git -
Install PyTorch and FlashAttention.
-
Install
requirements.txt:pip install -r requirements.txt -
Prepare the dataset via
./src/data/make.py. The dataset is first downloaded from Hugging Face, processed, and then saved as tensors in .pt files. If you are lazy 🦥, you can also directly download the artifacts of the following command.python ./src/data/make.py --dataset climbmix10m -
Train a neural network via
./src/train.py. For training oncuda:0:# If you are not using μP, k_input is the learning rate python ./src/train.py --opt muon --micro_batch_size 32 --train_batches 2000 --k_input 3e-2 --momentum 0.95 --model_device_index 0 ./out/testA lot of settings (e.g. depth, number of heads) are configured in
./src/models/utils_models.py/get_model_opts()and./src/models/transformer.py. For one node with 4 GPUs:OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=4 ./src/train.py --opt shampoo --micro_batch_size 32 --train_batches 2000 --k_input 3e-3 --momentum 0.95 --beta2 0.95 --eps 1e-10 ./out/test



