I'm implementing a transformer and I have everything working, including attention using the new scaled_dot_product_attention
from PyTorch 2.0. I'll only be doing causal attention, however, so it seems like it makes sense to use the is_causal=True
flag for efficiency. This also works as I'd expect as long as the k, v and q tensors have the same size.
But I'm not sure how to pass along past (cached) keys/values to the function in this mode. If the k, v tensors are wider than q, I need a rectangular mask as wide as k/v and as tall as q, with the upper right triangle masked out. All is well if I construct such a mask myself and pass it to the function. I get behavior similar to typical causal attention, where past tokens are attended to fully and new tokens (for which there are queries) are attended causally.
According to the documentation, though, is_causal=True
is equivalent to using a mask built with:
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
Where L and S are the query and key/value lengths, respectively. This has all but the lower left triangular portion masked out, which attends partially to past tokens and not at all to new tokens. Is this causal mode just not suitable for my use case, or am I missing something?
Suppose I have the following tensors:
q = torch.rand((1, n_heads, 3, head_dim))
k = torch.rand((1, n_heads, 6, head_dim))
v = torch.rand((1, n_heads, 6, head_dim))
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
Where k and v are wider since they're concatenated onto the cached results of a previous inference pass. scaled_dot_product_attention
applies the following mask:
[[0, -inf, -inf, -inf, -inf, -inf]
[0, 0, -inf, -inf, -inf, -inf]
[0, 0, 0, -inf, -inf, -inf]]
But I would (maybe naively?) expect the attention operation to use a mask like this:
[[0, 0, 0, 0, -inf, -inf]
[0, 0, 0, 0, 0, -inf]
[0, 0, 0, 0, 0, 0]]
Can I achieve this somehow with scaled_dot_product_attention
, or am I maybe going about this all wrong? How am I meant to use the function with a cache of past keys/values?