|
8 | 8 |
|
9 | 9 | #include <executorch/extension/llm/sampler/sampler.h> |
10 | 10 |
|
| 11 | +#include <set> |
| 12 | + |
11 | 13 | #include <gtest/gtest.h> |
12 | 14 | #include <torch/torch.h> |
13 | 15 |
|
@@ -39,3 +41,114 @@ TEST(SamplerTest, TestArgMaxWithFP16) { |
39 | 41 | input[0][0][396] = 1.0f; |
40 | 42 | EXPECT_EQ(sampler.sample(input.data_ptr<c10::Half>()), 396); |
41 | 43 | } |
| 44 | + |
| 45 | +TEST(SamplerTest, TestTopKRestrictsToCandidates) { |
| 46 | + // With topk=3, sampling must always return one of the top-3 indices, |
| 47 | + // regardless of the random draw. |
| 48 | + Sampler sampler{ |
| 49 | + /*vocab_size*/ 100, |
| 50 | + /*temperature*/ 1.0f, |
| 51 | + /*topp*/ 0.0f, // disable top-p so we exercise top-k alone |
| 52 | + /*rng_seed*/ 42}; |
| 53 | + sampler.set_topk(3); |
| 54 | + |
| 55 | + // Construct logits where indices {7, 13, 42} dominate. |
| 56 | + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); |
| 57 | + input[7] = 5.0f; |
| 58 | + input[13] = 4.5f; |
| 59 | + input[42] = 4.0f; |
| 60 | + |
| 61 | + std::set<int32_t> allowed = {7, 13, 42}; |
| 62 | + for (int trial = 0; trial < 50; ++trial) { |
| 63 | + // Re-fill logits each trial because sample() mutates them in place. |
| 64 | + torch::Tensor logits = input.clone(); |
| 65 | + int32_t out = sampler.sample(logits.data_ptr<float>()); |
| 66 | + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +TEST(SamplerTest, TestTopKDisabledByZero) { |
| 71 | + // topk=0 means disabled. With topp disabled, sampling collapses to |
| 72 | + // multinomial over the full vocab, but the dominant token should still |
| 73 | + // win the vast majority of the time. |
| 74 | + Sampler sampler{ |
| 75 | + /*vocab_size*/ 50, |
| 76 | + /*temperature*/ 1.0f, |
| 77 | + /*topp*/ 0.0f, |
| 78 | + /*rng_seed*/ 7}; |
| 79 | + sampler.set_topk(0); // disabled |
| 80 | + |
| 81 | + torch::Tensor input = torch::full({50}, -10.0f, at::kFloat); |
| 82 | + input[11] = 20.0f; // dominant |
| 83 | + |
| 84 | + int hits = 0; |
| 85 | + for (int trial = 0; trial < 20; ++trial) { |
| 86 | + torch::Tensor logits = input.clone(); |
| 87 | + if (sampler.sample(logits.data_ptr<float>()) == 11) { |
| 88 | + hits++; |
| 89 | + } |
| 90 | + } |
| 91 | + EXPECT_GE(hits, 18); // dominant token should win nearly every time |
| 92 | +} |
| 93 | + |
| 94 | +TEST(SamplerTest, TestTopKWithFP16) { |
| 95 | + // Smoke test the FP16 template instantiation of the top-k path. |
| 96 | + Sampler sampler{ |
| 97 | + /*vocab_size*/ 50, |
| 98 | + /*temperature*/ 1.0f, |
| 99 | + /*topp*/ 0.0f, |
| 100 | + /*rng_seed*/ 99}; |
| 101 | + sampler.set_topk(2); |
| 102 | + |
| 103 | + torch::Tensor input = torch::full({50}, -10.0f, at::kHalf); |
| 104 | + input[3] = 5.0f; |
| 105 | + input[8] = 4.5f; |
| 106 | + |
| 107 | + std::set<int32_t> allowed = {3, 8}; |
| 108 | + for (int trial = 0; trial < 30; ++trial) { |
| 109 | + torch::Tensor logits = input.clone(); |
| 110 | + int32_t out = sampler.sample(logits.data_ptr<c10::Half>()); |
| 111 | + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +TEST(SamplerTest, TestTopKEqualsOneIsArgmax) { |
| 116 | + // topk=1 should behave like greedy argmax even with temperature > 0. |
| 117 | + Sampler sampler{ |
| 118 | + /*vocab_size*/ 100, |
| 119 | + /*temperature*/ 1.0f, |
| 120 | + /*topp*/ 0.0f, |
| 121 | + /*rng_seed*/ 123}; |
| 122 | + sampler.set_topk(1); |
| 123 | + |
| 124 | + torch::Tensor input = torch::rand({100}, at::kFloat); |
| 125 | + input[57] = 100.0f; // make 57 the unambiguous max |
| 126 | + |
| 127 | + for (int trial = 0; trial < 10; ++trial) { |
| 128 | + torch::Tensor logits = input.clone(); |
| 129 | + EXPECT_EQ(sampler.sample(logits.data_ptr<float>()), 57); |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +TEST(SamplerTest, TestTopKTakesPrecedenceOverTopP) { |
| 134 | + // When both top-k and top-p are set, top-k should restrict the candidate |
| 135 | + // set; top-p alone would admit a third token that top-k=2 must exclude. |
| 136 | + Sampler sampler{ |
| 137 | + /*vocab_size*/ 100, |
| 138 | + /*temperature*/ 1.0f, |
| 139 | + /*topp*/ 0.99f, // would keep nearly the whole vocab on its own |
| 140 | + /*rng_seed*/ 99}; |
| 141 | + sampler.set_topk(2); |
| 142 | + |
| 143 | + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); |
| 144 | + input[3] = 5.0f; |
| 145 | + input[8] = 4.5f; |
| 146 | + input[19] = 4.0f; // would be in the top-p set but is excluded by top-k=2 |
| 147 | + |
| 148 | + std::set<int32_t> allowed = {3, 8}; |
| 149 | + for (int trial = 0; trial < 50; ++trial) { |
| 150 | + torch::Tensor logits = input.clone(); |
| 151 | + int32_t out = sampler.sample(logits.data_ptr<float>()); |
| 152 | + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; |
| 153 | + } |
| 154 | +} |
0 commit comments