Zulip Chat Archive

Stream: Machine Learning for Theorem Proving

Topic: HiPPO papers (efficient sequence modeling)

Jason Rute (Oct 14 2021 at 00:25):

I don't normally mention non theorem-prover type stuff here, but I can't help mention this sequence of papers:

  1. HiPPO: Recurrent Memory with Optimal Polynomial Projections

    * Neurips 2020 spotlight
    * blog article
    * paper

  2. Combining Recurrent, Convolutional, and Continuous-time Models with Structured Learned Linear State-Space Layers

    * Neurips 2021 poster
    * abstract

  3. Efficiently Modeling Long Sequences with Structured State Space

    * Under review for ICLR 2022
    * preprint

Jason Rute (Oct 14 2021 at 00:25):

In sequence modeling you have a sequence of inputs u(t), where t can be discrete integer time steps, or continuous real time (where the actual data is sampled at a sample rate dt). Sequences can be words in a document, audio recordings, images flattened to one dimension, proteins, video frames in a video game. The goal of sequence modeling is to take this sequence u and output another sequence of the same length. The simplest sequence modeling task is autoregressive modeling, where you just try to predict u(t+1) from u(0),...,u(t). Other sequence tasks are to predict a summary of the sequence u. For example, one might predict from a IMBD movie review what the star rating would be. Now, roughly, there are three methods of using neural networks to work with sequence data:

  • If you are modeling language (autoregressively) or playing a video game, your are computing some output y(t-1) from the partial sequence u(0),...,u(t-1) and then using that y(t-1) to compute the next element u(t) of your sequence (which is then fed back into the next step of your model). Recurrent neural networks (RNNs), including LSTMs and GRUs, compute one step at a time, where they keep track of a hidden internal state summarizing what has happened before in the sequence. The problem is that the information in this hidden state tends to be forgotten over time (the vanishing gradient problem). Encoder transformers, like the GPT models, keep around the whole initial sequence and directly use u(0),...,u(t-1) to compute y(t-1). This is quadratic in the length of the sequence and can make it difficult to handle long sequences.
  • If one has the whole sequence at once, then there are more efficient techniques to compute over the sequence. For example, convolutional neural networks (CNNs) and Decoder Transformers do this. CNNs however only take into account local information in the sequence and not global information, making it intractable on long sequences if information has to pass from one end to the other. Once again, Transformers can be quite time and memory intensive since, even in this Decoder Transformer setting, they must compare every timestep t1 to every other timestep t2.
  • If the time step is (theoretically) continuous, then instead of directly computing an output y(t), you can model y(t) via a differential equation where the derivative y'(t) is computed by the neural network instead of y(t). (Think of a spring-mass system where an external force u(t) is also be applied to the mass, and we are trying to predict the position of the mass.) If we know the differential equation, we can roll-out the sequence as a simulation where we compute the value at regular time steps using the computed derivative.

Jason Rute (Oct 14 2021 at 00:25):

The problem with all of these systems is that they don't handle really long sequences tractably, especially the first two. Every month for the past year or two, there is a new paper trying to come up with a better long-sequence model (usually a variation of the Tranformer).

Jason Rute (Oct 14 2021 at 00:25):

Now, let me quickly summarize these three papers:

Jason Rute (Oct 14 2021 at 00:26):

The first paper (which has a great blog article by the authors summarizing the paper), rigorously looks at the problem of memorizing a sequence into a fixed sizes N-dimensional vector x. They show that not only is there a rigorous formalism of what it means to memorize the history of a sequence u, but that there is well behaved class of dynamical systems which do this. Basically, there are a family of nice NxN matrices A (called HiPPO matrices) where for a corresponding vector B, the system

x'(t) = Ax(t) + Bu(t)

will memorize the history of a continuous-time u(t) into the state vector x. They show one can drop in this HiPPO matrix (without training it) into an RNN for good results on long range dependency tasks, and that it provides an after-the-fact explanation of how to come up with the gating mechanisms for the LSTM and GRU models.

Jason Rute (Oct 14 2021 at 00:26):

The second paper is not out yet, but from the abstract (and from the third paper which is available), they suggest a new type of neural network layer called a Linear State-Space Layer (LSSL). It is basically a drop-in replacement for a Transformer layer or a convolutional layer. (The input to the layer has shape batch size x sequence length x input features and the output is the same shape.) It does this by modeling the linear state space system:

