gomoku_rl.policy.base module
- class gomoku_rl.policy.base.Policy(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device='cuda')[source]
Bases:
ABC
- REGISTRY: dict[str, Type[Policy]] = {'DQN': <class 'gomoku_rl.policy.dqn.DQN'>, 'PPO': <class 'gomoku_rl.policy.ppo.PPO'>, 'dqn': <class 'gomoku_rl.policy.dqn.DQN'>, 'ppo': <class 'gomoku_rl.policy.ppo.PPO'>}
- abstract __call__(tensordict: TensorDict) TensorDict [source]
Defines the computation performed at every call of the policy.
- Parameters:
tensordict (TensorDict) – Input tensor dictionary containing at least the observation data.
- Returns:
Output tensor dictionary containing at least the actions to be taken in the environment.
- Return type:
TensorDict
- abstract __init__(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device='cuda')[source]
Initializes the policy.
- Parameters:
cfg (DictConfig) – Configuration object containing policy-specific settings.
action_spec (DiscreteTensorSpec) – Specification of the action space.
observation_spec (TensorSpec) – Specification of the observation space.
device – The device (e.g., ‘cuda’ or ‘cpu’) where the policy’s tensors will be allocated. Defaults to ‘cuda’.
- abstract learn(data: TensorDict) Dict [source]
Updates the policy based on a batch of data.
- Parameters:
data (TensorDict) – A batch of data typically including observations, actions, rewards, and next observations.
- Returns:
A dictionary containing information about the learning step, such as loss values.
- Return type:
Dict
- abstract load_state_dict(state_dict: Dict)[source]
Loads the policy state from a dictionary.
- Parameters:
state_dict (Dict) – the state of the policy.