Scaled Dot-Product Attention

Hard
~25 min
code completion

Scaled Dot-Product Attention

The full attention formula is:

where:

  • (queries): shape (n_q, d_k)
  • (keys): shape (n_k, d_k)
  • (values): shape (n_k, d_v)
  • : key dimension (used for scaling to prevent large dot products)
  • Steps:

    1. Compute scores: → shape (n_q, n_k)

    2. Scale: divide by

    3. Softmax each row (axis=1): each query gets a distribution over keys

    4. Multiply by : weighted sum of values → shape (n_q, d_v)

    Your task:

    Implement scaled_dot_product_attention(Q, K, V).

    Example Tests

    Output shape is (n_q, d_v)

    Input: {"K":[[1,0],[0,1],[1,1]],"Q":[[1,0],[0,1]],"V":[[1,2,3],[4,5,6],[7,8,9]]}

    Expected: [2,3]

    Q=K=V identity: uniform attention, output is weighted avg of V

    Input: {"K":[[1,0],[0,1]],"Q":[[1,0],[0,1]],"V":[[1,0],[0,1]]}

    Expected: [[0.66976,0.33024],[0.33024,0.66976]]

    Attention weights sum to 1 for each query (check via output shape)

    Input: {"K":[[1,0],[0,1]],"Q":[[1,0]],"V":[[1],[2]]}

    Expected: [1,1]

    Sign in to solve this problem

    You can read the full problem statement above. Create a free account to run code in the browser, submit solutions, and track your progress.