mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import torch.nn as nn
|
|
|
|
|
|
# 更简化的注意力模块
|
|
class UltraSimplifiedAttention(nn.Module):
|
|
def __init__(self, in_channels):
|
|
super(UltraSimplifiedAttention, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_channels, in_channels // 64, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(in_channels // 64, in_channels, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, c, _, _ = x.size()
|
|
y = self.avg_pool(x).view(b, c)
|
|
y = self.fc(y).view(b, c, 1, 1)
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
# CNN结合更简化的注意力机制模型
|
|
class CNNWithUltraSimplifiedAttention(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(CNNWithUltraSimplifiedAttention, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
|
self.attention1 = UltraSimplifiedAttention(16)
|
|
self.relu1 = nn.ReLU(inplace=True)
|
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
|
self.attention2 = UltraSimplifiedAttention(32)
|
|
self.relu2 = nn.ReLU(inplace=True)
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
self.fc_input_dim = 32 * 256 * 200
|
|
self.fc1 = nn.Linear(self.fc_input_dim, 128)
|
|
self.relu3 = nn.ReLU(inplace=True)
|
|
self.fc2 = nn.Linear(128, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.pool1(self.relu1(self.attention1(self.conv1(x))))
|
|
x = self.pool2(self.relu2(self.attention2(self.conv2(x))))
|
|
x = x.view(-1, self.fc_input_dim)
|
|
x = self.relu3(self.fc1(x))
|
|
x = self.fc2(x)
|
|
return x |