Towards a Statistical Theory of Learning to Learn In-context with Transformers
Abstract
Classical learning theory focuses on supervised learning of functions via empirical risk minimization where labeled examples for a particular task are represented by the data distribution experienced by the model during training. Recently, in-context learning emerged as a paradigm shift in large pre-trained models. When conditioned with few labeled examples of potentially unseen tasks in the training, the model infers the task at hand and makes predictions on new points. Learning to learn in-context on the other hand, aims at training models in a meta-learning setup that generalize to new unseen tasks from only few shots of labeled examples. We present in this paper a statistical learning framework for the problem of in-context meta learning and define a function class that enables it. The meta-learner is abstracted as a function defined on the cross product of the probability space (representing context) and the data space. The data distribution is sampled from a meta distribution on tasks. Thanks to the regularity we assume on the function class in the Wasserstein geometry, we leverage tools from optimal transport in order to study the generalization of the meta learner to unseen tasks. Finally, we show that encoder transformers exhibit this type of regularity and leverage our theory to analyze their generalization properties.