#!/usr/bin/env python3
"""geodesic-dome-mcp — MCP tool: geodesic dome wireframes + bill of materials.

Class I icosahedron subdivision → structural analysis → Pillow render.
Reports: strut lengths/counts, connector types, triangle areas, material BOM.
Pure Python + numpy + Pillow. No Blender, no GPU, no X11.
"""

import json, math, sys
from collections import defaultdict, Counter
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw

OUT_DIR = Path("/tmp/geodesic-dome")
OUT_DIR.mkdir(exist_ok=True)

PHI = (1 + math.sqrt(5)) / 2

ICO_V = np.array([
    [-1, PHI, 0], [1, PHI, 0], [-1, -PHI, 0], [1, -PHI, 0],
    [0, -1, PHI], [0, 1, PHI], [0, -1, -PHI], [0, 1, -PHI],
    [PHI, 0, -1], [PHI, 0, 1], [-PHI, 0, -1], [-PHI, 0, 1],
], dtype=np.float64)
ICO_V = ICO_V / np.linalg.norm(ICO_V, axis=1, keepdims=True)

ICO_F = [
    [0, 11, 5], [0, 5, 1], [0, 1, 7], [0, 7, 10], [0, 10, 11],
    [1, 5, 9], [5, 11, 4], [11, 10, 2], [10, 7, 6], [7, 1, 8],
    [3, 9, 4], [3, 4, 2], [3, 2, 6], [3, 6, 8], [3, 8, 9],
    [4, 9, 5], [2, 4, 11], [6, 2, 10], [8, 6, 7], [9, 8, 1],
]

# Material specs per m²
MATERIALS = {
    "glass":         {"kg_per_m2": 2.5,  "cost_per_m2": 25,  "transparent": True},
    "polycarbonate": {"kg_per_m2": 1.2,  "cost_per_m2": 35,  "transparent": True},
    "insulated":     {"kg_per_m2": 5.0,  "cost_per_m2": 50,  "transparent": False},
}


def geodesic_dome(frequency, half=False):
    """Build dome: vertices array, edge list, triangle list (each tri = 3 vertex indices)."""
    vlist, vmap = [], {}
    def add(pt):
        key = tuple(round(x, 10) for x in pt)
        if key not in vmap:
            vmap[key] = len(vlist)
            vlist.append(pt)
        return vmap[key]

    edges = set()
    triangles = []

    for face in ICO_F:
        a, b, c = ICO_V[face[0]], ICO_V[face[1]], ICO_V[face[2]]
        grid = {}
        for i in range(frequency + 1):
            for j in range(frequency + 1 - i):
                k = frequency - i - j
                pt = (i * a + j * b + k * c) / frequency
                pt = pt / np.linalg.norm(pt)
                # Half-sphere: skip if all three face vertices are below equator
                grid[(i, j)] = add(pt)

        for i in range(frequency):
            for j in range(frequency - i):
                v0, v1, v2 = grid[(i,j)], grid[(i+1,j)], grid[(i,j+1)]
                triangles.append((v0, v1, v2))
                edges.add(tuple(sorted((v0, v1))))
                edges.add(tuple(sorted((v1, v2))))
                edges.add(tuple(sorted((v2, v0))))
                if i + j + 2 <= frequency:
                    v0, v1, v2 = grid[(i+1,j)], grid[(i+1,j+1)], grid[(i,j+1)]
                    triangles.append((v0, v1, v2))
                    edges.add(tuple(sorted((v0, v1))))
                    edges.add(tuple(sorted((v1, v2))))
                    edges.add(tuple(sorted((v2, v0))))

    verts = np.array(vlist)

    if half:
        # Keep only triangles with at least one vertex above equator (z > -0.05)
        kept_tris = []
        kept_edges = set()
        kept_verts = set()
        for tri in triangles:
            zs = [verts[v][2] for v in tri]
            if any(z > -0.05 for z in zs):
                kept_tris.append(tri)
                for e in [(tri[0],tri[1]), (tri[1],tri[2]), (tri[2],tri[0])]:
                    kept_edges.add(tuple(sorted(e)))
                    kept_verts.update(e)
        triangles = kept_tris
        edges = list(kept_edges)
    else:
        edges = list(edges)

    return verts, edges, triangles


def analyze_structure(verts, edges, triangles, radius_meters=5.0):
    """Structural analysis: struts, connectors, triangles, BOM."""
    n_verts = len(verts)
    n_edges = len(edges)
    n_tris = len(triangles)

    # ── Strut lengths ──
    chord_lengths = []
    for a, b in edges:
        chord = np.linalg.norm(verts[a] - verts[b]) * radius_meters
        chord_lengths.append(chord)

    # Group by 1cm buckets
    strut_groups = Counter()
    for cl in chord_lengths:
        bucket = round(cl, 2)
        strut_groups[bucket] += 1

    strut_table = []
    for length in sorted(strut_groups):
        count = strut_groups[length]
        strut_table.append({
            "length_m": round(length, 2),
            "count": count,
            "total_m": round(length * count, 2),
        })

    unique_strut_sizes = len(strut_table)
    total_strut_m = round(sum(s["total_m"] for s in strut_table), 2)

    # ── Connector types ──
    valence = Counter()
    for a, b in edges:
        valence[a] += 1
        valence[b] += 1

    connectors = Counter()
    for v, deg in valence.items():
        connectors[deg] += 1

    connector_report = {}
    for deg in sorted(connectors):
        label = f"{deg}-way"
        connector_report[label] = connectors[deg]

    # ── Triangle areas ──
    tri_areas = []
    for v0, v1, v2 in triangles:
        a = verts[v0] * radius_meters
        b = verts[v1] * radius_meters
        c = verts[v2] * radius_meters
        # Area = 0.5 * |(b-a) × (c-a)|
        area = 0.5 * np.linalg.norm(np.cross(b - a, c - a))
        tri_areas.append(area)

    # Group by 0.05m² buckets
    area_groups = Counter()
    for a in tri_areas:
        area_groups[round(a, 1)] += 1

    tri_table = []
    for area in sorted(area_groups):
        count = area_groups[area]
        tri_table.append({
            "area_m2": round(area, 2),
            "count": count,
            "total_m2": round(area * count, 2),
        })

    total_area_m2 = round(sum(t["total_m2"] for t in tri_table), 2)

    # ── Bill of materials ──
    bom = {}
    for mat_name, spec in MATERIALS.items():
        bom[mat_name] = {
            "total_m2": total_area_m2,
            "total_kg": round(total_area_m2 * spec["kg_per_m2"], 1),
            "total_cost_usd": round(total_area_m2 * spec["cost_per_m2"]),
            "transparent": spec["transparent"],
            "per_m2_kg": spec["kg_per_m2"],
            "per_m2_cost_usd": spec["cost_per_m2"],
        }

    return {
        "dome_radius_m": radius_meters,
        "frequency": None,  # set by caller
        "vertices": n_verts,
        "edges": n_edges,
        "triangles": n_tris,
        "unique_strut_sizes": unique_strut_sizes,
        "struts": strut_table,
        "total_strut_meters": total_strut_m,
        "connectors": connector_report,
        "total_connectors": sum(connectors.values()),
        "triangle_types": len(tri_table),
        "triangles_by_area": tri_table[:10],  # top 10, not all
        "total_surface_m2": total_area_m2,
        "bill_of_materials": bom,
    }


def analyze_strut_constraint(struts, max_length=2.0, bar_length=6.0):
    """Enforce 2m max strut length for FI12 rebar (6m bars → 3× 2m cuts)."""
    pieces_per_bar = int(bar_length / max_length)  # 3 pieces of 2m
    over = []
    ok = []
    total_bars = 0
    total_waste_m = 0.0

    for s in struts:
        L = s["length_m"]
        count = s["count"]
        if L > max_length:
            over.append({"length_m": L, "count": count, "excess_m": round(L - max_length, 2)})
        else:
            ok.append(s)

    # Calculate rebar bars needed and waste
    total_pieces = sum(s["count"] for s in struts)
    total_bars = math.ceil(total_pieces / pieces_per_bar)
    total_cut_m = total_bars * bar_length
    total_strut_m_needed = sum(s["length_m"] * s["count"] for s in struts)
    total_waste_m = round(total_cut_m - total_strut_m_needed, 2)
    waste_pieces = total_bars * pieces_per_bar - total_pieces

    return {
        "max_strut_length_m": max_length,
        "bar_length_m": bar_length,
        "pieces_per_bar": pieces_per_bar,
        "struts_over_limit": over,
        "struts_ok": len(ok),
        "total_pieces_needed": total_pieces,
        "rebar_bars_needed": total_bars,
        "total_bar_meters": round(total_bars * bar_length, 1),
        "waste_meters": total_waste_m,
        "waste_pieces": waste_pieces,
        "waste_use": "corner hardening gussets" if waste_pieces > 0 else "none",
        "compliant": len(over) == 0,
    }


def construction_layers(verts, triangles, radius):
    """Group triangles into construction layers ordered top-to-bottom.
    
    Welding sequence: start at apex (highest Z), weld that layer,
    lift the completed section, weld the next layer underneath.
    """
    # Compute centroid Z for each triangle
    tri_zs = []
    for idx, (v0, v1, v2) in enumerate(triangles):
        z = (verts[v0][2] + verts[v1][2] + verts[v2][2]) / 3 * radius
        tri_zs.append((z, idx))

    # Sort top-to-bottom (descending Z)
    tri_zs.sort(key=lambda x: -x[0])

    # Group into layers by Z proximity (every 0.5m drop = new layer)
    layers = []
    current_layer = []
    current_z = None
    layer_threshold = 0.5  # meters

    for z, idx in tri_zs:
        if current_z is None:
            current_z = z
        if abs(z - current_z) > layer_threshold:
            layers.append({
                "layer": len(layers) + 1,
                "height_m": round(current_z, 2),
                "triangles": len(current_layer),
                "sequence": "weld → lift → next layer" if len(layers) > 0 else "start at apex",
            })
            current_layer = []
            current_z = z
        current_layer.append(idx)

    if current_layer:
        layers.append({
            "layer": len(layers) + 1,
            "height_m": round(current_z, 2),
            "triangles": len(current_layer),
            "sequence": "weld → lift → next layer" if len(layers) > 0 else "start at apex",
        })

    return {
        "total_layers": len(layers),
        "construction_method": "top-to-bottom welding, lift after each layer",
        "layer_threshold_m": layer_threshold,
        "layers": layers,
    }


def render(verts, edges, size=2048, yaw=0.3, pitch=0.4,
           line_color="#00b4d8", bg_color="#0d1117", lw=1):
    sy, cy = math.sin(yaw), math.cos(yaw)
    sp, cp = math.sin(pitch), math.cos(pitch)
    rotated = np.empty_like(verts)
    for idx, v in enumerate(verts):
        x = v[0] * cy + v[2] * sy
        z = -v[0] * sy + v[2] * cy
        y = v[1]
        yy = y * cp - z * sp
        zz = y * sp + z * cp
        rotated[idx] = [x, yy, zz]

    distance = 4.0
    scale = size * 0.42
    cx = cy = size / 2
    img = Image.new("RGB", (size, size), bg_color)
    draw = ImageDraw.Draw(img)

    for a, b in edges:
        p, q = rotated[a], rotated[b]
        if p[2] < -0.2 and q[2] < -0.2:
            continue
        pz = distance - p[2]
        qz = distance - q[2]
        if pz < 0.1 or qz < 0.1:
            continue
        px = cx + p[0] * scale / pz
        py = cy - p[1] * scale / pz
        qx = cx + q[0] * scale / qz
        qy = cy - q[1] * scale / qz
        avg_z = (p[2] + q[2]) / 2
        depth = 0.2 + 0.8 * max(0, min(1, (avg_z + 1.0) / 2.0))
        r = int(int(line_color[1:3], 16) * depth)
        g = int(int(line_color[3:5], 16) * depth)
        b = int(int(line_color[5:7], 16) * depth)
        color = f"#{r:02x}{g:02x}{b:02x}"
        draw.line([(px, py), (qx, qy)], fill=color, width=lw)

    return img


def generate(params):
    freq = int(params.get("frequency", 3))
    half = params.get("half", "false").lower() in ("true", "1", "yes")
    radius = float(params.get("radius_m", 5.0))
    size = int(params.get("size", 2048))
    yaw = float(params.get("yaw", 0.3))
    pitch = float(params.get("pitch", 0.4))
    lc = params.get("line_color", "#00b4d8")
    bg = params.get("bg_color", "#0d1117")
    lw = int(params.get("line_width", 1))
    analysis_only = params.get("analysis_only", "false").lower() in ("true", "1", "yes")

    verts, edges, triangles = geodesic_dome(freq, half=half)
    analysis = analyze_structure(verts, edges, triangles, radius_meters=radius)
    analysis["frequency"] = freq
    analysis["half_sphere"] = half

    # ── Strut constraint: 2m max (FI12 rebar, 6m bars → 3× 2m cuts) ──
    strut_constraint = analyze_strut_constraint(analysis["struts"], max_length=2.0, bar_length=6.0)
    analysis["strut_constraint"] = strut_constraint

    # ── Construction layers: order triangles top-to-bottom for welding ──
    layers = construction_layers(verts, triangles, radius)
    analysis["construction_layers"] = layers

    result = {"success": True, "analysis": analysis}

    if not analysis_only:
        img = render(verts, edges, size=size, yaw=yaw, pitch=pitch,
                     line_color=lc, bg_color=bg, lw=lw)
        shape = "half" if half else "full"
        stem = f"geodesic-f{freq}-{shape}-y{yaw:.2f}-p{pitch:.2f}"
        path = OUT_DIR / f"{stem}.png"
        img.save(path)
        result["image_path"] = str(path)
        result["size"] = f"{size}x{size}"
        result["yaw"] = yaw
        result["pitch"] = pitch

    return result


def main():
    raw = sys.stdin.read().strip()
    if not raw:
        print(json.dumps({"error": "empty input"})); sys.exit(1)
    req = json.loads(raw)
    rid = req.get("id", 1)
    args = req.get("params", {}).get("arguments", {})
    try:
        result = generate(args)
        resp = {"jsonrpc":"2.0","id":rid,"result":{"content":[{"type":"text","text":json.dumps(result)}]}}
    except Exception as e:
        resp = {"jsonrpc":"2.0","id":rid,"error":{"code":-1,"message":str(e)}}
    print(json.dumps(resp))

if __name__ == "__main__":
    main()
