from __future__ import annotations
import numpy as np
import open3d as o3d
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import math
import time

class TimingLogger:
    def __init__(self):
        self.records = [] # list of dict
        self.start_times = {}
        self.global_start = time.time()
    
    def start(self, name):
        self.start_times[name] = time.time()
        
    def stop(self, name, count=1, unit="block"):
        if name not in self.start_times: return
        duration = time.time() - self.start_times[name]
        self.records.append({
            "name": name, 
            "duration": duration, 
            "count": count, 
            "unit": unit
        })
        return duration

    def save_report(self, path: Path):
        total_time = time.time() - self.global_start
        lines = [
            "================ Performance Report ================",
            f"Total Execution Time: {total_time:.4f} sec",
            "----------------------------------------------------",
            f"{'Stage Name':<35} | {'Total(s)':<10} | {'Avg(ms)':<10} | {'Count':<5} {'Unit'}",
            "----------------------------------------------------"
        ]
        
        for r in self.records:
            avg_ms = (r['duration'] / max(1, r['count'])) * 1000.0
            lines.append(f"{r['name']:<35} | {r['duration']:<10.4f} | {avg_ms:<10.2f} | {r['count']:<5} {r['unit']}")
            
        with open(path, "w", encoding="utf-8") as f:
            f.write("\n".join(lines))
        print(f"\n[Log] Timing report saved to {path}")


# -----------------------------
# 工具函数
# -----------------------------
def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def make_T(R: np.ndarray, t: np.ndarray) -> np.ndarray:
    """构造 4x4 SE(3) 矩阵"""
    T = np.eye(4, dtype=np.float64)
    T[:3, :3] = R
    T[:3, 3] = t.reshape(3)
    return T

def transform_points(pts: np.ndarray, T: np.ndarray) -> np.ndarray:
    """对 Nx3 点进行 SE(3) 变换：p' = R p + t"""
    R = T[:3, :3]
    t = T[:3, 3]
    return (pts @ R.T) + t

def yaw_from_R(R: np.ndarray) -> float:
    """yaw(rad)"""
    return float(np.arctan2(R[1, 0], R[0, 0]))

def euclidean(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.linalg.norm(a - b))

def estimate_normals(pcd: o3d.geometry.PointCloud, voxel_size: float) -> None:
    """ICP(Point-to-Plane) 需要法向"""
    radius = max(1e-3, voxel_size * 3.0)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=radius, max_nn=30))
    pcd.normalize_normals()

def load_pose_txt(path: Path) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load pose txt file.
    Format: timestamp x y z [qx qy qz qw] (>=4 columns)
    Returns: timestamps (N,), positions (N,3)
    """
    try:
        data = np.loadtxt(str(path))
    except Exception as e:
        print(f"[Error] Failed to load {path}: {e}")
        return np.array([]), np.array([])

    if data.ndim == 1:
        data = data.reshape(1, -1)

    if data.shape[0] == 0 or data.shape[1] < 4:
        return np.array([]), np.array([])

    ts = data[:, 0].astype(np.float64)
    xyz = data[:, 1:4].astype(np.float64)

    # sort + unique timestamps
    order = np.argsort(ts)
    ts = ts[order]
    xyz = xyz[order]
    ts_u, idx_u = np.unique(ts, return_index=True)
    return ts_u, xyz[idx_u]

def interp_positions(target_times: np.ndarray, src_times: np.ndarray, src_positions: np.ndarray) -> np.ndarray:
    """Linear interpolation of positions at target_times."""
    if len(src_times) < 2:
        return np.zeros((len(target_times), 3), dtype=np.float64)
    x = np.interp(target_times, src_times, src_positions[:, 0])
    y = np.interp(target_times, src_times, src_positions[:, 1])
    z = np.interp(target_times, src_times, src_positions[:, 2])
    return np.stack([x, y, z], axis=1)

def compute_svd_alignment(local_pts: np.ndarray, global_pts: np.ndarray) -> np.ndarray:
    """
    Kabsch/SVD: find T such that global ≈ R*local + t
    local_pts/global_pts: (N,3) point pairs with same timestamp
    """
    n = min(len(local_pts), len(global_pts))
    if n < 3:
        return np.eye(4, dtype=np.float64)

    A = local_pts[:n]
    B = global_pts[:n]

    ca = A.mean(axis=0)
    cb = B.mean(axis=0)
    AA = A - ca
    BB = B - cb

    H = AA.T @ BB
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = Vt.T @ U.T

    t = cb - (R @ ca)
    return make_T(R, t)

def voxel_keys(pts: np.ndarray, voxel: float) -> np.ndarray:
    """
    pts -> voxel indices -> structured keys for fast intersect
    """
    idx = np.floor(pts / voxel).astype(np.int32)
    dtype = np.dtype([("x", np.int32), ("y", np.int32), ("z", np.int32)])
    keys = idx.view(dtype).reshape(-1)
    return np.unique(keys)


# -----------------------------
# Configuration
# -----------------------------
@dataclass
class Config:
    # Paths
    pcd_dir: Path = Path("data/HK_PCD_DS")
    pose_root_dir: Path = Path("data/HK_POSE")
    gnss_json: Path = Path("data/HK_POSE/simulated_gnss.json")
    out_dir: Path = Path("results")

    # Common
    voxel_size: float = 0.5
    max_stage: int = 5 

    # SVD
    min_svd_pairs: int = 20
    max_svd_pairs: int = 20000

    # Module 3a: Distance Gating
    gating_distance_m: float = 500.0

    # Module 3b: Overlap Check
    overlap_voxel_size: float = 0.5
    overlap_max_points: int = 200000
    overlap_ratio_threshold: float = 0.01

    # ICP
    icp_max_corr_dist: float = 2.0
    icp_max_iter: int = 50
    icp_rmse_threshold: float = 0.5
    icp_fitness_threshold: float = 0.20

    # GTSAM
    gnss_sigma_xyz: float = 50.0
    gnss_sigma_rot_deg: float = 60.0
    icp_sigma_xyz: float = 0.10
    icp_sigma_rot_deg: float = 5.0

    # Output
    preview_points_per_block: int = 30000


# -----------------------------
# Shared Context
# -----------------------------
@dataclass
class OptimizerContext:
    # Metadata
    gnss_info: Dict[str, Dict] = field(default_factory=dict)
    block_ids: List[str] = field(default_factory=list)
    
    # File Paths
    pcd_paths: Dict[str, Path] = field(default_factory=dict)
    pose_paths: Dict[str, Path] = field(default_factory=dict)
    gnss_kf_paths: Dict[str, Path] = field(default_factory=dict)

    # Loaded Data
    pcd_ds: Dict[str, o3d.geometry.PointCloud] = field(default_factory=dict)
    local_traj: Dict[str, Tuple[np.ndarray, np.ndarray]] = field(default_factory=dict)
    gnss_traj: Dict[str, Tuple[np.ndarray, np.ndarray]] = field(default_factory=dict)
    
    # Results
    svd_global_poses: Dict[str, np.ndarray] = field(default_factory=dict) # Initial Guess
    pairs_gated: List[Tuple[str, str, float]] = field(default_factory=list)
    candidate_edges: List[Tuple[str, str, float]] = field(default_factory=list)
    icp_edges: Dict[Tuple[str, str], Dict] = field(default_factory=dict)
    optimized_poses: Dict[str, np.ndarray] = field(default_factory=dict)

    def get_pose(self, bid: str, mode: str = "optimized") -> np.ndarray:
        # Helper to get best available pose
        if mode == "optimized" and bid in self.optimized_poses:
            return self.optimized_poses[bid]
        if bid in self.svd_global_poses:
            return self.svd_global_poses[bid]
        
        # Fallback to origin
        if bid in self.gnss_info:
            origin = np.array(self.gnss_info[bid]["origin_xyz"], dtype=np.float64)
            return make_T(np.eye(3), origin)
            
        return np.eye(4)
