PyTorch神经网络层拆解

如题所述

第1个回答  2022-06-26

本文将拆解常见的PyTorch神经网络层,从开发者的角度来看,这些神经网络层都是一个一个的函数,完成对数据的处理。

第一 :CLASS torch.nn.Flatten( start_dim=1 , end_dim=- 1 ) ,将多维的输入一维化,常用在从卷积层到全连接层的过渡。需要注意的是,Flatten()的默认值start_dim=1,即默认数据数据的格式是[N,C,H,W]第0维度为Batch Size,不参与Flatten。后面的CHW全部展平为一维。

第二 , CLASS torch.nn.Linear( in_features , out_features , bias=True , device=None , dtype=None ) ,Linear又叫全连接层,TensorFlow里面叫Dense,主要用于分类。
Linear类有两个属性:

第三 ,CLASS torch.nn.Conv2d (in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None),卷积层,常用于提取图像特征,CNN+RELU+MaxPooling已经成为一种常见的特征提取操作了。
需要注意的是:CNN要求数据输入格式为:[N, Cin, Hin, Wout],Cin是输入数据Tensor的通道数量,输出为[N, Cout, Hout, Wout],Cout为本CNN层的卷积个数。Hout和Wout计算公式如下所示:

范例程序:

总结: