Writing

Amol Kapoor

Simple DL Part 6: An End to End Example (with Code!)

June, 2021
TLDR
Overview

At this point, you hopefully have a high level understanding of some key deep learning principles: features, embeddings, losses, and how they all interact with each other. Unfortunately for me, some of my readers complained that this was not enough, and that without an end to end example that showed how the intuition could be applied, this whole project was meaningless. sigh. Even though I was really hoping to avoid digging into the specifics of a deep learning library, I think my readers are probably right. Hold on tight folks, this one is going to be a long one.

There are countless starter examples for deep learning. Most of these implement a small neural network that can learn to classify handwritten numbers from 0 to 9 (aka MNIST). Many of these tutorials are quite good for understanding the particular syntax of a specific library, but they do a poor job of linking the code to some deeper understanding of ML. In part that is because the deeper understanding doesn't really exist -- the answer to 'why' is 'because'.

I wish I could say that this frustration goes away. It doesn't. I still feel like this when I read new ML papers.

In this tutorial, I'll instead do something totally different by teaching you how to implement a small neural network that can learn to classify handwritten numbers from 0 to 9. We're not going to touch code until the very end -- instead, we'll spend a lot of time trying to think through the problem in order to build some intuition of what we should be doing. All parts of the tutorial will be grounded in the previous SimpleDL lessons. Even though MNIST is a really well known dataset with countless 'solutions', I'll try to approach the task as if it was a real world learning problem.

The Problem

You work for the IRS. You have to deal with millions of tax filings -- over 150M, according to a random website called Google. That's a ton of filings. Most people do these on printed forms, filling in fields by hand. The techs over at the Department of Technology have scanned all the filings. Now they need to pull out all of the numbers.

One problem: the scans are all unparsed images.

We need to build a system that can convert images of numbers into actual numbers in some programming language or database, so that we can do more number crunching down the line. Unfortunately, there are tons of edge cases, which makes most statistical/geometric/traditional computer vision approaches obsolete. A human can probably figure out most of them, but humans are expensive and slow. Can we use deep learning?

