Skip to content

Split-K predicate is missing on TMA store #4164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jacobhinkle opened this issue Apr 1, 2025 · 1 comment
Open

Split-K predicate is missing on TMA store #4164

jacobhinkle opened this issue Apr 1, 2025 · 1 comment
Assignees
Labels

Comments

@jacobhinkle
Copy link
Collaborator

On Hopper, when we have both splitk and smem epilogue enabled, we generate code like the following:

  bool b24;
  b24 = ((nvfuser_index_t)threadIdx.x) < 32ULL;

  bool b28;
  b28 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1);

// ...

      wgmma::fence();
      wgmma::m64n256k16Half<1, 1, 1, 1>((*reinterpret_cast<Array<float, 128, 1>*>(&T7[0])), (4611686293305294848ULL | ((262143ULL & (uint64_t)(i56)) >> 4ULL)), (4611686293338849280ULL | ((262143ULL & (uint64_t)(i57)) >> 4ULL)), true);
    }
    __syncthreads();
  }
  #pragma unroll
  for(nvfuser_index_t i58 = 0; i58 < 4; ++i58) {
    if (((Hopper::electSync(4294967295U) && b24) && b25)) {
      mbarrier::inval(toSmem((&T10[i58])));
    }
  }
  wgmma::wait<0LL>();
  // Allocate global tensor T9
  grid_sync::blockSerializeWait<false, false, true>(&T9[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll
  for(nvfuser_index_t i29 = 0; i29 < 32; ++i29) {
    nvfuser_index_t i59;
    i59 = 2 * i29;
    nvfuser_index_t i60;
    i60 = 4 * i29;
    nvfuser_index_t i61;
    i61 = 8 * i29;
    nvfuser_index_t i62;
    i62 = i17 + i61;
    bool b63;
    b63 = i26 < (-i61);
    #pragma unroll
    for(nvfuser_index_t i31 = 0; i31 < 2; ++i31) {
      bool b64;
      b64 = b63 && (i27 < (-(8 * i31)));
      // Allocate global tensor T11
      reduction::serialReductionStep</*vec_size=*/2>(
        &T2[(i59 + (64 * i31))],
        &T7[(i60 + (2 * i31))],
        0.000000000e+00f,
        &T11[(i62 + (i18 * i31))],
        [](float &a, float b) { a = a + b; },
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        b64,
        b64);
    }
  }
  grid_sync::blockSerializeRelease<false, false, true>(&T9[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  Array<__half, 128, 8> T6;
  #pragma unroll
  for(nvfuser_index_t i65 = 0; i65 < 32; ++i65) {
    nvfuser_index_t i66;
    i66 = 2 * i65;
    nvfuser_index_t i67;
    i67 = 4 * i65;
    #pragma unroll
    for(nvfuser_index_t i68 = 0; i68 < 2; ++i68) {
      nvfuser_index_t i69;
      i69 = i66 + (64 * i68);
      nvfuser_index_t i70;
      i70 = i67 + (2 * i68);
      #pragma unroll
      for(nvfuser_index_t i71 = 0; i71 < 2; ++i71) {
        T6[(i70 + i71)]
           = __float2half(T2[(i69 + i71)]);
      }
    }
  }
  #pragma unroll
  for(nvfuser_index_t i72 = 0; i72 < 16; ++i72) {
    if (b28) {
      stmatrix4((uint32_t)((toSmem(T8) + ((((nvfuser_index_t)threadIdx.y) * 32768) + (((i72 / 4) * 8192) + ((i19 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i72 % 4) * 2)) ^ (i19 % 8)) * 16)))))), (*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i72)])));
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i73 = 0; i73 < 4; ++i73) {
    fenceAsyncProxy();
    if ((Hopper::electSync(4294967295U) && b24)) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr21, (Array<int, 2, 1>{(int32_t)((i6 + (64 * i73))), i23}) }), (i20 + (8192 * i73)));
    }
  }
  cpAsyncBulkCommitGroup();
  cpAsyncBulkWaitGroup<0LL>();
}

The predicate b28 is predicating the stmatrix call which is good, but it should also guard the TMA store, otherwise all CTAs in the grid will be writing the output instead of just the last CTA in each split-K segment, which is both slow and incorrect.

@jacobhinkle jacobhinkle self-assigned this Apr 1, 2025
@jacobhinkle
Copy link
Collaborator Author

Note that the code above can give correct results as long as the CTAs are ordered linearly. There will be multiple TMA stores to each output location, but the last one should be the correct one. This is still inefficient though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant