import torch from torchvision import transforms from model import CNNClassifier from PIL import Image from io import BytesIO class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载预训练模型 model_path = './shotrdp.pth' model = CNNClassifier(len(class_labels)) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) # 数据预处理 transform = transforms.Compose([ transforms.Resize((1024, 800)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def solve(image_bytes): _image = Image.open(BytesIO(image_bytes)).convert('RGB') _image = transform(_image).unsqueeze(0).to(device) _output = model(_image) _, _predicted = torch.max(_output.data, 1) _result = class_labels[_predicted.item()] print(f"\nPredicted result: {_result}") return _result # if __name__ == '__main__': # image_path = './screen/0.png' # # image = Image.open(image_path).convert('RGB') # image = transform(image).unsqueeze(0).to(device) # # output = model(image) # _, predicted = torch.max(output.data, 1) # # print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")