mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class DualAttention(nn.Module):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
middle_channels = max(in_channels // 32, 1)
|
|
self.channel_attn = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(in_channels, middle_channels, 1, bias=False),
|
|
nn.SiLU(),
|
|
nn.Conv2d(middle_channels, in_channels, 1, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
self.spatial_attn = nn.Sequential(
|
|
nn.Conv2d(2, 1, 3, padding=1, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
# 通道加权
|
|
ca = self.channel_attn(x) * x
|
|
|
|
# 空间特征统计(均值+标准差)
|
|
stats = torch.cat([
|
|
x.mean(dim=1, keepdim=True),
|
|
x.std(dim=1, keepdim=True)
|
|
], dim=1)
|
|
|
|
# 空间加权
|
|
sa = self.spatial_attn(stats) * ca
|
|
return sa
|
|
|
|
|
|
class CustomCNN(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
|
DualAttention(24),
|
|
nn.BatchNorm2d(24),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
|
|
DualAttention(36),
|
|
nn.BatchNorm2d(36),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
|
|
DualAttention(48),
|
|
nn.BatchNorm2d(48),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
|
|
DualAttention(64),
|
|
nn.BatchNorm2d(64),
|
|
nn.Hardswish(inplace=True)
|
|
)
|
|
|
|
self.classifier = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Flatten(),
|
|
nn.Linear(64, 192),
|
|
nn.Hardswish(inplace=True),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(192, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.classifier(self.features(x))
|
|
|
|
|
|
class SimpleCNN(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
|
nn.BatchNorm2d(24),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
|
|
nn.BatchNorm2d(36),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
|
|
nn.BatchNorm2d(48),
|
|
nn.Hardswish(inplace=True),
|
|
|
|
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
|
|
nn.BatchNorm2d(64),
|
|
nn.Hardswish(inplace=True)
|
|
)
|
|
|
|
self.classifier = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Flatten(),
|
|
nn.Linear(64, 192),
|
|
nn.Hardswish(inplace=True),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(192, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.classifier(self.features(x))
|