加入CNN+Attention的图像分类模块

This commit is contained in:
2025-03-28 16:01:42 +08:00
parent b5fcc72f63
commit f7f26c709b
3 changed files with 175 additions and 0 deletions

48
nn/model.py Normal file
View File

@@ -0,0 +1,48 @@
import torch.nn as nn
# 更简化的注意力模块
class UltraSimplifiedAttention(nn.Module):
def __init__(self, in_channels):
super(UltraSimplifiedAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 64, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 64, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# CNN结合更简化的注意力机制模型
class CNNWithUltraSimplifiedAttention(nn.Module):
def __init__(self, num_classes):
super(CNNWithUltraSimplifiedAttention, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.attention1 = UltraSimplifiedAttention(16)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.attention2 = UltraSimplifiedAttention(32)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc_input_dim = 32 * 256 * 200
self.fc1 = nn.Linear(self.fc_input_dim, 128)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(self.relu1(self.attention1(self.conv1(x))))
x = self.pool2(self.relu2(self.attention2(self.conv2(x))))
x = x.view(-1, self.fc_input_dim)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x