Commit 6cef44e7 authored by tianyu-z's avatar tianyu-z

Fix some typo and make English Translation

parent 97ccec45
...@@ -2,43 +2,41 @@ ...@@ -2,43 +2,41 @@
Lightweight version of MAPPO to help you quickly migrate to your local environment. Lightweight version of MAPPO to help you quickly migrate to your local environment.
轻量版MAPPO,帮助你快速移植到本地环境。 - [Video (in Chinese)](https://www.bilibili.com/video/BV1bd4y1L73N)
This is a translated English version. Please click [here](README_CN.md) for the orginal Chinese readme.
- [视频解析](https://www.bilibili.com/video/BV1bd4y1L73N/?spm_id_from=333.999.0.0&vd_source=d8ab7686ea514acb6635faa5d2227d61)
## Table of Contents ## Table of Contents
- [背景](#背景) - [Background](#Background)
- [安装](#安装) - [Installation](#Installation)
- [用法](#用法) - [Usage](#Usage)
## 背景 ## Background
MAPPO原版代码对于环境的封装过于复杂,本项目直接将环境封装抽取出来。更加方便将MAPPO代码移植到自己的项目上。 The original MAPPO code was too complex in terms of environment encapsulation, so this project directly extracts and encapsulates the environment. This makes it easier to transfer the MAPPO code to your own project.
## 安装 ## Installation
直接将代码下载下来,创建一个Conda环境,然后运行代码,缺啥补啥包。具体什么包以后再添加。 Simply download the code, create a Conda environment, and then run the code, adding packages as needed. Specific packages will be added later.
## 用法 ## Usage
- 环境部分是一个空的的实现,文件`light_mappo/envs/env_core.py`里面环境部分的实现:[Code](https://github.com/tinyzqh/light_mappo/blob/main/envs/env_core.py) - The environment part is an empty implementation, and the implementation of the environment part in the light_mappo/envs/env_core.py file is: [Code] (https://github.com/tinyzqh/light_mappo/blob/main/envs/env_core.py)
```python ```python
import numpy as np import numpy as np
class EnvCore(object): class EnvCore(object):
""" """
# 环境中的智能体 # Environment Agent
""" """
def __init__(self): def __init__(self):
self.agent_num = 2 # 设置智能体(小飞机)的个数,这里设置为两个 self.agent_num = 2 # set the number of agents(aircrafts), here set to two
self.obs_dim = 14 # 设置智能体的观测纬度 self.obs_dim = 14 # set the observation dimension of agents
self.action_dim = 5 # 设置智能体的动作纬度,这里假定为一个五个纬度的 self.action_dim = 5 # set the action dimension of agents, here set to a five-dimensional
def reset(self): def reset(self):
""" """
# self.agent_num设定为2个智能体时,返回值为一个list,每个list里面为一个shape = (self.obs_dim, )的观测数据 # When self.agent_num is set to 2 agents, the return value is a list, and each list contains observation data of shape = (self.obs_dim,)
""" """
sub_agent_obs = [] sub_agent_obs = []
for i in range(self.agent_num): for i in range(self.agent_num):
...@@ -48,8 +46,8 @@ class EnvCore(object): ...@@ -48,8 +46,8 @@ class EnvCore(object):
def step(self, actions): def step(self, actions):
""" """
# self.agent_num设定为2个智能体时,actions的输入为一个2纬的list,每个list里面为一个shape = (self.action_dim, )的动作数据 # When self.agent_num is set to 2 agents, the input of actions is a two-dimensional list, and each list contains action data of shape = (self.action_dim,).
# 默认参数情况下,输入为一个list,里面含有两个元素,因为动作纬度为5,所里每个元素shape = (5, ) # By default, the input is a list containing two elements, because the action dimension is 5, so each element has a shape of (5,)
""" """
sub_agent_obs = [] sub_agent_obs = []
sub_agent_reward = [] sub_agent_reward = []
...@@ -65,9 +63,9 @@ class EnvCore(object): ...@@ -65,9 +63,9 @@ class EnvCore(object):
``` ```
只需要编写这一部分的代码,就可以无缝衔接MAPPO。在env_core.py之后,单独提出来了两个文件env_discrete.py和env_continuous.py这两个文件用于封装处理动作空间和离散动作空间。在algorithms/utils/act.py中elif self.continuous_action:这个判断逻辑也是用来处理连续动作空间的。和runner/shared/env_runner.py部分的# TODO 这里改造成自己环境需要的形式即可都是用来处理连续动作空间的。 Just write this part of the code, and you can seamlessly connect with MAPPO. After env_core.py, two files, env_discrete.py and env_continuous.py, were separately extracted to encapsulate the action space and discrete action space. In elif self.continuous_action: in algorithms/utils/act.py, this judgment logic is also used to handle continuous action spaces. The # TODO here in runner/shared/env_runner.py is also used to handle continuous action spaces.
在train.py文件里面,选择注释连续环境,或者离散环境进行demo环境的切换。 In the train.py file, choose to comment out continuous environment or discrete environment to switch the demo environment.
## Related Efforts ## Related Efforts
...@@ -77,6 +75,9 @@ class EnvCore(object): ...@@ -77,6 +75,9 @@ class EnvCore(object):
[@tinyzqh](https://github.com/tinyzqh). [@tinyzqh](https://github.com/tinyzqh).
## Translator
[@tianyu-z](https://github.com/tianyu-z)
## License ## License
[MIT](LICENSE) © tinyzqh [MIT](LICENSE) © tinyzqh
......
# light_mappo
Lightweight version of MAPPO to help you quickly migrate to your local environment.
轻量版MAPPO,帮助你快速移植到本地环境。
- [视频解析](https://www.bilibili.com/video/BV1bd4y1L73N/?spm_id_from=333.999.0.0&vd_source=d8ab7686ea514acb6635faa5d2227d61)
英文翻译版readme,请点击[这里](README.md)
## Table of Contents
- [背景](#背景)
- [安装](#安装)
- [用法](#用法)
## 背景
MAPPO原版代码对于环境的封装过于复杂,本项目直接将环境封装抽取出来。更加方便将MAPPO代码移植到自己的项目上。
## 安装
直接将代码下载下来,创建一个Conda环境,然后运行代码,缺啥补啥包。具体什么包以后再添加。
## 用法
- 环境部分是一个空的的实现,文件`light_mappo/envs/env_core.py`里面环境部分的实现:[Code](https://github.com/tinyzqh/light_mappo/blob/main/envs/env_core.py)
```python
import numpy as np
class EnvCore(object):
"""
# 环境中的智能体
"""
def __init__(self):
self.agent_num = 2 # 设置智能体(小飞机)的个数,这里设置为两个
self.obs_dim = 14 # 设置智能体的观测维度
self.action_dim = 5 # 设置智能体的动作维度,这里假定为一个五个维度的
def reset(self):
"""
# self.agent_num设定为2个智能体时,返回值为一个list,每个list里面为一个shape = (self.obs_dim, )的观测数据
"""
sub_agent_obs = []
for i in range(self.agent_num):
sub_obs = np.random.random(size=(14, ))
sub_agent_obs.append(sub_obs)
return sub_agent_obs
def step(self, actions):
"""
# self.agent_num设定为2个智能体时,actions的输入为一个2纬的list,每个list里面为一个shape = (self.action_dim, )的动作数据
# 默认参数情况下,输入为一个list,里面含有两个元素,因为动作维度为5,所里每个元素shape = (5, )
"""
sub_agent_obs = []
sub_agent_reward = []
sub_agent_done = []
sub_agent_info = []
for i in range(self.agent_num):
sub_agent_obs.append(np.random.random(size=(14,)))
sub_agent_reward.append([np.random.rand()])
sub_agent_done.append(False)
sub_agent_info.append({})
return [sub_agent_obs, sub_agent_reward, sub_agent_done, sub_agent_info]
```
只需要编写这一部分的代码,就可以无缝衔接MAPPO。在env_core.py之后,单独提出来了两个文件env_discrete.py和env_continuous.py这两个文件用于封装处理动作空间和离散动作空间。在algorithms/utils/act.py中elif self.continuous_action:这个判断逻辑也是用来处理连续动作空间的。和runner/shared/env_runner.py部分的# TODO 这里改造成自己环境需要的形式即可都是用来处理连续动作空间的。
在train.py文件里面,选择注释连续环境,或者离散环境进行demo环境的切换。
## Related Efforts
- [on-policy](https://github.com/marlbenchmark/on-policy) - 💌 Learn the author implementation of MAPPO.
## Maintainers
[@tinyzqh](https://github.com/tinyzqh).
## License
[MIT](LICENSE) © tinyzqh
...@@ -5,7 +5,11 @@ from envs.env_core import EnvCore ...@@ -5,7 +5,11 @@ from envs.env_core import EnvCore
class ContinuousActionEnv(object): class ContinuousActionEnv(object):
"""对于连续动作环境的封装""" """
对于连续动作环境的封装
Wrapper for continuous action environment.
"""
def __init__(self): def __init__(self):
self.env = EnvCore() self.env = EnvCore()
self.num_agent = self.env.agent_num self.num_agent = self.env.agent_num
...@@ -27,7 +31,12 @@ class ContinuousActionEnv(object): ...@@ -27,7 +31,12 @@ class ContinuousActionEnv(object):
total_action_space = [] total_action_space = []
for agent in range(self.num_agent): for agent in range(self.num_agent):
# physical action space # physical action space
u_action_space = spaces.Box(low=-np.inf, high=+np.inf, shape=(self.signal_action_dim,), dtype=np.float32) u_action_space = spaces.Box(
low=-np.inf,
high=+np.inf,
shape=(self.signal_action_dim,),
dtype=np.float32,
)
if self.movable: if self.movable:
total_action_space.append(u_action_space) total_action_space.append(u_action_space)
...@@ -37,17 +46,31 @@ class ContinuousActionEnv(object): ...@@ -37,17 +46,31 @@ class ContinuousActionEnv(object):
# observation space # observation space
share_obs_dim += self.signal_obs_dim share_obs_dim += self.signal_obs_dim
self.observation_space.append(spaces.Box(low=-np.inf, high=+np.inf, shape=(self.signal_obs_dim,), self.observation_space.append(
dtype=np.float32)) # [-inf,inf] spaces.Box(
low=-np.inf,
self.share_observation_space = [spaces.Box(low=-np.inf, high=+np.inf, shape=(share_obs_dim,), high=+np.inf,
dtype=np.float32) for _ in range(self.num_agent)] shape=(self.signal_obs_dim,),
dtype=np.float32,
)
) # [-inf,inf]
self.share_observation_space = [
spaces.Box(
low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32
)
for _ in range(self.num_agent)
]
def step(self, actions): def step(self, actions):
""" """
输入actions度假设: 输入actions度假设:
# actions shape = (5, 2, 5) # actions shape = (5, 2, 5)
# 5个线程的环境,里面有2个智能体,每个智能体的动作是一个one_hot的5维编码 # 5个线程的环境,里面有2个智能体,每个智能体的动作是一个one_hot的5维编码
Input actions dimension assumption:
# actions shape = (5, 2, 5)
# 5 threads of environment, there are 2 agents inside, and each agent's action is a 5-dimensional one_hot encoding
""" """
results = self.env.step(actions) results = self.env.step(actions)
...@@ -65,4 +88,4 @@ class ContinuousActionEnv(object): ...@@ -65,4 +88,4 @@ class ContinuousActionEnv(object):
pass pass
def seed(self, seed): def seed(self, seed):
pass pass
\ No newline at end of file
...@@ -5,25 +5,29 @@ class EnvCore(object): ...@@ -5,25 +5,29 @@ class EnvCore(object):
""" """
# 环境中的智能体 # 环境中的智能体
""" """
def __init__(self): def __init__(self):
self.agent_num = 2 # 设置智能体(小飞机)的个数,这里设置为两个 self.agent_num = 2 # 设置智能体(小飞机)的个数,这里设置为两个 # set the number of agents(aircrafts), here set to two
self.obs_dim = 14 # 设置智能体的观测纬度 self.obs_dim = 14 # 设置智能体的观测维度 # set the observation dimension of agents
self.action_dim = 5 # 设置智能体的动作纬度,这里假定为一个五个纬度的 self.action_dim = 5 # 设置智能体的动作维度,这里假定为一个五个维度的 # set the action dimension of agents, here set to a five-dimensional
def reset(self): def reset(self):
""" """
# self.agent_num设定为2个智能体时,返回值为一个list,每个list里面为一个shape = (self.obs_dim, )的观测数据 # self.agent_num设定为2个智能体时,返回值为一个list,每个list里面为一个shape = (self.obs_dim, )的观测数据
# When self.agent_num is set to 2 agents, the return value is a list, each list contains a shape = (self.obs_dim, ) observation data
""" """
sub_agent_obs = [] sub_agent_obs = []
for i in range(self.agent_num): for i in range(self.agent_num):
sub_obs = np.random.random(size=(14, )) sub_obs = np.random.random(size=(14,))
sub_agent_obs.append(sub_obs) sub_agent_obs.append(sub_obs)
return sub_agent_obs return sub_agent_obs
def step(self, actions): def step(self, actions):
""" """
# self.agent_num设定为2个智能体时,actions的输入为一个2纬的list,每个list里面为一个shape = (self.action_dim, )的动作数据 # self.agent_num设定为2个智能体时,actions的输入为一个2纬的list,每个list里面为一个shape = (self.action_dim, )的动作数据
# 默认参数情况下,输入为一个list,里面含有两个元素,因为动作纬度为5,所里每个元素shape = (5, ) # 默认参数情况下,输入为一个list,里面含有两个元素,因为动作维度为5,所里每个元素shape = (5, )
# When self.agent_num is set to 2 agents, the input of actions is a 2-dimensional list, each list contains a shape = (self.action_dim, ) action data
# The default parameter situation is to input a list with two elements, because the action dimension is 5, so each element shape = (5, )
""" """
sub_agent_obs = [] sub_agent_obs = []
sub_agent_reward = [] sub_agent_reward = []
...@@ -35,4 +39,4 @@ class EnvCore(object): ...@@ -35,4 +39,4 @@ class EnvCore(object):
sub_agent_done.append(False) sub_agent_done.append(False)
sub_agent_info.append({}) sub_agent_info.append({})
return [sub_agent_obs, sub_agent_reward, sub_agent_done, sub_agent_info] return [sub_agent_obs, sub_agent_reward, sub_agent_done, sub_agent_info]
\ No newline at end of file
...@@ -12,7 +12,11 @@ from envs.env_core import EnvCore ...@@ -12,7 +12,11 @@ from envs.env_core import EnvCore
class DiscreteActionEnv(object): class DiscreteActionEnv(object):
"""对于离散动作环境的封装""" """
对于离散动作环境的封装
Wrapper for discrete action environment.
"""
def __init__(self): def __init__(self):
self.env = EnvCore() self.env = EnvCore()
self.num_agent = self.env.agent_num self.num_agent = self.env.agent_num
...@@ -42,8 +46,15 @@ class DiscreteActionEnv(object): ...@@ -42,8 +46,15 @@ class DiscreteActionEnv(object):
# total action space # total action space
if len(total_action_space) > 1: if len(total_action_space) > 1:
# all action spaces are discrete, so simplify to MultiDiscrete action space # all action spaces are discrete, so simplify to MultiDiscrete action space
if all([isinstance(act_space, spaces.Discrete) for act_space in total_action_space]): if all(
act_space = MultiDiscrete([[0, act_space.n - 1] for act_space in total_action_space]) [
isinstance(act_space, spaces.Discrete)
for act_space in total_action_space
]
):
act_space = MultiDiscrete(
[[0, act_space.n - 1] for act_space in total_action_space]
)
else: else:
act_space = spaces.Tuple(total_action_space) act_space = spaces.Tuple(total_action_space)
self.action_space.append(act_space) self.action_space.append(act_space)
...@@ -52,17 +63,30 @@ class DiscreteActionEnv(object): ...@@ -52,17 +63,30 @@ class DiscreteActionEnv(object):
# observation space # observation space
share_obs_dim += self.signal_obs_dim share_obs_dim += self.signal_obs_dim
self.observation_space.append(spaces.Box(low=-np.inf, high=+np.inf, shape=(self.signal_obs_dim,), self.observation_space.append(
dtype=np.float32)) # [-inf,inf] spaces.Box(
low=-np.inf,
self.share_observation_space = [spaces.Box(low=-np.inf, high=+np.inf, shape=(share_obs_dim,), high=+np.inf,
dtype=np.float32) for _ in range(self.num_agent)] shape=(self.signal_obs_dim,),
dtype=np.float32,
)
) # [-inf,inf]
self.share_observation_space = [
spaces.Box(
low=-np.inf, high=+np.inf, shape=(share_obs_dim,), dtype=np.float32
)
for _ in range(self.num_agent)
]
def step(self, actions): def step(self, actions):
""" """
输入actions度假设: 输入actions度假设:
# actions shape = (5, 2, 5) # actions shape = (5, 2, 5)
# 5个线程的环境,里面有2个智能体,每个智能体的动作是一个one_hot的5维编码 # 5个线程的环境,里面有2个智能体,每个智能体的动作是一个one_hot的5维编码
Input actions dimension assumption:
# actions shape = (5, 2, 5)
# 5 threads of the environment, with 2 intelligent agents inside, and each intelligent agent's action is a 5-dimensional one_hot encoding
""" """
results = self.env.step(actions) results = self.env.step(actions)
...@@ -82,7 +106,8 @@ class DiscreteActionEnv(object): ...@@ -82,7 +106,8 @@ class DiscreteActionEnv(object):
def seed(self, seed): def seed(self, seed):
pass pass
class MultiDiscrete():
class MultiDiscrete:
""" """
- The multi-discrete action space consists of a series of discrete action spaces with different parameters - The multi-discrete action space consists of a series of discrete action spaces with different parameters
- It can be adapted to both a Discrete action space or a continuous (Box) action space - It can be adapted to both a Discrete action space or a continuous (Box) action space
...@@ -107,14 +132,22 @@ class MultiDiscrete(): ...@@ -107,14 +132,22 @@ class MultiDiscrete():
self.n = np.sum(self.high) + 2 self.n = np.sum(self.high) + 2
def sample(self): def sample(self):
""" Returns a array with one sample from each discrete action space """ """Returns a array with one sample from each discrete action space"""
# For each row: round(random .* (max - min) + min, 0) # For each row: round(random .* (max - min) + min, 0)
random_array = np.random.rand(self.num_discrete_space) random_array = np.random.rand(self.num_discrete_space)
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] return [
int(x)
for x in np.floor(
np.multiply((self.high - self.low + 1.0), random_array) + self.low
)
]
def contains(self, x): def contains(self, x):
return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and ( return (
np.array(x) <= self.high).all() len(x) == self.num_discrete_space
and (np.array(x) >= self.low).all()
and (np.array(x) <= self.high).all()
)
@property @property
def shape(self): def shape(self):
...@@ -124,8 +157,10 @@ class MultiDiscrete(): ...@@ -124,8 +157,10 @@ class MultiDiscrete():
return "MultiDiscrete" + str(self.num_discrete_space) return "MultiDiscrete" + str(self.num_discrete_space)
def __eq__(self, other): def __eq__(self, other):
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) return np.array_equal(self.low, other.low) and np.array_equal(
self.high, other.high
)
if __name__ == "__main__": if __name__ == "__main__":
DiscreteActionEnv().step(actions=None) DiscreteActionEnv().step(actions=None)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -23,13 +23,17 @@ def make_train_env(all_args): ...@@ -23,13 +23,17 @@ def make_train_env(all_args):
def get_env_fn(rank): def get_env_fn(rank):
def init_env(): def init_env():
# TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。 # TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。
# TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines.
from envs.env_continuous import ContinuousActionEnv from envs.env_continuous import ContinuousActionEnv
env = ContinuousActionEnv() env = ContinuousActionEnv()
# from envs.env_discrete import DiscreteActionEnv # from envs.env_discrete import DiscreteActionEnv
# env = DiscreteActionEnv() # env = DiscreteActionEnv()
env.seed(all_args.seed + rank * 1000) env.seed(all_args.seed + rank * 1000)
return env return env
return init_env return init_env
return DummyVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)]) return DummyVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
...@@ -37,20 +41,26 @@ def make_eval_env(all_args): ...@@ -37,20 +41,26 @@ def make_eval_env(all_args):
def get_env_fn(rank): def get_env_fn(rank):
def init_env(): def init_env():
# TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。 # TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。
# TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines.
from envs.env_continuous import ContinuousActionEnv from envs.env_continuous import ContinuousActionEnv
env = ContinuousActionEnv() env = ContinuousActionEnv()
# from envs.env_discrete import DiscreteActionEnv # from envs.env_discrete import DiscreteActionEnv
# env = DiscreteActionEnv() # env = DiscreteActionEnv()
env.seed(all_args.seed + rank * 1000) env.seed(all_args.seed + rank * 1000)
return env return env
return init_env return init_env
return DummyVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)]) return DummyVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])
def parse_args(args, parser): def parse_args(args, parser):
parser.add_argument('--scenario_name', type=str, default='MyEnv', help="Which scenario to run on") parser.add_argument(
"--scenario_name", type=str, default="MyEnv", help="Which scenario to run on"
)
parser.add_argument("--num_landmarks", type=int, default=3) parser.add_argument("--num_landmarks", type=int, default=3)
parser.add_argument('--num_agents', type=int, default=2, help="number of players") parser.add_argument("--num_agents", type=int, default=2, help="number of players")
all_args = parser.parse_known_args(args)[0] all_args = parser.parse_known_args(args)[0]
...@@ -62,15 +72,21 @@ def main(args): ...@@ -62,15 +72,21 @@ def main(args):
all_args = parse_args(args, parser) all_args = parse_args(args, parser)
if all_args.algorithm_name == "rmappo": if all_args.algorithm_name == "rmappo":
assert (all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy), ("check recurrent policy!") assert (
all_args.use_recurrent_policy or all_args.use_naive_recurrent_policy
), "check recurrent policy!"
elif all_args.algorithm_name == "mappo": elif all_args.algorithm_name == "mappo":
assert (all_args.use_recurrent_policy == False and all_args.use_naive_recurrent_policy == False), ( assert (
"check recurrent policy!") all_args.use_recurrent_policy == False
and all_args.use_naive_recurrent_policy == False
), "check recurrent policy!"
else: else:
raise NotImplementedError raise NotImplementedError
assert (all_args.share_policy == True and all_args.scenario_name == 'simple_speaker_listener') == False, ( assert (
"The simple_speaker_listener scenario can not use shared policy. Please check the config.py.") all_args.share_policy == True
and all_args.scenario_name == "simple_speaker_listener"
) == False, "The simple_speaker_listener scenario can not use shared policy. Please check the config.py."
# cuda # cuda
if all_args.cuda and torch.cuda.is_available(): if all_args.cuda and torch.cuda.is_available():
...@@ -86,27 +102,41 @@ def main(args): ...@@ -86,27 +102,41 @@ def main(args):
torch.set_num_threads(all_args.n_training_threads) torch.set_num_threads(all_args.n_training_threads)
# run dir # run dir
run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[ run_dir = (
0] + "/results") / all_args.env_name / all_args.scenario_name / all_args.algorithm_name / all_args.experiment_name Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[0] + "/results")
/ all_args.env_name
/ all_args.scenario_name
/ all_args.algorithm_name
/ all_args.experiment_name
)
if not run_dir.exists(): if not run_dir.exists():
os.makedirs(str(run_dir)) os.makedirs(str(run_dir))
if not run_dir.exists(): if not run_dir.exists():
curr_run = 'run1' curr_run = "run1"
else: else:
exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if exst_run_nums = [
str(folder.name).startswith('run')] int(str(folder.name).split("run")[1])
for folder in run_dir.iterdir()
if str(folder.name).startswith("run")
]
if len(exst_run_nums) == 0: if len(exst_run_nums) == 0:
curr_run = 'run1' curr_run = "run1"
else: else:
curr_run = 'run%i' % (max(exst_run_nums) + 1) curr_run = "run%i" % (max(exst_run_nums) + 1)
run_dir = run_dir / curr_run run_dir = run_dir / curr_run
if not run_dir.exists(): if not run_dir.exists():
os.makedirs(str(run_dir)) os.makedirs(str(run_dir))
setproctitle.setproctitle(str(all_args.algorithm_name) + "-" + \ setproctitle.setproctitle(
str(all_args.env_name) + "-" + str(all_args.experiment_name) + "@" + str( str(all_args.algorithm_name)
all_args.user_name)) + "-"
+ str(all_args.env_name)
+ "-"
+ str(all_args.experiment_name)
+ "@"
+ str(all_args.user_name)
)
# seed # seed
torch.manual_seed(all_args.seed) torch.manual_seed(all_args.seed)
...@@ -124,7 +154,7 @@ def main(args): ...@@ -124,7 +154,7 @@ def main(args):
"eval_envs": eval_envs, "eval_envs": eval_envs,
"num_agents": num_agents, "num_agents": num_agents,
"device": device, "device": device,
"run_dir": run_dir "run_dir": run_dir,
} }
# run experiments # run experiments
...@@ -141,7 +171,7 @@ def main(args): ...@@ -141,7 +171,7 @@ def main(args):
if all_args.use_eval and eval_envs is not envs: if all_args.use_eval and eval_envs is not envs:
eval_envs.close() eval_envs.close()
runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json')) runner.writter.export_scalars_to_json(str(runner.log_dir + "/summary.json"))
runner.writter.close() runner.writter.close()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment