@@ -219,6 +219,22 @@ def _optimize_layout_euclidean_densmap_epoch_init(
219219 re_sum [i ] = np .log (epsilon + (re_sum [i ] / phi_sum [i ]))
220220
221221
222+ _nb_optimize_layout_euclidean_single_epoch = numba .njit (
223+ _optimize_layout_euclidean_single_epoch , fastmath = True , parallel = False
224+ )
225+
226+ _nb_optimize_layout_euclidean_single_epoch_parallel = numba .njit (
227+ _optimize_layout_euclidean_single_epoch , fastmath = True , parallel = True
228+ )
229+
230+
231+ def _get_optimize_layout_euclidean_single_epoch_fn (parallel : bool = False ):
232+ if parallel :
233+ return _nb_optimize_layout_euclidean_single_epoch_parallel
234+ else :
235+ return _nb_optimize_layout_euclidean_single_epoch
236+
237+
222238def optimize_layout_euclidean (
223239 head_embedding ,
224240 tail_embedding ,
@@ -308,9 +324,10 @@ def optimize_layout_euclidean(
308324 epoch_of_next_negative_sample = epochs_per_negative_sample .copy ()
309325 epoch_of_next_sample = epochs_per_sample .copy ()
310326
311- optimize_fn = numba .njit (
312- _optimize_layout_euclidean_single_epoch , fastmath = True , parallel = parallel
313- )
327+ # Fix for calling UMAP many times for small datasets, otherwise we spend here
328+ # a lot of time in compilation step (first call to numba function)
329+ optimize_fn = _get_optimize_layout_euclidean_single_epoch_fn (parallel )
330+
314331 if densmap_kwds is None :
315332 densmap_kwds = {}
316333 if tqdm_kwds is None :
@@ -352,7 +369,6 @@ def optimize_layout_euclidean(
352369 ) + head_embedding [:, 0 ].astype (np .float64 ).view (np .int64 ).reshape (- 1 , 1 )
353370
354371 for n in tqdm (range (n_epochs ), ** tqdm_kwds ):
355-
356372 densmap_flag = (
357373 densmap
358374 and (densmap_kwds ["lambda" ] > 0 )
0 commit comments