Triforce: A general recipe for kickass Generative Models

For the past two years, I’ve been tinkering around with generative models in my spare time. I think I’ve landed on an approach that produces by far the most compelling results available today, and which scales like big language models. I’d like to outline the approach here.

First of all, I want to touch on something that’ll become immediately obvious: this isn’t a novel architecture or anything. In fact, it is pretty much OpenAI’s DALL E with a diffusion upsampler attached. Instead, it’s a way of thinking how one can (1) improve upon DALL E and (2) universally model generative domains using a single set of techniques.

Three Models

This approach uses three different neural networks to produce the finished result, all trained separately from one another.

The first is a Discrete VAE. This model is responsible for translating your base medium (for example, images, music, video, etc) into a string of integers. The DVAE preserves the structure of your medium, but simplifies the contents of it into discrete bins that can be reasoned about. The DVAE can also compress the medium so it is more computationally tractable to reason about.

The second is the causal transformer, essentially a GPT model. This model is trained in next-token-prediction where the tokens are the discrete outputs of the DVAE. These models are especially neat because you can throw anything you like into the sequence and they will learn how to reason about them. Have text and audio and want to produce images? Discretize all three and throw them into your causal transformer! It’ll learn how to convert between these mediums and predict image tokens. Want to flip the problem around and predict text from images and audio clips? Just flip this sequence around! The flexibility of this architecture is incredible.

The final stage is the diffusion network. To understand why this is necessary, you have to first understand that DVAE’s have absolutely awful decoders. They are always lossy and that cannot be fixed because VAEs do not scale. Anecdotally – this is almost certainly the reason that DALL E’s generates are so blurry.

Diffusion models are, bar none, the best super resolution models in existence. What is good for super resolution is also good for upsampling the output of your DVAE decoder. You simply feed the output of your DVAE as a prior to your diffusion model and train it to reproduce full resolution images. Unlike DVAEs, diffusion models respond excellently to scaling. Unlike GANs, diffusion models do not suffer from mode collapse.

That is the slowest generative model in existence

… You’re right. Autoregressive transformers are slow to sample from, and so are diffusion networks. This is not a fast technique, and it’ll likely never see use on the edge. However, it is capable of producing extremely compelling generates. Better than anything I have seen or heard from literature. While there is certainly a place in the world for something that works fast, there is also a place for something that truly works well. I think we are only a few years away from ML models that produce generates that the average human would consider as true art. Voice, music, paintings, etc – it’s all possible, with enough data, compute and patience.

To that end, I am currently building a text-to-speech triforce model which I suspect will blow every previous TTS model out of the water. It’s going to be slow and ungodly large, but my goal is to build something that you can truly enjoy listening to. Something that you might actually use to transcribe audio books or as a stand-in to voice actors, for example.

Like all large transformer models, this thing is going to be enormously data hungry so my last few months has been spent building a massive speech dataset pulled from podcasts, audiobooks and YouTube. I hope to write about that soon.

Super Resolution

Switched Convolutions – Spatial MoE for Convolutions

Switched Convolutions – Spatial MoE for Convolutions


I present switched convolutions: a method for scaling the parameter count of convolutions by learning a mapping across the spatial dimension that selects the convolutional kernel to be used at each location. I show how this method can be implemented in a way that has only a small increase in computational complexity. I finally discuss applications of switched convolutions and show that applying them to a pre trained VAE results in large gains in performance.

I have open sourced all of my work on switched convolutions. It can be found here.


Despite the growing popularity of autoregressive models based on Transformers for image processing tasks, CNNs remain the most efficient way to perform image processing. 

One disadvantage of CNNs is that it is difficult to effectively scale their parameter count. This is normally done by either increasing the depth of the network or increasing the number of channels in the intermediate states. The problem with scaling either of these numbers is that doing so increases computational complexity by O(n^2) for 2-D convolutions because every parameter is repeatedly applied across every spatial index.

Another option for scaling is to move back to stacked dense layers for processing images. The problem with this approach is it does not encode the translational invariance bias that gives convolutions their prowess at processing diverse images.

In the language modeling space, an interesting idea was put forward by the Mixture of Experts (MoE) paper: scale the parameter count of a model by “deactivating” most of the parameters for any given input. A second paper, “Switch Transformers” extends this idea by proposing modifications that allow a MoE model to scale parameters while achieving a near fixed computational cost. The resulting model is termed “sparse” – it uses the inputs to dynamically select which parameters to use for any given computation and most parameters are unused for every input.

I aim to apply the MoE paradigm to convolutions.

Switched Convolutions

A switched convolution is a convolution which is composed of b independent kernels. Computing the convolution is similar to a standard convolution, except that each spatial input location uses a single one of the b kernels.

Ideally, the mechanism that selects which kernel to use for each spatial location would be learned. I adapt the sparse routing mechanism from Switch Transformers to achieve this, and propose a novel normalization layer that promotes proportional usage of all kernels. 

This drawing visualizes how a switched conv works:


The selector is a parameterized function responsible for producing a discrete mapping from the input space to the kernel selection space. It basically converts an input image into a set of spatially-aligned integers, which will be used to select which convolutional kernel to be used at each image location.

The selector can be attached to any input, but in the experiments discussed in this post, I always attach it to the previous layer in the network. It is worth noting that I have tried using separate networks for generating selector inputs, but they have proven difficult to train and do not produce better results.

 Here is a sketch of the internals of a selector:

Switch Processing Network

A NN is embedded within the switch to allow it to segment the image into like zones which will use the same convolutional kernels. It can be useful to think of this network like the dense layers applied to the transformer attention inputs.

The switch NN can be implemented using any type or number of NN layers capable of adjusting the input channel count, for example a 1×1 convolution, a lambda layer or even a transformer.

Switch Norm

The objective of the switch norm is to promote load balancing across all kernels. Without a switch norm, switched convolutions tend to collapse into using a single kernel. The switch norm replaces the load balancing loss proposed in the MoE and Switch Transformers paper. I tried a similar load balancing loss with switched convolutions, but found the normalization method superior.

The switch norm works similar to a batch normalization across the selector dimension, except instead of operating across a batch, it operates across a large accumulated set of outputs, p. Every time the switch norm produces a new output, it adds that output to p. To keep memory in check, the accumulator is implemented as a rotating buffer of size q.

Effectively, this simple norm ensures that the average usage of each kernel across q samples and the entire spatial domain of the input is even. As long as q is big enough, there is still ample room for specialization from the selector, but no one kernel will ever dominate a switched conv. 

I used a value of q=256 for most of my experiments. Future work should explore adjusting this hyperparameter as I did not tinker with it much.

It is important to note that the rotating buffer p becomes a parameter for any network using switch normalization. Even though gradients do not flow to it, it develops a characteristic signal over time. Attempting to perform inference without using a saved p always produces poor results.

A reference implementation for the switch norm can be found here.

Differentiable Argmax or Hard Routing

The argmax function, which returns the integer index of the greatest element along the specified axis, is not normally a differentiable function. In implementing switched convolutions, I produce a “differentiable argmax” function.

The forward pass behaves identically to the standard numpy argmax() function. The numeric value of the input that was fed into diff_argmax is recorded.

In the backwards pass, the gradients are first divided by the input recorded by the forward pass. Then, the gradient is set to zero for all but the max element along the specified axis.

The gradients coming out of diff_argmax are a bit odd: they are exceptionally sparse and you might think that entire kernels would “die” off. This is what the switch norm prevents, however.

A reference diffargmax implementation for Pytorch can be found here.

Switched Convolution

The actual switched convolution iterates across each spatial location and uses the output of the selector to determine which convolutional kernel to apply at that location.

Naive Implementation

A simple way to compute the switched convolution output is to perform k standard convolutions for each kernel k, then multiply them by the one-hot output of the selector:

Such a method can even be used without hard routing. In my experiments this does not perform much better than hard routing.

CUDA implementation

It is worth noting that since only one kernel is active per spatial location, the switched convolution only needs to calculate one dot product per spatial location – exactly the same as a standard convolution.

In contrast to Switch Transformers, which require distributed training processes to start seeing a scaling advantage, switched convolutions can be optimized on a single GPU. However, the larger kernel size and pseudo-random access into the kernel has a significant effect on how quickly a switched convolution can run. 

