# /// script
# requires-python = ">=3.10"
# dependencies = ["requests", "pandas", "plotly"]
# ///
"""WeatherLabs — a multi-variable forecast dashboard in one file.

Run it (uv resolves the dependencies from the header above):

    WEATHERLABS_API_KEY=wl_live_... uv run weatherlabs_dashboard.py

Writes `weatherlabs_dashboard.html` (and opens it): line plots of temperature,
wind and precipitation for your sites, plus a map panel of the current
temperature field with your sites pinned on it.

Docs: https://weatherlabs.io/pages/docs.html
"""
from __future__ import annotations

import os
import webbrowser

import pandas as pd
import plotly.graph_objects as go
import requests
from plotly.subplots import make_subplots

API_BASE = os.environ.get('WEATHERLABS_API_BASE', 'https://api.weatherlabs.io')
SCOPE = f'{API_BASE}/v1/gfs/latest'
KEY = os.environ.get('WEATHERLABS_API_KEY', '')
HDRS = {'Authorization': f'Bearer {KEY}'} if KEY else {}

SITES = [  # name, lat, lon — edit me
    ('London', 51.5074, -0.1278),
    ('Berlin', 52.52, 13.405),
    ('Madrid', 40.4168, -3.7038),
]


def get(path: str, **params) -> dict:
    r = requests.get(f'{SCOPE}{path}', params=params, headers=HDRS, timeout=60)
    r.raise_for_status()
    return r.json()


def point_df(lat: float, lon: float, names: list[str]) -> pd.DataFrame:
    # ONE multi-var call: every variable shares the same valid_times
    fc = get('/forecast', lat=lat, lon=lon, vars=','.join(names), interp='bilinear')
    cols = {'time': pd.to_datetime(fc['valid_times'])}
    for v in names:
        cols[v] = fc['vars'][v]['values']
    return pd.DataFrame(cols)


def main() -> None:
    cycle = get('/cycle')
    print(f"cycle {cycle['init_compact']} · {cycle['n_steps']} steps")

    # ---- per-site series: temperature (°C), wind speed (m/s), precip (mm/h)
    frames: dict[str, pd.DataFrame] = {}
    for name, lat, lon in SITES:
        df = point_df(lat, lon, ['t2m', 'u10', 'v10', 'prate'])
        df['t2m_c'] = df['t2m'] - 273.15
        df['wind'] = (df['u10'] ** 2 + df['v10'] ** 2) ** 0.5
        df['precip_mmhr'] = df['prate'] * 3600
        frames[name] = df
        print(f'  {name}: {len(df)} steps')

    # ---- a coarsened global temperature field for the map panel
    grid = get('/grid', var='t2m', step=0, max_cells=30000)
    nlat, nlon = grid['nlat'], grid['nlon']
    lats = [grid['lat0'] + i * grid['dlat'] for i in range(nlat)]
    lons = [grid['lon0'] + j * grid['dlon'] for j in range(nlon)]
    # roll so the map runs -180..180
    z = [[grid['values'][i * nlon + j] for j in range(nlon)] for i in range(nlat)]
    shift = next(j for j, lo in enumerate(lons) if lo >= 180)
    lons = [lo - 360 for lo in lons[shift:]] + lons[:shift]
    z = [row[shift:] + row[:shift] for row in z]
    z_c = [[(v - 273.15) if v == v else None for v in row] for row in z]  # K→°C, NaN→None

    # ---- compose the dashboard
    fig = make_subplots(
        rows=2, cols=2,
        specs=[[{'colspan': 2}, None], [{}, {}]],
        subplot_titles=(
            f"2 m temperature — cycle {cycle['init_compact']}",
            'Wind speed (10 m)', 'Precipitation rate',
        ),
        vertical_spacing=0.12,
    )
    for name, df in frames.items():
        fig.add_trace(go.Scatter(x=df['time'], y=df['t2m_c'], name=f'{name} °C', mode='lines'), 1, 1)
        fig.add_trace(go.Scatter(x=df['time'], y=df['wind'], name=f'{name} m/s',
                                 mode='lines', showlegend=False), 2, 1)
        fig.add_trace(go.Scatter(x=df['time'], y=df['precip_mmhr'], name=f'{name} mm/h',
                                 mode='lines', showlegend=False), 2, 2)
    fig.update_layout(
        height=700, template='plotly_dark',
        title=f"WeatherLabs forecast dashboard · {', '.join(n for n, _, _ in SITES)}",
        legend=dict(orientation='h', y=1.08),
    )

    # map panel as its own figure (heatmap on lat/lon + site pins)
    map_fig = go.Figure(go.Heatmap(x=lons, y=lats, z=z_c, colorscale='RdBu_r',
                                   colorbar=dict(title='°C')))
    map_fig.add_trace(go.Scatter(
        x=[lon for _, _, lon in SITES], y=[lat for _, lat, _ in SITES],
        mode='markers+text', text=[n for n, _, _ in SITES], textposition='top center',
        marker=dict(size=9, color='#ffd54a', line=dict(width=1, color='#222')),
        showlegend=False,
    ))
    map_fig.update_layout(
        height=480, template='plotly_dark', title='Current 2 m temperature (step 0)',
        xaxis_title='longitude', yaxis_title='latitude',
        yaxis=dict(scaleanchor='x', scaleratio=1),
    )

    out = 'weatherlabs_dashboard.html'
    with open(out, 'w', encoding='utf-8') as f:
        f.write('<title>WeatherLabs dashboard</title><body style="background:#111;margin:0">')
        f.write(fig.to_html(full_html=False, include_plotlyjs='cdn'))
        f.write(map_fig.to_html(full_html=False, include_plotlyjs=False))
        f.write('</body>')
    print(f'wrote {out}')
    webbrowser.open('file://' + os.path.abspath(out))


if __name__ == '__main__':
    main()
