Skip to content

Commit

Permalink
update comment
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 2, 2023
1 parent d3dddab commit 4a72f93
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,16 @@ struct SumExpFloat {

// One block has N(warps_per_block) warps, one warp has M(WARP_SIZE) threads.
// 1. All the threads in one block read data into shared memory.
// 2. Reduce all data to the first warp. Only the first N threads of warp-0 are used. thread-0 computes data in warp-0 and
// writes the result into the location of data0, thread-1 computes data in warp-1 and writes the result into the location of data1.
// __syncwarp(mask) is necessary here to make sure thread-1,...N will delay writing data into warp-0 until thread-0
// has finished reading data from warp-0.
// 2. Reduce all data to the first warp. Only the threads of warp-0 are used. Each thread in warp-0 reads data from the
// same location of every warp and computes result. For example, thread-0 computes the first data of every warp and
// writes the result into the location of data0.
// Shared memory
// -----------------------------------------------------------------------------------------------------------------------
// | data0 | data1 | data2 | .... | dataM | ... | dataM*2 | ... |
// -----------------------------------------------------------------------------------------------------------------------
// | | | |
// -------------------warp-0----------------------------------warp-1----------------------------------warp-2--------------
// TODO: ROCm doesn't support __syncwarp() now, we need another implementation to make sure read before write.
// 3. Thread-0 reduces all vaild data in warp-0 and writes the results into the location of data0, then return data0.
// 3. Thread-0 reduces all data in warp-0 and writes the results into the location of data0, then return data0.

template <template <typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val,
Expand Down

0 comments on commit 4a72f93

Please sign in to comment.