"""
Combined BIM + ACC Server on port 9090
- viewer2.html: BIM Viewer (2-legged token)
- acc_viewer.html: ACC Viewer (3-legged OAuth2)
"""

import base64
import json
import os
import secrets
import time
from http.server import HTTPServer, SimpleHTTPRequestHandler
from urllib.parse import urlparse, parse_qs, urlencode, quote
import requests
from dotenv import load_dotenv

load_dotenv()

# 2-legged (BIM/OSS) - must use APS app that owns dthub-bim-demo bucket
APS_CLIENT_ID = os.getenv("APS_CLIENT_ID")
APS_CLIENT_SECRET = os.getenv("APS_CLIENT_SECRET")
# 3-legged (ACC) - can be same or different app
ACC_CLIENT_ID = os.getenv("ACC_CLIENT_ID", APS_CLIENT_ID)
ACC_CLIENT_SECRET = os.getenv("ACC_CLIENT_SECRET", APS_CLIENT_SECRET)
APS_CALLBACK_URL = os.getenv("APS_CALLBACK_URL", "http://localhost:9090/auth/callback")
BASE_URL = "https://developer.api.autodesk.com"

# 2-legged token cache (BIM)
_token_cache = {"token": None, "expires_at": 0}
# 3-legged (ACC)
_user_tokens = {}
_active_token = None


def get_server_token():
    if _token_cache["token"] and _token_cache["expires_at"] > time.time():
        return _token_cache["token"]
    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={"grant_type": "client_credentials", "scope": "data:read viewables:read"},
        auth=(APS_CLIENT_ID, APS_CLIENT_SECRET),
    )
    resp.raise_for_status()
    data = resp.json()
    _token_cache["token"] = data["access_token"]
    _token_cache["expires_at"] = time.time() + data["expires_in"] - 60
    return _token_cache["token"]


def get_auth_url():
    state = secrets.token_urlsafe(16)
    params = {
        "response_type": "code",
        "client_id": ACC_CLIENT_ID,
        "redirect_uri": APS_CALLBACK_URL,
        "scope": "data:read data:write data:create account:read",
        "state": state,
        "prompt": "login",
    }
    return f"{BASE_URL}/authentication/v2/authorize?{urlencode(params, quote_via=quote)}", state


def exchange_code(code):
    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={"grant_type": "authorization_code", "code": code, "redirect_uri": APS_CALLBACK_URL},
        auth=(ACC_CLIENT_ID, ACC_CLIENT_SECRET),
    )
    resp.raise_for_status()
    return resp.json()


def refresh_user_token(refresh_token):
    resp = requests.post(
        f"{BASE_URL}/authentication/v2/token",
        data={"grant_type": "refresh_token", "refresh_token": refresh_token},
        auth=(ACC_CLIENT_ID, ACC_CLIENT_SECRET),
    )
    resp.raise_for_status()
    return resp.json()


def get_user_token():
    global _active_token
    if not _active_token:
        return None
    if _active_token["expires_at"] < time.time():
        rt = _active_token.get("refresh_token")
        if rt:
            try:
                data = refresh_user_token(rt)
                _active_token["access_token"] = data["access_token"]
                _active_token["refresh_token"] = data.get("refresh_token", rt)
                _active_token["expires_at"] = time.time() + data["expires_in"] - 60
            except Exception:
                _active_token = None
                return None
    return _active_token["access_token"]


