Linear Attention Is (maybe) All You Need (to understand Transformer optimization)
Speaker
Kwangjun Ahn
LIDS & EECS
Host
Thien Le
CSAIL MIT
Abstract: Transformer training is notoriously difficult; requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training transformers by carefully studying a simple yet canonical linearized shallow transformer model. Specifically, we train linear transformers to solve regression tasks, inspired by J. von Oswald et al. (ICML 2023), and K. Ahn et al. (NeurIPS 2023). Most importantly, we observe that the linearized models mimic several prominent aspects of transformers vis-a-vis their training dynamics. Consequently, the results of this paper hold the promise of identifying a simple transformer model that might be a valuable, realistic proxy for understanding transformers.
Speaker bio: Kwangjun Ahn is a final year PhD student at MIT with the Department of EECS (Electrical Engineering & Computer Science) and Laboratory for Information and Decision Systems (LIDS). His advisors are Profs. Suvrit Sra and Ali Jadbabaie. He's also working part time at Google Research, where he's working on accelerating LLM inference with the Speech & Language Algorithms Team. His current research interests include understanding LLM optimization and how to speed up the optimization. He has worked on various topics over the years, including machine learning theory, optimization, statistics, and learning for control.
Speaker bio: Kwangjun Ahn is a final year PhD student at MIT with the Department of EECS (Electrical Engineering & Computer Science) and Laboratory for Information and Decision Systems (LIDS). His advisors are Profs. Suvrit Sra and Ali Jadbabaie. He's also working part time at Google Research, where he's working on accelerating LLM inference with the Speech & Language Algorithms Team. His current research interests include understanding LLM optimization and how to speed up the optimization. He has worked on various topics over the years, including machine learning theory, optimization, statistics, and learning for control.