You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
2.0 KiB
56 lines
2.0 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Vgg16(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Vgg16, self).__init__()
|
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, X):
|
|
h = F.relu(self.conv1_1(X))
|
|
h = F.relu(self.conv1_2(h))
|
|
relu1_2 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv2_1(h))
|
|
h = F.relu(self.conv2_2(h))
|
|
relu2_2 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv3_1(h))
|
|
h = F.relu(self.conv3_2(h))
|
|
h = F.relu(self.conv3_3(h))
|
|
relu3_3 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv4_1(h))
|
|
h = F.relu(self.conv4_2(h))
|
|
h = F.relu(self.conv4_3(h))
|
|
relu4_3 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv5_1(h))
|
|
h = F.relu(self.conv5_2(h))
|
|
h = F.relu(self.conv5_3(h))
|
|
relu5_3 = h
|
|
|
|
return [relu1_2, relu2_2, relu3_3, relu4_3, relu5_3]
|