多任务loss优化
多任务学习是推荐系统中常见的技术实现. 在很多推荐与排序场景中,业务目标通常有多个,找到一种综合排序方法使得多个目标都达到整体最优,才能实现受益最大化。
多任务学习
多任务学习经常使用联合训练(Joint-train)的模式进行多任务优化学习,公式如下:
$$ L=\min_\theta \sum_{t=1}^T \alpha^t L^t\left(\theta^{s h}, \theta^t\right) $$
公式$\theta^{t}$是任务$t$公式的独享参数,$\theta^{sh}$是所有任务的共享参数, $\alpha^t$是任务$t$对应的权重。 总$Loss$公式是每个子任务对应$Loss$公式的加权求和。
为了能够更好地『共享参数』,让同个模型中多个任务和谐共存、相辅相成,业界有两大优化方向,分别是:
- 网络结构优化,设计更好的参数共享位置与方式
- 优化策略提升,设计更好的优化策略以提升优化$Loss$公式过程中的多任务平衡.
本文主要聚焦在讨论第二个方向。即在学习中如何平衡多任务中的多个$Loss$, 达到多个任务上的效果最优化。
多任务loss优化
多任务loss优化方法更多的考虑的是在已有结构下,更好地结合任务进行训练和参数优化,它从$Loss$与梯度的维度去思考不同任务之间的关系。在优化过程中缓解梯度冲突,参数撕扯,尽量达到多任务的平衡优化。
目前各式各样的多任务优化方法策略,主要集中在3个问题:
- $Magnitude$ (Loss量级)
Loss 值有大有小,取值大的 Loss 可能会主导。典型的例子是二分类任务 + 回归任务的多任务优化,L2 Loss 和交叉熵损失的 Loss 大小与梯度大小的量级和幅度可能差异很大,如果不处理会对优化造成很大干扰。
- $Velocity$ (Loss学习速度)
不同任务因为样本的稀疏性、学习的难度不一致,在训练和优化过程中,存在 Loss 学习速度不一致的情况。如果不加以调整,可能会出现某个任务接近收敛甚至过拟合的时候,其他任务还是欠拟合的状态。
- $Direction$ (Loss梯度冲突)
不同任务的 Loss 对共享参数进行更新,梯度存在不同的大小和方向,相同参数被多个梯度同时更新的时候,可能会出现冲突,导致相互消耗抵消,进而出现跷跷板、甚至负迁移现象。 这也是核心需要处理的问题。
Uncertainty Weight Loss
简单的多任务学习往往是把所有$Loss$进行联合优化,通常需要需要手动调节他们的$weight\ \alpha^t$。
导致模型堆最后的$weight$非常敏感, 同时手工调节这些$weight$也是非常费时费力的工作。 UWL提出直接建模单个任务中的uncertainty,然后通过uncertainty来指导权重的调节。损失函数如下:
$$ \mathcal{L}\left(W, \sigma_1, \sigma_2\right) \approx \frac{1}{2 \sigma_1^2} \mathcal{L}_1(W)+\frac{1}{2 \sigma_2^2} \mathcal{L}_2(W)+\log \sigma_1+\log \sigma_2 $$
背后的含义是: Loss 大的任务,包含的uncertainty也应该多($\sigma$越大),对应的权重就应该小一点。防止任务在不确定性比较大时,往错误的方向”大步迈“。 这样优化的结果就是往往 Loss 小(『相对简单』)的任务会有一个更大的权重。
1def uncertainty_weight_loss(loss_list):
2 total_loss = 0.0
3 for idx, loss in enumerate(loss_list):
4 sigma = tf.get_variable(shape=[], dtype=tf.float32, name='loss_sigma_{}'.format(idx))
5 factor = 1.0 / ( sigma * sigma )
6 loss = factor * loss + tf.log(sigma)
7 total_loss = total_loss + loss
8 return total_loss
GradNorm
Gradient normalization方法的主要思想是:
- 希望不同的任务的 Loss 量级是接近的
- 希望不同的任务以相似的速度学习
$GradNorm$使用了两种loss,一种是$label\ loss$,就是整个任务的loss; 第二种是$gradient\ loss$,是通过梯度下降来更新权重w,来实现平衡任务训练速度的目的。 值得注意的是,这两种loss独立优化不相加,文章的重点就是$gradient loss$的计算.
算法流程:
PCGrad
PCGrad是Google在NIPS 2020《Gradient surgery for multi-task learning》这篇paper里提出的方法,PCGrad指出MTL多目标优化存在3个问题:
- 方向不一致,导致撕扯,需要解决
- 量级不一致,导致大gradients主导,需要解决
- 大曲率,导致容易过拟合,需要解决
解决办法如下:
- 先检测不同任务的梯度是否冲突(梯度向量的余弦相似度是否小于0, negative similarity)
- 如果有冲突,就把冲突的分量 clip 掉(把其中一个任务的梯度投影到另一个任务梯度的正交方向上,只保留正交方向的梯度分量,丢弃有冲突的分量)
其实就是将$g_i$关于$g_j$的方向做一个正交分解,减去与$g_j$共线反向的一部分,剩下的就是在$g_j$的法向量上的投影了。
算法流程:
DWA
Dynamic Weight Averaging:任务的训练速度越快,权重越小。
$$ \lambda_k(t):=\frac{K \exp \left(w_k(t-1) / T\right)}{\sum_i \exp \left(w_i(t-1) / T\right)}, w_k(t-1)=\frac{\mathcal{L}_k(t-1)}{\mathcal{L}_k(t-2)} $$
- $\lambda_k(t)$: 任务$k$在第$t$步的loss权重
- $L_k(t-1)$, $L_k(t-2)$: 任务$k$在第$t-1$步时第$t-2$步的Loss
- $w_k(t-1)$: 任务 $k$ 在第$t-1$步的训练速度 ($w_k(t-1)$越小, 任务训练越快, 下一轮分配的权重就小)
- T 是一个常数, $\mathrm{T}=1$ 时, w 等同于softmax的结果; T 足够大时, w趋近1, 各个任务的loss权重相同。
代码:
1import math
2
3T = 20
4
5def dynamic_weight_average(loss_t_1, loss_t_2):
6 """
7
8 :param loss_t_1: 每个task上一轮的loss列表, 长度为n, 任务数
9 :param loss_t_2: 每个task上上一轮的loss列表, 长度为n, 任务数
10 :return:
11 """
12 # 第1和2轮,w初设化为1,lambda也对应为1
13 if not loss_t_1 or not loss_t_2:
14 return None
15
16 assert len(loss_t_1) == len(loss_t_2)
17 task_n = len(loss_t_1)
18
19 w = [l_1 / l_2 for l_1, l_2 in zip(loss_t_1, loss_t_2)]
20
21 lamb = [math.exp(v / T) for v in w]
22
23 lamb_sum = sum(lamb)
24
25 return [task_n * l / lamb_sum for l in lamb]
Pareto
帕累托最优:排除明显差的解,构建帕累托前沿,从前沿上取最优解。
- 解A优于解B
当解A的所有目标都优于解B, 则称为解A优于解B。如下图的E点,它的函数f1和f2的值都小于C点(这里可以认为f是loss),那么可以说解E由于解C,同理也由于解D
- 解A无差别于解B
当解A的一个目标由于解B,另一个目标不如解B,则称为解A无差别于解B。如下图的A和B,A点的f1值小于B点,但A点的f2值大于B点。
- 帕累托前沿
如果我们找到一条曲线,上面的所有解两两之间都存在A无差别于B的关系,那么就说明找到了一条帕累托前沿(为什么E不在帕累托前沿上呢,因为在前沿上,一定存在一个点,f1和f2的值都小于E点)。
具体需要使用哪个解,则因人而异。一般来说,通过定义一组参数下界,让模型知道每个任务的权重,至少要大于定义的下界,从而实现最优解的选择。 帕累托前沿上的任何一个解都是一个可接受的解。
帕累托在多目标上的应用:
代码实现:
1import numpy as np
2from scipy.optimize import minimize
3from scipy.optimize import nnls
4
5
6def pareto_step(w, c, G):
7 """
8 使用pareto优化更新下一setp的w权重
9 ref:http://ofey.me/papers/Pareto.pdf
10 K : the number of task
11 M : the dim of NN's params
12 :param W: # (K,1)
13 :param C: # (K,1)
14 :param G: # (K,M)
15 :return:
16 """
17 GGT = np.matmul(G, np.transpose(G)) # (K, K)
18 e = np.mat(np.ones(np.shape(w))) # (K, 1)
19 m_up = np.hstack((GGT, e)) # (K, K+1)
20 m_down = np.hstack((np.transpose(e), np.mat(np.zeros((1, 1))))) # (1, K+1)
21 M = np.vstack((m_up, m_down)) # (K+1, K+1)
22 z = np.vstack((-np.matmul(GGT, c), 1 - np.sum(c))) # (K+1, 1)
23 hat_w = np.matmul(np.matmul(np.linalg.inv(np.matmul(np.transpose(M), M)), M), z) # (K+1, 1)
24 hat_w = hat_w[:-1] # (K, 1)
25 hat_w = np.reshape(np.array(hat_w), (hat_w.shape[0],)) # (K,)
26 c = np.reshape(np.array(c), (c.shape[0],)) # (K,)
27 new_w = ASM(hat_w, c)
28 return new_w
29
30
31def ASM(hat_w, c):
32 """
33 保证w_i的值满足一定条件,比如 sum(w_i) = 1, w_i>=c for all w_i
34 ref:
35 http://ofey.me/papers/Pareto.pdf,
36 https://stackoverflow.com/questions/33385898/how-to-include-constraint-to-scipy-nnls-function-solution-so-that-it-sums-to-1
37
38 :param hat_w: # (K,)
39 :param c: # (K,)
40 :return:
41 """
42 A = np.array([[0 if i != j else 1 for i in range(len(c))] for j in range(len(c))])
43 b = hat_w
44 x0, _ = nnls(A, b)
45
46 def _fn(x, A, b):
47 return np.linalg.norm(A.dot(x) - b)
48
49 cons = {'type': 'eq', 'fun': lambda x: np.sum(x) + np.sum(c) - 1}
50 bounds = [[0., None] for _ in range(len(hat_w))]
51 min_out = minimize(_fn, x0, args=(A, b), method='SLSQP', bounds=bounds, constraints=cons)
52 new_w = min_out.x + c
53 return new_w
54
55
56use_pareto = True
57w_a, w_b = 0.5, 0.5
58c_a, c_b = 0.4, 0.2
59for step in range(0, 100):
60 res = sess.run([a_gradients, b_gradients, train, loss, loss_a, loss_b],
61 feed_dict={weight_a: w_a, weight_b: w_b})
62
63 if use_pareto:
64 s = time.time()
65 weights = np.mat([[w_a], [w_b]])
66 paras = np.hstack((res[0], res[1]))
67 paras = np.transpose(paras)
68 w_a, w_b = pareto_step(weights, np.mat([[c_a], [c_b]]), paras)
69 print("pareto cost: {}".format(time.time() - s))
70
71 l, l_a, l_b = res[3:]
72 print("step:{:0>2d} w_a:{:4f} w_b:{:4f} loss:{:4f} loss_a:{:4f} loss_b:{:4f} r:{:4f}".format(step, w_a, w_b, l, l_a, l_b, l_a / l_b))