Multi-armed bandits for online optimization of language model pre-training: the use case of dynamic masking

03/24/2022
by   Iñigo Urteaga, et al.
0

Transformer-based language models (TLMs) provide state-of-the-art performance in many modern natural language processing applications. TLM training is conducted in two phases. First, the model is pre-trained over large volumes of text to minimize a generic objective function, such as the Masked Language Model (MLM). Second, the model is fine-tuned in specific downstream tasks. Pre-training requires large volumes of data and high computational resources, while introducing many still unresolved design choices. For instance, selecting hyperparameters for language model pre-training is often carried out based on heuristics or grid-based searches. In this work, we propose a multi-armed bandit-based online optimization framework for the sequential selection of pre-training hyperparameters to optimize language model performance. We pose the pre-training procedure as a sequential decision-making task, where at each pre-training step, an agent must determine what hyperparameters to use towards optimizing the pre-training objective. We propose a Thompson sampling bandit algorithm, based on a surrogate Gaussian process reward model of the MLM pre-training objective, for its sequential minimization. We empirically show how the proposed Gaussian process based Thompson sampling pre-trains robust and well-performing language models. Namely, by sequentially selecting masking hyperparameters of the TLM, we achieve satisfactory performance in less epochs, not only in terms of the pre-training MLM objective, but in diverse downstream fine-tuning tasks. The proposed bandit-based technique provides an automated hyperparameter selection method for pre-training TLMs of interest to practitioners. In addition, our results indicate that, instead of MLM pre-training with fixed masking probabilities, sequentially adapting the masking hyperparameters improves both pre-training loss and downstream task metrics.

READ FULL TEXT

Please sign up or login with your details

Forgot password? Click here to reset
Success!
Error Icon An error occurred

Sign in with Google

×

Use your Google Account to sign in to DeepAI

×

Consider DeepAI Pro