import xarray as xr
import numpy as np
import os

# -----------------------------
# Paths
# -----------------------------
INPUT_DIR = "/data/jjuvin/AVIATION/GAIA"
OUTPUT_DIR = "/data/jjuvin/AVIATION/GAIA/GAIA_REGRID"
TARGET_FILE = "/data/jjuvin/RHi/rhi-350to150-025-2025.nc"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -----------------------------
# Load target grid
# -----------------------------
ds_target = xr.open_dataset(TARGET_FILE)

target_lat = ds_target["latitude"]
target_lon = ds_target["longitude"]
target_lev = ds_target["level"]

# -----------------------------
# Loop over months
# -----------------------------
for m in range(1, 13):
    month = f"{m:02d}"

    infile = f"{INPUT_DIR}/emission.GAIA.2019-{month}.360x180.nc"
    outfile = f"{OUTPUT_DIR}/emission.GAIA.2019-{month}.regrid.nc"

    print(f"Processing month {month}")

    # -----------------------------
    # Load dataset
    # -----------------------------
    ds = xr.open_dataset(infile)

    # -----------------------------
    # Step 1: pressure Pa → hPa
    # -----------------------------
    ds = ds.assign_coords(
        level=ds["pressure_Pa"] / 100.0
    ).swap_dims({"pressure_Pa": "level"}).drop_vars("pressure_Pa")

    # -----------------------------
    # Step 2: ensure monotonic levels
    # -----------------------------
    ds = ds.sortby("level")

    # -----------------------------
    # Step 3: horizontal interpolation
    # -----------------------------
    ds_interp_h = ds.interp(
        latitude=target_lat,
        longitude=target_lon,
        method="linear",
        kwargs={"fill_value": 0.0}
    )

    # -----------------------------
    # Step 4: vertical interpolation
    # -----------------------------
    ds_interp = ds_interp_h.interp(
        level=target_lev,
        method="linear",
        kwargs={"fill_value": 0.0}
    )

    # -----------------------------
    # Save
    # -----------------------------
    ds_interp.to_netcdf(outfile)

    print(f"Saved {outfile}")
    print("-----------------------------")
