#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Wed Apr 1 17:13:03 2026
Usage:python val_cls.pypython val_cls.py --model ./runs/classify/train/weights/best.pt --data /datasets/my_images --imgsz 224
# 使用 Apple Silicon GPUpython val_cls.py --device mps
# 使用 CPU (如果想對比效能差異)python val_cls.py --device cpu
Input:my_images/├── train/ # (可有可無,驗證時不讀取)└── val/ # 必須存在,驗證以此為準 ├── class_cat/ # 類別 A 的所有圖片 ├── class_dog/ # 類別 B 的所有圖片 └── ... """
import argparseimport sysimport torchfrom ultralytics import YOLO
# =============================================================================# ★ 預設值(CLI 未指定時使用)# =============================================================================SRC_MODEL = "/path/to/weights/best.pt"DATASET = "/path/to/val"# DEVICE = "mps" #改成自動偵測IMG_SIZE = 224
def get_default_device(): """自動偵測最佳可用設備""" if torch.cuda.is_available(): return "0" # NVIDIA GPU elif torch.backends.mps.is_available(): return "mps" # Apple Silicon GPU return "cpu" # 預設 CPU
def run_validation(): # 建立 CLI 參數解析器 parser = argparse.ArgumentParser(description="YOLO11-cls 分類模型驗證腳本 (支援 MPS/CUDA)")
# --- 設定參數與預設值 --- parser.add_argument("--model", type=str, default=SRC_MODEL, help="模型檔案路徑 (.pt)") parser.add_argument("--data", type=str, default=DATASET, help="資料集目錄路徑 (需包含 val 子目錄)") parser.add_argument("--imgsz", type=int, default=IMG_SIZE, help="輸入影像尺寸 (預設 224)") # 設備參數:預設會自動執行 get_default_device() parser.add_argument("--device", type=str, default=get_default_device(), help="運算設備: '0', 'cpu', 'mps'")
args = parser.parse_args()
try: # 1. 載入模型 print(f"--- 正在載入模型: {args.model} ---") model = YOLO(args.model)
# 2. 執行驗證 print(f"--- 開始驗證 | 設備: {args.device} | 資料集: {args.data} ---") results = model.val( data=args.data, imgsz=args.imgsz, device=args.device, split='val' )
# 3. 輸出結果摘要 print("\n" + "="*40) print("驗證結果摘要:") print(f"使用設備: {args.device.upper()}") print(f"Top-1 Accuracy: {results.top1:.4f}") print(f"Top-5 Accuracy: {results.top5:.4f}") print("="*40)
except Exception as e: print(f"執行過程中發生錯誤: {e}") sys.exit(1)
if __name__ == "__main__": run_validation()