From d3d529ed98dfa98e2c8629dcd3482d07c8d9a379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=BB=E7=81=B5?= Date: Fri, 9 May 2025 10:51:43 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=A8=A1=E5=9E=8B=E8=B0=83?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gui/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gui/train.py b/gui/train.py index 591329c..c5e98ac 100644 --- a/gui/train.py +++ b/gui/train.py @@ -6,7 +6,7 @@ import warnings from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split -from model import CNNClassifier +from model import SimpleCNN, CustomCNN warnings.filterwarnings("ignore", category=UserWarning) @@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) # 初始化模型 num_classes = len(full_dataset.classes) -model = CNNClassifier(num_classes) +chosen_model = CustomCNN +model = chosen_model(num_classes) # 初始化损失函数和优化器 criterion = nn.CrossEntropyLoss() @@ -71,7 +72,7 @@ if __name__ == '__main__': # 测试模型 load_saved_model = True if load_saved_model and os.path.exists(model_path): - model = CNNClassifier(num_classes) + model = chosen_model(num_classes) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device)