fan_in

如果权重是通过线性层(卷积或全连接)隐性确定的,则需设置mode=fan_in。

例子:

import torch

linear_layer = torch.nn.Linear(node_in, node_out)
init.kaiming_normal_(linear.weight, mode=’fan_in’)
output_data = relu(linear_layer(input_data))

fan_out

如果通过创建随机矩阵显式创建权重,则应进行设置mode=‘fan_out’。

import torch

w1 = torch.randn(node_in, node_out)
init.kaiming_normal_(w1, mode=’fan_out’)
b1 = torch.randn(node_out)
output_data  = relu(linear(input_data, w1, b1))
Logo

一站式 AI 云服务平台

更多推荐