#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import numpy as np
from pathlib import Path
from typing import Iterable, List, Tuple

Pose = Tuple[str, np.ndarray, np.ndarray]  # (timestamp_str, t_xyz, q_xyzw)

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Stitch FAST-LIVO2 pose blocks into a single global trajectory, "
            "generate per-block global Mock GNSS files, and export a simulation summary JSON."
        )
    )
    parser.add_argument(
        "--input-dir",
        default="data/HK_POSE",
        help="Directory containing pose txt files (default: data/HK_POSE).",
    )
    parser.add_argument(
        "--input-files",
        nargs="+",
        help="Ordered list of pose txt files. Relative paths are resolved under --input-dir.",
    )
    parser.add_argument(
        "--pattern",
        default="*_fixed.txt",
        help="Glob pattern for input files if --input-files is not set (default: *_fixed.txt).",
    )
    parser.add_argument(
        "--stitched-output",
        help="Output path for stitched trajectory txt (default: <input-dir>/stitched_path.txt).",
    )
    parser.add_argument(
        "--json-output",
        help="Output path for simulated GNSS summary JSON (default: <input-dir>/simulated_gnss.json).",
    )
    parser.add_argument(
        "--gnss-interval",
        type=float,
        default=1.0,
        help="Time interval for simulated GNSS sampling in seconds (default: 1.0).",
    )
    parser.add_argument(
        "--gnss-noise",
        type=float,
        default=0.5,
        help="Std dev of simulated GNSS noise in meters (default: 0.5).",
    )
    parser.add_argument(
        "--decimals",
        type=int,
        default=6,
        help="Decimal places for output xyz/quaternion values (default: 6).",
    )
    return parser.parse_args()


def resolve_paths(args: argparse.Namespace) -> Tuple[List[Path], Path, Path]:
    base_dir = Path(args.input_dir)
    if args.input_files:
        input_paths = []
        for item in args.input_files:
            path = Path(item)
            if not path.is_absolute():
                path = base_dir / path
            input_paths.append(path)
    else:
        # Default fallback if no files specified
        input_paths = sorted(list(base_dir.glob(args.pattern)))
        
    stitched_output = Path(args.stitched_output) if args.stitched_output else base_dir / "stitched_path.txt"
    json_output = Path(args.json_output) if args.json_output else base_dir / "simulated_gnss.json"
    
    return input_paths, stitched_output, json_output


def read_pose_file(path: Path) -> List[Pose]:
    poses: List[Pose] = []
    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) >= 8:
                timestamp = parts[0]
                x, y, z, qx, qy, qz, qw = map(float, parts[1:8])
                t = np.array([x, y, z], dtype=np.float64)
                q = np.array([qx, qy, qz, qw], dtype=np.float64)
                poses.append((timestamp, t, q))
    return poses


def quat_to_rot(q: np.ndarray) -> np.ndarray:
    q = q.astype(np.float64)
    norm = np.linalg.norm(q)
    if norm == 0.0:
        return np.eye(3)
    q = q / norm
    x, y, z, w = q
    xx, yy, zz = x * x, y * y, z * z
    xy, xz, yz = x * y, x * z, y * z
    wx, wy, wz = w * x, w * y, w * z
    return np.array(
        [
            [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)],
            [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)],
            [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)],
        ],
        dtype=np.float64,
    )


def rot_to_quat(R: np.ndarray) -> np.ndarray:
    trace = R[0, 0] + R[1, 1] + R[2, 2]
    if trace > 0.0:
        s = 0.5 / np.sqrt(trace + 1.0)
        w = 0.25 / s
        x = (R[2, 1] - R[1, 2]) * s
        y = (R[0, 2] - R[2, 0]) * s
        z = (R[1, 0] - R[0, 1]) * s
    elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]:
        s = 2.0 * np.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2])
        w = (R[2, 1] - R[1, 2]) / s
        x = 0.25 * s
        y = (R[0, 1] + R[1, 0]) / s
        z = (R[0, 2] + R[2, 0]) / s
    elif R[1, 1] > R[2, 2]:
        s = 2.0 * np.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2])
        w = (R[0, 2] - R[2, 0]) / s
        x = (R[0, 1] + R[1, 0]) / s
        y = 0.25 * s
        z = (R[1, 2] + R[2, 1]) / s
    else:
        s = 2.0 * np.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1])
        w = (R[1, 0] - R[0, 1]) / s
        x = (R[0, 2] + R[2, 0]) / s
        y = (R[1, 2] + R[2, 1]) / s
        z = 0.25 * s
        
    q = np.array([x, y, z, w], dtype=np.float64)
    norm = np.linalg.norm(q)
    if norm > 0:
        q /= norm
    return q


def pose_to_matrix(t: np.ndarray, q: np.ndarray) -> np.ndarray:
    T = np.eye(4, dtype=np.float64)
    T[:3, :3] = quat_to_rot(q)
    T[:3, 3] = t
    return T


def matrix_to_pose(T: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    t = T[:3, 3].copy()
    q = rot_to_quat(T[:3, :3])
    return t, q


def generate_gnss_data(
    global_poses: List[Pose], 
    interval: float, 
    noise_std: float
) -> Tuple[np.ndarray, np.ndarray]: # (times, noisy_positions)
    """
    Generate downsampled-noisy GNSS from a list of poses.
    """
    if not global_poses:
        return np.array([]), np.array([])

    prev_time = -np.inf
    selected_times = []
    selected_positions = []

    for ts_str, t_xyz, _ in global_poses:
        ts = float(ts_str)
        if ts - prev_time >= interval:
            selected_times.append(ts)
            selected_positions.append(t_xyz)
            prev_time = ts
            
    if not selected_times:
        return np.array([]), np.array([])
        
    positions_clean = np.array(selected_positions)
    noise = np.random.normal(0, noise_std, positions_clean.shape)
    positions_noisy = positions_clean + noise
    
    return np.array(selected_times), positions_noisy


def main() -> None:
    args = parse_args()
    input_paths, stitched_output, json_output = resolve_paths(args)
    
    np.random.seed(42)  # For reproducible noise

    if not input_paths:
        print("No input files found.")
        return

    print(f"Stitching {len(input_paths)} files...")
    
    stitched_all: List[Pose] = []
    gnss_summary = {}
    T_global_prev_end = np.eye(4, dtype=np.float64)
    
    # Ensure output dirs exist
    stitched_output.parent.mkdir(parents=True, exist_ok=True)

    block_positions_all = []

    for idx, path in enumerate(input_paths):
        block_id = f"block_{idx+1}"
        print(f"Processing {block_id}: {path.name}")
        poses = read_pose_file(path)
        
        if not poses:
            print(f"  [Warn] Empty file {path}")
            continue
            
        # --- 1. Stitching Logic ---
        # Get start pose of current block
        t0, q0 = poses[0][1], poses[0][2]
        T_local_start = pose_to_matrix(t0, q0)
        
        if idx == 0:
            # First block defines the world frame origin
            T_offset = np.eye(4)
        else:
            # Append start of this to end of previous
            # T_offset * T_local_start = T_global_prev_end
            T_offset = T_global_prev_end @ np.linalg.inv(T_local_start)
            
        block_global_poses: List[Pose] = []
        
        # Transform all poses in this block
        for timestamp, t_local, q_local in poses:
            T_local = pose_to_matrix(t_local, q_local)
            T_global = T_offset @ T_local
            
            t_glob, q_glob = matrix_to_pose(T_global)
            
            pose_tuple = (timestamp, t_glob, q_glob)
            block_global_poses.append(pose_tuple)
            stitched_all.append(pose_tuple)
            
            # Update tracker for next block alignment
            T_global_prev_end = T_global

        # --- 2. Mock GNSS Generation ---
        gnss_times, gnss_pos = generate_gnss_data(block_global_poses, args.gnss_interval, args.gnss_noise)
        
        if len(gnss_times) > 0:
            # Save per-block GNSS file
            gnss_filename = path.stem + "_gnss.txt"
            if "_fixed" in gnss_filename:
                 gnss_filename = gnss_filename.replace("_fixed_gnss", "_gnss")
            
            gnss_path = path.parent / gnss_filename
            with gnss_path.open("w", encoding="utf-8") as f:
                for t, pos in zip(gnss_times, gnss_pos):
                    f.write(f"{t:.6f} {pos[0]:.4f} {pos[1]:.4f} {pos[2]:.4f}\n")
            print(f"  -> Wrote GNSS: {gnss_path.name} ({len(gnss_times)} pts)")
            
            # Update Summary Data (Based on Mock Data)
            origin_xyz = gnss_pos[0].tolist()
            centroid_xyz = np.mean(gnss_pos, axis=0).tolist()
            
            gnss_summary[block_id] = {
                "origin_xyz": origin_xyz,
                "centroid_xyz": centroid_xyz,
                "gnss_file": gnss_filename,
                "pose_file": path.name
            }
        else:
            print(f"  [Warn] No GNSS points generated for {path.name} (too short?)")

    # --- 3. Save Final Outputs ---
    
    # Save Stitched Path
    with stitched_output.open("w", encoding="utf-8") as f:
        fmt = f"{{:.{args.decimals}f}}"
        for timestamp, t, q in stitched_all:
            line = f"{timestamp} " + \
                   f"{fmt.format(t[0])} {fmt.format(t[1])} {fmt.format(t[2])} " + \
                   f"{fmt.format(q[0])} {fmt.format(q[1])} {fmt.format(q[2])} {fmt.format(q[3])}\n"
            f.write(line)
    
    # Save JSON Summary
    with json_output.open("w", encoding="utf-8") as f:
        json.dump(gnss_summary, f, indent=2)
            
    print(f"\nProcessing Complete.")
    print(f"  Total Stitched Poses: {len(stitched_all)}")
    print(f"  Stitched Output: {stitched_output}")
    print(f"  JSON Summary: {json_output}")


if __name__ == "__main__":
    main()
