import numpy as np
import math
import open3d as o3d
import matplotlib.pyplot as plt
from .common import Config, OptimizerContext, transform_points, yaw_from_R

def save_pose_table_and_preview(cfg: Config, ctx: OptimizerContext, mode: str = "initial"):
    """
    mode: "initial" or "optimized"
    """
    csv_path = cfg.out_dir / f"global_poses_{mode}.csv"
    png_path = cfg.out_dir / f"global_preview_{mode}.png"
    print(f"\n=== [Output] Save pose table & preview ({mode}) ===")
    print(f"[Save] {csv_path}")

    # Pose table
    try:
        from scipy.spatial.transform import Rotation as R
    except Exception:
        R = None

    lines = ["block_id,tx,ty,tz,roll_deg,pitch_deg,yaw_deg"]
    for bid in ctx.block_ids:
        T = ctx.get_pose(bid, mode)
        t = T[:3, 3]
        if R is not None:
            rpy = R.from_matrix(T[:3, :3]).as_euler("xyz", degrees=True)
            roll, pitch, yaw = float(rpy[0]), float(rpy[1]), float(rpy[2])
        else:
            roll, pitch = 0.0, 0.0
            yaw = math.degrees(yaw_from_R(T[:3, :3]))
        lines.append(f"{bid},{t[0]:.6f},{t[1]:.6f},{t[2]:.6f},{roll:.3f},{pitch:.3f},{yaw:.3f}")

    csv_path.write_text("\n".join(lines), encoding="utf-8")

    # Preview plot (XY top-down)
    print(f"[Save] {png_path}")
    plt.figure(figsize=(10, 10), dpi=140)

    for bid in ctx.block_ids:
        if bid not in ctx.pcd_ds:
            continue
        pts = np.asarray(ctx.pcd_ds[bid].points, dtype=np.float64)
        if pts.shape[0] > cfg.preview_points_per_block:
            sel = np.random.choice(pts.shape[0], cfg.preview_points_per_block, replace=False)
            pts = pts[sel]

        T = ctx.get_pose(bid, mode)
        pts_g = transform_points(pts, T)
        plt.scatter(pts_g[:, 0], pts_g[:, 1], s=0.2, alpha=0.35, label=bid)

        tg = T[:3, 3]
        plt.scatter([tg[0]], [tg[1]], s=40)

    plt.axis("equal")
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.title(f"Global Preview ({mode})")
    plt.legend(markerscale=10, fontsize=8, loc="best")
    plt.tight_layout()
    plt.savefig(png_path)
    plt.close()

def export_global_map(cfg: Config, ctx: OptimizerContext, out_path):
    print(f"\n=== [Output] Export Global Map: {out_path} ===")
    merged = o3d.geometry.PointCloud()

    # Prefer optimized if available
    mode = "optimized" if ctx.optimized_poses else "initial"
    print(f"[Info] Using mode='{mode}' poses.")

    for bid in ctx.block_ids:
        if bid not in ctx.pcd_ds: 
            continue
        T = ctx.get_pose(bid, mode)
        pcd = o3d.geometry.PointCloud(ctx.pcd_ds[bid])
        pcd.transform(T)
        merged += pcd
        print(f"  [Merge] {bid} -> points={len(merged.points)}")

    # optional final voxel
    # merged = merged.voxel_down_sample(voxel_size=max(0.1, cfg.voxel_size))
    o3d.io.write_point_cloud(str(out_path), merged)
    print(f"[Save] {out_path} final_points={len(merged.points)}")
