import os import torch import torch.nn as nn import torch.optim as optim import warnings from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split from model import SimpleCNN, CustomCNN warnings.filterwarnings("ignore", category=UserWarning) # 数据预处理 transform = transforms.Compose([ transforms.Resize((1024, 800)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 data_dir = './datasets' full_dataset = datasets.ImageFolder(data_dir, transform=transform) # 划分训练集和测试集 train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) # 初始化模型 num_classes = len(full_dataset.classes) chosen_model = CustomCNN model = chosen_model(num_classes) # 初始化损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if __name__ == '__main__': # 训练模型 num_epochs = 16 model.to(device) for epoch in range(num_epochs): model.train() running_loss = 0.0 train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') for i, (images, labels) in enumerate(train_loader_tqdm): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() train_loader_tqdm.set_postfix(loss=running_loss / (i + 1)) print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}') # 保存模型 model_path = 'rdp_model.pth' torch.save(model.state_dict(), model_path) # 测试模型 load_saved_model = True if load_saved_model and os.path.exists(model_path): model = chosen_model(num_classes) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device) model.eval() correct = 0 total = 0 test_loader_tqdm = tqdm(test_loader, desc='Testing', unit='batch') with torch.no_grad(): for images, labels in test_loader_tqdm: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_loader_tqdm.set_postfix(accuracy=100 * correct / total) print(f'Accuracy on test set: {100 * correct / total}%')