Skip to content

Commit 2d9bbc1

Browse files
authored
Add top-k sampling support to llm Sampler (#19122)
Differential Revision: D102385104 Pull Request resolved: #19122
1 parent bdf1bf4 commit 2d9bbc1

3 files changed

Lines changed: 178 additions & 1 deletion

File tree

extension/llm/sampler/sampler.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,56 @@ int32_t Sampler::sample_mult(T* probabilities, float coin) {
6969
return vocab_size_ - 1; // in case of rounding errors
7070
}
7171

72+
template <typename T>
73+
int32_t Sampler::sample_topk(T* probabilities, float coin) {
74+
// top-k sampling samples from the k highest-probability tokens.
75+
// coin is a random number in [0, 1), usually from random_f32().
76+
//
77+
// TODO: probindex is allocated on every call; lifting it to a member
78+
// would avoid per-token heap allocation in autoregressive loops.
79+
const int n = vocab_size_;
80+
const int k = std::min(topk_, n);
81+
// Defensive: callers gate on topk_ > 0, but a private helper should not
82+
// rely on external invariants. Fall back to a deterministic index.
83+
if (k <= 0) {
84+
return 0;
85+
}
86+
87+
std::unique_ptr<ProbIndex<T>[]> probindex =
88+
std::make_unique<ProbIndex<T>[]>(n);
89+
for (int i = 0; i < n; i++) {
90+
probindex[i].index = i;
91+
probindex[i].prob = probabilities[i];
92+
}
93+
94+
auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
95+
return a.prob > b.prob;
96+
};
97+
// Partial sort: only the top-k entries need to be sorted in descending order.
98+
std::partial_sort(
99+
probindex.get(), probindex.get() + k, probindex.get() + n, compare);
100+
101+
// Sum of the top-k probabilities. Used to scale `coin` instead of
102+
// explicitly renormalizing the k probs — mathematically equivalent and
103+
// saves k divisions. Accumulate in float so FP16/BF16 inputs don't lose
104+
// precision over k summands.
105+
float topk_sum = 0.0f;
106+
for (int i = 0; i < k; i++) {
107+
topk_sum += static_cast<float>(probindex[i].prob);
108+
}
109+
110+
// Sample from the (implicitly renormalized) top-k distribution.
111+
const float r = coin * topk_sum;
112+
float cdf = 0.0f;
113+
for (int i = 0; i < k; i++) {
114+
cdf += static_cast<float>(probindex[i].prob);
115+
if (r < cdf) {
116+
return probindex[i].index;
117+
}
118+
}
119+
return probindex[k - 1].index; // in case of rounding errors
120+
}
121+
72122
template <typename T>
73123
int32_t Sampler::sample_topp(T* probabilities, float coin) {
74124
// top-p sampling (or "nucleus sampling") samples from the smallest set of
@@ -186,7 +236,10 @@ int32_t Sampler::sample(T* logits) {
186236
// flip a (float) coin (this is our source of entropy for sampling)
187237
float coin = random_f32(&rng_state_);
188238
// we sample from this distribution to get the next token
189-
if (topp_ <= 0 || topp_ >= 1) {
239+
if (topk_ > 0 && topk_ < vocab_size_) {
240+
// top-k sampling, restrict to the k most likely tokens
241+
next = sample_topk(logits, coin);
242+
} else if (topp_ <= 0 || topp_ >= 1) {
190243
// simply sample from the predicted probability distribution
191244
next = sample_mult(logits, coin);
192245
} else {

extension/llm/sampler/sampler.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,22 @@ class ET_EXPERIMENTAL Sampler {
4444

4545
Sampler(int32_t vocab_size, float temperature);
4646

47+
// Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k.
48+
// When top-k is enabled, top-p is ignored — the two modes are mutually
49+
// exclusive in this implementation.
50+
void set_topk(int32_t topk) {
51+
topk_ = topk;
52+
}
53+
4754
template <typename T>
4855
int32_t sample(T* logits);
4956

5057
private:
5158
template <typename T>
5259
int32_t sample_topp(T* probabilities, float coin);
5360
template <typename T>
61+
int32_t sample_topk(T* probabilities, float coin);
62+
template <typename T>
5463
int32_t sample_mult(T* probabilities, float coin);
5564
template <typename T>
5665
int32_t sample_argmax(T* probabilities);
@@ -60,6 +69,8 @@ class ET_EXPERIMENTAL Sampler {
6069
// reciprocal of temperature, or 0 if temperature == 0.
6170
float inv_temperature_;
6271
float topp_;
72+
// 0 (or >= vocab_size_) means top-k is disabled.
73+
int32_t topk_ = 0;
6374
unsigned long long rng_state_;
6475
};
6576

extension/llm/sampler/test/test_sampler.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/extension/llm/sampler/sampler.h>
1010

11+
#include <set>
12+
1113
#include <gtest/gtest.h>
1214
#include <torch/torch.h>
1315

@@ -39,3 +41,114 @@ TEST(SamplerTest, TestArgMaxWithFP16) {
3941
input[0][0][396] = 1.0f;
4042
EXPECT_EQ(sampler.sample(input.data_ptr<c10::Half>()), 396);
4143
}
44+
45+
TEST(SamplerTest, TestTopKRestrictsToCandidates) {
46+
// With topk=3, sampling must always return one of the top-3 indices,
47+
// regardless of the random draw.
48+
Sampler sampler{
49+
/*vocab_size*/ 100,
50+
/*temperature*/ 1.0f,
51+
/*topp*/ 0.0f, // disable top-p so we exercise top-k alone
52+
/*rng_seed*/ 42};
53+
sampler.set_topk(3);
54+
55+
// Construct logits where indices {7, 13, 42} dominate.
56+
torch::Tensor input = torch::full({100}, -10.0f, at::kFloat);
57+
input[7] = 5.0f;
58+
input[13] = 4.5f;
59+
input[42] = 4.0f;
60+
61+
std::set<int32_t> allowed = {7, 13, 42};
62+
for (int trial = 0; trial < 50; ++trial) {
63+
// Re-fill logits each trial because sample() mutates them in place.
64+
torch::Tensor logits = input.clone();
65+
int32_t out = sampler.sample(logits.data_ptr<float>());
66+
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
67+
}
68+
}
69+
70+
TEST(SamplerTest, TestTopKDisabledByZero) {
71+
// topk=0 means disabled. With topp disabled, sampling collapses to
72+
// multinomial over the full vocab, but the dominant token should still
73+
// win the vast majority of the time.
74+
Sampler sampler{
75+
/*vocab_size*/ 50,
76+
/*temperature*/ 1.0f,
77+
/*topp*/ 0.0f,
78+
/*rng_seed*/ 7};
79+
sampler.set_topk(0); // disabled
80+
81+
torch::Tensor input = torch::full({50}, -10.0f, at::kFloat);
82+
input[11] = 20.0f; // dominant
83+
84+
int hits = 0;
85+
for (int trial = 0; trial < 20; ++trial) {
86+
torch::Tensor logits = input.clone();
87+
if (sampler.sample(logits.data_ptr<float>()) == 11) {
88+
hits++;
89+
}
90+
}
91+
EXPECT_GE(hits, 18); // dominant token should win nearly every time
92+
}
93+
94+
TEST(SamplerTest, TestTopKWithFP16) {
95+
// Smoke test the FP16 template instantiation of the top-k path.
96+
Sampler sampler{
97+
/*vocab_size*/ 50,
98+
/*temperature*/ 1.0f,
99+
/*topp*/ 0.0f,
100+
/*rng_seed*/ 99};
101+
sampler.set_topk(2);
102+
103+
torch::Tensor input = torch::full({50}, -10.0f, at::kHalf);
104+
input[3] = 5.0f;
105+
input[8] = 4.5f;
106+
107+
std::set<int32_t> allowed = {3, 8};
108+
for (int trial = 0; trial < 30; ++trial) {
109+
torch::Tensor logits = input.clone();
110+
int32_t out = sampler.sample(logits.data_ptr<c10::Half>());
111+
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
112+
}
113+
}
114+
115+
TEST(SamplerTest, TestTopKEqualsOneIsArgmax) {
116+
// topk=1 should behave like greedy argmax even with temperature > 0.
117+
Sampler sampler{
118+
/*vocab_size*/ 100,
119+
/*temperature*/ 1.0f,
120+
/*topp*/ 0.0f,
121+
/*rng_seed*/ 123};
122+
sampler.set_topk(1);
123+
124+
torch::Tensor input = torch::rand({100}, at::kFloat);
125+
input[57] = 100.0f; // make 57 the unambiguous max
126+
127+
for (int trial = 0; trial < 10; ++trial) {
128+
torch::Tensor logits = input.clone();
129+
EXPECT_EQ(sampler.sample(logits.data_ptr<float>()), 57);
130+
}
131+
}
132+
133+
TEST(SamplerTest, TestTopKTakesPrecedenceOverTopP) {
134+
// When both top-k and top-p are set, top-k should restrict the candidate
135+
// set; top-p alone would admit a third token that top-k=2 must exclude.
136+
Sampler sampler{
137+
/*vocab_size*/ 100,
138+
/*temperature*/ 1.0f,
139+
/*topp*/ 0.99f, // would keep nearly the whole vocab on its own
140+
/*rng_seed*/ 99};
141+
sampler.set_topk(2);
142+
143+
torch::Tensor input = torch::full({100}, -10.0f, at::kFloat);
144+
input[3] = 5.0f;
145+
input[8] = 4.5f;
146+
input[19] = 4.0f; // would be in the top-p set but is excluded by top-k=2
147+
148+
std::set<int32_t> allowed = {3, 8};
149+
for (int trial = 0; trial < 50; ++trial) {
150+
torch::Tensor logits = input.clone();
151+
int32_t out = sampler.sample(logits.data_ptr<float>());
152+
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
153+
}
154+
}

0 commit comments

Comments
 (0)