DGCF icon indicating copy to clipboard operation
DGCF copied to clipboard

有关update的forward传参

Open imgkkk574 opened this issue 9 months ago • 1 comments

根据论文所提出的3阶更新机制,对某节点进行update时,应包含该节点上一时刻的embedding、新交互节点的embedding以及新交互节点的1-hop邻居的聚合embedding。 即在以下代码中, user_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs=user_timediffs_tensor, features=feature_tensor, adj_embeddings=user_adj_embedding, select='user_update') 传入的adj_embedding应为item_adj_embeddingitem_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs=item_timediffs_tensor, features=feature_tensor, adj_embeddings=item_adj_embedding, select='item_update')传入的adj_embedding应为user_adj_embedding。 目前仓库中的代码实现似乎有误。

此外,若直接进行以上修改,或许会带来数据泄露的问题:模型依靠t-batch算法进行并行计算,目前的代码实现中, lib.current_tbatches_user_adj[tbatch_to_insert].append(user_adj[userid]) # item``lib.current_tbatches_item_adj[tbatch_to_insert].append(item_adj[itemid]) # user存入的是整个数据的user-item邻接关系,前序t-batch数据可以使用到后续t-batch的邻接关系,造成数据泄露。

另,evaluate_interaction_prediction.py中的set_embeddings_training_end应该注释掉,因为目前DGCF.py的实现中并未存入有意义的_embeddings_time_series

imgkkk574 avatar Jul 05 '25 18:07 imgkkk574

lib.current_tbatches_user_adj[tbatch_to_insert].append(user_adj[userid])lib.current_tbatches_item_adj[tbatch_to_insert].append(item_adj[itemid])传入的adj改为forzenset(_adj[_id])应该能避免数据泄露的问题。原因在于user_id[]item_id[]的数据类型是set,为可变数据类型,后续t-batch的邻接关系会影响前序t-batch已保存的邻接关系。

imgkkk574 avatar Jul 06 '25 06:07 imgkkk574