This post is taken from a tutorial I am writing with David Duvenaud.
When you write a nontrivial piece of software, how often do you get it completely correct on the first try? When you implement a machine learning algorithm, how thorough are your tests? If your answers are “rarely” and “not very,” stop and think about the implications.
There’s a large literature on testing the convergence of optimization algorithms and MCMC samplers, but I want to talk about a more basic problem here: how to test if your code correctly implements the mathematical specification of an algorithm. There are a lot of general strategies software developers use to test code, but machine learning algorithms, and especially sampling algorithms, present their own difficulties:
- The algorithms are stochastic, so there’s no single “correct” output
- Algorithms can fail because they get stuck in local optima, even if they are mathematically correct
- Good performance is often a matter of judgment: if your algorithm gets 83.5% of the test examples correct, does that mean it’s working?
Furthermore, many machine learning algorithms have a curious property: they are robust against bugs. Since they’re designed to deal with noisy data, they can often deal pretty well with noise caused by math mistakes as well. If you make a math mistake in your implementation, the algorithm might still make sensible-looking predictions. This is bad news, not good news. It means bugs are subtle and hard to detect. Your algorithm might work well in some situations, such as small toy datasets you use for validation, and completely fail in other situations — high dimensions, large numbers of training examples, noisy observations, etc. Even if it only hurts performance by a few percent, that difference may be important.
If you’re a researcher, there’s another reason to worry about this. Your job isn’t just to write an algorithm that makes good predictions — it’s to run experiments to figure out how one thing changes as a result of varying another. Even if the algorithm “works,” it may have funny behaviors as some parameterizations compensate for mistakes better than others. You may get curious results, such as performance worsening as you add more training data. Even worse, you may get plausible but incorrect results. Even if these aren’t the experiments you eventually publish, they’ll still mess up your intuition about the problem.
In the next two blog posts, I’ll focus on testing MCMC samplers, partly because they’re the kind of algorithm I have the most experience with, and partly because they are especially good illustrations of the challenges involved in testing machine learning code. These strategies should be more widely applicable, though.
In software development, it is useful to distinguish two different kinds of tests: unit tests and integration tests. Unit tests are short and simple tests which check the correctness of small pieces of code, such as individual functions. The idea is for the tests to be as local as possible — ideally, a single bug should cause a single test to fail. Integration tests test the overall behavior of the system, without reference to the underlying implementation. They are more global than unit tests: they test whether different parts of the system interact correctly to produce the desired behavior. Both kinds of tests are relevant to MCMC samplers.
A development pattern called test-driven development is built around testability: the idea is to first write unit and integration tests as a sort of formal specification for the software, and then to write code to make the tests pass. A fully test-driven development process might look something like:
- write integration tests based on high-level specs for the system
- write unit tests
- write code which makes the unit tests pass
- make sure the integration tests pass
In research, because the requirements are so amorphous, it is often difficult to come up with good specs in advance. However, I’ve found the general idea of getting unit tests to pass, followed by integration tests, to be a useful one. I’ll talk about unit tests in this blog post and integration tests in the followup.
Unit testing is about independently testing small chunks of code. Naturally, if you want to unit test your code, it needs to be divided into independently testable chunks, i.e. it needs to be modular. Such a modular design doesn’t happen automatically, especially in research, where you’re building up a code base organically with constantly shifting goals and strategies. It can be difficult to retrofit to code base with unit tests; in fact, the book Working effectively with legacy code defines legacy code as “code without unit tests.” The whole book is about strategies for refactoring an existing code base to make it testable.
It’s much easier to try to keep things modular and testable from the beginning. In fact, one of the arguments behind test-driven development is that by writing unit tests along with the actual code, you force yourself to structure your code in a modular way, and this winds up having other benefits like reusability and maintainability. This agrees with my own experience.
Unfortunately, the way machine learning is typically taught encourages a programming style which is neither modular nor testable. In problem sets, you typically work through a bunch of math to derive the update rules for an iterative algorithm, and then implement the updates directly in MATLAB. This approach makes it easy for graders to spot differences from the correct solution, and you’ll probably be able to write a correct-looking homework assignment this way, but it’s not a robust way to build real software projects. Iterative algorithms are by nature difficult to test as black boxes.
Fortunately, most machine learning algorithms are formulated in terms of general mathematical principles that suggest a modular organization into testable chunks. For instance, we often formulate algorithms as optimization problems which can be solved using general techniques such as gradient descent or L-BFGS. From an implementation standpoint, this requires writing functions to (a) evaluate the objective function at a point, and (b) compute the gradient of the objective function with respect to the model parameters. These functions are then fed to library optimization routines. The gradients can be checked using finite difference techniques. Conveniently, most scientific computing environments provide implementations of these checks; for instance, scipy.optimize.check_grad (in Python) or Carl Rasmussen’s checkgrad function (in MATLAB).
This organization into gradient computations and general purpose optimization routines exemplifies a generally useful pattern for machine learning code: separate out the implementations of the model and general-purpose algorithms. Consider, for instance, writing a Gibbs sampler for a distribution. Seemingly the most straightforward way to implement a Gibbs sampler would be to derive by hand the update rules for the individual variables and then write a single iterative routine using those rules. As I mentioned before, such an approach can be difficult to test.
Now consider how we might separate out the model from the algorithm. In each update, we sample a variable from its conditional distribution . It is difficult to directly test the correctness of samples from the conditional distribution: such a test would have to be slow and nondeterministic, both highly undesirable properties for a unit test to have. A simpler approach is to test that the conditional distribution is consistent with the joint distribution. In particular, for any two values and , we must have:
Importantly, this relationship must hold exactly for any two values and . For the vast majority of models we use, if our formula for the conditional probability is wrong, then this relationship would almost surely fail to hold for two randomly chosen values. Therefore, a simple and highly effective test for conditional probability computations is to choose random values for , , and , and verify the above equality to a suitable precision, such as . (As usual in probabilistic inference, we should use log probabilities rather than probabilities for numerical reasons.)
In terms of implementation, this suggests writing three separate modules:
- classes representing probability distributions which know how to (1) evaluate the log probability density function (or probability mass function) at a point, and (2) generate samples
- a specification of the model in terms of functions which compute the probability of a joint assignment and functions which return conditional probability distributions
- a Gibbs sampling routine, which sweeps over the variables in the model, replacing each with a sample from its conditional distribution
The distribution classes require their own forms of unit testing. For instance, we may generate large numbers of samples from the distributions, and then ensure that the empirical moments are consistent with the exact moments. I’ll defer discussion of testing the Gibbs sampler itself to my next post.
While this modular organization was originally motivated by testability, it also has benefits in terms of reusability: it’s easy to reuse the distribution classes between entirely different projects, and the conditional probability functions can be reused in samplers which are more sophisticated than Gibbs sampling.
While Gibbs sampling is a relatively simple example, this thought process is a good way of coming up with modular and testable ways to implement machine learning algorithms. With more work, it can be extended to trickier samplers like Metropolis-Hastings. The general idea is to make a list of mathematical properties that the component functions should satisfy, and then check that they actually satisfy those properties.
There are a lot of helpful software tools for managing unit tests. I use nose, which is a simple and elegant testing framework for Python. In order to avoid clutter, I keep the tests in a separate directory whose structure parallels the main code directory.
I’ve found that unit testing of the sort that I’ve described takes very little time and provides substantial peace of mind. It still doesn’t guarantee correctness of the full algorithm, of course, so in the next post I’ll talk about integration testing.