2026-04-22

LLM Pre-Training

Table of Contents

In the era of LLMs, attention has shifted from designing novel architectures to rather systematic scaling, massive scale data engineering and optimized distributed training using large clusters.

1. Optimal ratio of Tokens to Parameters

1:1 Parameter Dominance:

In 2020, Jared Kaplan from OpenAI established the Kaplan laws of scaling. For a fixed compute budget \(C\), number of optimal parameters \(N \propto C^{0.73}\) and number of optimal tokens \(D \propto C^{0.27}\). This showed that parameters needed to be increase more than the tokens and led to a ratio of \(D:N = 1:1 \ or\ 2:1\). But this analysis was based on small model size and didn't count token embedding parameters as model parameter.

20:1 Chinchilla Paradigm:

In 2022, the "Chinchilla" study by DeepMind revealed a better relation. \(N \propto C^{0.5}\), and \(D \propto C^{0.5}\). This mean that data and model parameters need to be scaled at the same rate. This leads to ratio of \(D:N = 20:1\) i.e. the optimal number of tokens is 20 times the number of parameters of model. So, the previous models like GPT-3 were undertrained.

Inference-Aware Scaling:

Both analysis took into account only the training compute. If we also want to take inference cost of the model, then it is more optimal to train a smaller model for much greater number of tokens. For example Lamma 3 8B was trained of 15 Trillion tokens which is a \(D:N \approx 1875:1\) ratio. This was we get a model that performs much like larger system while remaining cheaper to serve.

2. Data Pipeline

With such massive datasets to train on, the process of curating the dataset (i.e. deduplication, data quality, and data mix), and tokenizing also requires special attention.

2.1. Tokenization

Uses algorithms like:

  1. BPE (Byte-Pair Encoding)
  2. Unigram
  3. SentencePiece (more language agnostic; applies BPE or Unigram)

See: this article for details.

2.2. Deduplication

Exact matching isn't feasbile with such large dataset. Hence some of the following techniques are used:

  1. Bloom Filters for exact match - If document hash is detected by filter, no need for expensive storage lookup
  2. MinHash LSH (Locality Sensitive Hashing) for Near Duplicate Detection

2.3. Quality Filtering

  • Heuristic Filtering: Based on word count, symbol-to-word ratio, mean word length, number of bullet points, number of stop words, KL divergence of token distribution with expected distribution in natural language and other heuristics, a document may be classified as gibberish or natural language

    But this approach can remove ~18% of good tokens.

  • Model Based Quality Filtering:
    1. Use a large capable model (e.g. Llama 3 70B) to score a subset of documents
    2. Train an small classifier on the resulting scored dataset
    3. Use that fast classifier for filtering

2.4. Data Mixing

Empirically it is found that natural language is good for fluency but code and math data are drivers of reasoning capability. A typical mix (Llama 3.1) is:

  • 50% General Knowledge
  • 25% Math and reasoning
  • 17% Code
  • 8% Multilingual

Some researchers do dynamic mixing where the proportion of code is higher in earlier phase to get good reasoning and then training on more general knowledge data for fluency.

At the end an "Annealing" phase is done where the the learning rate is decayed to zero and model is trained on hard maths and high quality reasoning traces.

3. Distributed Training

Classically there are 3 dimensions of parallelism:

  1. Data Parallelism: Batches are divided into mini batches
  2. Tensor Parallelism: Single weight matrix is sharded across GPUs
  3. Pipeline Parallelism: Model is divided into sequence of layers, and sharded across the GPU.

To keep the parameters upto date, after an optimizer update step, an reduce-scatter + all-gather on the parameters is required.

Apart from this there is another technique called ZeRO (Zero Redundancy Optimizer):

  1. Stage 1: Optimizer State are sharded across the data parallel GPUs. Each GPU updates its own shard of states. To transfer optimizer states requires a reduce-scatter + all-gather
  2. Stage 2: Gradient are also sharded. And the gradients are communicated just after they are computed and then discarded.
  3. Stage 3: Parameters are also sharded. Each GPU much fetch the parameters for the layer from the GPU that owns that layer. In addition to previous stage, an all-gather for parameters is needed.

Comparision of Communication Volume (where \(\Phi\) is model size):

Method Communication Memory Savings
Standard \(2\Phi\) 1x (Baseline)
ZeRO Stage 1 \(2\Phi\) 4x
ZeRO Stage 2 \(2\Phi\) 8x
ZeRO Stage 3 \(2\Phi\) Linear with GPUs

Activation recomputation/checkpointing is another technique to save memory.

4. Numerical Precision

Moving for floating point 32 to floating point 16 cand double the throughput and halve the memory requirement. But FP16 uses 5 bits for exponent and 10 bits for mantissa. This is good precision but small dynamic range (max value ~65,000).

So, Brain Float is used to solve this issue. BF16 uses 8 bits for exponent (same as FP32) and 7 bits for mantissa.

Additionally mixed precision training is used to optimize memory and communication overhead. Matrix multiplication can use 16 bits while while loss accumulation is done in FP32 to maintain convergence.


Backlinks


You can send your feedback, queries here