@@ -328,10 +328,9 @@ def forward(
328328def get_edges_weight (history_baskets ):
329329 edges_weight_dict = defaultdict (float )
330330 for basket_items in history_baskets :
331- for i in range (len (basket_items )):
332- for j in range (i + 1 , len (basket_items )):
333- edges_weight_dict [(basket_items [i ], basket_items [j ])] += 1
334- edges_weight_dict [(basket_items [j ], basket_items [i ])] += 1
331+ for (item_i ,item_j ) in itertools .combinations (basket_items , 2 ):
332+ edges_weight_dict [(item_i , item_j )] += 1
333+ edges_weight_dict [(item_j , item_i )] += 1
335334 return edges_weight_dict
336335
337336
@@ -373,19 +372,11 @@ def transform_data(
373372 torch .tensor (list (range (nodes .shape [0 ]))) for nodes in batch_nodes
374373 ]
375374 batch_src = [
376- (
377- torch .stack ([project_nodes for _ in range (project_nodes .shape [0 ])], dim = 1 )
378- .flatten ()
379- .tolist ()
380- )
375+ project_nodes .repeat ((project_nodes .shape [0 ], 1 )).T .flatten ().tolist ()
381376 for project_nodes in batch_project_nodes
382377 ]
383378 batch_dst = [
384- (
385- torch .stack ([project_nodes for _ in range (project_nodes .shape [0 ])], dim = 0 )
386- .flatten ()
387- .tolist ()
388- )
379+ project_nodes .repeat ((project_nodes .shape [0 ],)).flatten ().tolist ()
389380 for project_nodes in batch_project_nodes
390381 ]
391382 batch_g = [
@@ -487,11 +478,8 @@ def forward(self, predict, truth):
487478 Returns:
488479 output: tensor
489480 """
490- # predict = torch.softmax(predict, dim=-1)
491481 predict = torch .sigmoid (predict )
492482 truth = truth .float ()
493- # print(predict.device)
494- # print(truth.device)
495483 if self .weights is not None :
496484 self .weights = self .weights .to (truth .device )
497485 predict = predict * self .weights
@@ -623,17 +611,3 @@ def learn(
623611
624612 # Note that step should be called after validate
625613 scheduler .step (total_val_loss )
626-
627-
628- def score (model : TemporalSetPrediction , history_baskets , total_items , device = "cpu" ):
629- model = model .to (device )
630- model .eval ()
631- (g , nodes_feature , edges_weight , lengths , nodes , _ ) = transform_data (
632- [history_baskets ],
633- item_embedding = model .embedding_matrix ,
634- total_items = total_items ,
635- device = device ,
636- is_test = True ,
637- )
638- preds = model (g , nodes_feature , edges_weight , lengths , nodes )
639- return preds .cpu ().detach ().numpy ()
0 commit comments