import folder_paths
import math
import torch
import torch.nn.functional as F
import numpy as np

def get_sample_indices(original_fps,
                       total_frames,
                       target_fps,
                       num_sample,
                       fixed_start=None):
    required_duration = num_sample / target_fps
    required_origin_frames = int(np.ceil(required_duration * original_fps))
    if required_duration > total_frames / original_fps:
        raise ValueError("required_duration must be less than video length")

    if not fixed_start is None and fixed_start >= 0:
        start_frame = fixed_start
    else:
        max_start = total_frames - required_origin_frames
        if max_start < 0:
            raise ValueError("video length is too short")
        start_frame = np.random.randint(0, max_start + 1)
    start_time = start_frame / original_fps

    end_time = start_time + required_duration
    time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)

    frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
    frame_indices = np.clip(frame_indices, 0, total_frames - 1)
    return frame_indices

def linear_interpolation(features, input_fps, output_fps, output_len=None):
    """
    features: shape=[1, T, 512]
    input_fps: fps for audio, f_a
    output_fps: fps for video, f_m
    output_len: video length
    """
    features = features.transpose(1, 2)  # [1, 512, T]
    seq_len = features.shape[2] / float(input_fps)  # T/f_a
    if output_len is None:
        output_len = int(seq_len * output_fps)  # f_m*T/f_a
    output_features = F.interpolate(
        features, size=output_len, align_corners=True,
        mode='linear')  # [1, 512, output_len]
    return output_features.transpose(1, 2)  # [1, output_len, 512]

class WanVideoAddS2VEmbeds:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "embeds": ("WANVIDIMAGE_EMBEDS",),
                    "frame_window_size": ("INT", {"default": 80, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of frames in a single window"}),
                    "audio_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for audio embeddings"}),
                    "pose_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage for pose embeddings"}),
                    "pose_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage for pose embeddings"})
                },
                "optional": {
                    "audio_encoder_output": ("AUDIO_ENCODER_OUTPUT",),
                    "ref_latent": ("LATENT",),
                    "pose_latent": ("LATENT",),
                    "vae": ("WANVAE",),
                    "enable_framepack": ("BOOLEAN", {"default": False, "tooltip": "Enable Framepack sampling loop, not compatible with context windows"})
                }
        }

    RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "INT",)
    RETURN_NAMES = ("image_embeds", "audio_frame_count")
    FUNCTION = "add"
    CATEGORY = "WanVideoWrapper"

    def add(self, embeds, frame_window_size, audio_encoder_output=None, audio_scale=1.0, ref_latent=None, pose_latent=None, vae=None, pose_start_percent=0.0, pose_end_percent=1.0, enable_framepack=False):
        audio_frame_count=0
        if audio_encoder_output is not None:
            all_layers = audio_encoder_output["encoded_audio_all_layers"]
            audio_feat = torch.stack(all_layers, dim=0).squeeze(1)  # shape: [num_layers, T, 512]

            print("audio_feat in", audio_feat.shape)
            input_fps = 50 # determined by the model itself
            output_fps = 30 # determined by the model itself
            bucket_fps = 16 # target fps for the generation

            if input_fps != output_fps:
                audio_feat = linear_interpolation(audio_feat, input_fps=input_fps, output_fps=output_fps)
            print("audio_feat after interpolation", audio_feat.shape)

            audio_feat = audio_feat[:, :embeds["num_frames"] * output_fps // bucket_fps, :]
            print("audio_feat after trim", audio_feat.shape)

            self.video_rate = output_fps

            audio_embed_bucket, num_repeat = self.get_audio_embed_bucket_fps(
                audio_feat,
                fps=bucket_fps,
                batch_frames=frame_window_size
            )
            print("audio_embed_bucket", audio_embed_bucket.shape)

            audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
            if len(audio_embed_bucket.shape) == 3:
                audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
            elif len(audio_embed_bucket.shape) == 4:
                audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)

            audio_frame_count = audio_embed_bucket.shape[-1]

            print("audio_embed_bucket", audio_embed_bucket.shape)

        new_entry = {
            "audio_embed_bucket": audio_embed_bucket if audio_encoder_output is not None else None,
            "num_repeat": num_repeat if audio_encoder_output is not None else None,
            "ref_latent": ref_latent["samples"] if ref_latent is not None else None,
            "pose_latent": pose_latent["samples"] if pose_latent is not None else None,
            "audio_scale": audio_scale,
            "vae": vae,
            "pose_start_percent": pose_start_percent,
            "pose_end_percent": pose_end_percent,
            "enable_framepack": enable_framepack,
            "frame_window_size": frame_window_size
        }
        updated = dict(embeds)
        updated["audio_embeds"] = new_entry
        return (updated, audio_frame_count)
    
    def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
        num_layers, audio_frame_num, audio_dim = audio_embed.shape

        if num_layers > 1:
            return_all_layers = True
        else:
            return_all_layers = False

        scale = self.video_rate / fps

        min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1

        bucket_num = min_batch_num * batch_frames
        padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
        batch_idx = get_sample_indices(
            original_fps=self.video_rate,
            total_frames=audio_frame_num + padd_audio_num,
            target_fps=fps,
            num_sample=bucket_num,
            fixed_start=0)
        batch_audio_eb = []
        audio_sample_stride = int(self.video_rate / fps)
        for bi in batch_idx:
            if bi < audio_frame_num:

                chosen_idx = list(
                    range(bi - m * audio_sample_stride,
                          bi + (m + 1) * audio_sample_stride,
                          audio_sample_stride))
                chosen_idx = [0 if c < 0 else c for c in chosen_idx]
                chosen_idx = [
                    audio_frame_num - 1 if c >= audio_frame_num else c
                    for c in chosen_idx
                ]

                if return_all_layers:
                    frame_audio_embed = audio_embed[:, chosen_idx].flatten(
                        start_dim=-2, end_dim=-1)
                else:
                    frame_audio_embed = audio_embed[0][chosen_idx].flatten()
            else:
                frame_audio_embed = \
                torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
                    else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
            batch_audio_eb.append(frame_audio_embed)
        batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
                                   dim=0)

        return batch_audio_eb, min_batch_num


    
NODE_CLASS_MAPPINGS = {
    "WanVideoAddS2VEmbeds": WanVideoAddS2VEmbeds,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "WanVideoAddS2VEmbeds": "WanVideo Add S2V Embeds",
}