Gautier Hamon
hamongautier.bsky.social
Gautier Hamon
@hamongautier.bsky.social
PhD student at INRIA Flowers team. MVA master
reytuag.github.io/gautier-hamon/
8/ For the curious, here are the achievements success rate on craftax across training, training for 1e9 steps (left) and training for 4e9 steps (right).
November 22, 2024 at 10:16 AM
7/ The JAX ecosystem in RL is currently blooming with wonderful open-sources projects from others that I linked at the bottom of the repository. github.com/Reytuag/tran...
This work was done at @FlowersINRIA
.
Also feel free to reach me if you have questions or suggestions !
GitHub - Reytuag/transformerXL_PPO_JAX
Contribute to Reytuag/transformerXL_PPO_JAX development by creating an account on GitHub.
github.com
November 22, 2024 at 10:16 AM
6/ Potential next steps could be to test it on Xland-Minigrid
, to test it on an Open-Ended meta-RL environment github.com/dunnolab/xla...
I'm also curious to implement Muesli (arxiv.org/abs/2104.06159) with transformerXL as in arxiv.org/abs/2301.07608
November 22, 2024 at 10:16 AM
5/Here is the training curve obtained from training for 1e9 steps, reporting the scores from PPO and PPO-RNN provided in the craftax repo.
Noting that PPO-RNN was already beating other baselines with Unsupervised Environment Design and intrinsic motivation. arxiv.org/pdf/2402.16801
November 22, 2024 at 10:16 AM
4/ Testing it on the challenging Craftax from github.com/MichaelTMatt...
(with little hyperparameter tuning), it obtained higher returns in 1e9 steps than PPO-RNN.
Training it for longer, led to the 3rd floor in craftax, making it the first to get advanced achievements.
GitHub - MichaelTMatthews/Craftax: (Crafter + NetHack) in JAX. ICML 2024 Spotlight.
(Crafter + NetHack) in JAX. ICML 2024 Spotlight. Contribute to MichaelTMatthews/Craftax development by creating an account on GitHub.
github.com
November 22, 2024 at 10:16 AM
3/
Training a 3M parameters Transformer for 1e6 steps in MemoryChain-bsuite (from gymnax) takes 10s on a A100. (with 512 env)
Training a 5M parameters Transformer for 1e9 steps in craftax takes ~6h on a single A100. (with 1024 envs)
We also support multi-GPU training.
November 22, 2024 at 10:16 AM
2/ We implement TransformerXL-PPO following "Stabilizing Transformers for Reinforcement
Learning" arxiv.org/abs/1910.06764
The code follows the template from PureJaxRL github.com/luchris429/p...
⚡️Training is fast thanks to JAX
Stabilizing Transformers for Reinforcement Learning
Owing to their ability to both effectively integrate information over long time horizons and scale to massive amounts of data, self-attention architectures have recently shown breakthrough success in ...
arxiv.org
November 22, 2024 at 10:16 AM
The video encoding might not do it full justice.
Paper: direct.mit.edu/isal/proceed...
November 22, 2024 at 10:04 AM
November 22, 2024 at 10:00 AM