In-batch Negative Sampling

  • code:
 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4
 5class RecommenderModel(nn.Module):
 6    def __init__(self, user_size, item_size, embedding_dim):
 7        super(RecommenderModel, self).__init__()
 8        self.user_embedding = nn.Embedding(user_size, embedding_dim)
 9        self.item_embedding = nn.Embedding(item_size, embedding_dim)
10
11        def forward(self, user_ids, item_ids):
12        user_embeds = self.user_embedding(user_ids)
13        item_embeds = self.item_embedding(item_ids)
14        return user_embeds, item_embeds
15
16    def in_batch_negative_sampling_loss(user_embeds, item_embeds):
17        batch_size = user_embeds.size(0)
18        
19        # 正样本得分 (batch_size,) 
20        positive_scores = torch.sum(user_embeds * item_embeds, dim=-1) 
21        
22        # 负样本得分 (batch_size, batch_size)
23        negative_scores = torch.matmul(user_embeds, item_embeds.t())  
24        
25        # 创建标签  (batch_size, batch_size)
26        labels = torch.eye(batch_size).to(user_embeds.device) 
27        
28        # 计算损失
29        loss = F.cross_entropy(negative_scores, labels.argmax(dim=-1))
30        
31        return loss
32
33# 示例数据
34batch_size = 4
35embedding_dim = 8
36user_size = 100
37item_size = 1000
38
39user_ids = torch.randint(0, user_size, (batch_size,))
40item_ids = torch.randint(0, item_size, (batch_size,))
41
42model = RecommenderModel(user_size, item_size, embedding_dim)
43user_embeds, item_embeds = model(user_ids, item_ids)
44
45loss = in_batch_negative_sampling_loss(user_embeds, item_embeds)
46print(f'Loss: {loss.item()}')

优点

  • 效性:批量内负采样能够充分利用每个训练批次中的样本,提高训练效率,避免显式生成大量负样本的开销。
  • 适用性:这种方法特别适用于深度学习的推荐系统,在大规模数据训练时效果显著。
  • 实现:通过在每个批次中将其他正样本视为负样本,并使用合适的损失函数(如交叉熵损失),可以有效地优化模型。

缺点

  • 对热门商品的打压过于严重

batch内的item对于当前user都是正样本,这种样本天然要比随机采样的item的热度要高。也就是说,我们采样了一批热门商品当作负样本(hard negtive-sample)。这样难免对于热门商品的打压太过了,可以在计算<user, item>得分时,减去商品的热度,来补偿。

sampled_softmax loss

$$\mathcal{L}=-log{\left[\frac{exp(s_{i,i}-log(p_j))}{\sum_{k\neq i}exp(s_{i,k}-log(p_k))+exp(s_{i,i} - log(p_i))}\right]}$$

对每个item j,假设被采样的概率为$p_j$,那么$log Q$矫正就是在本来的内积上加上 $-log{p_j}$ $$ s^c(x_i, y_j) = s(x_i, y_j) – \log {p_j} $$

$p_j$的概率通过距离上一次看到y的间隔来估计,item越热门,训练过程中就越经常看到item,那么距离上次看到y的间隔B(y)跟概率p成反比。于是有如下算法

在实践当中,对每个y都可以用PS的1维向量来存储对应的step,那么下一次再看到y时就可以计算出对应的间隔和概率了。

1step = get_global_step()  # 获取当前的batch计数tensor
2item_step = get_item_step_vector()   # 获取用于存储item上次见过的step的向量
3item_step.set_gradient(grad=step - item_step, lr=1.0)  
4delta = tf.clip_by_value(step - item_step, 1, 1000)
5logq = tf.stop_gradient(tf.log(delta))
6batch_logits += logq  # batch_logits 是前面计算的logits
7...

$$ \begin{aligned} \frac{\partial\mathcal{L}}{\partial {s_{i,j}}} \\ &=\frac{\exp(s_{i,j}-log(p_j))}{\sum_{j \neq i}\exp(s_{i,j}-log(p_j))+\exp(s_{i,i}-log(p_i))} \\ &=\frac1{p_j}P_{i,j} \end{aligned} $$

可以看到,越热门的负样本,${1} / {p_j}$越小,gradient越小,可以起到一定的补偿机制。

 1def sampled_softmax_loss(weights,
 2                         biases,
 3                         labels,
 4                         inputs,
 5                         num_sampled,
 6                         num_classes,
 7                         num_true=1):
 8    """
 9    weights: 待优化的矩阵,形状[num_classes, dim]。可以理解为所有item embedding矩阵,那时 num_classes = 所有item的个数
10    biases: 待优化变量,[num_classes]。每个item还有自己的bias,与user无关,代表自己本身的受欢迎程度。
11    labels: 正例的item ids,形状是[batch_size,num_true]的正数矩阵。每个元素代表一个用户点击过的一个item id,允许一个用户可以点击过至多num_true个item。
12    inputs: 输入的[batch_size, dim]矩阵,可以认为是user embedding
13    num_sampled:整个batch要采集多少负样本
14    num_classes: 在u2i中,可以理解成所有item的个数
15    num_true: 一条样本中有几个正例,一般就是1
16    """
17     # logits: [batch_size, num_true + num_sampled]的float矩阵
18     # labels: 与logits相同形状,如果num_true=1的话,每行就是[1,0,0,...,0]的形式
19    logits, labels = _compute_sampled_logits(
20              weights=weights,
21              biases=biases,
22              labels=labels,
23              inputs=inputs,
24              num_sampled=num_sampled,
25              num_classes=num_classes,
26              num_true=num_true,
27              sampled_values=sampled_values,
28              subtract_log_q=True,
29              remove_accidental_hits=remove_accidental_hits,
30              partition_strategy=partition_strategy,
31              name=name,
32              seed=seed)
33    labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
34    
35    # sampled_losses:形状与logits相同,也是[batch_size, num_true + num_sampled]
36		# 一行样本包含num_true个正例和num_sampled个负例
37		# 所以一行样本也有num_true + num_sampled个sigmoid loss
38
39		sampled_losses = sigmoid_cross_entropy_with_logits(
40		      labels=labels,
41		      logits=logits,
42		      name="sampled_losses")
43		      
44		# We sum out true and sampled losses.
45		return _sum_rows(sampled_losses)
46
47def _compute_sampled_logits(weights,
48       biases,
49       labels,
50       inputs,
51       num_sampled,
52       num_classes,
53       num_true=1,
54       ......
55       subtract_log_q=True,
56       remove_accidental_hits=False,......):
57    """
58    输入:
59        weights: 待优化的矩阵,形状[num_classes, dim]。可以理解为所有item embedding矩阵,那时num_classes=所有item的个数
60        biases: 待优化变量,[num_classes]。每个item还有自己的bias,与user无关,代表自己的受欢迎程度。
61        labels: 正例的item ids,形状是[batch_size,num_true]的正数矩阵。每个元素代表一个用户点击过的一个item id。允许一个用户可以点击过多个item。
62        inputs: 输入的[batch_size, dim]矩阵,可以认为是user embedding
63        num_sampled:整个batch要采集多少负样本
64        num_classes: 在u2i中,可以理解成所有item的个数
65        num_true: 一条样本中有几个正例,一般就是1
66        subtract_log_q:是否要对匹配度,进行修正
67        remove_accidental_hits:如果采样到的某个负例,恰好等于正例,是否要补救
68    Output:
69        out_logits: [batch_size, num_true + num_sampled]
70        out_labels: 与`out_logits`同形状
71    """

reference:

  1. 双塔召回模型中的logQ矫正