
import torch
from ..utils import log
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from tqdm import tqdm
import gc

from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
import folder_paths


class WanVideoUni3C_ControlnetLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}),

            "base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
            "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}),
            "load_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
            "attention_mode": ([
                    "sdpa",
                    "sageattn",
                    ], {"default": "sdpa"}),
            },
            "optional": {
                "compile_args": ("WANCOMPILEARGS", ),
                #"block_swap_args": ("BLOCKSWAPARGS", ),
            }
        }

    RETURN_TYPES = ("WANVIDEOCONTROLNET",)
    RETURN_NAMES = ("controlnet", )
    FUNCTION = "loadmodel"
    CATEGORY = "WanVideoWrapper"

    def loadmodel(self, model, base_precision, load_device, quantization, attention_mode, compile_args=None):

        device = mm.get_torch_device()
        offload_device = mm.unet_offload_device()

        transformer_load_device = device if load_device == "main_device" else offload_device

        base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision]


        model_path = folder_paths.get_full_path_or_raise("controlnet", model)

        sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)

        if not "controlnet_patch_embedding.weight" in sd:
            raise ValueError("Invalid ControlNet model")

        in_channels = sd["controlnet_patch_embedding.weight"].shape[1]
        ffn_dim = sd["controlnet_blocks.0.ffn.0.bias"].shape[0]

        controlnet_cfg = {
            "in_channels": in_channels,
            "conv_out_dim": 5120,
            "time_embed_dim": 5120,
            "dim": 1024,
            "ffn_dim": ffn_dim,
            "num_heads": 16,
            "num_layers": 20,
            "add_channels": 7,
            "mid_channels": 256,
            "attention_mode": attention_mode,
            "quantized": True if quantization != "disabled" else False,
            "base_dtype": base_dtype
        }

        from .controlnet import WanControlNet

        with init_empty_weights():
            controlnet = WanControlNet(controlnet_cfg)
        controlnet.eval()

        if quantization == "disabled":
            for k, v in sd.items():
                if isinstance(v, torch.Tensor):
                    if v.dtype == torch.float8_e4m3fn:
                        quantization = "fp8_e4m3fn"
                        break
                    elif v.dtype == torch.float8_e5m2:
                        quantization = "fp8_e5m2"
                        break

        if "fp8_e4m3fn" in quantization:
            dtype = torch.float8_e4m3fn
        elif quantization == "fp8_e5m2":
            dtype = torch.float8_e5m2
        else:
            dtype = base_dtype
        params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "proj_in"}

        log.info("Using accelerate to load and assign controlnet model weights to device...")
        param_count = sum(1 for _ in controlnet.named_parameters())
        for name, param in tqdm(controlnet.named_parameters(),
                desc=f"Loading transformer parameters to {transformer_load_device}",
                total=param_count,
                leave=True):
            dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
            if "controlnet_patch_embedding" in name:
                dtype_to_use = torch.float32
            set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])

        del sd

        if compile_args is not None:
            if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'):
                torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
                torch._dynamo.config.force_parameter_static_shapes = compile_args["force_parameter_static_shapes"]
                try:
                    torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"]
                except Exception as e:
                    log.warning(f"Could not set recompile_limit: {e}")
            if compile_args["compile_transformer_blocks_only"]:
                for i, block in enumerate(controlnet.controlnet_blocks):
                    controlnet.controlnet_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
            else:
                controlnet = torch.compile(controlnet, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])


        if load_device == "offload_device" and controlnet.device != offload_device:
            log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}")
            controlnet.to(offload_device)
            gc.collect()
            mm.soft_empty_cache()

        return (controlnet,)

class WanVideoUni3C_embeds:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "controlnet": ("WANVIDEOCONTROLNET",),
            "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
            "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply the controlnet"}),
            "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply the controlnet"}),
            },
            "optional": {
                "render_latent": ("LATENT",),
                "render_mask": ("MASK", {"tooltip": "NOT IMPLEMENTED!"}),
                "offload": ("BOOLEAN", {"default": True, "tooltip": "If enabled, the controlnet model will be offloaded before main model block processing to save VRAM."}),
            },
        }

    RETURN_TYPES = ("UNI3C_EMBEDS", )
    RETURN_NAMES = ("uni3c_embeds",)
    FUNCTION = "process"
    CATEGORY = "WanVideoWrapper"

    def process(self, controlnet, strength, start_percent, end_percent, render_latent=None, render_mask=None, offload=True):

        latent_mask = latents = None
        if render_latent is not None:
            latents = render_latent["samples"]
            # nframe = latents.shape[2] * 4
            # height = latents.shape[3] * 8
            # width = latents.shape[4] * 8

        if render_mask is not None:
            raise NotImplementedError("render_mask is not implemented at this time")
            # mask = torch.nn.functional.interpolate(
            #         render_mask.unsqueeze(0).unsqueeze(0),  # Add batch and channel dims [1,1,T,H,W]
            #         size=(nframe, height, width),
            #         mode='trilinear',
            #         align_corners=False
            #     ).squeeze(0)
            # latent_mask = mask.unsqueeze(0).to(device)
            # log.info(f"latent mask shape {latent_mask.shape}")

        # # load camera
        # cam_info = json.load(open(f"{render_path}/cam_info.json"))
        # w2cs = torch.tensor(np.array(cam_info["extrinsic"]), dtype=torch.float32, device=device)
        # intrinsic = torch.tensor(np.array(cam_info["intrinsic"]), dtype=torch.float32, device=device)
        # intrinsic[0, :] = intrinsic[0, :] / cam_info["width"] * width
        # intrinsic[1, :] = intrinsic[1, :] / cam_info["height"] * height
        # intrinsic = intrinsic[None].repeat(nframe, 1, 1)

        # from .utils import build_cameras, set_initial_camera, traj_map

        # focal_length = 1.0
        # start_elevation = 5.0
        # depth_avg = 0.5
        # traj_type = "orbit"
        # cam_traj, x_offset, y_offset, z_offset, d_theta, d_phi, d_r = traj_map(traj_type)
        # focallength_px = focal_length * width

        # K = torch.tensor([[focallength_px, 0, width / 2],
        #                   [0, focallength_px, height / 2],
        #                   [0, 0, 1]], dtype=torch.float32)
        # K_inv = K.inverse()
        # intrinsic = K[None].repeat(nframe, 1, 1)


        # w2c_0, c2w_0 = set_initial_camera(start_elevation, depth_avg)
        # w2cs, c2ws, intrinsic = build_cameras(cam_traj=cam_traj,
        #                                     w2c_0=w2c_0,
        #                                     c2w_0=c2w_0,
        #                                     intrinsic=intrinsic,
        #                                     nframe=nframe,
        #                                     focal_length=focal_length,
        #                                     d_theta=d_theta,
        #                                     d_phi=d_phi,
        #                                     d_r=d_r,
        #                                     radius=depth_avg,
        #                                     x_offset=x_offset,
        #                                     y_offset=y_offset,
        #                                     z_offset=z_offset)


        # from .camera import get_camera_embedding
        # camera_embedding = get_camera_embedding(intrinsic, w2cs, nframe, height, width, normalize=True)
        #print("camera embedding shape", camera_embedding.shape)

        uni3c_embeds = {
            "controlnet": controlnet,
            "controlnet_weight": strength,
            "start": start_percent,
            "end": end_percent,
            "render_latent": latents,
            "render_mask": latent_mask,
            "camera_embedding": None,
            "offload": offload,
        }

        return (uni3c_embeds,)

NODE_CLASS_MAPPINGS = {
    "WanVideoUni3C_ControlnetLoader": WanVideoUni3C_ControlnetLoader,
    "WanVideoUni3C_embeds": WanVideoUni3C_embeds,
    }
NODE_DISPLAY_NAME_MAPPINGS = {
    "WanVideoUni3C_ControlnetLoader": "WanVideo Uni3C Controlnet Loader",
    "WanVideoUni3C_embeds": "WanVideo Uni3C Embeds",
    }
