Highway network 和 Residual network 分析和比较

内容纲要

Highway network

Highway network 设计目的和思想

设计的主要目的就是为了训练更深层次的神经网络,借鉴了lstm通过门控机制能够学到相对长时序的依赖,想要设计一个“立起来的”lstm,所以就有了 Highway network。
highway_instance

Highway network 架构核心

H 一般来说就是一个仿射变换加一个非线性的激活函数比如 relu

\mathbf{y}=H\left(\mathbf{x}, \mathbf{W}_{\mathbf{H}}\right)

原本考虑给它两个门

  1. T transform gate (用来决定有多少到下一层的 output 是取决于变换)
  2. C carry gate (决定多大程度维持原来的 input)
    公式如下:
\mathbf{y}=H\left(\mathbf{x}, \mathbf{W}_{\mathbf{H}}\right) \cdot T\left(\mathbf{x}, \mathbf{W}_{\mathbf{T}}\right)+\mathbf{x} \cdot C\left(\mathbf{x}, \mathbf{W}_{\mathbf{C}}\right)

coupled T 和 C ,用 1-T 代替 C ,T 作为门控的角色这里激活函数是 Sigmoid

\mathbf{y}=H\left(\mathbf{x}, \mathbf{W}_{\mathbf{H}}\right) \cdot T\left(\mathbf{x}, \mathbf{W}_{\mathbf{T}}\right)+\mathbf{x} \cdot\left(1-T\left(\mathbf{x}, \mathbf{W}_{\mathbf{T}}\right)\right)

VS Residual network

Residual network 相当于是上面简化上面第二个公式,让 T=C=1

y=H\left(x, W_{H}\right)+x

因为是设计来做图像分类的 Residual network 的 H 是包含多个层 CNN ,BN ,Relu 等的一个 subnetwork

highway_vs_resnet

简单来说,有三点不同(这只是我能归纳出的粗浅的三点,以后有机会再深入)

  1. ResNet 没有额外参数
  2. ResNet 直接使用 X 跳步链接,Highway network 使用门控
  3. ResNet 试验和比赛证明了类似各种变种,验证了自己这一套做法最强

基于以上原因残差连接的方式也被更加广泛使用。

Highway network 核心代码示例

class Demo(nn.Module):
    def __init__(self, args, pretrained):
        super(Demo, self).__init__()
        # 一般会定义多层结构
        assert self.args.hidden_size * 2 == (self.args.char_channel_size + self.args.word_dim)
        for i in range(2):
            setattr(self, 'highway_linear{}'.format(i),
                    nn.Sequential(Linear(args.hidden_size * 2, args.hidden_size * 2),
                                  nn.ReLU()))
            setattr(self, 'highway_gate{}'.format(i),
                    nn.Sequential(Linear(args.hidden_size * 2, args.hidden_size * 2),
                                  nn.Sigmoid()))

    def highway_network(x1, x2):
    """
    :param x1: (batch, seq_len, char_channel_size)
    :param x2: (batch, seq_len, word_dim)
    :return: (batch, seq_len, hidden_size * 2)
    """
    # (batch, seq_len, char_channel_size + word_dim)
    x = torch.cat([x1, x2], dim=-1)
    for i in range(2):
        h = getattr(self, 'highway_linear{}'.format(i))(x)
        g = getattr(self, 'highway_gate{}'.format(i))(x)
        x = g * h + (1 - g) * x
    # (batch, seq_len, hidden_size * 2)
    return x

Residual network 核心代码示例

注意:identity 除了,身份、特征、特性其中有一个意思是 “本体” 。

# init 定义的就不写了
def forward(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out

参考

Highway network 论文
ResNet 论文
李宏毅视频讲解
Troch vision
quora 讨论
知乎 讨论

发表评论

电子邮件地址不会被公开。 必填项已用*标注