Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #456 from kshitij12345/iterate/unroll/block_adjace…
Browse files Browse the repository at this point in the history
…nt_difference

pragma unroll in block_adjacent_difference::Iterate
  • Loading branch information
alliepiper authored Apr 25, 2022
2 parents d2c014c + e37eac8 commit 191172d
Showing 1 changed file with 27 additions and 59 deletions.
86 changes: 27 additions & 59 deletions cub/block/block_adjacent_difference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ private:
};

/// Templated unrolling of item comparison (inductive case)
template <int ITERATION, int MAX_ITERATIONS>
struct Iterate
{
/**
Expand All @@ -210,19 +209,15 @@ private:
T (&preds)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
preds[ITERATION] = input[ITERATION - 1];

flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
#pragma unroll
for (int i = 1; i < ITEMS_PER_THREAD; ++i) {
preds[i] = input[i - 1];
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
preds[ITERATION],
input[ITERATION],
(linear_tid * ITEMS_PER_THREAD) + ITERATION);

Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagHeads(linear_tid,
flags,
input,
preds,
flag_op);
preds[i],
input[i],
(linear_tid * ITEMS_PER_THREAD) + i);
}
}

/**
Expand All @@ -239,44 +234,17 @@ private:
T (&input)[ITEMS_PER_THREAD],
FlagOp flag_op)
{
flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD - 1; ++i) {
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
input[ITERATION],
input[ITERATION + 1],
(linear_tid * ITEMS_PER_THREAD) + ITERATION + 1);

Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagTails(linear_tid,
flags,
input,
flag_op);
input[i],
input[i + 1],
(linear_tid * ITEMS_PER_THREAD) + i + 1);
}
}
};

/// Templated unrolling of item comparison (termination case)
template <int MAX_ITERATIONS>
struct Iterate<MAX_ITERATIONS, MAX_ITERATIONS>
{
// Head flags
template <int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
static __device__ __forceinline__ void
FlagHeads(int /*linear_tid*/,
FlagT (&/*flags*/)[ITEMS_PER_THREAD],
T (&/*input*/)[ITEMS_PER_THREAD],
T (&/*preds*/)[ITEMS_PER_THREAD],
FlagOp /*flag_op*/)
{}

// Tail flags
template <int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
static __device__ __forceinline__ void
FlagTails(int /*linear_tid*/,
FlagT (&/*flags*/)[ITEMS_PER_THREAD],
T (&/*input*/)[ITEMS_PER_THREAD],
FlagOp /*flag_op*/)
{}
};


/***************************************************************************
* Thread fields
**************************************************************************/
Expand Down Expand Up @@ -991,7 +959,7 @@ public:
}

// Set output for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, output, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, output, input, preds, flag_op);
}

/**
Expand Down Expand Up @@ -1021,7 +989,7 @@ public:
output[0] = ApplyOp<FlagOp>::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD);

// Set output for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, output, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, output, input, preds, flag_op);
}

#endif // DOXYGEN_SHOULD_SKIP_THIS
Expand Down Expand Up @@ -1088,7 +1056,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set output for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, output, input, flag_op);
Iterate::FlagTails(linear_tid, output, input, flag_op);
}


Expand Down Expand Up @@ -1123,7 +1091,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set output for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, output, input, flag_op);
Iterate:FlagTails(linear_tid, output, input, flag_op);
}


Expand Down Expand Up @@ -1176,10 +1144,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -1234,10 +1202,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}

/**
Expand Down Expand Up @@ -1285,10 +1253,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -1340,10 +1308,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}

};
Expand Down

0 comments on commit 191172d

Please sign in to comment.