SampleRNN in PyTorch

Posted by Piotr Kozakowski & Bartosz Michalak on Thu 29 June 2017

Some news

It's been a long time since our last post. We've been quite busy studying and playing with SampleRNN. Unfortunately, none of our ideas to improve SampleRNN worked and we haven't yet managed to get it to generate any higher-quality music. We've discovered though, that there are some problems - not with the model itself, but with the reference implementation, that slowed our efforts considerably:

  • The code is quite messy. Model definition and training code are mixed in a single file, there are many code repetitions and a few weird design choices, which makes it harder to work with. We don't mean to criticize the authors though - we realize that this was never meant to be production-quality code and we are grateful that they provided any code!

  • It uses reshapes extensively. Which is not a bad thing, when they are necessary - but here they are used even when they could be replaced by some more suitable operation. For example, upsampling of a (examples, time, channels) tensor by a factor of k in the time axis is performed by linearly projecting channels to size k * channels and then reshaping the tensor to shape (examples, k * time, channels). Okay, when you know what it is supposed to do, it might be pretty obvious tensor manipulation - but it certainly isn't obvious code. Especially when reading it the first time. Such an operation could easily be replaced by transposed convolution, which would be much more understandable. Also, so many reshapes make modifying the code harder - it's easy to make a mistake when you have to enter all of the shapes by yourself.

  • It's written in Theano and debugging in Theano is hard. When working with deep learning there inevitably comes the moment when your network is not learning properly and you have no idea why. It's then useful to look at how the weights change over time, at the outputs of various layers and gradients. But you cannot just simply take an arbitrary intermediate value out of a Theano computation graph and do whatever you want with it. You can inject into the graph an Op that does what you want, but that's inflexible - Theano provides the Print op that prints tensors, but how to print a transformed version of a tensor (like just the variance of the output)? You need to include the transformation in the computation graph and make the final result depend on it so that the transformation gets evaluated, but without changing the final result. That's quite hacky. And it only works with printing - what if you wanted to gather output variances in each epoch and plot them? You can also return all intermediate values from the Function, but that clutters the code pretty badly. If you have any tricks that make debugging in Theano easier, please contact us! We’re still hoping there’s something we’ve missed. ;)

Seeing all of these problems, we decided to rewrite SampleRNN to PyTorch. Rewriting the whole code to a different framework is quite a radical decision, but we think it will pay off with greatly increased prototyping and debugging speed in the future. In the following section we’ll try to prove that we’ve chosen the right tool for the job.

PyTorch

PyTorch is a relatively new deep learning framework developed by Facebook. Its basic building block is a Module - essentially any differentiable function operating on tensors. It might be a linear transformation, convolution, softmax activation etc. Modules can be built of other modules, which enables to build complex models. This approach, as the name suggests, promotes modularity and object-oriented design. But the main feature that makes PyTorch stand out from the crowd is that it uses dynamic computation graphs.

Most deep learing frameworks, like Theano and TensorFlow, use static computation graphs. It means that the training program is split into two parts - graph creation and actual training. It first defines all computations and builds a graph out of them, then differentiates it and trains the model by executing the computation on data. At this point the computation graph can no longer change - it's static. This approach has some advantages, for example abstracting out the computations makes it possible define a graph once and run it on multiple architectures. However, as mentioned before, it also makes debugging a nightmare.

In PyTorch you don't need to define the graph first and then run it. It uses a define-by-run paradigm - you run computations on actual tensors with graphs generated on-the-fly. The computation graph can be different in each iteration of the algorithm, because in each iteration it is generated from scratch. This way you can backpropagate through control structures, like Python for loop, instead of resorting to symbolic equivalents, like the problematic Theano scan. It's almost magical: you write code that closely resembles Numpy code, operating on concrete values all the time, the difference is that it can run on GPU and that you can run the backward() method on the result to compute gradients with respect to parameters.

This also solves the debugging problem - because you're operating on concrete tensors all the time, at each point having access to intermediate outputs, you can compute statistics on them, collect them and do anything else with them. In fact you don't even need to modify your model's code to do that - you can attach a forward (or backward) hook to your Module, that will be called each time the Module computes its forward (or backward) pass, passing the intermediate value to the hook.

In summary, PyTorch is really great and we encourage anyone to try it. You can learn more on their website: http://pytorch.org/about/.

Our implementation of SampleRNN

Our code is available here. We tried to translate the model from Theano to PyTorch as closely as possible most of the time, but there are some differences, mostly related to design rather than actual functionality.

  • We encapsulated all logical parts of the model in classes, making it easy to modify things. We used a modified version of PyTorch's Trainer class to take care of the training process and wrote some plugins for it for validation, visualization etc. Some of these plugins might be useful in other deep learning projects, unrelated to SampleRNN.
  • We got rid of the unnecessary reshapes, making the code clearer.
  • We allowed training models with an arbitrary number of tiers. The original code allows only maximum three tiers and each number of tiers has a separate training script.
  • We changed how the experiment tags are generated. Most hyperparameters have sensible default values and if you don't change them, they won't be included in the tag. This is to avoid problems with too long directory names (and it actually makes a difference - you don't want to exceed 255 characters in a filename).
  • We haven't implemented weight normalization for now, because that would require extra work and the model works well without it - the convergence is just slower.
  • We didn’t allow training models with LSTM units, only GRU, but in our experiments the model worked better with GRU units anyway.

If you'd like to implement any missing functionality or have any ideas about how to improve the code, you're welcome to contribute!

Stay with us

We believe that with our new implementation we'll make progress faster and hopefully before long we'll be able to present to you some actual results. In the meantime, have a great summer, everyone!


Comments !