mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
import torch
|
|
from torchvision import transforms
|
|
from model import CNNClassifier
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# 加载预训练模型
|
|
model_path = './shotrdp.pth'
|
|
model = CNNClassifier(len(class_labels))
|
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
|
model.to(device)
|
|
|
|
# 数据预处理
|
|
transform = transforms.Compose([
|
|
transforms.Resize((1024, 800)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
def solve(image_bytes):
|
|
_image = Image.open(BytesIO(image_bytes)).convert('RGB')
|
|
_image = transform(_image).unsqueeze(0).to(device)
|
|
|
|
_output = model(_image)
|
|
_, _predicted = torch.max(_output.data, 1)
|
|
|
|
_result = class_labels[_predicted.item()]
|
|
|
|
# print(f"\nPredicted result: {_result}")
|
|
return _result
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# image_path = './screen/0.png'
|
|
#
|
|
# image = Image.open(image_path).convert('RGB')
|
|
# image = transform(image).unsqueeze(0).to(device)
|
|
#
|
|
# output = model(image)
|
|
# _, predicted = torch.max(output.data, 1)
|
|
#
|
|
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|