Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
结
结合Transformer与多智能体强化学习的多无人机编码缓存传输方法
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
牛辰龙
结合Transformer与多智能体强化学习的多无人机编码缓存传输方法
Commits
ea508a3a
Commit
ea508a3a
authored
May 16, 2023
by
hezhiqiang01
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix share_policy = false的错误
parent
f2073aa3
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
507 additions
and
401 deletions
+507
-401
.gitignore
.gitignore
+2
-1
config.py
config.py
+332
-121
envs/env_discrete.py
envs/env_discrete.py
+22
-31
envs/env_wrappers.py
envs/env_wrappers.py
+1
-1
runner/separated/base_runner.py
runner/separated/base_runner.py
+74
-54
runner/separated/env_runner.py
runner/separated/env_runner.py
+32
-99
runner/shared/env_runner.py
runner/shared/env_runner.py
+36
-84
train/train.py
train/train.py
+8
-10
No files found.
.gitignore
View file @
ea508a3a
*.pyc
results
.*
\ No newline at end of file
config.py
View file @
ea508a3a
This diff is collapsed.
Click to expand it.
envs/env_discrete.py
View file @
ea508a3a
...
...
@@ -36,30 +36,30 @@ class DiscreteActionEnv(object):
share_obs_dim
=
0
total_action_space
=
[]
for
agent
in
range
(
self
.
num_agent
):
for
agent
_idx
in
range
(
self
.
num_agent
):
# physical action space
u_action_space
=
spaces
.
Discrete
(
self
.
signal_action_dim
)
# 5个离散的动作
if
self
.
movable
:
#
if self.movable:
total_action_space
.
append
(
u_action_space
)
# total action space
if
len
(
total_action_space
)
>
1
:
# all action spaces are discrete, so simplify to MultiDiscrete action space
if
all
(
[
isinstance
(
act_space
,
spaces
.
Discrete
)
for
act_space
in
total_action_space
]
):
act_space
=
MultiDiscrete
(
[[
0
,
act_space
.
n
-
1
]
for
act_space
in
total_action_space
]
)
else
:
act_space
=
spaces
.
Tuple
(
total_action_space
)
self
.
action_space
.
append
(
act_space
)
else
:
self
.
action_space
.
append
(
total_action_space
[
0
])
#
if len(total_action_space) > 1:
#
# all action spaces are discrete, so simplify to MultiDiscrete action space
#
if all(
#
[
#
isinstance(act_space, spaces.Discrete)
#
for act_space in total_action_space
#
]
#
):
#
act_space = MultiDiscrete(
#
[[0, act_space.n - 1] for act_space in total_action_space]
#
)
#
else:
#
act_space = spaces.Tuple(total_action_space)
#
self.action_space.append(act_space)
#
else:
self
.
action_space
.
append
(
total_action_space
[
agent_idx
])
# observation space
share_obs_dim
+=
self
.
signal_obs_dim
...
...
@@ -73,9 +73,7 @@ class DiscreteActionEnv(object):
)
# [-inf,inf]
self
.
share_observation_space
=
[
spaces
.
Box
(
low
=-
np
.
inf
,
high
=+
np
.
inf
,
shape
=
(
share_obs_dim
,),
dtype
=
np
.
float32
)
spaces
.
Box
(
low
=-
np
.
inf
,
high
=+
np
.
inf
,
shape
=
(
share_obs_dim
,),
dtype
=
np
.
float32
)
for
_
in
range
(
self
.
num_agent
)
]
...
...
@@ -135,12 +133,7 @@ class MultiDiscrete:
"""Returns a array with one sample from each discrete action space"""
# For each row: round(random .* (max - min) + min, 0)
random_array
=
np
.
random
.
rand
(
self
.
num_discrete_space
)
return
[
int
(
x
)
for
x
in
np
.
floor
(
np
.
multiply
((
self
.
high
-
self
.
low
+
1.0
),
random_array
)
+
self
.
low
)
]
return
[
int
(
x
)
for
x
in
np
.
floor
(
np
.
multiply
((
self
.
high
-
self
.
low
+
1.0
),
random_array
)
+
self
.
low
)]
def
contains
(
self
,
x
):
return
(
...
...
@@ -157,9 +150,7 @@ class MultiDiscrete:
return
"MultiDiscrete"
+
str
(
self
.
num_discrete_space
)
def
__eq__
(
self
,
other
):
return
np
.
array_equal
(
self
.
low
,
other
.
low
)
and
np
.
array_equal
(
self
.
high
,
other
.
high
)
return
np
.
array_equal
(
self
.
low
,
other
.
low
)
and
np
.
array_equal
(
self
.
high
,
other
.
high
)
if
__name__
==
"__main__"
:
...
...
envs/env_wrappers.py
View file @
ea508a3a
...
...
@@ -46,7 +46,7 @@ class DummyVecEnv():
return
obs
,
rews
,
dones
,
infos
def
reset
(
self
):
obs
=
[
env
.
reset
()
for
env
in
self
.
envs
]
obs
=
[
env
.
reset
()
for
env
in
self
.
envs
]
# [env_num, agent_num, obs_dim]
return
np
.
array
(
obs
)
def
close
(
self
):
...
...
runner/separated/base_runner.py
View file @
ea508a3a
import
time
import
wandb
import
os
import
numpy
as
np
from
itertools
import
chain
...
...
@@ -10,17 +8,18 @@ from tensorboardX import SummaryWriter
from
utils.separated_buffer
import
SeparatedReplayBuffer
from
utils.util
import
update_linear_schedule
def
_t2n
(
x
):
return
x
.
detach
().
cpu
().
numpy
()
class
Runner
(
object
):
def
__init__
(
self
,
config
):
self
.
all_args
=
config
[
'all_args'
]
self
.
envs
=
config
[
'envs'
]
self
.
eval_envs
=
config
[
'eval_envs'
]
self
.
device
=
config
[
'device'
]
self
.
num_agents
=
config
[
'num_agents'
]
self
.
all_args
=
config
[
"all_args"
]
self
.
envs
=
config
[
"envs"
]
self
.
eval_envs
=
config
[
"eval_envs"
]
self
.
device
=
config
[
"device"
]
self
.
num_agents
=
config
[
"num_agents"
]
# parameters
self
.
env_name
=
self
.
all_args
.
env_name
...
...
@@ -34,7 +33,6 @@ class Runner(object):
self
.
n_eval_rollout_threads
=
self
.
all_args
.
n_eval_rollout_threads
self
.
use_linear_lr_decay
=
self
.
all_args
.
use_linear_lr_decay
self
.
hidden_size
=
self
.
all_args
.
hidden_size
self
.
use_wandb
=
self
.
all_args
.
use_wandb
self
.
use_render
=
self
.
all_args
.
use_render
self
.
recurrent_N
=
self
.
all_args
.
recurrent_N
...
...
@@ -49,37 +47,42 @@ class Runner(object):
if
self
.
use_render
:
import
imageio
self
.
run_dir
=
config
[
"run_dir"
]
self
.
gif_dir
=
str
(
self
.
run_dir
/
'gifs'
)
self
.
gif_dir
=
str
(
self
.
run_dir
/
"gifs"
)
if
not
os
.
path
.
exists
(
self
.
gif_dir
):
os
.
makedirs
(
self
.
gif_dir
)
else
:
if
self
.
use_wandb
:
self
.
save_dir
=
str
(
wandb
.
run
.
dir
)
else
:
#
if self.use_wandb:
#
self.save_dir = str(wandb.run.dir)
#
else:
self
.
run_dir
=
config
[
"run_dir"
]
self
.
log_dir
=
str
(
self
.
run_dir
/
'logs'
)
self
.
log_dir
=
str
(
self
.
run_dir
/
"logs"
)
if
not
os
.
path
.
exists
(
self
.
log_dir
):
os
.
makedirs
(
self
.
log_dir
)
self
.
writter
=
SummaryWriter
(
self
.
log_dir
)
self
.
save_dir
=
str
(
self
.
run_dir
/
'models'
)
self
.
save_dir
=
str
(
self
.
run_dir
/
"models"
)
if
not
os
.
path
.
exists
(
self
.
save_dir
):
os
.
makedirs
(
self
.
save_dir
)
from
algorithms.algorithm.r_mappo
import
RMAPPO
as
TrainAlgo
from
algorithms.algorithm.rMAPPOPolicy
import
RMAPPOPolicy
as
Policy
self
.
policy
=
[]
for
agent_id
in
range
(
self
.
num_agents
):
share_observation_space
=
self
.
envs
.
share_observation_space
[
agent_id
]
if
self
.
use_centralized_V
else
self
.
envs
.
observation_space
[
agent_id
]
share_observation_space
=
(
self
.
envs
.
share_observation_space
[
agent_id
]
if
self
.
use_centralized_V
else
self
.
envs
.
observation_space
[
agent_id
]
)
# policy network
po
=
Policy
(
self
.
all_args
,
po
=
Policy
(
self
.
all_args
,
self
.
envs
.
observation_space
[
agent_id
],
share_observation_space
,
self
.
envs
.
action_space
[
agent_id
],
device
=
self
.
device
)
device
=
self
.
device
,
)
self
.
policy
.
append
(
po
)
if
self
.
model_dir
is
not
None
:
...
...
@@ -89,13 +92,19 @@ class Runner(object):
self
.
buffer
=
[]
for
agent_id
in
range
(
self
.
num_agents
):
# algorithm
tr
=
TrainAlgo
(
self
.
all_args
,
self
.
policy
[
agent_id
],
device
=
self
.
device
)
tr
=
TrainAlgo
(
self
.
all_args
,
self
.
policy
[
agent_id
],
device
=
self
.
device
)
# buffer
share_observation_space
=
self
.
envs
.
share_observation_space
[
agent_id
]
if
self
.
use_centralized_V
else
self
.
envs
.
observation_space
[
agent_id
]
bu
=
SeparatedReplayBuffer
(
self
.
all_args
,
share_observation_space
=
(
self
.
envs
.
share_observation_space
[
agent_id
]
if
self
.
use_centralized_V
else
self
.
envs
.
observation_space
[
agent_id
]
)
bu
=
SeparatedReplayBuffer
(
self
.
all_args
,
self
.
envs
.
observation_space
[
agent_id
],
share_observation_space
,
self
.
envs
.
action_space
[
agent_id
])
self
.
envs
.
action_space
[
agent_id
],
)
self
.
buffer
.
append
(
bu
)
self
.
trainer
.
append
(
tr
)
...
...
@@ -115,9 +124,11 @@ class Runner(object):
def
compute
(
self
):
for
agent_id
in
range
(
self
.
num_agents
):
self
.
trainer
[
agent_id
].
prep_rollout
()
next_value
=
self
.
trainer
[
agent_id
].
policy
.
get_values
(
self
.
buffer
[
agent_id
].
share_obs
[
-
1
],
next_value
=
self
.
trainer
[
agent_id
].
policy
.
get_values
(
self
.
buffer
[
agent_id
].
share_obs
[
-
1
],
self
.
buffer
[
agent_id
].
rnn_states_critic
[
-
1
],
self
.
buffer
[
agent_id
].
masks
[
-
1
])
self
.
buffer
[
agent_id
].
masks
[
-
1
],
)
next_value
=
_t2n
(
next_value
)
self
.
buffer
[
agent_id
].
compute_returns
(
next_value
,
self
.
trainer
[
agent_id
].
value_normalizer
)
...
...
@@ -134,30 +145,39 @@ class Runner(object):
def
save
(
self
):
for
agent_id
in
range
(
self
.
num_agents
):
policy_actor
=
self
.
trainer
[
agent_id
].
policy
.
actor
torch
.
save
(
policy_actor
.
state_dict
(),
str
(
self
.
save_dir
)
+
"/actor_agent"
+
str
(
agent_id
)
+
".pt"
)
torch
.
save
(
policy_actor
.
state_dict
(),
str
(
self
.
save_dir
)
+
"/actor_agent"
+
str
(
agent_id
)
+
".pt"
,
)
policy_critic
=
self
.
trainer
[
agent_id
].
policy
.
critic
torch
.
save
(
policy_critic
.
state_dict
(),
str
(
self
.
save_dir
)
+
"/critic_agent"
+
str
(
agent_id
)
+
".pt"
)
torch
.
save
(
policy_critic
.
state_dict
(),
str
(
self
.
save_dir
)
+
"/critic_agent"
+
str
(
agent_id
)
+
".pt"
,
)
def
restore
(
self
):
for
agent_id
in
range
(
self
.
num_agents
):
policy_actor_state_dict
=
torch
.
load
(
str
(
self
.
model_dir
)
+
'/actor_agent'
+
str
(
agent_id
)
+
'.pt'
)
policy_actor_state_dict
=
torch
.
load
(
str
(
self
.
model_dir
)
+
"/actor_agent"
+
str
(
agent_id
)
+
".pt"
)
self
.
policy
[
agent_id
].
actor
.
load_state_dict
(
policy_actor_state_dict
)
policy_critic_state_dict
=
torch
.
load
(
str
(
self
.
model_dir
)
+
'/critic_agent'
+
str
(
agent_id
)
+
'.pt'
)
policy_critic_state_dict
=
torch
.
load
(
str
(
self
.
model_dir
)
+
"/critic_agent"
+
str
(
agent_id
)
+
".pt"
)
self
.
policy
[
agent_id
].
critic
.
load_state_dict
(
policy_critic_state_dict
)
def
log_train
(
self
,
train_infos
,
total_num_steps
):
for
agent_id
in
range
(
self
.
num_agents
):
for
k
,
v
in
train_infos
[
agent_id
].
items
():
agent_k
=
"agent%i/"
%
agent_id
+
k
if
self
.
use_wandb
:
wandb
.
log
({
agent_k
:
v
},
step
=
total_num_steps
)
else
:
# if self.use_wandb:
# pass
# wandb.log({agent_k: v}, step=total_num_steps)
# else:
self
.
writter
.
add_scalars
(
agent_k
,
{
agent_k
:
v
},
total_num_steps
)
def
log_env
(
self
,
env_infos
,
total_num_steps
):
for
k
,
v
in
env_infos
.
items
():
if
len
(
v
)
>
0
:
if
self
.
use_wandb
:
wandb
.
log
({
k
:
np
.
mean
(
v
)},
step
=
total_num_steps
)
else
:
#
if self.use_wandb:
#
wandb.log({k: np.mean(v)}, step=total_num_steps)
#
else:
self
.
writter
.
add_scalars
(
k
,
{
k
:
np
.
mean
(
v
)},
total_num_steps
)
runner/separated/env_runner.py
View file @
ea508a3a
This diff is collapsed.
Click to expand it.
runner/shared/env_runner.py
View file @
ea508a3a
This diff is collapsed.
Click to expand it.
train/train.py
View file @
ea508a3a
...
...
@@ -31,11 +31,15 @@ def make_train_env(all_args):
def
init_env
():
# TODO 注意注意,这里选择连续还是离散可以选择注释上面两行,或者下面两行。
# TODO Important, here you can choose continuous or discrete action space by uncommenting the above two lines or the below two lines.
from
envs.env_continuous
import
ContinuousActionEnv
env
=
ContinuousActionEnv
()
# from envs.env_discrete import DiscreteActionEnv
# env = DiscreteActionEnv()
env
.
seed
(
all_args
.
seed
+
rank
*
1000
)
return
env
...
...
@@ -63,9 +67,7 @@ def make_eval_env(all_args):
def
parse_args
(
args
,
parser
):
parser
.
add_argument
(
"--scenario_name"
,
type
=
str
,
default
=
"MyEnv"
,
help
=
"Which scenario to run on"
)
parser
.
add_argument
(
"--scenario_name"
,
type
=
str
,
default
=
"MyEnv"
,
help
=
"Which scenario to run on"
)
parser
.
add_argument
(
"--num_landmarks"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
"--num_agents"
,
type
=
int
,
default
=
2
,
help
=
"number of players"
)
...
...
@@ -79,20 +81,16 @@ def main(args):
all_args
=
parse_args
(
args
,
parser
)
if
all_args
.
algorithm_name
==
"rmappo"
:
assert
(
all_args
.
use_recurrent_policy
or
all_args
.
use_naive_recurrent_policy
),
"check recurrent policy!"
assert
all_args
.
use_recurrent_policy
or
all_args
.
use_naive_recurrent_policy
,
"check recurrent policy!"
elif
all_args
.
algorithm_name
==
"mappo"
:
assert
(
all_args
.
use_recurrent_policy
==
False
and
all_args
.
use_naive_recurrent_policy
==
False
all_args
.
use_recurrent_policy
==
False
and
all_args
.
use_naive_recurrent_policy
==
False
),
"check recurrent policy!"
else
:
raise
NotImplementedError
assert
(
all_args
.
share_policy
==
True
and
all_args
.
scenario_name
==
"simple_speaker_listener"
all_args
.
share_policy
==
True
and
all_args
.
scenario_name
==
"simple_speaker_listener"
)
==
False
,
"The simple_speaker_listener scenario can not use shared policy. Please check the config.py."
# cuda
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment