PDF
\(\newcommand{\footnotename}{footnote}\) \(\def \LWRfootnote {1}\) \(\newcommand {\footnote }[2][\LWRfootnote ]{{}^{\mathrm {#1}}}\) \(\newcommand {\footnotemark }[1][\LWRfootnote ]{{}^{\mathrm {#1}}}\) \(\let \LWRorighspace \hspace \) \(\renewcommand {\hspace }{\ifstar \LWRorighspace \LWRorighspace }\) \(\newcommand {\mathnormal }[1]{{#1}}\) \(\newcommand \ensuremath [1]{#1}\) \(\newcommand {\LWRframebox }[2][]{\fbox {#2}} \newcommand {\framebox }[1][]{\LWRframebox } \) \(\newcommand {\setlength }[2]{}\) \(\newcommand {\addtolength }[2]{}\) \(\newcommand {\setcounter }[2]{}\) \(\newcommand {\addtocounter }[2]{}\) \(\newcommand {\arabic }[1]{}\) \(\newcommand {\number }[1]{}\) \(\newcommand {\noalign }[1]{\text {#1}\notag \\}\) \(\newcommand {\cline }[1]{}\) \(\newcommand {\directlua }[1]{\text {(directlua)}}\) \(\newcommand {\luatexdirectlua }[1]{\text {(directlua)}}\) \(\newcommand {\protect }{}\) \(\def \LWRabsorbnumber #1 {}\) \(\def \LWRabsorbquotenumber "#1 {}\) \(\newcommand {\LWRabsorboption }[1][]{}\) \(\newcommand {\LWRabsorbtwooptions }[1][]{\LWRabsorboption }\) \(\def \mathchar {\ifnextchar "\LWRabsorbquotenumber \LWRabsorbnumber }\) \(\def \mathcode #1={\mathchar }\) \(\let \delcode \mathcode \) \(\let \delimiter \mathchar \) \(\def \oe {\unicode {x0153}}\) \(\def \OE {\unicode {x0152}}\) \(\def \ae {\unicode {x00E6}}\) \(\def \AE {\unicode {x00C6}}\) \(\def \aa {\unicode {x00E5}}\) \(\def \AA {\unicode {x00C5}}\) \(\def \o {\unicode {x00F8}}\) \(\def \O {\unicode {x00D8}}\) \(\def \l {\unicode {x0142}}\) \(\def \L {\unicode {x0141}}\) \(\def \ss {\unicode {x00DF}}\) \(\def \SS {\unicode {x1E9E}}\) \(\def \dag {\unicode {x2020}}\) \(\def \ddag {\unicode {x2021}}\) \(\def \P {\unicode {x00B6}}\) \(\def \copyright {\unicode {x00A9}}\) \(\def \pounds {\unicode {x00A3}}\) \(\let \LWRref \ref \) \(\renewcommand {\ref }{\ifstar \LWRref \LWRref }\) \( \newcommand {\multicolumn }[3]{#3}\) \(\require {textcomp}\)

CASPR: Combining Axis Preconditioners using Kronecker Sums/Products for Training Large Neural Networks

Inderjit S. Dhillon, Sai S. Duvvuri

Abstract

Deep Neural Networks (DNNs) have transformed fields like computer vision, natural language processing, and scientific research by enabling systems to learn complex patterns, make high-level predictions, and analyze large data sets. DNNs have driven advancements in material sciences, chemistry, and physics, significantly aiding scientific discovery. However, they are difficult to optimize due to their large parameter spaces and can require extensive computational resources, and thus effectively training DNNs is a contemporary challenge.

Most DNNs, including Large Language Models, are trained using adaptive regularization methods such as Adam, which can be regarded as diagonally preconditioned stochastic gradient descent. This diagonal preconditioner comes from a diagonal approximation of the gradient outer product matrix. However, a recent open competition called “AlgoPerf: Training Algorithms benchmark competition” [1] revealed an intriguing discovery: a non-diagonal preconditioning method called Shampoo [2], which uses a Kronecker product approximation of the outer-product matrix, was found to be the best method on a varied suite of benchmark problems.

In this talk, I will introduce adaptive methods and show how Kroencker products can be used to get a computationally efficient preconditioner. I will then talk about a general technique called Combining AxeS PReconditioners (CASPR) [3], which optimizes matrix-shaped DNN parameters by finding different preconditioners for each mode/axis of the parameter and combining them using a Kronecker-sum based approximation. The Kronecker-sum based combination allows us to show that CASPR is ordered between the Kronecker product based combination, Shampoo, and full-matrix “Adagrad” preconditioners in Loewner order, and as a result it is nearer to full-matrix Adagrad than Shampoo. Experimental results demonstrate that CASPR approximates the gradient second-moment matrix more accurately, and shows improvement in training and generalization performance compared to the existing practical adaptive regularization methods in a variety of tasks including graph neural network on OGBG-molpcba, Transformer on a universal dependencies dataset and auto-regressive large language modeling on the C4 dataset.

References