-
Notifications
You must be signed in to change notification settings - Fork 184
Description
In KNeighborsClassifier, method .predict_proba() is implemented in Python, as the underlying oneDAL class doesn't offer this method.
When running in distributed mode, the nearest neighbors for a given point might not necessarily be in the same data fold that the rank is holding, which makes the implementation of such methods more challenging. As such, .predict_proba() throws a NotImplementedError when called in an SPMD KNeighborsClassifier object:
#2700
For class predictions (calls to .predict()), oneDAL C++ would handle this in distributed mode by moving data among ranks:
https://github.com/uxlfoundation/oneDAL/blob/main/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp#L669
A similar logic could be implemented in Python through mpi4py, although it's unclear how efficient it would be. For example, it could use this function:
https://mpi4py.readthedocs.io/en/latest/reference/mpi4py.MPI.Comm.html#mpi4py.MPI.Comm.Isendrecv_replace
Or perhaps this:
https://mpi4py.readthedocs.io/en/stable/reference/mpi4py.MPI.Comm.html#mpi4py.MPI.Comm.isend
But there's the challenge that the Python implementation needs to work with array API objects and restrict itself to operations supported by this standard.