import gym
from gym import spaces
import numpy as np
from envs.env_core import EnvCore
from envs.env_discrete import MultiDiscrete


class ContinuousActionEnv(object):
    """
    对于连续动作环境的封装
    Wrapper for continuous action environment.
    """

    def __init__(self):
        self.env = EnvCore()
        self.num_agent = self.env.AerialVehiclesNum + 2

        self.signal_obs_dim = self.env.obs_dim
        self.signal_action_dim = self.env.action_dim
        print( self.num_agent)

        # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
        self.discrete_action_input = False

        self.movable = True

        # configure spaces
        self.action_space = []
        self.observation_space = []
        self.share_observation_space = []

        share_obs_dim = 0
        total_action_space = []

        for agent in range(self.env.AerialVehiclesNum):
            # physical action space
            u_action_space = spaces.Box(
                low=-np.inf,
                high=+np.inf,
                shape=(self.signal_action_dim,),
                dtype=np.float32,
            )

            if self.movable:
                total_action_space.append(u_action_space)

            # total action space
            self.action_space.append(total_action_space[0])

            # observation space
            share_obs_dim += self.signal_obs_dim
            self.observation_space.append(
                spaces.Box(
                    low=-np.inf,
                    high=+np.inf,
                    shape=(self.signal_obs_dim,),
                    dtype=np.float32,
                )
            )  # [-inf,inf]
        print('section1')

        mu_action = []
        for i in range(self.env.CustomerNum * self.env.AerialVehiclesNum):
           mu_action.append([0, 1])
        u_action_space = MultiDiscrete(mu_action)

        if self.movable:
            total_action_space.append(u_action_space)

        # total action space
        self.action_space.append(u_action_space)

        # observation space
        share_obs_dim += self.env.getDimension2()
        self.observation_space.append(
            spaces.Box(
                low=-np.inf,
                high=+np.inf,
                shape=(self.env.getDimension2(),),
                dtype=np.float32,
            )
        )

        print('section2')


        mu_action = []

        mu_action.append([0, self.env.ContentNum - 1])
        u_action_space = spaces.Discrete(self.env.ContentNum)

        if self.movable:
            total_action_space.append(u_action_space)

            # total action space
        self.action_space.append(u_action_space)

        # observation space
        share_obs_dim += (self.env.CacheNum + self.env.CustomerNum)
        self.observation_space.append(
            spaces.Box(
                low=-np.inf,
                high=+np.inf,
                shape=(self.env.CacheNum + self.env.CustomerNum,),
                dtype=np.float32,
            )
        )




        print("ons")
        print(share_obs_dim)
        self.share_observation_space = [
            spaces.Box(
                low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32
            )
            for _ in range(self.num_agent)
        ]


    def step(self, actions):
        """
        输入actions维度假设：
        # actions shape = (5, 2, 5)
        # 5个线程的环境，里面有2个智能体，每个智能体的动作是一个one_hot的5维编码

        Input actions dimension assumption:
        # actions shape = (5, 2, 5)
        # 5 threads of environment, there are 2 agents inside, and each agent's action is a 5-dimensional one_hot encoding
        """

        results = self.env.step(actions)
        obs, rews, dones, infos = results
        return obs, np.stack(rews), np.stack(dones), infos

    def reset(self):
        obs = self.env.reset()
        return obs

    def close(self):
        pass

    def render(self, mode="rgb_array"):
        pass

    def seed(self, seed):
        pass
