import numpy as np
import open3d as o3d
import math
from .common import (
    Config, OptimizerContext, 
    load_pose_txt, interp_positions, compute_svd_alignment, 
    make_T, transform_points, yaw_from_R, euclidean,
    estimate_normals, voxel_keys
)

def run_module2_svd_coarse_alignment(cfg: Config, ctx: OptimizerContext):
    print("\n=== [Module 2] Trajectory SVD Coarse Alignment ===")
    ctx.svd_global_poses.clear()

    for bid in ctx.block_ids:
        pose_file = ctx.pose_paths[bid]
        gnss_file = ctx.gnss_kf_paths[bid]

        print(f"\n[SVD] {bid}")
        if not pose_file.exists() or not gnss_file.exists():
            print(f"  [Skip] missing files.")
            ctx.svd_global_poses[bid] = _fallback_pose(ctx, bid)
            continue

        l_ts, l_xyz = load_pose_txt(pose_file)
        g_ts, g_xyz = load_pose_txt(gnss_file)

        if len(l_ts) < 2 or len(g_ts) < 2:
            print(f"  [Warn] empty/too short trajectories, fallback.")
            ctx.svd_global_poses[bid] = _fallback_pose(ctx, bid)
            continue

        ctx.local_traj[bid] = (l_ts, l_xyz)
        ctx.gnss_traj[bid] = (g_ts, g_xyz)

        # Sync
        mask = (g_ts >= l_ts.min()) & (g_ts <= l_ts.max())
        g_ts_v = g_ts[mask]
        g_xyz_v = g_xyz[mask]

        if len(g_ts_v) < cfg.min_svd_pairs:
            print(f"  [Warn] < {cfg.min_svd_pairs} pairs, fallback centroid.")
            c_l = l_xyz.mean(axis=0)
            c_g = g_xyz.mean(axis=0)
            T = make_T(np.eye(3), c_g - c_l)
            ctx.svd_global_poses[bid] = T
            continue

        l_xyz_i = interp_positions(g_ts_v, l_ts, l_xyz)

        # Sample
        if len(g_ts_v) > cfg.max_svd_pairs:
            sel = np.linspace(0, len(g_ts_v) - 1, cfg.max_svd_pairs).astype(int)
            l_xyz_i = l_xyz_i[sel]
            g_xyz_v = g_xyz_v[sel]

        T = compute_svd_alignment(l_xyz_i, g_xyz_v)
        
        # Eval
        pred = transform_points(l_xyz_i, T)
        res = np.linalg.norm(pred - g_xyz_v, axis=1)
        rmse = math.sqrt(float(np.mean(res**2)))
        yaw_deg = math.degrees(yaw_from_R(T[:3, :3]))
        t = T[:3, 3]
        
        print(f"  [OK] pairs={len(l_xyz_i)}, rmse={rmse:.3f}m, yaw={yaw_deg:.2f}deg")
        ctx.svd_global_poses[bid] = T

def run_module3a_distance_gating(cfg: Config, ctx: OptimizerContext):
    print("\n=== [Module 3a] Distance Gating (Centroid) ===")
    ctx.pairs_gated.clear()
    
    # Use GNSS info for rough centroid check
    centroids = {}
    for bid in ctx.block_ids:
        # If we have better SVD pose, update centroid? 
        # Actually SVD pose * local_pcd_center is better.
        # But let's stick to json centroid for 'coarse' gating as requested or use SVD if avail.
        if bid in ctx.svd_global_poses and bid in ctx.pcd_ds:
            T = ctx.svd_global_poses[bid]
            c_local = ctx.pcd_ds[bid].get_center()
            c_global = (T[:3,:3] @ c_local) + T[:3,3]
            centroids[bid] = c_global
        else:
            centroids[bid] = np.array(ctx.gnss_info[bid]["centroid_xyz"], dtype=np.float64)

    for i in range(len(ctx.block_ids)):
        for j in range(i + 1, len(ctx.block_ids)):
            bi, bj = ctx.block_ids[i], ctx.block_ids[j]
            d = euclidean(centroids[bi], centroids[bj])
            if d < cfg.gating_distance_m:
                ctx.pairs_gated.append((bi, bj, d))
                print(f"  [Pass] {bi}<->{bj} dist={d:.2f}m")

    print(f"[Done] Gated pairs: {len(ctx.pairs_gated)}")

