mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
更新模型调用
This commit is contained in:
@@ -6,7 +6,7 @@ import warnings
|
||||
from tqdm import tqdm
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from model import CNNClassifier
|
||||
from model import SimpleCNN, CustomCNN
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
@@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
|
||||
|
||||
# 初始化模型
|
||||
num_classes = len(full_dataset.classes)
|
||||
model = CNNClassifier(num_classes)
|
||||
chosen_model = CustomCNN
|
||||
model = chosen_model(num_classes)
|
||||
|
||||
# 初始化损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
@@ -71,7 +72,7 @@ if __name__ == '__main__':
|
||||
# 测试模型
|
||||
load_saved_model = True
|
||||
if load_saved_model and os.path.exists(model_path):
|
||||
model = CNNClassifier(num_classes)
|
||||
model = chosen_model(num_classes)
|
||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||
model.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user