continuous chain of thought: Nonlinear Function
Created:
Modified:

continuous chain of thought

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

it seems clear that forcing transformers to produce discrete tokens is a significant constraint on their reasoning ability. tokens collapse the 'state of consciousness' in the residual stream. quite a lot of human reasoning and "deep intuition" is not token-based and relies on these longer chains of continuous thought.

a big advantage of continuous chains of thought is that we can train them by backprop from a final supervision.

unfortunately, there's a fundamental tradeoff in transformer expressivity. transformers are efficient to train because there are no sequential paths of length >D> D (number of layers) between any pair of input and output tokens. conditioning on all the tokens lets us parallelize the training process across positions. if you add a pathway from the last layer at one position, to the first layer at the next position, you allow for a potentially infinite-depth sequential computation. this is great, because it's more expressive! but you essentially now have an RNN again, with the residual stream as latent state (formally the state would also include the KV cache but it's the residual stream that doesn't decouple); you've given up the parallelism advantages of the transformer architecture.

the CoCoNUT work (https://arxiv.org/abs/2412.06769) gets around this by only splicing in finite-length continuous chains of thought. the transformer switches between "language mode", where it outputs tokens, and "latent mode", where it thinks continuously. this blows up the training cost by the length of the latent chains of thought, but they use relatively short lengths (k×ck \times c, replacing k=3k=3 discrete tokens with c=2c=2 continuous passes per token) and a small model (GPT-2) so it's tractable. they find that:

  • they generally need to prime the model with a reasoning 'curriculum': first train on discrete CoT, then gradually replace the discrete steps with continuous steps
  • continuous CoT 'works' in the sense of being better than no chain of thought. it outperforms a discrete chain of thought on two of their three tasks, but is outperformed on the third. (some tasks benefit more than others from deeper compute)
  • training got unstable if they increased the continuous CoT length to c=3c=3 . presumably they're running into some version of vanishing/exploding gradients from the deep compute path. since this is effectively an RNN, I wonder if adopting the LSTM tricks would work here?
    • evidence that the continuous chain of thought does something BFS or dynamic-programming-like. the probabilities of graph nodes evolve over a continuous chain of thought