Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Finding Backward Chaining Circuits in Transformers Trained on Tree Search, published by abhayesian on May 29, 2024 on LessWrong.
This post is a summary of our paper A Mechanistic Analysis of a Transformer Trained on a Symbolic Multi-Step Reasoning Task (ACL 2024). While we wrote and released the paper a couple of months ago, we have done a bad job promoting it so far. As a result, we're writing up a summary of our results here to reinvigorate interest in our work and hopefully find some collaborators for follow-up projects.
If you're interested in the results we describe in this post, please see the paper for more details.
TL;DR - We train transformer models to find the path from the root of a tree to a given leaf (given an edge list of the tree). We use standard techniques from mechanistic interpretability to figure out how our model performs this task. We found circuits that involve backward chaining - the first layer attends to the goal and each successive layer attends to the parent of the output of the previous layer, thus allowing the model to climb up the tree one node at a time.
However, this algorithm would only find the correct path in graphs where the distance from the starting node to the goal is less than or equal to the number of layers in the model. To solve harder problem instances, the model performs a similar backward chaining procedure at insignificant tokens (which we call register tokens). Random nodes are chosen to serve as subgoals and the model backward chains from all of them in parallel.
In the final layers of the model, information from the register tokens is merged into the model's main backward chaining procedure, allowing it to deduce the correct path to the goal when the distance is greater than the number of layers. In summary, we find a parallelized backward chaining algorithm in our models that allows them to efficiently navigate towards goals in a tree graph.
Motivation & The Task
Many people here have conjectured about what kinds of mechanisms inside future superhuman systems might allow them to perform a wide range of tasks efficiently. John Wentworth coined the term
general-purpose search to group several hypothesized mechanisms that share a couple of core properties. Others have proposed projects around how to
search
for
search inside neural networks.
While general-purpose search is still relatively vague and undefined, we can study how language models perform simpler and better-understood versions of search. Graph search, the task of finding the shortest path between two nodes, has been the cornerstone of algorithmic research for decades, is among the first topics covered by virtually every CS course (BFS/DFS/Djikstra), and serves as the basis for
planning algorithms in GOFAI systems. Our project revolves around understanding how transformer language models perform graph search at a mechanistic level.
While we initially tried to understand how models find paths over any directed graph, we eventually restricted our focus specifically to trees. We trained a small GPT2-style transformer model (6 layers, 1 attention head per layer) to perform this task. The two figures below describe how we generate our dataset, and tokenize the examples.
It is important to note that this task cannot be solved trivially. To correctly predict the next node in the path, the model must know the entire path ahead of time. The model must figure out the entire path in a single forward pass. This is not the case for a bunch of other tasks proposed in the literature on evaluating the reasoning capabilities of language models (see
Saparov & He (2023) for instance). As a result of this difficulty, we can expect to find much more interesting mechanisms in our models.
We train our model on a dataset of 150,000 randomly generated trees. The model achieves an ac...
view more