优化神经网络,缩减模型大小

This commit is contained in:
2025-03-28 18:07:58 +08:00
parent 2ad08423ce
commit a587bd804e
13 changed files with 91 additions and 69 deletions

3
.gitignore vendored
View File

@@ -3,11 +3,12 @@
# Work dir and files
.idea
.venv
__pycache__
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

View File

@@ -2,7 +2,7 @@ import ctypes
def shot(target, width, height):
lib = ctypes.CDLL('../shotrdp.dll')
lib = ctypes.CDLL('./shotrdp.dll')
lib.GetScreen.argtypes = [
ctypes.c_char_p,
@@ -21,7 +21,7 @@ def shot(target, width, height):
print(ctypes.string_at(error_ptr).decode())
else:
result = ctypes.string_at(data, length.value)
with open('test.png', 'wb') as f:
with open('./screen/0.png', 'wb') as f:
f.write(result)
lib.Free(data)

72
gui/model.py Normal file
View File

@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
# 轻量双通道注意力
class LiteDualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.channel_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 16, 1, bias=False),
nn.SiLU(),
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False),
nn.Sigmoid()
)
self.spatial_attn = nn.Sequential(
nn.Conv2d(2, 1, 5, padding=2, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# 通道加权
ca = self.channel_attn(x) * x
# 空间特征统计(均值+标准差)
stats = torch.cat([
x.mean(dim=1, keepdim=True),
x.std(dim=1, keepdim=True)
], dim=1)
# 空间加权5x5卷积生成单通道注意力
sa = self.spatial_attn(stats) * ca
return sa
class CNNClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
LiteDualAttention(32),
nn.BatchNorm2d(32),
nn.Hardswish(inplace=True),
nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False),
LiteDualAttention(48),
nn.BatchNorm2d(48),
nn.Hardswish(inplace=True),
nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False),
LiteDualAttention(64),
nn.BatchNorm2d(64),
nn.Hardswish(inplace=True),
nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False),
LiteDualAttention(96),
nn.BatchNorm2d(96),
nn.Hardswish(inplace=True)
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(96, 256),
nn.Hardswish(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.classifier(self.features(x))

View File

@@ -1,20 +1,18 @@
import torch
from torchvision import transforms
from model import CNNWithUltraSimplifiedAttention
from model import CNNClassifier
from PIL import Image
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练模型
model_path = './rdp_model.pth'
model = CNNWithUltraSimplifiedAttention(4)
model_path = './shotrdp.pth'
model = CNNClassifier(len(class_labels))
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)
print(f"Model loaded from {model_path}")
# 数据预处理
transform = transforms.Compose([
transforms.Resize((1024, 800)),
@@ -22,14 +20,13 @@ transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if __name__ == '__main__':
image_path = './screen/0.png'
# 预测截图类别
image_path = './test.png'
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
output = model(image)
_, predicted = torch.max(output.data, 1)
output = model(image)
_, predicted = torch.max(output.data, 1)
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")

BIN
gui/screen/0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 543 KiB

BIN
gui/screen/1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 499 KiB

BIN
gui/screen/2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 610 KiB

BIN
gui/screen/3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 559 KiB

BIN
gui/screen/4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 612 KiB

BIN
gui/shotrdp.dll Normal file

Binary file not shown.

BIN
gui/shotrdp.pth Normal file

Binary file not shown.

View File

@@ -6,7 +6,7 @@ import warnings
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from model import CNNWithUltraSimplifiedAttention
from model import CNNClassifier
warnings.filterwarnings("ignore", category=UserWarning)
@@ -34,7 +34,7 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 初始化模型
num_classes = len(full_dataset.classes)
model = CNNWithUltraSimplifiedAttention(num_classes)
model = CNNClassifier(num_classes)
# 初始化损失函数和优化器
criterion = nn.CrossEntropyLoss()
@@ -44,7 +44,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 训练模型
num_epochs = 4
num_epochs = 16
model.to(device)
for epoch in range(num_epochs):
model.train()
@@ -71,7 +71,7 @@ if __name__ == '__main__':
# 测试模型
load_saved_model = True
if load_saved_model and os.path.exists(model_path):
model = CNNWithUltraSimplifiedAttention(num_classes)
model = CNNClassifier(num_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)

View File

@@ -1,48 +0,0 @@
import torch.nn as nn
# 更简化的注意力模块
class UltraSimplifiedAttention(nn.Module):
def __init__(self, in_channels):
super(UltraSimplifiedAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 64, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 64, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# CNN结合更简化的注意力机制模型
class CNNWithUltraSimplifiedAttention(nn.Module):
def __init__(self, num_classes):
super(CNNWithUltraSimplifiedAttention, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.attention1 = UltraSimplifiedAttention(16)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.attention2 = UltraSimplifiedAttention(32)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc_input_dim = 32 * 256 * 200
self.fc1 = nn.Linear(self.fc_input_dim, 128)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(self.relu1(self.attention1(self.conv1(x))))
x = self.pool2(self.relu2(self.attention2(self.conv2(x))))
x = x.view(-1, self.fc_input_dim)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x