#!/usr/bin/env python3
"""
IF.emotion trace bundle verifier + Merkle inclusion proof tool.

Run with the venv:
  /root/tmp/iftrace_venv/bin/python /root/tmp/iftrace.py <command> ...
"""

from __future__ import annotations

import argparse
import hashlib
import io
import json
import os
import sys
import tarfile
import tempfile
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterable

try:
    from canonicaljson import encode_canonical_json  # type: ignore
except ModuleNotFoundError:
    # `iftrace.py` should be usable as a single-file verifier for basic bundles (manifest + sha256s)
    # without requiring external dependencies.
    def encode_canonical_json(obj: Any) -> bytes:  # type: ignore
        return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False).encode("utf-8")

try:
    from nacl.signing import VerifyKey  # type: ignore
    from nacl.encoding import HexEncoder  # type: ignore
except ModuleNotFoundError:
    VerifyKey = None  # type: ignore
    HexEncoder = None  # type: ignore


def sha256_bytes(data: bytes) -> str:
    return hashlib.sha256(data or b"").hexdigest()


def sha256_file(path: Path) -> str:
    h = hashlib.sha256()
    with path.open("rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            h.update(chunk)
    return h.hexdigest()


def canonical_json_bytes(obj: Any) -> bytes:
    return encode_canonical_json(obj)


def merkle_root_hex(leaves_hex: list[str]) -> str:
    if not leaves_hex:
        return sha256_bytes(b"")
    level: list[bytes] = [bytes.fromhex(h) for h in leaves_hex if isinstance(h, str) and len(h) == 64]
    if not level:
        return sha256_bytes(b"")
    while len(level) > 1:
        if len(level) % 2 == 1:
            level.append(level[-1])
        nxt: list[bytes] = []
        for i in range(0, len(level), 2):
            nxt.append(hashlib.sha256(level[i] + level[i + 1]).digest())
        level = nxt
    return level[0].hex()


def merkle_inclusion_proof(leaves_hex: list[str], index: int) -> dict:
    if index < 0 or index >= len(leaves_hex):
        raise ValueError("index out of range")
    level: list[bytes] = [bytes.fromhex(h) for h in leaves_hex]
    proof: list[dict] = []
    idx = index
    while len(level) > 1:
        if len(level) % 2 == 1:
            level.append(level[-1])
        sibling_idx = idx ^ 1
        sibling = level[sibling_idx]
        side = "left" if sibling_idx < idx else "right"
        proof.append({"sibling": sibling.hex(), "side": side})
        nxt: list[bytes] = []
        for i in range(0, len(level), 2):
            nxt.append(hashlib.sha256(level[i] + level[i + 1]).digest())
        level = nxt
        idx //= 2
    root = level[0].hex()
    return {"index": index, "root": root, "path": proof}


def merkle_verify_proof(leaf_hex: str, proof: dict) -> bool:
    try:
        cur = bytes.fromhex(leaf_hex)
        for step in proof.get("path", []):
            sib = bytes.fromhex(step["sibling"])
            if step["side"] == "left":
                cur = hashlib.sha256(sib + cur).digest()
            else:
                cur = hashlib.sha256(cur + sib).digest()
        return cur.hex() == proof.get("root")
    except Exception:
        return False


def read_json(path: Path) -> Any:
    return json.loads(path.read_text(encoding="utf-8", errors="strict"))


def verify_ed25519_hex(*, pub_hex: str, msg: bytes, sig_hex: str) -> None:
    if VerifyKey is None or HexEncoder is None:
        raise RuntimeError("Missing dependency: pynacl (install with `pip install pynacl`) to verify Ed25519 signatures")
    vk = VerifyKey(pub_hex, encoder=HexEncoder)
    vk.verify(msg, bytes.fromhex(sig_hex))


@dataclass(frozen=True)
class VerifyResult:
    ok: bool
    notes: list[str]


def safe_check(name: str, func: Callable[[], VerifyResult], *, debug: bool) -> VerifyResult:
    try:
        return func()
    except Exception as e:
        if debug:
            print(traceback.format_exc(), file=sys.stderr)
        return VerifyResult(False, [f"{name}: exception {type(e).__name__}: {e}"])


def verify_trace_events(events_path: Path) -> VerifyResult:
    notes: list[str] = []
    prev_hash = "0" * 64
    expected_idx = 0
    lines = events_path.read_text(encoding="utf-8", errors="ignore").splitlines()
    for line in lines:
        if not line.strip():
            continue
        obj = json.loads(line)
        ev = obj.get("event") or {}
        idx = int(ev.get("idx", -1))
        if idx != expected_idx:
            return VerifyResult(False, notes + [f"trace_events: idx mismatch (got {idx}, expected {expected_idx})"])
        if str(ev.get("prev_hash") or "") != prev_hash:
            return VerifyResult(False, notes + ["trace_events: prev_hash mismatch"])
        stored_hash = str(ev.get("event_hash") or "")
        payload = dict(ev)
        payload.pop("event_hash", None)
        recomputed = sha256_bytes(prev_hash.encode("utf-8") + canonical_json_bytes(payload))
        if recomputed != stored_hash:
            return VerifyResult(False, notes + ["trace_events: event_hash mismatch (recomputed != stored)"])
        prev_hash = stored_hash
        expected_idx += 1
    notes.append(f"trace_events: ok (events={expected_idx}, head_hash={prev_hash[:16]}…)")
    return VerifyResult(True, notes)


def verify_req_seen(ledger_path: Path, head_path: Path) -> VerifyResult:
    notes: list[str] = []
    head = read_json(head_path)
    pub_hex = str(head.get("signer_ed25519") or "").strip()
    sig_hex = str(head.get("sig_ed25519") or "").strip()
    if not pub_hex or not sig_hex:
        return VerifyResult(False, ["req_seen: missing signer_ed25519 or sig_ed25519 in head"])
    # Recreate the message that was signed (the head core before adding sig/key_id/signer).
    head_core = {
        "schema": head.get("schema"),
        "hour_utc": head.get("hour_utc"),
        "updated_utc": head.get("updated_utc"),
        "count": head.get("count"),
        "merkle_root": head.get("merkle_root"),
        "last_leaf_hash": head.get("last_leaf_hash"),
    }
    verify_ed25519_hex(pub_hex=pub_hex, msg=canonical_json_bytes(head_core), sig_hex=sig_hex)
    notes.append("req_seen_head: Ed25519 signature OK")

    leaves: list[str] = []
    lines = ledger_path.read_text(encoding="utf-8", errors="ignore").splitlines()
    for line in lines:
        if not line.strip():
            continue
        entry = json.loads(line)
        leaf = str(entry.get("leaf_hash") or "").strip()
        entry_core = dict(entry)
        entry_core.pop("leaf_hash", None)
        recomputed_leaf = sha256_bytes(canonical_json_bytes(entry_core))
        if recomputed_leaf != leaf:
            return VerifyResult(False, notes + ["req_seen: leaf_hash mismatch"])
        leaves.append(leaf)

    root = merkle_root_hex(leaves)
    if root != str(head.get("merkle_root") or ""):
        return VerifyResult(False, notes + ["req_seen: merkle_root mismatch"])
    if int(head.get("count") or 0) != len(leaves):
        return VerifyResult(False, notes + ["req_seen: count mismatch"])
    notes.append(f"req_seen: ok (count={len(leaves)}, merkle_root={root[:16]}…)")
    return VerifyResult(True, notes)


def verify_story(story_path: Path, events_path: Path) -> VerifyResult:
    notes: list[str] = []
    # Collect all event hashes from ground truth.
    hashes: set[str] = set()
    for line in events_path.read_text(encoding="utf-8", errors="ignore").splitlines():
        if not line.strip():
            continue
        ev = (json.loads(line).get("event") or {})
        h = str(ev.get("event_hash") or "").strip()
        if len(h) == 64:
            hashes.add(h)
    # Ensure every story line that mentions event_hash=... points to a real event.
    for line in story_path.read_text(encoding="utf-8", errors="ignore").splitlines():
        if "event_hash=" not in line:
            continue
        h = line.split("event_hash=", 1)[1].strip().split()[0]
        if h and h not in hashes:
            return VerifyResult(False, [f"if_story: unknown event_hash referenced: {h}"])
    notes.append("if_story: ok (all referenced event_hash values exist)")
    return VerifyResult(True, notes)


def verify_manifest(payload_dir: Path) -> VerifyResult:
    notes: list[str] = []
    manifest_path = payload_dir / "manifest.json"
    sha_list_path = payload_dir / "sha256s.txt"
    if not manifest_path.exists() or not sha_list_path.exists():
        return VerifyResult(False, ["manifest: missing manifest.json or sha256s.txt"])

    manifest = read_json(manifest_path)
    files = manifest.get("files") or []
    manifest_map = {f["path"]: f["sha256"] for f in files if isinstance(f, dict) and "path" in f and "sha256" in f}

    sha_map: dict[str, str] = {}
    for line in sha_list_path.read_text(encoding="utf-8", errors="ignore").splitlines():
        parts = line.strip().split()
        if len(parts) >= 2:
            sha_map[parts[1]] = parts[0]

    # sha256s.txt is a checksum file; it must not be self-referential.
    sha_map.pop("sha256s.txt", None)
    # manifest.json is the root index; do not make it self-referential in sha256s.
    sha_map.pop("manifest.json", None)

    for name, sha in sha_map.items():
        p = payload_dir / name
        if not p.exists():
            return VerifyResult(False, [f"manifest: sha256s references missing file: {name}"])
        got = sha256_file(p)
        if got != sha:
            return VerifyResult(False, [f"manifest: sha256 mismatch for {name}"])
        if name != "manifest.json":
            if manifest_map.get(name) != sha:
                return VerifyResult(False, [f"manifest: manifest.json mismatch for {name}"])

    notes.append(f"manifest: ok (files={len(sha_map)})")
    return VerifyResult(True, notes)


def extract_tarball(tar_path: Path) -> Path:
    tmp = Path(tempfile.mkdtemp(prefix="iftrace_"))
    with tarfile.open(tar_path, "r:gz") as tf:
        base = tmp.resolve()
        for m in tf.getmembers():
            name = str(m.name or "")
            if not name or name.startswith("/") or name.startswith("\\"):
                raise RuntimeError(f"tar: invalid member path: {name!r}")
            dest = (base / name).resolve()
            if dest != base and not str(dest).startswith(str(base) + os.sep):
                raise RuntimeError(f"tar: path traversal member: {name!r}")
        tf.extractall(tmp, filter=tarfile.data_filter)
    return tmp


def cmd_verify(args: argparse.Namespace) -> int:
    tar_path = Path(args.tar).resolve()
    expected_sha = (args.expected_sha256 or "").strip().lower()
    debug = bool(getattr(args, "debug", False))
    try:
        got_sha = sha256_file(tar_path)
    except Exception as e:
        print(f"FAIL tar_sha256: {type(e).__name__}: {e}")
        if debug:
            print(traceback.format_exc(), file=sys.stderr)
        return 2
    if expected_sha and got_sha != expected_sha:
        print(f"FAIL tar_sha256 expected={expected_sha} got={got_sha}")
        return 2
    print(f"OK tar_sha256 {got_sha}")

    try:
        root = extract_tarball(tar_path)
    except Exception as e:
        print(f"FAIL extract: {type(e).__name__}: {e}")
        if debug:
            print(traceback.format_exc(), file=sys.stderr)
        return 2
    payload_dir = root / "payload"
    if not payload_dir.exists():
        print("FAIL: tarball missing payload/ directory")
        return 2

    checks: list[VerifyResult] = []
    checks.append(safe_check("manifest", lambda: verify_manifest(payload_dir), debug=debug))

    events_path = payload_dir / "trace_events.jsonl"
    if events_path.exists():
        checks.append(safe_check("trace_events", lambda: verify_trace_events(events_path), debug=debug))

    story_path = payload_dir / "if_story.md"
    if story_path.exists() and events_path.exists():
        checks.append(safe_check("if_story", lambda: verify_story(story_path, events_path), debug=debug))

    # REQ_SEEN verification if present
    head_files = sorted(payload_dir.glob("req_seen_head_*.json"))
    ledger_files = sorted(payload_dir.glob("req_seen_*.jsonl"))
    if head_files and ledger_files:
        checks.append(safe_check("req_seen", lambda: verify_req_seen(ledger_files[0], head_files[0]), debug=debug))

    ok = True
    for res in checks:
        for n in res.notes:
            print(n)
        ok = ok and res.ok

    if not ok:
        print("FAIL verify")
        return 2
    print("OK verify")
    return 0


def cmd_prove_inclusion(args: argparse.Namespace) -> int:
    ledger = Path(args.ledger).resolve()
    head = Path(args.head).resolve()
    trace_id = (args.trace_id or "").strip()
    leaf_hash = (args.leaf_hash or "").strip().lower()

    leaves: list[str] = []
    idx_by_trace: dict[str, int] = {}
    lines = ledger.read_text(encoding="utf-8", errors="ignore").splitlines()
    for i, line in enumerate(lines):
        if not line.strip():
            continue
        entry = json.loads(line)
        lh = str(entry.get("leaf_hash") or "").strip()
        leaves.append(lh)
        tid = str(entry.get("trace_id") or "").strip()
        if tid and tid not in idx_by_trace:
            idx_by_trace[tid] = len(leaves) - 1

    if trace_id:
        if trace_id not in idx_by_trace:
            raise SystemExit("trace_id not found in ledger")
        index = idx_by_trace[trace_id]
        leaf_hash = leaves[index]
    else:
        if not leaf_hash:
            raise SystemExit("provide --trace-id or --leaf-hash")
        if leaf_hash not in leaves:
            raise SystemExit("leaf_hash not found in ledger")
        index = leaves.index(leaf_hash)

    proof = merkle_inclusion_proof(leaves, index)
    proof["leaf_hash"] = leaf_hash
    proof["hour_utc"] = read_json(head).get("hour_utc")
    print(json.dumps(proof, indent=2, sort_keys=True))
    return 0


def cmd_verify_inclusion(args: argparse.Namespace) -> int:
    proof = read_json(Path(args.proof).resolve())
    leaf = str(proof.get("leaf_hash") or "").strip()
    ok = merkle_verify_proof(leaf, proof)
    print("OK" if ok else "FAIL")
    return 0 if ok else 2


def cmd_selftest(args: argparse.Namespace) -> int:
    failures: list[str] = []

    def check(ok: bool, msg: str) -> None:
        if not ok:
            failures.append(msg)

    empty = merkle_root_hex([])
    check(empty == sha256_bytes(b""), f"merkle_root_hex(empty) mismatch: got {empty}")

    leaf = "c" * 64
    check(merkle_root_hex([leaf]) == leaf, "merkle_root_hex(single) should equal the leaf hash")

    leaves = ["a" * 64, "b" * 64, "c" * 64]
    root = merkle_root_hex(leaves)
    for i, lh in enumerate(leaves):
        proof = merkle_inclusion_proof(leaves, i)
        check(str(proof.get("root") or "") == root, f"proof root mismatch for index {i}")
        check(merkle_verify_proof(lh, proof), f"merkle_verify_proof failed for index {i}")

    proof0 = merkle_inclusion_proof(leaves, 0)
    check(not merkle_verify_proof("d" * 64, proof0), "merkle_verify_proof should fail for wrong leaf hash")

    try:
        merkle_inclusion_proof(leaves, -1)
        failures.append("merkle_inclusion_proof should raise on index -1")
    except ValueError:
        pass

    try:
        merkle_inclusion_proof(leaves, len(leaves))
        failures.append("merkle_inclusion_proof should raise on out-of-range index")
    except ValueError:
        pass

    if failures:
        for f in failures:
            print(f"FAIL selftest: {f}")
        return 2
    print("OK selftest")
    return 0


def main() -> int:
    ap = argparse.ArgumentParser(prog="iftrace")
    sub = ap.add_subparsers(dest="cmd", required=True)

    v = sub.add_parser("verify", help="Verify a trace payload tarball (manifest, hashes, chains, signatures)")
    v.add_argument("tar", help="Path to emo_trace_payload_<trace_id>.tar.gz")
    v.add_argument("--expected-sha256", default="", help="Expected tarball SHA256 (optional)")
    v.add_argument("--debug", action="store_true", help="Print tracebacks on errors")
    v.set_defaults(func=cmd_verify)

    p = sub.add_parser("prove-inclusion", help="Generate a Merkle inclusion proof for a REQ_SEEN ledger leaf")
    p.add_argument("--ledger", required=True, help="Path to req_seen_<hour>.jsonl")
    p.add_argument("--head", required=True, help="Path to req_seen_head_<hour>.json")
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--trace-id", default="", help="Trace ID to prove inclusion for")
    g.add_argument("--leaf-hash", default="", help="Leaf hash to prove inclusion for")
    p.set_defaults(func=cmd_prove_inclusion)

    pv = sub.add_parser("verify-inclusion", help="Verify a Merkle inclusion proof JSON")
    pv.add_argument("proof", help="Path to proof JSON")
    pv.set_defaults(func=cmd_verify_inclusion)

    st = sub.add_parser("selftest", help="Run internal self-tests (Merkle + proof verification)")
    st.set_defaults(func=cmd_selftest)

    args = ap.parse_args()
    return int(args.func(args))


if __name__ == "__main__":
    raise SystemExit(main())
