Case study of Grokking via Mechanistic Interpretability
Notes from YT video: The most complex model we actually understand - Welch Labs - https://www.youtube.com/watch?v=D8GOeCFFby4
The video presents a paper [pdf:Progress measures for Grokking via Mechanistic Interpretability] that investigates how grokking occurs. It takes the problem (x+y) mod 113 and a network with one attention layer, followed by a two layer MLP. The numbers are input as tokens (one-hot encoding).
It turns out the initially the model remembers the training data. But after long period of traning,
- the attention layer learns the sine and cosine of the inputs (sin x, sin y),
- the first MLP learn the product of those sine and cosines (sin x cos y)
- the second layer learns the sum of those products (cos x cos y - sin x sin y = cos (x + y)), effectively doing a modular addition,
- and then final layer undos the cosine and sines to get the result.
Weight decay moves the model towards the sine and cosine solution instead of letting it rest at the memorized solution. The memorized solution uses a lot of weights while the sine and cosine based solution is cleaner and requires fewer active parameters.
It turns out that training can be split into three phases:
- Memorization of the training data;
- Circuit formation, where the network learns a mechanism that generalizes
- Cleanup where weight decay removes the memorization components.
Grokking happens during the cleanup phase. Thus grokking, rather than being a sudden shift, arises from the gradual amplification of structured mechanisms encoded in the weights, followed by the later removal of memorizing components.