A naive CUDA kernel that implements this can be found here. This custom kernel could use a significant amount of optimization (for example, it does not use tensor cores) but currently operates at ~15% the speed of a normal convolution when accounting for both the forward and backward passes with b=8. This means it is net-faster than the naive implementation at b=8, and improves linearly from there. It also has significantly better memory utilization properties because it saves considerably less intermediate tensors for backprop.


Training models with switched convolutions works best with large batch sizes. This makes sense: switched convolutions are very sparse and their parameters will only accrue meaningful gradients across a large set of examples. For example, if b=8, each parameter in the switched conv is generally only receiving about 1/8th of the gradient signal.

While it is possible to train a model incorporating switched convolutions from scratch, it is tedious since the signals that the selector function feeds off of are exceptionally noisy in the early stages of training.

For this reason, I use a different, staged approach to training models with switched convolutions: first, train a standard CNN model. After this has converged, I convert a subset of the convolutions in that model to switched convolutions and continue training. This has several advantages:

  1. First stage training can be fast: smaller batch sizes can be used alongside simpler computations.
  2. Since the selector functions are only brought online in the second stage, they start training on fairly “mature” latents.

Converting a standard convolution to a switched convolution is simple: simply copy the kernel parameters across the switch breadth (b) and add a selector. Once you start training, the kernel parameters across the breadth dimension will naturally diverge and specialize as directed by the selector.

Uses & Demonstration

In experimenting with switched convolutions, I have seen the most success in applying them to generative networks. This is intuitive: they offer a way to decouple the expressive nature of the convolution in a generative network from a receptive understanding of what the network is actually working on. For example, a selector can learn to apply different kernels to “draw” hair, eyes, and skin – which all have different textures.


To demonstrate how effective switched convolutions are at improving network performance, I apply them to the stage 1 VQVAE network. I first train a vanilla stage 1 VQVAE to convergence:

I then convert the network by replacing 4 convolutions in both the encoder and decoder with switched convolutions that use b=8 and selector composed of a lambda layer followed by a 1×1 convolution:

The result is a 20% improvement in loss, accounting for both the pixel-MSE reconstruction loss and the commitment loss.

Other Tests

It is worth noting that VQVAE is likely under parameterized for the data I used in this experiment. Inserting switched convolutions in a similar manner into other networks did not show as much success. Here are some notable things I tried:

  1. Classification networks: inserted switched convs in the upper (high resolution) layers of resnet-50. Performance slightly degraded.
  2. Segmentation networks: inserted switched convs in the high resolution backbone layers. Performance did not change.
  3. Stylegan2: inserted switched convs in the generator. Performance degraded. (This is a special case because of the way conv weights interact with the mapping network).
  4. Super-resolution: A 5-layer deep switched conv network of breadth 8 was found to have competitive performance with the 23-layer deep RRDB network from the paper.

Visualizing the Selector Outputs

It is trivial to output the maps produced by the selectors as a colormap. This can be instructive as it shows how the network learns to partition the images. Here are some example selector maps from the high resolution decoder selector from the VQVAE I trained:

As you can see, these selector maps generally seem to resemble edge detectors in function. They also seem to perform shading in generative networks, for example the arms in the third image.

Future Work

At this point, I don’t believe switched convolutions have demonstrated enough value to support continued research as I have currently formulated them. That being said, I still think the concept has value and I would like to revisit them in the future.

In particular, I am not satisfied with the way the selectors operate. This is purely a heuristic, but I believe the power of switched convs would be best expressed when the semantics of the image are separated from the texture. That is to say – I would have liked to have regions of the image that exhibit different textures (e.g. hair, eyes, skin, background) selected differently.

One project I am currently pondering is working on an unsupervised auto-segmenter. Something in the vein of Pixel-Level Contrastive Learning. If I could train a network that produces useful semantic latents at the per-pixel level, it could likely be applied at the input of the selector in switched convolutions to great effect.

Super Resolution

SRGANs and Batch Size

Batch size is one of the oldest hyper parameters in SGD, but it doesn’t get enough attention for super-resolution GANs.

The problem starts with the fact that most SR algorithms are notorious GPU memory hogs. This is because they generally operate on high-dimensional images at high convolutional filter counts.

To put this in context, the final intermediate tensor of the classic RRDB model has a shape of (<bs>x64x128x128) or over 33M floats at a batch size of 32. This one tensor consumes more than 10% of the models total memory usage!

To cope with this high memory usage, SR papers often recommend training with miniscule batch sizes in the regime of 4-16 samples per batch. This is wholly inadequate, as I will discuss in this article.

Larger batches are (almost) always better

Training SR models with larger batch sizes results in an immediate permanent improvement in performance of every SR model I have trained thus far. I discovered this on a whim with a custom model I was developing, but found out later that it applies to RRDB and SRResNet as well. Here is an example plot:

Perceptual loss of two identical models trained on different batch sizes. Blue line is batch-size=16. Red line is batch-size=64. Blue anomaly is caused by an overflow during 16-bit training.

The plot above conveys my experience in general: a larger batch size does not just accelerate training, it permanently improves it. This difference is visible in the resulting images as well. Models trained on larger batch sizes exhibit less artifacts and more coherent fine image structures (e.g. eyes, hair, ears, fingers).

Here is an interesting anecdote from a recent experience I had with this: I am training an ESRGAN model and decided to move from training to 128×128 HQ images to 256×256. To accomplish this, I re-used the same model and added a layer to the discriminator. I decided to speed things up by reducing the batch size by a factor of 2. After nearly a week of training and many tens of thousands of iterations, the results were worse than what I had started with. After doubling the batch size, the model finally began to visually improve again.

Recommendations for larger batches

I’ve done some comparisons between the same model with different batch sizes. The performance improvement that comes with increasing batch size is nearly linear between batch-size=[16,128]. I have not experimented heavily past 128 due to my own computational budget limitations. Any model I am serious about these days gets a batch size of 128, though.

Accommodating Large Batches

As mentioned earlier, the authors of SR papers have good reason to recommend smaller batch sizes: the RRDB network proposed in ESRGAN consumes about 10GB of VRAM with a batch size of 16!

As I’ve worked on more SR topics, I’ve come up with several workarounds that can help you scale your batch sizes up.

  1. Gradient Accumulation – You can easily synthesize arbitrarily large batch sizes using a technique called gradient accumulation. This simply involves repeatedly summing the gradients from multiple backwards passes into your parameters before performing an optimizer step. This can affect models that use batch statistics, but shouldn’t matter for SRGAN models because they shouldn’t be using batch normalization. Gradient accumulation is controlled in DLAS using the mega_batch_factor configuration parameter.
  2. Gradient Checkpointing – This is an unfortunately named and underutilized feature of pytorch that allows you to prune out most of the intermediate tensors your model produces from GPU memory. This comes at the cost of having to re-compute these intermediate tensors in the backwards pass. Trust me: this is much faster than you think it is. The performance penalty of gradient checkpointing is often negligible simply because it allows you to fully utilize your GPU where you would otherwise only be partially using it. Gradient checkpointing is enabled in DLAS using the checkpointing_enabled configuration parameter.
  3. Mixed Precision – This is fairly old hat by now, but training in FP16 or in mixed precision mode will result in far lower memory usage. It can be somewhat of a pain, though, as evidenced above. Torch has recently made this a first-class feature.

(By the way, all of these are implemented in DLAS – my generative network trainer. Check that out if you are interested in trying these out without spending many hours tweaking knobs.)

DLAS Super Resolution

Training SRFlow in DLAS (and why you shouldn’t)

SRFlow is a really neat adaptation of normalizing flows for the purpose of image super-resolution. It is particularly compelling because it potentially trains SR networks with only a single negative-log-likelihood loss.

Thanks to a reference implementation from the authors or the paper, I was able to bring a trainable SRFlow network into DLAS. I’ve had some fun playing around with the models I have trained with this architecture, but I’ve also had some problems that I want to document here.

First of all – the good

First of all – SRFlow does work. It produces images that are perceptually better than PSNR-trained models and don’t have artifacts like GAN-trained ones. For this reason, I think this is a very promising research direction, especially if we can figure out more effective image processing operations that have tractable determinants.

Before I dig into the “bad”, I want to provide a “pressure relief” for the opinions I express here. These are not simple networks to train or understand. It is very likely that I have done something wrong in my experiments. Everything I state and do is worth being double-checked (and a lot of it is trivial to do so for those who are actually interested).

The Bad

Model Size

SRFlow starts with a standard RRDB backbone, and tacks on a normalizing flow network. This comes at significant computational cost. RRDB is no lightweight already, and the normalizing flow net is much, much worse. These networks have a step time about 4x what I was seeing with ESRGAN networks. It is worth noting that reported GPU utilization while training SRFlow networks is far lower than I am used to, averaging about 50%. I believe this is due to inefficiencies in the model code (which I took from the author). I was tempted to make improvements here, but preferred to keep backwards compatibility so I could use the authors pretrained model.

Aside from training slowly, SRFlow has a far higher memory burden. On my RTX3090 with 24G of VRAM, I was running OOM when trying to perform inference on images about 1000x1000px in size (on the HQ end).


While SRFlow generally produces aesthetically pleasing results, every trained model I have used generates subtle blocky artifacts. These artifacts are most visible in uniform textures. Here is are two good examples of what I am talking about:

Examples of SRFlow artifacts. Images artificially blown up 250% for better visualization.

I have encountered these artifacts in other generative models I have trained in the past. They result from 1×1 convolutions which cannot properly integrate small differences in latent representation that neighboring pixels might contain. Unfortunately, SRFlow can only use 1×1 convolutions because these are the only type of convolution which are invertible.

Technically speaking, there is no reason why we could not eliminate these artifacts using additional “consistency” filters trained on the SRFlow output. I think it is worth knowing about them, though, since they point at a deeper problem with the architecture.

The SRFlow architecture currently has poor convergence

This one is a bit more complicated. You first need to understand the objective function of normalizing flows: to map a latent space (in this case, HQ images conditional on their LQ counterparts) to a distribution indistinguishable from gaussian noise.

To show why I think that SRFlow does a poor job at this, I will use the pretrained 8x face upsampler model provided by the authors. To demonstrate the problems with this model, I pulled a random face from the FFHQ dataset and downsampled it 8x:

I then went to the Jupyter notebook found in author’s repo and did a few upsample tests with the CelebA_8x model. Here is the best result:
Note that it is missing a lot of high frequency details and has some of the blocky artifacts discussed earlier.

I then converted that same model into my repo, and ran a script I have been using to play with these models. One thing I can do with this script is generate the “mean” face for any LR input (simple really, you just feed a tensor full of zeros to the gaussian input). Here is the output from that:

So what you are seeing here is what the model thinks the “most likely” HQ image is for the given LQ input. For reference, here is the image difference between the original HQ and the mean:

Note that the mean is missing a lot of the high-frequency details. My original suspiscion for why this is happening is that the network is encoding these details into the Z vector that it is supposed to be converting to a gaussian distribution. To test this, I plotted the std(dim=1) and mean(dim=1) of the Z vectors at the end of the network (dim 1 is channel/filter dimension):

In a well trained normalizing flow, these would be indistinguishable from noise. As you can see, they are not: the Z vector contains a ton of structural information about the underlying HQ image. This tells me that the network is unable to properly capture these high frequency details and map them to a believable function.

This is, in general, my experience with SRFlow. I presented one image above, but the same behavior is exhibited in pretty much all inputs I have tested with and extends to every other SRFlow network I have trained or work with. The best I can ever get out of the network is images with Z=0, which produces appealing, “smoothed” images that beat out PSNR losses, but it is misses all of the high-frequency details that a true SR algorithm should be creating. No amount of noise at the Z-input produces these details: the network simply does not learn how to convert these high frequency details into true gaussian noise.

It is worth noting that I brought this up with the authors. They gave this response to my comments, which provides some reasons why I may be seeing these issues. I can buy into these reasons, but they point to limitations with SRFlow that render it much less useful than other types of SR networks.


I think the idea behind SRFlow has some real merit. I hope that the authors or others continue this line of research and find architectures that do a better job converging. For the time being, however, I will continue working with GANs for super-resolution.

Concepts Super Resolution

Translational Regularization for Image Super Resolution


Modern image super-resolution techniques generally use multiple losses when training. Many techniques use a GAN loss to aid in producing high-frequency details. This GAN loss comes at a cost of producing high-frequency artifacts and distortions on the source image. In this post, I propose a simple regularization method for reducing those artifacts in any SRGAN model.

Background on SR Losses

Most SR models use composite losses to achieve realistic outputs. A pixel-wise loss and/or a perceptual loss coerces the generator to produce images that look structurally similar to the input low-resolution image. With only these losses, the network converges on producing high-resolution images that are essentially the numerical mean of all of the training data.

To humans, this results in an output image that is blurred and overly smoothed. High-frequency details like pock-marks, individual hair strands, fine scratches, etc are not represented in the high-resolution images. These can be appealing to the eye, but they are also clearly artificial.

Extreme examples of images upsampled using only pixel losses. The network learns how to form sharp edges, but completely fails at producing high frequency details.

To improve on this situation, adding a GAN loss was proposed in the SRGAN paper from 2017. This loss is effective in bringing back many high-frequency details, but comes at a cost: the generator eventually begins to learn to “trick” the discriminator by adding high-frequency artifacts in the image.

Examples of GAN artifacts. They often appear in areas of high-frequency details like eyes and hair. For the hair, notice the “strands” the generator is applying that go against the actual flow of the hair.

These artifacts range from mild to extremely bothersome. I have observed them simply removing eyebrows from faces to distorting hands or feet into giant blobs, even when the structural information for those feature were in the low-resolution images. Images generated from GAN SR networks are therefore generally more realistic than their perceptual counterparts, but are even more unsuited for general use since their failure mode is so severe.

Existing Solutions to SRGAN Artifacts

There are many proposed solutions to GAN artifacts. To name a few:

SPSR trains two separate networks: one built on top of images that have been fed through an edge detector and one on the raw image. The logic is that the network is induced to preserve the structure of the low-resolution image throughout the upsampling process.

TecoGAN (and other video SR architectures) improve the state by adding temporal coherence losses, which forces the generator to be self-consistent across multiple frames.

GLEAN uses a pretrained generative network trained with only a GAN loss to guide the SRGAN process towards realistic high-frequency textures.

Approaching the problem by posing the loss in the frequency-domain or after a wavelet transform have also been explored as solutions to the problem.

Of these, I have found that the TecoGAN approach leads to the most impressive reduction in GAN artifacts. It is particularly intriguing because even though the intention of the paper was to improve temporal consistency, the authors also achieved superior single-image super-resolution.

Exploring Self Consistency Losses

The main divergence between SRGAN and TecoGAN is the pingpong loss proposed the TecoGAN paper. This loss is derived by feeding a series of warped video frames recursively forward then backward through the generative network. The same high-resolution video frame before and after this recursive feedforward is compared to each other with a simple pixel loss. The idea is that artifacts introduced by the network will necessarily grow during the “ping-pong” process causing inconsistent outputs which could then be trained away.

This type of self-consistency loss is more powerful than the standard L1/L2 loss against a fixed target because the network can learn to be self-consistent from the gradients of the feedforward passes that produced both the images. For example, the network can learn to fix the problem of growing artifacts by suppressing those artifacts early on (in the first pass of the network) or suppressing their growth by accumulating a better statistical understanding of the underlying natural image. Either way, downstream quality is a result.

Self Consistency Losses for Single Image Super Resolution

The same recursive redundancy loss can be performed for single images as well. The basic method to do this is to take an HQ image and derive two LQ images that share some region from that HQ image. Then, feed these LQ images through your generator and compare the same regions in the generated results.

There are actually many ways you can do this. Basically any image augmentation you might read from DiffAug or such works. For the purposes of image SR, you should probably steer away from color shifts or blurs, but translation, rotation and zooms are great methods.

Having tried all three, I have had particular success with translation. The following simple algorithm has had a noticeable effect on image quality for all of my SR networks:

Example crops from top-left and bottom-right corners of an HQ image.
  1. For any given HQ image, crop LQ patches from each corner of the image. For example, from a 256px image, extract 4 224px patches.
  2. Randomly pick any single corner image to feed forward through the network for the normal losses (e.g. L1, perceptual, GAN).
  3. Pull the region from (2) that is shared with all corner crops.
  4. Randomly pick a second corner crop, feed it forward, and crop out the region of the image that is shared with all corner crops.
  5. Perform an L1 loss between the results from (4) and (3).

This algorithm can be further improved upon by selecting crops that don’t necessarily need to start in the image corners, but I am not sure that the additional complexity warrants improvements. Sheer and zoom can also be added, but this also adds complexity (particularly regarding pixel alignment). I have tried zoom losses and they did not add significant performance gains.

Example validation performance gains on an L1-perceptual loss from a VGG-16 network between two networks. The red line represents a baseline network without the translational consistency loss. The blue line re-starts training of the baseline network at step 30k with the translational consistency loss added. Performance gains are ~1-2%. Heuristic perceptual gains are much higher due to less artifacts.

One note about this loss: it should not be applied to an SR network until after it begins to produce coherent images. Applying the loss from the start of training results in networks that never converge because their initial outputs are so noisy that the translational loss dominates the landscape. The TecoGAN authors noted the same result with their ping-pong loss, as an example.

DLAS Super Resolution

Deep Learning Art School (DLAS)

At the beginning of this year, I started working on image super-resolution on a whim: could I update some old analog-TV quality videos I have archived away to look more like modern videos? This has turned out to be a rabbit hole far deeper than I could have imagined.

It started out by learning about modern image super-resolution techniques. To this end, I started with a popular GitHub repo called ‘mmsr’. This repo no longer exists, and has since been absorbed into mmediting, but at the time it was a very well-written ML trainer library containing all of the components needed to set up an SR-training pipeline.

As my SR (and GAN) journey continued, I often needed to make sweeping alterations to the trainer code. This frustrated me, because it invalidated old experiments or added a ton of labor (and messy code) to keep them relevant. It was doubly-insulting because MMSR at its core was designed to be configuration-driven. As a “good” SWE before an ML practitioner, I started coming up with a plan to massively overhaul MMSR.

Deep Learning Art School is the manifestation of that plan. With it, I have wholly embraced a configuration-driven ML training pipeline that is targeted at research and experimentation. It was originally designed with training image super-resolution models in mind, but I have been able to easily build configurations that train everything from pure GANs to object detectors to image recognizers with very small changes to the plugin API. I now edit the core training code so infrequently that I considering breaking it off into its own repo (or turning it into a Python tool). This has been a design success beyond my wildest dreams.

The repo still has some rough edges, to be sure. Most of that is due to two things:

  1. In the original design, I never imagined I would be using it outside of image SR. There are many unnecessary hard-coded points that make this assumption and make other work flows inconvenient.
  2. I did not bother to write tests for the original implementation. I just never thought it would be as useful as it turned out to be.

In the next couple of months, I plan to slowly chip away at these problems. This tool has been incredible for me as a way to bootstrap my way into implementing pretty much any image-related paper or idea I can come up with, and I want to share it with the world.

Expect to hear more from me about this repo going forwards, but here are some reference implementations of SOTA papers that might whet your appetite for what DLAS is and what it can do:

  • SRFlow implentation – I pulled in the model source code from the author’s repo, made a few minor changes, and was able to train it!
  • GLEAN implementation – I hand-coded this one based on information from the paper and successfully reproduced some of what they accomplished in the paper (haven’t had a chance to test everything yet).
  • ESRGAN implementation – Not new by any measure, but shows what the DLAS way of accomplishing this “classic” method looks like.

Accelerated Differentiable Image Warping in Pytorch

Computing optical flow is an important part of video understanding. There are many ways to train a model to compute this, but one of the more compelling methods is to:

  1. Feed a model an image pair
  2. Have it predict optical flow
  3. Apply that optical flow to the original image
  4. Compute a pixel-wise loss against the second image.

In order to use this algorithm, however, you need a differentiable way to do step (3), typically called an “image warp”. Tensorflow has just such an operation in contrib, but to my knowledge Pytorch does not.

After digging around for awhile today, I found what I needed in one of nVidia’s open source repositories:

In this repository, the author has implemented a new CUDA primitive called “resample2d“. Although there isn’t any documentation on this operation, it is exactly what is needed to compute an image warp given an optical flow vector.

Suppose you have an image and a .flo file, which you can find from several places. Here is how you would use this operation to compute the secondary image:

from utils.flow_utils import readFlow, writeFlow, visulize_flow_file
from networks.resample2d_package.resample2d import Resample2d

im1 = load_image('image.png').to('cuda')
flow = torch.tensor(readFlow('flow.flo')).permute(2,0,1).unsqueeze(0).contiguous().to('cuda')
resample = Resample2d()
synth = resample(im2, flow)
torchvision.utils.save_image(synth, "flowed_image.png")

You’ll need to import the code from the above linked repository to run this. Note that resample2d must be performed on the GPU. It does not work on CPU and just returns all zeros.


Batch Normalization is a Hack

Batch normalization has a simple goal: stabilize the gradients of large computational graphs. In doing so, this technique has enabled the deep learning renaissance that almost every major ML breakthrough in the last 5 years has relied on.

The concept is sound: by regularizing the mean and variance of the inputs of nearly every layer in a neural network, the gradients of that network rarely explode backward pass. The end result is that many neural networks can be easily trained with gradient techniques that would otherwise have never converged.

So why am I calling it a hack? Let’s dig in.

It violates batch invariance

When training a neural network with SGD, we often use mini-batches. Mini-batches allow a computational graph to average the gradients of a forward pass across a greater portion of the input space. They also allow us to easily scale training to fit the computational resources available to us.

There is an important caveat to mini-batches: the choice of which inputs are included in a batch are supposed to be totally random. As a result, there is no structure to the batch dimension of a mini-batched input vector. Therefore, it is always wrong to compute across this structure, which is precisely what batch normalization does.

To give a grounded example of why this is bad, consider some work on GANs I have been doing recently. The discriminator of this GAN is a VGG-style network that used batch normalization. In conventional GAN style, the discriminator discerns whether or not the inputs it is given are real or fake. Since fake inputs must backprop through the generator, they are always batched together. To provide symmetry, real inputs are also batched together.

While analyzing the performance of the GAN, I noticed that the discriminator was always predicting a near-uniform value across the batch dimension. E.g. if one image was predicted as fake, all of the images in the minibatch would be predicted fake and vice versa. This struck me as odd: shouldn’t the discriminator be making predictions per-image and not per-batch?

It turns out this odd behavior was caused by batch normalization. Switching all of the VGG blocks in the discriminator to group norm caused the issue to go away and improved the generator’s inception score in validation by 5%!

I’m not the only one to experience these issues with batch normalization in GANs. In this paper, Xiang et al took an extensive look at batch normalization in Generators and found evidence of performance loss in the end result.

I think it’d be pretty silly to assume that this “bad behavior” is constrained only to GANs. While I can’t point to any specific evidence, I wouldn’t be surprised to see research showing that batch norm hurts the some part of the performance of most neural networks and I’d bet that violating batch invariance is the root cause.

Scaling? Good luck.

Since batch normalization operates across the batch axis, it doesn’t play nicely with data parallelism. While the major ML frameworks have workarounds for this problem, these workarounds bottleneck the network with expensive reduce-alls scattered throughout the graph.

This isn’t really an issue with simple networks because you can simply normalize across the batch dimension per-compute unit, but as your neural network begins to scale to the point that you can no longer fit ~8 samples per compute unit, the normalization really starts to break down. The result is degraded network performance or slower training.

Things get more dire when you don’t have access to a lot of parallel compute. The best way to train large networks in this situation is via gradient accumulation, but this doesn’t play well with batch norm and there are currently no good workarounds pre-implemented in the big frameworks.

Lazy Weight Initialization

The authors of the batch normalization paper argued that it reduces the effect of gradient explosions caused by improper weight initialization.

This seems to be the case, as evidenced by its widespread use in deep residual networks, which have issues with exploding means. In their work titled “Fixup Initialization”, Zhang et al dig into whether or not it is possible to train deep residual networks without batch normalization. They accomplish their goal primarily by deriving a new weight initialization algorithm for use specifically in residual networks. Using Fixup, the authors were able to achieve near-BN performance without any normalization.

The interesting conclusion that can be drawn from this paper is that normalization is (at least in some cases) less a necessity than it is a way to cover up the need for deriving proper weight initialization schemes. It is not hard to do some analysis on a residual network without BN or Fixup to see that the means and variance quickly get out of whack as the network depth increases.

All this is to say that while it might be appropriate to perform normalization in the experimental phase of research, we should hold the ML community at a higher standard when we start using models in production: these models should be closely analyzed for the possibility of improvements via weight initialization. At the very least, weights should be initialized so that the networks can be trained without normalization, even if normalization is still used for performance reasons. I would love to see work similar to FixUp be done for other ML focus areas like NLP.

