@@ -139,7 +139,7 @@ def loss(self, x, x_, mu_qz, logvar_qz, mu_qhx, logvar_qhx, mu_qhy, logvar_qhy,
139139 ll_choices = {
140140 "mult" : x * torch .log (x_ + EPS ),
141141 "bern" : x * torch .log (x_ + EPS ) + (1 - x ) * torch .log (1 - x_ + EPS ),
142- "gaus" : - (x - x_ ) ** 2 ,
142+ "gaus" : - (( x - x_ ) ** 2 ) ,
143143 "pois" : x * torch .log (x_ + EPS ) - x_ ,
144144 }
145145
@@ -160,29 +160,34 @@ def loss(self, x, x_, mu_qz, logvar_qz, mu_qhx, logvar_qhx, mu_qhy, logvar_qhy,
160160 std_ph = torch .exp (0.5 * logvar_ph )
161161
162162 # KL(q(h|x)||p(h|x))
163- kld_hx = - 0.5 * (1 + 2.0 * torch .log (std_qhx ) - (mu_qhx - mu_ph ).pow (2 ) - std_qhx .pow (
164- 2 )) # assuming std_ph is 1 for now
163+ kld_hx = - 0.5 * (
164+ 1 + 2.0 * torch .log (std_qhx ) - (mu_qhx - mu_ph ).pow (2 ) - std_qhx .pow (2 )
165+ ) # assuming std_ph is 1 for now
165166 kld_hx = torch .sum (kld_hx , dim = 1 )
166167
167168 # KL(q(h|x)||q(h|y))
168- kld_hy = - 0.5 * (1 + 2.0 * torch .log (std_qhx ) - 2.0 * torch .log (std_qhy ) - (
169- (mu_qhx - mu_qhy ).pow (2 ) + std_qhx .pow (2 )) / std_qhy .pow (2 )) # assuming std_ph is 1 for now
169+ kld_hy = - 0.5 * (
170+ 1
171+ + 2.0 * torch .log (std_qhx )
172+ - 2.0 * torch .log (std_qhy )
173+ - ((mu_qhx - mu_qhy ).pow (2 ) + std_qhx .pow (2 )) / std_qhy .pow (2 )
174+ ) # assuming std_ph is 1 for now
170175 kld_hy = torch .sum (kld_hy , dim = 1 )
171176
172177 return torch .mean (beta * kld_z + alpha_1 * kld_hx + alpha_2 * kld_hy - ll )
173178
174179
175180def learn (
176- cvae ,
177- train_set ,
178- n_epochs ,
179- batch_size ,
180- learn_rate ,
181- beta ,
182- alpha_1 ,
183- alpha_2 ,
184- verbose ,
185- device = torch .device ("cpu" ),
181+ cvae ,
182+ train_set ,
183+ n_epochs ,
184+ batch_size ,
185+ learn_rate ,
186+ beta ,
187+ alpha_1 ,
188+ alpha_2 ,
189+ verbose ,
190+ device = torch .device ("cpu" ),
186191):
187192 optimizer = torch .optim .Adam (params = cvae .parameters (), lr = learn_rate )
188193
@@ -197,11 +202,11 @@ def learn(
197202 ):
198203 y_batch = y [u_ids , :]
199204 y_batch .data = np .ones (len (y_batch .data )) # Binarize data
200- y_batch = y_batch .A
205+ y_batch = y_batch .toarray ()
201206 y_batch = torch .tensor (y_batch , dtype = torch .float32 , device = device )
202207
203208 x_batch = x [u_ids , :]
204- x_batch = x_batch .A
209+ x_batch = x_batch .toarray ()
205210 x_batch = torch .tensor (x_batch , dtype = torch .float32 , device = device )
206211
207212 # Reconstructed batch
0 commit comments