Kavya G

Online Speculative Decoding

This is an ICML’24 paper with Ion Stoica as one of the authors. We had discussed and come up with this idea and stumbled on this as a prior art which solves the same problem in nearly the same way. So obviously, I am already motivated about the problem and think the solution is clever and neat (!). My only disappointment is that evaluation section of this paper was not as systemsy as I would expect from an Ion Stoica paper (though the appendix does solve some of my concerns)

Latency is one of the biggest problems of inferencing (due to the auto-regressive nature) of LLMs. One of the most popular techniques to reduce this latency is Speculative Decoding, where a smaller draft model predicts the target model’s outputs. This essentially parallelizes the target model generation in case of correct prediction. However, the efficiency of the speculation and the corresponding savings depends heavily on the accuracy of the draft model, which could deviate over time due to the changing nature of the queries. This paper presents an online speculative decoding technique to continuously improve the (usually static) draft model to match the query distribution of the target model.

Apart from just contributing a dynamic draft model technique, they also introduce specialized draft models, each corresponding to different domains and topics. They propose to route “speculation” requests to the best expert, to achieve better speculation accuracy. These expert draft models, which need to be fine-tuned on the fly, can either be trained in a cloud environment with the updated models pushed over the network or using the paper’s proposed approach of using free computation capacity of the serving GPUs.

Architecture

Training

They use knowledge distillation (where a student model’s predictive distribution is aligned with a teacher model’s) for training the draft model.

They explain here that using probability outputs along with KL divergence does better than normal fine-tuning with label based loss calculation.

Adaptation

For dynamic training of the draft model, they maintain a Replay Buffer to log the instances where the speculation was incorrect (as well as the expected output). They then train periodically to update the model

Routing

Their specialization/routing strategy is nothing great or fancy. They slap a small BERT classification model to bucket queries into topics and assign each draft model a topic (by only training on those).

Evaluation

The key metric to consider is token acceptance rate (alpha)

The eval section was a bit disappointing and not as systems-heavy as I had hoped. They use a dataset as the set of queries. To assess the system, they train the starting draft model on the first 10% of the dataset, then slowly expose it to more and more data and measure the acceptance ratio.

As the upper bound baseline, they use a static model trained on the entire dataset which has the best acceptance ratio.

They also evaluate the hypothesis that a single draft model will perform worse than multiple draft models, but this again seems highly dataset specific.

Finally, they compare against Medusa (a competitive speculative decoding technique), an actual apples-to-apples comparison, on some datasets to show better speedups. Again it is not clear why these particular datasets were chosen.

Key Takeaways

  1. The paper presents a math representation of speculative decoding (probably not the first ones to do it, but the first one I am reading), which is interesting. This math model assumes an acceptance rate for the draft model and calculates the expected speedup and latency from speculating!
  2. They use Knowledge Distillation for learning the draft model. They specifically mention that instead of normal fine-tuning which uses the label for loss calculation, their technique which uses KL divergence on the probability output distribution performs better. They are able to do this because the draft model tries to copy the target model’s exact probability distribution as well (so they have that in hand, unlike typical generation cases).
  3. The evaluation section was a bit meh for me, since they evaluated the core claim of their system (online stream of queries, draft model diverges over time, so need to adapt) purely on datasets to model the queries. I feel that a trace data on inference requests could have been a better benchmark. On a similar note, the specialized models benchmarking is done on a completely handpicked dataset, leaving it unclear whether it actually improves accuracy on a generic user input stream.
  4. Edit: They have an appendix with more microbenchmarks that addresses some of the points above, for ex. they test on a real chat workload, but they still cheat a bit on the specialized draft models by picking different languages as the different specializations.