更新模型调用

This commit is contained in:
2025-05-09 10:52:09 +08:00
committed by GitHub
parent d3d529ed98
commit c505c68553

View File

@@ -1,6 +1,6 @@
import torch
from torchvision import transforms
from model import CNNClassifier
from model import CustomCNN
from PIL import Image
from io import BytesIO
@@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练模型
model_path = './shotrdp.pth'
model = CNNClassifier(len(class_labels))
model = CustomCNN(len(class_labels))
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)
@@ -34,13 +34,13 @@ def solve(image_bytes):
return _result
# if __name__ == '__main__':
# image_path = './screen/0.png'
#
# image = Image.open(image_path).convert('RGB')
# image = transform(image).unsqueeze(0).to(device)
#
# output = model(image)
# _, predicted = torch.max(output.data, 1)
#
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
if __name__ == '__main__':
image_path = './screen/0.png'
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
output = model(image)
_, predicted = torch.max(output.data, 1)
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")