Scaling Up Dataset Distillation to ImageNet-1K with Constant Memory

针对MTT的一些优化。MTT通过匹配在不同数据集上训练的模型参数,来优化\(C_{sync}\). 其损失函数为 \[ L = \frac{||\hat{\theta}_{t+T}-\theta^*_{t+M}||_{2}^2}{||\theta^*_{t+T}-\theta^*_{t+M}||_{2}^2}\tag{1} \] 其中\(\theta^*_{t}\)代表模型在原始数据集\(C_{pre}\)上训练时第t步的参数,再在此基础上,\(\hat{\theta}_{t+T}\)代表在合成数据\(C_{sync}\)上再训练\(T\)步之后的参数,其中\(T\ll M\)

注意到 \[ \hat{\theta}_{t+T} = \theta^*_{t} - \beta \nabla_{l}(\theta^*_{t};\hat{X}_{0}) - \beta \nabla_{\theta}l(\hat{\theta}_{t+1};\hat{X}_{1}) - \dots \beta \nabla_{\theta}l(\hat{\theta}_{t+T-1};\hat{X}_{T-1})\tag{2} \] 从而 \[ L' = ||\hat{\theta}_{t+T}-\theta^*_{t+M}||_{2}^2 = ||\theta^*_{t} - \beta\sum_{i=0}^{T-1} \nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i})-\theta^*_{t+M}||_{2}^2\tag{3} \] 注意到对于任意\(\tilde{X}_{i}\),有 \[ \frac{\partial L'}{\partial \tilde{X}_{i}} = \frac{\partial L'}{\partial \hat{\theta}_{t+T}}\frac{\hat{\theta}_{t+T}}{\partial \tilde{X}_{i}}\tag{4} \] 由于\(\hat{\theta}_{t+T}\)的存在,需要将共\(T\)步的所有梯度都保存下来,从而导致巨大的显存消耗。 事实上可以对\((3)\)式进行变形,得到 \[ \begin{flalign} L' &= ||\theta^*_{t} - \beta\sum_{i=0}^{T-1} \nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i})-\theta^*_{t+M}||_{2}^2 \\ &= \beta^2||\sum_{i=0}^{T-1} \nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i})||_{2}^2-2\beta (\sum_{i=0}^{T-1} \nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i}))^T(\theta_{t}^* - \theta^*_{t+M})+C \tag{5} \end{flalign} \] 其中\(C = ||\theta_{t}^* - \theta^*_{t+M}||_{2}^2\),是一个常数

\((5)\)式出发,可以将\((4)\)式转化为 \[ \frac{\partial L'}{\partial \tilde{X}_{i}} = 2\beta^2 G^T \frac{\partial}{\partial \tilde{X}_{i}} \nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i}) - 2\beta \frac{\partial}{\partial \tilde{X}_{i}}(\nabla_{\theta}l(\hat{\theta}_{t+i};\hat{X}_{i}))(\theta_{t}^* - \theta^*_{t+M}) \tag{6} \] 其中\(G = \sum_{i=0}^{T-1} \nabla_{l}(\hat{\theta}_{t+i};\hat{X}_{i})\)

\((6)\)式中,第二项只与每一个\(\tilde{X}_{i}\)有关,只有第二项中\(G^T\)是与全体步骤的梯度有关的,而其作为梯度之和,可以提前计算好,从而整体显存与迭代步数\(T\)可以做到不相关,从而显存占用从\(O(T)\)压到了\(O(1)\) 整体来看,只是对梯度的求法做了一个变形,但确实带来了很好的效果。

另一点在于标签的使用。作者发现原始的0/1 hard label,收敛性差,但如果换成\((0,1)\)之间的soft label,会有更好效果。因此作者使用目标步长\(\theta_{t+M}^*\)的参数作为教师,生成soft label概率分布。这一点其实与KD是一致的。


Scaling Up Dataset Distillation to ImageNet-1K with Constant Memory
https://sophilex.github.io/posts/b581fc49/
作者
Sophilex
发布于
2025年10月13日
许可协议