Commit e9f82c59 authored by hzq's avatar hzq

[dev] 删除并行环境

parent 4bc6d041
......@@ -7,151 +7,17 @@ Modified from OpenAI Baselines code to work with multi-agent envs
"""
import numpy as np
import torch
from multiprocessing import Process, Pipe
from abc import ABC, abstractmethod
def tile_images(img_nhwc):
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
input: img_nhwc, list or array of images, ndim=4 once turned into array
n = batch index, h = height, w = width, c = channel
returns:
bigim_HWc, ndarray with ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
N, h, w, c = img_nhwc.shape
H = int(np.ceil(np.sqrt(N)))
W = int(np.ceil(float(N)/H))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
return img_Hh_Ww_c
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if 'bool' in done.__class__.__name__:
if done:
ob = env.reset()
else:
if np.all(done):
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send((ob))
elif cmd == 'render':
if data == "rgb_array":
fr = env.render(mode=data)
remote.send(fr)
elif data == "human":
env.render(mode=data)
elif cmd == 'reset_task':
ob = env.reset_task()
remote.send(ob)
elif cmd == 'close':
env.close()
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.share_observation_space, env.action_space))
else:
raise NotImplementedError
class ShareVecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
Used to batch data from multiple copies of an environment, so that
each observation becomes an batch of observations, and expected action is a batch of actions to
be applied per-environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, share_observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.share_observation_space = share_observation_space
self.action_space = action_space
@abstractmethod
def reset(self):
"""
Reset all the environments and return an array of
observations, or a dict of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
@abstractmethod
def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
@abstractmethod
def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a dict of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
def close_extras(self):
"""
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
pass
def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
# single env
class DummyVecEnv():
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
self.num_envs = len(env_fns)
self.observation_space = env.observation_space
self.share_observation_space = env.share_observation_space
self.action_space = env.action_space
self.actions = None
def step(self, actions):
"""
......@@ -161,109 +27,6 @@ class ShareVecEnv(ABC):
self.step_async(actions)
return self.step_wait()
def render(self, mode='human'):
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg)
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class SubprocVecEnv(ShareVecEnv):
def __init__(self, env_fns, spaces=None):
"""
envs: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, share_observation_space, action_space = self.remotes[0].recv()
ShareVecEnv.__init__(self, len(env_fns), observation_space,
share_observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
return np.stack(obs)
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
def render(self, mode="rgb_array"):
for remote in self.remotes:
remote.send(('render', mode))
if mode == "rgb_array":
frame = [remote.recv() for remote in self.remotes]
return np.stack(frame)
# single env
class DummyVecEnv(ShareVecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in env_fns]
env = self.envs[0]
ShareVecEnv.__init__(self, len(
env_fns), env.observation_space, env.share_observation_space, env.action_space)
self.actions = None
def step_async(self, actions):
self.actions = actions
......
......@@ -14,7 +14,7 @@ import numpy as np
from pathlib import Path
import torch
from config import get_config
from envs.env_wrappers import SubprocVecEnv, DummyVecEnv
from envs.env_wrappers import DummyVecEnv
"""Train script for MPEs."""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment