1313# limitations under the License.
1414# ============================================================================
1515
16+ import itertools as it
17+
1618import numpy as np
1719import torch
1820import torch .nn as nn
1921from tqdm .auto import trange
2022
21- torch .set_default_dtype (torch .float32 )
2223
2324EPS = 1e-10
2425
3233
3334
3435class BiVAE (nn .Module ):
35- def __init__ (self , k , user_e_structure , item_e_structure , act_fn , likelihood , cap_priors , feature_dim , batch_size ):
36+ def __init__ (
37+ self ,
38+ k ,
39+ user_encoder_structure ,
40+ item_encoder_structure ,
41+ act_fn ,
42+ likelihood ,
43+ cap_priors ,
44+ feature_dim ,
45+ batch_size ,
46+ ):
3647 super (BiVAE , self ).__init__ ()
3748
38- self .mu_theta = torch .zeros ((item_e_structure [0 ], k )) # n_users*k
39- self .mu_beta = torch .zeros ((user_e_structure [0 ], k )) # n_items*k
49+ self .mu_theta = torch .zeros ((item_encoder_structure [0 ], k )) # n_users*k
50+ self .mu_beta = torch .zeros ((user_encoder_structure [0 ], k )) # n_items*k
4051
41- # zero mean for standard normal prior
42- self .mu_prior = torch .zeros ((batch_size , k ), requires_grad = False )
43-
44- self .theta_ = torch .randn (item_e_structure [0 ], k ) * 0.01
45- self .beta_ = torch .randn (user_e_structure [0 ], k ) * 0.01
46- torch .nn .init .kaiming_uniform_ (self .theta_ , a = np .sqrt (5 ))
52+ self .theta = torch .randn (item_encoder_structure [0 ], k ) * 0.01
53+ self .beta = torch .randn (user_encoder_structure [0 ], k ) * 0.01
54+ torch .nn .init .kaiming_uniform_ (self .theta , a = np .sqrt (5 ))
4755
4856 self .likelihood = likelihood
4957 self .act_fn = ACT .get (act_fn , None )
@@ -58,23 +66,32 @@ def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, ca
5866
5967 # User Encoder
6068 self .user_encoder = nn .Sequential ()
61- for i in range (len (user_e_structure ) - 1 ):
69+ for i in range (len (user_encoder_structure ) - 1 ):
6270 self .user_encoder .add_module (
63- "fc{}" .format (i ), nn .Linear (user_e_structure [i ], user_e_structure [i + 1 ])
71+ "fc{}" .format (i ),
72+ nn .Linear (user_encoder_structure [i ], user_encoder_structure [i + 1 ]),
6473 )
6574 self .user_encoder .add_module ("act{}" .format (i ), self .act_fn )
66- self .user_mu = nn .Linear (user_e_structure [- 1 ], k ) # mu
67- self .user_std = nn .Linear (user_e_structure [- 1 ], k )
75+ self .user_mu = nn .Linear (user_encoder_structure [- 1 ], k ) # mu
76+ self .user_std = nn .Linear (user_encoder_structure [- 1 ], k )
6877
6978 # Item Encoder
7079 self .item_encoder = nn .Sequential ()
71- for i in range (len (item_e_structure ) - 1 ):
80+ for i in range (len (item_encoder_structure ) - 1 ):
7281 self .item_encoder .add_module (
73- "fc{}" .format (i ), nn .Linear (item_e_structure [i ], item_e_structure [i + 1 ])
82+ "fc{}" .format (i ),
83+ nn .Linear (item_encoder_structure [i ], item_encoder_structure [i + 1 ]),
7484 )
7585 self .item_encoder .add_module ("act{}" .format (i ), self .act_fn )
76- self .item_mu = nn .Linear (item_e_structure [- 1 ], k ) # mu
77- self .item_std = nn .Linear (item_e_structure [- 1 ], k )
86+ self .item_mu = nn .Linear (item_encoder_structure [- 1 ], k ) # mu
87+ self .item_std = nn .Linear (item_encoder_structure [- 1 ], k )
88+
89+ def to (self , device ):
90+ self .beta = self .beta .to (device = device )
91+ self .theta = self .theta .to (device = device )
92+ self .mu_beta = self .mu_beta .to (device = device )
93+ self .mu_theta = self .mu_theta .to (device = device )
94+ return super (BiVAE , self ).to (device )
7895
7996 def encode_user_prior (self , x ):
8097 h = self .user_prior_encoder (x )
@@ -137,71 +154,61 @@ def loss(self, x, x_, mu, mu_prior, std, kl_beta):
137154
138155
139156def learn (
140- bivae ,
141- train_set ,
142- n_epochs ,
143- batch_size ,
144- learn_rate ,
145- beta_kl ,
146- verbose ,
147- device = torch .device ("cpu" ),
157+ bivae ,
158+ train_set ,
159+ n_epochs ,
160+ batch_size ,
161+ learn_rate ,
162+ beta_kl ,
163+ verbose ,
164+ device = torch .device ("cpu" ),
165+ dtype = torch .float32 ,
148166):
149- user_params = [{'params' : bivae .user_encoder .parameters ()}, \
150- {'params' : bivae .user_mu .parameters ()}, \
151- {'params' : bivae .user_std .parameters ()}, \
152- ]
153-
154- item_params = [{'params' : bivae .item_encoder .parameters ()}, \
155- {'params' : bivae .item_mu .parameters ()}, \
156- {'params' : bivae .item_std .parameters ()}, \
157- ]
167+ user_params = it .chain (
168+ bivae .user_encoder .parameters (),
169+ bivae .user_mu .parameters (),
170+ bivae .user_std .parameters (),
171+ )
172+
173+ item_params = it .chain (
174+ bivae .item_encoder .parameters (),
175+ bivae .item_mu .parameters (),
176+ bivae .item_std .parameters (),
177+ )
158178
159179 if bivae .cap_priors .get ("user" , False ):
160- user_params . append ({ 'params' : bivae .user_prior_encoder .parameters ()} )
180+ user_params = it . chain ( user_params , bivae .user_prior_encoder .parameters ())
161181 user_features = train_set .user_feature .features [: train_set .num_users ]
162- user_features = torch .from_numpy (user_features ).float ().to (device )
163182
164183 if bivae .cap_priors .get ("item" , False ):
165- item_params . append ({ 'params' : bivae .item_prior_encoder .parameters ()} )
184+ item_params = it . chain ( item_params , bivae .item_prior_encoder .parameters ())
166185 item_features = train_set .item_feature .features [: train_set .num_items ]
167- item_features = torch .from_numpy (item_features ).float ().to (device )
168186
169187 u_optimizer = torch .optim .Adam (params = user_params , lr = learn_rate )
170188 i_optimizer = torch .optim .Adam (params = item_params , lr = learn_rate )
171189
172- progress_bar = trange (1 , n_epochs + 1 , disable = not verbose )
173- bivae .beta_ = bivae .beta_ .to (device = device )
174- bivae .theta_ = bivae .theta_ .to (device = device )
175-
176- bivae .mu_beta = bivae .mu_beta .to (device = device )
177- bivae .mu_theta = bivae .mu_theta .to (device = device )
178-
179- bivae .mu_prior = bivae .mu_prior .to (device = device )
180-
181190 x = train_set .matrix .copy ()
182- x .data = np .ones ( len ( x .data ) ) # Binarize data
191+ x .data = np .ones_like ( x .data ) # Binarize data
183192 tx = x .transpose ()
184193
194+ progress_bar = trange (1 , n_epochs + 1 , disable = not verbose )
185195 for _ in progress_bar :
186-
187196 # item side
188197 i_sum_loss = 0.0
189198 i_count = 0
190- for batch_id , i_ids in enumerate (
191- train_set .item_iter (batch_size , shuffle = False )
192- ):
199+ for i_ids in train_set .item_iter (batch_size , shuffle = False ):
193200 i_batch = tx [i_ids , :]
194201 i_batch = i_batch .A
195- i_batch = torch .tensor (i_batch , dtype = torch . float32 , device = device )
202+ i_batch = torch .tensor (i_batch , dtype = dtype , device = device )
196203
197204 # Reconstructed batch
198- beta , i_batch_ , i_mu , i_std = bivae (i_batch , user = False , theta = bivae .theta_ )
205+ beta , i_batch_ , i_mu , i_std = bivae (i_batch , user = False , theta = bivae .theta )
199206
207+ i_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
200208 if bivae .cap_priors .get ("item" , False ):
201209 i_batch_f = item_features [i_ids ]
210+ i_batch_f = torch .tensor (i_batch_f , dtype = dtype , device = device )
202211 i_mu_prior = bivae .encode_item_prior (i_batch_f )
203- else :
204- i_mu_prior = bivae .mu_prior [0 :len (i_batch )]
205212
206213 i_loss = bivae .loss (i_batch , i_batch_ , i_mu , i_mu_prior , i_std , beta_kl )
207214 i_optimizer .zero_grad ()
@@ -211,29 +218,27 @@ def learn(
211218 i_sum_loss += i_loss .data .item ()
212219 i_count += len (i_batch )
213220
214- beta , _ , i_mu , _ = bivae (i_batch , user = False , theta = bivae .theta_ )
221+ beta , _ , i_mu , _ = bivae (i_batch , user = False , theta = bivae .theta )
215222
216- bivae .beta_ .data [i_ids ] = beta .data
223+ bivae .beta .data [i_ids ] = beta .data
217224 bivae .mu_beta .data [i_ids ] = i_mu .data
218225
219226 # user side
220227 u_sum_loss = 0.0
221228 u_count = 0
222- for batch_id , u_ids in enumerate (
223- train_set .user_iter (batch_size , shuffle = False )
224- ):
229+ for u_ids in train_set .user_iter (batch_size , shuffle = False ):
225230 u_batch = x [u_ids , :]
226231 u_batch = u_batch .A
227- u_batch = torch .tensor (u_batch , dtype = torch . float32 , device = device )
232+ u_batch = torch .tensor (u_batch , dtype = dtype , device = device )
228233
229234 # Reconstructed batch
230- theta , u_batch_ , u_mu , u_std = bivae (u_batch , user = True , beta = bivae .beta_ )
235+ theta , u_batch_ , u_mu , u_std = bivae (u_batch , user = True , beta = bivae .beta )
231236
237+ u_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
232238 if bivae .cap_priors .get ("user" , False ):
233239 u_batch_f = user_features [u_ids ]
240+ u_batch_f = torch .tensor (u_batch_f , dtype = dtype , device = device )
234241 u_mu_prior = bivae .encode_user_prior (u_batch_f )
235- else :
236- u_mu_prior = bivae .mu_prior [0 :len (u_batch )]
237242
238243 u_loss = bivae .loss (u_batch , u_batch_ , u_mu , u_mu_prior , u_std , beta_kl )
239244 u_optimizer .zero_grad ()
@@ -243,28 +248,30 @@ def learn(
243248 u_sum_loss += u_loss .data .item ()
244249 u_count += len (u_batch )
245250
246- theta , _ , u_mu , _ = bivae (u_batch , user = True , beta = bivae .beta_ )
247- bivae .theta_ .data [u_ids ] = theta .data
251+ theta , _ , u_mu , _ = bivae (u_batch , user = True , beta = bivae .beta )
252+ bivae .theta .data [u_ids ] = theta .data
248253 bivae .mu_theta .data [u_ids ] = u_mu .data
249254
250- progress_bar .set_postfix (loss_i = (i_sum_loss / i_count ), loss_u = (u_sum_loss / (u_count )))
255+ progress_bar .set_postfix (
256+ loss_i = (i_sum_loss / i_count ), loss_u = (u_sum_loss / (u_count ))
257+ )
251258
252259 # infer mu_beta
253- for batch_id , i_ids in enumerate ( train_set .item_iter (batch_size , shuffle = False ) ):
260+ for i_ids in train_set .item_iter (batch_size , shuffle = False ):
254261 i_batch = tx [i_ids , :]
255262 i_batch = i_batch .A
256- i_batch = torch .tensor (i_batch , dtype = torch . float32 , device = device )
263+ i_batch = torch .tensor (i_batch , dtype = dtype , device = device )
257264
258- beta , _ , i_mu , _ = bivae (i_batch , user = False , theta = bivae .theta_ )
265+ beta , _ , i_mu , _ = bivae (i_batch , user = False , theta = bivae .theta )
259266 bivae .mu_beta .data [i_ids ] = i_mu .data
260267
261268 # infer mu_theta
262- for batch_id , u_ids in enumerate ( train_set .user_iter (batch_size , shuffle = False ) ):
269+ for u_ids in train_set .user_iter (batch_size , shuffle = False ):
263270 u_batch = x [u_ids , :]
264271 u_batch = u_batch .A
265- u_batch = torch .tensor (u_batch , dtype = torch . float32 , device = device )
272+ u_batch = torch .tensor (u_batch , dtype = dtype , device = device )
266273
267- theta , _ , u_mu , _ = bivae (u_batch , user = True , beta = bivae .beta_ )
274+ theta , _ , u_mu , _ = bivae (u_batch , user = True , beta = bivae .beta )
268275 bivae .mu_theta .data [u_ids ] = u_mu .data
269276
270277 return bivae
0 commit comments