Implicit neural networks
Neural networks consists of a sequence of consecutive operations that are typically defined explicitly. An explicit operation is one that computes the output directly from a sequence of explicit operations applied to the input. A simple example is a feed-forward MLP, where the transition from one layer to the next is done by the following sequence of operations
Additionally one could add operations like batch normalization and max pooling, all of which are given explicitly. Alternatively, two variables can be related via an implicit equation. A simple example of an explicit function versus an implicit function is \(y = x^2\) versus \(x^2 + y^2 = 1\). From the second example it becomes clear why implicit functions are sometimes favorable, since the implicit function \(x^2 + y^2 = 1\) has an explicit counterpart with two equations, namely \(y = \sqrt{1 - x^2}\) and \(y = -\sqrt{1 - x^2}\). In a more abstract fashion, we can write an explicit equation as
and an implicit equation as
Neural networks can be defined implicitly as well through the concept of implicit layers and were introduced under the name Deep Equilibrium Models (DEQ). This concept is a bit abstract but the nice thing about this paradigm is that the memory requirements for deep networks are constant. To understand this concept we need to cover two fundamental concepts:
- Implicit functions and the implicit function theorem. Taking derivatives of explicit functions is easy, since we have an explicit relation of the output with respect to the input, and we can compute \(\frac{dy}{dx}\) in a straightforward manner. However, if \(y\) is only given through \(f(x, y) = 0\) then computing the derivative is less straightforward.
- Fixed point iterations. Fixed point iterations are iterations of the form \(x_{k+1} = \mathcal{F}(x_k)\) and we call a vector \(x_{\star}\) a fixed point of \(\mathcal{F}\) if \(x_{\star} = \mathcal{F}(x_{\star})\). DEQs are based on the idea that the layers of a neural network will eventually reach a fixed point.
Fixed point iterations
Consider the following fixed point iteration
This is essentially repeated application of one layer of a neural network with weight matrix \(W\), bias \(b\), some input \(x\) and activation function \(\tanh\). Assuming for now that a fixed point actually exists we iterate until convergence, i.e. \(z_{\star} = \mathcal{F}(z_{\star})\) up to some tolerance. Alternatively, we can write the above equation as
where the function is now implicitly defined. Defining
the goal is now to solve the root finding problem
where \(z_{\star}(x)\) denotes the solution depending on \(x\). Let the solution to this problem be given by \(z_{\star}(x)\) and assume we want to compute \(\frac{dz_{\star}(x)}{dx}\) (note that we could choose to differentiate through any parameter, for example the weight matrix, but this is just for illustrative purposes). Since we only have access to \(z_{\star}\) through the equation \(g(x, z_{\star}(x)) = 0\) need to differentiate through this equation to obtain \(\frac{dz_{\star}(x)}{dx}\). This yields:
This equation allows us to solve for \(\frac{dz_{\star}(x)}{dx}\) as follows
The main question here is whether existence is guaranteed. The implicit function theorem states that if a fixed point exists and the function \(g\) is differentiable with non-singular Jacobian around \(z_{\star}\) there exists a unique function \(z_{\star}(x)\). The key point here is that one can differentiate through \(z_{\star}\) without needing to differentiate through the solver used to obtain the fixed point. This saves a huge amount of memory that would otherwise be needed in order to perform backpropagation.
This observation has led to the development of the Deep Equilibrium Network. This network has the following structure:
As we can see, DEQs apply a fixed point iteration to a single layer of a neural network. The question is whether this fixed point iteration actually converges: It could also blow-up or oscillate. It turns out that in general the fixed point iteration converges. As you can probably guess at this point, the fixed point iteration is solved using implicit differentiation, thereby bypassing the need to store any information necessary for the backward pass. This way one can build an extremely deep network. If we now want to update the weights of the neural network we need to evaluate the partial derivative with respect to \(W\). Given that \(z_{\star}\) is a fixed point we have
Computing \(\frac{\partial f(x, z_{\star})}{\partial W}\) via implicit differentiation and rearranging terms gives
Backpropagation actually implements the transpose of this expression, i.e.:
where \(y\) is some vector we apply the gradient to. Evaluating the gradient is now a two-step process:
- Evaluate \(\left(I - \frac{\partial f(x, z_{\star})}{\partial z_{\star}}\right)^{-1}y\). Since this matrix tends to be large we do not evaluate the inverse directly, but rather solve the linear system $$ y = \left(I - \frac{\partial f(x, z_{\star})}{\partial z_{\star}}\right)g \quad \Leftrightarrow \quad g = \frac{\partial f(x, z_{\star})}{\partial z_{\star}}g + y. $$
- Compute $$ \left(\frac{\partial z_{\star}}{\partial W}\right)^Ty = \left(\frac{\partial f(x, z_{\star})}{\partial W}\right)^Tg $$
So far we have considered a rather simple model for the DEQ. We have assumed a constant weight \(W\) accross the layers and have assumed a simple feed-forward model. It turns out that a feed-forward neural network with constant weights accross the layers is actually equivalent to a neural network with a layer-dependent matrix, which is summarized in the following theorem by Bai et al., 2019:
Consider a traditional \(L\)-layer MLP
This network is equivalent to the following weight-tied network of equivalent depth:
We prove the theorem for the case \(L = 4\), but it extends to general \(L\). Define the matrices
Then after one iteration we have
For the second iteration we have
and hence
Similarly, for the next layer we obtain
which leads to
Then, finally,
and hence
Moreover, note that we have only used a single layer DEQ as opposed to the multi-layer architecture that is typical for powerful neural networks. However, any deep neural network can be represented as a deep neural network. The argument is as follows. Assume that construct a two-layer network \(g_2(g_1(x))\). This can be posed a single layer DEQ using the following relation:
That is, the complexity of the extra layer can simply be added by concatenating the two layers to make a single layer neural network. The same argument holds for stacking DEQs: a single DEQ can model any number of stacked DEQs.
Finally, we can increase the complexity of the DEQ by substituting the simple feed-forward neural network with any sequence of operations, including convolutions, normalizations, grouping and skip connections.
Further reading
These notes are essentially a summary of the following tutorial, specifically chapters 1 and 4:
Below is the paper introducing Deep Equilibrium Models (DEQ):