Files
ShotRDP/gui/train.py
2025-05-09 10:51:43 +08:00

94 lines
2.9 KiB
Python

import os
import torch
import torch.nn as nn
import torch.optim as optim
import warnings
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from model import SimpleCNN, CustomCNN
warnings.filterwarnings("ignore", category=UserWarning)
# 数据预处理
transform = transforms.Compose([
transforms.Resize((1024, 800)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
data_dir = './datasets'
full_dataset = datasets.ImageFolder(data_dir, transform=transform)
# 划分训练集和测试集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 初始化模型
num_classes = len(full_dataset.classes)
chosen_model = CustomCNN
model = chosen_model(num_classes)
# 初始化损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 训练模型
num_epochs = 16
model.to(device)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')
for i, (images, labels) in enumerate(train_loader_tqdm):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loader_tqdm.set_postfix(loss=running_loss / (i + 1))
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}')
# 保存模型
model_path = 'rdp_model.pth'
torch.save(model.state_dict(), model_path)
# 测试模型
load_saved_model = True
if load_saved_model and os.path.exists(model_path):
model = chosen_model(num_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)
model.eval()
correct = 0
total = 0
test_loader_tqdm = tqdm(test_loader, desc='Testing', unit='batch')
with torch.no_grad():
for images, labels in test_loader_tqdm:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_loader_tqdm.set_postfix(accuracy=100 * correct / total)
print(f'Accuracy on test set: {100 * correct / total}%')