Commit ea508a3a authored by hezhiqiang01's avatar hezhiqiang01

fix share_policy = false的错误

parent f2073aa3
*.pyc
results
.*
\ No newline at end of file
This diff is collapsed.
......@@ -36,30 +36,30 @@ class DiscreteActionEnv(object):
share_obs_dim = 0
total_action_space = []
for agent in range(self.num_agent):
for agent_idx in range(self.num_agent):
# physical action space
u_action_space = spaces.Discrete(self.signal_action_dim) # 5个离散的动作
if self.movable:
# if self.movable:
total_action_space.append(u_action_space)
# total action space
if len(total_action_space) > 1:
# all action spaces are discrete, so simplify to MultiDiscrete action space
if all(
[
isinstance(act_space, spaces.Discrete)
for act_space in total_action_space
]
):
act_space = MultiDiscrete(
[[0, act_space.n - 1] for act_space in total_action_space]
)
else:
act_space = spaces.Tuple(total_action_space)
self.action_space.append(act_space)
else:
self.action_space.append(total_action_space[0])
# if len(total_action_space) > 1:
# # all action spaces are discrete, so simplify to MultiDiscrete action space
# if all(
# [
# isinstance(act_space, spaces.Discrete)
# for act_space in total_action_space
# ]
# ):
# act_space = MultiDiscrete(
# [[0, act_space.n - 1] for act_space in total_action_space]
# )
# else:
# act_space = spaces.Tuple(total_action_space)
# self.action_space.append(act_space)
# else:
self.action_space.append(total_action_space[agent_idx])
# observation space
share_obs_dim += self.signal_obs_dim
......@@ -73,9 +73,7 @@ class DiscreteActionEnv(object):
) # [-inf,inf]
self.share_observation_space = [
spaces.Box(
low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32
)
spaces.Box(low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32)
for _ in range(self.num_agent)
]
......@@ -135,12 +133,7 @@ class MultiDiscrete:
"""Returns a array with one sample from each discrete action space"""
# For each row: round(random .* (max - min) + min, 0)
random_array = np.random.rand(self.num_discrete_space)
return [
int(x)
for x in np.floor(
np.multiply((self.high - self.low + 1.0), random_array) + self.low
)
]
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.0), random_array) + self.low)]
def contains(self, x):
return (
......@@ -157,9 +150,7 @@ class MultiDiscrete:
return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal(
self.high, other.high
)
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
if __name__ == "__main__":
......
......@@ -46,7 +46,7 @@ class DummyVecEnv():
return obs, rews, dones, infos
def reset(self):
obs = [env.reset() for env in self.envs]
obs = [env.reset() for env in self.envs] # [env_num, agent_num, obs_dim]
return np.array(obs)
def close(self):
......
import time
import wandb
import os
import numpy as np
from itertools import chain
......@@ -10,17 +8,18 @@ from tensorboardX import SummaryWriter
from utils.separated_buffer import SeparatedReplayBuffer
from utils.util import update_linear_schedule
def _t2n(x):
return x.detach().cpu().numpy()
class Runner(object):
def __init__(self, config):
self.all_args = config['all_args']
self.envs = config['envs']
self.eval_envs = config['eval_envs']
self.device = config['device']
self.num_agents = config['num_agents']
self.all_args = config["all_args"]
self.envs = config["envs"]
self.eval_envs = config["eval_envs"]
self.device = config["device"]
self.num_agents = config["num_agents"]
# parameters
self.env_name = self.all_args.env_name
......@@ -34,7 +33,6 @@ class Runner(object):
self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads
self.use_linear_lr_decay = self.all_args.use_linear_lr_decay
self.hidden_size = self.all_args.hidden_size
self.use_wandb = self.all_args.use_wandb
self.use_render = self.all_args.use_render
self.recurrent_N = self.all_args.recurrent_N
......@@ -49,37 +47,42 @@ class Runner(object):
if self.use_render:
import imageio
self.run_dir = config["run_dir"]
self.gif_dir = str(self.run_dir / 'gifs')
self.gif_dir = str(self.run_dir / "gifs")
if not os.path.exists(self.gif_dir):
os.makedirs(self.gif_dir)
else:
if self.use_wandb:
self.save_dir = str(wandb.run.dir)
else:
# if self.use_wandb:
# self.save_dir = str(wandb.run.dir)
# else:
self.run_dir = config["run_dir"]
self.log_dir = str(self.run_dir / 'logs')
self.log_dir = str(self.run_dir / "logs")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.writter = SummaryWriter(self.log_dir)
self.save_dir = str(self.run_dir / 'models')
self.save_dir = str(self.run_dir / "models")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
self.policy = []
for agent_id in range(self.num_agents):
share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]
share_observation_space = (
self.envs.share_observation_space[agent_id]
if self.use_centralized_V
else self.envs.observation_space[agent_id]
)
# policy network
po = Policy(self.all_args,
po = Policy(
self.all_args,
self.envs.observation_space[agent_id],
share_observation_space,
self.envs.action_space[agent_id],
device = self.device)
device=self.device,
)
self.policy.append(po)
if self.model_dir is not None:
......@@ -89,13 +92,19 @@ class Runner(object):
self.buffer = []
for agent_id in range(self.num_agents):
# algorithm
tr = TrainAlgo(self.all_args, self.policy[agent_id], device = self.device)
tr = TrainAlgo(self.all_args, self.policy[agent_id], device=self.device)
# buffer
share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]
bu = SeparatedReplayBuffer(self.all_args,
share_observation_space = (
self.envs.share_observation_space[agent_id]
if self.use_centralized_V
else self.envs.observation_space[agent_id]
)
bu = SeparatedReplayBuffer(
self.all_args,
self.envs.observation_space[agent_id],
share_observation_space,
self.envs.action_space[agent_id])
self.envs.action_space[agent_id],
)
self.buffer.append(bu)
self.trainer.append(tr)
......@@ -115,9 +124,11 @@ class Runner(object):
def compute(self):
for agent_id in range(self.num_agents):
self.trainer[agent_id].prep_rollout()
next_value = self.trainer[agent_id].policy.get_values(self.buffer[agent_id].share_obs[-1],
next_value = self.trainer[agent_id].policy.get_values(
self.buffer[agent_id].share_obs[-1],
self.buffer[agent_id].rnn_states_critic[-1],
self.buffer[agent_id].masks[-1])
self.buffer[agent_id].masks[-1],
)
next_value = _t2n(next_value)
self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer)
......@@ -134,30 +145,39 @@ class Runner(object):
def save(self):
for agent_id in range(self.num_agents):
policy_actor = self.trainer[agent_id].policy.actor
torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt")
torch.save(
policy_actor.state_dict(),
str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt",
)
policy_critic = self.trainer[agent_id].policy.critic
torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_agent" + str(agent_id) + ".pt")
torch.save(
policy_critic.state_dict(),
str(self.save_dir) + "/critic_agent" + str(agent_id) + ".pt",
)
def restore(self):
for agent_id in range(self.num_agents):
policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor_agent' + str(agent_id) + '.pt')
policy_actor_state_dict = torch.load(str(self.model_dir) + "/actor_agent" + str(agent_id) + ".pt")
self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict)
policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic_agent' + str(agent_id) + '.pt')
policy_critic_state_dict = torch.load(
str(self.model_dir) + "/critic_agent" + str(agent_id) + ".pt"
)
self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict)
def log_train(self, train_infos, total_num_steps):
for agent_id in range(self.num_agents):
for k, v in train_infos[agent_id].items():
agent_k = "agent%i/" % agent_id + k
if self.use_wandb:
wandb.log({agent_k: v}, step=total_num_steps)
else:
# if self.use_wandb:
# pass
# wandb.log({agent_k: v}, step=total_num_steps)
# else:
self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps)
def log_env(self, env_infos, total_num_steps):
for k, v in env_infos.items():
if len(v) > 0:
if self.use_wandb:
wandb.log({k: np.mean(v)}, step=total_num_steps)
else:
# if self.use_wandb:
# wandb.log({k: np.mean(v)}, step=total_num_steps)
# else:
self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)
This diff is collapsed.
This diff is collapsed.
......@@ -31,11 +31,15 @@ def make_train_env(all_args):
def init_env():
# TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。
# TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines.
from envs.env_continuous import ContinuousActionEnv
env = ContinuousActionEnv()
# from envs.env_discrete import DiscreteActionEnv
# env = DiscreteActionEnv()
env.seed(all_args.seed + rank * 1000)
return env
......@@ -63,9 +67,7 @@ def make_eval_env(all_args):
def parse_args(args, parser):
parser.add_argument(
"--scenario_name", type=str, default="MyEnv", help="Which scenario to run on"
)
parser.add_argument("--scenario_name", type=str, default="MyEnv", help="Which scenario to run on")
parser.add_argument("--num_landmarks", type=int, default=3)
parser.add_argument("--num_agents", type=int, default=2, help="number of players")
......@@ -79,20 +81,16 @@ def main(args):
all_args = parse_args(args, parser)
if all_args.algorithm_name == "rmappo":
assert (
all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy
), "check recurrent policy!"
assert all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy, "check recurrent policy!"
elif all_args.algorithm_name == "mappo":
assert (
all_args.use_recurrent_policy == False
and all_args.use_naive_recurrent_policy == False
all_args.use_recurrent_policy == False and all_args.use_naive_recurrent_policy == False
), "check recurrent policy!"
else:
raise NotImplementedError
assert (
all_args.share_policy == True
and all_args.scenario_name == "simple_speaker_listener"
all_args.share_policy == True and all_args.scenario_name == "simple_speaker_listener"
) == False, "The simple_speaker_listener scenario can not use shared policy. Please check the config.py."
# cuda
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment