Skip to content

Commit c9dcc15

Browse files
authored
Merge pull request #1152 from kmkolasinski/master
Compile _optimize_layout_euclidean_single_epoch once
2 parents c72ac2f + f8895b8 commit c9dcc15

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

umap/layouts.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,22 @@ def _optimize_layout_euclidean_densmap_epoch_init(
219219
re_sum[i] = np.log(epsilon + (re_sum[i] / phi_sum[i]))
220220

221221

222+
_nb_optimize_layout_euclidean_single_epoch = numba.njit(
223+
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=False
224+
)
225+
226+
_nb_optimize_layout_euclidean_single_epoch_parallel = numba.njit(
227+
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=True
228+
)
229+
230+
231+
def _get_optimize_layout_euclidean_single_epoch_fn(parallel: bool = False):
232+
if parallel:
233+
return _nb_optimize_layout_euclidean_single_epoch_parallel
234+
else:
235+
return _nb_optimize_layout_euclidean_single_epoch
236+
237+
222238
def optimize_layout_euclidean(
223239
head_embedding,
224240
tail_embedding,
@@ -308,9 +324,10 @@ def optimize_layout_euclidean(
308324
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
309325
epoch_of_next_sample = epochs_per_sample.copy()
310326

311-
optimize_fn = numba.njit(
312-
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=parallel
313-
)
327+
# Fix for calling UMAP many times for small datasets, otherwise we spend here
328+
# a lot of time in compilation step (first call to numba function)
329+
optimize_fn = _get_optimize_layout_euclidean_single_epoch_fn(parallel)
330+
314331
if densmap_kwds is None:
315332
densmap_kwds = {}
316333
if tqdm_kwds is None:
@@ -352,7 +369,6 @@ def optimize_layout_euclidean(
352369
) + head_embedding[:, 0].astype(np.float64).view(np.int64).reshape(-1, 1)
353370

354371
for n in tqdm(range(n_epochs), **tqdm_kwds):
355-
356372
densmap_flag = (
357373
densmap
358374
and (densmap_kwds["lambda"] > 0)

0 commit comments

Comments
 (0)