#!/usr/bin/env python3
from __future__ import annotations

import argparse
import re
from pathlib import Path
from typing import Iterable, List, Tuple

import numpy as np
import open3d as o3d


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Preprocess raw PCDs by downsampling and persist them for the optimizer."
    )
    parser.add_argument(
        "--pcd-dir",
        default="data/HK_PCD",
        help="Input raw PCD folder (default: data/HK_PCD).",
    )
    parser.add_argument(
        "--out-dir",
        default="data/HK_PCD_DS",
        help="Output folder for processed PCDs (default: data/HK_PCD_DS).",
    )
    parser.add_argument(
        "--pattern",
        default="*_all_raw_points_*.pcd",
        help="Glob pattern for input PCDs (default: *_all_raw_points_*.pcd).",
    )
    parser.add_argument(
        "--voxel-size",
        type=float,
        default=0.5,
        help="Voxel size for voxel_down_sample in meters (default: 0.5).",
    )
    parser.add_argument(
        "--uniform-every-k",
        type=int,
        default=1,
        help="Uniform downsample every k points, 1 disables (default: 1).",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing processed PCDs.",
    )
    return parser.parse_args()


def sort_pcd_paths(paths: Iterable[Path]) -> List[Path]:
    pat = re.compile(r"(\d+)\.pcd$", re.IGNORECASE)

    def key(path: Path) -> int:
        match = pat.search(path.name)
        if match:
            return int(match.group(1))
        return 1_000_000

    return sorted(paths, key=key)


def preprocess_one(
    input_path: Path,
    output_path: Path,
    voxel_size: float,
    uniform_every_k: int,
) -> Tuple[int, int, int]:
    pcd = o3d.io.read_point_cloud(str(input_path))
    n_raw = int(np.asarray(pcd.points).shape[0])

    if uniform_every_k > 1:
        pcd = pcd.uniform_down_sample(every_k_points=uniform_every_k)
    n_uniform = int(np.asarray(pcd.points).shape[0])

    pcd = pcd.voxel_down_sample(voxel_size=voxel_size)
    n_ds = int(np.asarray(pcd.points).shape[0])

    output_path.parent.mkdir(parents=True, exist_ok=True)
    ok = o3d.io.write_point_cloud(str(output_path), pcd)
    if not ok:
        raise RuntimeError(f"Failed to write processed PCD: {output_path}")

    return n_raw, n_uniform, n_ds


def main() -> None:
    args = parse_args()
    pcd_dir = Path(args.pcd_dir)
    out_dir = Path(args.out_dir)

    if args.voxel_size <= 0.0:
        raise ValueError("voxel_size must be > 0")
    if args.uniform_every_k < 1:
        raise ValueError("uniform_every_k must be >= 1")
    if not pcd_dir.exists():
        raise FileNotFoundError(f"Input PCD folder not found: {pcd_dir}")

    pcd_paths = sort_pcd_paths(pcd_dir.glob(args.pattern))
    if not pcd_paths:
        raise FileNotFoundError(f"No PCDs matched {args.pattern} under {pcd_dir}")

    out_dir.mkdir(parents=True, exist_ok=True)
    print(f"[Input ] {pcd_dir} ({len(pcd_paths)} files)")
    print(f"[Output] {out_dir}")
    print(f"[Config] voxel_size={args.voxel_size}, uniform_every_k={args.uniform_every_k}")

    processed = 0
    skipped = 0
    for path in pcd_paths:
        out_path = out_dir / path.name
        if out_path.exists() and not args.overwrite:
            skipped += 1
            print(f"[Skip ] {path.name} already exists.")
            continue

        n_raw, n_uniform, n_ds = preprocess_one(
            input_path=path,
            output_path=out_path,
            voxel_size=args.voxel_size,
            uniform_every_k=args.uniform_every_k,
        )
        processed += 1
        print(
            f"[Done ] {path.name}: raw={n_raw} -> uniform={n_uniform} -> voxel={n_ds}"
        )

    print(f"[Summary] processed={processed}, skipped={skipped}")


if __name__ == "__main__":
    main()
