Commit ed31aa10 authored by hezhiqiang01's avatar hezhiqiang01

modifi filepath

parent 6d507319
from mappo import algorithms, envs, runner, scripts, utils, config
from . import algorithms, envs, runner, scripts, utils, config
__version__ = "0.1.0"
......
......@@ -6,8 +6,8 @@
"""
import torch
from mappo.algorithms.algorithm.r_actor_critic import R_Actor, R_Critic
from mappo.utils.util import update_linear_schedule
from algorithms.algorithm.r_actor_critic import R_Actor, R_Critic
from utils.util import update_linear_schedule
class RMAPPOPolicy:
......
......@@ -7,13 +7,13 @@
import torch
import torch.nn as nn
from mappo.algorithms.utils.util import init, check
from mappo.algorithms.utils.cnn import CNNBase
from mappo.algorithms.utils.mlp import MLPBase
from mappo.algorithms.utils.rnn import RNNLayer
from mappo.algorithms.utils.act import ACTLayer
from mappo.algorithms.utils.popart import PopArt
from mappo.utils.util import get_shape_from_obs_space
from algorithms.utils.util import init, check
from algorithms.utils.cnn import CNNBase
from algorithms.utils.mlp import MLPBase
from algorithms.utils.rnn import RNNLayer
from algorithms.utils.act import ACTLayer
from algorithms.utils.popart import PopArt
from utils.util import get_shape_from_obs_space
class R_Actor(nn.Module):
......
......@@ -8,9 +8,9 @@
import numpy as np
import torch
import torch.nn as nn
from mappo.utils.util import get_gard_norm, huber_loss, mse_loss
from mappo.utils.valuenorm import ValueNorm
from mappo.algorithms.utils.util import check
from utils.util import get_gard_norm, huber_loss, mse_loss
from utils.valuenorm import ValueNorm
from algorithms.utils.util import check
class RMAPPO():
......
......@@ -189,7 +189,7 @@ def get_config():
# network parameters
parser.add_argument("--share_policy", action='store_false',
default=False, help='Whether agent share the same policy')
default=True, help='Whether agent share the same policy')
parser.add_argument("--use_centralized_V", action='store_false',
default=True, help="Whether to use centralized V function")
parser.add_argument("--stacked_frames", type=int, default=1,
......
......@@ -7,8 +7,8 @@ from itertools import chain
import torch
from tensorboardX import SummaryWriter
from mappo.utils.separated_buffer import SeparatedReplayBuffer
from mappo.utils.util import update_linear_schedule
from utils.separated_buffer import SeparatedReplayBuffer
from utils.util import update_linear_schedule
def _t2n(x):
return x.detach().cpu().numpy()
......@@ -67,8 +67,8 @@ class Runner(object):
os.makedirs(self.save_dir)
from mappo.algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from mappo.algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
self.policy = []
......
......@@ -12,8 +12,8 @@ import numpy as np
from itertools import chain
import torch
from mappo.utils.util import update_linear_schedule
from mappo.runner.separated.base_runner import Runner
from utils.util import update_linear_schedule
from runner.separated.base_runner import Runner
import imageio
......
......@@ -3,7 +3,7 @@ import os
import numpy as np
import torch
from tensorboardX import SummaryWriter
from mappo.utils.shared_buffer import SharedReplayBuffer
from utils.shared_buffer import SharedReplayBuffer
def _t2n(x):
"""Convert torch tensor to a numpy array."""
......@@ -63,8 +63,8 @@ class Runner(object):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
from mappo.algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from mappo.algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
from algorithms.algorithm.r_mappo import RMAPPO as TrainAlgo
from algorithms.algorithm.rMAPPOPolicy import RMAPPOPolicy as Policy
share_observation_space = self.envs.share_observation_space[0] if self.use_centralized_V else self.envs.observation_space[0]
......
......@@ -15,7 +15,7 @@
import time
import numpy as np
import torch
from mappo.runner.shared.base_runner import Runner
from runner.shared.base_runner import Runner
import wandb
import imageio
......
......@@ -15,7 +15,7 @@ import numpy as np
from pathlib import Path
import torch
from config import get_config
from .envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from envs.env_wrappers import SubprocVecEnv, DummyVecEnv
"""Train script for MPEs."""
......@@ -124,9 +124,9 @@ def main(args):
# run experiments
if all_args.share_policy:
from mappo.runner.shared.env_runner import EnvRunner as Runner
from runner.shared.env_runner import EnvRunner as Runner
else:
from mappo.runner.separated.env_runner import EnvRunner as Runner
from runner.separated.env_runner import EnvRunner as Runner
runner = Runner(config)
runner.run()
......
......@@ -2,7 +2,7 @@ import torch
import numpy as np
from collections import defaultdict
from mappo.utils.util import check, get_shape_from_obs_space, get_shape_from_act_space
from utils.util import check, get_shape_from_obs_space, get_shape_from_act_space
def _flatten(T, N, x):
return x.reshape(T * N, *x.shape[2:])
......
import torch
import numpy as np
from mappo.utils.util import get_shape_from_obs_space, get_shape_from_act_space
from utils.util import get_shape_from_obs_space, get_shape_from_act_space
def _flatten(T, N, x):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment