mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
更新模型调用
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from model import CNNClassifier
|
from model import CustomCNN
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||||||
|
|
||||||
# 加载预训练模型
|
# 加载预训练模型
|
||||||
model_path = './shotrdp.pth'
|
model_path = './shotrdp.pth'
|
||||||
model = CNNClassifier(len(class_labels))
|
model = CustomCNN(len(class_labels))
|
||||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@@ -34,13 +34,13 @@ def solve(image_bytes):
|
|||||||
return _result
|
return _result
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# image_path = './screen/0.png'
|
image_path = './screen/0.png'
|
||||||
#
|
|
||||||
# image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert('RGB')
|
||||||
# image = transform(image).unsqueeze(0).to(device)
|
image = transform(image).unsqueeze(0).to(device)
|
||||||
#
|
|
||||||
# output = model(image)
|
output = model(image)
|
||||||
# _, predicted = torch.max(output.data, 1)
|
_, predicted = torch.max(output.data, 1)
|
||||||
#
|
|
||||||
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||||
|
|||||||
Reference in New Issue
Block a user