import torch from torchvision import transforms from model import CNNWithUltraSimplifiedAttention from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载预训练模型 model_path = './rdp_model.pth' model = CNNWithUltraSimplifiedAttention(4) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) print(f"Model loaded from {model_path}") # 数据预处理 transform = transforms.Compose([ transforms.Resize((1024, 800)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 预测截图类别 image_path = './test.png' class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0).to(device) output = model(image) _, predicted = torch.max(output.data, 1) print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")