diff --git a/gui/model.py b/gui/model.py index bf6c4d6..475025a 100644 --- a/gui/model.py +++ b/gui/model.py @@ -2,20 +2,20 @@ import torch import torch.nn as nn -# 轻量双通道注意力 -class LiteDualAttention(nn.Module): +class DualAttention(nn.Module): def __init__(self, in_channels): super().__init__() + middle_channels = max(in_channels // 32, 1) self.channel_attn = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, in_channels // 16, 1, bias=False), + nn.Conv2d(in_channels, middle_channels, 1, bias=False), nn.SiLU(), - nn.Conv2d(in_channels // 16, in_channels, 1, bias=False), + nn.Conv2d(middle_channels, in_channels, 1, bias=False), nn.Sigmoid() ) self.spatial_attn = nn.Sequential( - nn.Conv2d(2, 1, 5, padding=2, bias=False), + nn.Conv2d(2, 1, 3, padding=1, bias=False), nn.Sigmoid() ) @@ -29,43 +29,77 @@ class LiteDualAttention(nn.Module): x.std(dim=1, keepdim=True) ], dim=1) - # 空间加权(5x5卷积生成单通道注意力) + # 空间加权 sa = self.spatial_attn(stats) * ca return sa -class CNNClassifier(nn.Module): +class CustomCNN(nn.Module): def __init__(self, num_classes): super().__init__() self.features = nn.Sequential( - nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False), - LiteDualAttention(32), - nn.BatchNorm2d(32), + nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), + DualAttention(24), + nn.BatchNorm2d(24), nn.Hardswish(inplace=True), - nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False), - LiteDualAttention(48), + nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False), + DualAttention(36), + nn.BatchNorm2d(36), + nn.Hardswish(inplace=True), + + nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False), + DualAttention(48), nn.BatchNorm2d(48), nn.Hardswish(inplace=True), - nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False), - LiteDualAttention(64), + nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False), + DualAttention(64), nn.BatchNorm2d(64), - nn.Hardswish(inplace=True), - - nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False), - LiteDualAttention(96), - nn.BatchNorm2d(96), nn.Hardswish(inplace=True) ) self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), - nn.Linear(96, 256), + nn.Linear(64, 192), nn.Hardswish(inplace=True), nn.Dropout(0.3), - nn.Linear(256, num_classes) + nn.Linear(192, num_classes) + ) + + def forward(self, x): + return self.classifier(self.features(x)) + + +class SimpleCNN(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(24), + nn.Hardswish(inplace=True), + + nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(36), + nn.Hardswish(inplace=True), + + nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(48), + nn.Hardswish(inplace=True), + + nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False), + nn.BatchNorm2d(64), + nn.Hardswish(inplace=True) + ) + + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(64, 192), + nn.Hardswish(inplace=True), + nn.Dropout(0.3), + nn.Linear(192, num_classes) ) def forward(self, x):