"""
visualize.py — Step-by-step PNG visualization for the Scheinman minimizer.
Called by /api/visualize; returns base64-encoded images.
"""
from __future__ import annotations

import io
import os
import sys
import base64
from typing import Optional

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'python'))

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

from scheinman import (
    Implicant,
    build_trace,
    format_expression,
    InputRowRecord, MatchPairRecord, TreeStepRecord, PIChartRecord, MinimizationTrace,
)

# ── colour palette (matches draw_examples.py) ────────────────────────────────
_C_TAG   = '#1a1a8c'
_C_MATCH = '#8b0000'
_C_ABS   = '#666666'
_C_COLHD = '#f0f4f8'
_C_SEP   = '#cccccc'

_OUTPUT_SYMBOLS = ['α', 'β', 'γ']
_VAR_NAMES = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')


# ── low-level drawing primitives ─────────────────────────────────────────────

def _hline(ax, x0, x1, y, color='k', lw=1.0):
    ax.plot([x0, x1], [y, y], color=color, lw=lw, zorder=2)


def _vline(ax, x, y0, y1, color='k', lw=1.0):
    ax.plot([x, x], [y0, y1], color=color, lw=lw, zorder=2)


def _col_header(ax, cx, cy, text):
    ax.text(cx, cy, text, ha='center', va='center', fontsize=9,
            fontweight='bold', color='#222',
            bbox=dict(boxstyle='round,pad=0.30', facecolor=_C_COLHD,
                      edgecolor='#aaaacc', lw=0.8), zorder=4)


def _col_text(ax, x, y_top, rows, rh=0.52, fs=9.5, tag_dx=0.42):
    """Draw a vertical column of (value_str, tag_str, absorbed, mset_str|None) rows."""
    for i, (val, tag, abs_, mset) in enumerate(rows):
        yi = y_top - i * rh
        if abs_:
            ax.text(x - 0.22, yi, '✓', ha='right', va='center',
                    fontsize=fs - 1.5, color=_C_ABS)
        ax.text(x, yi, val, ha='center', va='center',
                fontsize=fs, family='monospace')
        if tag:
            ax.text(x + tag_dx, yi, tag, ha='left', va='center',
                    fontsize=fs - 2, style='italic', color=_C_TAG)
        if mset:
            ax.text(x, yi - rh * 0.38, mset, ha='center', va='center',
                    fontsize=fs - 2.5, color='#557799', style='italic')
    return y_top - (len(rows) - 1) * rh if rows else y_top


def _branch_lines(ax, root_x, root_y_bot, xl, xc, xm, branch_y):
    mid_y = (root_y_bot + branch_y) / 2
    _vline(ax, root_x, root_y_bot, mid_y, lw=1.1)
    _hline(ax, xl, xm, mid_y, lw=1.1)
    for bx in (xl, xc, xm):
        _vline(ax, bx, mid_y, branch_y, lw=1.1)


def _fig_to_b64(fig) -> str:
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=120, bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)
    return base64.b64encode(buf.read()).decode()


# ── formatting helpers ────────────────────────────────────────────────────────

def _mset_annotation(imp: Implicant, on_set: list[int]) -> str | None:
    covered = [m for m in on_set if imp.covers(m)]
    return '{' + ','.join(str(m) for m in sorted(covered)) + '}' if len(covered) > 1 else None


def _output_tag(outputs: frozenset, n_outputs: int) -> str:
    if n_outputs <= 1 or not outputs:
        return ''
    return ''.join(_OUTPUT_SYMBOLS[i] for i in sorted(outputs) if i < len(_OUTPUT_SYMBOLS))


def _pi_label(pi: Implicant, n: int) -> str:
    """Return ASCII label like 'A~BC~D' for a prime implicant."""
    return format_expression([pi], n, use_unicode=False)


# ── tree rendering ────────────────────────────────────────────────────────────

def _render_tree_step(step: TreeStepRecord, on_set: list[int], n: int) -> str:
    """Render one level of the tree as a base64 PNG, reading data from a TreeStepRecord."""
    var = _VAR_NAMES[step.depth] if step.depth < len(_VAR_NAMES) else f'V{step.depth}'
    n_outputs = step.n_outputs

    # Convert InputRowRecord lists to the (val_str, tag_str, absorbed, mset) tuples
    # that _col_text expects.

    def _row_to_tuple(row: InputRowRecord, right: bool = False) -> tuple:
        tag = _output_tag(row.outputs, n_outputs)
        val_str = str(row.display_value)
        return (val_str, tag, row.absorbed_at_depth, row.mset)

    root_rows = [_row_to_tuple(r) for r in step.input_rows]
    left_rows  = [_row_to_tuple(r) for r in step.left_rows]
    right_rows = [_row_to_tuple(r, right=True) for r in step.right_rows]

    # Build match-pair display tuples from MatchPairRecord list
    pairs: list[tuple[str, str, str, bool]] = []
    for pair in step.match_pairs:
        l_str = str(pair.left_value)
        r_str = str(pair.right_value_original)
        if n_outputs > 1:
            lt = _output_tag(pair.left_outputs, n_outputs)
            rt = _output_tag(pair.right_outputs, n_outputs)
            if lt:
                l_str += f'{{{lt}}}'
            if rt:
                r_str += f'{{{rt}}}'
            if pair.left_outputs and pair.right_outputs:
                if pair.is_kept:
                    shared_tag = _output_tag(pair.shared_outputs, n_outputs)
                    result_str = f"= {{{shared_tag}}}  →  keep"
                else:
                    result_str = "= ∅  →  discard"
            else:
                result_str = "→  merged  (keep)"
        else:
            result_str = "→  merged  (keep)"
        pairs.append((l_str, r_str, result_str, pair.is_kept))

    RH   = 0.52
    MRWH = 0.72

    n_col_rows  = max(len(left_rows), len(right_rows), 1)
    n_root_rows = max(len(root_rows), 1)
    n_match     = max(len(pairs), 1)

    content_h = (n_root_rows * RH + 2.0 +
                 max(n_col_rows * RH, n_match * MRWH) + 1.5)
    W, H = 17.0, min(max(content_h + 2.0, 9.0), 24.0)

    fig, ax = plt.subplots(figsize=(W, H))
    ax.set_xlim(0, W)
    ax.set_ylim(0, H)
    ax.axis('off')
    fig.patch.set_facecolor('white')

    # Title
    path_label = step.path_label
    n_input = len(step.input_rows)
    if path_label:
        title_top = f'Scheinman Tree  —  {path_label}  —  variable {var}'
    else:
        title_top = f'Scheinman Tree  —  depth {step.depth}  (variable {var})'
    ax.text(W / 2, H - 0.25, title_top,
            ha='center', va='top', fontsize=11, fontweight='bold')

    # Root list
    rx     = W / 2
    ry_top = H - 1.25
    ax.text(rx, ry_top + 0.45,
            f'Input list  ({n_input} term{"s" if n_input != 1 else ""})',
            ha='center', va='bottom', fontsize=9.5, fontweight='bold')
    ry_bot = _col_text(ax, rx, ry_top, root_rows, rh=RH, fs=9.5, tag_dx=0.44)

    # Three-way branch
    xl   = W * 0.14
    xc   = W * 0.50
    xm   = W * 0.86
    br_y = ry_bot - 1.4
    _branch_lines(ax, rx, ry_bot - 0.05, xl, xc, xm, br_y + 0.30)

    _col_header(ax, xl, br_y + 0.10, f'{var} = 0  (Left)')
    _col_header(ax, xc, br_y + 0.10, f'{var} = 1  (Right, bit cleared)')
    _col_header(ax, xm, br_y + 0.10, 'Match  (−)')

    col_y = br_y - 0.38
    if left_rows:
        _col_text(ax, xl, col_y, left_rows,  rh=RH, fs=9.5, tag_dx=0.40)
    else:
        ax.text(xl, col_y, '(none)', ha='center', va='center',
                fontsize=9, style='italic', color='#aaa')

    if right_rows:
        _col_text(ax, xc, col_y, right_rows, rh=RH, fs=9.0, tag_dx=0.55)
    else:
        ax.text(xc, col_y, '(none)', ha='center', va='center',
                fontsize=9, style='italic', color='#aaa')

    if pairs:
        for k, (lv, rv, res, keep) in enumerate(pairs):
            ky  = col_y - k * MRWH
            col = _C_MATCH if keep else '#999999'
            ax.text(xm, ky,
                    f'{lv}  ∩  {rv}',
                    ha='center', va='center', fontsize=8.5, color=col, style='italic')
            ax.text(xm, ky - 0.30, res,
                    ha='center', va='center', fontsize=8.5, color=col,
                    fontweight='bold' if keep else 'normal')
    else:
        ax.text(xm, col_y, '(no matches)', ha='center', va='center',
                fontsize=9, style='italic', color='#aaa')

    # Column separators
    sep_top = br_y + 0.50
    sep_bot = (col_y
               - max(len(left_rows), len(right_rows), 1) * RH
               - len(pairs) * MRWH * 0.4
               - 0.4)
    for sx in ((xl + xc) / 2, (xc + xm) / 2):
        ax.plot([sx, sx], [sep_bot, sep_top],
                color=_C_SEP, lw=0.7, linestyle='--', zorder=1)

    # Step footer badges — use bit_pos == 0 to detect "last variable"
    at_last_var = (step.bit_pos == 0)
    footer_y = sep_bot - 0.45
    for col_x, child_step in (
        (xl, step.child_step_left),
        (xc, step.child_step_right),
        (xm, step.child_step_matched),
    ):
        if child_step is not None:
            ax.text(col_x, footer_y, f'→ Step {child_step}',
                    ha='center', va='center', fontsize=8.5, color='#336699',
                    fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.25', facecolor='#ddeeff',
                              edgecolor='#336699', lw=0.8))
        elif at_last_var:
            ax.text(col_x, footer_y, '(last variable — items become PIs)',
                    ha='center', va='center', fontsize=7.5,
                    color='#557700', style='italic')
        else:
            ax.text(col_x, footer_y, '(absorbed — no subtree)',
                    ha='center', va='center', fontsize=7.5,
                    color='#999999', style='italic')

    # At the last variable, label each unabsorbed item with its resulting PI expression
    if at_last_var:
        pi_y = footer_y - 0.60
        ax.text(W / 2, pi_y + 0.22,
                'PIs produced at this step:',
                ha='center', va='center', fontsize=8.5, fontweight='bold', color='#444')
        pi_labels: list[str] = []
        matched_left_values = {p.left_value for p in step.match_pairs}
        # Unmatched left items become PIs
        for row in step.left_rows:
            if not row.absorbed_at_depth and row.display_value not in matched_left_values:
                expr = format_expression(
                    [Implicant(value=row.original_value, mask=row.mask)], n)
                pi_labels.append(expr)
        # Unmatched right items become PIs (original_value has the cleared bit restored)
        for row in step.right_rows:
            if not row.absorbed_at_depth and row.display_value not in matched_left_values:
                expr = format_expression(
                    [Implicant(value=row.original_value, mask=row.mask)], n)
                pi_labels.append(expr)
        # Matched pairs produce merged PIs (last bit cleared from mask)
        for pair in step.match_pairs:
            lmask = next(
                (r.mask for r in step.left_rows if r.display_value == pair.left_value), 0)
            merged_mask = lmask & ~(1 << step.bit_pos)
            expr = format_expression(
                [Implicant(value=pair.left_value_raw, mask=merged_mask)], n)
            pi_labels.append(expr)
            # In multi-output, the intersection may cover only a subset of the
            # left item's outputs — the remainder still produce a leaf PI at D=0.
            if pair.left_outputs and pair.shared_outputs and pair.left_outputs != pair.shared_outputs:
                pi_labels.append(format_expression(
                    [Implicant(value=pair.left_value_raw, mask=lmask)], n))
            # Same for the right item (D=1 side).
            if pair.right_outputs and pair.shared_outputs and pair.right_outputs != pair.shared_outputs:
                pi_labels.append(format_expression(
                    [Implicant(value=pair.right_value_original, mask=lmask)], n))
        if pi_labels:
            ax.text(W / 2, pi_y - 0.10,
                    '  ,  '.join(sorted(set(pi_labels), key=len)),
                    ha='center', va='center', fontsize=9.5,
                    color='#8b0000', fontweight='bold', family='monospace')
        else:
            ax.text(W / 2, pi_y - 0.10, '(all items absorbed — no new PIs)',
                    ha='center', va='center', fontsize=8.5,
                    color='#999', style='italic')

    return _fig_to_b64(fig)


