Files
ShotRDP/gui/model.py

73 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
# 轻量双通道注意力
class LiteDualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.channel_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 16, 1, bias=False),
nn.SiLU(),
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False),
nn.Sigmoid()
)
self.spatial_attn = nn.Sequential(
nn.Conv2d(2, 1, 5, padding=2, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# 通道加权
ca = self.channel_attn(x) * x
# 空间特征统计(均值+标准差)
stats = torch.cat([
x.mean(dim=1, keepdim=True),
x.std(dim=1, keepdim=True)
], dim=1)
# 空间加权5x5卷积生成单通道注意力
sa = self.spatial_attn(stats) * ca
return sa
class CNNClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
LiteDualAttention(32),
nn.BatchNorm2d(32),
nn.Hardswish(inplace=True),
nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False),
LiteDualAttention(48),
nn.BatchNorm2d(48),
nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False),
LiteDualAttention(64),
nn.BatchNorm2d(64),
nn.Hardswish(inplace=True),
nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False),
LiteDualAttention(96),
nn.BatchNorm2d(96),
nn.Hardswish(inplace=True)
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(96, 256),
nn.Hardswish(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))