diff --git a/.gitignore b/.gitignore index 117752e..abe4ff6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,11 +3,12 @@ # Work dir and files .idea +.venv +__pycache__ # Binaries for programs and plugins *.exe *.exe~ -*.dll *.so *.dylib diff --git a/gui/main.py b/gui/main.py index a743d0a..67a038e 100644 --- a/gui/main.py +++ b/gui/main.py @@ -2,7 +2,7 @@ import ctypes def shot(target, width, height): - lib = ctypes.CDLL('../shotrdp.dll') + lib = ctypes.CDLL('./shotrdp.dll') lib.GetScreen.argtypes = [ ctypes.c_char_p, @@ -21,7 +21,7 @@ def shot(target, width, height): print(ctypes.string_at(error_ptr).decode()) else: result = ctypes.string_at(data, length.value) - with open('test.png', 'wb') as f: + with open('./screen/0.png', 'wb') as f: f.write(result) lib.Free(data) diff --git a/gui/model.py b/gui/model.py new file mode 100644 index 0000000..bf6c4d6 --- /dev/null +++ b/gui/model.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + + +# 轻量双通道注意力 +class LiteDualAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.channel_attn = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels // 16, 1, bias=False), + nn.SiLU(), + nn.Conv2d(in_channels // 16, in_channels, 1, bias=False), + nn.Sigmoid() + ) + + self.spatial_attn = nn.Sequential( + nn.Conv2d(2, 1, 5, padding=2, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + # 通道加权 + ca = self.channel_attn(x) * x + + # 空间特征统计(均值+标准差) + stats = torch.cat([ + x.mean(dim=1, keepdim=True), + x.std(dim=1, keepdim=True) + ], dim=1) + + # 空间加权(5x5卷积生成单通道注意力) + sa = self.spatial_attn(stats) * ca + return sa + + +class CNNClassifier(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False), + LiteDualAttention(32), + nn.BatchNorm2d(32), + nn.Hardswish(inplace=True), + + nn.Conv2d(32, 48, 3, stride=2, padding=1, bias=False), + LiteDualAttention(48), + nn.BatchNorm2d(48), + nn.Hardswish(inplace=True), + + nn.Conv2d(48, 64, 3, stride=2, padding=1, bias=False), + LiteDualAttention(64), + nn.BatchNorm2d(64), + nn.Hardswish(inplace=True), + + nn.Conv2d(64, 96, 3, padding=2, dilation=2, bias=False), + LiteDualAttention(96), + nn.BatchNorm2d(96), + nn.Hardswish(inplace=True) + ) + + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(96, 256), + nn.Hardswish(inplace=True), + nn.Dropout(0.3), + nn.Linear(256, num_classes) + ) + + def forward(self, x): + return self.classifier(self.features(x)) diff --git a/nn/predict.py b/gui/predict.py similarity index 52% rename from nn/predict.py rename to gui/predict.py index b5b3a23..e3865e1 100644 --- a/nn/predict.py +++ b/gui/predict.py @@ -1,20 +1,18 @@ import torch from torchvision import transforms -from model import CNNWithUltraSimplifiedAttention +from model import CNNClassifier from PIL import Image +class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载预训练模型 -model_path = './rdp_model.pth' -model = CNNWithUltraSimplifiedAttention(4) +model_path = './shotrdp.pth' +model = CNNClassifier(len(class_labels)) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) -print(f"Model loaded from {model_path}") - - # 数据预处理 transform = transforms.Compose([ transforms.Resize((1024, 800)), @@ -22,14 +20,13 @@ transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) +if __name__ == '__main__': + image_path = './screen/0.png' -# 预测截图类别 -image_path = './test.png' -class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] + image = Image.open(image_path).convert('RGB') + image = transform(image).unsqueeze(0).to(device) -image = Image.open(image_path).convert('RGB') -image = transform(image).unsqueeze(0).to(device) + output = model(image) + _, predicted = torch.max(output.data, 1) -output = model(image) -_, predicted = torch.max(output.data, 1) -print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") \ No newline at end of file + print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") diff --git a/gui/screen/0.png b/gui/screen/0.png new file mode 100644 index 0000000..d00d039 Binary files /dev/null and b/gui/screen/0.png differ diff --git a/gui/screen/1.png b/gui/screen/1.png new file mode 100644 index 0000000..cfe8606 Binary files /dev/null and b/gui/screen/1.png differ diff --git a/gui/screen/2.png b/gui/screen/2.png new file mode 100644 index 0000000..087e39a Binary files /dev/null and b/gui/screen/2.png differ diff --git a/gui/screen/3.png b/gui/screen/3.png new file mode 100644 index 0000000..2440a05 Binary files /dev/null and b/gui/screen/3.png differ diff --git a/gui/screen/4.png b/gui/screen/4.png new file mode 100644 index 0000000..307bd58 Binary files /dev/null and b/gui/screen/4.png differ diff --git a/gui/shotrdp.dll b/gui/shotrdp.dll new file mode 100644 index 0000000..587d897 Binary files /dev/null and b/gui/shotrdp.dll differ diff --git a/gui/shotrdp.pth b/gui/shotrdp.pth new file mode 100644 index 0000000..5b5866a Binary files /dev/null and b/gui/shotrdp.pth differ diff --git a/nn/train.py b/gui/train.py similarity index 93% rename from nn/train.py rename to gui/train.py index 8473019..591329c 100644 --- a/nn/train.py +++ b/gui/train.py @@ -6,7 +6,7 @@ import warnings from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split -from model import CNNWithUltraSimplifiedAttention +from model import CNNClassifier warnings.filterwarnings("ignore", category=UserWarning) @@ -34,7 +34,7 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) # 初始化模型 num_classes = len(full_dataset.classes) -model = CNNWithUltraSimplifiedAttention(num_classes) +model = CNNClassifier(num_classes) # 初始化损失函数和优化器 criterion = nn.CrossEntropyLoss() @@ -44,7 +44,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == '__main__': # 训练模型 - num_epochs = 4 + num_epochs = 16 model.to(device) for epoch in range(num_epochs): model.train() @@ -71,7 +71,7 @@ if __name__ == '__main__': # 测试模型 load_saved_model = True if load_saved_model and os.path.exists(model_path): - model = CNNWithUltraSimplifiedAttention(num_classes) + model = CNNClassifier(num_classes) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) diff --git a/nn/model.py b/nn/model.py deleted file mode 100644 index 0d501e2..0000000 --- a/nn/model.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch.nn as nn - - -# 更简化的注意力模块 -class UltraSimplifiedAttention(nn.Module): - def __init__(self, in_channels): - super(UltraSimplifiedAttention, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(in_channels, in_channels // 64, bias=False), - nn.ReLU(inplace=True), - nn.Linear(in_channels // 64, in_channels, bias=False), - nn.Sigmoid() - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y.expand_as(x) - - -# CNN结合更简化的注意力机制模型 -class CNNWithUltraSimplifiedAttention(nn.Module): - def __init__(self, num_classes): - super(CNNWithUltraSimplifiedAttention, self).__init__() - self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) - self.attention1 = UltraSimplifiedAttention(16) - self.relu1 = nn.ReLU(inplace=True) - self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) - - self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) - self.attention2 = UltraSimplifiedAttention(32) - self.relu2 = nn.ReLU(inplace=True) - self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) - - self.fc_input_dim = 32 * 256 * 200 - self.fc1 = nn.Linear(self.fc_input_dim, 128) - self.relu3 = nn.ReLU(inplace=True) - self.fc2 = nn.Linear(128, num_classes) - - def forward(self, x): - x = self.pool1(self.relu1(self.attention1(self.conv1(x)))) - x = self.pool2(self.relu2(self.attention2(self.conv2(x)))) - x = x.view(-1, self.fc_input_dim) - x = self.relu3(self.fc1(x)) - x = self.fc2(x) - return x \ No newline at end of file