import json
import torch
import torchvision.transforms.functional as TF
from ..utils import log
from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature
import os
from comfy import model_management as mm
device = mm.get_torch_device()
script_directory = os.path.dirname(os.path.abspath(__file__))

VAE_STRIDE = (4, 8, 8)  # t, h, w

class WanVideoWanDrawWanMoveTracks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "images": ("IMAGE",),
                    "tracks": ("TRACKS",),
                },
                "optional": {
                    "line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}),
                    "circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}),
                    "opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}),
                    "line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}),
                }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "execute"
    CATEGORY = "WanVideoWrapper"

    def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14):
        if tracks is None or "track_path" not in tracks:
            log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.")
            return (images.float().cpu(), )
        track = tracks["track_path"].unsqueeze(0)
        track_visibility = tracks["track_visibility"].unsqueeze(0)
        images_in = images * 255.0
        if images_in.shape[0] != track.shape[1]:
            repeat_count = track.shape[1] // images.shape[0]
            images_in = images_in.repeat(repeat_count, 1, 1, 1)
        track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
        track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1)

        return (track_video.float().cpu(), )


class WanVideoAddWanMoveTracks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image_embeds": ("WANVIDIMAGE_EMBEDS",),
                    "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}),
                },
                "optional": {
                    "track_mask": ("MASK",),
                    "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
                    "tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}),
                }
        }

    RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS")
    RETURN_NAMES = ("image_embeds", "tracks")
    FUNCTION = "add"
    CATEGORY = "WanVideoWrapper"

    def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None):
        updated = dict(image_embeds)

        track_visibility = None

        target_shape = image_embeds.get("target_shape")
        if target_shape is not None:
            height = target_shape[2] * VAE_STRIDE[1]
            width = target_shape[3] * VAE_STRIDE[2]
        else:
            height = image_embeds["lat_h"] * VAE_STRIDE[1]
            width = image_embeds["lat_w"] * VAE_STRIDE[2]
        num_frames = image_embeds["num_frames"]

        if track_coords is not None:
            tracks_data = parse_json_tracks(track_coords)
            track_list = [
                [[track[frame]['x'], track[frame]['y']] for track in tracks_data]
                for frame in range(len(tracks_data[0]))
            ]
            track = torch.tensor(track_list, dtype=torch.float32, device=device)  # shape: (frames, num_tracks, 2)
        elif tracks is not None and "track_path" in tracks:
            track = tracks["track_path"]
            if track_mask is None:
                track_visibility = tracks.get("track_visibility", None)
        track = track[:num_frames]

        num_tracks = track.shape[-2]
        if track_visibility is None:
            if track_mask is None:
                track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
            else:
                track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
        feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)

        updated.setdefault("wanmove_embeds", {})
        updated["wanmove_embeds"]["track_pos"] = track_pos
        updated["wanmove_embeds"]["strength"] = strength

        tracks_dict = {
            "track_path": track,
            "track_visibility": track_visibility,
        }

        return (updated, tracks_dict,)


def parse_json_tracks(tracks):
    tracks_data = []
    try:
        # If tracks is a string, try to parse it as JSON
        if isinstance(tracks, str):
            parsed = json.loads(tracks.replace("'", '"'))
            tracks_data.extend(parsed)
        else:
            # If tracks is a list of strings, parse each one
            for track_str in tracks:
                parsed = json.loads(track_str.replace("'", '"'))
                tracks_data.append(parsed)

        # Check if we have a single track (dict with x,y) or a list of tracks
        if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
            # Single track detected, wrap it in a list
            tracks_data = [tracks_data]
        elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
            # Already a list of tracks, nothing to do
            pass
        else:
            # Unexpected format
            log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}")

    except json.JSONDecodeError as e:
        log.warning(f"Error parsing tracks JSON: {e}")
        tracks_data = []

    return tracks_data

import node_helpers

class WanMove_native:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "positive": ("CONDITIONING",),
            "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
            },
            "optional": {
                "track_mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("CONDITIONING", "TRACKS")
    RETURN_NAMES = ("positive", "tracks")
    FUNCTION = "patchcond"
    CATEGORY = "WanVideoWrapper"
    DEPRECATED = True

    def patchcond(self, positive, track_coords, track_mask=None):

        concat_latent_image = positive[0][1]["concat_latent_image"]
        B, C, T, H, W = concat_latent_image.shape
        num_frames = (T-1) * 4 + 1
        width = W * 8
        height = H * 8

        tracks_data = parse_json_tracks(track_coords)
        track_list = [
            [[track[frame]['x'], track[frame]['y']] for track in tracks_data]
            for frame in range(len(tracks_data[0]))
        ]
        track = torch.tensor(track_list, dtype=torch.float32, device=device)  # shape: (frames, num_tracks, 2)
        track = track[:num_frames]

        num_tracks = track.shape[-2]
        if track_mask is None:
            track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
        else:
            track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)

        feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
        wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0))
        positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond})

        tracks_dict = {
            "track_path": track,
            "track_visibility": track_visibility,
        }
        return (positive, tracks_dict)


NODE_CLASS_MAPPINGS = {
    "WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks,
    "WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks,
    "WanMove_native": WanMove_native,
    }
NODE_DISPLAY_NAME_MAPPINGS = {
    "WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks",
    "WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks",
    "WanMove_native": "WanMove Native",
    }
