"""
APS Authentication module.
Handles 2-legged (client credentials) and 3-legged (OAuth2 authorization code) authentication.
"""

import os
import time
import secrets
import requests
from urllib.parse import urlencode, quote


BASE_URL = "https://developer.api.autodesk.com"

# --- 2-Legged Token (server-to-server) ---

_token_cache = {"token": None, "expires_at": 0}


def get_2legged_token(client_id=None, client_secret=None, scopes="data:read viewables:read"):
    """
    Get a 2-legged access token (client credentials flow).
    Cached until near expiration.

    Args:
        client_id: APS Client ID (default: env APS_CLIENT_ID)
        client_secret: APS Client Secret (default: env APS_CLIENT_SECRET)
        scopes: Space-separated OAuth2 scopes

    Returns:
        str: Access token
    """
    client_id = client_id or os.getenv("APS_CLIENT_ID")
    client_secret = client_secret or os.getenv("APS_CLIENT_SECRET")

    if _token_cache["token"] and _token_cache["expires_at"] > time.time():
        return _token_cache["token"]

    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={"grant_type": "client_credentials", "scope": scopes},
        auth=(client_id, client_secret),
    )
    resp.raise_for_status()
    data = resp.json()
    _token_cache["token"] = data["access_token"]
    _token_cache["expires_at"] = time.time() + data["expires_in"] - 60
    return _token_cache["token"]


def clear_token_cache():
    """Clear the 2-legged token cache."""
    _token_cache["token"] = None
    _token_cache["expires_at"] = 0


# --- 3-Legged OAuth2 (user authorization) ---

_user_tokens = {}
_active_token = None


def get_auth_url(client_id=None, callback_url=None, scopes="data:read data:write data:create account:read"):
    """
    Generate Autodesk OAuth2 authorization URL for 3-legged flow.

    Returns:
        tuple: (authorization_url, state)
    """
    client_id = client_id or os.getenv("ACC_CLIENT_ID", os.getenv("APS_CLIENT_ID"))
    callback_url = callback_url or os.getenv("APS_CALLBACK_URL", "http://localhost:9091/auth/callback")

    state = secrets.token_urlsafe(16)
    params = {
        "response_type": "code",
        "client_id": client_id,
        "redirect_uri": callback_url,
        "scope": scopes,
        "state": state,
        "prompt": "login",
    }
    url = f"{BASE_URL}/authentication/v2/authorize?{urlencode(params, quote_via=quote)}"
    return url, state


def exchange_code(code, client_id=None, client_secret=None, callback_url=None):
    """
    Exchange authorization code for access token (3-legged flow).

    Returns:
        dict: Token response with access_token, refresh_token, expires_in
    """
    client_id = client_id or os.getenv("ACC_CLIENT_ID", os.getenv("APS_CLIENT_ID"))
    client_secret = client_secret or os.getenv("ACC_CLIENT_SECRET", os.getenv("APS_CLIENT_SECRET"))
    callback_url = callback_url or os.getenv("APS_CALLBACK_URL", "http://localhost:9091/auth/callback")

    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": callback_url,
        },
        auth=(client_id, client_secret),
    )
    resp.raise_for_status()
    return resp.json()


def refresh_user_token(refresh_token, client_id=None, client_secret=None):
    """Refresh an expired user token."""
    client_id = client_id or os.getenv("ACC_CLIENT_ID", os.getenv("APS_CLIENT_ID"))
    client_secret = client_secret or os.getenv("ACC_CLIENT_SECRET", os.getenv("APS_CLIENT_SECRET"))

    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={
            "grant_type": "refresh_token",
            "refresh_token": refresh_token,
        },
        auth=(client_id, client_secret),
    )
    resp.raise_for_status()
    return resp.json()


def set_active_token(token_data):
    """
    Set the active 3-legged user token.

    Args:
        token_data: dict with access_token, refresh_token (optional), expires_at
    """
    global _active_token
    _active_token = token_data


def get_user_token():
    """
    Get valid user token, refreshing if needed.

    Returns:
        str or None: Access token, or None if not logged in
    """
    global _active_token
    if not _active_token:
        return None
    if _active_token["expires_at"] < time.time():
        rt = _active_token.get("refresh_token")
        if rt:
            try:
                data = refresh_user_token(rt)
                _active_token["access_token"] = data["access_token"]
                _active_token["refresh_token"] = data.get("refresh_token", rt)
                _active_token["expires_at"] = time.time() + data["expires_in"] - 60
            except Exception:
                _active_token = None
                return None
    return _active_token["access_token"]


def logout():
    """Clear the active user session."""
    global _active_token
    _active_token = None