So it’s a hack..

I think batch normalization is far overused in current models. The situation is a combination of lack of understanding combined with a lack of good research into the use of alternatives. So much research builds upon existing SOTA models that it becomes difficult to reconsider using such a fundamental part of the model, even when it has so many drawbacks.

The next time you are building a model, I urge you to think twice about using batch normalization. For image researchers: try out group normalization, I bet you will be surprised by the results. Other normalization schemes like instance norm can work well too. You might also consider developing your own normalization – it actually isn’t too hard. Often all you need to do is find an input dimension with wide variance and perform normalization across it.

If you are putting a model into production, you should be doing some investigations into weight initialization and trainable scaling. A well-designed model should not need normalization to train properly, and you likely have performance to gain if it does.

Super Resolution

Diving into Super Resolution

After finishing my last project, I wanted to understand generative networks a bit better. In particular, GANs interest me because there doesn’t seem to be much research on them going on in the language modeling space.

To build up my GAN chops, I decided to try to figure out image repair and super-resolution. My reasoning was actually pretty simple: I have a large collection of old VHS quality Good Eats episodes that I enjoy watching with my family. Modern flat screens really bring out how inadequate the visual quality of these types of old videos are, however. Wouldn’t it be great if I could use machine learning to “fix” these videos to provide a better experience for myself and my family? How hard could it be?

Turns out, really hard.

State of SISR

SISR stands for single-image super-resolution. It is the most basic form of super-resolution that has been around for decades. It is appealing because it is extremely easy to collect data for it: just find a source of high quality images, downsample them and train a model to reverse that operation.

SISR has gone through the usual trends of data science. Methods run the spectrum from simple mathematic upsampling to PSNR-trained convolutional neural networks to GAN approaches. I decided to start with the latter, specifically a technique that Wang et al call “ESRGAN”.

This choice was driven primarily by the existence of the excellent ESRGAN Github project. This code is well designed and documented and has been a pleasure to work on top of.

Although my goal is eventually video super-sampling, my initial investigation into the field showed that video SR is just a subset of image SR (big shocker!). Therefore, I decided to start by really understanding SISR.

Challenges of Super Resolution (and image generation)

Training a deep GAN on image super-resolution is a hardware-challenged problem. I plan to dive into this a bit more in a future article, but TL;DR: these models benefit from training on large images, but large images consume utterly insane amounts of GPU memory during the training passes. Thus, we are forced to train on small snippets of the images. When you take small snippets, you lose context that the model would otherwise use to make better SR “decisions”.

This is coupled with the fact that convolutional networks are typically parameter-poor. Put another way: they can be hard to train because the models just don’t have the capacity and structure to generalize to the enormous variety found in the world of images.

The result of this is often hidden away by research papers. They present only the best results of highly-specialized networks that can do one thing very well, but absolutely fail on anything else. The famous StyleGAN, for example, can only produce one type of image (and one subset of those images to boot). Edge cases produce atrocious results.

Super-resolution does not have the luxury of specialization. An effective SR model must adapt to a wide variety of image contents. Even though you can restrict the domain of the images you are upsampling (for example, Good Eats frames in my case), the variety will still be staggering.

The ESRGAN authors wisely worked around this problem by specifically designing their model to recognize and reconstruct image textures. This can produce great results for the majority of an image, but begins to fall apart when you attempt to super-resolve high-frequency parts of an image – like hair or eyes that have no detail in the LR image.

Super Resolution for Pre-trained Image Models

One facet of SR that is particularly interesting to me is the possibility that, as a technique, it might be used to train models on image understanding. Large NLP models are largely trained on next token prediction, and you can consider SR to be the image-analog to this task.

I can’t help but shake the feeling that natural image understanding is fundamentally limited by our current image processing techniques. I feel that the whole field is on the cusp of a breakthrough, and SR might very well be the basis of that breakthrough.

Of course, there’s a caveat: images are insanely complex. The adage “an image is worth a thousand words” comes to mind here. If effective NLP models require billions of parameters – how many parameters are required for true image understanding?

Going Forwards

I started my deep dive into SISR just as the COVID pandemic began to take off in North America in 2020. I’m writing this a little more than 3 months in, and I feel that I’ve learned a lot in the process.

You’re probably wondering what the point of this article is. It’s an introduction into a series of articles on musings, findings, and explorations into the world of SR. For such an obvious field of ML application, SR doesn’t have a whole lot of documentation. My hope is that the things I’ve learned can be useful to others exploring the field. Stay tuned!


Fine-tuning XLNet For Generation Tasks

About a month ago, I decided to take the plunge into learning how to fine tune a language generation model. One use-case of language generation that I found particularly compelling was abstractive document summarization. A lot of the papers currently available that deal with abstractive summarization and transformers work by truncating the input text to the maximum sequence length of the model. In the post-transformer XL world, I thought it’d be neat to fix that limitation.

XLNet and TransformerXL are the two recurrent language models currently available in the Transformers NLP library. “Recurrent” in this context means that they were designed to model extremely long sequences by breaking those sequences into chunks and processing them one at a time. The chunks are then tied together via a “memory” that is recursively passed between from forward pass to the next.

XLNet is particularly interesting for language generation because it is pre-trained in a regressive manner similar to the GPT family of models. This holds the promise for more coherent text output than what you would typically find with MLM models like Transformer XL.

With all this in mind, I decided to try to fine-tune a model that could do abstractive summarization over an arbitrarily long corpus. I’ve been working on cracking this egg for about a month now, and have seen some success. In this post, I’m going to show you what I did. Along the way, you’ll learn about how XLNet works from a technical perspective and how you can use it for your NLP tasks.

I believe what is covered here could be used for just about any generative task that you have supervised data for. More importantly – you don’t even need to have much of that data. In my tests, the model converges after ~10,000 samples.

Who’s This Guide For?

I’m going to try my best not to go too far into the weeds in this post. I expect that you have a fairly strong knowledge of Python and PyTorch. You should generally know how language models work as well. This includes reading up on tokenizers, word embeddings, transformers, BERT. I will go into a bit more detail about segment recurrence and XLNet’s flavor of autogressive modeling, but you should still probably peruse the Transformer XL and XLNet papers if you want to get the most out of this post. You will also find the documentation for the Transformers library from Hugging Face useful, because a lot of this code is built on top of that library.

Worth noting: a data scientist I am not. This is not a paper, and I don’t mean for it to be one. It’s a practical guide on how to make XLNet produce real results for an interesting problem. I will not compare my work to prior art. Some unfounded conjecture will be made. I probably contort mathematics and terminology in awful, ugly ways. There are likely other problems. I apologize in advance for all this, and I welcome comments or pull requests.

The Plan

Before we dive into implementation details, let’s sketch out a rough design of what we’re going to build.

We have a pre-trained language model like XLNet, thanks to our friends at huggingface. We need a dataset. For the purposes of this article, that dataset is going to be a list of BBC news articles and their summaries generated via XSum. Click here to grab a zip file filled with XSum articles. We’re going to try to teach the model to generate a summary of an article given its text.

Each news article has text that spans thousands of tokens in length. Unfortunately, as you may have read – XLNet is a real monster of a model. A GPU with 16GB of VRAM can only handle sequence lengths of 256 tokens and a paltry batch size of 4 on the base XLNet model!

Fortunately, recurrence comes to the rescue. Even though we’re limited to 256 tokens, XLNet can model far longer sequences using a built-in memory mechanism. To exploit this, we are going to split up the text from our articles into chunks. Each chunk will be exactly 256 tokens. 

To fine-tune the model, we will feed the chunks into the model one forward pass at a time. After each forward pass, we will collect the memory outputs from the model. We will then feed the next chunk into the model with the memory from the last pass. We will repeat this process until we get to the last chunk. This last chunk will have our “target” text appended to it, along with some masks that force the model to predict that target text during the forward pass. On this last pass, we will record the gradients and compute a loss against the target text. We will then perform gradient descent against the loss and (hopefully) end up with a better model of our objective.

Diving Into XLNet

Let’s talk about those “masks” I was talking about. XLNet has a ton of inputs for the forward pass. The really unique ones are perm_mask and target_mapping.

perm_mask is what separates XLNet from all the other models out there. Instead of corrupting the inputs with <mask> tokens like BERT, XLNet is trained by predicting each output given a limited input context. Put another way, every output token can only “pay attention” to a limited number of input tokens. Which tokens it can attend to are governed by the permutation mask. That mask is a square tensor of shape sequence_length x sequence_length. The first dimension selects an output token. The second dimension is set to 1 for tokens that cannot be used to predict that output token, and 0 for tokens that can be used.

If you’re familiar with the way the GPT model was trained, this is very similar. The difference is that XLNet is not constrained to predicting a sequence from left to right – it is trained to do so in any order. Interestingly, you can replicate sequential regressive predictions like the one GPT uses by just building a “stepped” permutation mask, which is just what we will end up using to train the model to predict our target text. I’ll dig into this in a bit more detail when we get to the code below, so don’t worry if you don’t completely get it yet.

target_mapping is another mask that allows you to tell the XLNet implementation how to map decoded tokens to labels so that a loss can be computed. For our purposes, this is going to be an eye tensor across the target text. This allows us to make the labels just the target text.

The transformers documentation is missing some details on how XLNet actually works. The model will internally activate one of two “attention modes” depending on the input parameters you feed it. The first mode activates if you fail to specify target_mapping or permutation_mask. In this mode, called “standard self-attention”, XLNet behaves similar to BERT and other MLMs. The second mode is “two-stream self attention”, which enables the permutation regressive language modeling that makes XLNet special. This was not clear to me when I started working on this project and I feel it is worth mentioning in case anyone else gets confused.

Alright, back to the plan. We’re going to use the permutation mask to force the model to predict the target sequence one word at a time – all in a single forward pass! To do this, the permutation mask needs to do something like this to the input:


  |<--------------- ARTICLE TEXT ------------------------> | <--------- SUMMARY -----------> |
A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got stuck in a tree<eos>


TARGET | <------------------------INPUT (*=permutation masked)------------------------------------> |
A A brown cat was stranded in a tree on Friday the 13th. <sep> * * * * * * * *
cat A brown cat was stranded in a tree on Friday the 13th. <sep> A * * * * * * *
got A brown cat was stranded in a tree on Friday the 13th. <sep> A cat * * * * * *
stuck A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got * * * * *
in A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got stuck * * * *
a A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got stuck in * * *
tree A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got stuck in a * *
<eos> A brown cat was stranded in a tree on Friday the 13th. <sep> A cat got stuck in a tree *

As you can see, every target word can only pay attention to target words before it. This is exactly how the model is going to work when we go to generate text – which should give us some great results!

Pre-processing the Dataset

In my recent work with language models, I’ve taken to processing my datasets in two stages.

In the first stage, I take raw data (that’d be a CSV/JSON/XML file) and extract out the relevant features (for example, the text and a classifier). I then feed the text through the appropriate tokenizer and save the results as a Pytorch pickle (e.g. with

In the second stage, the data is loaded into a Pytorch Dataset so that it can be batched, randomized, and served to the model.

There are a few reasons I separate these stages:

  1. It adds a layer of abstraction between the raw data and the logic that loads data into the model, which allows me to use multiple datasets with the same trainer programs.
  2. My train programs load much more quickly when they can operate off of the raw data produced from stage 1.
  3. I’ve been working in Windows a lot lately and Pytorch doesnt support dataloader workers in Windows. This means failing to separate the stages would significantly slow down training.

Stage 1

For the purposes of this post, I’m going to assume you are going to write your own stage 1 data processor. Here is an example of a stage 1 data processor that I wrote to process the XSum dataset. Th is is generally how that processor works:

tok = XLNetTokenizer.from_pretrained("xlnet-base-cased")
output = []
for data in datas:
    text = data["text"]
    text_enc = tok.encode(
        text, add_special_tokens=False, max_length=None, pad_to_max_length=False
    title = data["summary"]
    title_enc = tok.encode(
        title, add_special_tokens=False, max_length=None, pad_to_max_length=False
        "text": torch.tensor(text_enc, dtype=torch.long),
        "target": torch.tensor(title_enc, dtype=torch.long),
    }), "")

Stage 2

Stage 2 processing is where all the magic happens. We need a Pytorch Dataset class that can produce batches of “chunked” inputs. These chunks must then be sequentially fed into the model alongside the memory. After all the chunks in a batch are consumed, the memory is reset and a new batch of chunks is retrieved.

This is a challenging problem because the total number of chunks for any given text is not fixed. For example, one article might consist of 1000 tokens – or four 256-token chunks; another might be made up of 1500 tokens which would use six chunks. In order to do batched training, we need to randomly select batches of articles that use four chunks, then six chunks, etc.

To accomplish this, I’ve written the chunked text dataloader. This is an implementation of both Dataset and Dataloader that consume a stage-1 processed file and serves batches of chunked inputs with each input in the batch having the same number of chunks.

To enable the autoregressive language modeling of XLNet, this dataset also automatically appends the target text onto the end of input text and produces a permutation mask and target mapping for each chunk such that we achieve the masking I outlined in the example above.

I’m not going to dig into this code too much, since I feel I’ve done a good job documenting it. Check out the source code itself for more details. Just know when you see “ChunkedTextDataset” referenced below, this is where it is coming from.

Note that chunked_text_dataloader script can be executed. Doing so will produce a really neat view into the outputs it is producing, including a full permutation mask tree like the one I outlined above.

Fine-tuning XLNet

The actual training code looks very similar to the example code from the Transformers repo. There are a few differences, though. Lets walk through it.

We’ll start in the __main__ section. First, we process command line arguments, then we load the datasets:

