diff --git a/.gitignore b/.gitignore index abe4ff6..1225073 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # Work dir and files .idea .venv -__pycache__ +__pycache__/ # Binaries for programs and plugins *.exe diff --git a/gui/main.py b/gui/main.py index 67a038e..10a0247 100644 --- a/gui/main.py +++ b/gui/main.py @@ -1,4 +1,5 @@ import ctypes +from predict import solve def shot(target, width, height): @@ -20,9 +21,11 @@ def shot(target, width, height): if error_ptr: print(ctypes.string_at(error_ptr).decode()) else: - result = ctypes.string_at(data, length.value) - with open('./screen/0.png', 'wb') as f: - f.write(result) + image_bytes = ctypes.string_at(data, length.value) + result = solve(image_bytes) + + # with open('./screen/0.png', 'wb') as f: + # f.write(result) lib.Free(data) diff --git a/gui/predict.py b/gui/predict.py index e3865e1..f476246 100644 --- a/gui/predict.py +++ b/gui/predict.py @@ -2,7 +2,7 @@ import torch from torchvision import transforms from model import CNNClassifier from PIL import Image - +from io import BytesIO class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -20,13 +20,27 @@ transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) -if __name__ == '__main__': - image_path = './screen/0.png' - image = Image.open(image_path).convert('RGB') - image = transform(image).unsqueeze(0).to(device) +def solve(image_bytes): + _image = Image.open(BytesIO(image_bytes)).convert('RGB') + _image = transform(_image).unsqueeze(0).to(device) - output = model(image) - _, predicted = torch.max(output.data, 1) + _output = model(_image) + _, _predicted = torch.max(_output.data, 1) - print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}") + _result = class_labels[_predicted.item()] + + print(f"\nPredicted result: {_result}") + return _result + + +# if __name__ == '__main__': +# image_path = './screen/0.png' +# +# image = Image.open(image_path).convert('RGB') +# image = transform(image).unsqueeze(0).to(device) +# +# output = model(image) +# _, predicted = torch.max(output.data, 1) +# +# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")