PyTorch实例:如何打印神经网络结构

使用PyTorch可以很容易地打印神经网络的结构。具体的实现方法如下:

1. 建立神经网络模型

# 建立一个简单的网络
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 10)

net = Net()

2. 使用print_summary()函数打印网络结构

# 导入打印神经网络结构模块
from torchsummary import summary

# 使用summary函数打印神经网络结构
summary(net, (10,))

3. 输出结果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 20]              220
            Linear-2                  [-1, 30]              630
            Linear-3                  [-1, 10]              310
================================================================
Total params: 1,160
Trainable params: 1,160
Non-trainable params: 0
----------------------------------------------------------------

从输出结果可以看出,网络结构由三层线性层组成,每层的输入输出形状,以及参数数量都能够被打印出来。

本文链接:http://task.lmcjl.com/news/9624.html

展开阅读全文