import numpy as np
import math
from .common import Config, OptimizerContext

try:
    import gtsam
    from gtsam import Pose3, Rot3, Point3
    GTSAM_AVAILABLE = True
except ImportError:
    GTSAM_AVAILABLE = False

def run_module5_gtsam(cfg: Config, ctx: OptimizerContext):
    print("\n=== [Module 5] Global Optimization (GTSAM) ===")
    if not GTSAM_AVAILABLE:
        print("[Error] GTSAM not available. Skipping optimization.")
        ctx.optimized_poses = ctx.svd_global_poses.copy()
        return

    graph = gtsam.NonlinearFactorGraph()
    initial = gtsam.Values()

    # Noise
    rot_sigma = math.radians(cfg.gnss_sigma_rot_deg) # large sigma = weak prior
    trans_sigma = cfg.gnss_sigma_xyz
    prior_sig = np.array([rot_sigma]*3 + [trans_sigma]*3, dtype=np.float64)
    prior_noise = gtsam.noiseModel.Diagonal.Sigmas(prior_sig)

    icp_rot_sigma = math.radians(cfg.icp_sigma_rot_deg)
    icp_trans_sigma = cfg.icp_sigma_xyz
    icp_sig = np.array([icp_rot_sigma]*3 + [icp_trans_sigma]*3, dtype=np.float64)
    icp_noise = gtsam.noiseModel.Diagonal.Sigmas(icp_sig)

    def key_of(bid: str) -> int:
        return gtsam.symbol("x", int(bid.split("_")[-1]))

    # 1. Priors (Weak GNSS/SVD)
    print(f"[Info] Add priors (sigma_xyz={trans_sigma}m)...")
    for bid in ctx.block_ids:
        k = key_of(bid)
        
        # We assume SVD pose is "Prior"
        T = ctx.get_pose(bid, mode="initial")
        
        pose = Pose3(Rot3(T[:3, :3]), Point3(T[0, 3], T[1, 3], T[2, 3]))
        graph.add(gtsam.PriorFactorPose3(k, pose, prior_noise))
        initial.insert(k, pose)

    # 2. Between Factors (ICP)
    print(f"[Info] Add ICP factors ({len(ctx.icp_edges)})...")
    for (bi, bj), info in ctx.icp_edges.items():
        ki = key_of(bi)
        kj = key_of(bj)
        
        # ICP T_ij: p_j = T_ij * p_i
        # Factor: Between(j, i, meas) where meas = Pose_j^{-1} * Pose_i = T_ij
        T_ij = info["T_ij"]
        meas = Pose3(Rot3(T_ij[:3, :3]), Point3(T_ij[0, 3], T_ij[1, 3], T_ij[2, 3]))
        
        graph.add(gtsam.BetweenFactorPose3(kj, ki, meas, icp_noise))

    # 3. Optimize
    print("Optimizing...")
    try:
        err_before = graph.error(initial)
        print(f"  Error Before: {err_before:.6f}")

        params = gtsam.LevenbergMarquardtParams()
        optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
        result = optimizer.optimize()

        err_after = graph.error(result)
        print(f"  Error After:  {err_after:.6f}")

        # 4. Save
        ctx.optimized_poses.clear()
        for bid in ctx.block_ids:
            k = key_of(bid)
            if result.exists(k):
                ctx.optimized_poses[bid] = result.atPose3(k).matrix()
    except Exception as e:
        print(f"[Error] GTSAM Optimization failed: {e}")
        ctx.optimized_poses = ctx.svd_global_poses.copy()

    print("[Done] Optimization finished.")