# chunked_model_config is a dictionary initialized from command line arguments. Most fields
# are self-explanatory or can be inferred from the ChunkedTextDataset docs.
# Get the datasets
train_set = ChunkedTextDataset(
    os.path.join(input_folder, ""),
val_set = ChunkedTextDataset(
    os.path.join(input_folder, ""),
train_loader = train_set.get_dataloader(batch_size, num_workers=0)
val_loader = val_

Next, we load up a XLNetConfig and XLNetModel from the Transformers pre-trained model database. This is where an important deviation takes place – you need to tell the model to output its mems by setting a “mem_len” in the config:

config = transformers.XLNetConfig.from_pretrained(
config.mem_len = chunked_model_config["mem_len"]
# config.num_labels = 1 # Un-comment this if you want 
model = transformers.XLNetLMHeadModel.from_pretrained(

Configuring the optimizer and scheduler are direct copy and pastes from the Transformers example with the exception of some provisions to enable what I call aggregate_batch_size. Since I can only train batches of size 3 at a time on my GPU, I wanted to combine several mini-batches together to get some of the benefits of larger batch size training.

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        "params": [
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        "weight_decay": 0,
        "params": [
            for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        "weight_decay": 0.0,
optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=start_lr, eps=1e-8)
scheduler = transformers.get_linear_schedule_with_warmup(
    num_training_steps=epochs * len(train_set) / aggregate_batch_size,

Now we’re ready to run the training loop:

trainer = Trainer(
for _ in range(epochs):

Let’s look at how the trainer classes loop() method works. This loop can do both validation and training loops, so you’ll see some split logic here. It starts off with some boilerplate:

def loop(self, _validate=False, _skip_batches=1):
    # How many steps per logging event.
    _logging_steps = 5
    # How many steps between checkpoint save and validation.
    _steps_till_save = 2000
    _steps_till_validate = 2000

    _dataloader = self.val_dataloader if _validate else self.train_dataloader
    _epoch_iterator = tqdm(
        _dataloader, desc="Val Iteration" if _validate else "Train Iteration")

    # Total batches processed and number of times optimizer.step() has been called.
    _steps, _optimizer_steps = 0, 0
    # Aggregate losses.
    _loss_sum, _logging_loss = 0, 0
    _chunks = 0

    # This controls how many batches are required per optimizer step.
    _batches_required_for_desired_sz = int(
        self.desired_batch_sz / self.chunked_model_config["batch_size"]

    if _validate:

Next we’ll start iterating through the batches. Let’s do a quick recap on what we’ll be getting inside of each batch. The batch returned from our dataloader is a dictionary. Each value in the dictionary (with the exception of “labels”) is a list of tensors. Each of those lists will have the same length. The lists are meant to be zipped together, and the aggregate collection of tensors you get by fetching one zipped entry is called a “chunk”. Here’s a psuedocode sketch of what a batch would look like if you printed it out:

  "input_ids": [<batched_tensor1>, <batched_tensor2>, ..]
  "attention_masks: [<batched_tensor1>, <batched_chunk2>, ..}
  "permutation_masks": [..]
  "target_mappings": [..] # Only appears if force_max_len_gen=True
  "labels": <batched_tensor>

Most of these values line up with inputs that the XLNet model is expecting. “labels” is only a single tensor because they are only fed into the model along with the last chunk. This is because only the last chunk contains the target text to be predicted.

The general algorithm is to loop over all of the chunks in the batch. The mems output from passing each chunk into model.forward() is saved and used as an input into the next forward() pass. The last chunk of every batch also feeds the labels in and computes a loss. We use this loss to backprop and update the model weights. There is some logic to support the aggregate_batch_size that we discussed earlier.

for _masked_input_ids, _attention_masks, _perm_mask, _target_mapping in zip(
    _is_last_chunk = _chunk_counter == (_num_chunks - 1)
    _chunk_counter += 1

    # Forward
    _inputs = {
    if _mems is not None:
        _inputs["mems"] = _mems

    # Only compute gradients on the last forward() per-chunkset.
    if _is_last_chunk:
        _inputs["labels"] =
        _loss, _logits, _mems = self.model.forward(**_inputs)
        with torch.no_grad():
            _logits, _mems = self.model.forward(**_inputs)

    # Backwards
    # Only compute backwards on the last chunk per chunkset.
    if not _validate and _is_last_chunk:
        backward_time = time.time() - __s
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

    if not _validate and _step % _batches_required_for_desired_sz == 0:
        # Update weights after all chunks have been processed at an interval to fulfill desired_batch_sz
        _optimizer_steps += 1

    # Always accumulate loss across the last chunk, where it should be lowest. That's the goal of this model.
    _loss_sum += _loss.item()
    if not _validate and _steps % _logging_steps == 0:
        _loss_scalar = (_loss_sum - _logging_loss) / _logging_steps
        _logging_loss = _loss_sum
        _logs = {
            "avg_chunks": _chunks / _logging_steps,
            "loss": _loss_scalar,
            "learning_rate": self.scheduler.get_lr()[0],
            "optimizer_steps": _optimizer_steps,
        # Perform logging here.

    # The train loop automatically runs a validation loop and saves checkpoints.
    if not _validate:
        if _steps % _steps_till_save == 0:
            self.save_model("chkpt_%i" % (_steps))
        if _steps % _steps_till_validate == 0:
            self.loop(_validate=True, _skip_batches=10)

In the next section, I’ll take a look at how the model turned out.


I have been doing quite a few experiments with this set-up to see what kind of performance I can wring from it. There are some constants I’ve observed across all of these runs:

  • The initial loss drops quite quickly early on (after ~20 steps) and follows a far more gradual trendline from there on out. Here is a pretty typical loss graph:
  • Training for long periods of time seems to hurt the model rather than help it. The training and validation losses stay stagnant, but the generated text gradually loses its quality (quality in this case is just my own judgement, though I’m working on ways to make this more measurable). I believe what is happening is the model is losing the pre-trained language modeling capabilities in favor of being able to predict the idiosyncrasies of the target text.
  • I developed a scheme for using a “aggregate batch size” where the gradients from individual batches are summed together to produce larger overall batches. This was necessary because my GPU could only do batches of 3 on the final model. Using this mechanism, I found that a “aggregate batch size” of 30-60 seemed to work best.

Text Generator

Due to the particular way this model was trained, you cannot use the stock Transformers Generator API to use it in test. This is because the model expects a sequence input with a fixed space for generation. The generation API only feeds the model sequences with a single <mask> token on the end. This model will almost always stick an <eos> token in that place because it thinks it has no room to actually create a summary. The other problem is that the existing generation API doesn’t accept chunked inputs and cannot be fed a mems to compensate.

To get around this, I hacked the Transformers library to work with my particular generation needs. Here is my fork. Here is the code that you can use to generate summaries.

From testing, I have found that sampling generally produces inferior generation results. I saw the results I present below with sampling turned off, num_beams=12 and repetition_penalty=5.

Ablation Testing 

In this section, I want to answer two questions:

  • Did my fine-training scheme improve model performance at all? Or is XLNet naturally good at producing text summaries?
  • Does my chunking algorithm allow XLNet to improve performance by being able to recall past sequences, or is it primarily generating summaries based on the current sequence?

To test this, I will post demonstrative summaries from 5 different approaches:

  1. “human”: The first is the human generated summary that came with the XSum article.
  2. “base”: Summaries generated using a baseline XLNet model with no fine-tuning.
  3. “candidate”: The best fine-tuned XLNet model I produced during my testing.
  4. “no-mems”: The same fine-tuned model from (3) where mems are not recursively fed to the final chunk (e.g. not used).
  5. “corrupt-mems”: The same fine-tuned model from (3) where mems are generated from a single article text, and then applied to the same last chunk as the rest of these examples.
human:        "The US, South Korea and Japan have said they will be united in their response to North Korea's claim to have successfully tested a hydrogen bomb."
base:         "The North has denied any involvement in the test. However, South Korea and Japan have also expressed concern about Pyongyang's claims of a full thermonuclear explosion - \"a clear threat to international peace & security by its latest provocative act on Wednesday!\" Yonhap news agency reported that President Barack Obama had spoken separately with Mr Park over whether he supported further sanctions against his country after it carried out an"
candidate:    "North Korea says it carried out a nuclear test on Wednesday."
no-mems:      "The US and neighbouring countries are thought to be carrying out atmospheric sampling, hoping for leaked radioactive material which would give clues asto what kind of device was tested. South Korea is suspected North Korean testing a "boosted" atomic bomb that uses some fusion fuel to increase the yield on its nuclear reactors; officials have said Pyongyang has claimed a successful hydrogen bombs test whereas Seoul says it claims"
corrupt-mems: "Partick Thistle (Partick Heretle) won a free kick in the second half of their first-half win over Motherwell, but they were not able to convert an attempt from outside his own box. Danny Devine has been shown yellow cards for fouling Liam Lindsay on Saturday night when he was sent off after conceding two goals as part way through play at home against Dundee United and"

human:        "Northampton hooker Dylan Hartley has been confirmed as England's captain for the Six Nations - two days after his six-week suspension for striking ended."
base:         "The Australian was asked about his suitability for the role, Jones interrupted: \"I haven't actually told him. It is the continuity of this job that we can use as best I can to get our preparation right before the first game against France on 4 February at Twickenham in England and then we have a couple more days off! We are pleased with James Haskell back"
candidate:    "England hooker Dylan Hartley has been given the all-clear to return as captain after being banned for two weeks."
no-mems:      "England back row James Haskell has been given the all-clear to link up with France after recovering from a foot injury that kept him out for six months. Jones said people would \"have to wait and see if he will feature againstFrance in 11 days' time because of his lack thereof][ [lack] [[|]]\\{\u2022}</<> <nowiki*&#"
corrupt-mems: "Steven Lawless (Partick Thistle) has been shown the yellow card for a bad foul on Chris Cadden in his first match as part of Partick this season."

human:        "Some of the damage caused by the recent floods could have been prevented if the correct water management techniques had been used, says a group of leading environmental and planning experts."
base:         "In the letter to the Daily Telegraph, "The government will provide leadership and funding for a flood defence policy that is based on best practice developed over many years. It is essential if we are not addressing long-term problems of water management techniques such as stream alleviation streams in towns or cities like the Somerset Levels where flooding has been seen since December 2007 when it began with severe weather."
candidate:    "The Prime Minister has called on the government to take action against flooding in England and Wales."
no-mems:      "The government has announced its scheme to provide grants for homeowners in England hit by the floods and are spending \u00a32.4bn on flood management, protection from coastal erosion as part of our long-term planto improve resilience; we will look at the lessons that be learned to see where additional flooding can help\"Mr Cameron says he is looking forward too much Too MuchToo"
corrupt-mems: "Partick Thistle (Partick Heretle) won a free kick in the second half of their 2-1 win over Motherwell on Saturday."

This testing was quite instructive. As expected, the baseline model tries to continue the article where it left off, as it doesn’t understand that we want a summary following the <sep> tag. We may be able to coax out better performance from this model by forcing it to predict text at the beginning of an article, but we would have no way to use memory in that case.

As for the mems ablation testing – wow! It is very clear that the model is using memory just by comparing the no-mems and candidate examples. The corrupt-mems examples shows just how much this model is leaning on memory, though. Some of the generated text completely ignores the context of the last chunk that it was given. It seems that the memory contributes more to most summaries than the preceding chunk text!

If you’d like to see some more sample outputs, skip a couple more sections below.

Work Needed

I suspect that there is one major shortcoming to my approach: I believe that the transformers library does not properly apply relative positional encodings to sequences recursively fed into it. What this effectively means is that the model doesn’t have context on where in a sequence of text the “stuff” it holds in its memory came from. Once you’ve processed two chunks into memory, they will likely be “jumbled” up. Fortunately for the purposes of document summarization, sequence ordering actually doesn’t matter all that much. However, this would be something that would be nice to fix for other use cases.

I also suspect that my data could be better cleaned. I remove some  useless, repetitive text in my preprocessor but I know that some articles, for instance, are filled with gibberish and should be removed.

I would love to see how this model performs on xlnet-large-cased, which has more than 3x the parameters. I just don’t have the hardware on hand to do so.

Honestly, at this point, I’ve convinced myself that this type of pre-training is not the way forward for a truly game-changing text abstraction engine. The problem is that during training, I am forcing the model to predict my summarization examples, rather than produce good summarization models in general.

Think back to school – imagine if you were graded only by your ability to produce an exact copy of a “golden” book report. I suspect you wouldn’t learn much about expressing yourself.

I think the way forward is to have the model produce a summary and then get graded on it. I’m going to pursue that pursue rather than fine tuning this further. Stay tuned..

Alternatives Considered

This project has been through many iterations, a few of which are worth mentioning.

Before I understood that target_mapping was mandatory to get permutation modeling to work in XLNet, I tried an alternative scheme to pre-train the model. Instead of using the permutation_mask to force the model to predict all target tokens, I went with a variant of masked language modeling.

In this scheme, I pre-pended to target text onto each chunk and had the dataset mask out a high percentage of the target tokens (50% seemed to be a good settling point). Prepended text worked particularly well because it forced the model to predict where the target text ended and the article text began – and these models were actually quite good at producing proper length summaries. They also had very low losses – so they were quite good at the modeling task at hand. Unfortunately, their generation quality was fairly low. The summaries they produced sometimes looked like a summary, but were generally incoherent.

The logic behind prepending the target to every chunk was that it allowed me to compute a loss on every chunk and therefore perform optimizations on the model as often. One thing I worry about with my current permutation solution is that he model is not really learning how to use the memory properly because there is no stage where it is both producing memory and having gradients computed. The memory is always just something that is “there”, but is not fine tuned. This earlier scheme had to perform predictions on every chunk, which intuitively forced the model to optimize its internal state on every chunk. I do not have solid metrics to back up whether or not this actually matters, this is all conjecture.

Trying it out

I’ve put together a test model and a sample dataset which you can use to test the model out for yourself. In this section, I’ll document how you can do that. I’ll also show you how you can train it yourself.

Caveat: XLNet is no joke. Running the generation script on a CPU will take 20 minutes to 1 hour per inference. Keep in mind that it is doing up to 100 forward() passes on a 110M parameter model every inference. Running on a GPU currently requires 11GB of RAM at a minimum, you can reduce this by reducing the number of beams in the generator and reducing the batch size (further) in the trainer.


First, download the model and the XSum test data.

Next, you’ll want to set-up your Python environment to match mine. Chances are this works fine on the latest Pytorch, but here’s the details for reproducibility.

pip install torch===1.4.0 torchvision===0.5.0
git clone
git clone

cd transformers
pip install .

cd ../NonIntNLP
git reset --hard remotes/origin/xlnet_xsum_train
cd Python

Use your own data

I’ve provided some test data to work against. If you want to have the model summarize your own text, you’ll first need to do some pre-processing on that data. I’ve provided a simple script to do that in the NonIntNLP repo. First, you need to create a JSON file called “input.json” that looks like this:

{ "text": "article text here",
  "summary": "article summary here" },
{ "text": "article2 text here",
  "summary": "article2 summary here" },

To preprocess this JSON file, run:

python processors/

Generate summaries

To run the generator against the pretrained model, you’ll need a pre-trained model like the one you downloaded above and a pre-processed data file ( Invoke the generator as follows:

python --model_dir=<extracted_model_dir> --data_file=<extracted or generated> --number_to_generate=1 --device=cuda

Train the model

To train the model, you’ll first need some training data. You can download my XSum raw data and run it through processors/ You’ll need to modify the code a bit for your system and file paths. I’ll leave that up to you to figure out. Once run, you’ll get three files: “”, “”, “”. To train a baseline Transformers model against these files, run the following command:

python --input_folder=<folder where and are stored> --epochs=1 --output_dir=<directory to save checkpoints> --seq_sz=256 --max_predict_sz=80 --batch_sz=3 --aggregate_batch_sz=36 --start_lr=5e-6 --device=cuda

Some tips:

  • seq_sz and batch_sz are what you want to modulate if you get a CUDA OOM error. I doubt you will be able to train on a GPU with less than 8GB of RAM.
  • XLNet has a seq_sz limit of 512. Above this it should still work, but I believe the transformers library will automatically start using input recursion, which may reduce performance. I can’t test this because my GPU cannot handle it. Practically speaking, the model does fine at seq_sz=256.
  • This script does not currently support multi-GPU training, but it should be easy to implement.
  • start_lr should be scaled alongside aggregate_batch_sz.

More Sample Outputs

I ran the best combination of model and generator configuration options against 50 news articles from the XSum test set I generated and posted the results on Github here. I’ll spend this section discussing some of the results.

To start with, I think a good number of these generated summaries are pretty incredible! I’m going to focus on the negatives below, but I’d say the majority of the generated text looks and sounds like a summary, and many times gives the actual summary a run for it’s money.

This particular model has a penchant for producing short, succinct summaries. As a result, it often loses some important details that you would normally expect to see in a summary. For example:

Generated summary: “Former Chancellor George Osborne has been appointed as the new editor of The Sunday Times.”
Actual summary: “The rules on MPs taking second jobs are to be discussed by a parliamentary committee on Thursday.”

This article was mostly about the controversy surrounding Osborne’s appointment, not the appointment itself.

Like many automated text summarizers, this one sometimes struggles to get facts right. Here are a couple of pretty egregious examples:

Generated summary: “Scotland’s unemployment rate has fallen by more than half in the last three months, according to figures released today.”

Actual summary: “Unemployment in Scotland rose by 16,000 between November and January to stand at 171,000, according to official figures.”

Generated summary: “A man killed by a taxi driver in Glasgow has pleaded guilty to manslaughter.”

Actual summary: “A lorry driver has admitted causing the death of a student after he drove through a red light and crashed into a taxi in Glasgow.”

The model occasionally has a penchant for run-on summaries which quickly lose their point. I can’t figure out what triggers this to happen. Here’s an example:

Generated summary: “Dublin Airport has apologised to customers for any inconvenience caused by a ground incident in which two aircraft clipped each other on the taxiway at its airport. Ryanair have said they were bused and boarded three replacement flights from Edinburgh FR812 & Zadar as passengers prepared their departures this morning, according to an Irish airline statement issued late Wednesday evening online via Twitter!”

Finally, the model occasionally will extract some information from an article, but totally miss the intended point. Here is an example:

Generated summary: “Women in the UK are being judged more than men, according to a new study.”

Actual summary: “Reverend Libby Lane has become the first female bishop.”

I’ll leave you off with an amusing example encountered in the test set. An article which is apparently filled with gibberish is fed into the model, and it responds with gibberish. 🙂

Generated summary: “Fe’i y datganiad wedi pry-droed a Mr Deguara difrifol car Vauxhall Caerdyd, ong mewn pa de la fru dux van Gael. Y fe cafwy Daddy Pesticcio no fos to go for his daughter Sophie Taylor as she was pregnant with her son Michael Wheeler at the”

Actual summary: “Mae menyw yn wynebu carchar wedi iddi achosi marwolaeth cyn-gariad ei chymar wrth ei rasio ar hyd strydoedd Caerdydd.”

Copywrite and Attribution

This code is under a creative commons license – do with it what you like. I only ask that if you do use this code or post in a public project or paper, that I you reference my name (James Betker) and this work.