import torch import torch.nn as nn class DualAttention(nn.Module): def __init__(self, in_channels): super().__init__() middle_channels = max(in_channels // 32, 1) self.channel_attn = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, middle_channels, 1, bias=False), nn.SiLU(), nn.Conv2d(middle_channels, in_channels, 1, bias=False), nn.Sigmoid() ) self.spatial_attn = nn.Sequential( nn.Conv2d(2, 1, 3, padding=1, bias=False), nn.Sigmoid() ) def forward(self, x): # 通道加权 ca = self.channel_attn(x) * x # 空间特征统计(均值+标准差) stats = torch.cat([ x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True) ], dim=1) # 空间加权 sa = self.spatial_attn(stats) * ca return sa class CustomCNN(nn.Module): def __init__(self, num_classes): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), DualAttention(24), nn.BatchNorm2d(24), nn.Hardswish(inplace=True), nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False), DualAttention(36), nn.BatchNorm2d(36), nn.Hardswish(inplace=True), nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False), DualAttention(48), nn.BatchNorm2d(48), nn.Hardswish(inplace=True), nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False), DualAttention(64), nn.BatchNorm2d(64), nn.Hardswish(inplace=True) ) self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, 192), nn.Hardswish(inplace=True), nn.Dropout(0.3), nn.Linear(192, num_classes) ) def forward(self, x): return self.classifier(self.features(x)) class SimpleCNN(nn.Module): def __init__(self, num_classes): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(24), nn.Hardswish(inplace=True), nn.Conv2d(24, 36, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(36), nn.Hardswish(inplace=True), nn.Conv2d(36, 48, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(48), nn.Hardswish(inplace=True), nn.Conv2d(48, 64, 3, padding=2, dilation=2, bias=False), nn.BatchNorm2d(64), nn.Hardswish(inplace=True) ) self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, 192), nn.Hardswish(inplace=True), nn.Dropout(0.3), nn.Linear(192, num_classes) ) def forward(self, x): return self.classifier(self.features(x))