diff --git a/README.md b/README.md index d92552f2..0f85f536 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | Year | Model and paper | Additional dependencies | Examples | | :---: | --- | :---: | :---: | +| 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](#) | [requirements.txt](cornac/models/bivaecf/requirements.txt) | | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py) | | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py) | | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | [requirements.txt](cornac/models/narre/requirements.txt) | [narre_example.py](examples/narre_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 4b000770..f3c61d66 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -16,6 +16,7 @@ from .recommender import Recommender from .baseline_only import BaselineOnly +from .bivaecf import BiVAECF from .bpr import BPR from .bpr import WBPR from .c2pf import C2PF diff --git a/cornac/models/bivaecf/__init__.py b/cornac/models/bivaecf/__init__.py new file mode 100644 index 00000000..b048076c --- /dev/null +++ b/cornac/models/bivaecf/__init__.py @@ -0,0 +1 @@ +from .recom_bivaecf import BiVAECF \ No newline at end of file diff --git a/cornac/models/bivaecf/bivae.py b/cornac/models/bivaecf/bivae.py new file mode 100644 index 00000000..1d745206 --- /dev/null +++ b/cornac/models/bivaecf/bivae.py @@ -0,0 +1,270 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import torch +import torch.nn as nn +from tqdm.auto import trange + +torch.set_default_dtype(torch.float32) + +EPS = 1e-10 + +ACT = { + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "elu": nn.ELU(), + "relu": nn.ReLU(), + "relu6": nn.ReLU6(), +} + + +class BiVAE(nn.Module): + def __init__(self, k, user_e_structure, item_e_structure, act_fn, likelihood, cap_priors, feature_dim, batch_size): + super(BiVAE, self).__init__() + + self.mu_theta = torch.zeros((item_e_structure[0], k)) # n_users*k + self.mu_beta = torch.zeros((user_e_structure[0], k)) # n_items*k + + # zero mean for standard normal prior + self.mu_prior = torch.zeros((batch_size, k), requires_grad=False) + + self.theta_ = torch.randn(item_e_structure[0], k) * 0.01 + self.beta_ = torch.randn(user_e_structure[0], k) * 0.01 + torch.nn.init.kaiming_uniform_(self.theta_, a=np.sqrt(5)) + + self.likelihood = likelihood + self.act_fn = ACT.get(act_fn, None) + if self.act_fn is None: + raise ValueError("Supported act_fn: {}".format(ACT.keys())) + + self.cap_priors = cap_priors + if self.cap_priors.get("user", False): + self.user_prior_encoder = nn.Linear(feature_dim.get("user"), k) + if self.cap_priors.get("item", False): + self.item_prior_encoder = nn.Linear(feature_dim.get("item"), k) + + # User Encoder + self.user_encoder = nn.Sequential() + for i in range(len(user_e_structure) - 1): + self.user_encoder.add_module( + "fc{}".format(i), nn.Linear(user_e_structure[i], user_e_structure[i + 1]) + ) + self.user_encoder.add_module("act{}".format(i), self.act_fn) + self.user_mu = nn.Linear(user_e_structure[-1], k) # mu + self.user_std = nn.Linear(user_e_structure[-1], k) + + # Item Encoder + self.item_encoder = nn.Sequential() + for i in range(len(item_e_structure) - 1): + self.item_encoder.add_module( + "fc{}".format(i), nn.Linear(item_e_structure[i], item_e_structure[i + 1]) + ) + self.item_encoder.add_module("act{}".format(i), self.act_fn) + self.item_mu = nn.Linear(item_e_structure[-1], k) # mu + self.item_std = nn.Linear(item_e_structure[-1], k) + + def encode_user_prior(self, x): + h = self.user_prior_encoder(x) + return h + + def encode_item_prior(self, x): + h = self.item_prior_encoder(x) + return h + + def encode_user(self, x): + h = self.user_encoder(x) + return self.user_mu(h), torch.sigmoid(self.user_std(h)) + + def encode_item(self, x): + h = self.item_encoder(x) + return self.item_mu(h), torch.sigmoid(self.item_std(h)) + + def decode_user(self, theta, beta): + h = theta.mm(beta.t()) + return torch.sigmoid(h) + + def decode_item(self, theta, beta): + h = beta.mm(theta.t()) + return torch.sigmoid(h) + + def reparameterize(self, mu, std): + eps = torch.randn_like(mu) + return mu + eps * std + + def forward(self, x, user=True, beta=None, theta=None): + + if user: + mu, std = self.encode_user(x) + theta = self.reparameterize(mu, std) + return theta, self.decode_user(theta, beta), mu, std + else: + mu, std = self.encode_item(x) + beta = self.reparameterize(mu, std) + return beta, self.decode_item(theta, beta), mu, std + + def loss(self, x, x_, mu, mu_prior, std, kl_beta): + # Likelihood + ll_choices = { + "bern": x * torch.log(x_ + EPS) + (1 - x) * torch.log(1 - x_ + EPS), + "gaus": -(x - x_) ** 2, + "pois": x * torch.log(x_ + EPS) - x_, + } + + ll = ll_choices.get(self.likelihood, None) + if ll is None: + raise ValueError("Supported likelihoods: {}".format(ll_choices.keys())) + + ll = torch.sum(ll, dim=1) + + # KL term + kld = -0.5 * (1 + 2.0 * torch.log(std) - (mu - mu_prior).pow(2) - std.pow(2)) + kld = torch.sum(kld, dim=1) + + return torch.mean(kl_beta * kld - ll) + + +def learn( + bivae, + train_set, + n_epochs, + batch_size, + learn_rate, + beta_kl, + verbose, + device=torch.device("cpu"), +): + user_params = [{'params': bivae.user_encoder.parameters()}, \ + {'params': bivae.user_mu.parameters()}, \ + {'params': bivae.user_std.parameters()}, \ + ] + + item_params = [{'params': bivae.item_encoder.parameters()}, \ + {'params': bivae.item_mu.parameters()}, \ + {'params': bivae.item_std.parameters()}, \ + ] + + if bivae.cap_priors.get("user", False): + user_params.append({'params': bivae.user_prior_encoder.parameters()}) + user_features = train_set.user_feature.features[: train_set.num_users] + user_features = torch.from_numpy(user_features).float().to(device) + + if bivae.cap_priors.get("item", False): + item_params.append({'params': bivae.item_prior_encoder.parameters()}) + item_features = train_set.item_feature.features[: train_set.num_items] + item_features = torch.from_numpy(item_features).float().to(device) + + u_optimizer = torch.optim.Adam(params=user_params, lr=learn_rate) + i_optimizer = torch.optim.Adam(params=item_params, lr=learn_rate) + + progress_bar = trange(1, n_epochs + 1, disable=not verbose) + bivae.beta_ = bivae.beta_.to(device=device) + bivae.theta_ = bivae.theta_.to(device=device) + + bivae.mu_beta = bivae.mu_beta.to(device=device) + bivae.mu_theta = bivae.mu_theta.to(device=device) + + bivae.mu_prior = bivae.mu_prior.to(device=device) + + x = train_set.matrix.copy() + x.data = np.ones(len(x.data)) # Binarize data + tx = x.transpose() + + for _ in progress_bar: + + # item side + i_sum_loss = 0.0 + i_count = 0 + for batch_id, i_ids in enumerate( + train_set.item_iter(batch_size, shuffle=False) + ): + i_batch = tx[i_ids, :] + i_batch = i_batch.A + i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device) + + # Reconstructed batch + beta, i_batch_, i_mu, i_std = bivae(i_batch, user=False, theta=bivae.theta_) + + if bivae.cap_priors.get("item", False): + i_batch_f = item_features[i_ids] + i_mu_prior = bivae.encode_item_prior(i_batch_f) + else: + i_mu_prior = bivae.mu_prior[0:len(i_batch)] + + i_loss = bivae.loss(i_batch, i_batch_, i_mu, i_mu_prior, i_std, beta_kl) + i_optimizer.zero_grad() + i_loss.backward() + i_optimizer.step() + + i_sum_loss += i_loss.data.item() + i_count += len(i_batch) + + beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_) + + bivae.beta_.data[i_ids] = beta.data + bivae.mu_beta.data[i_ids] = i_mu.data + + # user side + u_sum_loss = 0.0 + u_count = 0 + for batch_id, u_ids in enumerate( + train_set.user_iter(batch_size, shuffle=False) + ): + u_batch = x[u_ids, :] + u_batch = u_batch.A + u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device) + + # Reconstructed batch + theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta_) + + if bivae.cap_priors.get("user", False): + u_batch_f = user_features[u_ids] + u_mu_prior = bivae.encode_user_prior(u_batch_f) + else: + u_mu_prior = bivae.mu_prior[0:len(u_batch)] + + u_loss = bivae.loss(u_batch, u_batch_, u_mu, u_mu_prior, u_std, beta_kl) + u_optimizer.zero_grad() + u_loss.backward() + u_optimizer.step() + + u_sum_loss += u_loss.data.item() + u_count += len(u_batch) + + theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_) + bivae.theta_.data[u_ids] = theta.data + bivae.mu_theta.data[u_ids] = u_mu.data + + progress_bar.set_postfix(loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count))) + + # infer mu_beta + for batch_id, i_ids in enumerate(train_set.item_iter(batch_size, shuffle=False)): + i_batch = tx[i_ids, :] + i_batch = i_batch.A + i_batch = torch.tensor(i_batch, dtype=torch.float32, device=device) + + beta, _, i_mu, _ = bivae(i_batch, user=False, theta=bivae.theta_) + bivae.mu_beta.data[i_ids] = i_mu.data + + # infer mu_theta + for batch_id, u_ids in enumerate(train_set.user_iter(batch_size, shuffle=False)): + u_batch = x[u_ids, :] + u_batch = u_batch.A + u_batch = torch.tensor(u_batch, dtype=torch.float32, device=device) + + theta, _, u_mu, _ = bivae(u_batch, user=True, beta=bivae.beta_) + bivae.mu_theta.data[u_ids] = u_mu.data + + return bivae diff --git a/cornac/models/bivaecf/recom_bivaecf.py b/cornac/models/bivaecf/recom_bivaecf.py new file mode 100644 index 00000000..41080d5c --- /dev/null +++ b/cornac/models/bivaecf/recom_bivaecf.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +""" +@author: Aghiles Salah +""" + +from ..recommender import Recommender +from ...utils.common import scale +from ...exception import ScoreException + + +class BiVAECF(Recommender): + """Bilateral Variational AutoEncoder for Collaborative Filtering. + + Parameters + ---------- + k: int, optional, default: 10 + The dimension of the stochastic user ``theta'' and item ``beta'' factors. + + encoder_structure: list, default: [20] + The number of neurons per layer of the user and item encoders for BiVAE. + For example, encoder_structure = [20], the user (item) encoder structure will be [num_items, 20, k] ([num_users, 20, k]). + + act_fn: str, default: 'tanh' + Name of the activation function used between hidden layers of the auto-encoder. + Supported functions: ['sigmoid', 'tanh', 'elu', 'relu', 'relu6'] + + likelihood: str, default: 'pois' + The likelihood function used for modeling the observations. + Supported choices: + + bern: Bernoulli likelihood + gaus: Gaussian likelihood + pois: Poisson likelihood + + n_epochs: int, optional, default: 100 + The number of epochs for SGD. + + batch_size: int, optional, default: 100 + The batch size. + + learning_rate: float, optional, default: 0.001 + The learning rate for Adam. + + beta_kl: float, optional, default: 1.0 + The weight of the KL terms as in beta-VAE. + + cap_priors: dict, optional, default: {"user":False, "item":False} + When {"user":True, "item":True}, CAP priors are used (see BiVAE paper for details),\ + otherwise the standard Normal is used as a Prior over the user and item latent variables. + + name: string, optional, default: 'BiVAECF' + The name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + pre-trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + use_gpu: boolean, optional, default: True + If True and your system supports CUDA then training is performed on GPUs. + + References + ---------- + * Quoc-Tuan Truong, Aghiles Salah, Hady W. Lauw. " Bilateral Variational Autoencoder for Collaborative Filtering." + ACM International Conference on Web Search and Data Mining (WSDM). 2021. + """ + + def __init__( + self, + name="BiVAECF", + k=10, + encoder_structure=[20], + act_fn="tanh", + likelihood="pois", + n_epochs=100, + batch_size=100, + learning_rate=0.001, + beta_kl=1.0, + cap_priors={"user": False, "item": False}, + trainable=True, + verbose=False, + seed=None, + use_gpu=True, + ): + Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) + self.k = k + self.encoder_structure = encoder_structure + self.act_fn = act_fn + self.likelihood = likelihood + self.batch_size = batch_size + self.n_epochs = n_epochs + self.learning_rate = learning_rate + self.beta_kl = beta_kl + self.cap_priors = cap_priors + self.seed = seed + self.use_gpu = use_gpu + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + + import torch + from .bivae import BiVAE, learn + + self.device = ( + torch.device("cuda:0") + if (self.use_gpu and torch.cuda.is_available()) + else torch.device("cpu") + ) + + if self.trainable: + feature_dim = {"user":None, "item":None} + if self.cap_priors.get("user", False): + if train_set.user_feature is None: + raise ValueError("CAP priors for users is set to True but no user features are provided") + else: + feature_dim["user"] = train_set.user_feature.feature_dim + + if self.cap_priors.get("item", False): + if train_set.item_feature is None: + raise ValueError("CAP priors for items is set to True but no item features are provided") + else: + feature_dim["item"] = train_set.item_feature.feature_dim + + if self.seed is not None: + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + if not hasattr(self, "bivaecf"): + num_items = train_set.matrix.shape[1] + num_users = train_set.matrix.shape[0] + self.bivae = BiVAE( + self.k, + [num_items] + self.encoder_structure, + [num_users] + self.encoder_structure, + self.act_fn, + self.likelihood, + self.cap_priors, + feature_dim, + self.batch_size, + ).to(self.device) + + learn( + self.bivae, + self.train_set, + n_epochs=self.n_epochs, + batch_size=self.batch_size, + learn_rate=self.learning_rate, + beta_kl=self.beta_kl, + verbose=self.verbose, + device=self.device, + ) + + elif self.verbose: + print("%s is trained already (trainable = False)" % (self.name)) + + return self + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + theta_u = self.bivae.mu_theta[user_idx].view(1, -1) + beta = self.bivae.mu_beta + known_item_scores = self.bivae.decode_user(theta_u, beta).data.cpu().numpy().flatten() + # known_item_scores = np.dot(self.bivaecf.mu_beta.dot(self.s),self.bivaecf.mu_theta[user_idx]) + return known_item_scores + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + theta_u = self.bivae.mu_theta[user_idx].view(1, -1) + beta_i = self.bivae.mu_beta[item_idx].view(1, -1) + pred = self.bivae.decode_user(theta_u, beta_i).data.cpu().numpy().flatten() + + pred = scale(pred, self.train_set.min_rating, self.train_set.max_rating, 0., 1.) + + return pred diff --git a/cornac/models/bivaecf/requirements.txt b/cornac/models/bivaecf/requirements.txt new file mode 100644 index 00000000..b03b3da9 --- /dev/null +++ b/cornac/models/bivaecf/requirements.txt @@ -0,0 +1 @@ +torch>=0.4.1 \ No newline at end of file diff --git a/docs/source/models.rst b/docs/source/models.rst index 207febcd..46b7a5ea 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -10,6 +10,11 @@ Recommender (Generic Class) .. automodule:: cornac.models.recommender :members: +Bilateral VAE for Collaborative Filtering (BiVAECF) +---------------------------------------------------- +.. automodule:: cornac.models.bivaecf.recom_bivaecf + :members: + Collaborative Context Poisson Factorization (C2PF) ---------------------------------------------------- .. automodule:: cornac.models.c2pf.recom_c2pf