Just Train Twice: Improving Group Robustness Without Training Group Information

INTERNATIONAL CONFERENCE ON MACHINE LEARNING, VOL 139(2021)

引用 392|浏览290
暂无评分
摘要
Standard training via empirical risk minimization (ERM) can produce models that achieve low error on average but high error on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve low worst-group error, like group distributionally robust optimization (group DRO) require expensive group annotations for each training point, whereas approaches that do not use such group annotations achieve worse worst-group performance. In this paper, we propose a simple two-stage approach, JTT, that minimizes the loss over a reweighted dataset (second stage) where we upweight training examples that are misclassified at the end of a few steps of standard training (first stage). Intuitively, this upweights points from groups on which standard ERM models perform poorly, leading to improved worst-group performance. On four image classification and natural language processing tasks with spurious correlations, we show that JTT closes 73% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.
更多
查看译文
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要