Skip to content

moaziat/jaxGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

56 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxGPT pipeline

jaxGPT

jaxGPT - A Minimal GPT implementatoin in JAX

A lightweight implementation of a GPT-like transformer model with 1.6M parameters (for now!!) using JAX and Flax's NNX module. This code demonstrates autoregressive text generation, multi-head attention, and a training loop with Optax. The code handles training on multiple devices automatically through positional sharding.

jaxGPT

Installation

  1. Ensure JAX is installed (follow instructions for your hardware here).
  2. Install dependencies:
pip install flax optax

Run it

clone the repo then:

python train.py --max_iters 200000 --lr 1e-3

Output

Example 1: Training on the tiny-shakespeare dataset:

  • generated output:

KING RIVH:
It genton. I was the grief
Forth the time of is offected: how God noug-leakes?
Your wift quany at stand pass--
Stoation incuttection to-danius shorate anginganeness,
Thy switn's roveried skily before in grief
Would thee stoble Rome of find,
Mach thing bound you will all, and Clifford!

FLORIZEL:
Who should had thanks you
Formit gaters thee gods, and swear be
Intives that thither-buney, we heart this grule commends
By the ritory striken time of their
conselds, let where I will becalle

Example 2: Training on a names dataset

  • names generated:
breesmin
kindley
britlynn
iassiqu
floreca
catle
peava
theania
yaanira
hanvaron
neasea
orrell
cinjaya
daraz
quaniella
blaver
jushadomie
treymeca
shayana
julene
share
istofano
nariah
mckinda
iskia
kashir
arhiyah
letena
estina
asamie
azur
aralie
stara
keiliy
yyena
josebe
breya
sinna
seondro
ramek
zariah
shakerra
fairah
demariy
siara
clanda
saprion
naimah
shety
marlon
cailla

References

About

GPT in jax/flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages