#!/usr/bin/env python3
# ===================================================================
#  Atlas Flash Tool (cross-platform) — Print and Play Creative Mfg.
#
#  Flashes a pre-compiled Atlas firmware binary to a Waveshare
#  ESP32-S3 Mini. Works on Windows, macOS, and Linux.
#
#  REQUIREMENTS:
#    - Python 3.7+
#    - esptool  (install with:  pip install esptool)
#    - pyserial (install with:  pip install pyserial)  [optional,
#      enables automatic port detection]
#
#  SETUP (once):
#    1. In Arduino IDE: open atlas_vX.Y.Z.ino, set the board to
#       "ESP32S3 Dev Module" with all Tools settings from
#       ATLAS_BUILD_GUIDE.md, then Sketch -> Export Compiled Binary.
#    2. Copy the three generated .bin files into the SAME folder as
#       this script:
#         atlas_vX.Y.Z.ino.bootloader.bin
#         atlas_vX.Y.Z.ino.partitions.bin
#         atlas_vX.Y.Z.ino.bin
#    3. Edit FIRMWARE_NAME below to match your version.
#
#  USAGE (each board):
#    python flash_atlas.py
#    - Plug in a board, pick the port, wait ~10s, repeat.
#    - Or run  python flash_atlas.py --batch  to loop continuously:
#      it waits for each new board, flashes it, and waits for the
#      next one until you press Ctrl+C.
# ===================================================================

import os
import sys
import time
import subprocess

# ---- EDIT THIS to match your firmware version ----
FIRMWARE_NAME = "atlas_v5_19_1"

# ESP32-S3 standard flash offsets
OFFSET_BOOTLOADER = "0x0"
OFFSET_PARTITIONS = "0x8000"
OFFSET_APP        = "0x10000"

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


def bin_path(suffix):
    return os.path.join(SCRIPT_DIR, f"{FIRMWARE_NAME}.ino.{suffix}")


def check_binaries():
    """Verify all three .bin files are present."""
    missing = []
    for suffix in ("bootloader.bin", "partitions.bin", "bin"):
        if not os.path.isfile(bin_path(suffix)):
            missing.append(f"{FIRMWARE_NAME}.ino.{suffix}")
    if missing:
        print("ERROR: missing firmware binaries:")
        for m in missing:
            print(f"   {m}")
        print()
        print("Fix: in Arduino IDE do Sketch -> Export Compiled Binary,")
        print("then copy the three .bin files into this folder.")
        print(f"Also confirm FIRMWARE_NAME='{FIRMWARE_NAME}' matches them.")
        return False
    return True


def list_ports():
    """Return a list of (device, description) serial ports, if pyserial
    is available. Otherwise return None."""
    try:
        import serial.tools.list_ports
        return [(p.device, p.description)
                for p in serial.tools.list_ports.comports()]
    except ImportError:
        return None


def pick_port():
    """Prompt the user to choose a serial port."""
    ports = list_ports()
    if ports is None:
        print("(pyserial not installed — automatic port list unavailable)")
        print("Install it with:  pip install pyserial")
        print()
        return input("Enter port manually (e.g. COM7 or /dev/ttyUSB0): ").strip()

    if not ports:
        print("No serial ports detected. Is the board plugged in?")
        return input("Enter port manually anyway, or press Enter to abort: ").strip()

    print("Detected serial ports:")
    for i, (dev, desc) in enumerate(ports):
        print(f"  [{i}] {dev}  —  {desc}")
    print()
    choice = input("Pick a number, or type a port name directly: ").strip()
    if choice.isdigit() and int(choice) < len(ports):
        return ports[int(choice)][0]
    return choice


def wait_for_new_board(known_ports):
    """Block until a new serial port appears; return its device name."""
    print("Waiting for a board to be plugged in... (Ctrl+C to stop)")
    while True:
        ports = list_ports()
        if ports is None:
            # Can't auto-detect; fall back to manual
            return pick_port()
        current = {d for d, _ in ports}
        new = current - known_ports
        if new:
            dev = sorted(new)[0]
            print(f"  New board detected on {dev}")
            time.sleep(0.5)  # let the port settle
            return dev
        time.sleep(0.4)


def flash(port):
    """Run esptool to flash all three binaries. Returns True on success."""
    cmd = [
        sys.executable, "-m", "esptool",
        "--chip", "esp32s3",
        "--port", port,
        "--baud", "921600",
        "--before", "default_reset",
        "--after", "hard_reset",
        "write_flash",
        "-z",
        "--flash_mode", "keep",
        "--flash_freq", "keep",
        "--flash_size", "keep",
        OFFSET_BOOTLOADER, bin_path("bootloader.bin"),
        OFFSET_PARTITIONS, bin_path("partitions.bin"),
        OFFSET_APP,        bin_path("bin"),
    ]
    print(f"\nFlashing {FIRMWARE_NAME} to {port} ...\n")
    result = subprocess.run(cmd)
    return result.returncode == 0


def print_fail_help():
    print()
    print("=" * 42)
    print("  FLASH FAILED")
    print("=" * 42)
    print("Common fixes:")
    print(" - Wrong port: re-check the port name")
    print(" - Port busy: close Arduino IDE Serial Monitor / other tools")
    print(" - Board not in download mode: hold BOOT, tap RESET,")
    print("   release BOOT, then try again")
    print(" - Bad USB cable: use a known data-capable cable")
    print(" - esptool not installed: pip install esptool")
    print()


def print_success():
    print()
    print("=" * 42)
    print("  FLASH COMPLETE")
    print("=" * 42)
    print("The board has rebooted and is running Atlas.")
    print()


def single_mode():
    port = pick_port()
    if not port:
        print("No port chosen. Exiting.")
        return
    ok = flash(port)
    print_success() if ok else print_fail_help()


def batch_mode():
    print("BATCH MODE — flash many boards in a row.")
    print("Each time a new board is plugged in, it gets flashed automatically.")
    print("Press Ctrl+C to stop.\n")
    count = 0
    # Establish the baseline set of ports already present (hubs, etc.)
    base = list_ports()
    if base is None:
        print("Batch mode needs pyserial for auto-detection.")
        print("Install it with:  pip install pyserial")
        return
    known = {d for d, _ in base}
    try:
        while True:
            port = wait_for_new_board(known)
            ok = flash(port)
            if ok:
                count += 1
                print_success()
                print(f"  Total flashed this session: {count}")
            else:
                print_fail_help()
            print("Unplug this board. Waiting for the next one...\n")
            # Update known set so the just-flashed board (still plugged
            # in) isn't re-detected as "new". Wait for it to be removed.
            while True:
                ports = list_ports()
                cur = {d for d, _ in ports}
                if port not in cur:
                    known = cur
                    break
                time.sleep(0.4)
    except KeyboardInterrupt:
        print(f"\n\nDone. Flashed {count} board(s) this session.")


def main():
    print()
    print("=" * 42)
    print("  ATLAS FLASH TOOL")
    print(f"  Firmware: {FIRMWARE_NAME}")
    print("  Print and Play Creative Manufacturing")
    print("=" * 42)
    print()

    if not check_binaries():
        sys.exit(1)

    if "--batch" in sys.argv:
        batch_mode()
    else:
        single_mode()


if __name__ == "__main__":
    main()
