#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Thu Apr 2 13:35:19 2026
Usage:python 4_predict_cls.py --model best.pt --source my_images/test/test1python 4_predict_cls.py --device cpu
Input 完整路徑:my_images/└── test/ └── test1
Output:my_images/└── test/ └── test1_pred"""import argparsefrom pathlib import Pathfrom ultralytics import YOLOimport torch
# =============================================================================# ★ 預設值# =============================================================================SRC_MODEL = "/path/to/weights/best.pt"DATASET = "/path/to/test"DEVICE = None # ← 改為 None,代表自動IMG_SIZE = 224
# =============================================================================# ★ 自動裝置偵測# =============================================================================def auto_select_device(user_device=None): if user_device is not None: return user_device
if torch.cuda.is_available(): return "0" # GPU elif torch.backends.mps.is_available(): return "mps" else: return "cpu"
# =============================================================================# ★ 推論主流程# =============================================================================def run_inference(model_path, source_dir, imgsz, device):
# 0. device 決定 device = auto_select_device(device)
print(f"\n使用設備: {device}")
# 1. 初始化模型 model = YOLO(model_path)
# 2. 路徑處理 src_path = Path(source_dir) if not src_path.exists(): print(f"找不到來源目錄: '{source_dir}'") return
save_dir = src_path.parent / f"{src_path.name}_pred"
print(f"來源: {src_path}") print(f"輸出: {save_dir}")
# 3. 推論 results = model.predict( source=str(src_path), imgsz=imgsz, device=device, save=True, save_txt=True, project=str(save_dir.parent), name=save_dir.name, exist_ok=True )
print(f"\n完成!共處理 {len(results)} 張圖片")
# =============================================================================# ★ CLI# =============================================================================if __name__ == "__main__": parser = argparse.ArgumentParser(description="YOLO-cls 批次分類推論 CLI")
parser.add_argument("--model", type=str, default=SRC_MODEL) parser.add_argument("--source", type=str, default=DATASET) parser.add_argument("--imgsz", type=int, default=224)
# device 預設 None → 自動 parser.add_argument( "--device", type=str, default=None, help="指定設備 (cpu / 0 / mps),預設自動" )
args = parser.parse_args()
results = run_inference(args.model, args.source, args.imgsz, args.device)