Skip to content

Commit 5ed1f9d

Browse files
authored
Refactor BiVAECF code (PreferredAI#376)
1 parent b4f30f4 commit 5ed1f9d

2 files changed

Lines changed: 137 additions & 111 deletions

File tree

cornac/models/bivaecf/bivae.py

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16+
import itertools as it
17+
1618
import numpy as np
1719
import torch
1820
import torch.nn as nn
1921
from tqdm.auto import trange
2022

21-
torch.set_default_dtype(torch.float32)
2223

2324
EPS = 1e-10
2425

@@ -32,18 +33,25 @@
3233

3334

3435
class BiVAE(nn.Module):
35-
def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, cap_priors, feature_dim, batch_size):
36+
def __init__(
37+
self,
38+
k,
39+
user_encoder_structure,
40+
item_encoder_structure,
41+
act_fn,
42+
likelihood,
43+
cap_priors,
44+
feature_dim,
45+
batch_size,
46+
):
3647
super(BiVAE, self).__init__()
3748

38-
self.mu_theta = torch.zeros((item_e_structure[0], k)) # n_users*k
39-
self.mu_beta = torch.zeros((user_e_structure[0], k)) # n_items*k
49+
self.mu_theta = torch.zeros((item_encoder_structure[0], k)) # n_users*k
50+
self.mu_beta = torch.zeros((user_encoder_structure[0], k)) # n_items*k
4051

41-
# zero mean for standard normal prior
42-
self.mu_prior = torch.zeros((batch_size, k), requires_grad=False)
43-
44-
self.theta_ = torch.randn(item_e_structure[0], k) * 0.01
45-
self.beta_ = torch.randn(user_e_structure[0], k) * 0.01
46-
torch.nn.init.kaiming_uniform_(self.theta_, a=np.sqrt(5))
52+
self.theta = torch.randn(item_encoder_structure[0], k) * 0.01
53+
self.beta = torch.randn(user_encoder_structure[0], k) * 0.01
54+
torch.nn.init.kaiming_uniform_(self.theta, a=np.sqrt(5))
4755

4856
self.likelihood = likelihood
4957
self.act_fn = ACT.get(act_fn, None)
@@ -58,23 +66,32 @@ def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, ca
5866

5967
# User Encoder
6068
self.user_encoder = nn.Sequential()
61-
for i in range(len(user_e_structure) - 1):
69+
for i in range(len(user_encoder_structure) - 1):
6270
self.user_encoder.add_module(
63-
"fc{}".format(i), nn.Linear(user_e_structure[i], user_e_structure[i + 1])
71+
"fc{}".format(i),
72+
nn.Linear(user_encoder_structure[i], user_encoder_structure[i + 1]),
6473
)
6574
self.user_encoder.add_module("act{}".format(i), self.act_fn)
66-
self.user_mu = nn.Linear(user_e_structure[-1], k) # mu
67-
self.user_std = nn.Linear(user_e_structure[-1], k)
75+
self.user_mu = nn.Linear(user_encoder_structure[-1], k) # mu
76+
self.user_std = nn.Linear(user_encoder_structure[-1], k)
6877

6978
# Item Encoder
7079
self.item_encoder = nn.Sequential()
71-
for i in range(len(item_e_structure) - 1):
80+
for i in range(len(item_encoder_structure) - 1):
7281
self.item_encoder.add_module(
73-
"fc{}".format(i), nn.Linear(item_e_structure[i], item_e_structure[i + 1])
82+
"fc{}".format(i),
83+
nn.Linear(item_encoder_structure[i], item_encoder_structure[i + 1]),
7484
)
7585
self.item_encoder.add_module("act{}".format(i), self.act_fn)
76-
self.item_mu = nn.Linear(item_e_structure[-1], k) # mu
77-
self.item_std = nn.Linear(item_e_structure[-1], k)
86+
self.item_mu = nn.Linear(item_encoder_structure[-1], k) # mu
87+
self.item_std = nn.Linear(item_encoder_structure[-1], k)
88+
89+
def to(self, device):
90+
self.beta = self.beta.to(device=device)
91+
self.theta = self.theta.to(device=device)
92+
self.mu_beta = self.mu_beta.to(device=device)
93+
self.mu_theta = self.mu_theta.to(device=device)
94+
return super(BiVAE, self).to(device)
7895

7996
def encode_user_prior(self, x):
8097
h = self.user_prior_encoder(x)
@@ -137,71 +154,61 @@ def loss(self, x, x_, mu, mu_prior, std, kl_beta):
137154

138155

139156
def learn(
140-
bivae,
141-
train_set,
142-
n_epochs,
143-
batch_size,
144-
learn_rate,
145-
beta_kl,
146-
verbose,
147-
device=torch.device("cpu"),
157+
bivae,
158+
train_set,
159+
n_epochs,
160+
batch_size,
161+
learn_rate,
162+
beta_kl,
163+
verbose,
164+
device=torch.device("cpu"),
165+
dtype=torch.float32,
148166
):
149-
user_params = [{'params': bivae.user_encoder.parameters()}, \
150-
{'params': bivae.user_mu.parameters()}, \
151-
{'params': bivae.user_std.parameters()}, \
152-
]
153-
154-
item_params = [{'params': bivae.item_encoder.parameters()}, \
155-
{'params': bivae.item_mu.parameters()}, \
156-
{'params': bivae.item_std.parameters()}, \
157-
]
167+
user_params = it.chain(
168+
bivae.user_encoder.parameters(),
169+
bivae.user_mu.parameters(),
170+
bivae.user_std.parameters(),
171+
)
172+
173+
item_params = it.chain(
174+
bivae.item_encoder.parameters(),
175+
bivae.item_mu.parameters(),
176+
bivae.item_std.parameters(),
177+
)
158178

159179
if bivae.cap_priors.get("user", False):
160-
user_params.append({'params': bivae.user_prior_encoder.parameters()})
180+
user_params = it.chain(user_params, bivae.user_prior_encoder.parameters())
161181
user_features = train_set.user_feature.features[: train_set.num_users]
162-
user_features = torch.from_numpy(user_features).float().to(device)
163182

164183
if bivae.cap_priors.get("item", False):
165-
item_params.append({'params': bivae.item_prior_encoder.parameters()})
184+
item_params = it.chain(item_params, bivae.item_prior_encoder.parameters())
166185
item_features = train_set.item_feature.features[: train_set.num_items]
167-
item_features = torch.from_numpy(item_features).float().to(device)
168186

169187
u_optimizer = torch.optim.Adam(params=user_params, lr=learn_rate)
170188
i_optimizer = torch.optim.Adam(params=item_params, lr=learn_rate)
171189

172-
progress_bar = trange(1, n_epochs + 1, disable=not verbose)
173-
bivae.beta_ = bivae.beta_.to(device=device)
174-
bivae.theta_ = bivae.theta_.to(device=device)
175-
176-
bivae.mu_beta = bivae.mu_beta.to(device=device)
177-
bivae.mu_theta = bivae.mu_theta.to(device=device)
178-
179-
bivae.mu_prior = bivae.mu_prior.to(device=device)
180-
181190
x = train_set.matrix.copy()
182-
x.data = np.ones(len(x.data)) # Binarize data
191+
x.data = np.ones_like(x.data) # Binarize data
183192
tx = x.transpose()
184193

194+
progress_bar = trange(1, n_epochs + 1, disable=not verbose)
185195
for _ in progress_bar:
186-
187196
# item side
188197
i_sum_loss = 0.0
189198
i_count = 0
190-
for batch_id, i_ids in enumerate(
191-
train_set.item_iter(batch_size, shuffle=False)
192-
):
199+
for i_ids in train_set.item_iter(batch_size, shuffle=False):
193200
i_batch = tx[i_ids, :]
194201
i_batch = i_batch.A
195-
i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device)
202+
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)
196203

