from ..utils import log
import torch

def set_transformer_cache_method(transformer, timesteps, cache_args=None):      
    transformer.cache_device = cache_args["cache_device"]
    if cache_args["cache_type"] == "TeaCache":
        log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
        transformer.teacache_state.clear_all()
        transformer.enable_teacache = True
        transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
        transformer.teacache_start_step = cache_args["start_step"]
        transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.teacache_use_coefficients = cache_args["use_coefficients"]
        transformer.teacache_mode = cache_args["mode"]
    elif cache_args["cache_type"] == "MagCache":
        log.info(f"MagCache: Using cache device: {transformer.cache_device}")
        transformer.magcache_state.clear_all()
        transformer.enable_magcache = True
        transformer.magcache_start_step = cache_args["start_step"]
        transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.magcache_thresh = cache_args["magcache_thresh"]
        transformer.magcache_K = cache_args["magcache_K"]
    elif cache_args["cache_type"] == "EasyCache":
        log.info(f"EasyCache: Using cache device: {transformer.cache_device}")
        transformer.easycache_state.clear_all()
        transformer.enable_easycache = True
        transformer.easycache_start_step = cache_args["start_step"]
        transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
        transformer.easycache_thresh = cache_args["easycache_thresh"]
    return transformer

class TeaCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0
    
    def new_prediction(self, cache_device='cpu'):
        """Create new prediction state and return its ID"""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'previous_residual': None,
            'accumulated_rel_l1_distance': 0,
            'previous_modulated_input': None,
            'skipped_steps': [],
        }
        return pred_id
    
    def update(self, pred_id, **kwargs):
        """Update state for specific prediction"""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value
    
    def get(self, pred_id):
        return self.states.get(pred_id, {})
    
    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

class MagCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0
    
    def new_prediction(self, cache_device='cpu'):
        """Create new prediction state and return its ID"""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'residual_cache': None,
            'accumulated_ratio': 1.0,
            'accumulated_steps': 0,
            'accumulated_err': 0,
            'skipped_steps': [],
        }
        return pred_id
    
    def update(self, pred_id, **kwargs):
        """Update state for specific prediction"""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value
    
    def get(self, pred_id):
        return self.states.get(pred_id, {})
    
    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

class EasyCacheState:
    def __init__(self, cache_device='cpu'):
        self.cache_device = cache_device
        self.states = {}
        self._next_pred_id = 0

    def new_prediction(self, cache_device='cpu'):
        """Create a new prediction state and return its ID."""
        self.cache_device = cache_device
        pred_id = self._next_pred_id
        self._next_pred_id += 1
        self.states[pred_id] = {
            'previous_raw_input': None,
            'previous_raw_output': None,
            'cache': None,
            'accumulated_error': 0.0,
            'skipped_steps': [],
            'cache_ovi': None,
        }
        return pred_id

    def update(self, pred_id, **kwargs):
        """Update state for a specific prediction."""
        if pred_id not in self.states:
            return None
        for key, value in kwargs.items():
            self.states[pred_id][key] = value

    def get(self, pred_id):
        return self.states.get(pred_id, {})

    def clear_all(self):
        self.states = {}
        self._next_pred_id = 0

def relative_l1_distance(last_tensor, current_tensor):
    l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean()
    norm = torch.abs(last_tensor).mean()
    relative_l1_distance = l1_distance / norm
    return relative_l1_distance.to(torch.float32).to(current_tensor.device)

def cache_report(transformer, cache_args):
    cache_type = cache_args["cache_type"]
    states = (
        transformer.teacache_state.states if cache_type == "TeaCache" else
        transformer.magcache_state.states if cache_type == "MagCache" else
        transformer.easycache_state.states if cache_type == "EasyCache" else
        None
    )
    state_names = {
        0: "conditional",
        1: "unconditional"
    }
    for pred_id, state in states.items():
        name = state_names.get(pred_id, f"prediction_{pred_id}")
        if 'skipped_steps' in state:
            log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
    transformer.teacache_state.clear_all()
    transformer.magcache_state.clear_all()
    transformer.easycache_state.clear_all()
    del states