Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 82 additions & 75 deletions cornac/models/bivaecf/bivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
# ============================================================================

import itertools as it

import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import trange

torch.set_default_dtype(torch.float32)

EPS = 1e-10

Expand All @@ -32,18 +33,25 @@


class BiVAE(nn.Module):
def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, cap_priors, feature_dim, batch_size):
def __init__(
self,
k,
user_encoder_structure,
item_encoder_structure,
act_fn,
likelihood,
cap_priors,
feature_dim,
batch_size,
):
super(BiVAE, self).__init__()

self.mu_theta = torch.zeros((item_e_structure[0], k)) # n_users*k
self.mu_beta = torch.zeros((user_e_structure[0], k)) # n_items*k
self.mu_theta = torch.zeros((item_encoder_structure[0], k)) # n_users*k
self.mu_beta = torch.zeros((user_encoder_structure[0], k)) # n_items*k

# zero mean for standard normal prior
self.mu_prior = torch.zeros((batch_size, k), requires_grad=False)

self.theta_ = torch.randn(item_e_structure[0], k) * 0.01
self.beta_ = torch.randn(user_e_structure[0], k) * 0.01
torch.nn.init.kaiming_uniform_(self.theta_, a=np.sqrt(5))
self.theta = torch.randn(item_encoder_structure[0], k) * 0.01
self.beta = torch.randn(user_encoder_structure[0], k) * 0.01
torch.nn.init.kaiming_uniform_(self.theta, a=np.sqrt(5))

self.likelihood = likelihood
self.act_fn = ACT.get(act_fn, None)
Expand All @@ -58,23 +66,32 @@ def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, ca

# User Encoder
self.user_encoder = nn.Sequential()
for i in range(len(user_e_structure) - 1):
for i in range(len(user_encoder_structure) - 1):
self.user_encoder.add_module(
"fc{}".format(i), nn.Linear(user_e_structure[i], user_e_structure[i + 1])
"fc{}".format(i),
nn.Linear(user_encoder_structure[i], user_encoder_structure[i + 1]),
)
self.user_encoder.add_module("act{}".format(i), self.act_fn)
self.user_mu = nn.Linear(user_e_structure[-1], k) # mu
self.user_std = nn.Linear(user_e_structure[-1], k)
self.user_mu = nn.Linear(user_encoder_structure[-1], k) # mu
self.user_std = nn.Linear(user_encoder_structure[-1], k)

# Item Encoder
self.item_encoder = nn.Sequential()
for i in range(len(item_e_structure) - 1):
for i in range(len(item_encoder_structure) - 1):
self.item_encoder.add_module(
"fc{}".format(i), nn.Linear(item_e_structure[i], item_e_structure[i + 1])
"fc{}".format(i),
nn.Linear(item_encoder_structure[i], item_encoder_structure[i + 1]),
)
self.item_encoder.add_module("act{}".format(i), self.act_fn)
self.item_mu = nn.Linear(item_e_structure[-1], k) # mu
self.item_std = nn.Linear(item_e_structure[-1], k)
self.item_mu = nn.Linear(item_encoder_structure[-1], k) # mu
self.item_std = nn.Linear(item_encoder_structure[-1], k)

def to(self, device):
self.beta = self.beta.to(device=device)
self.theta = self.theta.to(device=device)
self.mu_beta = self.mu_beta.to(device=device)
self.mu_theta = self.mu_theta.to(device=device)
return super(BiVAE, self).to(device)

def encode_user_prior(self, x):
h = self.user_prior_encoder(x)
Expand Down Expand Up @@ -137,71 +154,61 @@ def loss(self, x, x_, mu, mu_prior, std, kl_beta):


def learn(
bivae,
train_set,
n_epochs,
batch_size,
learn_rate,
beta_kl,
verbose,
device=torch.device("cpu"),
bivae,
train_set,
n_epochs,
batch_size,
learn_rate,
beta_kl,
verbose,
device=torch.device("cpu"),
dtype=torch.float32,
):
user_params = [{'params': bivae.user_encoder.parameters()}, \
{'params': bivae.user_mu.parameters()}, \
{'params': bivae.user_std.parameters()}, \
]

item_params = [{'params': bivae.item_encoder.parameters()}, \
{'params': bivae.item_mu.parameters()}, \
{'params': bivae.item_std.parameters()}, \
]
user_params = it.chain(
bivae.user_encoder.parameters(),
bivae.user_mu.parameters(),
bivae.user_std.parameters(),
)

item_params = it.chain(
bivae.item_encoder.parameters(),
bivae.item_mu.parameters(),
bivae.item_std.parameters(),
)

if bivae.cap_priors.get("user", False):
user_params.append({'params': bivae.user_prior_encoder.parameters()})
user_params = it.chain(user_params, bivae.user_prior_encoder.parameters())
user_features = train_set.user_feature.features[: train_set.num_users]
user_features = torch.from_numpy(user_features).float().to(device)

