Skip to content

Commit 53dbac8

Browse files
committed
refactor code
1 parent 3d4ac7d commit 53dbac8

File tree

2 files changed

+15
-35
lines changed

2 files changed

+15
-35
lines changed

cornac/models/dnntsp/dnntsp.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,9 @@ def forward(
328328
def get_edges_weight(history_baskets):
329329
edges_weight_dict = defaultdict(float)
330330
for basket_items in history_baskets:
331-
for i in range(len(basket_items)):
332-
for j in range(i + 1, len(basket_items)):
333-
edges_weight_dict[(basket_items[i], basket_items[j])] += 1
334-
edges_weight_dict[(basket_items[j], basket_items[i])] += 1
331+
for (item_i,item_j) in itertools.combinations(basket_items, 2):
332+
edges_weight_dict[(item_i, item_j)] += 1
333+
edges_weight_dict[(item_j, item_i)] += 1
335334
return edges_weight_dict
336335

337336

@@ -373,19 +372,11 @@ def transform_data(
373372
torch.tensor(list(range(nodes.shape[0]))) for nodes in batch_nodes
374373
]
375374
batch_src = [
376-
(
377-
torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=1)
378-
.flatten()
379-
.tolist()
380-
)
375+
project_nodes.repeat((project_nodes.shape[0], 1)).T.flatten().tolist()
381376
for project_nodes in batch_project_nodes
382377
]
383378
batch_dst = [
384-
(
385-
torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=0)
386-
.flatten()
387-
.tolist()
388-
)
379+
project_nodes.repeat((project_nodes.shape[0],)).flatten().tolist()
389380
for project_nodes in batch_project_nodes
390381
]
391382
batch_g = [
@@ -487,11 +478,8 @@ def forward(self, predict, truth):
487478
Returns:
488479
output: tensor
489480
"""
490-
# predict = torch.softmax(predict, dim=-1)
491481
predict = torch.sigmoid(predict)
492482
truth = truth.float()
493-
# print(predict.device)
494-
# print(truth.device)
495483
if self.weights is not None:
496484
self.weights = self.weights.to(truth.device)
497485
predict = predict * self.weights
@@ -623,17 +611,3 @@ def learn(
623611

624612
# Note that step should be called after validate
625613
scheduler.step(total_val_loss)
626-
627-
628-
def score(model: TemporalSetPrediction, history_baskets, total_items, device="cpu"):
629-
model = model.to(device)
630-
model.eval()
631-
(g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data(
632-
[history_baskets],
633-
item_embedding=model.embedding_matrix,
634-
total_items=total_items,
635-
device=device,
636-
is_test=True,
637-
)
638-
preds = model(g, nodes_feature, edges_weight, lengths, nodes)
639-
return preds.cpu().detach().numpy()

cornac/models/dnntsp/recom_dnntsp.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,15 @@ def fit(self, train_set, val_set=None):
122122
return self
123123

124124
def score(self, user_idx, history_baskets, **kwargs):
125-
from .dnntsp import score
125+
from .dnntsp import transform_data
126126

127-
item_scores = score(
128-
self.model, history_baskets, self.total_items, device=self.device
127+
self.model.eval()
128+
(g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data(
129+
[history_baskets],
130+
item_embedding=self.model.embedding_matrix,
131+
total_items=self.total_items,
132+
device=self.device,
133+
is_test=True,
129134
)
130-
return item_scores
135+
preds = self.model(g, nodes_feature, edges_weight, lengths, nodes)
136+
return preds.cpu().detach().numpy()

0 commit comments

Comments
 (0)