优化神经网络,缩减模型大小

This commit is contained in:
2025-03-28 18:07:58 +08:00
parent 2ad08423ce
commit a587bd804e
13 changed files with 91 additions and 69 deletions

92
gui/train.py Normal file
View File

@@ -0,0 +1,92 @@
import os
import torch
import torch.nn as nn
import torch.optim as optim
import warnings
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from model import CNNClassifier
warnings.filterwarnings("ignore", category=UserWarning)
# 数据预处理
transform = transforms.Compose([
transforms.Resize((1024, 800)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
data_dir = './datasets'
full_dataset = datasets.ImageFolder(data_dir, transform=transform)
# 划分训练集和测试集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 初始化模型
num_classes = len(full_dataset.classes)
model = CNNClassifier(num_classes)
# 初始化损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 训练模型
num_epochs = 16
model.to(device)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')
for i, (images, labels) in enumerate(train_loader_tqdm):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loader_tqdm.set_postfix(loss=running_loss / (i + 1))
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}')
# 保存模型
model_path = 'rdp_model.pth'
torch.save(model.state_dict(), model_path)
# 测试模型
load_saved_model = True
if load_saved_model and os.path.exists(model_path):
model = CNNClassifier(num_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)
model.eval()
correct = 0
total = 0
test_loader_tqdm = tqdm(test_loader, desc='Testing', unit='batch')
with torch.no_grad():
for images, labels in test_loader_tqdm:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_loader_tqdm.set_postfix(accuracy=100 * correct / total)
print(f'Accuracy on test set: {100 * correct / total}%')