import math

import numpy
import numpy as np
import torch
import torch.nn as nn
from torchvision.models.video.mvit import PositionalEncoding

from envs.env_core import EnvCore
from .util import init, get_clones
from transformers import GPT2Config, GPT2Model
"""MLP modules."""
from transformers import EncodecModel
Feature = 4

class MLPLayer(nn.Module):
    def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU):
        super(MLPLayer, self).__init__()
        self._layer_N = layer_N

        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)

        self.fc1 = nn.Sequential(
            init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        self.fc_h = nn.Sequential(init_(
            nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        self.fc2 = get_clones(self.fc_h, self._layer_N)

    def forward(self, x):
        x = self.fc1(x)
        for i in range(self._layer_N):
            x = self.fc2[i](x)
        return x

class MLPLayer2(nn.Module):
    def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU):
        super(MLPLayer2, self).__init__()
        self._layer_N = layer_N

        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)

        self.fc1 = nn.Sequential(
            init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        self.fc2 = nn.Linear(hidden_size, Feature)


    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


class MLPBase(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(MLPBase, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.mlp = MLPLayer(obs_dim, self.hidden_size,
                              self._layer_N, self._use_orthogonal, self._use_ReLU)


    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)

        x = self.mlp(x)

        return x

class MLPBaseGPT2(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(MLPBaseGPT2, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        config = GPT2Config(n_embd=128,
                            n_layer=6,
                            n_head=8)
        self.mlp = new_mlp_gpt2(obs_dim,
                    128,
                    1,
                    config = config)




    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)

        x = self.mlp(x)

        return x


class MLPBaseWithTrans(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(MLPBaseWithTrans, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.mlp = MLPLayer((EnvCore.AerialVehiclesNum+3) * Feature, self.hidden_size,
                              self._layer_N, self._use_orthogonal, self._use_ReLU)

        self.transLayer = TransLayer(args, obs_shape)

    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)

        x = self.transLayer(x)

        x = x[0]
        res = []
        for i in  range (0, len(x)):
            for j in  range (0, Feature):
                res.append(x[i][j])
        w = []
        w.append(res)
        w = torch.Tensor(w)


        x = self.mlp(w)

        return x



class TransLayer(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(TransLayer, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.uavLayer = MLPLayer2(EnvCore.dimension1, self.hidden_size,
                                  self._layer_N, self._use_orthogonal, self._use_ReLU)
        self.cacheLayer = MLPLayer2(EnvCore.dimension2, self.hidden_size,
                                  self._layer_N, self._use_orthogonal, self._use_ReLU)
        self.relationLayer = MLPLayer2(EnvCore.dimension3, self.hidden_size,
                                  self._layer_N, self._use_orthogonal, self._use_ReLU)
        self.kLayer = MLPLayer2(EnvCore.dimension4, self.hidden_size,
                                  self._layer_N, self._use_orthogonal, self._use_ReLU)

        self.attentionLayer = MultiHeadAttention(Feature, 6, 64, 64)

    def forward(self, x):

        length = EnvCore.AerialVehiclesNum + 1
        start = 0
        result = []


        for i in range(0, EnvCore.AerialVehiclesNum):
            w = x[0][start:start + length]
            start = start + length
            result.append(self.uavLayer(w))

        length = EnvCore.dimension2
        w = x[0][start:start + length]
        start = start + length
        result.append(self.cacheLayer(w))

        length = EnvCore.dimension3
        w = x[0][start:start + length]
        start = start + length
        result.append(self.relationLayer(w))

        length = EnvCore.dimension4
        w = x[0][start:start + length]
        start = start + length
        result.append(self.kLayer(w))
        list_mask = []
        list_mask.append(result)
        list_mask = torch.tensor([item.detach().numpy() for item in result])
        list_mask2 = []
        list_mask2.append(list_mask)
        list_mask = torch.tensor([item.detach().numpy() for item in list_mask2])

        list_mask = torch.Tensor(list_mask)

        mask = get_attn_pad_mask(list_mask, list_mask)


        list_attn2 = list_mask
        s1, s2 = self.attentionLayer(list_attn2,list_attn2,list_attn2,mask)

        return s1



class TransBase(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(TransBase, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_featrue_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.trans = TransLayer(obs_dim, self.hidden_size,
                              self._layer_N, self._use_orthogonal, self._use_ReLU)

    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)

        x = self.trans(x)

        return x


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

d_model = Feature
n_heads= 6
d_k = d_v = 64
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model , n_heads, d_k , d_v ):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]

        return output + residual, attn


class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9)  # Fills elements of self tensor with value where mask is True.

        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)  # [batch_size, n_heads, len_q, d_v]
        return context, attn


def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]


def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q,_ = seq_q.size()

    batch_size, len_k,_ = seq_k.size()
    # eq(zero) is PAD token
    data = []
    for i in range (0,len_q):
        data.append(False)
    data2 = []
    for i in range (0,len_k):
        data2.append(data)
    data3 = []
    data3.append(data2)

    return torch.tensor(data3)

class new_mlp_gpt2(nn.Module):
    def __init__(self, input_dim, embd_dim, output_dim, config):
        super(new_mlp_gpt2, self).__init__()

        self.linear1 = nn.Linear(input_dim, embd_dim)

        self.gpt2 = GPT2Model(config)

        self.linear2 = nn.Linear(embd_dim, 128)

    def forward(self, x):


        x = self.linear1(x)
        gpt2_output = self.gpt2(inputs_embeds=x)

        x = gpt2_output.last_hidden_state
        x = self.linear2(x)


        return x
