import xarray as xr
import numpy as np

# ----------------------------
# User settings
# ----------------------------
INPUT_PATH = "/data/jjuvin/AVIATION/GAIA/GAIA_REGRID/"
OUTPUT_PATH = "/data/jjuvin/AVIATION/GAIA/NA_PDF/"

months = [f"{i:02d}" for i in range(1, 13)]  # 01..12

# ----------------------------
# Loop over months
# ----------------------------
for month in months:
    infile = f"{INPUT_PATH}emission.GAIA.2019-{month}.regrid.nc"
    outfile = f"{OUTPUT_PATH}emission.GAIA.2019-{month}.NA.pdf.nc"

    print(f"Processing month {month}")

    # Open dataset
    ds = xr.open_dataset(infile)
    print(ds)

    # Select North Atlantic region is done before
    seg = ds['seg_length_m']

    # Compute latitude weights: sqrt(cos(lat)) in radians
    # lat_rad = np.deg2rad(ds_na['latitude'])
    # lat_weights = np.sqrt(np.cos(lat_rad))

    # Broadcast over longitude and pressure levels
    # seg has dims: (time, pressure_Pa, latitude, longitude)
    #weights_2d = lat_weights.values[np.newaxis, :, np.newaxis]  # shape (1, lat, 1)
    #seg_weighted = seg * weights_2d  # broadcast to (time, pressure_Pa, lat, lon)

    # Normalize to get PDF
    col_sum = seg.sum(dim=["level"])
    pdf = xr.where(col_sum > 0, seg / col_sum, 0.0)
    pdf = pdf.transpose(*seg.dims)

    # Add PDF variable
    ds['flight_density_pdf'] = pdf
    #ds_na['flight_density_pdf'].attrs['description'] = "Flight density PDF weighted by sqrt(cos(lat))"
    ds['flight_density_pdf'].attrs['description'] = "Flight density PDF per column"

    # Save to NetCDF
    ds.to_netcdf(outfile)

    print(f"Saved {outfile}\n")
