#!/usr/bin/env python3
import argparse
import numpy as np
import open3d as o3d
from pathlib import Path
from scipy.spatial.transform import Rotation as R

def load_pose_txt(path):
    """
    Load pose txt file. 
    Format: timestamp x y z [qx qy qz qw]
    Returns: timestamps, positions (N, 3)
    """
    try:
        data = np.loadtxt(path)
    except Exception as e:
        raise ValueError(f"Failed to load {path}: {e}")
        
    if data.ndim == 1:
        data = data.reshape(1, -1)
        
    timestamps = data[:, 0]
    positions = data[:, 1:4]
    return timestamps, positions

def interp_positions(target_times, src_times, src_positions):
    """
    Linear interpolation of positions at target_times.
    Assumes times are sorted.
    """
    # Use numpy interp for each axis
    x_interp = np.interp(target_times, src_times, src_positions[:, 0])
    y_interp = np.interp(target_times, src_times, src_positions[:, 1])
    z_interp = np.interp(target_times, src_times, src_positions[:, 2])
    return np.stack([x_interp, y_interp, z_interp], axis=1)

def compute_initial_alignment_svd(local_poses, global_gnss):
    """
    Compute T_local_to_global using SV (Kabsch Algorithm).
    local_poses: (N, 3) array in Local Frame
    global_gnss: (N, 3) array in Global Frame
    """
    # 1. Data Alignment
    n = min(len(local_poses), len(global_gnss))
    local_p = local_poses[:n]
    global_p = global_gnss[:n]
    
    if n < 3:
        raise ValueError("Need at least 3 points for SVD alignment.")

    # 2. Centroids
    centroid_local = np.mean(local_p, axis=0)
    centroid_global = np.mean(global_p, axis=0)
    
    # 3. Center Data
    A = local_p - centroid_local
    B = global_p - centroid_global
    
    # 4. Covariance Matrix
    H = np.dot(A.T, B)
    
    # 5. SVD
    U, S, Vt = np.linalg.svd(H)
    R_mat = np.dot(Vt.T, U.T)
    
    # 6. Reflection Case (Fix for determinant < 0)
    if np.linalg.det(R_mat) < 0:
        print("[Info] Reflection detected, fixing rotation matrix...")
        Vt[2, :] *= -1
        R_mat = np.dot(Vt.T, U.T)
        
    # 7. Translation
    t_vec = centroid_global - np.dot(R_mat, centroid_local)
    
    # 8. Assemble T
    T_init = np.eye(4)
    T_init[:3, :3] = R_mat
    T_init[:3, 3] = t_vec
    
    return T_init

def main():
    parser = argparse.ArgumentParser(description="Traj Alignment SVD: Align PCD using Local Pose and Global GNSS trajectory.")
    parser.add_argument("--local-pose", required=True, help="Path to Local Pose TXT (High Freq)")
    parser.add_argument("--global-gnss", required=True, help="Path to Global GNSS TXT (Low Freq)")
    parser.add_argument("--pcd-in", required=True, help="Input PCD file path")
    parser.add_argument("--pcd-out", required=True, help="Output PCD file path")
    
    args = parser.parse_args()
    
    # 1. Load Trajectories
    print(f"Loading Local Poses: {args.local_pose}")
    local_ts, local_xyz = load_pose_txt(args.local_pose)
    
    print(f"Loading Global GNSS: {args.global_gnss}")
    gnss_ts, gnss_xyz = load_pose_txt(args.global_gnss)
    
    # 2. Sync Data (Interpolate Local Poses to GNSS times)
    print("Synchronizing Trajectories...")
    # Only use GNSS times that are within Local Pose time range to avoid extrapolation errors
    valid_mask = (gnss_ts >= local_ts.min()) & (gnss_ts <= local_ts.max())
    valid_gnss_ts = gnss_ts[valid_mask]
    valid_gnss_xyz = gnss_xyz[valid_mask]
    
    if len(valid_gnss_ts) < 10:
        print(f"Warning: Only {len(valid_gnss_ts)} overlapping points found. Alignment might be unstable.")
    
    local_xyz_interp = interp_positions(valid_gnss_ts, local_ts, local_xyz)
    
    # 3. Compute Transform
    print(f"Computing Alignment with {len(local_xyz_interp)} points...")
    T = compute_initial_alignment_svd(local_xyz_interp, valid_gnss_xyz)
    
    print("\nCalculated Transform (Local -> Global):")
    print(T)
    
    # Calculate Yaw rotation for verification
    yaw_deg = np.degrees(np.arctan2(T[1, 0], T[0, 0]))
    print(f"\nEstimated Yaw Rotation: {yaw_deg:.2f} degrees")
    print(f"Estimated Translation: {T[:3, 3]}")

    # 4. Apply to PCD
    print(f"\nTransforming PCD: {args.pcd_in} -> {args.pcd_out}")
    if not Path(args.pcd_in).exists():
        print(f"Error: Input PCD {args.pcd_in} not found.")
        return

    pcd = o3d.io.read_point_cloud(args.pcd_in)
    if pcd.is_empty():
        print("Error: Empty PCD file.")
        return
        
    pcd.transform(T)
    
    Path(args.pcd_out).parent.mkdir(parents=True, exist_ok=True)
    success = o3d.io.write_point_cloud(args.pcd_out, pcd)
    
    if success:
        print("Done. Saved transformed PCD.")
    else:
        print("Error: Failed to write output PCD.")

if __name__ == "__main__":
    main()
