Module ai.rl.a3c.trainer
Trainer module for A3C.
Expand source code
"""Trainer module for A3C."""
from ._config import Config
from ._trainer import Trainer
__all__ = ["Config", "Trainer"]
Classes
class Config
-
Trainer configuration.
class Trainer (config: Config, environment: Factory, network: torch.nn.modules.module.Module, optimizer_class: torch.optim.optimizer.Optimizer, optimizer_params: Mapping[str, Any])
-
A3C trainer. Spawns multiple processes that each get a copy of the network.
Args
config
:trainer.Config
- Trainer configuration.
environment
:environments.Factory
- Environment factory.
network
:nn.Module
- Network with two outputs, policy logits and state value.
optimizer_class
:optim.Optimizer
- Optimizer class.
optimizer_params
:Mapping[str, Any]
- Keyword arguments sent to the optimizer class at initialization.
Methods
def start(self)
-
Starts the training and blocks until it has finished.
Expand source code
def start(self): """Starts the training and blocks until it has finished.""" log_port = self._logger.start() self._workers = [ Worker( self._config, self._environment, self._network, self._optimizer_class, self._optimizer_params, log_port, ) for _ in range(self._config.workers) ] for worker in self._workers: worker.start() start_time = perf_counter() while perf_counter() - start_time < self._config.train_time: sleep(5.0) for worker in self._workers: worker.terminate() self._logger.terminate() for worker in self._workers: worker.join() self._logger.join()