"""
ACC (Autodesk Construction Cloud) Server
3-legged OAuth2 + Data Management API + Viewer
Port 9091 - Separate from BIM Viewer (port 9090)
"""

import json
import os
import secrets
import time
from http.server import HTTPServer, SimpleHTTPRequestHandler
from urllib.parse import urlparse, parse_qs, urlencode, quote
import requests
from dotenv import load_dotenv

load_dotenv()

APS_CLIENT_ID = os.getenv("ACC_CLIENT_ID", os.getenv("APS_CLIENT_ID"))
APS_CLIENT_SECRET = os.getenv("ACC_CLIENT_SECRET", os.getenv("APS_CLIENT_SECRET"))
APS_CALLBACK_URL = os.getenv("APS_CALLBACK_URL", "http://localhost:9091/auth/callback")
BASE_URL = "https://developer.api.autodesk.com"

# 3-legged token storage (in-memory, per session)
_user_tokens = {}  # state -> {access_token, refresh_token, expires_at}
_active_token = None  # Currently active user token


def get_auth_url():
    """Generate Autodesk OAuth2 authorization URL."""
    state = secrets.token_urlsafe(16)
    params = {
        "response_type": "code",
        "client_id": APS_CLIENT_ID,
        "redirect_uri": APS_CALLBACK_URL,
        "scope": "data:read data:write data:create account:read",
        "state": state,
        "prompt": "login",
    }
    url = f"{BASE_URL}/authentication/v2/authorize?{urlencode(params, quote_via=quote)}"
    print(f"[AUTH] Login URL: {url}")
    return url, state


def exchange_code(code):
    """Exchange authorization code for access token."""
    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": APS_CALLBACK_URL,
        },
        auth=(APS_CLIENT_ID, APS_CLIENT_SECRET),
    )
    resp.raise_for_status()
    return resp.json()


def refresh_user_token(refresh_token):
    """Refresh expired user token."""
    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={
            "grant_type": "refresh_token",
            "refresh_token": refresh_token,
        },
        auth=(APS_CLIENT_ID, APS_CLIENT_SECRET),
    )
    resp.raise_for_status()
    return resp.json()


def get_user_token():
    """Get valid user token, refreshing if needed."""
    global _active_token
    if not _active_token:
        return None
    if _active_token["expires_at"] < time.time():
        # Try refresh
        rt = _active_token.get("refresh_token")
        if rt:
            try:
                data = refresh_user_token(rt)
                _active_token["access_token"] = data["access_token"]
                _active_token["refresh_token"] = data.get("refresh_token", rt)
                _active_token["expires_at"] = time.time() + data["expires_in"] - 60
            except Exception:
                _active_token = None
                return None
    return _active_token["access_token"]


