mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
优化神经网络,缩减模型大小
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,11 +3,12 @@
|
||||
|
||||
# Work dir and files
|
||||
.idea
|
||||
.venv
|
||||
__pycache__
|
||||
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import ctypes
|
||||
|
||||
|
||||
def shot(target, width, height):
|
||||
lib = ctypes.CDLL('../shotrdp.dll')
|
||||
lib = ctypes.CDLL('./shotrdp.dll')
|
||||
|
||||
lib.GetScreen.argtypes = [
|
||||
ctypes.c_char_p,
|
||||
@@ -21,7 +21,7 @@ def shot(target, width, height):
|
||||
print(ctypes.string_at(error_ptr).decode())
|
||||
else:
|
||||
result = ctypes.string_at(data, length.value)
|
||||
with open('test.png', 'wb') as f:
|
||||
with open('./screen/0.png', 'wb') as f:
|
||||
f.write(result)
|
||||
|
||||
lib.Free(data)
|
||||
|
||||
72
gui/model.py
Normal file
72
gui/model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# 轻量双通道注意力
|
||||
class LiteDualAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.channel_attn = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels // 16, 1, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.spatial_attn = nn.Sequential(
|
||||
nn.Conv2d(2, 1, 5, padding=2, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 通道加权
|
||||
ca = self.channel_attn(x) * x
|
||||
|
||||
# 空间特征统计(均值+标准差)
|
||||
stats = torch.cat([
|
||||
x.mean(dim=1, keepdim=True),
|
||||
x.std(dim=1, keepdim=True)
|
||||
], dim=1)
|
||||
|
||||
# 空间加权(5x5卷积生成单通道注意力)
|
||||
sa = self.spatial_attn(stats) * ca
|
||||
return sa
|
||||
|
||||
|
||||
class CNNClassifier(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
|
||||
LiteDualAttention(32),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.Hardswish(inplace=True),
|
||||
|
||||
nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False),
|
||||
LiteDualAttention(48),
|
||||
nn.BatchNorm2d(48),
|
||||
nn.Hardswish(inplace=True),
|
||||
|
||||
nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False),
|
||||
LiteDualAttention(64),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.Hardswish(inplace=True),
|
||||
|
||||
nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False),
|
||||
LiteDualAttention(96),
|
||||
nn.BatchNorm2d(96),
|
||||
nn.Hardswish(inplace=True)
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Flatten(),
|
||||
nn.Linear(96, 256),
|
||||
nn.Hardswish(inplace=True),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.classifier(self.features(x))
|
||||
@@ -1,20 +1,18 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from model import CNNWithUltraSimplifiedAttention
|
||||
from model import CNNClassifier
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载预训练模型
|
||||
model_path = './rdp_model.pth'
|
||||
model = CNNWithUltraSimplifiedAttention(4)
|
||||
model_path = './shotrdp.pth'
|
||||
model = CNNClassifier(len(class_labels))
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.to(device)
|
||||
|
||||
print(f"Model loaded from {model_path}")
|
||||
|
||||
|
||||
# 数据预处理
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((1024, 800)),
|
||||
@@ -22,14 +20,13 @@ transform = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
if __name__ == '__main__':
|
||||
image_path = './screen/0.png'
|
||||
|
||||
# 预测截图类别
|
||||
image_path = './test.png'
|
||||
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = transform(image).unsqueeze(0).to(device)
|
||||
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = transform(image).unsqueeze(0).to(device)
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||
BIN
gui/screen/0.png
Normal file
BIN
gui/screen/0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 543 KiB |
BIN
gui/screen/1.png
Normal file
BIN
gui/screen/1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 499 KiB |
BIN
gui/screen/2.png
Normal file
BIN
gui/screen/2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 610 KiB |
BIN
gui/screen/3.png
Normal file
BIN
gui/screen/3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 559 KiB |
BIN
gui/screen/4.png
Normal file
BIN
gui/screen/4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 612 KiB |
BIN
gui/shotrdp.dll
Normal file
BIN
gui/shotrdp.dll
Normal file
Binary file not shown.
BIN
gui/shotrdp.pth
Normal file
BIN
gui/shotrdp.pth
Normal file
Binary file not shown.
@@ -6,7 +6,7 @@ import warnings
|
||||
from tqdm import tqdm
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from model import CNNWithUltraSimplifiedAttention
|
||||
from model import CNNClassifier
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
@@ -34,7 +34,7 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
|
||||
|
||||
# 初始化模型
|
||||
num_classes = len(full_dataset.classes)
|
||||
model = CNNWithUltraSimplifiedAttention(num_classes)
|
||||
model = CNNClassifier(num_classes)
|
||||
|
||||
# 初始化损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
@@ -44,7 +44,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 训练模型
|
||||
num_epochs = 4
|
||||
num_epochs = 16
|
||||
model.to(device)
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
@@ -71,7 +71,7 @@ if __name__ == '__main__':
|
||||
# 测试模型
|
||||
load_saved_model = True
|
||||
if load_saved_model and os.path.exists(model_path):
|
||||
model = CNNWithUltraSimplifiedAttention(num_classes)
|
||||
model = CNNClassifier(num_classes)
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.to(device)
|
||||
|
||||
48
nn/model.py
48
nn/model.py
@@ -1,48 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# 更简化的注意力模块
|
||||
class UltraSimplifiedAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super(UltraSimplifiedAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels // 64, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(in_channels // 64, in_channels, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
# CNN结合更简化的注意力机制模型
|
||||
class CNNWithUltraSimplifiedAttention(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(CNNWithUltraSimplifiedAttention, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
||||
self.attention1 = UltraSimplifiedAttention(16)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||
self.attention2 = UltraSimplifiedAttention(32)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.fc_input_dim = 32 * 256 * 200
|
||||
self.fc1 = nn.Linear(self.fc_input_dim, 128)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Linear(128, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool1(self.relu1(self.attention1(self.conv1(x))))
|
||||
x = self.pool2(self.relu2(self.attention2(self.conv2(x))))
|
||||
x = x.view(-1, self.fc_input_dim)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user