Lab notes is a way for me to openly blog about the things I am building. I intend to talk about things I am building and the methods I plan to use to build them. Everything written here should be treated with a healthy amount of skepticism.
I wanted to write about something I built about a month ago that I think is really neat and I would like to return to someday. A quick disclaimer is that I think there is a strong probability that formal research on this idea already exists: I vaguely recall reading something similar about a year ago. If this is true and you would like to see credit given, drop me a line.
I was thinking about an interesting problem recently: the wav2vec ASR model works using a sequence-to-sequence modeling scheme called “Connectionist Temporal Classification”. This is a hyper-cool way of tackling the alignment problem that haunts speech-to-text and wav2vec’s performance reflects it’s power. One thing I’ve been looking into is whether or not I could reverse this coding scheme. In other words: could I generate speech from CTC codes?
There are two sub-problems here. First, you need to train a model that learns to convert fully aligned CTC codes to speech. This is a perfect application for diffusion models, and it turns out to work quite well. That’s not what this blog post is about, though. 🙂
The second problem is: how do you convert raw text into the aligned CTC codes? My first intuition here was to represent aligned CTC codes in a symbolic manner. To understand how this can be done, you need to first understand what the difference is between “aligned CTC codes” and “unaligned CTC codes” (aka “text”). With “aligned CTC codes”, each character is aligned with a split second of audio data. For example, if you pronounced the word “hello” by dragging out the “o” at the end, the CTC codes might look something like “hheeelllllloooooooooooooo”. CTC codes also have one or more <pad> tokens between each character. I’ll represent these as 🍒. So that “hello” with the dragged “o” might actually look like “h🍒ee🍒ll🍒🍒oooooooooooooooo🍒”.
So with aligned CTC codes, you have repeated characters, and you have padding tokens. This can be represented symbolically: simply search through your dataset for the longest sequences of repeats and pads. The longest sequence of pads will be of length P. The longest sequence of repeats will be of length R. Now, you have a classification problem: each character of input text can be classified in one of P*R classes. Which class it belongs to is determined by how many pads that preceded it, p, and how many times it is repeated, r. The class is (r*P+p).
So now we’ve got a way to convert text into aligned CTC codes. The next question is: how do we model this? One simple way is using a basic transformer encoder. Feed in text on one end, predict the class at the other. This works OK, but the results are quite boring. For any given input text, the model chases the most likely classes. As a result, all words sound more or less the same and there is almost no variance, regardless of context.
Context is actually quite important to this problem: Let’s say I showed you someone pronouncing the word “banana”, and I only showed you the pronunciation of the first two letters “ba”, where the person produced a very long, drawn out “ahhhh”. Would you agree that this is relevant information for predicting how the rest of the word is going to be pronounced? I would say it is pretty likely that if the first “a” was drawn out, the rest of the “a”s will be as well. At least it is more likely then in the scenario where a person speaks the first “a” as short.
By that logic, it’s probably better to predict these CTC “classes” in some sort of ordered fashion. One well-known tool we have for that is an autoregressive decoder, like GPT. Training a GPT model on aligned CTC codes turns out to work quite well. The downside is that it is slow; sampling one character at a time.
I couldn’t help but think that this could be sped up somehow, and began thinking of ways to do it. One idea I had was: what if we let the model learn to choose the decoding order, and decode multiple characters at a time? This seemed intriguing, so I set about designing a way to make it work.
This is how I ended up with what I am calling a “confidence decoder”. An implementation for the exact confidence decoder I described above can be found here. Here’s how they work:
- Take a sequence of characters and their aligned CTC “classes”, as defined above.
- Convert the characters to embeddings.
- Mask a random number of the “classes”, from 0-100%. Embed them as well. (Masked classes share an embedding).
- Sum the class embeddings, the character embeddings, and a positional signal.
- Feed the whole sequence through a transformer.
- Feed the outputs into two separate heads:
- A class prediction head.
- A “confidence” head, of dim 1.
- Set the unmasked confidence head values to 0.
- Take the softmax of the remaining confidence heads.
- Take the cross entropy of the class prediction head (against the true classes).
- Multiply the confidence softmax by the class prediction head
- The average of step (10) is your loss.
What this model ends up learning is not only how to predict the masked class scores, given a prior of some unmasked classes, but also how to predict a confidence score for what it believes is the most likely of it’s predictions.
This is a pretty cool property when it comes time to decode this model. You first feed in text with all masks for the class. The model tells you what it thinks each class might be, as well as provides you with confidence scores for each of those classification results. You can choose the sequence elements with the highest confidence scores to decode first, then feed the whole thing back into the model. In this way it becomes a interesting hybrid between an encoder-only model and a AR decoder.
I trained a small toy model like I described above on my CTC codes. It did pretty good! What I thought was particularly neat was how the confidence scores exhibited a fairly robust amount of variance as the decoding process went along. They were not always fixed at a mean or anything. The model clearly had learned a preference for decoding order! The outputs of this model were, heuristically, quite good. Certainly better than the simple encoder model. I can’t say they were any better than the AR decoder model. The decoding process was faster, though!