mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
封装识别接口
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,7 +4,7 @@
|
|||||||
# Work dir and files
|
# Work dir and files
|
||||||
.idea
|
.idea
|
||||||
.venv
|
.venv
|
||||||
__pycache__
|
__pycache__/
|
||||||
|
|
||||||
# Binaries for programs and plugins
|
# Binaries for programs and plugins
|
||||||
*.exe
|
*.exe
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
|
from predict import solve
|
||||||
|
|
||||||
|
|
||||||
def shot(target, width, height):
|
def shot(target, width, height):
|
||||||
@@ -20,9 +21,11 @@ def shot(target, width, height):
|
|||||||
if error_ptr:
|
if error_ptr:
|
||||||
print(ctypes.string_at(error_ptr).decode())
|
print(ctypes.string_at(error_ptr).decode())
|
||||||
else:
|
else:
|
||||||
result = ctypes.string_at(data, length.value)
|
image_bytes = ctypes.string_at(data, length.value)
|
||||||
with open('./screen/0.png', 'wb') as f:
|
result = solve(image_bytes)
|
||||||
f.write(result)
|
|
||||||
|
# with open('./screen/0.png', 'wb') as f:
|
||||||
|
# f.write(result)
|
||||||
|
|
||||||
lib.Free(data)
|
lib.Free(data)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from model import CNNClassifier
|
from model import CNNClassifier
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
class_labels = ['Windows 7', 'Windows 10', 'Windows Server 2008', 'Windows Server 2012']
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -20,13 +20,27 @@ transform = transforms.Compose([
|
|||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
image_path = './screen/0.png'
|
|
||||||
|
|
||||||
image = Image.open(image_path).convert('RGB')
|
def solve(image_bytes):
|
||||||
image = transform(image).unsqueeze(0).to(device)
|
_image = Image.open(BytesIO(image_bytes)).convert('RGB')
|
||||||
|
_image = transform(_image).unsqueeze(0).to(device)
|
||||||
|
|
||||||
output = model(image)
|
_output = model(_image)
|
||||||
_, predicted = torch.max(output.data, 1)
|
_, _predicted = torch.max(_output.data, 1)
|
||||||
|
|
||||||
print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
_result = class_labels[_predicted.item()]
|
||||||
|
|
||||||
|
print(f"\nPredicted result: {_result}")
|
||||||
|
return _result
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# image_path = './screen/0.png'
|
||||||
|
#
|
||||||
|
# image = Image.open(image_path).convert('RGB')
|
||||||
|
# image = transform(image).unsqueeze(0).to(device)
|
||||||
|
#
|
||||||
|
# output = model(image)
|
||||||
|
# _, predicted = torch.max(output.data, 1)
|
||||||
|
#
|
||||||
|
# print(f"File name: {image_path}, Predicted result: {class_labels[predicted.item()]}, Output: {output.data}")
|
||||||
|
|||||||
Reference in New Issue
Block a user