mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
加入CNN+Attention的图像分类模块
This commit is contained in:
48
nn/model.py
Normal file
48
nn/model.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# 更简化的注意力模块
|
||||
class UltraSimplifiedAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super(UltraSimplifiedAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels // 64, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(in_channels // 64, in_channels, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
# CNN结合更简化的注意力机制模型
|
||||
class CNNWithUltraSimplifiedAttention(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(CNNWithUltraSimplifiedAttention, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
||||
self.attention1 = UltraSimplifiedAttention(16)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||
self.attention2 = UltraSimplifiedAttention(32)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.fc_input_dim = 32 * 256 * 200
|
||||
self.fc1 = nn.Linear(self.fc_input_dim, 128)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool1(self.relu1(self.attention1(self.conv1(x))))
|
||||
x = self.pool2(self.relu2(self.attention2(self.conv2(x))))
|
||||
x = x.view(-1, self.fc_input_dim)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user