Post

๐๐ž๐ฐ ๐๐ž๐œ๐จ๐๐ข๐ง๐  ๐ญ๐ž๐œ๐ก๐ง๐ข๐ช๐ฎ๐ž ๐ข๐ง ๐ญ๐ซ๐š๐ง๐ฌ๐Ÿ๐จ๐ซ๐ฆ๐ž๐ซ๐ฌ ๐ฌ๐ข๐ ๐ง๐ข๐Ÿ๐ข๐œ๐š๐ง๐ญ๐ฅ๐ฒ ๐ซ๐ž๐๐ฎ๐œ๐ž๐ฌ ๐ก๐š๐ฅ๐ฅ๐ฎ๐œ๐ข๐ง๐š๐ญ๐ข๐จ๐ง๐ฌ ๐Ÿ‘

DoLa decoding, which made a conference paper at ICLR โ€˜24, has just been merged in Transformers by Joรฃo Gante and Yung-Sung Chuang.

This new decoding method is simple yet extremely impressive!

Reminder: Decoder LLMs (the GPT kind of LLM, the most common one) generate their outputs one token at a time: at each step, given a current text, they compute, for each token in their vocabulary, a โ€œlogitโ€ that should represent the probability that this token is coming next.

Then the decoder either picks the highest logit token (greedy decoding) or samples one with a probability defined by the logits (sampling). The token gets appended to the current text, and the decoder compute logits again, and the cycle continues.

The authors of DoLa wanted to improve that simple method.

They knew this established fact that transformer LMs encode low-level info (like base syntax) in early layers and more high-level info like knowledge in the later layers.

๐Ÿ’ก This gave them their key idea: During decoding, rather than picking the token with the highest logit, ๐˜„๐—ต๐˜† ๐—ป๐—ผ๐˜ ๐—ฝ๐—ถ๐—ฐ๐—ธ ๐˜๐—ต๐—ฒ ๐˜๐—ผ๐—ธ๐—ฒ๐—ป ๐˜„๐—ถ๐˜๐—ต ๐˜๐—ต๐—ฒ ๐—บ๐—ผ๐˜€๐˜ ๐—ถ๐—บ๐—ฝ๐—ฟ๐—ฒ๐˜€๐˜€๐—ถ๐˜ƒ๐—ฒ ๐—ถ๐—ป๐—ฐ๐—ฟ๐—ฒ๐—ฎ๐˜€๐—ฒ ๐—ถ๐—ป ๐—น๐—ผ๐—ด๐—ถ๐˜ ๐—ฎ๐—ฐ๐—ฟ๐—ผ๐˜€๐˜€ ๐—น๐—ฎ๐˜†๐—ฒ๐—ฟ๐˜€?

Implementation is actually quite simple: at each step, you get the layer for which the logits diverge most from your final layer, and this chosen layer becomes the premature layer. Then you subtract the logits from the premature layer to your final layer, in order to reward tokens for which the logits progressed most. And this lets you pick your next token.

Their test settings:

  • โžค Test 4 sizes of Llama-1 models (7b to 65B)
  • โžค Benchmarks on multiple choice QA (TruthfulQA, FACTOR) and open-ended QA (TruthfulQA-open-ended, GSM8K)

โœจ ๐—ฅ๐—ฒ๐˜€๐˜‚๐—น๐˜๐˜€ ๐—ฎ๐—ฟ๐—ฒ ๐—ฒ๐˜…๐˜๐—ฟ๐—ฒ๐—บ๐—ฒ๐—น๐˜† ๐—ถ๐—บ๐—ฝ๐—ฟ๐—ฒ๐˜€๐˜€๐—ถ๐˜ƒ๐—ฒ!

  • ๐Ÿš€ ๐Ÿฑ% - ๐Ÿฎ๐Ÿฌ% ๐—ฏ๐—ฎ๐˜€๐—ฒ ๐—ฝ๐—ผ๐—ถ๐—ป๐˜๐˜€ ๐—ถ๐—ป๐—ฐ๐—ฟ๐—ฒ๐—ฎ๐˜€๐—ฒ ๐—ฎ๐—ฐ๐—ฟ๐—ผ๐˜€๐˜€ ๐˜๐—ต๐—ฒ ๐—ฏ๐—ฒ๐—ป๐—ฐ๐—ต๐—บ๐—ฎ๐—ฟ๐—ธ๐˜€
  • ๐Ÿš€ For instance on TruthfulQA / Open-ended, across all model sizes the increase in truthfulness is 14 base points, which is ๐—ฎ๐—ฟ๐—ผ๐˜‚๐—ป๐—ฑ ๐Ÿฐ๐Ÿฌ% ๐—ถ๐—บ๐—ฝ๐—ฟ๐—ผ๐˜ƒ๐—ฒ๐—บ๐—ฒ๐—ป๐˜ ๐—ฐ๐—ผ๐—บ๐—ฝ๐—ฎ๐—ฟ๐—ฒ๐—ฑ ๐˜๐—ผ ๐˜€๐˜๐—ฎ๐—ป๐—ฑ๐—ฎ๐—ฟ๐—ฑ ๐—ฑ๐—ฒ๐—ฐ๐—ผ๐—ฑ๐—ถ๐—ป๐—ด!

๐Ÿค” Wouldnโ€™t decoding take longer because of this added contrasting step?

  • ๐Ÿ‘‰ ๐—ง๐—ต๐—ฒ ๐—ฟ๐˜‚๐—ป๐˜๐—ถ๐—บ๐—ฒ ๐—ถ๐—ป๐—ฐ๐—ฟ๐—ฒ๐—ฎ๐˜€๐—ฒ ๐—ถ๐˜€ ๐—ป๐—ฒ๐—ด๐—น๐—ถ๐—ด๐—ถ๐—ฏ๐—น๐—ฒ, ๐Ÿญ ๐˜๐—ผ ๐Ÿด% ๐—ผ๐—ป๐—น๐˜†.

The paper has additional insights such as how token confidence evolves across layers for different types of tokens: I recommend to read it!

๐™๐™š๐™–๐™™ ๐™ž๐™ฉ ๐™๐™š๐™ง๐™š ๐Ÿ‘‰ https://huggingface.co/papers/2309.03883

 DoLa new Decoding Tech

This post is licensed under CC BY 4.0 by the author.