if bivae.cap_priors.get("item", False):
item_params.append({'params': bivae.item_prior_encoder.parameters()})
item_params = it.chain(item_params, bivae.item_prior_encoder.parameters())
item_features = train_set.item_feature.features[: train_set.num_items]
item_features = torch.from_numpy(item_features).float().to(device)

u_optimizer = torch.optim.Adam(params=user_params, lr=learn_rate)
i_optimizer = torch.optim.Adam(params=item_params, lr=learn_rate)

progress_bar = trange(1, n_epochs + 1, disable=not verbose)
bivae.beta_ = bivae.beta_.to(device=device)
bivae.theta_ = bivae.theta_.to(device=device)

bivae.mu_beta = bivae.mu_beta.to(device=device)
bivae.mu_theta = bivae.mu_theta.to(device=device)

bivae.mu_prior = bivae.mu_prior.to(device=device)

x = train_set.matrix.copy()
x.data = np.ones(len(x.data)) # Binarize data
x.data = np.ones_like(x.data) # Binarize data
tx = x.transpose()

progress_bar = trange(1, n_epochs + 1, disable=not verbose)
for _ in progress_bar:

# item side
i_sum_loss = 0.0
i_count = 0
for batch_id, i_ids in enumerate(
train_set.item_iter(batch_size, shuffle=False)
):
for i_ids in train_set.item_iter(batch_size, shuffle=False):
i_batch = tx[i_ids, :]
i_batch = i_batch.A
i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device)
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)

# Reconstructed batch
beta, i_batch_, i_mu, i_std = bivae(i_batch, user=False, theta=bivae.theta_)
beta, i_batch_, i_mu, i_std = bivae(i_batch, user=False, theta=bivae.theta)

i_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
if bivae.cap_priors.get("item", False):
i_batch_f = item_features[i_ids]
i_batch_f = torch.tensor(i_batch_f, dtype=dtype, device=device)
i_mu_prior = bivae.encode_item_prior(i_batch_f)
else:
i_mu_prior = bivae.mu_prior[0:len(i_batch)]

i_loss = bivae.loss(i_batch, i_batch_, i_mu, i_mu_prior, i_std, beta_kl)
i_optimizer.zero_grad()
Expand All @@ -211,29 +218,27 @@ def learn(
i_sum_loss += i_loss.data.item()
i_count += len(i_batch)

beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_)
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta)

bivae.beta_.data[i_ids] = beta.data
bivae.beta.data[i_ids] = beta.data
bivae.mu_beta.data[i_ids] = i_mu.data

# user side
u_sum_loss = 0.0
u_count = 0
for batch_id, u_ids in enumerate(
train_set.user_iter(batch_size, shuffle=False)
):
for u_ids in train_set.user_iter(batch_size, shuffle=False):
u_batch = x[u_ids, :]
u_batch = u_batch.A
u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device)
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)

# Reconstructed batch
theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta_)
theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta)

u_mu_prior = 0.0 # zero mean for standard normal prior if not CAP prior
if bivae.cap_priors.get("user", False):
u_batch_f = user_features[u_ids]
u_batch_f = torch.tensor(u_batch_f, dtype=dtype, device=device)
u_mu_prior = bivae.encode_user_prior(u_batch_f)
else:
u_mu_prior = bivae.mu_prior[0:len(u_batch)]

u_loss = bivae.loss(u_batch, u_batch_, u_mu, u_mu_prior, u_std, beta_kl)
u_optimizer.zero_grad()
Expand All @@ -243,28 +248,30 @@ def learn(
u_sum_loss += u_loss.data.item()
u_count += len(u_batch)

theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_)
bivae.theta_.data[u_ids] = theta.data
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta)
bivae.theta.data[u_ids] = theta.data
bivae.mu_theta.data[u_ids] = u_mu.data

progress_bar.set_postfix(loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count)))
progress_bar.set_postfix(
loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count))
)

# infer mu_beta
for batch_id, i_ids in enumerate(train_set.item_iter(batch_size, shuffle=False)):
for i_ids in train_set.item_iter(batch_size, shuffle=False):
i_batch = tx[i_ids, :]
i_batch = i_batch.A
i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device)
i_batch = torch.tensor(i_batch, dtype=dtype, device=device)

beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_)
beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta)
bivae.mu_beta.data[i_ids] = i_mu.data

# infer mu_theta
for batch_id, u_ids in enumerate(train_set.user_iter(batch_size, shuffle=False)):
for u_ids in train_set.user_iter(batch_size, shuffle=False):
u_batch = x[u_ids, :]
u_batch = u_batch.A
u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device)
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)

theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_)
theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta)
bivae.mu_theta.data[u_ids] = u_mu.data

return bivae
Loading