更新模型调用

This commit is contained in:
2025-05-09 10:51:43 +08:00
committed by GitHub
parent 5d7b573a79
commit d3d529ed98

View File

@@ -6,7 +6,7 @@ import warnings
from tqdm import tqdm from tqdm import tqdm
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from model import CNNClassifier from model import SimpleCNN, CustomCNN
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 初始化模型 # 初始化模型
num_classes = len(full_dataset.classes) num_classes = len(full_dataset.classes)
model = CNNClassifier(num_classes) chosen_model = CustomCNN
model = chosen_model(num_classes)
# 初始化损失函数和优化器 # 初始化损失函数和优化器
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@@ -71,7 +72,7 @@ if __name__ == '__main__':
# 测试模型 # 测试模型
load_saved_model = True load_saved_model = True
if load_saved_model and os.path.exists(model_path): if load_saved_model and os.path.exists(model_path):
model = CNNClassifier(num_classes) model = chosen_model(num_classes)
model.load_state_dict(torch.load(model_path, weights_only=True)) model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device) model.to(device)