Skip to content

Commit 3741b9f

Browse files
committed
refactor
1 parent 769804b commit 3741b9f

File tree

1 file changed

+47
-48
lines changed

1 file changed

+47
-48
lines changed

cornac/models/sansa/recom_sansa.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
import numpy as np
22
import scipy.sparse as sp
33

4-
from sansa.core import (
5-
FactorizationMethod,
6-
GramianFactorizer,
7-
CHOLMODGramianFactorizerConfig,
8-
ICFGramianFactorizerConfig,
9-
UnitLowerTriangleInverter,
10-
UMRUnitLowerTriangleInverterConfig,
11-
)
12-
from sansa.utils import get_squared_norms_along_compressed_axis, inplace_scale_along_compressed_axis, inplace_scale_along_uncompressed_axis
13-
144
from ..recommender import Recommender
155
from ..recommender import ANNMixin, MEASURE_DOT
166
from ...exception import ScoreException
@@ -108,30 +98,17 @@ def __init__(
10898
self.l2 = l2
10999
self.weight_matrix_density = weight_matrix_density
110100
self.compute_gramian = compute_gramian
101+
self.factorizer_class = factorizer_class
102+
self.factorizer_shift_step = factorizer_shift_step
103+
self.factorizer_shift_multiplier = factorizer_shift_multiplier
104+
self.inverter_scans = inverter_scans
105+
self.inverter_finetune_steps = inverter_finetune_steps
111106
self.use_absolute_value_scores = use_absolute_value_scores
112107
self.verbose = verbose
113108
self.seed = seed
114-
self.X = X
115-
if self.X is not None:
116-
self.X = self.X.astype(np.float32)
109+
self.X = X.astype(np.float32) if X is not None and X.dtype != np.float32 else X
117110
self.weights = (W1, W2)
118111

119-
if factorizer_class == "CHOLMOD":
120-
self.factorizer_config = CHOLMODGramianFactorizerConfig()
121-
else:
122-
self.factorizer_config = ICFGramianFactorizerConfig(
123-
factorization_shift_step=factorizer_shift_step, # initial diagonal shift if incomplete factorization fails
124-
factorization_shift_multiplier=factorizer_shift_multiplier, # multiplier for the shift for subsequent attempts
125-
)
126-
self.factorizer = GramianFactorizer.from_config(self.factorizer_config)
127-
self.factorization_method = self.factorizer_config.factorization_method
128-
129-
self.inverter_config = UMRUnitLowerTriangleInverterConfig(
130-
scans=inverter_scans, # number of scans through all columns of the matrix
131-
finetune_steps=inverter_finetune_steps, # number of finetuning steps, targeting worst columns
132-
)
133-
self.inverter = UnitLowerTriangleInverter.from_config(self.inverter_config)
134-
135112
def fit(self, train_set, val_set=None):
136113
"""Fit the model to observations.
137114
@@ -149,15 +126,55 @@ def fit(self, train_set, val_set=None):
149126
"""
150127
Recommender.fit(self, train_set, val_set)
151128

129+
from sansa.core import (
130+
FactorizationMethod,
131+
GramianFactorizer,
132+
CHOLMODGramianFactorizerConfig,
133+
ICFGramianFactorizerConfig,
134+
UnitLowerTriangleInverter,
135+
UMRUnitLowerTriangleInverterConfig,
136+
)
137+
from sansa.utils import get_squared_norms_along_compressed_axis, inplace_scale_along_compressed_axis, inplace_scale_along_uncompressed_axis
138+
152139
# User-item interaction matrix (sp.csr_matrix)
153140
self.X = train_set.matrix.astype(np.float32)
154141

142+
if self.factorizer_class == "CHOLMOD":
143+
self.factorizer_config = CHOLMODGramianFactorizerConfig()
144+
else:
145+
self.factorizer_config = ICFGramianFactorizerConfig(
146+
factorization_shift_step=self.factorizer_shift_step, # initial diagonal shift if incomplete factorization fails
147+
factorization_shift_multiplier=self.factorizer_shift_multiplier, # multiplier for the shift for subsequent attempts
148+
)
149+
self.factorizer = GramianFactorizer.from_config(self.factorizer_config)
150+
self.factorization_method = self.factorizer_config.factorization_method
151+
152+
self.inverter_config = UMRUnitLowerTriangleInverterConfig(
153+
scans=self.inverter_scans, # number of scans through all columns of the matrix
154+
finetune_steps=self.inverter_finetune_steps, # number of finetuning steps, targeting worst columns
155+
)
156+
self.inverter = UnitLowerTriangleInverter.from_config(self.inverter_config)
157+
155158
# create a working copy of user_item_matrix
156159
X = self.X.copy()
157160

158161
if self.factorization_method == FactorizationMethod.ICF:
159162
# scale matrix X
160-
_apply_icf_scaling(X, self.compute_gramian)
163+
if self.compute_gramian:
164+
# Inplace scale columns of X by square roots of column norms of X^TX.
165+
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X.T @ X)))
166+
# Divide columns of X by the computed square roots of row norms of X^TX
167+
da[da == 0] = 1 # ignore zero elements
168+
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
169+
del da
170+
else:
171+
# Inplace scale rows and columns of X by square roots of row norms of X.
172+
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X)))
173+
# Divide rows and columns of X by the computed square roots of row norms of X
174+
da[da == 0] = 1 # ignore zero elements
175+
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
176+
inplace_scale_along_compressed_axis(X, 1 / da) # CSR row scaling
177+
del da
161178

162179
# Compute LDL^T decomposition of
163180
# - P(X^TX + self.l2 * I)P^T if compute_gramian=True
@@ -270,21 +287,3 @@ def get_item_vectors(self):
270287
Matrix of item vectors for all items available in the model.
271288
"""
272289
return self.self.weights[1]
273-
274-
275-
def _apply_icf_scaling(X: sp.csr_matrix, compute_gramian: bool) -> None:
276-
if compute_gramian:
277-
# Inplace scale columns of X by square roots of column norms of X^TX.
278-
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X.T @ X)))
279-
# Divide columns of X by the computed square roots of row norms of X^TX
280-
da[da == 0] = 1 # ignore zero elements
281-
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
282-
del da
283-
else:
284-
# Inplace scale rows and columns of X by square roots of row norms of X.
285-
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X)))
286-
# Divide rows and columns of X by the computed square roots of row norms of X
287-
da[da == 0] = 1 # ignore zero elements
288-
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
289-
inplace_scale_along_compressed_axis(X, 1 / da) # CSR row scaling
290-
del da

0 commit comments

Comments
 (0)