# -*- coding: utf-8 -*-
#
# oc-sso-win.py — Windows-oriented SSO (SAML v2) wrapper for OpenConnect (AnyConnect)
#
# Defaults:
#  - Browser: Edge
#  - Dedicated persistent profile (auto-created): %LOCALAPPDATA%\oc-sso-profile
#
# Requirements:
#   pip install websocket-client
#
# Example:
#   python oc-sso-win.py https://example.com --openconnect "C:\Program Files\OpenConnect\openconnect.exe" -- -v

from __future__ import annotations

import argparse
import json
import logging as log
import os
import shutil
import socket
import subprocess
import tempfile
from time import sleep, time
from typing import Optional, Tuple, Dict, Any, List
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from xml.etree import ElementTree as ET

import websocket  # websocket-client


def default_profile_dir() -> str:
    local = os.environ.get("LOCALAPPDATA") or os.path.expanduser(r"~\AppData\Local")
    return os.path.join(local, "oc-sso-profile")


def ensure_dir(path: str) -> str:
    os.makedirs(path, exist_ok=True)
    return path


def main() -> int:
    p = argparse.ArgumentParser(
        prog=os.path.basename(__file__),
        description="Windows-oriented SSO (SAML v2) wrapper for OpenConnect using Chromium DevTools Protocol.",
        epilog="All args after '--' are passed through to openconnect.",
        allow_abbrev=False,
    )
    p.add_argument("-L", "--log-level", choices=["ERROR", "WARNING", "INFO", "DEBUG"], default="INFO")
    p.add_argument("--version-string", dest="version", default="5.1.10.233")
    p.add_argument("--sso-timeout", type=int, default=180)
    p.add_argument("--attempts", type=int, default=5)
    p.add_argument("--authgroup", default=None)
    p.add_argument("--group-access", default=None)
    p.add_argument("--openconnect", dest="openconnect_path", default=None)

    # Default browser: Edge
    p.add_argument("--browser", choices=["chrome", "edge", "chromium"], default="edge")
    p.add_argument("--browser-path", default=None)

    # Dedicated persistent profile (default) or user-specified
    p.add_argument("--user-data-dir", default=None,
                   help=("Chromium user data dir to use. If omitted, a dedicated persistent profile is used: "
                         r"%LOCALAPPDATA%\oc-sso-profile (auto-created)."))
    p.add_argument("--profile-directory", default=None,
                   help="Profile directory name inside user data dir (e.g. Default, Profile 1). Optional.")
    p.add_argument("--keep-temp-profile", action="store_true",
                   help="(debug) If using a temporary profile, do not delete it on exit.")
    p.add_argument("--no-page-nudge", action="store_true",
                   help="Do not send Page.navigate after attaching (normally harmless).")

    p.add_argument("server")

    args, oc_args = p.parse_known_args()

    # If the user separates OpenConnect args with '--', argparse keeps it; remove it.
    if oc_args and oc_args[0] == "--":
        oc_args = oc_args[1:]

    # Default to dedicated persistent profile if none specified
    if not args.user_data_dir:
        args.user_data_dir = ensure_dir(default_profile_dir())
        if args.profile_directory is None:
            args.profile_directory = "Default"

    log.basicConfig(format="%(levelname)s: %(message)s", level=getattr(log, args.log_level))

    vpn_url = normalize_server_to_url(args.server)

    try:
        log.info("Authenticating to VPN endpoint %r", vpn_url)
        vpn_url, complete = authenticate(vpn_url, args)
    except HTTPError as e:
        log.error("HTTP error: %r", e)
        return 65
    except KeyboardInterrupt:
        log.warning("Exiting on Ctrl+C")
        return 130
    except Exception as e:
        log.error("Error: %s", e)
        return 70

    # Launch openconnect with the session token
    oc_bin = resolve_openconnect(args.openconnect_path)
    cmd = [
        oc_bin,
        "--protocol=anyconnect",
        "--useragent", f"AnyConnect Windows_64 {args.version}",
        "--version-string", args.version,
    ]
    if args.authgroup and not any(a.startswith("--authgroup") for a in oc_args):
        cmd.append(f"--authgroup={args.authgroup}")

    cookie_mode = os.environ.get("OC_SSO_COOKIE_MODE", "stdin").lower()
    if cookie_mode == "arg":
        cmd += ["--cookie", complete.session_token]
        cookie_input = None
    else:
        cmd.append("--cookie-on-stdin")
        cookie_input = complete.session_token.encode("utf-8")

    cmd += list(oc_args)
    cmd.append(vpn_url)

    log.debug("Starting OpenConnect: %r", cmd)
    try:
        subprocess.run(cmd, input=cookie_input, check=True)
        return 0
    except subprocess.CalledProcessError as e:
        return e.returncode or 1
    except KeyboardInterrupt:
        return 130


