Abstract:
Transformers have been proven a successful model for a variety of tasks in
sequence modeling. However, computing the attention matrix, which is their key
component, has quadratic complexity with respect to the sequence length, thus
making them prohibitively expensive for large sequences. To address this, we
propose clustered attention, which instead of computing the attention
for every query, groups queries into clusters and computes attention just for
the centroids. To further improve this approximation, we use the computed
clusters to identify the keys with the highest attention per query and compute
the exact key/query dot products. This results in a model with linear
complexity with respect to the sequence length for a fixed number of clusters.
We evaluate our approach on two automatic speech recognition datasets and show
that our model consistently outperforms vanilla transformers for a given
computational budget. Finally, we demonstrate that our model can approximate
arbitrarily complex attention distributions with a minimal number of clusters
by approximating a pretrained BERT model on GLUE and SQuAD benchmarks with only
25 clusters and no loss in performance.