Neural networks are well known to be over-parameterized and can often easily fit data with near-zero training loss with decent generalization performance on test dataset. Although all these parameters are initialized at random, the optimization process can consistently lead to similarly good outcomes. And this is true even when the number of model parameters exceeds the number of training data points.
Neural tangent kernel (NTK) (Jacot et al. 2018) is a kernel to explain the evolution of neural networks during training via gradient descent. It leads to great insights into why neural networks with enough width can consistently converge to a global minimum when trained to minimize an empirical loss. In the post, we will do a deep dive into the motivation and definition of NTK, as well as the proof of a deterministic convergence at different initializations of neural networks with infinite width by characterizing NTK in such a setting.
đ¤ Different from my previous posts, this one mainly focuses on a small number of core papers, less on the breadth of the literature review in the field. There are many interesting works after NTK, with modification or expansion of the theory for understanding the learning dynamics of NNs, but they wonât be covered here. The goal is to show all the math behind NTK in a clear and easy-to-follow format, so the post is quite math-intensive. If you notice any mistakes, please let me know and I will be happy to correct them quickly. Thanks in advance!
Basics
This section contains reviews of several very basic concepts which are core to understanding of neural tangent kernel. Feel free to skip.
Vector-to-vector Derivative
Given an input vector
Throughout the post, I use integer subscript(s) to refer to a single entry out of a vector or matrix value; i.e.
The gradient of a vector with respect to a vector is defined as
Differential Equations
Differential equations describe the relationship between one or multiple functions and their derivatives. There are two main types of differential equations.
- (1) ODE (Ordinary differential equation) contains only an unknown function of one random variable. ODEs are the main form of differential equations used in this post. A general form of ODE looks like
. - (2) PDE (Partial differential equation) contains unknown multivariable functions and their partial derivatives.
Letâs review the simplest case of differential equations and its solution. Separation of variables (Fourier method) can be used when all the terms containing one variable can be moved to one side, while the other terms are all moved to the other side. For example,
Central Limit Theorem
Given a collection of i.i.d. random variables,
CTL can also apply to multidimensional vectors, and then instead of a single scale
Taylor Expansion
The Taylor expansion is to express a function as an infinite sum of components, each represented in terms of this functionâs derivatives. The Tayler expansion of a function
The first-order Taylor expansion is often used as a linear approximation of the function value:
Kernel & Kernel Methods
A kernel is essentially a similarity function between two data points,
Depending on the problem structure, some kernels can be decomposed into two feature maps, one corresponding to one data point, and the kernel value is an inner product of these two features:
Kernel methods are a type of non-parametric, instance-based machine learning algorithms. Assuming we have known all the labels of training samples
Gaussian Processes
Gaussian process (GP) is a non-parametric method by modeling a multivariate Gaussian probability distribution over a collection of random variables. GP assumes a prior over functions and then updates the posterior over functions based on what data points are observed.
Given a collection of data points
Check this post for a high-quality and highly visualization tutorial on what Gaussian Processes are.
Notation
Let us consider a fully-connected neural networks with parameter
The training dataset contains
Now letâs look into the forward pass computation in every layer in detail. For
Note that the NTK parameterization applies a rescale weight
All the network parameters are initialized as an i.i.d Gaussian
Neural Tangent Kernel
Neural tangent kernel (NTK) (Jacot et al. 2018) is an important concept for understanding neural network training via gradient descent. At its core, it explains how updating the model parameters on one data sample affects the predictions for other samples.
Letâs start with the intuition behind NTK, step by step.
The empirical loss function
and according to the chain rule. the gradient of the loss is:
When tracking how the network parameter
Again, by the chain rule, the network output evolves according to the derivative:
Here we find the Neural Tangent Kernel (NTK), as defined in the blue part in the above formula,
where each entry in the output matrix at location
The âfeature mapâ form of one input
Infinite Width Networks
To understand why the effect of one gradient descent is so similar for different initializations of network parameters, several pioneering theoretical work starts with infinite width networks. We will look into detailed proof using NTK of how it guarantees that infinite width networks can converge to a global minimum when trained to minimize an empirical loss.
Connection with Gaussian Processes
Deep neural networks have deep connection with gaussian processes (Neal 1994). The output functions of a
Lee & Bahri et al. (2018) showed a proof by mathematical induction:
(1) Letâs start with
Since the weights and biases are initialized i.i.d., all the output dimensions of this network
(2) Using induction, we first assume the proposition is true for
Then we need to prove the proposition is also true for
We can infer that the expectation of the sum of contributions of the previous hidden layers is zero:
Since
When
The form of Gaussian processes in the above process is referred to as the Neural Network Gaussian Process (NNGP) (Lee & Bahri et al. (2018)).
Deterministic Neural Tangent Kernel
Finally we are now prepared enough to look into the most critical proposition from the NTK paper:
When
- (1) deterministic at initialization, meaning that the kernel is irrelevant to the initialization values and only determined by the model architecture; and
- (2) stays constant during training.
The proof depends on mathematical induction as well:
(1) First of all, we always have
(2) Now when
Note that
Next letâs check the case
The output function of this
And we know its derivative with respect to different sets of parameters; let denote
where
The NTK for this
where each individual entry at location
When
and the red section has the limit:
Later, Arora et al. (2019) provided a proof with a weaker limit, that does not require all the hidden layers to be infinitely wide, but only requires the minimum width to be sufficiently large.
Linearized Models
From the previous section, according to the derivative chain rule, we have known that the gradient update on the output of an infinite width network is as follows; For brevity, we omit the inputs in the following analysis:
To track the evolution of
Such formation is commonly referred to as the linearized model, given
Eventually we get the same learning dynamics, which implies that a neural network with infinite width can be considerably simplified as governed by the above linearized model (Lee & Xiao, et al. 2019).
In a simple case when the empirical loss is an MSE loss,
When
Lazy Training
People observe that when a neural network is heavily over-parameterized, the model is able to learn with the training loss quickly converging to zero but the network parameters hardly change. Lazy training refers to the phenomenon. In other words, when the loss
Let
Still following the first-order Taylor expansion, we can track the change in the differential of
Let
Chizat et al. (2019) showed the proof for a two-layer neural network that
Citation
Cited as:
Weng, Lilian. (Sep 2022). Some math behind neural tangent kernel. LilâLog. https://lilianweng.github.io/posts/2022-09-08-ntk/.
Or
@article{weng2022ntk,
title = "Some Math behind Neural Tangent Kernel",
author = "Weng, Lilian",
journal = "Lil'Log",
year = "2022",
month = "Sep",
url = "https://lilianweng.github.io/posts/2022-09-08-ntk/"
}
References
[1] Jacot et al. âNeural Tangent Kernel: Convergence and Generalization in Neural Networks.â NeuriPS 2018.
[2]Radford M. Neal. âPriors for Infinite Networks.â Bayesian Learning for Neural Networks. Springer, New York, NY, 1996. 29-53.
[3] Lee & Bahri et al. âDeep Neural Networks as Gaussian Processes.â ICLR 2018.
[4] Chizat et al. âOn Lazy Training in Differentiable Programmingâ NeuriPS 2019.
[5] Lee & Xiao, et al. âWide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent.â NeuriPS 2019.
[6] Arora, et al. âOn Exact Computation with an Infinitely Wide Neural Net.â NeurIPS 2019.
[7] (YouTube video) âNeural Tangent Kernel: Convergence and Generalization in Neural Networksâ by Arthur Jacot, Nov 2018.
[8] (YouTube video) âLecture 7 - Deep Learning Foundations: Neural Tangent Kernelsâ by Soheil Feizi, Sep 2020.
[9] âUnderstanding the Neural Tangent Kernel.â Rajatâs Blog.
[10] âNeural Tangent Kernel.âApplied Probability Notes, Mar 2021.
[11] âSome Intuition on the Neural Tangent Kernel.â inFERENCe, Nov 2020.