Batch内负采样
In-batch Negative Sampling code: 1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4 5class RecommenderModel(nn.Module): 6 def __init__(self, user_size, item_size, embedding_dim): 7 super(RecommenderModel, self).__init__() 8 self.user_embedding = nn.Embedding(user_size, embedding_dim) 9 self.item_embedding = nn.Embedding(item_size, embedding_dim) 10 11 def forward(self, user_ids, item_ids): 12 user_embeds = self.user_embedding(user_ids) 13 item_embeds = self.item_embedding(item_ids) 14 return user_embeds, item_embeds 15 16 def in_batch_negative_sampling_loss(user_embeds, item_embeds): 17 batch_size = user_embeds.size(0) 18 19 # 正样本得分 (batch_size,) 20 positive_scores = torch.sum(user_embeds * item_embeds, dim=-1) 21 22 # 负样本得分 (batch_size, batch_size) 23 negative_scores = torch.matmul(user_embeds, item_embeds.t()) 24 25 # 创建标签 (batch_size, batch_size) 26 labels = torch.eye(batch_size).to(user_embeds.device) 27 28 # 计算损失 29 loss = F.cross_entropy(negative_scores, labels.argmax(dim=-1)) 30 31 return loss 32 33# 示例数据 34batch_size = 4 35embedding_dim = 8 36user_size = 100 37item_size = 1000 38 39user_ids = torch.randint(0, user_size, (batch_size,)) 40item_ids = torch.randint(0, item_size, (batch_size,)) 41 42model = RecommenderModel(user_size, item_size, embedding_dim) 43user_embeds, item_embeds = model(user_ids, item_ids) 44 45loss = in_batch_negative_sampling_loss(user_embeds, item_embeds) 46print(f'Loss: {loss.item()}') 优点 效性:批量内负采样能够充分利用每个训练批次中的样本,提高……