class ACCHandler(SimpleHTTPRequestHandler):
    def do_GET(self):
        parsed = urlparse(self.path)
        path = parsed.path
        params = parse_qs(parsed.query)

        routes = {
            "/auth/login": self.handle_login,
            "/auth/callback": self.handle_callback,
            "/auth/status": self.handle_auth_status,
            "/auth/logout": self.handle_logout,
            "/api/token": self.handle_token,
            "/api/hubs": self.handle_hubs,
            "/api/projects": self.handle_projects,
            "/api/topfolders": self.handle_top_folders,
            "/api/folder": self.handle_folder_contents,
            "/api/versions": self.handle_versions,
            "/api/item": self.handle_item_details,
            "/api/metadata": self.handle_metadata,
            "/api/tree": self.handle_tree,
            "/api/properties": self.handle_properties,
            "/api/phases": self.handle_phases,
        }

        handler = routes.get(path)
        if handler:
            handler(params)
        else:
            super().do_GET()

    def send_json(self, data, status=200):
        self.send_response(status)
        self.send_header("Content-Type", "application/json")
        self.send_header("Access-Control-Allow-Origin", "*")
        self.end_headers()
        self.wfile.write(json.dumps(data, ensure_ascii=False).encode())

    def send_error_json(self, msg, status=500):
        self.send_json({"error": msg}, status)

    def redirect(self, url):
        self.send_response(302)
        self.send_header("Location", url)
        self.end_headers()

    # ===== Auth Endpoints =====

    def handle_login(self, params):
        """GET /auth/login - Redirect to Autodesk login page."""
        url, state = get_auth_url()
        self.redirect(url)

    def handle_callback(self, params):
        """GET /auth/callback?code=xxx&state=yyy - OAuth2 callback."""
        global _active_token
        code = params.get("code", [None])[0]
        error = params.get("error", [None])[0]

        if error:
            self.send_response(200)
            self.send_header("Content-Type", "text/html")
            self.end_headers()
            self.wfile.write(f"""<html><body>
                <h2>Login Failed</h2><p>{error}</p>
                <a href="/acc_viewer.html">Back</a>
            </body></html>""".encode())
            return

        if not code:
            return self.send_error_json("Missing code", 400)

        try:
            data = exchange_code(code)
            _active_token = {
                "access_token": data["access_token"],
                "refresh_token": data.get("refresh_token"),
                "expires_at": time.time() + data["expires_in"] - 60,
            }
            # Redirect to viewer
            self.redirect("/acc_viewer.html#logged-in")
        except Exception as e:
            self.send_error_json(f"Token exchange failed: {e}", 500)

    def handle_auth_status(self, params):
        """GET /auth/status - Check if user is logged in."""
        token = get_user_token()
        if token:
            # Get user profile
            try:
                resp = requests.get(
                    f"{BASE_URL}/userprofile/v1/users/@me",
                    headers={"Authorization": f"Bearer {token}"},
                )
                if resp.status_code == 200:
                    profile = resp.json()
                    self.send_json({
                        "loggedIn": True,
                        "name": f"{profile.get('firstName', '')} {profile.get('lastName', '')}".strip(),
                        "email": profile.get("emailId", ""),
                    })
                    return
            except Exception:
                pass
            self.send_json({"loggedIn": True, "name": "User", "email": ""})
        else:
            self.send_json({"loggedIn": False})

    def handle_logout(self, params):
        """GET /auth/logout - Clear user session."""
        global _active_token
        _active_token = None
        self.send_json({"loggedIn": False})

    def handle_token(self, params):
        """GET /api/token - Get viewer token (3-legged)."""
        token = get_user_token()
        if not token:
            return self.send_error_json("Not logged in", 401)
        self.send_json({"access_token": token, "expires_in": 3600})

    # ===== ACC Data Management Endpoints =====

    def _api_get(self, url, params=None):
        """Make authenticated GET to APS API."""
        token = get_user_token()
        if not token:
            return None, "Not logged in"
        resp = requests.get(
            url,
            headers={"Authorization": f"Bearer {token}"},
            params=params,
        )
        if resp.status_code == 401:
            return None, "Token expired. Please login again."
        resp.raise_for_status()
        return resp.json(), None

    def handle_hubs(self, params):
        """GET /api/hubs - List all ACC/BIM 360 hubs."""
        try:
            data, err = self._api_get(f"{BASE_URL}/project/v1/hubs")
            if err:
                return self.send_error_json(err, 401)
            hubs = []
            for h in data.get("data", []):
                hubs.append({
                    "id": h["id"],
                    "name": h["attributes"]["name"],
                    "type": h["attributes"].get("extension", {}).get("type", ""),
                })
            self.send_json({"hubs": hubs})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_projects(self, params):
        """GET /api/projects?hub_id=xxx - List projects in a hub."""
        hub_id = params.get("hub_id", [None])[0]
        if not hub_id:
            return self.send_error_json("Missing hub_id", 400)
        try:
            all_projects = []
            url = f"{BASE_URL}/project/v1/hubs/{hub_id}/projects"
            while url:
                data, err = self._api_get(url)
                if err:
                    return self.send_error_json(err, 401)
                for p in data.get("data", []):
                    all_projects.append({
                        "id": p["id"],
                        "name": p["attributes"]["name"],
                        "type": p["attributes"].get("extension", {}).get("type", ""),
                        "status": p["attributes"].get("extension", {}).get("data", {}).get("projectType", ""),
                    })
                # Pagination
                url = data.get("links", {}).get("next", {})
                if isinstance(url, dict):
                    url = url.get("href")
            self.send_json({"projects": all_projects})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_top_folders(self, params):
        """GET /api/topfolders?hub_id=xxx&project_id=yyy - List top folders in project."""
        hub_id = params.get("hub_id", [None])[0]
        project_id = params.get("project_id", [None])[0]
        if not hub_id or not project_id:
            return self.send_error_json("Missing hub_id or project_id", 400)
        try:
            data, err = self._api_get(
                f"{BASE_URL}/project/v1/hubs/{hub_id}/projects/{project_id}/topFolders"
            )
            if err:
                return self.send_error_json(err, 401)
            folders = []
            for f in data.get("data", []):
                folders.append({
                    "id": f["id"],
                    "name": f["attributes"]["name"],
                    "type": "folder",
                })
            self.send_json({"items": folders})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_folder_contents(self, params):
        """GET /api/folder?project_id=xxx&folder_id=yyy - List folder contents."""
        project_id = params.get("project_id", [None])[0]
        folder_id = params.get("folder_id", [None])[0]
        if not project_id or not folder_id:
            return self.send_error_json("Missing project_id or folder_id", 400)
        try:
            all_items = []
            url = f"{BASE_URL}/data/v1/projects/{project_id}/folders/{folder_id}/contents"
            while url:
                data, err = self._api_get(url)
                if err:
                    return self.send_error_json(err, 401)
                for item in data.get("data", []):
                    attrs = item["attributes"]
                    item_type = item["type"]  # "folders" or "items"
                    entry = {
                        "id": item["id"],
                        "name": attrs.get("displayName") or attrs.get("name", ""),
                        "type": "folder" if item_type == "folders" else "file",
                        "lastModified": attrs.get("lastModifiedTime", ""),
                        "createTime": attrs.get("createTime", ""),
                    }
                    if item_type == "items":
                        ext = attrs.get("extension", {}).get("type", "")
                        entry["fileType"] = attrs.get("fileType", "")
                        entry["extensionType"] = ext
                    all_items.append(entry)
                # Pagination
                url = data.get("links", {}).get("next", {})
                if isinstance(url, dict):
                    url = url.get("href")
            # Sort: folders first, then files
            all_items.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"].lower()))
            self.send_json({"items": all_items})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_versions(self, params):
        """GET /api/versions?project_id=xxx&item_id=yyy - Get item versions (tip = latest)."""
        project_id = params.get("project_id", [None])[0]
        item_id = params.get("item_id", [None])[0]
        if not project_id or not item_id:
            return self.send_error_json("Missing project_id or item_id", 400)
        try:
            data, err = self._api_get(
                f"{BASE_URL}/data/v1/projects/{project_id}/items/{item_id}/versions"
            )
            if err:
                return self.send_error_json(err, 401)
            versions = []
            for v in data.get("data", []):
                attrs = v["attributes"]
                deriv = v.get("relationships", {}).get("derivatives", {}).get("data", {})
                urn = deriv.get("id", "")
                versions.append({
                    "id": v["id"],
                    "version": attrs.get("versionNumber", 0),
                    "name": attrs.get("name", ""),
                    "lastModified": attrs.get("lastModifiedTime", ""),
                    "createTime": attrs.get("createTime", ""),
                    "fileType": attrs.get("fileType", ""),
                    "urn": urn,
                })
            versions.sort(key=lambda x: x["version"], reverse=True)
            self.send_json({"versions": versions})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_item_details(self, params):
        """GET /api/item?project_id=xxx&item_id=yyy - Get item tip (latest version URN)."""
        project_id = params.get("project_id", [None])[0]
        item_id = params.get("item_id", [None])[0]
        if not project_id or not item_id:
            return self.send_error_json("Missing project_id or item_id", 400)
        try:
            # Get item tip version
            data, err = self._api_get(
                f"{BASE_URL}/data/v1/projects/{project_id}/items/{item_id}/tip"
            )
            if err:
                return self.send_error_json(err, 401)
            attrs = data.get("data", {}).get("attributes", {})
            deriv = data.get("data", {}).get("relationships", {}).get("derivatives", {}).get("data", {})
            urn = deriv.get("id", "")
            self.send_json({
                "name": attrs.get("name", ""),
                "version": attrs.get("versionNumber", 0),
                "fileType": attrs.get("fileType", ""),
                "urn": urn,
                "lastModified": attrs.get("lastModifiedTime", ""),
            })
        except Exception as e:
            self.send_error_json(str(e))

    # ===== Model Derivative Endpoints (proxy with user token) =====

    def handle_metadata(self, params):
        """GET /api/metadata?urn=xxx - List views/GUIDs in model."""
        urn = params.get("urn", [None])[0]
        if not urn:
            return self.send_error_json("Missing urn", 400)
        try:
            data, err = self._api_get(f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata")
            if err:
                return self.send_error_json(err, 401)
            self.send_json(data)
        except Exception as e:
            self.send_error_json(str(e))

    def handle_tree(self, params):
        """GET /api/tree?urn=xxx&guid=yyy - Get object tree hierarchy."""
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = get_user_token()
            if not token:
                return self.send_error_json("Not logged in", 401)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}",
                headers={"Authorization": f"Bearer {token}"},
                params={"forceget": "true"},
            )
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_properties(self, params):
        """GET /api/properties?urn=xxx&guid=yyy - Get all element properties."""
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = get_user_token()
            if not token:
                return self.send_error_json("Not logged in", 401)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}/properties",
                headers={"Authorization": f"Bearer {token}"},
                params={"forceget": "true"},
            )
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_phases(self, params):
        """GET /api/phases?urn=xxx&guid=yyy - Extract phase data per element."""
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = get_user_token()
            if not token:
                return self.send_error_json("Not logged in", 401)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}/properties",
                headers={"Authorization": f"Bearer {token}"},
                params={"forceget": "true"},
            )
            resp.raise_for_status()
            data = resp.json()
            collection = data.get("data", {}).get("collection", [])

            phases = {}
            elements = {}
            phase_names = set()

            for elem in collection:
                db_id = elem.get("objectid")
                elem_name = elem.get("name", "")
                props = elem.get("properties", {})

                phase_created = None
                phase_demolished = None
                category = ""

                phasing = props.get("Phasing", {})
                if isinstance(phasing, dict):
                    phase_created = phasing.get("Phase Created")
                    phase_demolished = phasing.get("Phase Demolished")

                cat_data = props.get("__category__", {})
                if isinstance(cat_data, dict):
                    category = cat_data.get("__category__", "")

                if phase_created:
                    phase_names.add(phase_created)
                    if phase_created not in phases:
                        phases[phase_created] = []
                    phases[phase_created].append(db_id)
                    elements[str(db_id)] = {
                        "name": elem_name,
                        "phaseCreated": phase_created,
                        "phaseDemolished": phase_demolished,
                        "category": category,
                    }

            sorted_phases = sorted(phase_names)
            self.send_json({
                "phases": sorted_phases,
                "phaseElements": phases,
                "elements": elements,
                "totalElements": len(collection),
                "phasedElements": len(elements),
            })
        except Exception as e:
            self.send_error_json(str(e))


if __name__ == "__main__":
    port = 9091
    if not APS_CLIENT_ID or not APS_CLIENT_SECRET:
        print("[ERROR] Set APS_CLIENT_ID and APS_CLIENT_SECRET in .env")
        exit(1)

    print(f"APS_CALLBACK_URL = {APS_CALLBACK_URL}")
    print(f"[IMPORTANT] Make sure this callback URL is registered in your APS app settings!")
    print(f"  -> https://aps.autodesk.com/myapps")
    print()

    server = HTTPServer(("0.0.0.0", port), ACCHandler)
    print(f"ACC Server running at http://localhost:{port}")
    print(f"Open http://localhost:{port}/acc_viewer.html")
    print("Press Ctrl+C to stop")
    server.serve_forever()