def normalize_server_to_url(server: str) -> str:
    s = server.strip()
    if s.startswith("https://") or s.startswith("http://"):
        u = urlparse(s)
        if not u.netloc:
            raise ValueError(f"Invalid URL: {server!r}")
        if u.scheme == "http":
            return s
        return "https://" + u.netloc + (u.path or "/")
    return "https://" + s


def authenticate(vpn_url: str, args) -> Tuple[str, "AuthCompleteResponse"]:
    vpn_url = urlopen(vpn_url).url
    log.debug("Auth target URL %r", vpn_url)

    headers = {
        "User-Agent": f"AnyConnect Windows_64 {args.version}",
        "Accept": "*/*",
        "Accept-Encoding": "identity",
        "X-Transcend-Version": "1",
        "X-Aggregate-Auth": "1",
        "Content-Type": "text/xml; charset=utf-8",
        "Cache-Control": "no-store",
        "Pragma": "no-cache",
    }

    def auth_init() -> "AuthRequestResponse":
        body = AuthInitRequest(
            version=args.version,
            vpn_url=vpn_url,
            authgroup=args.authgroup,
            group_access=args.group_access,
        ).to_bytes()
        resp = urlopen(Request(vpn_url, body, headers)).read()
        parsed = parse_response(resp)
        if not isinstance(parsed, AuthRequestResponse):
            raise RuntimeError("Unexpected init response")
        if parsed.error:
            raise RuntimeError(f"Server error: {parsed.error}")
        return parsed

    auth_info = auth_init()

    with ChromiumBrowser(
        browser=args.browser,
        browser_path=args.browser_path,
        user_data_dir=args.user_data_dir,
        profile_directory=args.profile_directory,
        keep_temp_profile=args.keep_temp_profile,
        startup_url=auth_info.login,
    ) as br:
        log.info("Complete sign-in in the opened browser window...")

        if not args.no_page_nudge:
            try:
                br.navigate(auth_info.login)
            except Exception:
                pass

        urls = [auth_info.login, vpn_url]
        token = br.wait_cookie_value(auth_info.cookie, args.sso_timeout, urls=urls)
        if not token:
            raise RuntimeError(f"Did not receive SSO token cookie {auth_info.cookie!r} in {args.sso_timeout}s")

    finish_body = AuthFinishRequest(args.version, auth_info, token).to_bytes()

    for _ in range(args.attempts):
        resp = urlopen(Request(vpn_url, finish_body, headers)).read()
        try:
            parsed = parse_response(resp)
        except Exception:
            parsed = None
        if isinstance(parsed, AuthCompleteResponse):
            return vpn_url, parsed
        log.warning("Auth finish returned unexpected response; retrying...")
        sleep(1)

    raise RuntimeError("Could not finish authentication")


# ---------------- XML ---------------

class AuthInitRequest:
    def __init__(self, version: str, vpn_url: str, authgroup: Optional[str], group_access: Optional[str]):
        self.xml = ET.Element("config-auth", {"client": "vpn", "type": "init", "aggregate-auth-version": "2"})
        ET.SubElement(self.xml, "version", {"who": "vpn"}).text = version
        ET.SubElement(self.xml, "device-id").text = "win"
        ET.SubElement(self.xml, "group-select").text = authgroup or ""
        ET.SubElement(self.xml, "group-access").text = group_access or vpn_url
        caps = ET.SubElement(self.xml, "capabilities")
        ET.SubElement(caps, "auth-method").text = "single-sign-on-v2"

    def to_bytes(self) -> bytes:
        ET.indent(self.xml)
        return ET.tostring(self.xml, xml_declaration=True, encoding="UTF-8") + b"\n"


class AuthFinishRequest:
    def __init__(self, version: str, auth_info: "AuthRequestResponse", sso_token: str):
        self.xml = ET.Element("config-auth", {"client": "vpn", "type": "auth-reply", "aggregate-auth-version": "2"})
        ET.SubElement(self.xml, "version", {"who": "vpn"}).text = version
        ET.SubElement(self.xml, "device-id").text = "win"
        ET.SubElement(self.xml, "session-token")
        ET.SubElement(self.xml, "session-id")
        self.xml.append(auth_info.opaque)
        auth = ET.SubElement(self.xml, "auth")
        ET.SubElement(auth, "sso-token").text = sso_token

    def to_bytes(self) -> bytes:
        ET.indent(self.xml)
        return ET.tostring(self.xml, xml_declaration=True, encoding="UTF-8") + b"\n"


def parse_response(response: bytes):
    xml = ET.fromstring(response)
    t = xml.get("type")
    if t == "auth-request":
        r = AuthRequestResponse(xml)
        exp = "main"
    elif t == "complete":
        r = AuthCompleteResponse(xml)
        exp = "success"
    else:
        raise RuntimeError(f"Unknown response type: {t!r}")

    log.info("Response received: id=%r, message=%r%s",
             getattr(r, "id", None),
             getattr(r, "message", None),
             f", title={getattr(r, 'title', None)!r}" if hasattr(r, "title") else "")
    if getattr(r, "id", None) != exp:
        raise RuntimeError(f"Unexpected response id: {getattr(r,'id',None)!r}")
    return r


class AuthRequestResponse:
    def __init__(self, xml: ET.Element):
        self.id = xml.find("auth").get("id")
        self.title = (xml.findtext("auth/title") or "")
        self.message = (xml.findtext("auth/message") or "")
        self.error = xml.findtext("auth/error")
        self.opaque = xml.find("opaque")
        self.login = xml.findtext("auth/sso-v2-login")
        self.cookie = xml.findtext("auth/sso-v2-token-cookie-name")
        if not self.login or not self.cookie or self.opaque is None:
            raise RuntimeError("Missing SSO fields in auth-request response")


class AuthCompleteResponse:
    def __init__(self, xml: ET.Element):
        self.id = xml.find("auth").get("id")
        self.message = (xml.findtext("auth/message") or "")
        self.session_token = xml.findtext("session-token")
        if not self.session_token:
            raise RuntimeError("Missing session-token in complete response")


# ------------- Browser (CDP) -------------

class BrowserException(Exception):
    pass


class ChromiumBrowser:
    def __init__(self,
                 browser: str,
                 browser_path: Optional[str],
                 user_data_dir: Optional[str] = None,
                 profile_directory: Optional[str] = None,
                 keep_temp_profile: bool = False,
                 startup_url: str = "about:blank"):
        self.browser = browser
        self.browser_path = browser_path
        self.user_data_dir = user_data_dir
        self.profile_directory = profile_directory
        self.keep_temp_profile = keep_temp_profile
        self.startup_url = startup_url

        self.proc: Optional[subprocess.Popen] = None
        self.tmpdir: Optional[tempfile.TemporaryDirectory] = None
        self.devtools_base: Optional[str] = None
        self.ws = None
        self.msg_id = 0

    def __enter__(self):
        self._start_browser(initial_url=self.startup_url)
        return self

    def __exit__(self, exc_type, exc, tb):
        try:
            if self.ws:
                try:
                    self.ws.close()
                except Exception:
                    pass
            if self.proc:
                try:
                    self.proc.terminate()
                except Exception:
                    pass
        finally:
            if self.tmpdir and (not self.keep_temp_profile):
                try:
                    self.tmpdir.cleanup()
                except Exception:
                    pass

    def _start_browser(self, initial_url: str):
        with socket.create_server(("127.0.0.1", 0)) as s:
            _, port = s.getsockname()

        if self.user_data_dir:
            user_data_dir = ensure_dir(self.user_data_dir)
            self.tmpdir = None
        else:
            self.tmpdir = tempfile.TemporaryDirectory(prefix="oc-sso-win-")
            user_data_dir = self.tmpdir.name

        exe = resolve_browser_executable(self.browser, self.browser_path)

        self.devtools_base = f"http://127.0.0.1:{port}"
        allow_origin = self.devtools_base

        args = [
            exe,
            "--new-window",
            f"--remote-debugging-port={port}",
            f"--remote-allow-origins={allow_origin}",
            f"--user-data-dir={user_data_dir}",
        ]
        if self.profile_directory:
            args.append(f"--profile-directory={self.profile_directory}")

        args += [
            "--no-first-run",
            "--no-default-browser-check",
            "--disable-background-networking",
            "--disable-sync",
            "--disable-extensions",
            "--disable-popup-blocking",
            "--disable-features=TranslateUI",
            initial_url,
        ]

        log.debug("Starting browser: %r", args)
        self.proc = subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        # Wait for DevTools and connect to a page target
        for _ in range(160):
            try:
                targets = json.loads(urlopen(f"{self.devtools_base}/json/list").read())
                page = _pick_page_target(targets)
                if page and page.get("webSocketDebuggerUrl"):
                    self.ws = websocket.create_connection(page["webSocketDebuggerUrl"])
                    self.ws.settimeout(10)
                    self._enable()
                    return
            except URLError:
                pass
            sleep(0.25)

        raise BrowserException("Could not connect to DevTools endpoint")

    def _enable(self):
        for m in ("Page.enable", "Network.enable", "Runtime.enable"):
            try:
                self._cmd(m, quiet=True)
            except Exception:
                pass

    def _cmd(self, method: str, params: Optional[dict] = None, quiet: bool = False):
        self.msg_id += 1
        my_id = self.msg_id
        payload = {"id": my_id, "method": method}
        if params is not None:
            payload["params"] = params

        if not quiet:
            log.debug("CDP command: %s %s", method, "" if params is None else repr(params))

        self.ws.send(json.dumps(payload))

        deadline = time() + 10
        while True:
            if time() > deadline:
                raise BrowserException(f"Timeout waiting for CDP response to {method} (id={my_id})")
            msg = json.loads(self.ws.recv())

            if "id" not in msg:
                if not quiet:
                    log.debug("CDP event: %s", msg.get("method"))
                continue

            if msg["id"] != my_id:
                continue

            if not quiet:
                log.debug("CDP response: %r", msg)

            if "result" in msg:
                return msg["result"]
            raise BrowserException(msg.get("error", msg))

    def navigate(self, url: str):
        return self._cmd("Page.navigate", {"url": url}, quiet=False)

    def get_current_url(self) -> Optional[str]:
        try:
            res = self._cmd("Page.getNavigationHistory", quiet=True)
            idx = int(res["currentIndex"])
            return res["entries"][idx]["url"]
        except Exception:
            return None

    def _get_cookies_for_urls(self, urls: List[str]) -> List[Dict[str, Any]]:
        try:
            res = self._cmd("Network.getCookies", {"urls": urls}, quiet=True)
            cookies = res.get("cookies", [])
            if cookies:
                return cookies
        except Exception:
            pass

        try:
            res = self._cmd("Network.getAllCookies", quiet=True)
            return res.get("cookies", [])
        except Exception:
            return []

    def wait_cookie_value(self, name: str, timeout_s: int, urls: List[str]) -> Optional[str]:
        end = time() + timeout_s
        while time() < end:
            cur = self.get_current_url()
            url_list: List[str] = []
            if cur:
                url_list.append(cur)
            for u in urls:
                if u not in url_list:
                    url_list.append(u)

            cookies = self._get_cookies_for_urls(url_list)
            for c in cookies:
                if c.get("name") == name:
                    log.debug("Token cookie found on %s (domain=%s path=%s)",
                              cur or "?", c.get("domain"), c.get("path"))
                    return c.get("value")

            sleep(1)

        return None


def _pick_page_target(targets):
    for t in targets:
        if t.get("type") == "page":
            return t
    return targets[0] if targets else None


def resolve_browser_executable(browser: str, browser_path: Optional[str]) -> str:
    if browser_path:
        if os.path.isfile(browser_path):
            return browser_path
        raise FileNotFoundError(f"Browser not found: {browser_path!r}")

    candidates = []
    if browser == "chrome":
        candidates += ["chrome.exe", "chrome"]
    elif browser == "edge":
        candidates += ["msedge.exe", "msedge"]
    else:
        candidates += ["chromium.exe", "chromium", "chrome.exe", "msedge.exe"]

    for c in candidates:
        p = shutil.which(c)
        if p:
            return p

    pf = os.environ.get("ProgramFiles", r"C:\Program Files")
    pfx = os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)")
    local = os.environ.get("LOCALAPPDATA", "")

    if browser in ("chrome", "chromium"):
        common = [
            os.path.join(pf, r"Google\Chrome\Application\chrome.exe"),
            os.path.join(pfx, r"Google\Chrome\Application\chrome.exe"),
            os.path.join(local, r"Google\Chrome\Application\chrome.exe"),
        ]
    else:
        common = [
            os.path.join(pf, r"Microsoft\Edge\Application\msedge.exe"),
            os.path.join(pfx, r"Microsoft\Edge\Application\msedge.exe"),
            os.path.join(local, r"Microsoft\Edge\Application\msedge.exe"),
        ]

    for pth in common:
        if pth and os.path.isfile(pth):
            return pth

    raise FileNotFoundError(f"Could not locate {browser} executable. Use --browser-path.")


def resolve_openconnect(openconnect_path: Optional[str]) -> str:
    if openconnect_path:
        if os.path.isfile(openconnect_path):
            return openconnect_path
        p = shutil.which(openconnect_path)
        if p:
            return p
        raise FileNotFoundError(f"openconnect not found: {openconnect_path!r}")

    for c in ("openconnect.exe", "openconnect"):
        p = shutil.which(c)
        if p:
            return p

    here = os.path.abspath(os.path.dirname(__file__))
    local = os.path.join(here, "openconnect.exe")
    if os.path.isfile(local):
        return local

    raise FileNotFoundError("Could not locate openconnect.exe. Add to PATH or pass --openconnect.")


if __name__ == "__main__":
    raise SystemExit(main())
