いまさらではあるものの、Self-attention (自己注意)の計算方法について途中の行列の形状に着目して調べてみました。
結果だけ知りたい方は、最後のまとめに進んでください。
参考文献
元論文はAttention Is All You Needです。
論文の行間を読むか参考文献を遡っていけば、具体的にどういう計算をするのか分かるのかもしれませんが、 論文の最後に、コードへのURL https://github.com/tensorflow/tensor2tensor が記載されていますので、こちらを主に参考にして、計算方法を調べていきます。
Q, K, Vはどうやって計算するの?
論文の式(1)で使われているQ=Query、K=Key、V=Valueの3つの値の計算方法を調べてみます。
論文のFig.1を見ると、入力が3つに分岐しているので、何かをどうにかして入力をQ, K, Vの3つにしていることはわかります。
まず、transformer_encode関数を見てみます。コメントによると、入力であるinputsの形状は(batch_size, input_length, 1, hidden_dim)とのことです。
関数の呼び出し直後に形状を変換していて、結局、inputsは(batch_size, input_length, hidden_dim)になっています。
その後、encoder_functionが呼び出されるのですが、これの中身は、 transformer_encoder です。この関数内の213行目から、common_attention.multihead_attentionを呼び出します。
キャッシュがなく、self-attentionの場合であれば、4650行目にて、compute_qkvが呼び出されます。 ここで、Q, K, Vが計算されているようです。
定義は
def compute_qkv(query_antecedent, ←これがcommon_kayers.layer_preprocess(x, hparams) memory_antecedent, ←これがNone total_key_depth, total_value_depth, ...となっていて、memory_antecedentがNoneであることは、呼び出し元のtransformer_encoder:213に戻るとわかります。
memory_antecedentがNoneならquery_antecedentにしているので、self-attentionの場合、入力はquery_antecedentと考えればよさそうです。
さて、Q, K, V (コード中ではq, k, v)はcompute_attention_componentで計算されていて、入力は
antecedent: a Tensor with shape [batch, length, channels]戻り値が
c : [batch, length, depth] tensorとなっています。filter_widthによって処理内容が異なるようですが、filter_width == 1のケースを見てみると、 4415行目で
のように書かれています。バイアスなしなので、単にMatMulの計算をしているだけになります。return common_layers.dense( antecedent, total_depth, use_bias=False, name=name, layer_collection=layer_collection)
Tensorflowのドキュメントによると、
Dense implements the operation: output = activation(dot(input, kernel) + bias) <中略> kernel is a weights matrix created by the layer,なので、
以上を考慮してcompute_qkvを読むと、q, k, vの形状は
q = (batch, length_q, total_key_depth) k = (batch, length_kv, total_key_depth) v = (batch, length_kv, total_value_depth)となっていることがわかります。
Self-attentionの計算はどうなるの?
multihead_attentionの引数に指定するattention_typeには色々種類があるようですが、デフォルト指定されているdot_productを見てみます。
dot_product_attentionの引数の説明には、
と書かれています。Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q.
まず、1648行目で、
と計算しています。論文中の式(1)のQKTの部分です。logits = tf.matmul(q, k, transpose_b=True)
式(1)の√dkで割る部分が見当たりませんが、どこかで計算されているとして、次にsoftmaxの計算をみてみます。これは、1654行目で
で計算されています。weightsの形状はlogitsと同じです。weights = tf.nn.softmax(logits, name="attention_weights")
ドロップアウトの処理をした後、1667行目の
で、式(1)の計算が完了します。これは、return tf.matmul(weights, v)
dot_product_attentionのコメント部分にも
と書いてあります。Returns: Tensor with shape [..., length_q, depth_v].
以上の計算の途中で得られる[weights]の形状が(..., length_q, length_kv)となっており、 self-attentionの場合はlength_q=length_kvですので、系列長の2乗で必要になるメモリや計算量が増えていくことになります。
まとめ
コードを調べたことで、論文の式(1)の各行列の形状は、
Q∈RLq×dk
K∈RLkv×dk
V∈RLkv×dv
QKT∈RLq×Lkv
(QKT)V∈RLq×dv
であることが明確になりました。ここで、LqはQの系列長、LkvはKとVの系列長です。 dkとdvは論文と同じです。
さらに、self-attentionの場合、Lq=Lkvですので、単にLとすれば、
Q∈RL×dk
K∈RL×dk
V∈RL×dv
QKT∈RL×L
(QKT)V∈RL×dv
のようにLの添字をなくせるのですっきりします。
QとKとVは、入力X∈RL×CをバイアスなしのDenseレイヤーに通すことで得ることができます。
つまり、Denseレイヤーの重み行列をそれぞれ
Wq∈RC×dk、
Wk∈RC×dk、
Wv∈RC×dvとすると、
Q=XWq
K=XWk
V=XWv
となります。ここで、Cは、X={x1,x2,...,xi,...xL}としたときの特徴ベクトルxiの次元数です。