如何实现重计算

pyTorch 重计算

只需要在参数量较多的每个层(自行脑补bert结构)之间设置检查点。

        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 改之前的
            # layer_outputs = layer_module(
            #     hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
            # )
            layer_outputs = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)

            hidden_states = layer_outputs[0]

            if self.output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

参考

PyTorch checkpoint
Training Deep Nets with Sublinear Memory Cost
tensorflow 的实现可以参考 gradient-checkpointing