Lucas Prieto
lucas-prieto.bsky.social
Lucas Prieto
@lucas-prieto.bsky.social
PhD Student at Imperial
Thank you to my co-authors Tolga Birdal, Pedro Mediano
and Melih Barsbey for all their help!
January 10, 2025 at 3:53 PM
This simple intervention removes the delay in generalization characteristic of grokking, leading to train and test performance increasing in tandem!
January 10, 2025 at 3:52 PM
Since the gradient does not fully align with the NLM direction, we tried ignoring the component of the gradient that aligns with the direction of the weights, preserving only the part of the gradent that is orthogonal to the NLM direction.
January 10, 2025 at 3:52 PM
For homogeneous models, naïvely scaling the logits corresponds to scaling the weights along their current direction, and we observe that after fitting the trainign data, MLP and transformer gradients do align with the direction of the weights.
January 10, 2025 at 3:52 PM
Why do tasks like modular addition lead to SC? We find ease of overfitting to be a crucial aspect. Beyond the point of 100% training accuracy, models can decrease the CE loss by scaling the logits without changing the decision boundary. We call this Naïve Loss Minimization (NLM).
January 10, 2025 at 3:52 PM
This simple intervention prevents SC and leads to grokking without weight decay, with weight norms increasing during generalization.
January 10, 2025 at 3:52 PM
To validate this hypothesis we replace the exponential in the Softmax with a softer ramp function s(x).
January 10, 2025 at 3:52 PM
Looking at modular addition, we notice that models can start grokking without weight decay but generalization stops suddenly when a large fraction of the samples face SC. As expected, this point is delayed when we increase floating point precision.
January 10, 2025 at 3:52 PM
These "absorption" errors, lead to zeros in the loss and the gradient, putting an end to learning!
January 10, 2025 at 3:52 PM
Why do we not see grokking by default when using cross-entropy loss? The surprising answer is a specific kind of floating point error in the Softmax we call Softmax Collapse (SC)!
January 10, 2025 at 3:52 PM