diff --git a/gui/train.py b/gui/train.py index 591329c..c5e98ac 100644 --- a/gui/train.py +++ b/gui/train.py @@ -6,7 +6,7 @@ import warnings from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split -from model import CNNClassifier +from model import SimpleCNN, CustomCNN warnings.filterwarnings("ignore", category=UserWarning) @@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) # 初始化模型 num_classes = len(full_dataset.classes) -model = CNNClassifier(num_classes) +chosen_model = CustomCNN +model = chosen_model(num_classes) # 初始化损失函数和优化器 criterion = nn.CrossEntropyLoss() @@ -71,7 +72,7 @@ if __name__ == '__main__': # 测试模型 load_saved_model = True if load_saved_model and os.path.exists(model_path): - model = CNNClassifier(num_classes) + model = chosen_model(num_classes) model.load_state_dict(torch.load(model_path, weights_only=True)) model.to(device)