Module ai.rl.a3c.trainer

Trainer module for A3C.

Expand source code
"""Trainer module for A3C."""

from ._config import Config
from ._trainer import Trainer


__all__ = ["Config", "Trainer"]

Classes

class Config

Trainer configuration.

class Trainer (config: Config, environment: Factory, network: torch.nn.modules.module.Module, optimizer_class: torch.optim.optimizer.Optimizer, optimizer_params: Mapping[str, Any])

A3C trainer. Spawns multiple processes that each get a copy of the network.

Args

config : trainer.Config
Trainer configuration.
environment : environments.Factory
Environment factory.
network : nn.Module
Network with two outputs, policy logits and state value.
optimizer_class : optim.Optimizer
Optimizer class.
optimizer_params : Mapping[str, Any]
Keyword arguments sent to the optimizer class at initialization.

Methods

def start(self)

Starts the training and blocks until it has finished.

Expand source code
def start(self):
    """Starts the training and blocks until it has finished."""

    log_port = self._logger.start()

    self._workers = [
        Worker(
            self._config,
            self._environment,
            self._network,
            self._optimizer_class,
            self._optimizer_params,
            log_port,
        )
        for _ in range(self._config.workers)
    ]

    for worker in self._workers:
        worker.start()

    start_time = perf_counter()
    while perf_counter() - start_time < self._config.train_time:
        sleep(5.0)

    for worker in self._workers:
        worker.terminate()
    self._logger.terminate()
    for worker in self._workers:
        worker.join()
    self._logger.join()