Pictured: the US Federal Department of Technology logo (the joke is that there's no such thing).

Using Canonical Tasks: Loss

In Part 5, we laid out three canonical tasks: classification, multiclassification, and regression. If we figure out which bucket our IRS problem falls in, we can infer a default loss for our model. Let's work through each possibility in reverse, starting with regression.

A regression task is one where we try to predict a continuous output. It is tempting to look at the IRS problem and say that since we are predicting numeric output, the problem space must be continuous. Unfortunately, that way lies madness. The trick is that even though we are predicting numbers, we are treating each number as a discrete category. In other words, if the true label is '1', the model is equally wrong if it predicts '1' or '9'. If the ordering output doesn't matter, it's not a regression problem.

What about multiclassification? The main difference between classification and multiclassification is whether the categories are mutually exclusive. In our IRS problem, each input could only be one of ten digits (0 - 9). In other words, they ARE mutually exclusive.

That leaves classification as the only remaining option. In Part 5, we showed that the Softmax Cross Entropy is a standard loss for classification tasks. Following the example, we can label our training data with a one-hot vector of size 10, where the hot index is the true number. We make the model output a vector (logits) of size 10, so we can compare the output with the ground truth labels. And then we just pass these into the appropriate function and call it a day.

If this all made sense, congrats. This is the hardest part.

What about the Features?

Some people think feature selection is a fine art. Those people are wrong. In deep learning, feature selection is something of a misnomer -- if you have features available, you should use them. The problem is that features are really hard to get, because they have to be consistent across all of the training and test data.

In our IRS problem, we can't be sure we have any metadata available consistently. People forget to add names, or addresses, or whatever. In fact, the only thing we can be sure of is that we have the raw pixels of the number we're trying to predict. So...let's just use that as our features.

But of course we can't just pass an image directly into a model from a filepath (ignoring this). You have to convert all features into a numeric vector representation first. Luckily, for images this is really easy -- each pixel is already a numeric RGB value. We can convert the images to numbers using a python library of choice and then use that as input. Each image ends up being a 3D input, with shape [Batch, Height, Width , Channels]. If we want to feed multiple images in at once, we can stack the vectors on a new axis, resulting in a [Batch, Height, Width, Channels]. O, and you are going to want to normalize the features, but that's a conversation for later.

What about the Model?

At this point, we know the input (image matrices) and the output (a logits vector of size 10). So we can slot in just about any model we want, as long as it constrains to the input and output.

Does the model choice matter? Well, kinda. Models have a certain 'capacity' that limits what kind of problems it can effectively solve. The problem is, we haven't really figured out how to calculate or define 'capacity'. Most engineers use the number of parameters in the model as a rough heuristic. More parameters = more capacity. On some level this is intuitive, but we also recognize that a multi-layer model is better than a flat model, even if they have the same parameter size.

I digress.

As long as we choose a model that isn't literally a single flat layer, we should be alright. In the industry, we broadly slice models by task -- convolutional networks for image/video processing, transformers for language processing, graph neural networks for graphs, who-knows-what for audio, etc. Since we are using images, we can use a convolutional network.

What about...everything else?

It is a little reductive to say that the above is all you need -- in practice, it's about 95% of what you need. There are two pieces left.

One piece that we're missing is called an optimizer. Optimizers are algorithms that calculate how to turn your model's output loss into gradients for each part of the model. Generally, you can set a learning rate and a few other parameters that determine how much each step of training impacts the model. A deep dive into how optimizers work isn't really in scope here. Suffice to say, there's a lot of interesting kinds of optimizers out there, and they all do slightly different things, and you pretty much always want to use the AdamOptimizer with a learning rate of 0.0001.

This is Adam. He optimizes things.

The other piece of the puzzle is our batching algorithm. Ideally, we could feed our model every piece of data in the dataset in one go. Unfortunately, for really large datasets, we don't have the computer memory to do that. So we instead sample the dataset into 'batches' and train the model one batch at a time. When we go through the entire dataset, we say that the model has trained for one 'epoch'.

One epoch is many batches.

A deep dive into how batching works isn't really in scope here either. Suffice to say, there are a lot of batching algorithms, and a lot of people are doing really cool work understanding how sampling strategies impact model training, and you pretty much always want to use random sampling with the largest batch size you can.

Let's code!

def get_features(paths):
  pass

def get_model():
  pass

def get_loss(logits, labels):
  pass

def get_optimizer():
  pass

def batch(iterable, batch_size=1):
  """Given an iterable, produce batches of size batch_size."""
  for idx in range(0, len(iterable), batch_size):
    yield iterable[idx:min(idx + batch_size, len(iterable))]

if __name__ == '__main__':
  PATHS = glob.glob('path/to/images/*')
  EPOCHS = 10
  BATCH_SIZE = 64
  model = get_model()
  optimizer = get_optimizer()
  for epoch in range(EPOCHS):
    # randomize the paths for each epoch.
    for batch in batch(PATHS, BATCH_SIZE):
      data, labels = features(batch)
      logits = model(data)
      loss = get_loss(logits, labels)
      optimizer(model, loss)

    print(f'Loss at epoch {epoch} is {loss}')

  INFER_PATHS = glob.glob('path/to/inference/images/*')
  data, labels = features(INFER_PATHS)
  logits = model(data)
  predictions = np.argmax(logits)
  accuracy = get_accuracy(predictions, labels)
  print('Model accuracy: ', predictions)

      
Conclusions

Deep learning isn't really software engineering. A software engineer spends time thinking about data structures, and interfaces, and abstractions, and representing all of these things with code. In contrast, the little bit of code that deep learning engineers actually write is more like setting up scaffolding. It's not that hard to write any specific part of a deep learning pipeline; the hard part is making sure your data pipeline does what you want.

This is why people use MNIST as an introduction, and why experts will use MNIST to debug their models when things go wrong. It's pretty easy to set up an MNIST data pipeline because the data is in a format that is very easy to use. All the images are the same size, in the same format, with accurate labels. The entire dataset is small enough that you can actually iterate in a reasonable timeframe. And a simple MNIST solution accurately demonstrates the rough skeleton for tackling deep learning: features, model, loss. Figure out how to turn your data into an embedding that makes sense for your specific problem, and you're good to go.