From f7f26c709bc4bf331a63b2c391462eb7f09c4c9f Mon Sep 17 00:00:00 2001 From: yv1ing Date: Fri, 28 Mar 2025 16:01:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5CNN+Attention=E7=9A=84?= =?UTF-8?q?=E5=9B=BE=E5=83=8F=E5=88=86=E7=B1=BB=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nn/model.py | 48 +++++++++++++++++++++++++++ nn/predict.py | 35 ++++++++++++++++++++ nn/train.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+) create mode 100644 nn/model.py create mode 100644 nn/predict.py create mode 100644 nn/train.py diff --git a/nn/model.py b/nn/model.py new file mode 100644 index 0000000..0d501e2 --- /dev/null +++ b/nn/model.py @@ -0,0 +1,48 @@ +import torch.nn as nn + + +# 更简化的注意力模块 +class UltraSimplifiedAttention(nn.Module): + def __init__(self, in_channels): + super(UltraSimplifiedAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels // 64, bias=False), + nn.ReLU(inplace=True), + nn.Linear(in_channels // 64, in_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +# CNN结合更简化的注意力机制模型 +class CNNWithUltraSimplifiedAttention(nn.Module): + def __init__(self, num_classes): + super(CNNWithUltraSimplifiedAttention, self).__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) + self.attention1 = UltraSimplifiedAttention(16) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) + self.attention2 = UltraSimplifiedAttention(32) + self.relu2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.fc_input_dim = 32 * 256 * 200 + self.fc1 = nn.Linear(self.fc_input_dim, 128) + self.relu3 = nn.ReLU(inplace=True) + self.fc2 = nn.Linear(128, num_classes) + + def forward(self, x): + x = self.pool1(self.relu1(self.attention1(self.conv1(x)))) + x = self.pool2(self.relu2(self.attention2(self.conv2(x)))) + x = x.view(-1, self.fc_input_dim) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return x \ No newline at end of file diff --git a/nn/predict.py b/nn/predict.py new file mode 100644 index 0000000..b5b3a23 --- /dev/null +++ b/nn/predict.py @@ -0,0 +1,35 @@ +import torch +from torchvision import transforms +from model import CNNWithUltraSimplifiedAttention +from PIL import Image + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 加载预训练模型 +model_path = './rdp_model.pth' +model = CNNWithUltraSimplifiedAttention(4) +model.load_state_dict(torch.load(model_path, weights_only=True)) +model.to(device) + +print(f"Model loaded from {model_path}") + + +# 数据预处理 +transform = transforms.Compose([ + transforms.Resize((1024, 800)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + + +# 预测截图类别 +image_path = './test.png' +class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] + +image = Image.open(image_path).convert('RGB') +image = transform(image).unsqueeze(0).to(device) + +output = model(image) +_, predicted = torch.max(output.data, 1) +print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") \ No newline at end of file diff --git a/nn/train.py b/nn/train.py new file mode 100644 index 0000000..8473019 --- /dev/null +++ b/nn/train.py @@ -0,0 +1,92 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +import warnings +from tqdm import tqdm +from torchvision import datasets, transforms +from torch.utils.data import DataLoader, random_split +from model import CNNWithUltraSimplifiedAttention + + +warnings.filterwarnings("ignore", category=UserWarning) + + +# 数据预处理 +transform = transforms.Compose([ + transforms.Resize((1024, 800)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +# 加载数据集 +data_dir = './datasets' +full_dataset = datasets.ImageFolder(data_dir, transform=transform) + +# 划分训练集和测试集 +train_size = int(0.8 * len(full_dataset)) +test_size = len(full_dataset) - train_size +train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) + +# 创建数据加载器 +train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) + +# 初始化模型 +num_classes = len(full_dataset.classes) +model = CNNWithUltraSimplifiedAttention(num_classes) + +# 初始化损失函数和优化器 +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=0.001) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +if __name__ == '__main__': + # 训练模型 + num_epochs = 4 + model.to(device) + for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') + for i, (images, labels) in enumerate(train_loader_tqdm): + images, labels = images.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + train_loader_tqdm.set_postfix(loss=running_loss / (i + 1)) + + print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}') + + # 保存模型 + model_path = 'rdp_model.pth' + torch.save(model.state_dict(), model_path) + + # 测试模型 + load_saved_model = True + if load_saved_model and os.path.exists(model_path): + model = CNNWithUltraSimplifiedAttention(num_classes) + model.load_state_dict(torch.load(model_path, weights_only=True)) + model.to(device) + + model.eval() + correct = 0 + total = 0 + + test_loader_tqdm = tqdm(test_loader, desc='Testing', unit='batch') + with torch.no_grad(): + for images, labels in test_loader_tqdm: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + test_loader_tqdm.set_postfix(accuracy=100 * correct / total) + + print(f'Accuracy on test set: {100 * correct / total}%')