class CombinedHandler(SimpleHTTPRequestHandler):
    def do_GET(self):
        # Normalize double slash (e.g. //viewer2.html -> /viewer2.html)
        if self.path.startswith("//"):
            self.path = "/" + self.path.lstrip("/")
        parsed = urlparse(self.path)
        path = parsed.path.rstrip("/") or "/"
        params = parse_qs(parsed.query)

        routes = {
            "/auth/login": self.handle_login,
            "/auth/callback": self.handle_callback,
            "/auth/status": self.handle_auth_status,
            "/auth/logout": self.handle_logout,
            "/api/token": self.handle_token,
            "/api/token-bim": self.handle_token_bim,
            "/api/metadata": self.handle_metadata,
            "/api/properties": self.handle_properties,
            "/api/tree": self.handle_tree,
            "/api/thumbnail": self.handle_thumbnail,
            "/api/manifest": self.handle_manifest,
            "/api/phases": self.handle_phases,
            "/api/hubs": self.handle_hubs,
            "/api/projects": self.handle_projects,
            "/api/topfolders": self.handle_top_folders,
            "/api/folder": self.handle_folder_contents,
            "/api/versions": self.handle_versions,
            "/api/item": self.handle_item_details,
        }

        handler = routes.get(path)
        if handler:
            handler(params)
        else:
            super().do_GET()

    def send_json(self, data, status=200):
        self.send_response(status)
        self.send_header("Content-Type", "application/json")
        self.send_header("Access-Control-Allow-Origin", "*")
        self.end_headers()
        self.wfile.write(json.dumps(data, ensure_ascii=False).encode())

    def send_error_json(self, msg, status=500):
        self.send_json({"error": msg}, status)

    def redirect(self, url):
        self.send_response(302)
        self.send_header("Location", url)
        self.end_headers()

    def _api_get(self, url, params=None):
        token = get_user_token()
        if not token:
            return None, "Not logged in"
        resp = requests.get(url, headers={"Authorization": f"Bearer {token}"}, params=params)
        if resp.status_code == 401:
            return None, "Token expired. Please login again."
        resp.raise_for_status()
        return resp.json(), None

    def _token_for_request(self, urn=None):
        """OSS (dthub-bim-demo) needs 2-legged. ACC projects need 3-legged when logged in."""
        if urn:
            try:
                decoded = base64.b64decode(urn).decode("utf-8", errors="ignore")
                if "adsk.obj" in decoded or "ovs.object" in decoded:
                    return get_server_token()
            except Exception:
                pass
        return get_user_token() or get_server_token()

    # ===== Auth =====
    def handle_login(self, params):
        url, _ = get_auth_url()
        self.redirect(url)

    def handle_callback(self, params):
        global _active_token
        code = params.get("code", [None])[0]
        error = params.get("error", [None])[0]
        if error:
            self.send_response(200)
            self.send_header("Content-Type", "text/html")
            self.end_headers()
            self.wfile.write(f'<html><body><h2>Login Failed</h2><p>{error}</p><a href="/acc_viewer.html">Back</a></body></html>'.encode())
            return
        if not code:
            return self.send_error_json("Missing code", 400)
        try:
            data = exchange_code(code)
            _active_token = {
                "access_token": data["access_token"],
                "refresh_token": data.get("refresh_token"),
                "expires_at": time.time() + data["expires_in"] - 60,
            }
            self.redirect("/acc_viewer.html#logged-in")
        except Exception as e:
            self.send_error_json(f"Token exchange failed: {e}", 500)

    def handle_auth_status(self, params):
        token = get_user_token()
        if token:
            try:
                resp = requests.get(f"{BASE_URL}/userprofile/v1/users/@me", headers={"Authorization": f"Bearer {token}"})
                if resp.status_code == 200:
                    profile = resp.json()
                    self.send_json({
                        "loggedIn": True,
                        "name": f"{profile.get('firstName', '')} {profile.get('lastName', '')}".strip(),
                        "email": profile.get("emailId", ""),
                    })
                    return
            except Exception:
                pass
            self.send_json({"loggedIn": True, "name": "User", "email": ""})
        else:
            self.send_json({"loggedIn": False})

    def handle_logout(self, params):
        global _active_token
        _active_token = None
        self.send_json({"loggedIn": False})

    def handle_token(self, params):
        """3-legged for ACC when logged in, else 2-legged."""
        token = get_user_token() or get_server_token()
        if not token:
            return self.send_error_json("No token", 500)
        self.send_json({"access_token": token, "expires_in": 3600})

    def handle_token_bim(self, params):
        """Always 2-legged for BIM Viewer (OSS bucket). ACC 3-legged has no OSS access."""
        try:
            token = get_server_token()
            self.send_json({"access_token": token, "expires_in": 3600})
        except Exception as e:
            self.send_error_json(str(e))

    # ===== BIM/Model Derivative (2-legged or 3-legged) =====
    def handle_metadata(self, params):
        urn = params.get("urn", [None])[0]
        if not urn:
            return self.send_error_json("Missing urn", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata", headers={"Authorization": f"Bearer {token}"})
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_properties(self, params):
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}/properties",
                headers={"Authorization": f"Bearer {token}"}, params={"forceget": "true"},
            )
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_tree(self, params):
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}",
                headers={"Authorization": f"Bearer {token}"}, params={"forceget": "true"},
            )
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_thumbnail(self, params):
        urn = params.get("urn", [None])[0]
        if not urn:
            return self.send_error_json("Missing urn", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/thumbnail",
                headers={"Authorization": f"Bearer {token}"}, params={"width": 400, "height": 400},
            )
            resp.raise_for_status()
            self.send_response(200)
            self.send_header("Content-Type", "image/png")
            self.send_header("Access-Control-Allow-Origin", "*")
            self.end_headers()
            self.wfile.write(resp.content)
        except Exception as e:
            self.send_error_json(str(e))

    def handle_manifest(self, params):
        urn = params.get("urn", [None])[0]
        if not urn:
            return self.send_error_json("Missing urn", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(f"{BASE_URL}/modelderivative/v2/designdata/{urn}/manifest", headers={"Authorization": f"Bearer {token}"})
            resp.raise_for_status()
            self.send_json(resp.json())
        except Exception as e:
            self.send_error_json(str(e))

    def handle_phases(self, params):
        urn = params.get("urn", [None])[0]
        guid = params.get("guid", [None])[0]
        if not urn or not guid:
            return self.send_error_json("Missing urn or guid", 400)
        try:
            token = self._token_for_request(urn)
            resp = requests.get(
                f"{BASE_URL}/modelderivative/v2/designdata/{urn}/metadata/{guid}/properties",
                headers={"Authorization": f"Bearer {token}"}, params={"forceget": "true"},
            )
            resp.raise_for_status()
            data = resp.json()
            collection = data.get("data", {}).get("collection", [])
            phases, elements, phase_names = {}, {}, set()
            for elem in collection:
                db_id = elem.get("objectid")
                elem_name = elem.get("name", "")
                props = elem.get("properties", {})
                phasing = props.get("Phasing", {})
                phase_created = phasing.get("Phase Created") if isinstance(phasing, dict) else None
                phase_demolished = phasing.get("Phase Demolished") if isinstance(phasing, dict) else None
                cat_data = props.get("__category__", {})
                category = cat_data.get("__category__", "") if isinstance(cat_data, dict) else ""
                if phase_created:
                    phase_names.add(phase_created)
                    phases.setdefault(phase_created, []).append(db_id)
                    elements[str(db_id)] = {"name": elem_name, "phaseCreated": phase_created, "phaseDemolished": phase_demolished, "category": category}
            self.send_json({"phases": sorted(phase_names), "phaseElements": phases, "elements": elements, "totalElements": len(collection), "phasedElements": len(elements)})
        except Exception as e:
            self.send_error_json(str(e))

    # ===== ACC Data Management (3-legged only) =====
    def handle_hubs(self, params):
        try:
            data, err = self._api_get(f"{BASE_URL}/project/v1/hubs")
            if err:
                return self.send_error_json(err, 401)
            hubs = [{"id": h["id"], "name": h["attributes"]["name"], "type": h["attributes"].get("extension", {}).get("type", "")} for h in data.get("data", [])]
            self.send_json({"hubs": hubs})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_projects(self, params):
        hub_id = params.get("hub_id", [None])[0]
        if not hub_id:
            return self.send_error_json("Missing hub_id", 400)
        try:
            all_projects = []
            url = f"{BASE_URL}/project/v1/hubs/{hub_id}/projects"
            while url:
                data, err = self._api_get(url)
                if err:
                    return self.send_error_json(err, 401)
                for p in data.get("data", []):
                    ext = p["attributes"].get("extension", {}).get("data", {})
                    all_projects.append({"id": p["id"], "name": p["attributes"]["name"], "type": p["attributes"].get("extension", {}).get("type", ""), "status": ext.get("projectType", "")})
                url = data.get("links", {}).get("next", {})
                url = url.get("href") if isinstance(url, dict) else None
            self.send_json({"projects": all_projects})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_top_folders(self, params):
        hub_id = params.get("hub_id", [None])[0]
        project_id = params.get("project_id", [None])[0]
        if not hub_id or not project_id:
            return self.send_error_json("Missing hub_id or project_id", 400)
        try:
            data, err = self._api_get(f"{BASE_URL}/project/v1/hubs/{hub_id}/projects/{project_id}/topFolders")
            if err:
                return self.send_error_json(err, 401)
            folders = [{"id": f["id"], "name": f["attributes"]["name"], "type": "folder"} for f in data.get("data", [])]
            self.send_json({"items": folders})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_folder_contents(self, params):
        project_id = params.get("project_id", [None])[0]
        folder_id = params.get("folder_id", [None])[0]
        if not project_id or not folder_id:
            return self.send_error_json("Missing project_id or folder_id", 400)
        try:
            all_items = []
            url = f"{BASE_URL}/data/v1/projects/{project_id}/folders/{folder_id}/contents"
            while url:
                data, err = self._api_get(url)
                if err:
                    return self.send_error_json(err, 401)
                for item in data.get("data", []):
                    attrs = item["attributes"]
                    item_type = item["type"]
                    entry = {"id": item["id"], "name": attrs.get("displayName") or attrs.get("name", ""), "type": "folder" if item_type == "folders" else "file", "lastModified": attrs.get("lastModifiedTime", ""), "createTime": attrs.get("createTime", "")}
                    if item_type == "items":
                        entry["fileType"] = attrs.get("fileType", "")
                        entry["extensionType"] = attrs.get("extension", {}).get("type", "")
                    all_items.append(entry)
                url = data.get("links", {}).get("next", {})
                url = url.get("href") if isinstance(url, dict) else None
            all_items.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"].lower()))
            self.send_json({"items": all_items})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_versions(self, params):
        project_id = params.get("project_id", [None])[0]
        item_id = params.get("item_id", [None])[0]
        if not project_id or not item_id:
            return self.send_error_json("Missing project_id or item_id", 400)
        try:
            data, err = self._api_get(f"{BASE_URL}/data/v1/projects/{project_id}/items/{item_id}/versions")
            if err:
                return self.send_error_json(err, 401)
            versions = []
            for v in data.get("data", []):
                attrs = v["attributes"]
                deriv = v.get("relationships", {}).get("derivatives", {}).get("data", {})
                versions.append({"id": v["id"], "version": attrs.get("versionNumber", 0), "name": attrs.get("name", ""), "lastModified": attrs.get("lastModifiedTime", ""), "createTime": attrs.get("createTime", ""), "fileType": attrs.get("fileType", ""), "urn": deriv.get("id", "")})
            versions.sort(key=lambda x: x["version"], reverse=True)
            self.send_json({"versions": versions})
        except Exception as e:
            self.send_error_json(str(e))

    def handle_item_details(self, params):
        project_id = params.get("project_id", [None])[0]
        item_id = params.get("item_id", [None])[0]
        if not project_id or not item_id:
            return self.send_error_json("Missing project_id or item_id", 400)
        try:
            data, err = self._api_get(f"{BASE_URL}/data/v1/projects/{project_id}/items/{item_id}/tip")
            if err:
                return self.send_error_json(err, 401)
            d = data.get("data", {})
            attrs = d.get("attributes", {})
            deriv = d.get("relationships", {}).get("derivatives", {}).get("data", {})
            self.send_json({"name": attrs.get("name", ""), "version": attrs.get("versionNumber", 0), "fileType": attrs.get("fileType", ""), "urn": deriv.get("id", ""), "lastModified": attrs.get("lastModifiedTime", "")})
        except Exception as e:
            self.send_error_json(str(e))


if __name__ == "__main__":
    port = 9090
    if not APS_CLIENT_ID or not APS_CLIENT_SECRET:
        print("[ERROR] Set APS_CLIENT_ID and APS_CLIENT_SECRET in .env")
        exit(1)
    print(f"ACC OAuth - Client ID: {ACC_CLIENT_ID}")
    print(f"ACC OAuth - Callback:   {APS_CALLBACK_URL}")
    print("  -> Add this Callback URL to your app at https://aps.autodesk.com/myapps")
    print()
    server = HTTPServer(("0.0.0.0", port), CombinedHandler)
    print(f"BIM + ACC server: http://localhost:{port}")
    print("  BIM:  http://localhost:9090/viewer2.html")
    print("  ACC:  http://localhost:9090/acc_viewer.html")
    server.serve_forever()
