mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
调整模型参数
This commit is contained in:
76
gui/model.py
76
gui/model.py
@@ -2,20 +2,20 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
# 轻量双通道注意力
|
class DualAttention(nn.Module):
|
||||||
class LiteDualAttention(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
middle_channels = max(in_channels // 32, 1)
|
||||||
self.channel_attn = nn.Sequential(
|
self.channel_attn = nn.Sequential(
|
||||||
nn.AdaptiveAvgPool2d(1),
|
nn.AdaptiveAvgPool2d(1),
|
||||||
nn.Conv2d(in_channels, in_channels // 16, 1, bias=False),
|
nn.Conv2d(in_channels, middle_channels, 1, bias=False),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False),
|
nn.Conv2d(middle_channels, in_channels, 1, bias=False),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.spatial_attn = nn.Sequential(
|
self.spatial_attn = nn.Sequential(
|
||||||
nn.Conv2d(2, 1, 5, padding=2, bias=False),
|
nn.Conv2d(2, 1, 3, padding=1, bias=False),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,43 +29,77 @@ class LiteDualAttention(nn.Module):
|
|||||||
x.std(dim=1, keepdim=True)
|
x.std(dim=1, keepdim=True)
|
||||||
], dim=1)
|
], dim=1)
|
||||||
|
|
||||||
# 空间加权(5x5卷积生成单通道注意力)
|
# 空间加权
|
||||||
sa = self.spatial_attn(stats) * ca
|
sa = self.spatial_attn(stats) * ca
|
||||||
return sa
|
return sa
|
||||||
|
|
||||||
|
|
||||||
class CNNClassifier(nn.Module):
|
class CustomCNN(nn.Module):
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.features = nn.Sequential(
|
self.features = nn.Sequential(
|
||||||
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
|
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
||||||
LiteDualAttention(32),
|
DualAttention(24),
|
||||||
nn.BatchNorm2d(32),
|
nn.BatchNorm2d(24),
|
||||||
nn.Hardswish(inplace=True),
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False),
|
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
|
||||||
LiteDualAttention(48),
|
DualAttention(36),
|
||||||
|
nn.BatchNorm2d(36),
|
||||||
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
|
||||||
|
DualAttention(48),
|
||||||
nn.BatchNorm2d(48),
|
nn.BatchNorm2d(48),
|
||||||
nn.Hardswish(inplace=True),
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False),
|
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
|
||||||
LiteDualAttention(64),
|
DualAttention(64),
|
||||||
nn.BatchNorm2d(64),
|
nn.BatchNorm2d(64),
|
||||||
nn.Hardswish(inplace=True),
|
|
||||||
|
|
||||||
nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False),
|
|
||||||
LiteDualAttention(96),
|
|
||||||
nn.BatchNorm2d(96),
|
|
||||||
nn.Hardswish(inplace=True)
|
nn.Hardswish(inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.AdaptiveAvgPool2d(1),
|
nn.AdaptiveAvgPool2d(1),
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(96, 256),
|
nn.Linear(64, 192),
|
||||||
nn.Hardswish(inplace=True),
|
nn.Hardswish(inplace=True),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.Linear(256, num_classes)
|
nn.Linear(192, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.classifier(self.features(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleCNN(nn.Module):
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(24),
|
||||||
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(36),
|
||||||
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(48),
|
||||||
|
nn.Hardswish(inplace=True),
|
||||||
|
|
||||||
|
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.Hardswish(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d(1),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(64, 192),
|
||||||
|
nn.Hardswish(inplace=True),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(192, num_classes)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
Reference in New Issue
Block a user