x' = Ax + Bu
y = Cx + Du

(Here u is a scalar input and y is a scalar output, and x is an N-D state vector. Since in reality, the input to the LSSL has many features, the actual LSSL layer has separate matrices A, B, C, D for each feature of the input.) Unlike the first paper, they find a nice class of HiPPO matrices A (which basically differ by how much emphasis they put on the recent past verse the distance past) which can be learned instead of hard-coded like the first paper. Moreover, they show that you can think of this layer as a continuous neural differential equation (which can be computed at different time step frequencies), an RNN (which can be computed one step at a time), or a convolution K * u (an actual convolution, which can be computed with fast Fourier transform, not to be confused with a CNN).

Jason Rute (Oct 14 2021 at 00:26):

The third paper (which is publicly accessible and very readable), shows that while the LSSL layer is not as bad as a transformer (which is quadratic in the length L) of the sequence, it still has O(N^2 * L) time for computing all the multiplications with the NxN matrix A. Also, to compute the kernel K, for the convolution, one needs O(N * L) space. This paper creates a better version of LSSL, called S3. (Why did this name it this? It will be impossible to Google this model!) This paper rigorously breaks the LSSL computation down by showing a HiPPO matrix A can (under some equivalence) always be a diagonal plus low-rank matrix. What this means practically, is that instead of learning and computing with the large matrix A, one can learn and compute with its much smaller decomposed parts. They show near perfect asymptotic runtimes of O(N * L) for the recurrent network version, and O(N + L) for the convolutional version. (There might be a log factor on one of those.) They also show that this speedup is true in practice as well. Now to the best part. They get state of the art on basically all long sequence task benchmarks by a wide margin. They also show they can handle sound recording with different sample frequency during testing than training. Even on the normal length Transformer tasks, they do comparably to Vanilla transformers.

Jason Rute (Oct 14 2021 at 00:26):

What I really like is that a very carefully reasoned mathematical analysis, led to some brilliant theory, which led to some great ideas, which (again with reasoned mathematical optimizations) led to some outstanding results. (What else in deep learning could be revolutionized by this sort of rigorous research workflow?)

Jason Rute (Oct 14 2021 at 00:26):

I probably shouldn't get too excited until this has been reproduced and other experiments conducted, but I think this opens up so many possibilities. Transformers are amazing, but the restriction to medium length sequences is very limiting. For example, GPT-3 can be "taught" a new task by just giving it a few examples: France => Paris, Italy => Rome, England =>. You can only put so much stuff before you run out of room. If a good long-range sequence models can be developed, then one could theoretically learn data online by just running the sequence model over say a whole document or whole training set.

Jason Rute (Oct 14 2021 at 00:27):

For example, if long sequence models are possible and good, we could do something like this where an AI agent interacts with a computer system. The AI can remember the whole conversation. While it probably won't have the ability to remember every little detail, it should be able to ask for relevant information repeated as long as it remembers such information exists.

Computer> Records:
  Record [Hilbert Space]: A Hilbert space is a ...
  Record [Banach Space]: A Banach space is a ...
  Record [L^p space]: An L^p space is a ...
Computer> Question: For which value of p is L^p a Hilbert space?
AI> Repeat records [Hilbert Space] and [L^p space]
Computer> Records:
  Record [Hilbert Space]: A Hilbert space is a ...
  Record [L^p space]: An L^p space is a ...
Computer> Question: For which value of p is L^p a Hilbert space?
AI> Answer: 2

Jason Rute (Oct 14 2021 at 00:27):

I really feel experiments like this are needed on this model. It could be ground-breaking (or it could be a huge dud of course---time will tell). Also, since it is a drop in replacement for Transformers, one could still use Transformer layers after this layer so one can still attend to recent information.

Jason Rute (Oct 14 2021 at 00:27):

One thing from the paper that I'm unclear about is if one can mix and match the RNN and convolution formalisms to do the following:

  • Start with a hidden state x(t0).
  • Gather a long sequence of L steps (where say L = 2^14)
  • Run the O(N + L) time convolution on that sequence
  • Extract the hidden state x(t0 + L)
  • Repeat with x(t0) = x(t0 + L)

Jason Rute (Oct 14 2021 at 00:28):

If so, this would (1) allow us to compute really long sequences in O(N + L)-time chunks, and also (2) allow us do a fast O(N + L)-time convolution to encode the initial sequence prompt, and then a slower O(N * L)-time RNN for decoding.

