"""
Function: 訓練 + Augmentation"""
import argparsefrom ultralytics import YOLOimport torchimport os
# =============================================================================# ★ 預設值(CLI 未指定時使用)# =============================================================================SRC_MODEL = "/path/to/model.pt"DATASET = "/path/to/dataset"PROJECT_DIR = "runs_cls"EXP_NAME = "cls_aug"EPOCHS = 3IMG_SIZE = 224BATCH_SIZE = 16DEVICE = "mps" # Device: 0,1,... for CUDA, 'cpu' or 'mps' for Apple GPU"WORKERS = 4 # Dataloader workersLR_RATE = 0.01 # Initial learning ratePATIENCE = 10 # Early stopping patience
def parse_args(): parser = argparse.ArgumentParser(description="YOLO11 Classification Training CLI with AutoAugmentation")
# 基本參數 parser.add_argument("--data", type=str, default=DATASET, help="Path to dataset directory") parser.add_argument("--model", type=str, default=SRC_MODEL, help="Model path") parser.add_argument("--project", type=str, default=PROJECT_DIR, help="Project folder") parser.add_argument("--name", type=str, default=EXP_NAME, help="Experiment name") parser.add_argument("--epochs", type=int, default=EPOCHS, help="Training epochs") parser.add_argument("--imgsz", type=int, default=IMG_SIZE, help="Image size") parser.add_argument("--batch", type=int, default=BATCH_SIZE, help="Batch size") parser.add_argument("--device", type=str, default=DEVICE, help="Device: 0,1,... for CUDA, 'cpu' or 'mps'") parser.add_argument("--workers", type=int, default=WORKERS, help="Dataloader workers") parser.add_argument("--lr0", type=float, default=LR_RATE, help="Initial learning rate") parser.add_argument("--patience", type=int, default=PATIENCE, help="Early stopping patience")
# Data Augmentation 參數 parser.add_argument("--augment", type=bool, default=True, help="Enable data augmentation") parser.add_argument("--degrees", type=float, default=10.0, help="Random rotation ± degrees") parser.add_argument("--translate", type=float, default=0.1, help="Random translation fraction") parser.add_argument("--scale", type=float, default=0.1, help="Random scaling fraction") parser.add_argument("--shear", type=float, default=2.0, help="Random shear angle") parser.add_argument("--perspective", type=float, default=0.0, help="Random perspective fraction") parser.add_argument("--hsv_h", type=float, default=0.015, help="Hue augmentation") parser.add_argument("--hsv_s", type=float, default=0.7, help="Saturation augmentation") parser.add_argument("--hsv_v", type=float, default=0.4, help="Value augmentation")
return parser.parse_args()
def main(): args = parse_args()
# 選擇 device if args.device.lower() == "mps": device_str = "mps" if torch.backends.mps.is_available() else "cpu" elif args.device.isdigit(): device_str = int(args.device) if torch.cuda.is_available() else "cpu" else: device_str = "cpu"
print("===== YOLO11-CLS Training Config =====") for k, v in vars(args).items(): print(f"{k}: {v}") print(f"Using device: {device_str}") print("======================================")
# 載入模型 model = YOLO(args.model)
# 開始訓練 model.train( data=args.data, epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, device=device_str, workers=args.workers, lr0=args.lr0, patience=args.patience, project=args.project, name=args.name, augment=args.augment, degrees=args.degrees, translate=args.translate, scale=args.scale, shear=args.shear, perspective=args.perspective, hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v )
if __name__ == "__main__": main()