Authors: Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, Ming-Wei Chang
Paper reference: https://arxiv.org/pdf/2002.08909.pdf
Contribution
This paper proposes a Retrieval-Augmented Language Model (REALM) framework, which augments language model pre-training with a latent knowledge retriever, which allows the model to retrieve and attend over documents from a large corpus such as Wikipedia, used during pre-training, fine-tuning and inference.
Experiments show the effectiveness of REALM by fine-tuning on the challenging task of Open-domain Question Answering (Open-QA).
Details
For both pre-training and fine-tuning, REALM takes some input $x$ and learns a distribution $p(y | x)$ over possible outputs $y$.
(1) For pre-training, the task is masked language modeling;
(2) For fine-tuning, the task is Open-QA: $x$ is a question, and $y$ is the answer.
Since the model structures in knowledge retriever and text generation are similar to previous post (Joint Retrieval and Generation Training for Grounded Text Generation), I only focus on the model pre-training part and how encodings for documents are updated asynchronously (for both works).
Injecting Bias into Pre-Training
The paper develops strategies in pre-training to guide model towards useful retrievals.
(1) Salient span masking. Instead of uniform making, spans (e.g. named entities, dates) that requires world knowledge to predict are masked by leveraging a Bert-based tagger.
(2) Null document. Add an empty document as one of selected $k$ documents since some salient span still do not require referring to any retrieved knowledge.
(3) Prohibiting trivial retrievals. If the masked sentence $x$ comes from document $z$, the knowledge augmented encoder can trivially predict $y$ by looking at the unmasked version of $x$ in $z$, which encourages model to learn to look for exact string matches between $x$ and $z$.
(4) Initialization (avoid cold-start problem). The paper uses a Inverse Cloze Task (ICT) to warm-start embeddings.
Cold-Start Problem:
At the beginning of training, if the retriever does not have good embeddings for $x$ and $z$, the retrieved documents $z$ will likely be unrelated to $x$. This causes the knowledge augmented encoder to learn to ignore the retrieved documents. Once this occurs, the knowledge retriever does not receive a meaningful gradient and cannot improve, creating a vicious cycle.
Asynchronous Update
Remember that the previous post and this paper both use Maximum Inner Product Search (MIPS) algorithms to find the approximate top $k$ documents and each document is assigned with a MIPS index. They share similar asynchronous update method as the following.
Updating the encodings for each document after every time step update of retriever is not realistic. The solution is to “refresh” the index by asynchronously re-embedding and re-indexing all documents every several hundred training steps.
These papers asynchronously refresh the MIPS index by running two jobs in parallel: a primary trainer job, which performs gradient updates on the parameters, and a secondary index builder job, which embeds and indexes the documents.
(1) the trainer sends the index builder a snapshot of its parameters, $\theta^{\prime}$;
(2) The trainer then continues to train while the index builder uses $\theta^{\prime}$ to construct a new index in the background.
As soon as the index builder is done, it sends the new index back to the trainer, and the process repeats. The MIPS index is slightly stale between refreshes, but it is only used to select the top $k$ documents. After documents are selected, all computations use the updated $\theta$.
At inference time, the model uses fixed MIPS index after pre-training.