-
Notifications
You must be signed in to change notification settings - Fork 197
Specialize sample for sparse weights
#943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
nalimilan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
Bump. |
|
My gut feeling is that we should address #885 first and then add a specialisation to the SparseArrays extension. By keeping a hard dependency on SparseArrays, StatsBase is holding back large parts of the Julia ecosystem. |
|
Yeah but this method is easy to move to an extension as soon as we create it, and it doesn't make things worse until then. |
Co-authored-by: Milan Bouchet-Valat <[email protected]>
Co-authored-by: Milan Bouchet-Valat <[email protected]>
d852edf to
db126d1
Compare
| i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) | ||
| return rowvals(wv.values)[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is unsafe - in general AbstractWeights are not required to have a values field. It's just a few AbstractWeights subtypes in StatsBase that have an (undocumented and internal) values field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So actually better define this method only for types defines in Base. Probably using:
for W in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
@eval function sample(rng::AbstractRNG, wv::W{<:Real,<:Real,<:SparseVector})
...(I'm saying this because AFAICT there's no public API which allows accessing the backing array. And anyway I'm not aware of custom AbstractWeights types defined elsewhere so we don't really care to apply this optimization to them.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are insufficient - since the method is implemented for AbstractWeights, to be sure it works not only for Weights we should test all subtypes implemented in StatsBase and a custom subtype of AbstractWeights.
This PR adds a new
samplemethod for sparse weights, as well as tests. It brings the time complexity fromO(n)toO(n_nonzero).This would be useful for e.g. top-p sampling, where one might have on the order of 100k tokens to sample from, but only a few are considered.
Benchmarks across different sizes and densities
Results
This shows the dense baseline, and the relative performance increase to invoking
samplewith the generic method for sparse weights.Benchmark setup
Note: For small vector lengths (~10) and low densities (~0.2) the performance difference becomes noisy and less meaningful. The generic method can sometimes be faster in these cases due to less overhead when it happens to find the target probability mass early in the vector. However, for these small cases the absolute timing differences are negligible (few nanoseconds) and sparse storage isn't really beneficial anyway.
Note: The implementation uses
SparseArrays.nonzeroinds, which is not public.