#!/usr/bin/env python3
import argparse
import numpy as np
from scipy.spatial.transform import Rotation as R
import sys

def parse_map_string(map_str):
    """
    解析映射字符串，例如 "y,z,-x" -> 3x3 变换矩阵
    """
    mapping = map_str.strip().split(',')
    if len(mapping) != 3:
        raise ValueError("Map string must have 3 components, e.g., 'y,z,-x'")
    
    axis_map = {'x': 0, 'y': 1, 'z': 2}
    matrix = np.zeros((3, 3))
    
    for i, axis_str in enumerate(mapping):
        axis_str = axis_str.strip().lower()
        sign = 1
        if axis_str.startswith('-'):
            sign = -1
            axis_str = axis_str[1:]
        
        if axis_str.startswith('+'):
            axis_str = axis_str[1:]
        
        if axis_str not in axis_map:
            raise ValueError(f"Unknown axis: {axis_str}")
            
        src_idx = axis_map[axis_str]
        matrix[i, src_idx] = sign
        
    return matrix

def transform_poses(pose_data, map_str):
    """
    对 Pose 数据 (t, x, y, z, qx, qy, qz, qw) 应用坐标变换
    """
    if pose_data.shape[1] < 8:
        raise ValueError(f"Input data must have at least 8 columns (t, x, y, z, qx, qy, qz, qw), got {pose_data.shape[1]}")

    # 1. 解析变换矩阵
    trans_mat = parse_map_string(map_str)
    det = np.linalg.det(trans_mat)
    
    print(f"[Info] Transformation Matrix for '{map_str}':")
    print(trans_mat)
    print(f"[Info] Determinant: {det:.2f}")
    
    is_valid_rotation = np.isclose(det, 1.0)
    is_reflection = np.isclose(det, -1.0)

    # 2. 变换位置 (Position)
    # 原始位置: columns 1, 2, 3
    positions = pose_data[:, 1:4]
    # P_new = R * P_old (注意转置处理：(matrix @ vectors.T).T)
    positions_new = (trans_mat @ positions.T).T
    
    # 3. 变换姿态 (Orientation)
    quats = pose_data[:, 4:8] # (qx, qy, qz, qw)
    quats_new = quats.copy()
    
    if is_valid_rotation:
        print("[Status] Valid rotation detected. Transforming quaternions...")
        # 将四元数转为旋转矩阵 R_old
        r_old = R.from_quat(quats).as_matrix()
        # R_new = T * R_old
        # scipy Rotations expect [N, 3, 3] usually, but let's handle batch properly
        # matmul auto broadcasts if shapes align properly, but here T is (3,3) and r_old is (N,3,3)
        # We want R_new[i] = T @ R_old[i]
        r_new = np.matmul(trans_mat, r_old)
        
        # 转回四元数
        quats_new = R.from_matrix(r_new).as_quat()
        
    elif is_reflection:
        print("\n[WARNING] ⚠️ 你的变换包含镜像 (Determinant = -1)！")
        print("          四元数无法表示镜像坐标系。")
        print("          位置 (XYZ) 已变换，但姿态可能在可视化中显示错误（由内向外翻转）。")
        print("          建议: 尝试修改 map 字符串为合法的旋转，例如 'z, y, -x' 或 'z, -y, -x'")
        
        # 对于镜像，旋转部分如果不处理，会导致轨迹虽然位置对了，但朝向看起来很怪
        # 这里仅保留原始四元数或者不做特定处理，单纯变换位置是极其常见的 "Fix" 方式
        pass 
    else:
        print("[Error] Transformation is not a standard orthogonal basis change.")

    # 4. 组合结果
    # 保持时间戳不变
    new_data = np.hstack((pose_data[:, 0:1], positions_new, quats_new))
    return new_data

def main():
    parser = argparse.ArgumentParser(description="Pose Coordinate Fixer - Transforms TUM format poses.")
    parser.add_argument("input_file", help="Path to input pose file (TUM format: t x y z qx qy qz qw)")
    parser.add_argument("output_file", help="Path to output pose file")
    parser.add_argument("--map", required=True, help="Mapping string, e.g., 'y,z,-x' or 'z,y,-x'")

    args = parser.parse_args()

    # 读取数据
    try:
        print(f"Reading {args.input_file}...")
        data = np.loadtxt(args.input_file)
    except Exception as e:
        print(f"Error reading input file: {e}")
        return

    if data.ndim == 1:
        data = data.reshape(1, -1)

    if data.shape[0] == 0:
        print("Error: Input file is empty.")
        return

    # 执行变换
    try:
        new_data = transform_poses(data, args.map)
    except ValueError as e:
        print(f"Error: {e}")
        return

    # 保存结果
    print(f"Saving to {args.output_file}...")
    # 格式化: timestamp keep 6 decimal, others 6-9 is enough usually.
    # %.6f for all assumes time is in seconds.
    np.savetxt(args.output_file, new_data, fmt='%.6f')
    print("Done.")

if __name__ == "__main__":
    main()