def run_module3b_overlap_check(cfg: Config, ctx: OptimizerContext):
    print("\n=== [Module 3b] Geometric Overlap Check (Voxel) ===")
    ctx.candidate_edges.clear()

    if not ctx.pairs_gated:
        print("[Warn] No gated pairs.")
        return

    for (bi, bj, dist) in ctx.pairs_gated:
        if bi not in ctx.pcd_ds or bj not in ctx.pcd_ds:
            continue

        Ti = ctx.get_pose(bi, mode="initial")
        Tj = ctx.get_pose(bj, mode="initial")

        pts_i = np.asarray(ctx.pcd_ds[bi].points)
        pts_j = np.asarray(ctx.pcd_ds[bj].points)

        # Sample
        pts_i = _sample(pts_i, cfg.overlap_max_points)
        pts_j = _sample(pts_j, cfg.overlap_max_points)

        pts_i_g = transform_points(pts_i, Ti)
        pts_j_g = transform_points(pts_j, Tj)

        vs = cfg.overlap_voxel_size
        kA = voxel_keys(pts_i_g, vs)
        kB = voxel_keys(pts_j_g, vs)
        inter = int(np.intersect1d(kA, kB).size)
        denom = max(1, min(kA.size, kB.size))
        ratio = float(inter / denom)

        keep = ratio >= cfg.overlap_ratio_threshold
        print(f"[Check] {bi}<->{bj} ratio={ratio:.4f} => {'KEEP' if keep else 'DROP'}")
        
        if keep:
            ctx.candidate_edges.append((bi, bj, ratio))

    print(f"[Done] Candidate edges: {len(ctx.candidate_edges)}")

def run_module4_icp(cfg: Config, ctx: OptimizerContext):
    print("\n=== [Module 4] ICP Fine Matching ===")
    ctx.icp_edges.clear()
    
    if not ctx.candidate_edges:
        print("[Warn] No candidate edges.")
        return

    for (bi, bj, ratio) in ctx.candidate_edges:
        print(f"\n[ICP] {bi} -> {bj} (overlap={ratio:.4f})")
        
        src = o3d.geometry.PointCloud(ctx.pcd_ds[bi]) # copy
        tgt = o3d.geometry.PointCloud(ctx.pcd_ds[bj]) # copy

        Ti = ctx.get_pose(bi, mode="initial")
        Tj = ctx.get_pose(bj, mode="initial")

        # Initial relative guess: p_j = T_rel * p_i
        # T_rel = inv(Tj) * Ti
        T_rel_init = np.linalg.inv(Tj) @ Ti

        estimate_normals(src, cfg.voxel_size)
        estimate_normals(tgt, cfg.voxel_size)

        reg = o3d.pipelines.registration.registration_icp(
            src, tgt,
            cfg.icp_max_corr_dist,
            T_rel_init,
            o3d.pipelines.registration.TransformationEstimationPointToPlane(),
            o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=cfg.icp_max_iter),
        )

        fitness = float(reg.fitness)
        rmse = float(reg.inlier_rmse)
        valid = (rmse <= cfg.icp_rmse_threshold) and (fitness >= cfg.icp_fitness_threshold)
        
        print(f"  -> {'VALID' if valid else 'FAIL'} (fit={fitness:.4f}, rmse={rmse:.4f})")

        if valid:
            ctx.icp_edges[(bi, bj)] = {
                "T_ij": reg.transformation,
                "fitness": fitness,
                "rmse": rmse,
                "overlap_ratio": ratio
            }

    print(f"[Done] Valid ICP edges: {len(ctx.icp_edges)}")


def _fallback_pose(ctx, bid):
    origin = np.array(ctx.gnss_info[bid]["origin_xyz"], dtype=np.float64)
    return make_T(np.eye(3), origin)

def _sample(pts, max_n):
    if pts.shape[0] <= max_n: return pts
    sel = np.random.choice(pts.shape[0], max_n, replace=False)
    return pts[sel]
