mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
更新模型调用
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from model import CNNClassifier
|
||||
from model import CustomCNN
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
@@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载预训练模型
|
||||
model_path = './shotrdp.pth'
|
||||
model = CNNClassifier(len(class_labels))
|
||||
model = CustomCNN(len(class_labels))
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.to(device)
|
||||
|
||||
@@ -34,13 +34,13 @@ def solve(image_bytes):
|
||||
return _result
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# image_path = './screen/0.png'
|
||||
#
|
||||
# image = Image.open(image_path).convert('RGB')
|
||||
# image = transform(image).unsqueeze(0).to(device)
|
||||
#
|
||||
# output = model(image)
|
||||
# _, predicted = torch.max(output.data, 1)
|
||||
#
|
||||
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||
if __name__ == '__main__':
|
||||
image_path = './screen/0.png'
|
||||
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = transform(image).unsqueeze(0).to(device)
|
||||
|
||||
output = model(image)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||
|
||||
Reference in New Issue
Block a user