diff --git a/Neural_ODE.tex b/Neural_ODE.tex index 347c274..c5036a0 100644 --- a/Neural_ODE.tex +++ b/Neural_ODE.tex @@ -1,926 +1,926 @@ \documentclass[usenames,dvipsnames,aspectratio=169,10pt]{beamer} \usepackage{multicol} \usetheme{metropolis} \usepackage{appendixnumberbeamer} \usepackage{autonum} \usepackage{booktabs} \usepackage[scale=2]{ccicons} \usepackage{bm} \usepackage{pgfplots} \usepackage[utf8]{inputenc} \usepackage{media9} \usepackage{subcaption} \usepackage[english]{babel} \usepackage{amsmath} \usepackage{mathtools} \usepackage{amsfonts} \usepackage{amssymb} \usepackage{graphicx} \usepackage{xmpmulti} \usepackage{animate} \newcommand{\notimplies}{\;\not\!\!\!\implies} \usepackage{fontspec} % optional \pgfplotsset{compat=newest} \usepgfplotslibrary{groupplots} \usepgfplotslibrary{dateplot} \usepgfplotslibrary{dateplot} \newcommand{\inputTikZ}[2]{% \scalebox{#1}{\input{#2}} } \newcommand\blfootnote[1]{% \begingroup \renewcommand\thefootnote{}\footnote{#1}% \addtocounter{footnote}{-1}% \endgroup } \usepgfplotslibrary{groupplots,dateplot} \usetikzlibrary{patterns,shapes.arrows} \pgfplotsset{compat=newest} \pgfplotsset{compat=1.13} \usepgfplotslibrary{fillbetween} \pgfmathdeclarefunction{gauss}{2} {\pgfmathparse{1/(#2*sqrt(2*pi))*exp(-((x-#1)^2)/(2*#2^2))}} \usepackage{xspace} \newcommand{\themename}{\textbf{\textsc{metropolis}}\xspace} \definecolor{burgundy}{RGB}{255,0,90} \usepackage{algorithm} \usepackage{algpseudocode} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % Listings % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \usepackage{listings,bera} \definecolor{keywords}{RGB}{255,0,90} \definecolor{comments}{RGB}{60,179,113} \lstset{language=Python, keywordstyle=\color{keywords}, commentstyle=\color{comments}\emph} \lstset{escapeinside={<@}{@>}} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % Color stuff % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \usepackage{xcolor,pifont} \newcommand*\colourcheck[1]{% \expandafter\newcommand\csname #1check\endcsname{\textcolor{#1}{\ding{52}}}% } \colourcheck{blue} \colourcheck{green} \colourcheck{red} \definecolor{fore}{RGB}{249,242,215} \definecolor{back}{RGB}{51,51,51} \definecolor{title}{RGB}{255,0,90} \definecolor{mDarkBrown}{HTML}{604c38} \definecolor{mDarkTeal}{HTML}{23373b} \definecolor{mLightBrown}{HTML}{EB811B} \definecolor{mLightGreen}{HTML}{14B03D} \definecolor{aqb}{HTML}{6FEBBE} \setbeamercolor{titlelike}{fg=} \setbeamercolor{normal text}{fg=fore,bg=back} \newcommand{\pink}[1]{{\color{magenta} #1}} \newcommand{\blue}[1]{{\color{aqb} #1}} \newcommand{\tinto}[1]{{\color{burgundy} #1}} %symbol definitions %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % Variables % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \newcommand{\yj}{{Y}_{j+1}} \newcommand{\aat}{\mathbf{a}^T} \newcommand{\dat}{\frac{\diff {{a}}^T}{\diff t}} \newcommand{\ym}{{Y}_{j}} \newcommand{\kj}{{K}_j} \newcommand{\bj}{b_j} \newcommand{\yy}{\mathbf{Y}} \newcommand{\te}{\bm{\theta}} \newcommand{\R}{\mathbb{R}} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % delimiters % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \newcommand{\lno}{\left \Vert} \newcommand{\rno}{\right \Vert} \newcommand{\lv}{\lvert} \newcommand{\rv}{ \rvert} \newcommand{\inner}[2]{\left\langle #1,#2 \right\rangle} \newcommand{\abs}[1]{\left \vert#1 \right \vert} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % operators % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \newcommand{\calc}{\mathcal{C}} \newcommand{\calku}{\mathcal{K}(u)} \newcommand{\ff}{\mathcal{F}} \newcommand{\diff}{\mathsf{d}} \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % % check mark % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \title{Continuous limits of DNN: Neural networks as ODEs} %\subtitle{An overview} \date{November 13, 2020} \author{Eva Vidli\v ckov\'a and Juan Pablo Madrigal Cianci,\\ CSQI.} \institute{Foundations of deep neural networks} \titlegraphic{\hfill\includegraphics[height=.5cm]{logo.png}} \begin{document} \maketitle %\begin{frame}{Table of contents} % \setbeamertemplate{section in toc}[sections numbered] % \tableofcontents[hideallsubsections] %\end{frame} %\section{Introduction and Motivation: ResNets} \begin{frame}{ResNets} Here we briefly describe what resnets are, maybe define some notation and make the point that they look like an ODE. \end{frame} \begin{frame}{Outline} The rest of the talk summarizes the following two articles: \nocite{*} \bibliography{bib_intro} \bibliographystyle{plain} \end{frame} -%\section{Stable Architectures} +\section{Stable Architectures} %\section{Introduction} \begin{frame}{Classification problem} \begin{itemize} \item training data: \begin{align} & y_1,\dots, y_s \in \R^n \quad \text{ feature vectors }\\ & c_1,\dots, c_s \in \R^m \quad \text{ label vectors } \\ & (c_l)_k \text{ - likelihood of $y_l$ belonging to class $k$} \end{align} \item objective: learn data-label relation function that generalizes well \item \textbf{deep architectures} \begin{itemize} \item[+] successful for highly nonlinear data-label relationships \item[--] dimensionality, non-convexity, \textbf{instability} of forward model \end{itemize} \end{itemize} \end{frame} \begin{frame}{ResNets: Forward propagation} %\textbf{Forward propagation} \begin{gather} Y_{j+1} = Y_j + {\color{red}h}\sigma(Y_j K_j + b_j), \quad j= 0,\dots,N-1\\[10pt] Y_j \in\R^{s\times n},\, K_j\in\R^{n\times n},\, b_j\in\R \end{gather} \begin{itemize} \item $Y_0 = [y_1,\dots,y_s]^\intercal$ \item $Y_1,\dots, Y_{N-1}$ - hidden layers,\; $Y_N$ - output layer \item activation function \[\sigma_{ht}(Y) = \tanh(Y),\quad \sigma_{ReLU} = \max(0,Y)\] \end{itemize} \end{frame} \begin{frame}{ResNets: Classification} %\textbf{Classification} \[ h(Y_N W + e_s\mu^\intercal),\quad W\in \R^{n\times m}, \mu\in\R^m \] \begin{itemize} \item hypothesis function \[ h(x) = \exp(x)./(1+\exp(x)) \] \end{itemize} \end{frame} \begin{frame}{Learning process} \begin{columns} \begin{column}{0.5\textwidth} \begin{center} forward prop. parameters\\[10pt] $(K_j, b_j,\;\; j=0,\dots, N-1)$ \end{center} \end{column} \begin{column}{0.5\textwidth} \begin{center} classification parameters \\[10pt] $(W, \mu)$ \end{center} \end{column} \end{columns} \vspace{1cm} \centerline{\textbf{Optimization problem}} \begin{gather} \min \frac{1}{s} S\big(h(Y_N W + e_s\mu^\intercal), C\big) + \alpha R(W,\mu,K_j, b_j) \\[5pt] \text{ s.t. } Y_{j+1} = Y_j + h\sigma(Y_j K_j + b_j), \qquad j= 0,\dots,N-1 \end{gather} \begin{itemize} \item $C = [c_1,c_2,\dots,c_s]^\intercal\in\mathbb{R}^{s\times m}$ \item e.g. $S(C_{pred},C) = \frac{1}{2}\|C_{pred} - C\|_F^2$ \end{itemize} \end{frame} \begin{frame}{Learning process} \begin{itemize} \item block coordinate descent method \item \begin{align} \frac{1}{s} S\big(h(Y_N W + e_s\mu^\intercal), C\big) &= \frac{1}{s} \sum_{i=1}^s S\Big(h\big((Y_N)_i^\intercal W + \mu^\intercal\big), c_i^\intercal\Big)\\ &\approx \frac{1}{|\mathcal{T}|} \sum_{i\in\mathcal{T}} S\Big(h\big((Y_N)_i^\intercal W + \mu^\intercal\big), c_i^\intercal\Big) \end{align} \item learning data \& validation data \end{itemize} \end{frame} %\section{ODE interpretation} \begin{frame}{ResNets as discretized ODEs} \textbf{ResNet} \[ Y_{j+1} = Y_j + h\sigma(Y_j K_j + b_j), \quad j= 0,\dots,N-1 \] \textbf{Continuous ODE} \begin{align} &\dot{y}(t) = \sigma\big( K^\intercal(t)y(t) + b(t) \big), \quad t\in [0,T]\\ &y(0) = y_0 \end{align} \end{frame} \begin{frame}{Stability of continuous ODEs} \end{frame} \begin{frame}{Stability of NN forward propagation} \textbf{Continuous problem} \[ \dot{y}(t) = \sigma\big( K(t)^\intercal y(t) + b(t) \big), \quad y(0) = y_0 \] stable if $K(t)$ changes sufficiently slow and \begin{align} \max_{i=1,\dots,n} &\text{Re}\Big(\lambda_i\big(J(t)\big)\Big) \leq 0,\quad\forall t\in [0,T]\\[10pt] \text{where } J(t) &= \nabla_y\Big(\sigma\big( K(t)^\intercal y(t) + b(t) \big)\Big)^\intercal\\ &= \text{diag}\Big(\underbrace{\sigma'\big( K(t)^\intercal y(t) + b(t) \big)}_{\geq 0}\Big) K(t)^\intercal \end{align} $ \to \max_{i=1,\dots,n} \text{Re}\Big(\lambda_i\big(K(t)\big)\Big) \leq 0,\quad\forall t\in [0,T]. $ \end{frame} \begin{frame}{Stability of NN forward propagation} \textbf{Forward Euler method} \[ Y_{j+1} = Y_j + h\sigma(Y_j K_j + b_j), \quad j= 0,\dots,N-1 \] stable if \[ \max_{i=1,\dots,n} |1+h\lambda_i(J_i)| \leq 1,\quad \forall j=0,1,\dots,N-1 \] \end{frame} \begin{frame}{Example: Stability of ResNet} \begin{figure} \centering \includegraphics[width = 0.8\textwidth]{figures/ResNet_stab.png} \end{figure} \begin{align} K_{+}&= \begin{pmatrix} 2 & -2\\ 0& 2 \end{pmatrix} & K_{-} &= \begin{pmatrix} -2 & 0 \\ 2 & -1 \end{pmatrix} & K_0 &= \begin{pmatrix} 0 & -1 \\ 1 & 0 \end{pmatrix}\\ \lambda(K_+) &= 2 & \lambda(K_-) &= -2 & \lambda(K_0)&= i,-i \end{align} \begin{itemize} \item $s = 3,\, n=2,\, h = 0.1, \, b = 0,\, \sigma = \tanh,\, N = 10$ \end{itemize} \end{frame} \begin{frame}{Well-posed forward propagation} \begin{enumerate} \item $\max_i \text{Re}(\lambda_i(K)) > 0$\\[10pt] \begin{itemize} \item neurons amplify signal with no upper bound \item unreliable generalization\\[20pt] \end{itemize} \item $\max_i \text{Re}(\lambda_i(K)) << 0$\\[10pt] \begin{itemize} \item inverse problem highly ill-posed \item vanishing gradients problem \item lossy network\\[20pt] \end{itemize} \end{enumerate} $\implies \text{Re}(\lambda_i(K(t))) \approx 0,\quad\forall i=1,2,\dots,n,\;\forall t\in [0,T]$ \end{frame} %\section{Stable architectures} \begin{frame}{Antisymmetric weight matrices} \[ \dot{y}(t) = \sigma \Big( \frac{1}{2}\big(\underbrace{ K(t) - K(t)^\intercal }_{ \mathclap{\text{antisymmetric $\to$ imaginary eigenvalues}} } - \gamma I\big)y(t) + b(t)\Big),\quad t\in [0,T] \] \bigskip \begin{enumerate} \item $\gamma = 0$ \begin{figure} \centering \includegraphics[scale = 0.25]{figures/RK_stab.png} \end{figure} \item $\gamma > 0$ \quad $\to$ Forward Euler discretization $$Y_{j+1} = Y_j + h\sigma\Big(\frac{1}{2}Y_j (K_j - K_j^\intercal - \gamma I) + b_j\Big)$$ \end{enumerate} \end{frame} \begin{frame}{Hamiltonian inspired NN} \vspace{-0.9cm} \[ \dot{y}(t) = -\nabla_z H(y,z,t), \quad \dot{z}(t) = \nabla_y H(y,z,t),\quad t\in [0,T] \] \begin{itemize} \item Hamiltonian $H: \R^n\times \R^n\times [0,T]\to \R$ conserved\\[10pt] \item energy conserved, not dissipated\\[20pt] \end{itemize} \end{frame} \begin{frame}{Hamiltonian inspired NN} Hamiltonian $H(y,z) = \frac{1}{2}z^\intercal z - f(y)$\\[5pt] $$\dot{y}(t) = -z(t), \; \dot{z}(t) = -\nabla_y f(y(t))\quad\implies \ddot{y}(t) = \nabla_y f(y(t))$$ \begin{itemize} \item $\ddot{y}(t) = \sigma\Big( K^\intercal (t) y(t) + b(t)\Big),\; y(0) = y_0,\; \dot{y}(0) = 0$\\[5pt] \item stable for $K$ with non-positive real eigenvalues\\[5pt] \item $K(C) = -C^\intercal C,\quad C\in\R^{n\times n}$\\[5pt] \item nonlinear parametrization - complicated optimization\\[5pt] \item leapfrog discretization scheme (symplectic integrator) \end{itemize} \end{frame} \begin{frame}{Hamiltonian inspired NN} \[ \dot{y}(t) = \sigma\Big( K (t) z(t) + b(t)\Big) \qquad \dot{z}(t) = \sigma\Big( K^\intercal (t) y(t) + b(t)\Big) \] Associated ODE: \begin{align} \frac{\partial}{\partial t} \begin{pmatrix} y\\ z \end{pmatrix}(t) &= \sigma \begin{pmatrix}\begin{pmatrix} 0 & K(t) \\ -K(t)^\intercal & 0 \end{pmatrix} \begin{pmatrix} y\\ z \end{pmatrix}(t) + b(t) \end{pmatrix}, \\ \begin{pmatrix} y\\ z \end{pmatrix}(0) &= \begin{pmatrix} y_0\\ 0 \end{pmatrix} \end{align} \begin{itemize} \item antisymmetric matrix \item Verlet integration scheme (symplectic) $$ z_{j+1/2} = z_{j-1/2} - h\sigma(K_j^\intercal y_j + b_j),\quad y_{j+1} = y_j + h\sigma (K_j z_{j+1/2} + b_j)$$ \item $K_j$ non-square \end{itemize} \end{frame} %\section{Regularization} \begin{frame}{Regularization} \begin{enumerate} \item Forward propagation\\[10pt] \begin{itemize} \item standard: weight decay (Tikhonov regularization) \[ R(K) = \frac{1}{2}\|K\|_F^2 \] \item $K$ to be sufficiently smooth \[ R(K) = \frac{1}{2h}\sum \|K_j - K_{j-1}\|_F^2\quad R(b) = \frac{1}{2h}\sum \|b_j - b_{j-1}\|^2\] \end{itemize} \item Classification\\[10pt] \begin{itemize} \item $h(y_j^\intercal w_k + \mu_k)\approx h \Big(\text{vol}(\Omega) \int_{\Omega} y(x)w(x)\mathrm{d}x + \mu_k \Big)$ \item $$R(w_k) = \frac{1}{2} \|L w_k\|^2\quad L - \text{discretized differential operator}$$ \end{itemize} \item Multi-level learning \end{enumerate} \end{frame} %\section{Numerical examples} \begin{frame}{Concentric ellipses} \begin{figure} \centering \includegraphics[width = \textwidth]{figures/Elipses.png} \end{figure} \begin{itemize} \item 1200 points: 1000 training + 200 validation \item multi-level: 4, 8, 16, \dots, 1024 layers \item T = 20, n = 2, $\alpha = 10^{-3}$, $\sigma = \tanh$ \item standard ResNet, antisymmetric ResNet, Hamiltonian - Verlet network \end{itemize} \end{frame} \begin{frame}{Convergence} \begin{figure} \centering \includegraphics[width = \textwidth]{figures/Convergence.png} \end{figure} \end{frame} \begin{frame}{Swiss roll} \begin{figure} \centering \includegraphics[width = \textwidth]{figures/Swiss_roll.png} \end{figure} \begin{itemize} \item 513 points: 257 training + 256 validation \item multi-level: 4, 8, 16, \dots, 1024 layers \item T = 20, n = 4,4,2, $\alpha = 5\cdot 10^{-3}$, $\sigma = \tanh$ \item standard ResNet, antisymmetric ResNet, Hamiltonian - Verlet network \end{itemize} \end{frame} \begin{frame}{Peaks} \begin{figure} \centering \includegraphics[width = \textwidth]{figures/Peaks.png} \end{figure} \begin{itemize} \item 5000 samples: 20\% for validation \item multi-level: 4, 8, 16, \dots, 1024 layers \item T = 5, n = 8,8,2, $\alpha = 5\cdot 10^{-6}$, $\sigma = \tanh$ \item standard ResNet, antisymmetric ResNet, Hamiltonian - Verlet network \end{itemize} \end{frame} \begin{frame}{MNIST} \begin{figure} \centering \includegraphics[scale = 0.8]{figures/MNIST.png} \end{figure} \end{frame} \begin{frame}{MNIST} \begin{itemize} \item 60 000 labeled images: 50 000 training, 10 000 validation, \item 28 $\times$ 28, multi-level: 4,8,16 \item T = 6, n = 4704, $\alpha = 0.005$, $3 \times 3$ convolution operators \item standard ResNet, antisymmetric ResNet, Hamiltonian - Verlet network \bigskip \end{itemize} \begin{figure} \centering \includegraphics[width = 0.9\textwidth]{figures/MNIST_table.png} \end{figure} \end{frame} -%\section{Neural ODEs} +\section{Neural ODEs} %% %% %% \begin{frame}[fragile]{Motivation: ResNets and Euler's method} $$\yj=\ym+ \underbrace{h\sigma\left( \ym\kj+\bj\right)}_\text{$=f(\ym,\te_j)$}, \quad \text{Euler discretization of $\frac{\diff \mathbf{Y}}{\diff t}=f(\mathbf{Y},\te(t))$}$$ \begin{columns} \begin{column}{0.5\linewidth} %\begin{lstlisting} %#D %-+*efines the architecture %def f(Y,t,θ): %return neural_net(z,θ[t]) % %#Defines the resnet %def resnet(Y): %for t in [1:T]: % Y=Y+f(Y,t,θ) %return Y %\end{lstlisting} \vspace{0.5cm} \texttt{ \blue{\#Defines the architecture}\\ \tinto{def} f(Y,t,θ):\\ \tinto{return} neural\_net(z,θ[t])\\ \vspace{1cm}} \texttt{ \blue{\#Defines the ResNet}\\ \tinto{def} ODE\_Net(Y0):\\ \pink{for} t in [1:T]:\\ \hspace{0.5 cm} Y=Y+f(Y,t,θ)\\ {\tinto{return} Y}} \vspace{1cm} \end{column} \begin{column}{0.5\linewidth} \begin{figure} \centering \uncover<2->{ \includegraphics[width=0.8\linewidth]{figures/eulers.png} } \end{figure} \uncover<3->{ Can we do better? } \end{column} \end{columns} \end{frame} %% \begin{frame}[fragile]{Improving on Euler's method} $$\yj=\ym+ \underbrace{h\sigma\left( \ym\kj+\bj\right)}_\text{$=f(\ym,\te_j)$}, \quad \text{Euler discretization of $\frac{\diff \mathbf{Y}}{\diff t}=f(\mathbf{Y},\te(t))$}$$ \begin{columns} \begin{column}{0.6\linewidth} %\begin{lstlisting} %#Defines the architecture %def f(Y,t,θ): %return neural_net([z,t],θ[t]) % %#Defines the ODE Net %def ODE_Net(Y0): %return ODE_Solver(f,Y0,\theta,t_0=0,t_f=1) %\end{lstlisting} \vspace{0.5cm} \texttt{ \blue{\#Defines the architecture}\\ \tinto{def} f(Y,t,θ):\\ \tinto{return} neural\_net(\pink{[z,t]},θ[t])\\ \vspace{0.5cm}} \texttt{ \blue{\#Defines the ODE Net}\\ \tinto{def} ODE\_Net(Y0):\\ {\tinto{return} \alert{ODE\_Solver}(f,Y0,θ,t\_0=0,t\_f=1)}} \vspace{0.5cm} Here \texttt{\alert{ODE\_Solver}} is a black-box ODE solver. \end{column} \begin{column}{0.5\linewidth} \centering \uncover<2->{ \includegraphics[width=0.7\linewidth]{figures/adaptive.png} } \end{column} \end{columns} \uncover<3->{ \begin{center} \pink{Main idea:} Continuous depth + good ODE solver. \end{center} } \end{frame} \begin{frame}[fragile] \begin{columns} \begin{column}{0.5\linewidth} \textbf{ResNet:} \texttt{ \blue{\#Defines the architecture}\\ \tinto{def} f(Y,t,θ):\\ \tinto{return} neural\_net(z,θ[t])\\ \vspace{0.25cm}} \texttt{ \blue{\#Defines the ResNet}\\ \tinto{def} ResNet(Y):\\ \pink{for} t in [1:T]:\\ \hspace{0.5 cm} Y=Y+f(Y,t,θ)\\ {\tinto{return} Y}} \begin{figure} \centering \end{figure} \end{column} \begin{column}{0.5\linewidth} \textbf{ODENet:} \texttt{ \blue{\#Defines the architecture}\\ \tinto{def} f(Y,t,θ):\\ \tinto{return} neural\_net(\pink{[z,t]},θ[t])\\ \vspace{0.25cm}} \texttt{ \blue{\#Defines the ODENet}\\ \tinto{def} ODE\_Net(Y0):\\ {\tinto{return} \alert{ODE\_Solver}(f,Y0,θ,t\_0=0,t\_f=1)}} \end{column} \end{columns} \begin{center} \includegraphics[width=0.4\linewidth]{figures/ode_res} \end{center} \end{frame} %% %% %% %\begin{frame}{Some considerations} %\begin{center} % \includegraphics[width=0.2\linewidth]{figures/resnett}\hspace{2cm} % \includegraphics[width=0.2\linewidth]{figures/odenett} %\end{center} %\end{frame} \begin{frame}{Training the Neural Network: Adjoint Method} We aim at minimizing $J:R^p\mapsto R,$ $$J(\yy,t_f,\te)=J\left(\yy(t_0)+\int_{t_0}^{t_f}f(\yy,t,\te)\diff t \right)=J(\text{\texttt{\alert{ODE\_Solver}}}(f,\yy(t_0),\te,t_0=0,t_f=1)).$$ \textbf{Difficulties: } \begin{enumerate} \item \alert{\texttt{ODE\_Solver}} is a black-box. \item There is no notion of layers, since we are on a continuous limit. \end{enumerate} $$\frac{\partial J}{\partial \te}=?$$ How does $\te$ depend on $\yy(t)$ at each instant $t$? Don't use back-prop, but rather the \tinto{adjoint-state method} (Pontryagin et al. 1962.). \end{frame} %% %% %% \begin{frame}{Training the Neural Network: Adjoint Method} Define first $$G(\yy,t_f,\te):=\int_{t_0}^{t_f} J(\yy,t,\te)\diff t, \quad \frac{\diff}{\diff t_f}G(\yy,t_f,\te)=J(\yy,t,\te)$$ and the Lagrangian $$L=G(\yy,t_f,\te)+\int_{t_0}^{t_f}\aat(t)\left( \dot{\yy}(t,\te)-f(\yy,t,\te) \right)\diff \te $$ Then, \begin{align} \frac{\partial L}{\partial \te}=\int_{t_0}^{t_f} \left(\frac{\partial J}{\partial \yy}\frac{\partial \yy}{\partial \te} +\frac{\partial J}{\partial \te}\right)\diff t +\int_{t_0}^{t_f}\aat(t)\left( \blue{\frac{\partial\dot{\yy}}{\partial \te}}-\frac{\partial f}{\partial \yy}\frac{\partial \yy}{\partial \te}- \frac{\partial f}{\partial \te} \right)\diff \te \end{align} IBP: \begin{align} \int_{t_0}^{t_f}\aat(t)\blue{\frac{\partial\dot{\yy}}{\partial \te}}\diff t=\aat(t)\frac{\partial{\yy}}{\partial \te}\rvert_{t_0}^{t_f}-\int_{t_0}^{t_f}\dat(t)\blue{\frac{\partial {\yy}}{\partial \te}}\diff t \end{align} \end{frame} \begin{frame}{Adjoint method (cont'd)} \begin{align} \frac{\partial L}{\partial \te}&=\int_{t_0}^{t_f} \left(\frac{\partial J}{\partial \yy}\frac{\partial \yy}{\partial \te} +\frac{\partial J}{\partial \te}\right)\diff t +\int_{t_0}^{t_f}\aat(t)\left( \blue{\frac{\partial\dot{\yy}}{\partial \te}}-\frac{\partial f}{\partial \yy}\frac{\partial \yy}{\partial \te}- \frac{\partial f}{\partial \te} \right)\diff \te\\ %% %% &=\int_{t_0}^{t_f} \left(\frac{\partial \yy}{\partial \te}\right)\alert{\left(\frac{\partial \yy}{\partial \te} -\aat\pd{f}{\yy}-\dat\right)}\diff t+\int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t +\left(\aat \pd{\yy}{\te}\right)_{t_0}^{t_f}\\ \end{align} Setting $\alert{\left(\frac{\partial J}{\partial \yy}\frac{\partial \yy}{\partial \te} -\aat\pd{f}{\yy}-\dat\right)}=0$, one gets $$ \tinto{\dat}\tinto{=\frac{\partial J}{\partial \yy} -\aat\pd{f}{\yy} \quad \text{(Adjoint Equation)}} $$ and, as such \begin{align} \frac{\partial L}{\partial \te}&=\int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t +\left(\aat \pd{\yy}{\te}\right)_{\blue{t_0}}^{t_f}=\int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t +\left(\aat(t_f) \pd{\yy}{\te}(t_f)\right) \end{align} \end{frame} %% %% %% \begin{frame}{Adjoint method (cont'd)} From $J(\yy,\te)=\frac{\diff}{\diff t_f} G(\yy,t_f,\te)$ and \alert{\begin{align} \int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t +\left(\aat(t_f) \pd{\yy}{\te}(t_f)\right), \end{align}} One then has \begin{align} \pd{J}{\te}&=\frac{\diff}{\diff t_f}\left(\int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t +\left(\aat(t_f) \pd{\yy}{\te}(t_f)\right)\right)=\dots\\&=\int_{t_0}^{t_f}-\aat \pd{f}{\te}+\pink{\pd{J}{\te}}\diff t, \quad \tinto{\dat}\tinto{=\frac{\partial J}{\partial \yy} -\aat\pd{f}{\yy} \qquad \text{(Adjoint Equation)}} \end{align} \uncover<2>{ Thus, gradient can be computed by computing the ode of $\yy$ and $\aat$. Can be done simmultaneously and doesn't need to storage values at every ``layer'' (time step) } \uncover<3->{\alert{implies} big save in memory.} \end{frame} \begin{frame}{Some considerations} \begin{enumerate} \item \alert{How deep are ODENets?} left to the ODE solver, complexity in terms of NFE \item \tinto{Accuracy-cost trade-off} Evaluate forward pass at a lower accuracy/cheaper cost \item \blue{Constant Memory Cost} Due to adjoint. \item In practice, 2-4X more expensive to train than corresponding ResNet \end{enumerate} \includegraphics[width=1\linewidth]{figures/four_plots} \end{frame} %% %% %% \begin{frame}{Application: Density transform} \alert{Normalizing flows} Given $\yy_0\sim p_0$ and \tinto{$f$} s.t $ \yy_1=\tinto{f}(\yy_0),$ then one has that $$\log p_1(\yy_1)=\log p_0(\yy_0)- \log \left \lvert \blue{\det} \ \frac{\partial \tinto{f}}{\partial \yy_0}\right \rvert$$ \uncover<2->{ Thus, if one knows \tinto{$f$} and can compute \blue{$\det$}, one can evaluate the transformed density $p_1$. } \uncover<3->{ This has applications in Bayesian inference, image generation, etc. } \uncover<3->{ \textbf{Issue? }$\blue{\det}$ can be, at worst $\mathcal{O}(M^3)$, where $M$ is the number of hidden units.} \uncover<4->{ One solution is to take \tinto{$f$} diagonal, but this reduces expressability of the transformation} \uncover<5->{\pink{Continuous normalizing flows might help}} \end{frame} %% %% %% \begin{frame}{Change of variable formula via continuous transformation} \textbf{Theorem:} Consider instead a \alert{continuous-in-time} transformation of $\yy(t)$ given by $$\frac{\diff \yy}{\diff t}=f\left( \yy(t),t\right)$$ Then, under the assumption that $f$ is uniformly Lipschitz continuous in $t$, it follows that the change in log-probability follows: $$\frac{\partial \log (p(\yy(t)))}{\partial t}=-\text{Tr}\left(\frac{\diff f}{\diff \yy (t)}\right).$$ \uncover<2->{ Notice that: \begin{enumerate} \item It involves a trace instead of a determinant $\implies$ cheaper. \item $f$ need not be bijective; if solution is unique, then, whole transf. is bijective. \end{enumerate} } \end{frame} %% %% %% \begin{frame}{Proof} $$\frac{\partial \log (p(\yy(t)))}{\partial t}=-\text{Tr}\left(\frac{\diff f}{\diff \yy (t)}\right).$$ Let $\yy(t+\epsilon)=T_\epsilon(\yy(t))$. \begin{align} \frac{\partial \log (p(\yy(t)))}{\partial t}&=\lim_{\epsilon \to 0+}\frac{\log p(\yy(y))-\log\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert-\log(p(\yy(t)))}{\epsilon}\\ &=\lim_{\epsilon \to 0+}\frac{-\log\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert}{\epsilon} =-\lim_{\epsilon \to 0+} \frac{\frac{\partial}{\partial \epsilon}\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert}{\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert}\quad \text{ (L'H\^opital)}\\ &=-\underbrace{\left(\lim_{\epsilon \to 0+} \frac{\partial}{\partial \epsilon}\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert\right)}_\text{bounded}\underbrace{\left(\lim_{\epsilon \to 0+}\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert\right)}_\text{=1}\\ &=-\left(\lim_{\epsilon \to 0+} \frac{\partial}{\partial \epsilon}\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert\right) \end{align} \end{frame} %% %% %% \begin{frame}{proof} Recall Jacobi's formula for an $n\times n$ matrix A: $\frac{d}{\diff t}\det{A(t)}=\text{Tr}\left( \text{Adj}(A(t))\frac{\diff A(t)}{\diff t}\right) .$ Then, \begin{align} =&-\lim_{\epsilon \to 0+} \frac{\partial}{\partial \epsilon}\lvert \det \frac{\partial T_\epsilon(\yy(t))}{\partial \yy}\rvert=-\lim_{\epsilon \to 0+}\text{Tr}\left( \text{Adj}\left(\frac{\partial}{\partial \yy} T_\epsilon (\yy(t))\right)\frac{\partial}{\partial \epsilon}\frac{\partial}{\partial \yy}T_\epsilon(\yy(t))\right)\\ &=\text{Tr}\left(\underbrace{\left( -\lim_{\epsilon \to 0+} \text{Adj} \left(\frac{\partial}{\partial \yy}T_\epsilon (\yy(y))\right) \right)}_\text{=I} \left(-\lim_{\epsilon \to 0+} \frac{\partial }{\partial \epsilon}\frac{\partial}{\partial \yy}T_\epsilon(\yy(t))\right) \right)\\ &=\text{Tr}\left(-\lim_{\epsilon \to 0+} \frac{\partial }{\partial \epsilon}\frac{\partial}{\partial \yy}\left(\yy+\epsilon f(\yy(t),t)+\mathcal{O}(\epsilon ^2)\right)\right)\\ &=\text{Tr}\left(-\lim_{\epsilon \to 0+} \frac{\partial }{\partial \epsilon}\left(I+\frac{\partial}{\partial \yy}\epsilon f(\yy(t),t)+\mathcal{O}(\epsilon ^2)\right)\right)\\ &=\text{Tr}\left(-\lim_{\epsilon \to 0+} \left(\frac{\partial}{\partial \yy} f(\yy(t),t)+\mathcal{O}(\epsilon) \right)\right)\\ &=-\text{Tr}\left( f(\yy(t),t)\right)\\ \end{align} \end{frame} %% %% %% \begin{frame}{Example: Density Matching} Given a \tinto{target} $p$ we construct a \alert{flow} $q$, minimizing $J=\text{KL}(q\lVert p):=\int \log\left(\frac{q(\te)}{p(\te)}\right)q(\te)\diff \te$ (assuming we can evaluate both $p$ and $q$.) \uncover<2->{ \begin{enumerate} \item \pink{Normalizing flow (NF)} $q(\yy_1)=\log p_0(\yy_0)- \log \left \lvert \blue{\det} \ \frac{\partial \tinto{f}}{\partial \yy_0}\right \rvert$ \item \blue{Continuous normalizing flow (CNF)} $q$ solves $\frac{\partial \log (q(\yy(t)))}{\partial t}=-\text{Tr}\left(\frac{\diff f}{\diff \yy (t)}\right).$ \end{enumerate} } \only<3>{ \begin{center} \includegraphics[width=0.7\linewidth]{figures/comparisson_final} \end{center} } \only<4>{ \begin{center} \includegraphics[width=0.9\linewidth]{figures/noise_to_data} \end{center} } \end{frame} %% %% %% \begin{frame}{Other applications: Time series} \begin{center} \includegraphics[width=1\linewidth]{figures/time_model} \end{center} \end{frame} \begin{frame}{Other applications: Time series} \begin{center} \includegraphics[width=0.6\linewidth]{figures/time_dyn} \end{center} \end{frame} %% %% %% \begin{frame}{Summary and conclusions} This paper can be seen more towards from a computational perspective than the previous one. Aim is to consider the time-continuous limit of the DNN and its interpretation as an ODE. Using this, one can use \alert{black-box} ODE solving routines. \begin{enumerate} \item There is no notion of layers. Use number of function evaluations as a measure of depth. \item Can speed up in terms of accuracy/cost. \item No control during training phase (due to black-box nature). More expensive than equivalent ResNet \item Constant memory cost \item Nice applications for density transport and continuous time models. \end{enumerate} \end{frame} %% %% %% \end{document}