Commit ea508a3a authored by hezhiqiang01's avatar hezhiqiang01

fix share_policy = false的错误

parent f2073aa3
*.pyc *.pyc
results results
.*
\ No newline at end of file
This diff is collapsed.
...@@ -36,30 +36,30 @@ class DiscreteActionEnv(object): ...@@ -36,30 +36,30 @@ class DiscreteActionEnv(object):
share_obs_dim = 0 share_obs_dim = 0
total_action_space = [] total_action_space = []
for agent in range(self.num_agent): for agent_idx in range(self.num_agent):
# physical action space # physical action space
u_action_space = spaces.Discrete(self.signal_action_dim) # 5个离散的动作 u_action_space = spaces.Discrete(self.signal_action_dim) # 5个离散的动作
if self.movable: # if self.movable:
total_action_space.append(u_action_space) total_action_space.append(u_action_space)
# total action space # total action space
if len(total_action_space) > 1: # if len(total_action_space) > 1:
# all action spaces are discrete, so simplify to MultiDiscrete action space # # all action spaces are discrete, so simplify to MultiDiscrete action space
if all( # if all(
[ # [
isinstance(act_space, spaces.Discrete) # isinstance(act_space, spaces.Discrete)
for act_space in total_action_space # for act_space in total_action_space
] # ]
): # ):
act_space = MultiDiscrete( # act_space = MultiDiscrete(
[[0, act_space.n - 1] for act_space in total_action_space] # [[0, act_space.n - 1] for act_space in total_action_space]
) # )
else: # else:
act_space = spaces.Tuple(total_action_space) # act_space = spaces.Tuple(total_action_space)
self.action_space.append(act_space) # self.action_space.append(act_space)
else: # else:
self.action_space.append(total_action_space[0]) self.action_space.append(total_action_space[agent_idx])
# observation space # observation space
share_obs_dim += self.signal_obs_dim share_obs_dim += self.signal_obs_dim
...@@ -73,9 +73,7 @@ class DiscreteActionEnv(object): ...@@ -73,9 +73,7 @@ class DiscreteActionEnv(object):
) # [-inf,inf] ) # [-inf,inf]
self.share_observation_space = [ self.share_observation_space = [
spaces.Box( spaces.Box(low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32)
low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32
)
for _ in range(self.num_agent) for _ in range(self.num_agent)
] ]
...@@ -135,12 +133,7 @@ class MultiDiscrete: ...@@ -135,12 +133,7 @@ class MultiDiscrete:
"""Returns a array with one sample from each discrete action space""" """Returns a array with one sample from each discrete action space"""
# For each row: round(random .* (max - min) + min, 0) # For each row: round(random .* (max - min) + min, 0)
random_array = np.random.rand(self.num_discrete_space) random_array = np.random.rand(self.num_discrete_space)
return [ return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.0), random_array) + self.low)]
int(x)
for x in np.floor(
np.multiply((self.high - self.low + 1.0), random_array) + self.low
)
]
def contains(self, x): def contains(self, x):
return ( return (
...@@ -157,9 +150,7 @@ class MultiDiscrete: ...@@ -157,9 +150,7 @@ class MultiDiscrete:
return "MultiDiscrete" + str(self.num_discrete_space) return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other): def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal( return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
self.high, other.high
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -46,7 +46,7 @@ class DummyVecEnv(): ...@@ -46,7 +46,7 @@ class DummyVecEnv():
return obs, rews, dones, infos return obs, rews, dones, infos
def reset(self): def reset(self):
obs = [env.reset() for env in self.envs] obs = [env.reset() for env in self.envs] # [env_num, agent_num, obs_dim]
return np.array(obs) return np.array(obs)
def close(self): def close(self):
......
import time import time
import wandb
import os import os
import numpy as np import numpy as np
from itertools import chain from itertools import chain
...@@ -10,17 +8,18 @@ from tensorboardX import SummaryWriter ...@@ -10,17 +8,18 @@ from tensorboardX import SummaryWriter
from utils.separated_buffer import SeparatedReplayBuffer from utils.separated_buffer import SeparatedReplayBuffer
from utils.util import update_linear_schedule from utils.util import update_linear_schedule
def _t2n(x): def _t2n(x):
return x.detach().cpu().numpy() return x.detach().cpu().numpy()
class Runner(object): class Runner(object):
def __init__(self, config): def __init__(self, config):
self.all_args = config["all_args"]
self.all_args = config['all_args'] self.envs = config["envs"]
self.envs = config['envs'] self.eval_envs = config["eval_envs"]
self.eval_envs = config['eval_envs'] self.device = config["device"]
self.device = config['device'] self.num_agents = config["num_agents"]
self.num_agents = config['num_agents']
# parameters # parameters
self.env_name = self.all_args.env_name self.env_name = self.all_args.env_name
...@@ -34,7 +33,6 @@ class Runner(object): ...@@ -34,7 +33,6 @@ class Runner(object):
self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads
self.use_linear_lr_decay = self.all_args.use_linear_lr_decay self.use_linear_lr_decay = self.all_args.use_linear_lr_decay
self.hidden_size = self.all_args.hidden_size self.hidden_size = self.all_args.hidden_size
self.use_wandb = self.all_args.use_wandb
self.use_render = self.all_args.use_render self.use_render = self.all_args.use_render
self.recurrent_N = self.all_args.recurrent_N self.recurrent_N = self.all_args.recurrent_N
...@@ -49,37 +47,42 @@ class Runner(object): ...@@ -49,37 +47,42 @@ class Runner(object):
if self.use_render: if self.use_render:
import imageio import imageio
self.run_dir = config["run_dir"] self.run_dir = config["run_dir"]
self.gif_dir = str(self.run_dir / 'gifs') self.gif_dir = str(self.run_dir / "gifs")
if not os.path.exists(self.gif_dir): if not os.path.exists(self.gif_dir):
os.makedirs(self.gif_dir) os.makedirs(self.gif_dir)
else: else:
if self.use_wandb: # if self.use_wandb:
self.save_dir = str(wandb.run.dir) # self.save_dir = str(wandb.run.dir)
else: # else:
self.run_dir = config["run_dir"] self.run_dir = config["run_dir"]
self.log_dir = str(self.run_dir / 'logs') self.log_dir = str(self.run_dir / "logs")
if not os.path.exists(self.log_dir): if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir) os.makedirs(self.log_dir)
self.writter = SummaryWriter(self.log_dir) self.writter = SummaryWriter(self.log_dir)
self.save_dir = str(self.run_dir / 'models') self.save_dir = str(self.run_dir / "models")
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
self.policy = [] self.policy = []
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id] share_observation_space = (
self.envs.share_observation_space[agent_id]
if self.use_centralized_V
else self.envs.observation_space[agent_id]
)
# policy network # policy network
po = Policy(self.all_args, po = Policy(
self.all_args,
self.envs.observation_space[agent_id], self.envs.observation_space[agent_id],
share_observation_space, share_observation_space,
self.envs.action_space[agent_id], self.envs.action_space[agent_id],
device = self.device) device=self.device,
)
self.policy.append(po) self.policy.append(po)
if self.model_dir is not None: if self.model_dir is not None:
...@@ -89,13 +92,19 @@ class Runner(object): ...@@ -89,13 +92,19 @@ class Runner(object):
self.buffer = [] self.buffer = []
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
# algorithm # algorithm
tr = TrainAlgo(self.all_args, self.policy[agent_id], device = self.device) tr = TrainAlgo(self.all_args, self.policy[agent_id], device=self.device)
# buffer # buffer
share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id] share_observation_space = (
bu = SeparatedReplayBuffer(self.all_args, self.envs.share_observation_space[agent_id]
if self.use_centralized_V
else self.envs.observation_space[agent_id]
)
bu = SeparatedReplayBuffer(
self.all_args,
self.envs.observation_space[agent_id], self.envs.observation_space[agent_id],
share_observation_space, share_observation_space,
self.envs.action_space[agent_id]) self.envs.action_space[agent_id],
)
self.buffer.append(bu) self.buffer.append(bu)
self.trainer.append(tr) self.trainer.append(tr)
...@@ -115,9 +124,11 @@ class Runner(object): ...@@ -115,9 +124,11 @@ class Runner(object):
def compute(self): def compute(self):
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
self.trainer[agent_id].prep_rollout() self.trainer[agent_id].prep_rollout()
next_value = self.trainer[agent_id].policy.get_values(self.buffer[agent_id].share_obs[-1], next_value = self.trainer[agent_id].policy.get_values(
self.buffer[agent_id].share_obs[-1],
self.buffer[agent_id].rnn_states_critic[-1], self.buffer[agent_id].rnn_states_critic[-1],
self.buffer[agent_id].masks[-1]) self.buffer[agent_id].masks[-1],
)
next_value = _t2n(next_value) next_value = _t2n(next_value)
self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer) self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer)
...@@ -134,30 +145,39 @@ class Runner(object): ...@@ -134,30 +145,39 @@ class Runner(object):
def save(self): def save(self):
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
policy_actor = self.trainer[agent_id].policy.actor policy_actor = self.trainer[agent_id].policy.actor
torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt") torch.save(
policy_actor.state_dict(),
str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt",
)
policy_critic = self.trainer[agent_id].policy.critic policy_critic = self.trainer[agent_id].policy.critic
torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_agent" + str(agent_id) + ".pt") torch.save(
policy_critic.state_dict(),
str(self.save_dir) + "/critic_agent" + str(agent_id) + ".pt",
)
def restore(self): def restore(self):
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor_agent' + str(agent_id) + '.pt') policy_actor_state_dict = torch.load(str(self.model_dir) + "/actor_agent" + str(agent_id) + ".pt")
self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict) self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict)
policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic_agent' + str(agent_id) + '.pt') policy_critic_state_dict = torch.load(
str(self.model_dir) + "/critic_agent" + str(agent_id) + ".pt"
)
self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict) self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict)
def log_train(self, train_infos, total_num_steps): def log_train(self, train_infos, total_num_steps):
for agent_id in range(self.num_agents): for agent_id in range(self.num_agents):
for k, v in train_infos[agent_id].items(): for k, v in train_infos[agent_id].items():
agent_k = "agent%i/" % agent_id + k agent_k = "agent%i/" % agent_id + k
if self.use_wandb: # if self.use_wandb:
wandb.log({agent_k: v}, step=total_num_steps) # pass
else: # wandb.log({agent_k: v}, step=total_num_steps)
# else:
self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps) self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps)
def log_env(self, env_infos, total_num_steps): def log_env(self, env_infos, total_num_steps):
for k, v in env_infos.items(): for k, v in env_infos.items():
if len(v) > 0: if len(v) > 0:
if self.use_wandb: # if self.use_wandb:
wandb.log({k: np.mean(v)}, step=total_num_steps) # wandb.log({k: np.mean(v)}, step=total_num_steps)
else: # else:
self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps) self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)
This diff is collapsed.
This diff is collapsed.
...@@ -31,11 +31,15 @@ def make_train_env(all_args): ...@@ -31,11 +31,15 @@ def make_train_env(all_args):
def init_env(): def init_env():
# TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。 # TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。
# TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines. # TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines.
from envs.env_continuous import ContinuousActionEnv from envs.env_continuous import ContinuousActionEnv
env = ContinuousActionEnv() env = ContinuousActionEnv()
# from envs.env_discrete import DiscreteActionEnv # from envs.env_discrete import DiscreteActionEnv
# env = DiscreteActionEnv() # env = DiscreteActionEnv()
env.seed(all_args.seed + rank * 1000) env.seed(all_args.seed + rank * 1000)
return env return env
...@@ -63,9 +67,7 @@ def make_eval_env(all_args): ...@@ -63,9 +67,7 @@ def make_eval_env(all_args):
def parse_args(args, parser): def parse_args(args, parser):
parser.add_argument( parser.add_argument("--scenario_name", type=str, default="MyEnv", help="Which scenario to run on")
"--scenario_name", type=str, default="MyEnv", help="Which scenario to run on"
)
parser.add_argument("--num_landmarks", type=int, default=3) parser.add_argument("--num_landmarks", type=int, default=3)
parser.add_argument("--num_agents", type=int, default=2, help="number of players") parser.add_argument("--num_agents", type=int, default=2, help="number of players")
...@@ -79,20 +81,16 @@ def main(args): ...@@ -79,20 +81,16 @@ def main(args):
all_args = parse_args(args, parser) all_args = parse_args(args, parser)
if all_args.algorithm_name == "rmappo": if all_args.algorithm_name == "rmappo":
assert ( assert all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy, "check recurrent policy!"
all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy
), "check recurrent policy!"
elif all_args.algorithm_name == "mappo": elif all_args.algorithm_name == "mappo":
assert ( assert (
all_args.use_recurrent_policy == False all_args.use_recurrent_policy == False and all_args.use_naive_recurrent_policy == False
and all_args.use_naive_recurrent_policy == False
), "check recurrent policy!" ), "check recurrent policy!"
else: else:
raise NotImplementedError raise NotImplementedError
assert ( assert (
all_args.share_policy == True all_args.share_policy == True and all_args.scenario_name == "simple_speaker_listener"
and all_args.scenario_name == "simple_speaker_listener"
) == False, "The simple_speaker_listener scenario can not use shared policy. Please check the config.py." ) == False, "The simple_speaker_listener scenario can not use shared policy. Please check the config.py."
# cuda # cuda
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment