mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
加入CNN+Attention的图像分类模块
This commit is contained in:
35
nn/predict.py
Normal file
35
nn/predict.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from model import CNNWithUltraSimplifiedAttention
|
||||
from PIL import Image
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载预训练模型
|
||||
model_path = './rdp_model.pth'
|
||||
model = CNNWithUltraSimplifiedAttention(4)
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.to(device)
|
||||
|
||||
print(f"Model loaded from {model_path}")
|
||||
|
||||
|
||||
# 数据预处理
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((1024, 800)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
|
||||
# 预测截图类别
|
||||
image_path = './test.png'
|
||||
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
||||
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = transform(image).unsqueeze(0).to(device)
|
||||
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||
Reference in New Issue
Block a user