Commit ed31aa10 authored by hezhiqiang01's avatar hezhiqiang01

modifi filepath

parent 6d507319
from mappo import algorithms, envs, runner, scripts, utils, config from . import algorithms, envs, runner, scripts, utils, config
__version__ = "0.1.0" __version__ = "0.1.0"
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
""" """
import torch import torch
from mappo.algorithms.algorithm.r_actor_critic import R_Actor, R_Critic from algorithms.algorithm.r_actor_critic import R_Actor, R_Critic
from mappo.utils.util import update_linear_schedule from utils.util import update_linear_schedule
class RMAPPOPolicy: class RMAPPOPolicy:
......
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mappo.algorithms.utils.util import init, check from algorithms.utils.util import init, check
from mappo.algorithms.utils.cnn import CNNBase from algorithms.utils.cnn import CNNBase
from mappo.algorithms.utils.mlp import MLPBase from algorithms.utils.mlp import MLPBase
from mappo.algorithms.utils.rnn import RNNLayer from algorithms.utils.rnn import RNNLayer
from mappo.algorithms.utils.act import ACTLayer from algorithms.utils.act import ACTLayer
from mappo.algorithms.utils.popart import PopArt from algorithms.utils.popart import PopArt
from mappo.utils.util import get_shape_from_obs_space from utils.util import get_shape_from_obs_space
class R_Actor(nn.Module): class R_Actor(nn.Module):
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mappo.utils.util import get_gard_norm, huber_loss, mse_loss from utils.util import get_gard_norm, huber_loss, mse_loss
from mappo.utils.valuenorm import ValueNorm from utils.valuenorm import ValueNorm
from mappo.algorithms.utils.util import check from algorithms.utils.util import check
class RMAPPO(): class RMAPPO():
......
...@@ -189,7 +189,7 @@ def get_config(): ...@@ -189,7 +189,7 @@ def get_config():
# network parameters # network parameters
parser.add_argument("--share_policy", action='store_false', parser.add_argument("--share_policy", action='store_false',
default=False, help='Whether agent share the same policy') default=True, help='Whether agent share the same policy')
parser.add_argument("--use_centralized_V", action='store_false', parser.add_argument("--use_centralized_V", action='store_false',
default=True, help="Whether to use centralized V function") default=True, help="Whether to use centralized V function")
parser.add_argument("--stacked_frames", type=int, default=1, parser.add_argument("--stacked_frames", type=int, default=1,
......
...@@ -7,8 +7,8 @@ from itertools import chain ...@@ -7,8 +7,8 @@ from itertools import chain
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from mappo.utils.separated_buffer import SeparatedReplayBuffer from utils.separated_buffer import SeparatedReplayBuffer
from mappo.utils.util import update_linear_schedule from utils.util import update_linear_schedule
def _t2n(x): def _t2n(x):
return x.detach().cpu().numpy() return x.detach().cpu().numpy()
...@@ -67,8 +67,8 @@ class Runner(object): ...@@ -67,8 +67,8 @@ class Runner(object):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
from mappo.algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from mappo.algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
self.policy = [] self.policy = []
......
...@@ -12,8 +12,8 @@ import numpy as np ...@@ -12,8 +12,8 @@ import numpy as np
from itertools import chain from itertools import chain
import torch import torch
from mappo.utils.util import update_linear_schedule from utils.util import update_linear_schedule
from mappo.runner.separated.base_runner import Runner from runner.separated.base_runner import Runner
import imageio import imageio
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import numpy as np import numpy as np
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from mappo.utils.shared_buffer import SharedReplayBuffer from utils.shared_buffer import SharedReplayBuffer
def _t2n(x): def _t2n(x):
"""Convert torch tensor to a numpy array.""" """Convert torch tensor to a numpy array."""
...@@ -63,8 +63,8 @@ class Runner(object): ...@@ -63,8 +63,8 @@ class Runner(object):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
from mappo.algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from mappo.algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
share_observation_space = self.envs.share_observation_space[0] if self.use_centralized_V else self.envs.observation_space[0] share_observation_space = self.envs.share_observation_space[0] if self.use_centralized_V else self.envs.observation_space[0]
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import time import time
import numpy as np import numpy as np
import torch import torch
from mappo.runner.shared.base_runner import Runner from runner.shared.base_runner import Runner
import wandb import wandb
import imageio import imageio
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
from config import get_config from config import get_config
from .envs.env_wrappers import SubprocVecEnv, DummyVecEnv from envs.env_wrappers import SubprocVecEnv, DummyVecEnv
"""Train script for MPEs.""" """Train script for MPEs."""
...@@ -124,9 +124,9 @@ def main(args): ...@@ -124,9 +124,9 @@ def main(args):
# run experiments # run experiments
if all_args.share_policy: if all_args.share_policy:
from mappo.runner.shared.env_runner import EnvRunner as Runner from runner.shared.env_runner import EnvRunner as Runner
else: else:
from mappo.runner.separated.env_runner import EnvRunner as Runner from runner.separated.env_runner import EnvRunner as Runner
runner = Runner(config) runner = Runner(config)
runner.run() runner.run()
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from mappo.utils.util import check, get_shape_from_obs_space, get_shape_from_act_space from utils.util import check, get_shape_from_obs_space, get_shape_from_act_space
def _flatten(T, N, x): def _flatten(T, N, x):
return x.reshape(T * N, *x.shape[2:]) return x.reshape(T * N, *x.shape[2:])
......
import torch import torch
import numpy as np import numpy as np
from mappo.utils.util import get_shape_from_obs_space, get_shape_from_act_space from utils.util import get_shape_from_obs_space, get_shape_from_act_space
def _flatten(T, N, x): def _flatten(T, N, x):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment