"""
APS Upload Script for ZIP archives (ReCap point clouds, composite models).
Uploads a zip file to APS, translates with compressedUrn + rootFilename.

Usage:
  python aps_upload_zip.py <path_to_zip> [root_filename]

  root_filename: path inside the zip to the main file (e.g. "Structured/Brownsville.rcp")
                 If omitted, auto-detects .rcp/.rvt/.ifc/.nwd files inside the zip.

Example:
  python aps_upload_zip.py revitproject/Structured_2023-04-12_09-32-50am.zip
"""

import os
import sys
import base64
import time
import json
import zipfile
import requests
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()

APS_CLIENT_ID = os.getenv("APS_CLIENT_ID")
APS_CLIENT_SECRET = os.getenv("APS_CLIENT_SECRET")
APS_BUCKET_KEY = os.getenv("APS_BUCKET_KEY", "dthub-bim-demo")

BASE_URL = "https://developer.api.autodesk.com"
CHUNK_SIZE = 20 * 1024 * 1024  # 20 MB


def get_access_token():
    url = f"{BASE_URL}/authentication/v2/token"
    data = {
        "grant_type": "client_credentials",
        "scope": "data:read data:write data:create bucket:create bucket:read",
    }
    resp = requests.post(url, data=data, auth=(APS_CLIENT_ID, APS_CLIENT_SECRET))
    resp.raise_for_status()
    token = resp.json()["access_token"]
    print(f"[OK] Access token received (expires in {resp.json()['expires_in']}s)")
    return token


def create_bucket(token):
    url = f"{BASE_URL}/oss/v2/buckets"
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
    }
    body = {"bucketKey": APS_BUCKET_KEY, "policyKey": "persistent"}
    resp = requests.post(url, headers=headers, json=body)
    if resp.status_code == 409:
        print(f"[OK] Bucket '{APS_BUCKET_KEY}' already exists")
        return
    resp.raise_for_status()
    print(f"[OK] Bucket '{APS_BUCKET_KEY}' created")


def detect_root_file(zip_path):
    """Auto-detect the root/main file inside the zip."""
    main_extensions = ['.rcp', '.rvt', '.ifc', '.nwd']
    with zipfile.ZipFile(zip_path, 'r') as zf:
        for ext in main_extensions:
            matches = [n for n in zf.namelist() if n.lower().endswith(ext) and not n.startswith('__MACOSX')]
            if matches:
                # Prefer the shortest path (top-level file)
                matches.sort(key=len)
                return matches[0]
    return None


def upload_file(token, file_path):
    """Upload file to APS bucket via S3 signed URLs in batches."""
    file_path = Path(file_path)
    file_size = file_path.stat().st_size
    object_key = file_path.name
    total_parts = max(1, (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE)
    batch_size = 25  # APS signed URL batch limit

    print(f"[...] Uploading {object_key} ({file_size / 1024 / 1024:.1f} MB, {total_parts} part(s))...")

    sign_url = f"{BASE_URL}/oss/v2/buckets/{APS_BUCKET_KEY}/objects/{object_key}/signeds3upload"
    headers = {"Authorization": f"Bearer {token}"}

    # Get first batch to obtain uploadKey
    first_batch = min(batch_size, total_parts)
    resp = requests.get(sign_url, headers=headers, params={"parts": first_batch, "firstPart": 1})
    resp.raise_for_status()
    sign_data = resp.json()
    upload_key = sign_data["uploadKey"]

    part_index = 0  # 0-based index for tracking
    with open(file_path, "rb") as f:
        while part_index < total_parts:
            # Get batch of signed URLs
            remaining = total_parts - part_index
            batch_count = min(batch_size, remaining)
            first_part = part_index + 1  # 1-based

            if part_index > 0:
                resp = requests.get(sign_url, headers=headers, params={
                    "parts": batch_count,
                    "firstPart": first_part,
                    "uploadKey": upload_key,
                })
                resp.raise_for_status()
                sign_data = resp.json()

            urls = sign_data["urls"]
            for j, part_url in enumerate(urls):
                chunk = f.read(CHUNK_SIZE)
                if not chunk:
                    break
                # Retry up to 3 times per part
                for attempt in range(3):
                    try:
                        part_resp = requests.put(
                            part_url,
                            data=chunk,
                            headers={"Content-Type": "application/octet-stream"},
                            timeout=600,
                        )
                        if part_resp.status_code in (200, 201):
                            break
                        print(f"[WARN] Part {part_index+j+1} status {part_resp.status_code}, retry {attempt+1}/3")
                    except requests.exceptions.RequestException as e:
                        print(f"[WARN] Part {part_index+j+1} error: {e}, retry {attempt+1}/3")
                    if attempt < 2:
                        time.sleep(5)
                else:
                    print(f"[ERROR] Part {part_index+j+1}/{total_parts} failed after 3 retries")
                    return None
                current = part_index + j + 1
                pct = int(current * 100 / total_parts)
                print(f"      Part {current}/{total_parts} uploaded ({pct}%)")

            part_index += batch_count

    # Complete the upload
    complete_url = f"{BASE_URL}/oss/v2/buckets/{APS_BUCKET_KEY}/objects/{object_key}/signeds3upload"
    resp = requests.post(
        complete_url,
        headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"},
        json={"uploadKey": upload_key},
    )
    resp.raise_for_status()

    object_id = resp.json()["objectId"]
    urn = base64.urlsafe_b64encode(object_id.encode()).decode().rstrip("=")
    print(f"[OK] Uploaded: {object_key}")
    print(f"     Object ID: {object_id}")
    print(f"     URN: {urn}")
    return urn


def translate_compressed(token, urn, root_filename):
    """Start translation for a compressed (zip) model with rootFilename."""
    url = f"{BASE_URL}/modelderivative/v2/designdata/job"
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
    }
    body = {
        "input": {
            "urn": urn,
            "compressedUrn": True,
            "rootFilename": root_filename,
        },
        "output": {
            "formats": [{"type": "svf2", "views": ["3d", "2d"]}]
        },
    }
    resp = requests.post(url, headers=headers, json=body)
    if resp.status_code != 200 and resp.status_code != 201:
        print(f"[ERROR] Translation request failed: {resp.status_code}")
        print(f"        Response: {resp.text}")
        resp.raise_for_status()
    print(f"[OK] Translation started (compressedUrn=true, rootFilename={root_filename})")
    return urn


def wait_for_translation(token, urn, timeout=1800):
    """Wait for translation to complete (longer timeout for large files)."""
    print("[...] Waiting for translation to complete...")
    start = time.time()
    while time.time() - start < timeout:
        url = f"{BASE_URL}/modelderivative/v2/designdata/{urn}/manifest"
        headers = {"Authorization": f"Bearer {token}"}
        resp = requests.get(url, headers=headers)
        resp.raise_for_status()
        result = resp.json()
        status = result.get("status", "unknown")
        progress = result.get("progress", "0%")
        print(f"      Status: {status} | Progress: {progress}")
        if status == "success":
            print("[OK] Translation complete!")
            return True
        if status == "failed":
            print("[ERROR] Translation failed!")
            # Print derivative messages for debugging
            for d in result.get("derivatives", []):
                for m in d.get("messages", []):
                    print(f"       - {m.get('type','')}: {m.get('message','')}")
            return False
        time.sleep(15)
    print("[ERROR] Translation timed out")
    return False


def save_viewer_config(name, urn):
    """Add model to viewer_config.json (merge with existing)."""
    config_path = Path(__file__).parent / "viewer_config.json"
    models = []
    if config_path.exists():
        with open(config_path, "r", encoding="utf-8") as f:
            models = json.load(f).get("models", [])

    # Update existing or add new
    found = False
    for m in models:
        if m["name"] == name:
            m["urn"] = urn
            found = True
            break
    if not found:
        models.append({"name": name, "urn": urn})

    with open(config_path, "w", encoding="utf-8") as f:
        json.dump({"models": models}, f, indent=2, ensure_ascii=False)
    print(f"[OK] Viewer config saved ({len(models)} models)")


def main():
    if not APS_CLIENT_ID or not APS_CLIENT_SECRET:
        print("[ERROR] Set APS_CLIENT_ID and APS_CLIENT_SECRET in .env")
        sys.exit(1)

    if len(sys.argv) < 2:
        print("Usage: python aps_upload_zip.py <path_to_zip> [root_filename]")
        sys.exit(1)

    zip_path = Path(sys.argv[1])
    if not zip_path.is_absolute():
        zip_path = Path(__file__).parent / zip_path
    if not zip_path.exists():
        print(f"[ERROR] File not found: {zip_path}")
        sys.exit(1)

    # Determine root filename
    if len(sys.argv) >= 3:
        root_filename = sys.argv[2]
    else:
        root_filename = detect_root_file(zip_path)
        if not root_filename:
            print("[ERROR] Cannot auto-detect root file. Specify it as 2nd argument.")
            sys.exit(1)

    file_size_mb = zip_path.stat().st_size / 1024 / 1024
    print(f"{'='*60}")
    print(f"ZIP Upload: {zip_path.name}")
    print(f"Size: {file_size_mb:.1f} MB")
    print(f"Root file: {root_filename}")
    print(f"{'='*60}\n")

    token = get_access_token()
    create_bucket(token)

    urn = upload_file(token, zip_path)
    if not urn:
        print("[ERROR] Upload failed")
        sys.exit(1)

    # Refresh token (upload of large files may exceed token lifetime)
    print("[...] Refreshing token before translation...")
    token = get_access_token()

    translate_compressed(token, urn, root_filename)
    success = wait_for_translation(token, urn)

    if success:
        model_name = Path(root_filename).stem
        save_viewer_config(model_name, urn)
        print(f"\n{'='*60}")
        print("DONE!")
        print(f"{'='*60}")
        print(f"\nModel: {model_name}")
        print(f"URN: {urn}")
    else:
        print("\n[ERROR] Translation failed. Model not added to config.")
        sys.exit(1)


if __name__ == "__main__":
    main()
