EMC^2: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence
arxiv(2024)
摘要
A key challenge in contrastive learning is to generate negative samples from
a large sample set to contrast with positive samples, for learning better
encoding of the data. These negative samples often follow a softmax
distribution which are dynamically updated during the training process.
However, sampling from this distribution is non-trivial due to the high
computational costs in computing the partition function. In this paper, we
propose an Efficient Markov Chain Monte Carlo negative sampling method for
Contrastive learning (EMC^2). We follow the global contrastive learning loss
as introduced in SogCLR, and propose EMC^2 which utilizes an adaptive
Metropolis-Hastings subroutine to generate hardness-aware negative samples in
an online fashion during the optimization. We prove that EMC^2 finds an
𝒪(1/√(T))-stationary point of the global contrastive loss in
T iterations. Compared to prior works, EMC^2 is the first algorithm that
exhibits global convergence (to stationarity) regardless of the choice of batch
size while exhibiting low computation and memory cost. Numerical experiments
validate that EMC^2 is effective with small batch training and achieves
comparable or better performance than baseline algorithms. We report the
results for pre-training image encoders on STL-10 and Imagenet-100.
更多查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要