197204
# Reconstructed batch
198-
beta, i_batch_, i_mu, i_std = bivae(i_batch, user=False, theta=bivae.theta_)
205+
beta, i_batch_, i_mu, i_std = bivae(i_batch, user=False, theta=bivae.theta)
199206

207+
i_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
200208
if bivae.cap_priors.get("item", False):
201209
i_batch_f = item_features[i_ids]
210+
i_batch_f = torch.tensor(i_batch_f, dtype=dtype, device=device)
202211
i_mu_prior = bivae.encode_item_prior(i_batch_f)
203-
else:
204-
i_mu_prior = bivae.mu_prior[0:len(i_batch)]
205212

206213
i_loss = bivae.loss(i_batch, i_batch_, i_mu, i_mu_prior, i_std, beta_kl)
207214
i_optimizer.zero_grad()
@@ -211,29 +218,27 @@ def learn(
211218
i_sum_loss += i_loss.data.item()
212219
i_count += len(i_batch)
213220

214-
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_)
221+
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta)
215222

216-
bivae.beta_.data[i_ids] = beta.data
223+
bivae.beta.data[i_ids] = beta.data
217224
bivae.mu_beta.data[i_ids] = i_mu.data
218225

219226
# user side
220227
u_sum_loss = 0.0
221228
u_count = 0
222-
for batch_id, u_ids in enumerate(
223-
train_set.user_iter(batch_size, shuffle=False)
224-
):
229+
for u_ids in train_set.user_iter(batch_size, shuffle=False):
225230
u_batch = x[u_ids, :]
226231
u_batch = u_batch.A
227-
u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device)
232+
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)
228233

229234
# Reconstructed batch
230-
theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta_)
235+
theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta)
231236

237+
u_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
232238
if bivae.cap_priors.get("user", False):
233239
u_batch_f = user_features[u_ids]
240+
u_batch_f = torch.tensor(u_batch_f, dtype=dtype, device=device)
234241
u_mu_prior = bivae.encode_user_prior(u_batch_f)
235-
else:
236-
u_mu_prior = bivae.mu_prior[0:len(u_batch)]
237242

238243
u_loss = bivae.loss(u_batch, u_batch_, u_mu, u_mu_prior, u_std, beta_kl)
239244
u_optimizer.zero_grad()
@@ -243,28 +248,30 @@ def learn(
243248
u_sum_loss += u_loss.data.item()
244249
u_count += len(u_batch)
245250

246-
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_)
247-
bivae.theta_.data[u_ids] = theta.data
251+
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta)
252+
bivae.theta.data[u_ids] = theta.data
248253
bivae.mu_theta.data[u_ids] = u_mu.data
249254

250-
progress_bar.set_postfix(loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count)))
255+
progress_bar.set_postfix(
256+
loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count))
257+
)
251258

252259
# infer mu_beta
253-
for batch_id, i_ids in enumerate(train_set.item_iter(batch_size, shuffle=False)):
260+
for i_ids in train_set.item_iter(batch_size, shuffle=False):
254261
i_batch = tx[i_ids, :]
255262
i_batch = i_batch.A
256-
i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device)
263+
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)
257264

258-
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_)
265+
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta)
259266
bivae.mu_beta.data[i_ids] = i_mu.data
260267

261268
# infer mu_theta
262-
for batch_id, u_ids in enumerate(train_set.user_iter(batch_size, shuffle=False)):
269+
for u_ids in train_set.user_iter(batch_size, shuffle=False):
263270
u_batch = x[u_ids, :]
264271
u_batch = u_batch.A
265-
u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device)
272+
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)
266273

267-
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_)
274+
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta)
268275
bivae.mu_theta.data[u_ids] = u_mu.data
269276

270277
return bivae

0 commit comments

Comments
 (0)