#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Mon Apr 6 12:44:32 2026
Usage:# 使用預設值直接執行python predict_cls_sort.py
# 自訂參數python predict_cls_sort.py \ --model best.pt \ --source /data/test/images \ --device cpu \ --conf 0.7
Function:YOLO11-cls 批次分類推論 + 依類別自動分類存檔
Input 完整路徑: my_images/test/test1/
Output 結構: my_images/test/test1_pred/ ├── neutrophil/ │ ├── img001.jpg │ └── img002.jpg ├── eosinophil/ │ └── img003.jpg └── ..."""
import argparseimport shutilfrom pathlib import Pathfrom ultralytics import YOLO
# =============================================================================# ★ 預設值(CLI 未指定時使用)# =============================================================================SRC_MODEL = "/path/to/weights/best.pt"DATASET = "/path/to/bloodcells_200raw"DEVICE = "mps"IMG_SIZE = 224CONF_THRES = 0.0 # 信心度門檻,低於此值歸入 _uncertain 資料夾(0.0 = 不過濾)
# 支援的圖片副檔名IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
def run_inference(model_path: str, source_dir: str, imgsz: int, device: str, conf: float) -> None:
# ------------------------------------------------------------------ # 1. 初始化模型 # ------------------------------------------------------------------ print(f"載入模型: {model_path}") model = YOLO(model_path)
# ------------------------------------------------------------------ # 2. 檢查來源目錄 # ------------------------------------------------------------------ src_path = Path(source_dir).resolve() if not src_path.exists(): print(f"[錯誤] 找不到來源目錄: '{src_path}'") return
image_files = sorted([ p for p in src_path.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_SUFFIXES ])
if not image_files: print(f"[警告] 來源目錄內沒有支援的圖片檔案: {src_path}") return
print(f"來源目錄 : {src_path}") print(f"找到圖片 : {len(image_files)} 張\n")
# ------------------------------------------------------------------ # 3. 建立目的目錄 # ------------------------------------------------------------------ dest_root = src_path.parent / f"{src_path.name}_pred" dest_root.mkdir(parents=True, exist_ok=True) print(f"目的目錄 : {dest_root}\n")
# ------------------------------------------------------------------ # 4. 逐張推論 + 依類別分類複製 # ------------------------------------------------------------------ stats: dict[str, int] = {} # 各類別計數
for img_path in image_files: results = model.predict( source=str(img_path), imgsz=imgsz, device=device, verbose=False, )
result = results[0] probs = result.probs # Probs 物件 top1_idx = int(probs.top1) # 最高分類別索引 top1_conf = float(probs.top1conf) # 信心度 class_name = result.names[top1_idx]
# 低於門檻歸入 _uncertain if conf > 0.0 and top1_conf < conf: class_name = "_uncertain"
# 建立子目錄(不存在時自動建立) class_dir = dest_root / class_name class_dir.mkdir(exist_ok=True)
# 複製原始圖片到對應子目錄 dest_file = class_dir / img_path.name shutil.copy2(img_path, dest_file)
stats[class_name] = stats.get(class_name, 0) + 1
print(f" {img_path.name:<40} → {class_name:<20} (conf: {top1_conf:.4f})")
# ------------------------------------------------------------------ # 5. 統計摘要 # ------------------------------------------------------------------ print("\n" + "=" * 60) print("推論完成!分類統計:") print("-" * 60) total = 0 for cls, count in sorted(stats.items()): print(f" {cls:<30} : {count} 張") total += count print("-" * 60) print(f" {'合計':<30} : {total} 張") print("=" * 60) print(f"\n結果已儲存至: {dest_root}")
if __name__ == "__main__": parser = argparse.ArgumentParser( description="YOLO11-cls 批次分類推論,並依類別將圖片整理至子目錄", formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument( "--model", type=str, default=SRC_MODEL, help="模型權重路徑 (.pt)" ) parser.add_argument( "--source", type=str, default=DATASET, help="來源圖片目錄(會處理目錄內所有圖片)" ) parser.add_argument( "--imgsz", type=int, default=IMG_SIZE, help="推論影像尺寸" ) parser.add_argument( "--device", type=str, default=DEVICE, help="運算裝置,例如 'cpu'、'0'(GPU)、'mps'(Apple Silicon)" ) parser.add_argument( "--conf", type=float, default=CONF_THRES, help="信心度門檻(低於此值歸入 _uncertain,設 0 表示不過濾)" )
args = parser.parse_args()
run_inference( model_path = args.model, source_dir = args.source, imgsz = args.imgsz, device = args.device, conf = args.conf, )