@@ -26,6 +26,8 @@ namespace planc {
2626 std::unique_ptr<arma::mat> WT; // kxm
2727 double lambda, sqrtLambda, objective_err;
2828 bool cleared;
29+ std::unique_ptr<arma::sp_mat> tempE;
30+ void load_whole_E (arma::uword i);
2931
3032 virtual double computeObjectiveError () {
3133 // obj_i = ||E_i - (W + V_i)*H_i||_F^2 + lambda * ||V_i*H_i||_F^2
@@ -39,16 +41,20 @@ namespace planc {
3941 arma::mat* Wptr = this ->W .get ();
4042 arma::mat L (this ->m , this ->k ); // (loading) L = W + V
4143 for (arma::uword i = 0 ; i < this ->nDatasets ; ++i) {
42- T* Eptr = this ->Ei [i].get ();
44+ // T* Eptr = this->Ei[i].get();
45+ // arma::sp_mat* Eptr = this->load_whole_E(i);
46+ this ->load_whole_E (i);
47+ arma::sp_mat* Eptr = this ->tempE .get ();
4348 arma::mat* Hptr = this ->Hi [i].get ();
4449 arma::mat* Vptr = this ->Vi [i].get ();
4550 L = *Wptr + *Vptr;
46- double sqnormE = arma::norm<T >(*Eptr, " fro" );
51+ double sqnormE = arma::norm<arma::sp_mat >(*Eptr, " fro" );
4752 sqnormE *= sqnormE;
4853 arma::mat LtL = L.t () * L; // k x k
4954 arma::mat HtH = Hptr->t () * *Hptr; // k x k
5055 double TrLtLHtH = arma::trace (LtL * HtH);
51- T Et = Eptr->t ();
56+ // T Et = Eptr->t();
57+ arma::sp_mat Et = Eptr->t ();
5258 arma::mat EtL = Et * L; // n_i x k
5359 double TrHtEtL = arma::trace (Hptr->t () * EtL);
5460 arma::mat VtV = Vptr->t () * *Vptr; // k x k
@@ -359,8 +365,75 @@ namespace planc {
359365 }
360366 this ->W .reset ();
361367 if (this ->WT .get () != nullptr ) this ->WT .reset ();
368+ this ->tempE .reset ();
362369 }
363370 this ->cleared = true ;
364371 }
365372 }; // class INMF
373+
374+ template <>
375+ inline void INMF<arma::sp_mat>::load_whole_E(arma::uword i) {
376+ // Make a real matrix copied from what's pointed by this->Ei[i]
377+ arma::sp_mat tempSparse = *(this ->Ei [i].get ());
378+ // Make new unique_ptr to manage the memory
379+ auto tempSparseUniquePtr = std::make_unique<arma::sp_mat>(tempSparse);
380+ this ->tempE = std::move (tempSparseUniquePtr);
381+ // return this->Ei[i].get();
382+ }
383+
384+ template <>
385+ inline void INMF<arma::mat>::load_whole_E(arma::uword i) {
386+ auto tempSparse = std::make_unique<arma::sp_mat>(*(this ->Ei [i]));
387+
388+ // Get raw pointer before transferring ownership
389+ // arma::sp_mat* rawPtr = tempSparse.get();
390+
391+ // Move unique_ptr to a class member to ensure it stays alive
392+ this ->tempE = std::move (tempSparse);
393+
394+ // return rawPtr;
395+ }
396+
397+ template <>
398+ inline void INMF<H5SpMat>::load_whole_E(arma::uword i) {
399+ H5SpMat* Eptr = this ->Ei [i].get ();
400+ // arma::uword m = Eptr->n_rows;
401+ arma::uword n = Eptr->n_cols ;
402+ this ->tempE .reset ();
403+ arma::sp_mat tempSparse = Eptr->cols (0 , n - 1 );
404+ auto tempSparseUniquePtr = std::make_unique<arma::sp_mat>(tempSparse);
405+ // auto tempSparse = std::make_unique<arma::sp_mat>(m, n);
406+ // (*tempSparse) = Eptr->cols(0, n - 1);
407+ // Get raw pointer before transferring ownership
408+ // arma::sp_mat* rawPtr = tempSparseUniquePtr.get();
409+
410+ // Move unique_ptr to a class member to ensure it stays alive
411+ this ->tempE = std::move (tempSparseUniquePtr);
412+ // return rawPtr;
413+ }
414+
415+ template <>
416+ inline void INMF<H5Mat>::load_whole_E(arma::uword i) {
417+ H5Mat* Eptr = this ->Ei [i].get ();
418+ // arma::uword m = Eptr->n_rows;
419+ arma::uword n = Eptr->n_cols ;
420+ auto out = std::make_unique<arma::sp_mat>(m, n);
421+ // Load on-disk dense matrix by chunks and convert to sparse and fill
422+ int numChunks = n / this ->INMF_CHUNK_SIZE ;
423+ if (numChunks * this ->INMF_CHUNK_SIZE < n) numChunks++;
424+ for (int i = 0 ; i < numChunks; ++i) {
425+ int spanStart = i * this ->INMF_CHUNK_SIZE ;
426+ int spanEnd = (i + 1 ) * this ->INMF_CHUNK_SIZE - 1 ;
427+ if (spanEnd > n - 1 ) spanEnd = n - 1 ;
428+ arma::mat dense_span = Eptr->cols (spanStart, spanEnd);
429+ arma::sp_mat sparse_span (dense_span);
430+ (*out).cols (spanStart, spanEnd) = sparse_span;
431+ }
432+ // Get raw pointer before transferring ownership
433+ // arma::sp_mat* rawPtr = out.get();
434+ // Move unique_ptr to a class member to ensure it stays alive
435+ this ->tempE = std::move (out);
436+ // return rawPtr;
437+ }
438+
366439}
0 commit comments