Pytorch實現簡單的CNN網絡

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__() # b, 3, 32, 32
        layer1 = torch.nn.Sequential() 
        layer1.add_module('conv1', torch.nn.Conv2d(3, 32, 3, 1, padding=1))//輸入3維,輸出32維,輸出大小32*32

        #b, 32, 32, 32
        layer1.add_module('relu1', torch.nn.ReLU(True)) 
        layer1.add_module('pool1', torch.nn.MaxPool2d(2, 2)) # b, 32, 16, 16 //池化爲16*16
        self.layer1 = layer1

        layer2 = torch.nn.Sequential()
        layer2.add_module('conv2', torch.nn.Conv2d(32, 64, 3, 1, padding=1))
        # b, 64, 16, 16 //處理成64維, 16*16
        layer2.add_module('relu2', torch.nn.ReLU(True))
        layer2.add_module('pool2', torch.nn.MaxPool2d(2, 2)) # b, 64, 8, 8
        self.layer2 = layer2

        layer3 = torch.nn.Sequential()
        layer3.add_module('conv3', torch.nn.Conv2d(64, 128, 3, 1, padding=1))
        #b, 128, 8, 8 //處理成128維,4*4
        layer3.add_module('rellu3', torch.nn.ReLU(True))
        layer3.add_module('pool3', torch.nn.MaxPool2d(2, 2)) #b 128, 4, 4
        self.layer3 = layer3
        //全連接輸出
        layer4 = torch.nn.Sequential()
        layer4.add_module('fc1', torch.nn.Linear(2048, 512))
        layer4.add_module('fc_relu1', torch.nn.ReLU(True))
        layer4.add_module('fc2', torch.nn.Linear(512, 64))
        layer4.add_module('fc_relu2', torch.nn.ReLU(True))
        layer4.add_module('fc3', torch.nn.Linear(64, 10))
        self.layer4 = layer4

    def forward(self, x):
        conv1 = self.layer1(x)
        conv2 = self.layer2(conv1)
        conv3 = self.layer3(conv2)
        fc_input = conv3.view(conv3.size(0), -1)
        fc_out = self.layer4(fc_input)
        return fc_out

model = SimpleCNN()

print(model)

運行結果如下

SimpleCNN(
  (layer1): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (rellu3): ReLU(inplace)
    (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer4): Sequential(
    (fc1): Linear(in_features=2048, out_features=512, bias=True)
    (fc_relu1): ReLU(inplace)
    (fc2): Linear(in_features=512, out_features=64, bias=True)
    (fc_relu2): ReLU(inplace)
    (fc3): Linear(in_features=64, out_features=10, bias=True)
  )
)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章