diff --git a/gui/predict.py b/gui/predict.py index e7e6033..c6b58a3 100644 --- a/gui/predict.py +++ b/gui/predict.py @@ -1,6 +1,6 @@ import torch from torchvision import transforms -from model import CNNClassifier +from model import CustomCNN from PIL import Image from io import BytesIO @@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载预训练模型 model_path = './shotrdp.pth' -model = CNNClassifier(len(class_labels)) +model = CustomCNN(len(class_labels)) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) @@ -34,13 +34,13 @@ def solve(image_bytes): return _result -# if __name__ == '__main__': -# image_path = './screen/0.png' -# -# image = Image.open(image_path).convert('RGB') -# image = transform(image).unsqueeze(0).to(device) -# -# output = model(image) -# _, predicted = torch.max(output.data, 1) -# -# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") +if __name__ == '__main__': + image_path = './screen/0.png' + + image = Image.open(image_path).convert('RGB') + image = transform(image).unsqueeze(0).to(device) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")