gomoku_rl.policy.ppo module
- class gomoku_rl.policy.ppo.PPO(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device: device | str | int | None = 'cuda')[source]
Bases:
Policy
- learn(data: TensorDict)[source]
Updates the policy based on a batch of data.
- Parameters:
data (TensorDict) – A batch of data typically including observations, actions, rewards, and next observations.
- Returns:
A dictionary containing information about the learning step, such as loss values.
- Return type:
Dict
- load_state_dict(state_dict: Dict)[source]
Loads the policy state from a dictionary.
- Parameters:
state_dict (Dict) – the state of the policy.