import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
from astropy.coordinates import SkyCoord
from astropy.time import Time
import astropy.units as u
from saltshaker import get_salt_observer, get_visibility_windows

# --- Configuration & Setup ---
target_name = 'Sirius'
year = 2026

observer = get_salt_observer()
target = SkyCoord.from_name(target_name)

# Sample every 1 day for a smooth, continuous "carpet" effect
dates = Time(f"{year}-01-01") + np.arange(0, 365, 1) * u.day

# --- Color Palette (Clean Documentation) ---
BG_COLOR = '#FFFFFF'        # Pure white background
NIGHT_COLOR = '#E5E7EB'     # Soft light-grey for dark time
TRACK_COLOR = '#1D4ED8'     # Bold, high-contrast blue for visibility tracks
TEXT_COLOR = '#111827'      # Near-black for crisp, legible text
GRID_COLOR = '#D1D5DB'      # Subtle grey for gridlines

# --- Plotting ---
plt.figure(figsize=(10, 6))

for date in dates:
    windows = get_visibility_windows(target, date)
    try:
        # Calculate twilight limits
        eve = observer.twilight_evening_astronomical(date, which='next')
        morn = observer.twilight_morning_astronomical(eve, which='next')

        # Base time: 10:00 UTC (12:00 SAST) to keep the night continuous
        base = Time(f"{date.iso.split()[0]} 10:00:00")
        to_h = lambda t: (t - base).to(u.hour).value

        # Plot Dark Time as vertical slices
        plt.plot([date.datetime, date.datetime], [to_h(eve), to_h(morn)],
                 color=NIGHT_COLOR, alpha=1.0, lw=2)

        # Plot SALT Visibility Tracks
        for w in windows:
            plt.plot([date.datetime, date.datetime], [to_h(w.start_time), to_h(w.end_time)],
                     color=TRACK_COLOR, lw=2, alpha=0.9)
    except Exception:
        continue

# --- Formatting & Styling ---
ax = plt.gca()
ax.set_facecolor(BG_COLOR)
plt.gcf().patch.set_facecolor(BG_COLOR)

# Titles and Labels
plt.title(f"Preliminary Annual Visibility Cycle: {target_name} (SALT)",
          color=TEXT_COLOR, fontsize=13, pad=12, fontweight='bold')
plt.xlabel("Date", color=TEXT_COLOR, fontsize=11, fontweight='500')
plt.ylabel("Hours from Noon SAST", color=TEXT_COLOR, fontsize=11, fontweight='500')

# Axis Ticks and Spines styling
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
ax.tick_params(colors=TEXT_COLOR, which='both', labelsize=10)

for spine in ax.spines.values():
    spine.set_color(GRID_COLOR)
    spine.set_linewidth(1)

# Grid and limits
plt.grid(True, axis='y', alpha=0.6, color=GRID_COLOR, linestyle=':')
ax.set_ylim(4, 20)
ax.invert_yaxis()

# Custom Legend
custom_lines = [
    Line2D([0], [0], color=NIGHT_COLOR, lw=4),
    Line2D([0], [0], color=TRACK_COLOR, lw=4)
]
legend = plt.legend(custom_lines, ['Astronomical Dark Time', 'Est. SALT Visibility'],
                    loc='upper right', framealpha=1.0,
                    facecolor=BG_COLOR, edgecolor=GRID_COLOR, fontsize=10)
for text in legend.get_texts():
    text.set_color(TEXT_COLOR)

plt.tight_layout()
plt.show()