A lightweight implementation of a GPT-like transformer model with 1.6M parameters (for now!!) using JAX and Flax's NNX module. This code demonstrates autoregressive text generation, multi-head attention, and a training loop with Optax. The code handles training on multiple devices automatically through positional sharding.
- Ensure JAX is installed (follow instructions for your hardware here).
- Install dependencies:
pip install flax optaxclone the repo then:
python train.py --max_iters 200000 --lr 1e-3Example 1: Training on the tiny-shakespeare dataset:
- generated output:
KING RIVH:
It genton. I was the grief
Forth the time of is offected: how God noug-leakes?
Your wift quany at stand pass--
Stoation incuttection to-danius shorate anginganeness,
Thy switn's roveried skily before in grief
Would thee stoble Rome of find,
Mach thing bound you will all, and Clifford!
FLORIZEL:
Who should had thanks you
Formit gaters thee gods, and swear be
Intives that thither-buney, we heart this grule commends
By the ritory striken time of their
conselds, let where I will becalle
Example 2: Training on a names dataset
- names generated:
breesmin
kindley
britlynn
iassiqu
floreca
catle
peava
theania
yaanira
hanvaron
neasea
orrell
cinjaya
daraz
quaniella
blaver
jushadomie
treymeca
shayana
julene
share
istofano
nariah
mckinda
iskia
kashir
arhiyah
letena
estina
asamie
azur
aralie
stara
keiliy
yyena
josebe
breya
sinna
seondro
ramek
zariah
shakerra
fairah
demariy
siara
clanda
saprion
naimah
shety
marlon
cailla
- https://github.com/karpathy/ng-video-lecture
- https://arxiv.org/abs/1706.03762v7: Attention Is All You Need

