import torch
import torch.nn as nn


class FactorConv3d(nn.Module):
    """
    (2+1)D decomposition of 3D convolution: 1xHxW spatial convolution → Swish → Tx1x1 temporal convolution
    """
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size,
                 stride: int = 1,
                 dilation: int = 1):
        super().__init__()

        if isinstance(kernel_size, int):
            k_t, k_h, k_w = kernel_size, kernel_size, kernel_size
        else:
            k_t, k_h, k_w = kernel_size

        pad_t  = (k_t - 1) * dilation // 2
        pad_hw = (k_h - 1) * dilation // 2

        self.spatial = nn.Conv3d(
            in_channels, in_channels,
            kernel_size=(1, k_h, k_w),
            stride=(1, stride, stride),
            padding=(0, pad_hw, pad_hw),
            dilation=(1, dilation, dilation),
            groups=in_channels,
            bias=False
        )

        self.temporal = nn.Conv3d(
            in_channels, out_channels,
            kernel_size=(k_t, 1, 1),
            stride=(stride, 1, 1),
            padding=(pad_t, 0, 0),
            dilation=(dilation, 1, 1),
            bias=True
        )

        self.act = nn.SiLU()

    def forward(self, x):
        out_dtype = x.dtype
        x = self.spatial(x.to(self.spatial.weight.dtype)).to(out_dtype)
        x = self.act(x)
        return self.temporal(x.to(self.temporal.weight.dtype)).to(out_dtype)

class LayerNorm2D(nn.Module):
    """
    LayerNorm over C for a 4-D tensor (B, C, H, W)
    """
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super().__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
            self.bias   = nn.Parameter(torch.zeros(1, num_channels, 1, 1))

    def forward(self, x):
        # x: (B, C, H, W)
        mean = x.mean(dim=1, keepdim=True)        # (B, 1, H, W)
        var  = x.var (dim=1, keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + self.eps)
        if self.affine:
            x = x * self.weight + self.bias
        return x


class PoseRefNetNoBNV3(nn.Module):
    def __init__(self,
                 in_channels_c: int,
                 in_channels_x: int,
                 hidden_dim: int = 256,
                 num_heads: int = 8,
                 dropout: float = 0.1):
        super().__init__()
        self.d_model = hidden_dim
        self.nhead = num_heads

        self.proj_p = nn.Conv2d(in_channels_c, hidden_dim, kernel_size=1)
        self.proj_r = nn.Conv2d(in_channels_x, hidden_dim, kernel_size=1)

        self.proj_p_back = nn.Conv2d(hidden_dim, in_channels_c, kernel_size=1)

        self.cross_attn = nn.MultiheadAttention(hidden_dim,
                                                num_heads=num_heads,
                                                dropout=dropout)

        self.ffn_pose = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1)
        )

        self.norm1 = LayerNorm2D(hidden_dim)
        self.norm2 = LayerNorm2D(hidden_dim)

    def forward(self, pose, ref, mask=None):
        """
        pose : (B, C1, T, H, W)
        ref  : (B, C2, T, H, W)
        mask : (B, T*H*W) optional key_padding_mask
        return: (B, d_model, T, H, W)
        """
        B, _, T, H, W = pose.shape

        p_trans = pose.permute(0, 2, 1, 3, 4).contiguous().flatten(0, 1)
        r_trans = ref.permute(0, 2, 1, 3, 4).contiguous().flatten(0, 1)

        p_trans = self.proj_p(p_trans.to(self.proj_p.weight.dtype)).to(self.cross_attn.in_proj_weight.dtype).flatten(2).transpose(1, 2)
        r_trans = self.proj_r(r_trans.to(self.proj_r.weight.dtype)).to(self.cross_attn.in_proj_weight.dtype).flatten(2).transpose(1, 2)

        out = self.cross_attn(query=r_trans,
                              key=p_trans,
                              value=p_trans,
                              key_padding_mask=mask)[0]

        out = self.norm1(out.transpose(1, 2).contiguous().view(B*T, -1, H, W))

        out_type = out.dtype

        out = out + self.ffn_pose(out.to(self.ffn_pose[0].weight.dtype)).to(out_type)
        out = self.norm2(out)

        out = self.proj_p_back(out.to(self.proj_p_back.weight.dtype)).to(out_type)

        return out.view(B, T, -1, H, W).contiguous().transpose(1, 2)