# ── PI chart rendering ────────────────────────────────────────────────────────

def _render_pi_chart(chart: PIChartRecord, n: int, n_outputs: int) -> str:
    """Render the prime implicant chart as a base64 PNG."""
    sorted_minterms = chart.minterms
    sorted_pis = chart.prime_implicants
    essential = chart.essential_pis
    selected_set = set(chart.selected_pis)

    n_rows = len(sorted_pis)
    n_cols = len(sorted_minterms)

    if n_rows == 0 or n_cols == 0:
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.axis('off')
        ax.text(0.5, 0.5, 'No prime implicants' if n_rows == 0 else 'No minterms',
                ha='center', va='center', fontsize=12)
        return _fig_to_b64(fig)

    cell_w     = max(0.65, 7.5 / max(n_cols, 1))
    cell_h     = 0.55
    label_w    = 3.0
    top_margin = 1.1
    bot_margin = 0.8
    right_pad  = 0.5

    fig_w = max(label_w + n_cols * cell_w + right_pad, 5.0)
    fig_h = max(top_margin + n_rows * cell_h + bot_margin, 4.0)

    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.set_xlim(0, fig_w)
    ax.set_ylim(0, fig_h)
    ax.axis('off')
    fig.patch.set_facecolor('white')

    ax.text(fig_w / 2, fig_h - 0.18,
            'Prime Implicant Chart',
            ha='center', va='top', fontsize=11, fontweight='bold')

    for j, m in enumerate(sorted_minterms):
        cx = label_w + (j + 0.5) * cell_w
        cy = fig_h - top_margin + 0.15
        ax.text(cx, cy, str(m), ha='center', va='center',
                fontsize=8.0, fontweight='bold')

    # Circle radius: small fraction of cell_h so it never dwarfs the cell
    circle_r = min(cell_w * 0.30, cell_h * 0.40)

    for i, pi in enumerate(sorted_pis):
        row_y_ctr = fig_h - top_margin - (i + 0.5) * cell_h

        if pi in essential:
            bg = '#e8f0ff'
        elif pi in selected_set:
            bg = '#fff8e0'
        else:
            bg = '#ffffff'
        rect = plt.Rectangle(
            (0, row_y_ctr - cell_h / 2), fig_w, cell_h,
            facecolor=bg, edgecolor='none', zorder=0,
        )
        ax.add_patch(rect)

        label = _pi_label(pi, n)
        if n_outputs > 1:
            tag = _output_tag(pi.outputs, n_outputs)
            if tag:
                label += f' {{{tag}}}'
        ax.text(label_w - 0.12, row_y_ctr, label,
                ha='right', va='center', fontsize=8.0, family='monospace')

        for j, m in enumerate(sorted_minterms):
            if pi.covers(m):
                cx = label_w + (j + 0.5) * cell_w
                ax.text(cx, row_y_ctr, '×', ha='center', va='center',
                        fontsize=9, fontweight='bold', color='#333')
                if pi in essential:
                    circ = Circle(
                        (cx, row_y_ctr), circle_r,
                        fill=False, edgecolor='#1a1a8c', lw=1.3, zorder=3,
                    )
                    ax.add_patch(circ)

        if pi in selected_set:
            badge = 'E' if pi in essential else 'G'
            ax.text(fig_w - 0.22, row_y_ctr, badge,
                    ha='center', va='center', fontsize=7.5,
                    style='italic', color='#555',
                    fontweight='bold' if pi in essential else 'normal')

    # Grid
    grid_top = fig_h - top_margin
    grid_bot = fig_h - top_margin - n_rows * cell_h
    for i in range(n_rows + 1):
        y = grid_top - i * cell_h
        ax.plot([0, fig_w], [y, y], color='#cccccc', lw=0.5, zorder=1)
    ax.plot([0, fig_w], [grid_top, grid_top], color='#888', lw=0.9, zorder=2)
    ax.plot([0, fig_w], [grid_bot, grid_bot], color='#888', lw=0.9, zorder=2)
    ax.plot([label_w, label_w], [grid_bot, grid_top], color='#888', lw=0.8, zorder=2)
    for j in range(n_cols + 1):
        x = label_w + j * cell_w
        ax.plot([x, x], [grid_bot, grid_top], color='#cccccc', lw=0.5, zorder=1)

    ax.text(fig_w / 2, 0.22,
            'E = essential (circled, blue row)   '
            'G = greedy-selected (yellow row)   '
            '× = covers minterm',
            ha='center', va='center', fontsize=7.0,
            style='italic', color='#666')

    return _fig_to_b64(fig)


# ── tree overview rendering ───────────────────────────────────────────────────

def _render_tree_overview(trace: MinimizationTrace) -> str:
    """Render a flowchart-style tree overview: step boxes linked by M/L/R edges."""
    from collections import defaultdict

    steps = trace.tree_steps
    if not steps:
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.axis('off')
        return _fig_to_b64(fig)

    depth_to_steps: dict = defaultdict(list)
    for s in steps:
        depth_to_steps[s.depth].append(s)

    max_depth = max(s.depth for s in steps)
    max_per_depth = max(len(v) for v in depth_to_steps.values())

    cell_w, cell_h = 2.8, 1.2
    pad_x, pad_y = 0.9, 0.7
    fig_w = (max_depth + 1) * cell_w + 2 * pad_x
    fig_h = max_per_depth * cell_h + 2 * pad_y + 0.8

    fig, ax = plt.subplots(figsize=(max(fig_w, 6), max(fig_h, 4)))
    ax.set_xlim(0, fig_w)
    ax.set_ylim(0, fig_h)
    ax.axis('off')
    fig.patch.set_facecolor('white')

    ax.text(fig_w / 2, fig_h - 0.2,
            f'Tree Overview  —  n={trace.n} variable{"s" if trace.n != 1 else ""}',
            ha='center', va='top', fontsize=11, fontweight='bold')

    # Assign (x, y) to each step box
    step_pos: dict = {}
    for depth in sorted(depth_to_steps):
        col_steps = depth_to_steps[depth]
        x = pad_x + depth * cell_w + cell_w / 2
        for rank, s in enumerate(col_steps):
            y = fig_h - 0.7 - pad_y - rank * cell_h
            step_pos[s.step_index] = (x, y)

    # Draw edges behind boxes
    edge_cfg = {
        'M': ('#8b0000', '-'),
        'L': ('#336699', '--'),
        'R': ('#336699', ':'),
    }
    for s in steps:
        x0, y0 = step_pos[s.step_index]
        for label, child_idx in (
            ('M', s.child_step_matched),
            ('L', s.child_step_left),
            ('R', s.child_step_right),
        ):
            if child_idx is None:
                continue
            pos = step_pos.get(child_idx)
            if pos is None:
                continue
            x1, y1 = pos
            col, ls = edge_cfg[label]
            ax.annotate('', xy=(x1, y1), xytext=(x0, y0),
                        arrowprops=dict(arrowstyle='->', color=col, lw=1.2,
                                        linestyle=ls, connectionstyle='arc3,rad=0.1'))
            mx, my = (x0 + x1) / 2, (y0 + y1) / 2
            ax.text(mx, my + 0.13, label, ha='center', va='bottom',
                    fontsize=7.5, color=col, fontweight='bold')

    # Draw step boxes on top
    for s in steps:
        x, y = step_pos[s.step_index]
        var = _VAR_NAMES[s.depth] if s.depth < len(_VAR_NAMES) else f'V{s.depth}'
        n_items = len(s.input_rows)
        label = f'Step {s.step_index}\nvar {var} — {n_items} item{"s" if n_items != 1 else ""}'
        bw, bh = cell_w * 0.82, cell_h * 0.66
        rect = plt.Rectangle((x - bw / 2, y - bh / 2), bw, bh,
                              facecolor='#e8f0ff', edgecolor='#336699',
                              lw=1.0, zorder=3)
        ax.add_patch(rect)
        ax.text(x, y, label, ha='center', va='center', fontsize=7.5,
                multialignment='center', zorder=4)

    ax.text(fig_w / 2, 0.20,
            'M = matched branch (solid red)   '
            'L = left / 0 branch (dashed blue)   '
            'R = right / 1 branch (dotted blue)',
            ha='center', va='bottom', fontsize=7, style='italic', color='#666')

    return _fig_to_b64(fig)


# ── public entry point ────────────────────────────────────────────────────────

def build_visualize_payload(
    minterms: list[int],
    dont_cares: list[int],
    n: int,
    overlapping: bool = True,
    functions: list[list[int]] | None = None,
) -> dict:
    """
    Run the Scheinman pipeline, capture intermediate data, render step images.
    Returns {"steps": [{"title": str, "image": base64_png}, ...]}.
    """
    trace = build_trace(minterms, dont_cares, n, overlapping, functions)
    if not trace.tree_steps:
        return {'steps': []}

    steps = []

    # Tree overview first — gives the holistic view matching the source material diagrams
    steps.append({
        'title': 'Step 0 — Tree Overview',
        'image': _render_tree_overview(trace),
    })

    for step in trace.tree_steps[:16]:
        var = _VAR_NAMES[step.depth] if step.depth < len(_VAR_NAMES) else f'V{step.depth}'
        path = step.path_label
        if path:
            title = f'Tree — {path} — variable {var}'
        else:
            title = f'Tree — depth {step.depth} — variable {var}'
        steps.append({
            'title': f'Step {step.step_index} — {title}',
            'image': _render_tree_step(step, trace.on_set, trace.n),
        })

    steps.append({
        'title': f'Step {len(steps)} — Prime Implicant Chart',
        'image': _render_pi_chart(trace.pi_chart, trace.n, trace.n_outputs),
    })

    return {'steps': steps}
