Files
ShotRDP/gui/model.py
2025-04-29 22:58:58 +08:00

107 lines
3.1 KiB
Python

import torch
import torch.nn as nn
class DualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
middle_channels = max(in_channels // 32, 1)
self.channel_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, middle_channels, 1, bias=False),
nn.SiLU(),
nn.Conv2d(middle_channels, in_channels, 1, bias=False),
nn.Sigmoid()
)
self.spatial_attn = nn.Sequential(
nn.Conv2d(2, 1, 3, padding=1, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# 通道加权
ca = self.channel_attn(x) * x
# 空间特征统计(均值+标准差)
stats = torch.cat([
x.mean(dim=1, keepdim=True),
x.std(dim=1, keepdim=True)
], dim=1)
# 空间加权
sa = self.spatial_attn(stats) * ca
return sa
class CustomCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
DualAttention(24),
nn.BatchNorm2d(24),
nn.Hardswish(inplace=True),
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
DualAttention(36),
nn.BatchNorm2d(36),
nn.Hardswish(inplace=True),
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
DualAttention(48),
nn.BatchNorm2d(48),
nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
DualAttention(64),
nn.BatchNorm2d(64),
nn.Hardswish(inplace=True)
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 192),
nn.Hardswish(inplace=True),
nn.Dropout(0.3),
nn.Linear(192, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(24),
nn.Hardswish(inplace=True),
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(36),
nn.Hardswish(inplace=True),
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(48),
nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
nn.BatchNorm2d(64),
nn.Hardswish(inplace=True)
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 192),
nn.Hardswish(inplace=True),
nn.Dropout(0.3),
nn.Linear(192, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))