From c505c68553034355359c861ca3b487cf4cdfc3fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=BB=E7=81=B5?= Date: Fri, 9 May 2025 10:52:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=A8=A1=E5=9E=8B=E8=B0=83?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gui/predict.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/gui/predict.py b/gui/predict.py index e7e6033..c6b58a3 100644 --- a/gui/predict.py +++ b/gui/predict.py @@ -1,6 +1,6 @@ import torch from torchvision import transforms -from model import CNNClassifier +from model import CustomCNN from PIL import Image from io import BytesIO @@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载预训练模型 model_path = './shotrdp.pth' -model = CNNClassifier(len(class_labels)) +model = CustomCNN(len(class_labels)) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) @@ -34,13 +34,13 @@ def solve(image_bytes): return _result -# if __name__ == '__main__': -# image_path = './screen/0.png' -# -# image = Image.open(image_path).convert('RGB') -# image = transform(image).unsqueeze(0).to(device) -# -# output = model(image) -# _, predicted = torch.max(output.data, 1) -# -# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") +if __name__ == '__main__': + image_path = './screen/0.png' + + image = Image.open(image_path).convert('RGB') + image = transform(image).unsqueeze(0).to(device) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")