更新模型调用

This commit is contained in:
2025-05-09 10:52:09 +08:00
committed by GitHub
parent d3d529ed98
commit c505c68553

View File

@@ -1,6 +1,6 @@
import torch import torch
from torchvision import transforms from torchvision import transforms
from model import CNNClassifier from model import CustomCNN
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练模型 # 加载预训练模型
model_path = './shotrdp.pth' model_path = './shotrdp.pth'
model = CNNClassifier(len(class_labels)) model = CustomCNN(len(class_labels))
model.load_state_dict(torch.load(model_path, weights_only=True)) model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device) model.to(device)
@@ -34,13 +34,13 @@ def solve(image_bytes):
return _result return _result
# if __name__ == '__main__': if __name__ == '__main__':
# image_path = './screen/0.png' image_path = './screen/0.png'
#
# image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert('RGB')
# image = transform(image).unsqueeze(0).to(device) image = transform(image).unsqueeze(0).to(device)
#
# output = model(image) output = model(image)
# _, predicted = torch.max(output.data, 1) _, predicted = torch.max(output.data, 1)
#
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")