Nikunj Saunshi
nsaunshi.bsky.social
Nikunj Saunshi
@nsaunshi.bsky.social
AI Reasoning and Foundations
Senior Research Scientist, Google |
PhD, Princeton University
It has been fascinating to uncover this surprising inductive bias of stacking, beyond efficiency. Hopefully understanding this would provide insights into improving reasoning abilities of models. This is in collab with SWEFF/BigML team at Google.

Paper link: arxiv.org/abs/2409.19044

10/10
On the Inductive Bias of Stacking Towards Improving Reasoning
Given the increasing scale of model sizes, novel training strategies like gradual stacking [Gong et al., 2019, Reddi et al., 2023] have garnered interest. Stacking enables efficient training by gradua...
arxiv.org
March 10, 2025 at 3:41 PM
Why this inductive bias? We don’t have a full picture yet. As an initial step, we uncover a connection of stacking to looped models (like Albert, Universal Transformer) whose iterative structure could be conducive for reasoning. (More on this in a follow-up paper) 9/n
March 10, 2025 at 3:41 PM
Are these improvements real? For a more robust verification, we construct reasoning primitives – simple synthetic tasks that are building blocks for reasoning – and find that MIDAS is significantly better on these primitives, up to +20% in absolute accuracy for some. 8/n
March 10, 2025 at 3:41 PM
Furthermore, for various math word problems (e.g. SVAMP, AsDiv, MAWPs and GSM8k), MIDAS significantly improves the few-shot performance compared to baseline, at similar perplexity. Furthermore, these gains prevail (or improve) after fine-tuning on GSM8k. 7/n
March 10, 2025 at 3:41 PM
This is most apparent for the TydiQA benchmark: for the same set of questions, MIDAS shows much larger improvements for the GoldP variant, where context is provided for each question, as against the NoContext variant. This clearly shows a bias towards contextual reasoning. 6/n
March 10, 2025 at 3:41 PM
We plot perplexity vs downstream evals as training proceeds for various downstream categories. Surprisingly, we find that MIDAS improves the most on tasks the require reasoning from context (open book QA, math problems), as opposed to memorization (closed book QA) 5/n
March 10, 2025 at 3:41 PM
Hypothesis: this is an “inductive bias” of stacking à la (Saunshi et al., Liu et al.) – good pretraining performance can be achieved in many ways & some methods have a bias towards learning more transferable skills for downstream tasks. So what bias does stacking have? 4/n
March 10, 2025 at 3:41 PM
Firstly, when trained on the same tokens as baseline, MIDAS achieves similar perplexity with ~25% less FLOPs/time. However, at the same perplexity, MIDAS can significantly improve downstream evals for a variety of tasks (thus the 40% speedup)! How is that possible? 3/n
March 10, 2025 at 3:41 PM
Stacking (Gong et al., Reddi et al.) gradually grows model depth in stages, by duplicating some layers in each stage, and has reduced training time for BERT. We show that a variant, MIDAS, can also make language model training ~40% efficient. What does this efficiency mean? 2/n
March 10, 2025 at 3:41 PM