封装识别接口

This commit is contained in:
2025-03-28 18:19:25 +08:00
parent a587bd804e
commit 553f5a48b0
3 changed files with 29 additions and 12 deletions

View File

@@ -2,7 +2,7 @@ import torch
from torchvision import transforms
from model import CNNClassifier
from PIL import Image
from io import BytesIO
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -20,13 +20,27 @@ transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if __name__ == '__main__':
image_path = './screen/0.png'
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
def solve(image_bytes):
_image = Image.open(BytesIO(image_bytes)).convert('RGB')
_image = transform(_image).unsqueeze(0).to(device)
output = model(image)
_, predicted = torch.max(output.data, 1)
_output = model(_image)
_, _predicted = torch.max(_output.data, 1)
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
_result = class_labels[_predicted.item()]
print(f"\nPredicted result: {_result}")
return _result
# if __name__ == '__main__':
# image_path = './screen/0.png'
#
# image = Image.open(image_path).convert('RGB')
# image = transform(image).unsqueeze(0).to(device)
#
# output = model(image)
# _, predicted = torch.max(output.data, 1)
#
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")