Jason Rute (Oct 14 2021 at 00:44):

Aside: I decided to ask GPT-J that Hilbert space question. Here is what I got:

Question: For which value of p is L^p a Hilbert space?
Answer: For all $p\geq 1$, L^p is a Hilbert space. This is a famous result (one that many people take for granted but that may not be widely appreciated).

The poor thing is trying... :cry:

Kevin Buzzard (Oct 14 2021 at 06:07):

Dumb comment from a non-expert: sequence modelling seems to be about "predicting a pattern". But a maths proof has no pattern, it is a collection of independent good ideas.

Scott Morrison (Oct 14 2021 at 06:57):

Really? The first n lines of a proof has no predictive power for the n+1-st line, unless you already know what you're doing? I'd hate to read your proofs. :-)

Scott Morrison (Oct 14 2021 at 06:58):

(Maybe this is just saying that maximally compressed proofs are unpleasant to AI agents as well as to humans: it's only the kernel that loves them.)

Jason Rute (Oct 14 2021 at 10:44):

@Kevin Buzzard Before answering your question, let me mention a meta point. When AI researchers talk about AIs proving theorems, even if they hope one day that AIs will prove interesting and novel theorems, right now most of their work is focused on proving standard sorts of results. Even though you see statistics like X prover solved 50% of theorems, that is likely because it solved most of the easy theorems, say theorems who's proofs are 1 to 10 tactics long (with a much better chance in the 1-5 tactic range).

Jason Rute (Oct 14 2021 at 10:46):

Second, not just sequence modeling but all applications of neural networks are about finding patterns. I liken it to human intuition or perception. But what is a pattern? And how can one use patterns to come up with proofs? As per Scott's point, if I gave you the first half of a short proof, you likely could fill in the rest, especially if you were also given the theorem statement and the relevant definitions and lemmas. With your extensive experience you might be able to fill in the proof without much thought, based on your intuition. The key to applying neural networks is to take these "intuition engines" that are neural networks, and use them to the best of their abilities. It is not at all clear right now how to solve longer and more complicated proofs which involve tricks or even just a long chain of reasoning. How do you get a neural network to output the "big idea" of a proof? What is the essence of a proof that if I only know that, the details are trivial? How can be be expressed in a way a computer can understand? Etc. (I'm of course definitely interested in your thoughts on this if you have any.)

Jason Rute (Oct 14 2021 at 10:49):

Third, when I'm talking about the potential of long sequence modeling, I'm being vague as to exactly how this would be used for theorem proving, or any other activity. Part of this is that I don't know what is possible. Part of it is that I already have 20+ ideas of what is possible. But here is one example, currently in Lean GPT-f we predict a tactic from it's goal. For example, we feed the sequence model used in Lean GPT-f the string:

GOAL z_fst z_snd: α h: (z_fst, z_snd).fst = (z_fst, z_snd).snd   (y : α), (y, y)  (z_fst, z_snd) PROOFSTEP

and expect it to fill in the tactic to apply to that goal. Now, this works pretty well, but what if instead of just supplying the goal, we also supplied the partial proof so far, or the definitions of snd and fst and . What if we also supplied all the recent lemmas in the file. We run into two problems. First, a very practical problem is that GPT-like models can only hand 2048 tokens (where a token is roughly 4 characters or so). So we would quickly run out of room to put all this information. Second, this could be information overload to some models which can handle longer sequences of characters. They would start to forget the earlier information it was given. A good, long-sequence model however would be able to handle all this information and use it well to predict the next tactic. When you realize this, the sky is the limit. What if we also put in the failed tactic attempts, so it knows what not to try again? What if we just put in all of mathlib? This would entirely change what it means to "train" a neural network, and would make it easier for these tools to adapt to new situations, like someone's personal project or a new version of Lean, or even a new theorem prover. I'm definitely getting ahead of myself here, but that is why I'm optimistic at least.

Kevin Buzzard (Oct 14 2021 at 12:12):

Hmm. When I'm half way through a proof I agree that it's easier to get to the end, but in my mental model this is because I can see what to do because of the hypotheses and the goal, not because I can see which tactics I just applied. But I realise now that you can let the tactic state be part of the data we're trying to find a pattern in.

Last updated: Dec 20 2023 at 11:08 UTC