更新模型调用

This commit is contained in:
2025-05-09 10:51:43 +08:00
committed by GitHub
parent 5d7b573a79
commit d3d529ed98

View File

@@ -6,7 +6,7 @@ import warnings
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from model import CNNClassifier
from model import SimpleCNN, CustomCNN
warnings.filterwarnings("ignore", category=UserWarning)
@@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# 初始化模型
num_classes = len(full_dataset.classes)
model = CNNClassifier(num_classes)
chosen_model = CustomCNN
model = chosen_model(num_classes)
# 初始化损失函数和优化器
criterion = nn.CrossEntropyLoss()
@@ -71,7 +72,7 @@ if __name__ == '__main__':
# 测试模型
load_saved_model = True
if load_saved_model and os.path.exists(model_path):
model = CNNClassifier(num_classes)
model = chosen_model(num_classes)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)