Although image classification with DL/CNNs (Deep Learning/Convolutional Neural Networks) has become extrememly accurate, there are things happening under the hood that make me raise an eyebrow.
One of those "You've got to be kidding" situations is an issue I call "Kitten Stew", and this paper describes the approach I've used to get the CNN to behave in a more intuitively sensible way in this situation.
If you run a CNN for image classification, and you deconvolve the images that are maximally activating the various feature maps, some of the images make intuitive sense... and some of them don't.
The ones that make the most obvious sense are the "early" ones, near the input - Gabor-like filters and variants thereof. That makes sense : detecting edges and textures is fundamental and it makes sense for those features to be the building blocks of the more abstract features. But closer to the classification head, the results are...strange. For identifying a kitten (for example), the synthetically constructed images that maximally activate "kitten" simply don't look like kittens. They have a lot of kitten-like textures. They have some easily recognizable kitten eyes scattered around. They have a mixture of kitten noses dotted here and there. They may even have a few kitten heads, feet, and tails sort of randomly stirred in. But they don't look anything like a kitten. They look like kitten stew.
Which means that even though the network has an extremely high bottom-line accuracy, at some deep level the CNN is *not* identifying kittens. In some conceptual sense, the CNN is looking at the image and saying "hmm... we've got noses, and feet, and some eyes, and some tails, and a whole bunch of fur... I betcha this came from a kitten!"
And that feels fundamentally wrong to me.
As it turns out, this issue can have consequences for the bottom line results as well. This is demonstrated by adversarial CNNs, which make it easy to tweak just about any image such that a human says "That's a dump truck!" (or whatever the image really is), but the CNN says "kitten!" - because the features the CNN are keying in on aren't actually, truly the essence of kitten-ness.
Another example of where I've seen this bite classification CNNs is with a roughly hexagonal pattern of dark dots on a light background - if the dots have the right size, the right spacing within rows, and the right spacing between rows, the CNN is absolutely certain it's a dog - look at all those dog eyes and noses!
And that means there's something fundamental about CNNs that we could be doing better.
How Do We Fix It?
We start by listening to how we teach children. For a picture of a cat, we don't say "look carefully at the fur. You see, dog fur would be very slightly coarser. Also, look at the edge of the body, where you can see how the cat's fur stands out against the background. There's something subtly different about that from dog's fur. " No. We say "There are the kitty's two ears. And there are its eyes (one, two), and there is its mouth, and there is its tail." And there's something about teaching that way that works.
This illustrates that one thing missing from CNNs is simply... counting. A cat has a nose, two eyes, and a tail. If you can't see some of those things, that's OK, but the identification should be slightly weaker - maybe the cat has its head turned, or is sitting on its tail, or something. But if you can't see any of them, then the identification probably ought to be pretty weak. And if you see too many of them - 3 eyes, for instance - then that's not a cat. Maybe it's two cats. Maybe it's something else. But the CNN should not become super excited about how cat-like the image is, no matter how much cat fur is visible, if the thing in question has 3 eyes.
What Does it Mean to Count Features?
In CNNs, what does "counting features" even mean? Certainly not the number of "neurons" activated in a feature map, which is dependent on input resolution and the aggressiveness of our pooling layers. After thinking about it some, I came up with this:
The "count of features" in a feature map is the count of regions, each made up of contiguous neurons, where the activation of each neuron exceeds some threshold.
If the grid represents activations of neurons in a feature map that exceed some threshold, then the count of features is 3: it has 3 contiguous regions of 1s.
If the feature map represents something like puppy feet, then having the information that there are three of them seems like it could be awfully useful.
How do we get a CNN to Count Features?
Computer graphics gives us several industry-standard algorithms that do exactly what we want, for instance variants of the "floodfill" algorithm. So counting features isn't hard. However, if we implement these algorithms in our DL network, the results are... horrible. Tragic. Disastrous. Because all of those algorithms are designed for the old world of iterative, sequential programming without GPUs. And all of them (that I could find) bring parallelized DL networks absolutely to their knees.
How Do we Count Faster?
The key realization on the path to counting features faster is that undercounting is inherent in analyzing 2-d images of the 3-d world, and that it's OK. We frequently see cats that only have one eye visible , and we still recognize that they are cats. We expect the same capability from CNNs doing image classification. So if our algorithm undercounts a feature - there are 4 paws visible but we only count 3, or 2, of them - then it's going to be OK. But we don't want the count to be 0 (it's useful to know that there are paws present!), and we definitely don't want to overcount (things with more than 4 paws aren't cats!), and we'd like the count to be accurate if possible (because accuracy can only help).
Given the constraint that accuracy is desirable but that undercounting is OK, we can create a matrix-centric procedure like this:
- Given an HxW activation map A of some feature with approximate 1s representing activation above some level (I say "approximate 1" because these numbers will be the results of squashing and won't be exactly 1)
- Compute C=AxB where B is a column vector of W 1s, resulting in a column vector C of H values.
- Squash the values in C into 0-1.
- Pad C with a 0 at each end, then run a 2x1 convolution along C, with the convolution configured to detect transitions from ~0s to ~1s. (This is just an edge detector). This will result in a column vector with ~1s where we find edges.
- Sum the column vector from 4. This is the count of the number of contiguous sequences of rows that contain the feature.
For Example 1 above, this procedure gives the result "2", because the sequence of rows 1-3 contains some 1s, and the sequence of rows 5-6 contains some 1s.
Well, this "2" gives us a bit of insight into the count of features, but a) it's wrong (the actual feature count is 3) and b) it's easy to devise input where the feature count from this procedure would be *badly* wrong. Like this:
Actual count of features: 6
Feature count from our procedure: 1
But instead of throwing this approach out altogether, let's first observe that the procedure we used is very fast to compute for tensor-centric computational frameworks: the entire procedure is expressed in the sorts of ways that DL frameworks like to behave.
So let's see if there's a way to build on this.
There is, if we we use a bit of metaphor: Imagine the grid being a table top. The 0's are the flat table top, and the reqions of 1's are objects sitting on the table. Imagine we crouch down at one end of the table, so that we're looking down its length. There's a light at the other end, so we can only see the objects sticking up in profile. In Example 1, we count 2 objects sticking up, even though there are really 3.
Intuitively, we know that this means there are *at least* two objects on the table, but there could be more. How do we find out? By changing our vantage point: moving around the edge of the table, still level with the table top, to see if there's an angle from which we can see a higher count of objects - i.e. a location from which we can see more gaps in the silhouette.
It's obvious we can implement a second perspective. Like this:
- Transpose the matrix.
- Run the prodecure from above again.
- Now we have a count of the sequences of *columns* that contain 1s, instead of a count of the sequences of rows. This is the equivalent of looking at the table from the side, instead of from the head or foot.
For Example 1, our result from this procedure is still wrong (2 instead of 3), but for Example 2, this takes us from badly wrong (1 instead of 6) to nearly right (we count 5 objects.)
This shows that we can increase our accuracy by taking the maximum of the results of the two procedures (one "looking" from the head of the table and one "looking" from the side), which intuitively makes sense from our metaphor: whichever vantage point shows us the largest number of breaks between objects gets us closest to the number of objects.
From 2 Perspectives to 4 (Slightly More Complicated)
Is this the best we can do? The answer is obviously no - there should be a way to look at the "table top" from other angles as well. The question is how to do it performantly, and that turns out to be easy.
One thing that all DL/tensor-based frameworks support is various ways to transform matrices. Some of them do, or do not, directly support the following, but to the extent that they allow the implementation of new low-level re-arrangement operations on matrices in a performant manner, they can all be made to support it.
Instead of transposing the matrix, let's skew it, so that it descends as it goes to the right:
Original matrix from Example 1:
Matrix from Example 1, skewed so that it desceneds one row per column, from left to right:
For readability, this shows dots for the padding we need to add, but in actual practice we would use zeros.
If we now apply the original procedure to this skewed matrix, we get the result 3, for three contiguous sequences of rows that contain 1s (rows 4-10, 14-18, and 20-22). And that's the right answer! There are three contiguous regions of ones in the original activation map, corresponding to instances of the target feature.
Metaphorically, this is equivalent to viewing the "table" from the top right corner.
Obviously, by slanting the matrix *up* as it goes to the right, we could correspondingly count features from the "vantage point" of the top-left corner.
By combining these ways of looking at the matrix, we can effectively view it from four different vantage points (top to bottom, side to side, and along both diagonals), thereby having a decent chance of accurately counting the actual number of features (by taking the maximum of the four results).
So... is that enough?
Well, it depends. There are two notable features of this approach:
- It may undercount, but it never returns zero if a feature is present.
- It never overcounts the features.
Since the DL network needs to be able to handle some level of undercounting anyway (because hidden features are an inherent feature of 2-d pictures of 3-d objects), this means that the question of "is four perspectives good enough" is a practical matter of balancing the cost of computing additional perspectives against the benefits of slightly increasing accuracy, rather than something with a correct answer. Personally I've found that 4 perspectives works fine, but there could be applications where the increase in accuracy is worth more computation.
From 4 Perspectives to 8 (More Complicated Again)
If we decide we want 8 perspectives instead of 4 (I haven't found a situation that really requires that), we would construct 4 additional matrices: one which descends at a rate of 2 rows per column, one which descends at a rate of 1 row per two columns, one which *ascends* at a rate of 2 rows per column, and one which ascends at a rate of 1 row per two columns.
Here's a (contrived) example of a feature activation map where this "8 perspective" approach results in improved accuracy:
Actual feature count: 2
Computed feature count, from any of the 4 "main" perspectives (top, side, and diagonals): 1
Here's the activation map re-skewed at "2 rows down per 1 column over":
Actual feature count: 2
Computed feature count: 2
Which means the "8 perspectives" approach can in fact increase accuracy in some situations.
However, there's a danger with this, as seen below.
Re-skewed with a 2-to-1 slant:
Actual number of features in the activation map: 1
Computed feature count: 3
Which is not only wrong, but badly and misleadingly wrong. While undercounting is undesirable, overcounting is disastrous: If our DL network is highly tolerant of overcounts, we've lost our edge in distinguishing "kitten stew" from a kitten.
So we need to fix this if we're going to use the "8 perspective" approach. Purely from experimentation, the solution I've developed is that before creating the 2-to-1 up and 2-to-1 down matrices, we offset the activation map by one row and add it to itself. Like this, for Example 3:
110 ... 110
100 110 210
101 + 100 = 201
101 101 202
... 101 101
When skewed 2-to-1, this still gives the correct feature count ("2").
And more importantly, it fixes the feature count for Example 4 as well. This "offset and add" effectively "blurs" the matrix slightly in the vertical direction. There are situations where this can decrease the accuracy instead of improving it - but when it decreases accuracy, it does so in the direction of undercounts rather than overcounts.
...And More Perspectives Can't Guarantee Accuracy Anyway
Actual feature count: 2
Count when viewed from any angle: 1
So there is no number of perspectives which will guarantee accuracy. And that's OK. Enough said.
How Do We Use this in a DL network?
So far I've focused on creating a matrix-centric, performant way to count features, with the rather vague belief that the resulting information will be useful in classifying images (although I expect it can be useful in regression, too.)
But how do we actually wire this capability into a network so that it does something?
The first thing to observe is that this process results in just one additional output from each feature map - because it is summary information, it has no spatial extent at all. As such, it is unsuitable for getting piped into the "next" layer of the CNN in any obvious way.
The thing that *is* obvious is feeding it directly into the final FC or max pool portions of the network as a fundamental input: "there are two cat eyes in the image" is mighty useful information for classification.
This approach works, but it turns out to require additional FC layers to work well. That's because the number "3 puppy paws" isn't really what we want for classification; what we ideally want is more like this:
- 0 puppy feet: I guess it could still be a puppy but this isn't evidence for it.
- 1 puppy foot: maybe that's a puppy.
- 2 puppy foot: decent evidence for a puppy
- 3-4 puppy feet: solid evidence for a puppy
- 5 puppy feet: uh, maybe we counted wrong, but "5 feet" isn't evidence of a puppy.
- 6 puppy feet: those aren't actually puppy feet and that isn't a puppy.
A couple of non-linearity layers applied to our count result can easily accomplish this result (especially with PRELUs), but if we're doing that learning in our FC layers, that means adding more FC layers to get the results into a useful form for classification. And simply put, that's folly: FC layers are incredibly expensive in terms of parameters trained, and adding more of them to learn the desired shape of the histograms (because a "histogram" is what I've given above) is nuts.
So instead of piping the feature count directly into the left-most FC layer, we pipe it into a "histogram neuron" which can learn the shape of the most useful histogram, and then pipe *that* result into the FC layer.
My design for a histogram neuron is trivial:
weight and bias, PRELU, weight and bias, PRELU. That's it.
This sequence effectively gives us four line segments to play with in shaping a curve, which is adequate for the histograms that most image classifications require. (In fact, I suspect we could get away with learning just a peak and two slopes, but I haven't tested that).
So to re-cap:
- Our feature maps each get a feature count computed "beside" them, as an additional characteristic of the feature map.
- Those counts then pass into histogram neurons that learn the desired distribution of counts.
- The activation level of each histogram neuron then becomes an input to the final classification (or regression) levels of the network (or to some other portion of the network that does not express spatial extent).
What About BackProp And Learning?
It's obvious that we can and should backpropagate through the "histogram neurons" to learn the desired shape of the histograms.
Backpropagating further back brings us to the counting procedures themselves. The procedures are expressed in terms of matrix operations and re-arrangements, so it is possible to backprop through all of them. There are three obvious ways I came up with to handle this, based on how we want to establish the cut-off levels for our counting (i.e. we need to establish what activation levels to treat as 1-ish vs. 0-ish). Here they are:
- We don't backprop through counting at all. Each feature map in the network is going to be followed by a non-linearity layer of some sort anyway (that's the usual pattern), and the parameter(s) of that non-linear layer are going to be learned based on the usual channels. We could just attach our counting processes to the activation maps resulting from those existing non-linearitiy layers, and use whatever thresholds they learn.
- We backprop through each of our counting processes, and use that backprop as an input into the learning of the associated non-linearity layer, the same as the backprop from the rest of the network is an input into the learning of the non-linearity layer. (This is the obvious approach.)
- We create a separate non-linearity layer - parallel to the usual non-linearity layer - for each feature map, and we attach our counting processes to these new activation maps. This allows the thresholds used in counting to be learned independently from the thresholds used in the rest of the network. And, of course, we backprop through the whole thing.
#1 has the advantage of not requiring us to write backprop code for those ugly matrix skews, but isn't as effective for obvious reasons. #2 is the cleanest, and #3 is the most flexible. Like most things in DL network design, figuring out which is best in a particular application is going to require trial and error.
But any way you do it, you get something a lot less likely to say "Doggie!" when it sees a particular array of black dots.