# Modify from https://github.com/liyunsheng13/dcd/blob/main/models/imagenet/mobilenetv2_dcd.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class Hsigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3., inplace=self.inplace) / 3.


class DYModule(nn.Module):
    def __init__(self, inp, oup, fc_squeeze=8):
        super(DYModule, self).__init__()
        self.conv = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
        if inp < oup:
            self.mul = 4
            reduction = 8
            self.avg_pool = nn.AdaptiveAvgPool2d(2)
        else:
            self.mul = 1
            reduction = 2
            self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.dim = min((inp * self.mul) // reduction, oup // reduction)
        while self.dim ** 2 > inp * self.mul * 2:
            reduction *= 2
            self.dim = min((inp * self.mul) // reduction, oup // reduction)
        if self.dim < 4:
            self.dim = 4

        squeeze = max(inp * self.mul, self.dim ** 2) // fc_squeeze
        if squeeze < 4:
            squeeze = 4
        self.conv_q = nn.Conv2d(inp, self.dim, 1, 1, 0, bias=False)

        self.fc = nn.Sequential(
            nn.Linear(inp * self.mul, squeeze, bias=False),
            SEModule_small(squeeze),
        )
        self.fc_phi = nn.Linear(squeeze, self.dim ** 2, bias=False)
        self.fc_scale = nn.Linear(squeeze, oup, bias=False)
        self.hs = Hsigmoid()
        self.conv_p = nn.Conv2d(self.dim, oup, 1, 1, 0, bias=False)
        # self.bn1 = nn.BatchNorm2d(self.dim)
        self.bn1 = nn.GroupNorm(num_groups=4, num_channels=self.dim)
        # self.bn2 = nn.BatchNorm1d(self.dim)
        self.bn2 = nn.GroupNorm(num_groups=4, num_channels=self.dim)

    def forward(self, x):
        x_type = x.dtype
        r = self.conv(x.to(self.conv.weight.dtype)).to(x_type)
        b, c, h, w = x.size()

        y = self.avg_pool(x).view(b, c * self.mul)
        y = self.fc(y)
        dy_phi = self.fc_phi(y).view(b, self.dim, self.dim)
        dy_scale = self.hs(self.fc_scale(y)).view(b, -1, 1, 1)
        r = dy_scale.expand_as(r) * r

        x = self.conv_q(x.to(self.conv_q.weight.dtype)).to(self.bn1.weight.dtype)
        x = self.bn1(x)

        x = x.view(b, -1, h * w)
        x = x + self.bn2(torch.matmul(dy_phi, x.to(dy_phi.dtype)).to(self.bn2.weight.dtype))
        x = x.view(b, -1, h, w)

        x = self.conv_p(x.to(self.conv_p.weight.dtype)).to(x_type)

        return x + r


class SEModule_small(nn.Module):
    def __init__(self, channel):
        super(SEModule_small, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channel, channel, bias=False),
            Hsigmoid()
        )

    def forward(self, x):
        y = self.fc(x)
        return x * y


class SEModule(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            Hsigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
