Module ai.rl.dqn.rainbow.trainers.seed

Distributed trainer, based on the SEED architecture.

Expand source code
"""Distributed trainer, based on the SEED architecture."""


from ._actor import Actor
from ._trainer import Trainer
from ._config import Config


__all__ = ["Actor", "Trainer", "Config"]

Classes

class Actor (agent_config: AgentConfig, config: Config, environment: Factory, data_port: int, router_port: int, logging_client: Client = None, daemon: bool = True)

Process objects represent activity that is run in a separate process

The class is analogous to threading.Thread

Ancestors

  • multiprocessing.context.Process
  • multiprocessing.process.BaseProcess

Methods

def run(self)

Method to be run in sub-process; can be overridden in sub-class

Expand source code
def run(self):
    for _ in range(self._config.actor_threads):
        self._threads.append(ActorThread(*self._args, **self._kwargs))
    for thread in self._threads:
        thread.start()

    while True:
        time.sleep(5.0)
class Config (max_environment_steps: int = -1, n_step: int = 3, epsilon: float = 0.1, actor_processes: int = 1, actor_threads: int = 4, inference_servers: int = 1, broadcast_period: float = 2.5, inference_batchsize: int = 4, inference_delay: float = 0.1, inference_device: torch.device = device(type='cpu'), minimum_buffer_size: int = 1000, max_train_frequency: float = -1)

Trainer config.

Class variables

var actor_processes : int
var actor_threads : int
var broadcast_period : float
var epsilon : float
var inference_batchsize : int
var inference_delay : float
var inference_device : torch.device
var inference_servers : int
var max_environment_steps : int
var max_train_frequency : float
var minimum_buffer_size : int
var n_step : int
class Trainer (agent: Agent, config: Config, environment: Factory)

SEED trainer.

Methods

def start(self, duration: float)

Starts training, and blocks until completed.

Args

duration : float
Training duration in seconds.
Expand source code
def start(self, duration: float):
    """Starts training, and blocks until completed.

    Args:
        duration (float): Training duration in seconds.
    """
    proxy = seed.InferenceProxy()
    router_port, dealer_port = proxy.start()

    data_sub = zmq.Context.instance().socket(zmq.SUB)
    data_sub.subscribe("")
    data_port = data_sub.bind_to_random_port("tcp://*")

    broadcaster = seed.Broadcaster(self._agent.model_instance, self._config.broadcast_period)
    broadcast_port = broadcaster.start()

    logger = create_logger()
    logger_port = logger.start()

    self._agent.set_logging_client(logging.Client("localhost", logger_port))

    stop_event = threading.Event()

    data_listening_thread = threading.Thread(
        target=data_listener,
        args=(self._agent, data_sub, logger_port, stop_event),
        daemon=True,
    )
    data_listening_thread.start()

    servers = [
        create_server(self, dealer_port, broadcast_port)
        for _ in range(self._config.inference_servers)
    ]
    for server in servers:
        server.start()

    # Allow some time to start server.
    time.sleep(5.0)

    actors = [
        create_actor(self, data_port, router_port, logger_port)
        for _ in range(self._config.actor_processes)
    ]
    for actor in actors:
        actor.start()

    training_thread = threading.Thread(
        target=trainer,
        args=(self._agent, self._config, stop_event, logger_port),
        daemon=True,
    )
    training_thread.start()

    # Allow some time to start actors.
    time.sleep(5.0)

    start = time.perf_counter()
    while time.perf_counter() - start < duration:
        time.sleep(5.0)

        for i in range(len(actors)):
            if not actors[i].is_alive():
                actors[i] = create_actor(self, data_port, router_port, logger_port)
                actors[i].start()
                print("Restarted actor...")

    stop_event.set()

    for actor in actors:
        actor.terminate()
    for server in servers:
        server.terminate()
    for actor in actors:
        actor.join()
    for server in servers:
        server.join()
    data_sub.close()

    training_thread.join()
    data_listening_thread.join()