mirror of
https://github.com/yv1ing/ShotRDP.git
synced 2025-09-16 15:10:57 +08:00
更新模型调用
This commit is contained in:
@@ -6,7 +6,7 @@ import warnings
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
from model import CNNClassifier
|
from model import SimpleCNN, CustomCNN
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
@@ -34,7 +34,8 @@ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
|
|||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
num_classes = len(full_dataset.classes)
|
num_classes = len(full_dataset.classes)
|
||||||
model = CNNClassifier(num_classes)
|
chosen_model = CustomCNN
|
||||||
|
model = chosen_model(num_classes)
|
||||||
|
|
||||||
# 初始化损失函数和优化器
|
# 初始化损失函数和优化器
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
@@ -71,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
# 测试模型
|
# 测试模型
|
||||||
load_saved_model = True
|
load_saved_model = True
|
||||||
if load_saved_model and os.path.exists(model_path):
|
if load_saved_model and os.path.exists(model_path):
|
||||||
model = CNNClassifier(num_classes)
|
model = chosen_model(num_classes)
|
||||||
model.load_state_dict(torch.load(model_path, weights_only=True))
|
model.load_state_dict(torch.load(model_path, weights_only=True))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user