#!/usr/bin/env python3
# =========================================================
# VLADFX ANIMATION TRANSFER v3.3
# Bake-Based Version - Works with ANY rig
# Progress bar fixed + Animation Layer warnings fixed
# =========================================================
import maya.cmds as cmds
import maya.mel as mel
import json
import os
import webbrowser

WINDOW_NAME = "VladfxAnimTransfer"

# =========================================================
# UTILS
# =========================================================
def short_name(node):
    return node.split("|")[-1].split(":")[-1]

def get_hierarchy(selection):
    result = []
    for root in selection:
        descendants = cmds.listRelatives(root, ad=True, fullPath=True) or []
        result.extend(descendants)
        result.append(root)
    return list(set(result))

# =========================================================
# EXPORT
# =========================================================
def export_animation(filepath, selection, start_frame, end_frame,
                     export_hierarchy=True, export_static=False):
    if export_hierarchy:
        nodes = get_hierarchy(selection)
        hierarchy_text = " (with hierarchy)"
    else:
        nodes = selection[:]
        hierarchy_text = " (ONLY selected controls)"

    total_frames = end_frame - start_frame + 1
    export_data = {}
    current_time_unit = cmds.currentUnit(q=True, time=True)

    print("\n" + "="*80)
    print("VLADFX ANIMATION TRANSFER v3.3 -- Exporting")
    print("="*80)
    print(f"Controls to process : {len(nodes)}{hierarchy_text}")
    print(f"Frame range         : {start_frame} -> {end_frame} ({total_frames} frames)")
    print(f"Frame rate          : {current_time_unit}")
    print("WARNING: Large rigs (Metahuman face) may take 30-90 seconds. Please wait...")
    print("-" * 80)

    cmds.progressWindow(
        title="Vladfx Animation Transfer - Exporting...",
        progress=0,
        maxValue=len(nodes),
        status="Starting export...",
        isInterruptable=True
    )

    total_attrs = 0
    exported_controls = 0

    for i, node in enumerate(nodes):
        short = short_name(node)
        cmds.progressWindow(edit=True, progress=i+1,
                            status=f"Processing: {short} ({i+1}/{len(nodes)})")
        cmds.refresh()

        if cmds.progressWindow(query=True, isCancelled=True):
            print("Export cancelled by user.")
            cmds.progressWindow(endProgress=True)
            return

        keyable = cmds.listAttr(node, keyable=True, unlocked=True, settable=True) or []
        channelbox = cmds.listAttr(node, channelBox=True, unlocked=True, settable=True) or []
        attrs = list(set(keyable + channelbox))

        node_exported = 0

        for attr in attrs:
            full_attr = f"{node}.{attr}"
            try:
                values = []
                first_value = None
                has_animation = export_static

                for frame in range(int(start_frame), int(end_frame) + 1):
                    value = cmds.getAttr(full_attr, time=frame)
                    values.append({"frame": frame, "value": value})

                    if first_value is None:
                        first_value = value
                    elif not has_animation:
                        if isinstance(value, (list, tuple)):
                            if any(abs(a - b) > 0.00001 for a, b in zip(value, first_value)):
                                has_animation = True
                        elif abs(value - first_value) > 0.00001:
                            has_animation = True

                if not has_animation and not export_static:
                    continue

                export_key = f"{short}.{attr}"
                export_data[export_key] = values
                total_attrs += 1
                node_exported += 1

            except:
                pass

        if node_exported > 0:
            exported_controls += 1
            print(f"OK {short} -> {node_exported} channels exported")
        else:
            print(f"   {short} -> no animation found")

    cmds.progressWindow(endProgress=True)

    export_data["__metadata__"] = {
        "time_unit": current_time_unit,
        "start_frame": start_frame,
        "end_frame": end_frame
    }

    with open(filepath, "w") as f:
        json.dump(export_data, f, indent=4)

    print("\n" + "="*80)
    print("VLADFX ANIMATION TRANSFER -- Export Complete")
    print("="*80)
    print(f"Controls with animation : {exported_controls}")
    print(f"Total channels exported : {total_attrs}")
    print(f"Frame rate saved        : {current_time_unit}")
    print(f"File saved              : {filepath}")
    print("="*80)

    cmds.inViewMessage(amg=f"" + chr(60) + "span style='color:#00FF00;" + chr(62) + "Export Complete!" + chr(60) + "/span" + chr(62) + "\n" + f"{exported_controls} controls    {total_attrs} channels", pos="topCenter", fade=True)

# =========================================================
# IMPORT (with layer fix)
# =========================================================
def import_animation(filepath, frame_offset=0, target_namespace_prefix="",
                     create_anim_layer=False, convert_to_trax=False,
                     dry_run=False, match_fps=False):
    if not os.path.exists(filepath):
        cmds.warning("File not found!")
        return

    with open(filepath, "r") as f:
        data = json.load(f)

    metadata = data.pop("__metadata__", {})
    exported_time_unit = metadata.get("time_unit")

    scene_nodes = cmds.ls(long=True)
    lookup = {short_name(node): node for node in scene_nodes}

    print("\n" + "="*80)
    print("VLADFX ANIMATION TRANSFER v3.3 -- Importing")
    print("="*80)

    if match_fps and exported_time_unit:
        current = cmds.currentUnit(q=True, time=True)
        if current != exported_time_unit:
            print(f"-> Changing frame rate from {current} to {exported_time_unit}")
            cmds.currentUnit(time=exported_time_unit)

    cmds.progressWindow(title="Vladfx Animation Transfer - Importing...", progress=0,
                        maxValue=len(data), status="Loading JSON...", isInterruptable=True)

    success = 0
    imported_nodes = set()

    for i, (target_attr, values) in enumerate(data.items()):
        cmds.progressWindow(edit=True, progress=i+1,
                            status=f"Importing: {target_attr} ({i+1}/{len(data)})")
        cmds.refresh()

        if cmds.progressWindow(query=True, isCancelled=True):
            break

        try:
            node_short, attr = target_attr.rsplit(".", 1)
            node_short_for_lookup = target_namespace_prefix + node_short if target_namespace_prefix else node_short

            if node_short_for_lookup not in lookup and node_short in lookup:
                node_short_for_lookup = node_short

            if node_short_for_lookup not in lookup:
                continue

            node = lookup[node_short_for_lookup]
            full_attr = f"{node}.{attr}"
            if not cmds.objExists(full_attr):
                continue

            imported_nodes.add(node)

            if dry_run:
                success += 1
                continue

            cmds.cutKey(full_attr, clear=True)
            for item in values:
                new_frame = item["frame"] + frame_offset
                cmds.setKeyframe(full_attr, time=new_frame, value=item["value"])

            success += 1

        except Exception as e:
            print(f"FAILED: {target_attr} -> {e}")

    cmds.progressWindow(endProgress=True)

    if create_anim_layer and imported_nodes and not dry_run:
        layer_name = "ImportedAnimation"
        if cmds.objExists(layer_name):
            layer_name = cmds.animLayer(name=layer_name)
        else:
            layer_name = cmds.animLayer(name=layer_name)

        cmds.select(list(imported_nodes))
        cmds.animLayer(layer_name, edit=True, addSelected=True)
        cmds.animLayer(layer_name, edit=True, selected=True)
        cmds.animLayer(layer_name, edit=True, preferred=True)

        print(f"OK Created Animation Layer: {layer_name} (set as active)")

    if convert_to_trax and imported_nodes and not dry_run:
        cmds.select(list(imported_nodes))
        try:
            mel.eval("CreateClip;")
            print("OK Created Trax clip")
        except:
            print("WARNING: Could not create Trax clip")

    print("\n" + "="*80)
    print("VLADFX ANIMATION TRANSFER -- Import Complete")
    print("="*80)
    print(f"Channels successfully imported : {success}")
    print(f"Controls affected             : {len(imported_nodes)}")
    print("="*80)

    cmds.inViewMessage(amg=f"" + chr(60) + "span style='color:#00FF00;" + chr(62) + "Import Complete!" + chr(60) + "/span" + chr(62) + "\n" + f"{success} channels    {len(imported_nodes)} controls", pos="topCenter", fade=True)

# =========================================================
# UI
# =========================================================
def toggle_fields(*args):
    use_timeline = cmds.checkBox(use_timeline_cb, q=True, value=True)
    cmds.intFieldGrp(start_field, edit=True, enable=not use_timeline)
    cmds.intFieldGrp(end_field, edit=True, enable=not use_timeline)

def show_ui():
    global use_timeline_cb, start_field, end_field

    if cmds.window(WINDOW_NAME, exists=True):
        cmds.deleteUI(WINDOW_NAME)

    cmds.window(WINDOW_NAME, title="Vladfx Animation Transfer v3.3", 
                widthHeight=(440, 620), sizeable=True)

    cmds.columnLayout(adjustableColumn=True, rowSpacing=12, columnAlign="center")

    cmds.separator(height=15, style="none")
    cmds.text(label="VLADFX ANIMATION TRANSFER", font="boldLabelFont", height=40, align="center")
    cmds.text(label="Pixels Behaves Sometimes.", font="smallPlainLabelFont", height=20, align="center")
    cmds.separator(height=15)

    # EXPORT SECTION
    cmds.frameLayout(label="EXPORT SETTINGS", collapsable=True, collapse=False, 
                     marginWidth=10, marginHeight=10, width=420)
    cmds.columnLayout(adjustableColumn=True, rowSpacing=8)

    use_timeline_cb = cmds.checkBox(label="Use Entire Timeline Range", 
                                    value=True,
                                    changeCommand=toggle_fields)

    start_field = cmds.intFieldGrp(label="Start Frame", 
                                   value1=int(cmds.playbackOptions(q=True, min=True)), 
                                   columnWidth=[1, 100], enable=False)

    end_field = cmds.intFieldGrp(label="End Frame", 
                                 value1=int(cmds.playbackOptions(q=True, max=True)), 
                                 columnWidth=[1, 100], enable=False)

    hierarchy_cb = cmds.checkBox(label="Export full hierarchy (recommended)", value=False)
    static_cb    = cmds.checkBox(label="Export static channels (neutral pose)", value=False)

    cmds.setParent("..")
    cmds.setParent("..")

    cmds.separator(height=15)

    # IMPORT SECTION
    cmds.frameLayout(label="IMPORT SETTINGS", collapsable=True, collapse=False, 
                     marginWidth=10, marginHeight=10, width=420)
    cmds.columnLayout(adjustableColumn=True, rowSpacing=8)

    offset_field = cmds.intFieldGrp(label="Frame Offset", value1=0, columnWidth=[1, 100])
    ns_field     = cmds.textFieldGrp(label="Target Namespace Prefix", text="", 
                                     placeholderText="e.g. rig: or leave empty", columnWidth=[1, 160])
    layer_cb     = cmds.checkBox(label="Create new Animation Layer", value=False)
    trax_cb      = cmds.checkBox(label="Convert to Trax clip", value=False)
    dry_cb       = cmds.checkBox(label="Dry-run (preview only)", value=False)
    match_fps_cb = cmds.checkBox(label="Match exported frame rate", value=False)

    cmds.setParent("..")
    cmds.setParent("..")

    cmds.separator(height=20)

    # Action Buttons
    cmds.rowLayout(numberOfColumns=2, columnWidth2=(215, 215), adjustableColumn=2)
    cmds.button(label="EXPORT ANIMATION", height=65, width=200, 
                command=lambda x: do_export(use_timeline_cb, start_field, end_field, hierarchy_cb, static_cb))
    cmds.button(label="IMPORT ANIMATION", height=65, width=200, 
                command=lambda x: do_import(offset_field, ns_field, layer_cb, trax_cb, dry_cb, match_fps_cb))
    cmds.setParent("..")

    # WEBSITE LINK
    cmds.separator(height=20, style="none")
    cmds.button(label="vladfx.com", 
                height=28, 
                width=160,
                align="center",
                command='import webbrowser; webbrowser.open("https://vladfx.com")',
                backgroundColor=(0.15, 0.15, 0.15))

    cmds.separator(height=10, style="none")
    cmds.showWindow()

# =========================================================
# BUTTON COMMANDS
# =========================================================
def do_export(use_timeline_cb, start_field, end_field, hierarchy_cb, static_cb):
    use_timeline = cmds.checkBox(use_timeline_cb, q=True, value=True)
    if use_timeline:
        start = int(cmds.playbackOptions(q=True, min=True))
        end   = int(cmds.playbackOptions(q=True, max=True))
    else:
        start = cmds.intFieldGrp(start_field, q=True, value1=True)
        end   = cmds.intFieldGrp(end_field, q=True, value1=True)

    hierarchy = cmds.checkBox(hierarchy_cb, q=True, value=True)
    static    = cmds.checkBox(static_cb, q=True, value=True)

    file_result = cmds.fileDialog2(fileMode=0, caption="Export Animation", fileFilter="JSON (*.json)")
    if not file_result: return

    export_animation(file_result[0], cmds.ls(sl=True, long=True), start, end, hierarchy, static)

def do_import(offset_field, ns_field, layer_cb, trax_cb, dry_cb, match_fps_cb):
    offset = cmds.intFieldGrp(offset_field, q=True, value1=True)
    ns_prefix = cmds.textFieldGrp(ns_field, q=True, text=True).strip()
    layer = cmds.checkBox(layer_cb, q=True, value=True)
    trax  = cmds.checkBox(trax_cb, q=True, value=True)
    dry   = cmds.checkBox(dry_cb, q=True, value=True)
    match_fps = cmds.checkBox(match_fps_cb, q=True, value=True)

    file_result = cmds.fileDialog2(fileMode=1, caption="Import Animation", fileFilter="JSON (*.json)")
    if not file_result: return

    import_animation(file_result[0], offset, ns_prefix, layer, trax, dry, match_fps)

# Launch UI
show_ui()
