调整模型参数

This commit is contained in:
2025-04-29 22:58:58 +08:00
committed by GitHub
parent f0701ef524
commit 5d7b573a79

View File

@@ -2,20 +2,20 @@ import torch
import torch.nn as nn import torch.nn as nn
# 轻量双通道注意力 class DualAttention(nn.Module):
class LiteDualAttention(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
middle_channels = max(in_channels // 32, 1)
self.channel_attn = nn.Sequential( self.channel_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 16, 1, bias=False), nn.Conv2d(in_channels, middle_channels, 1, bias=False),
nn.SiLU(), nn.SiLU(),
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False), nn.Conv2d(middle_channels, in_channels, 1, bias=False),
nn.Sigmoid() nn.Sigmoid()
) )
self.spatial_attn = nn.Sequential( self.spatial_attn = nn.Sequential(
nn.Conv2d(2, 1, 5, padding=2, bias=False), nn.Conv2d(2, 1, 3, padding=1, bias=False),
nn.Sigmoid() nn.Sigmoid()
) )
@@ -29,43 +29,77 @@ class LiteDualAttention(nn.Module):
x.std(dim=1, keepdim=True) x.std(dim=1, keepdim=True)
], dim=1) ], dim=1)
# 空间加权5x5卷积生成单通道注意力 # 空间加权
sa = self.spatial_attn(stats) * ca sa = self.spatial_attn(stats) * ca
return sa return sa
class CNNClassifier(nn.Module): class CustomCNN(nn.Module):
def __init__(self, num_classes): def __init__(self, num_classes):
super().__init__() super().__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False), nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
LiteDualAttention(32), DualAttention(24),
nn.BatchNorm2d(32), nn.BatchNorm2d(24),
nn.Hardswish(inplace=True), nn.Hardswish(inplace=True),
nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False), nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
LiteDualAttention(48), DualAttention(36),
nn.BatchNorm2d(36),
nn.Hardswish(inplace=True),
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
DualAttention(48),
nn.BatchNorm2d(48), nn.BatchNorm2d(48),
nn.Hardswish(inplace=True), nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False), nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
LiteDualAttention(64), DualAttention(64),
nn.BatchNorm2d(64), nn.BatchNorm2d(64),
nn.Hardswish(inplace=True),
nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False),
LiteDualAttention(96),
nn.BatchNorm2d(96),
nn.Hardswish(inplace=True) nn.Hardswish(inplace=True)
) )
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Flatten(), nn.Flatten(),
nn.Linear(96, 256), nn.Linear(64, 192),
nn.Hardswish(inplace=True), nn.Hardswish(inplace=True),
nn.Dropout(0.3), nn.Dropout(0.3),
nn.Linear(256, num_classes) nn.Linear(192, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(24),
nn.Hardswish(inplace=True),
nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(36),
nn.Hardswish(inplace=True),
nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(48),
nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False),
nn.BatchNorm2d(64),
nn.Hardswish(inplace=True)
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 192),
nn.Hardswish(inplace=True),
nn.Dropout(0.3),
nn.Linear(192, num_classes)
) )
def forward(self, x): def forward(self, x):