Module ai.rl.dqn.rainbow.trainers.seed
Distributed trainer, based on the SEED architecture.
Expand source code
"""Distributed trainer, based on the SEED architecture."""
from ._actor import Actor
from ._trainer import Trainer
from ._config import Config
__all__ = ["Actor", "Trainer", "Config"]
Classes
class Actor (agent_config: AgentConfig, config: Config, environment: Factory, data_port: int, router_port: int, logging_client: Client = None, daemon: bool = True)
-
Process objects represent activity that is run in a separate process
The class is analogous to
threading.Thread
Ancestors
- multiprocessing.context.Process
- multiprocessing.process.BaseProcess
Methods
def run(self)
-
Method to be run in sub-process; can be overridden in sub-class
Expand source code
def run(self): for _ in range(self._config.actor_threads): self._threads.append(ActorThread(*self._args, **self._kwargs)) for thread in self._threads: thread.start() while True: time.sleep(5.0)
class Config (max_environment_steps: int = -1, n_step: int = 3, epsilon: float = 0.1, actor_processes: int = 1, actor_threads: int = 4, inference_servers: int = 1, broadcast_period: float = 2.5, inference_batchsize: int = 4, inference_delay: float = 0.1, inference_device: torch.device = device(type='cpu'), minimum_buffer_size: int = 1000, max_train_frequency: float = -1)
-
Trainer config.
Class variables
var actor_processes : int
var actor_threads : int
var broadcast_period : float
var epsilon : float
var inference_batchsize : int
var inference_delay : float
var inference_device : torch.device
var inference_servers : int
var max_environment_steps : int
var max_train_frequency : float
var minimum_buffer_size : int
var n_step : int
class Trainer (agent: Agent, config: Config, environment: Factory)
-
SEED trainer.
Methods
def start(self, duration: float)
-
Starts training, and blocks until completed.
Args
duration
:float
- Training duration in seconds.
Expand source code
def start(self, duration: float): """Starts training, and blocks until completed. Args: duration (float): Training duration in seconds. """ proxy = seed.InferenceProxy() router_port, dealer_port = proxy.start() data_sub = zmq.Context.instance().socket(zmq.SUB) data_sub.subscribe("") data_port = data_sub.bind_to_random_port("tcp://*") broadcaster = seed.Broadcaster(self._agent.model_instance, self._config.broadcast_period) broadcast_port = broadcaster.start() logger = create_logger() logger_port = logger.start() self._agent.set_logging_client(logging.Client("localhost", logger_port)) stop_event = threading.Event() data_listening_thread = threading.Thread( target=data_listener, args=(self._agent, data_sub, logger_port, stop_event), daemon=True, ) data_listening_thread.start() servers = [ create_server(self, dealer_port, broadcast_port) for _ in range(self._config.inference_servers) ] for server in servers: server.start() # Allow some time to start server. time.sleep(5.0) actors = [ create_actor(self, data_port, router_port, logger_port) for _ in range(self._config.actor_processes) ] for actor in actors: actor.start() training_thread = threading.Thread( target=trainer, args=(self._agent, self._config, stop_event, logger_port), daemon=True, ) training_thread.start() # Allow some time to start actors. time.sleep(5.0) start = time.perf_counter() while time.perf_counter() - start < duration: time.sleep(5.0) for i in range(len(actors)): if not actors[i].is_alive(): actors[i] = create_actor(self, data_port, router_port, logger_port) actors[i].start() print("Restarted actor...") stop_event.set() for actor in actors: actor.terminate() for server in servers: server.terminate() for actor in actors: actor.join() for server in servers: server.join() data_sub.close() training_thread.join() data_listening_thread.join()