import numpy as np
from typing import Callable, Optional, List
import torch
from ..utils import log

def ordered_halving(val):
    bin_str = f"{val:064b}"
    bin_flip = bin_str[::-1]
    as_int = int(bin_flip, 2)

    return as_int / (1 << 64)

def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
    prev_val = -1
    for i, val in enumerate(window):
        val = val % num_frames
        if val < prev_val:
            return True, i
        prev_val = val
    return False, -1

def shift_window_to_start(window: list[int], num_frames: int):
    start_val = window[0]
    for i in range(len(window)):
        # 1) subtract each element by start_val to move vals relative to the start of all frames
        # 2) add num_frames and take modulus to get adjusted vals
        window[i] = ((window[i] - start_val) + num_frames) % num_frames

def shift_window_to_end(window: list[int], num_frames: int):
    # 1) shift window to start
    shift_window_to_start(window, num_frames)
    end_val = window[-1]
    end_delta = num_frames - end_val - 1
    for i in range(len(window)):
        # 2) add end_delta to each val to slide windows to end
        window[i] = window[i] + end_delta

def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
    all_indexes = list(range(num_frames))
    for w in windows:
        for val in w:
            try:
                all_indexes.remove(val)
            except ValueError:
                pass
    return all_indexes

def uniform_looped(
    step: int = ...,
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    if num_frames <= context_size:
        yield list(range(num_frames))
        return

    context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)

    for context_step in 1 << np.arange(context_stride):
        pad = int(round(num_frames * ordered_halving(step)))
        for j in range(
            int(ordered_halving(step) * context_step) + pad,
            num_frames + pad + (0 if closed_loop else -context_overlap),
            (context_size * context_step - context_overlap),
        ):
            yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]

#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
def uniform_standard(
    step: int = ...,
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    windows = []
    if num_frames <= context_size:
        windows.append(list(range(num_frames)))
        return windows

    context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)

    for context_step in 1 << np.arange(context_stride):
        pad = int(round(num_frames * ordered_halving(step)))
        for j in range(
            int(ordered_halving(step) * context_step) + pad,
            num_frames + pad + (0 if closed_loop else -context_overlap),
            (context_size * context_step - context_overlap),
        ):
            windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)])

    # now that windows are created, shift any windows that loop, and delete duplicate windows
    delete_idxs = []
    win_i = 0
    while win_i < len(windows):
        # if window is rolls over itself, need to shift it
        is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
        if is_roll:
            roll_val = windows[win_i][roll_idx]  # roll_val might not be 0 for windows of higher strides
            shift_window_to_end(windows[win_i], num_frames=num_frames)
            # check if next window (cyclical) is missing roll_val
            if roll_val not in windows[(win_i+1) % len(windows)]:
                # need to insert new window here - just insert window starting at roll_val
                windows.insert(win_i+1, list(range(roll_val, roll_val + context_size)))
        # delete window if it's not unique
        for pre_i in range(0, win_i):
            if windows[win_i] == windows[pre_i]:
                delete_idxs.append(win_i)
                break
        win_i += 1

    # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
    delete_idxs.reverse()
    for i in delete_idxs:
        windows.pop(i)
    return windows

def static_standard(
    step: int = ...,
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    windows = []
    if num_frames <= context_size:
        windows.append(list(range(num_frames)))
        return windows
    # always return the same set of windows
    delta = context_size - context_overlap
    for start_idx in range(0, num_frames, delta):
        # if past the end of frames, move start_idx back to allow same context_length
        ending = start_idx + context_size
        if ending >= num_frames:
            final_delta = ending - num_frames
            final_start_idx = start_idx - final_delta
            windows.append(list(range(final_start_idx, final_start_idx + context_size)))
            break
        windows.append(list(range(start_idx, start_idx + context_size)))
    return windows

def get_context_scheduler(name: str) -> Callable:
    if name == "uniform_looped":
        return uniform_looped
    elif name == "uniform_standard":
        return uniform_standard
    elif name == "static_standard":
        return static_standard
    else:
        raise ValueError(f"Unknown context_overlap policy {name}")


def get_total_steps(
    scheduler,
    timesteps: List[int],
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
):
    return sum(
        len(
            list(
                scheduler(
                    i,
                    num_steps,
                    num_frames,
                    context_size,
                    context_stride,
                    context_overlap,
                )
            )
        )
        for i in range(len(timesteps))
    )

def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False, window_type="linear"):
    window_mask = torch.ones_like(noise_pred_context)
    
    if window_type == "pyramid":
        # Create pyramid weights that peak in the middle
        length = noise_pred_context.shape[1]
        if length % 2 == 0:
            max_weight = length // 2
            weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
        else:
            max_weight = (length + 1) // 2
            weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
        
        # Normalize weights to range from 0 to 1
        max_val = max(weight_sequence)
        weight_sequence = [w / max_val for w in weight_sequence]
        
        # Apply the weights to create the mask
        weights_tensor = torch.tensor(weight_sequence, device=noise_pred_context.device)
        weights_tensor = weights_tensor.view(1, -1, 1, 1)
        window_mask = weights_tensor.expand_as(window_mask).clone()
        
        # Adjust for position in sequence if needed
        if not looped:
            if min(c) == 0:  # First chunk
                left_ramp = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device).view(1, -1, 1, 1)
                # Clone to avoid in-place memory conflict
                left_section = window_mask[:, :context_overlap].clone()
                window_mask[:, :context_overlap] = torch.maximum(left_section, left_ramp)
                
            if max(c) == latent_video_length - 1:  # Last chunk
                right_ramp = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device).view(1, -1, 1, 1)
                # Clone to avoid in-place memory conflict
                right_section = window_mask[:, -context_overlap:].clone()
                window_mask[:, -context_overlap:] = torch.maximum(right_section, right_ramp)
    else:  # Original "linear" window masking
        # Apply left-side blending for all except first chunk (or always in loop mode)
        if min(c) > 0 or (looped and max(c) == latent_video_length - 1):
            ramp_up = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device)
            ramp_up = ramp_up.view(1, -1, 1, 1)
            window_mask[:, :context_overlap] = ramp_up
            
        # Apply right-side blending for all except last chunk (or always in loop mode)
        if max(c) < latent_video_length - 1 or (looped and min(c) == 0):
            ramp_down = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device)
            ramp_down = ramp_down.view(1, -1, 1, 1)
            window_mask[:, -context_overlap:] = ramp_down
            
    return window_mask

class WindowTracker:
    def __init__(self, verbose=False):
        self.window_map = {}  # Maps frame sequence to persistent ID
        self.next_id = 0
        self.cache_states = {}  # Maps persistent ID to teacache state
        self.verbose = verbose
    
    def get_window_id(self, frames):
        key = tuple(sorted(frames))  # Order-independent frame sequence
        if key not in self.window_map:
            self.window_map[key] = self.next_id
            if self.verbose:
                log.info(f"New window pattern {key} -> ID {self.next_id}")
            self.next_id += 1
        return self.window_map[key]
    
    def get_teacache(self, window_id, base_state):
        if window_id not in self.cache_states:
            if self.verbose:
                log.info(f"Initializing persistent teacache for window {window_id}")
            self.cache_states[window_id] = base_state.copy()
        return self.cache_states[window_id]
