Robust Neural Networks: How to Prevent NaN

Presented here is a method for preventing neural networks from causing NaN’s to occur during training. NaN’s can be fatal to model training. If allowed to propagate, they will overwrite all weights in the network, destroying all prior training. Even if NaN’s are detected and rolled-back during training, the result is a fragile network that is difficult if not impossible to train further. How can we overcome this problem?

NaN Dragon

We first need to understand why some calculations result in NaN. In the case of a GPU, the hardware is designed to trade off precision for performance. CPU’s are able to perform such calculations with greater precision, but are much slower and generally not up to the task of training machine learning models. On my test GPU, the range of tanh(𝑥) is approximately (-3.4, 3.4). Outside of this range, tanh(𝑥) returns NaN. This is a rather small range for a function defined from negative infinity to positive infinity. Even when our network has inputs with a mean of 0 and deviation of 1, it doesn’t take many inputs to overwhelm tanh(𝑥).

To prevent tanh(𝑥) from producing NaN, we somehow need to make sure the weighted-sum inputs to the network do not fall outside of this range.

Reasoning About Error Signal

The error budget is the amount of signal that may be applied to the model to train it. In typical scenarios, 100% of the available error is used to drive the model towards the expected output. For example, when you train your network to output 1 if an image contains a dog, and -1 if it does not, you are using 100% of your error budget to train that objective. The training conditions of the network are:

  • expected 1, adjust weights to drive output higher
  • expected -1, adjust weights to drive output lower

Remember that at a given output node:

errorDelta = (expectedOutput - actualOutput) * TanhDerivative(actualResult)

Also, to update a given weight:

weights[i, j] += learningRate * errorDeltas[j] * nodeValues[i];

An example with additional conditions is label smoothing. We introduce the concept of overconfidence by stating that the network should output some value less than 1 when it sees a dog, and some value greater than -1 when it does not. For example, we might define 0.9 as a positive result and -0.9 as a negative result. If the network actually outputs 1 or -1, it is overconfident. In this case, the calculated error will cause the network to adjust the output towards the expected range (-0.9, 0.9). We therefore have additional conditions:

  • expected 0.9, actual result is too high, drive output lower
  • expected 0.9, actual result is too low, drive output higher
  • expected -0.9, actual result is too high, drive output lower
  • expected -0.9, actual result is too low, drive output higher

Understanding this, we’re able to reason about the effect of error signal at any node during backpropagation. e.g. If the input to tanh(𝑥) is near the upper limit, and errorDelta is positive, then we know that the next weight update will drive the input to tanh(𝑥) even higher, possibly causing NaN. Conversely, if the input to tanh(𝑥) is near the lower limit and errorDelta is negative, the next weight update might cause NaN to occur. Further, because we know this, we may attenuate the errorDelta (reduce or even negate it) such that the input to tanh(𝑥) is kept within the desired range. We’re using part of our error budget to establish a boundary for each tanh(𝑥) activation in the network. This effectively prevents NaN from occurring.

[numthreads(GROUPSIZE, 1, 1)] 
void StoreAccumulatedTanhConstrainedErrors(uint3 dtid : SV_DispatchThreadId) 
{ 
    uint leftNodeIndex = dtid.x + destinationLayerStartIndex; 
    if (dtid.x < destinationLayerCount)
    {
        float leftwardNodeValue = nodeValues[ValuesIndexOfLeftLayer(leftNodeIndex)];
        float leftwardNodeDerivative = nodeDerivatives[ValuesIndexOfLeftLayer(leftNodeIndex)];
        float errorDelta = accumulations[AccumulationsIndexBetween(leftNodeIndex, sourceLayerStartIndex)] * leftwardNodeDerivative;
        if ((leftwardNodeValue > TANH_CONSTRAINT_THRESHOLD && errorDelta > 0.0f) || 
          (leftwardNodeValue < -TANH_CONSTRAINT_THRESHOLD && errorDelta < 0.0f))
            errorDelta *= -1.0f;
        errorDeltas[ValuesIndexOfLeftLayer(leftNodeIndex)] = errorDelta;
    }
}

What is the effect on training? The effect is that we are spending some portion of our error budget to prevent the network from exploding. Training is impacted to the degree that signal would have otherwise been used to further refine the network’s weights. However, the point must be made that any diminished capacity for training isn’t the result of preventing the NaN, it is simply the intrinsic limit of the model as it was initially designed. The tanh(𝑥) constraint is simply taking a model that would’ve exploded and preventing that from occurring. This added robustness is an unalloyed improvement in the model’s capability. After all, models that explode are of no value at all.

Related Reading

Bob Burrough
October 17, 2021