Source code for gomoku_rl.policy.dqn

from omegaconf import DictConfig
from torch.cuda import _device_t
from torchrl.data.tensor_specs import DiscreteTensorSpec, TensorSpec
from .base import Policy

from .common import make_dqn_actor, get_optimizer, make_egreedy_actor

from typing import Callable, Dict, List, Any
from tensordict import TensorDictBase, TensorDict
import torch
import torch.nn as nn

from tensordict.nn import TensorDictModule
from torchrl.objectives import DQNLoss, SoftUpdate

from torchrl.data import TensorSpec, DiscreteTensorSpec
from torch.cuda import _device_t


from torchrl.data.replay_buffers.samplers import RandomSampler, Sampler
from omegaconf import DictConfig

from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage


[docs] def get_replay_buffer( buffer_size: int, batch_size: int, sampler: Sampler | None = None, device: _device_t = None, ): storage = LazyTensorStorage(max_size=buffer_size, device=device) buffer = TensorDictReplayBuffer( storage=storage, batch_size=batch_size, sampler=sampler ) return buffer
[docs] class DQN(Policy): def __init__( self, cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device: _device_t = "cuda", ) -> None: super().__init__(cfg, action_spec, observation_spec, device) self.cfg: DictConfig = cfg self.device: _device_t = device self.actor = make_dqn_actor(cfg, action_spec, device) fake_input = observation_spec.zero() fake_input["action_mask"] = ~fake_input["action_mask"] with torch.no_grad(): self.actor(fake_input) self.actor_explore = make_egreedy_actor( actor=self.actor, action_spec=action_spec, annealing_num_steps=cfg.annealing_num_steps, eps_init=cfg.eps_init, eps_end=cfg.eps_end, ) self.max_grad_norm: float = cfg.max_grad_norm self.gamma = cfg.gamma self.batch_size: int = cfg.batch_size self.buffer_size: int = cfg.buffer_size # Optimization steps per batch collected (aka UPD or updates per data) self.n_optim: int = cfg.n_optim self.target_update_interval: int = cfg.target_update_interval self.loss_module = DQNLoss(self.actor, delay_value=True) self.loss_module.make_value_estimator(gamma=self.gamma) self.target_updater = SoftUpdate(self.loss_module, eps=0.995) self.optimizer = get_optimizer( self.cfg.optimizer, self.loss_module.parameters() ) self.replay_buffer = get_replay_buffer( buffer_size=self.buffer_size, batch_size=self.batch_size, sampler=RandomSampler(), device=cfg.buffer_device, ) self._eval = False def __call__(self, tensordict: TensorDict) -> TensorDict: if self._eval: return self.actor(tensordict) else: tensordict = self.actor_explore(tensordict) self.actor_explore.module[-1].step() # print(self.actor_explore.module[-1].eps) return tensordict
[docs] def learn(self, data: TensorDict) -> Dict: invalid = data.get("invalid", None) if invalid is not None: data = data[~invalid] data["next", "done"] = data["next", "done"].unsqueeze(-1) self.replay_buffer.extend(data.reshape(-1)) if len(self.replay_buffer) < self.buffer_size: return {} losses = [] grad_norms = [] for gradient_step in range(1, self.n_optim + 1): transition = self.replay_buffer.sample().to(self.device) loss: torch.Tensor = self.loss_module(transition)["loss"] losses.append(loss.clone().detach()) self.optimizer.zero_grad() loss.backward() grad_norm = nn.utils.clip_grad_norm_( self.loss_module.parameters(), self.max_grad_norm ) grad_norms.append(grad_norm) self.optimizer.step() if gradient_step % self.target_update_interval: self.target_updater.step() avg_loss = torch.stack(losses).mean().item() avg_grad_norm = torch.stack(grad_norms).mean().item() return { "loss": avg_loss, "grad_norm": avg_grad_norm, }
[docs] def eval(self): self.actor.eval() self._eval = True
[docs] def train(self): self.actor.train() self._eval = False
[docs] def state_dict(self) -> Dict: # Resuming training will become impossible if the parameters of the loss_module are discarded. return { "actor": self.actor.state_dict(), "loss": self.loss_module.state_dict(), }
[docs] def load_state_dict(self, state_dict: Dict): self.actor.load_state_dict(state_dict["actor"]) self.loss_module.load_state_dict(state_dict["loss"]) self.loss_module.make_value_estimator(gamma=self.gamma) self.target_updater = SoftUpdate(self.loss_module, eps=0.995)