#!/usr/bin/env python3
"""
global_optimizer.py

大规模子图（Submaps / Blocks）分块融合 + 全局优化核心脚本 (Refactored & Modularized & Timed)

Usage:
  python src/global_optimizer.py --pcd_dir data/HK_PCD_DS --max_stage 5
"""

import argparse
from pathlib import Path

# Import Modules
from optimizer_modules.common import Config, OptimizerContext, TimingLogger
from optimizer_modules.loader import run_module0_metadata, run_module1_load_pcd
from optimizer_modules.matching import (
    run_module2_svd_coarse_alignment,
    run_module3a_distance_gating,
    run_module3b_overlap_check,
    run_module4_icp
)
from optimizer_modules.optimization import run_module5_gtsam
from optimizer_modules.exporter import save_pose_table_and_preview, export_global_map

def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser()
    ap.add_argument("--pcd_dir", type=str, default="data/HK_PCD_DS")
    ap.add_argument("--pose_root_dir", type=str, default="data/HK_POSE")
    ap.add_argument("--gnss_json", type=str, default="data/HK_POSE/simulated_gnss.json")
    ap.add_argument("--out_dir", type=str, default="results")
    ap.add_argument("--max_stage", type=int, default=5, choices=[1, 2, 3, 4, 5])

    # SVD
    ap.add_argument("--min_svd_pairs", type=int, default=20)
    
    # Overlap
    ap.add_argument("--gating_distance_m", type=float, default=500.0)
    ap.add_argument("--overlap_ratio_threshold", type=float, default=0.01)

    # ICP
    ap.add_argument("--icp_rmse_threshold", type=float, default=0.5)

    # GTSAM
    ap.add_argument("--gnss_sigma_xyz", type=float, default=50.0)

    return ap.parse_args()

def main():
    args = parse_args()
    
    # Init Config & Context
    cfg = Config()
    cfg.pcd_dir = Path(args.pcd_dir)
    cfg.pose_root_dir = Path(args.pose_root_dir)
    cfg.gnss_json = Path(args.gnss_json)
    cfg.out_dir = Path(args.out_dir)
    cfg.max_stage = args.max_stage
    
    # Override
    cfg.min_svd_pairs = args.min_svd_pairs
    cfg.gating_distance_m = args.gating_distance_m
    cfg.overlap_ratio_threshold = args.overlap_ratio_threshold
    cfg.icp_rmse_threshold = args.icp_rmse_threshold
    cfg.gnss_sigma_xyz = args.gnss_sigma_xyz

    ctx = OptimizerContext()
    timer = TimingLogger()

    print("\n================ Global Optimizer (Modular) ================")
    print(f"[Run] max_stage      = {cfg.max_stage}")
    print("==========================================================\n")

    # Module 0
    timer.start("Module 0: Load Metadata")
    run_module0_metadata(cfg, ctx)
    timer.stop("Module 0: Load Metadata", count=1, unit="step")

    # Module 1
    if cfg.max_stage >= 1:
        timer.start("Module 1: Load PCD")
        run_module1_load_pcd(cfg, ctx)
        timer.stop("Module 1: Load PCD", count=len(ctx.block_ids), unit="block")

    # Module 2
    if cfg.max_stage >= 2:
        timer.start("Module 2: SVD Alignment")
        run_module2_svd_coarse_alignment(cfg, ctx)
        timer.stop("Module 2: SVD Alignment", count=len(ctx.block_ids), unit="block")
        
        timer.start("Aux: Save Init Preview")
        save_pose_table_and_preview(cfg, ctx, mode="initial")
        timer.stop("Aux: Save Init Preview", count=1, unit="step")

    # Module 3
    if cfg.max_stage >= 3:
        timer.start("Module 3a: Dist Gating")
        run_module3a_distance_gating(cfg, ctx)
        timer.stop("Module 3a: Dist Gating", count=1, unit="all_pairs")

        timer.start("Module 3b: Overlap Check")
        run_module3b_overlap_check(cfg, ctx)
        timer.stop("Module 3b: Overlap Check", count=len(ctx.pairs_gated), unit="pair")

    # Module 4
    if cfg.max_stage >= 4:
        timer.start("Module 4: ICP")
        run_module4_icp(cfg, ctx)
        timer.stop("Module 4: ICP", count=len(ctx.candidate_edges), unit="edge")

    # Module 5
    if cfg.max_stage >= 5:
        timer.start("Module 5: GTSAM Opt")
        run_module5_gtsam(cfg, ctx)
        timer.stop("Module 5: GTSAM Opt", count=1, unit="graph")

        timer.start("Aux: Export Map")
        save_pose_table_and_preview(cfg, ctx, mode="optimized")
        export_global_map(cfg, ctx, cfg.out_dir / "Global_Map_Optimized.pcd")
        timer.stop("Aux: Export Map", count=len(ctx.block_ids), unit="block")

    # Save Timing Report
    timer.save_report(cfg.out_dir / "timing_log.txt")

if __name__ == "__main__":
    main()
