Ph.D. Student in Machine Learning

Back

End-to-End Weak Supervision

Salva Rühling Cachay, Benedikt Boecking, Artur Dubrawski

https://arxiv.org/abs/2107.02233

Figure 1

An encoder network $e$ is trained to output weights that are used in a linear combination with weak labels provided by multiple noisy labeling functions. The class probabilities generated by this linear combination are used in a cross-entropy loss that uses the output of a downstream network $f$ as ground-truth. Simultaneously, the class probabilities of the downstream network $f$ are penalized using the encoder $e$ network's probabilities as ground truth. Strong performance is found on weakly-supervised tasks. Robustness against adversarial or highly correlated labels is also assessed. Ablation studies provide additional insights. It is essential that the gradient w.r.t. the "ground-truth" (provided by either network) is stopped, such that the loss function is symmetrized.

Comments. Cool idea! I am not very familiar with weakly supervised learning, but this approach sounds very ingenious.

Why does it work? After staring at it for a while, I think I get the intuition somewhat. Since the encoder probabilities are constrained to be a linear combination of the weak labels and sum to one, the downstream model will try to mimic the majority vote. For example, when $e$ is initialized such that it outputs mostly zeros, the weak labels are uniformly weighted. This gives a learning signal for the downstream model to mimic the majority vote. It is not entirely clear what the learning signal for $e$ then is. At the start of training, its learning signal will be degenerate, I suspect, as it will be incentivized to become closer to a poorly performing downstream model $f$. Despite this, I think the system converges since the models will learn to incorporate as much information available in the input to mimic the other model's output. Since the ground-truth label is something that both models can agree on (i.e., extract from input features), it is converged on. However, it is not clear to me why the models do not converge on just predicting a single class all the time and obtaining high agreement as such.