gomoku_rl.policy.base module

class gomoku_rl.policy.base.Policy(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device='cuda')[source]

Bases: ABC

REGISTRY: dict[str, Type[Policy]] = {'DQN': <class 'gomoku_rl.policy.dqn.DQN'>, 'PPO': <class 'gomoku_rl.policy.ppo.PPO'>, 'dqn': <class 'gomoku_rl.policy.dqn.DQN'>, 'ppo': <class 'gomoku_rl.policy.ppo.PPO'>}
abstract __call__(tensordict: TensorDict) TensorDict[source]

Defines the computation performed at every call of the policy.

Parameters:

tensordict (TensorDict) – Input tensor dictionary containing at least the observation data.

Returns:

Output tensor dictionary containing at least the actions to be taken in the environment.

Return type:

TensorDict

abstract __init__(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device='cuda')[source]

Initializes the policy.

Parameters:
  • cfg (DictConfig) – Configuration object containing policy-specific settings.

  • action_spec (DiscreteTensorSpec) – Specification of the action space.

  • observation_spec (TensorSpec) – Specification of the observation space.

  • device – The device (e.g., ‘cuda’ or ‘cpu’) where the policy’s tensors will be allocated. Defaults to ‘cuda’.

abstract eval()[source]

Sets the policy to evaluation mode.

abstract learn(data: TensorDict) Dict[source]

Updates the policy based on a batch of data.

Parameters:

data (TensorDict) – A batch of data typically including observations, actions, rewards, and next observations.

Returns:

A dictionary containing information about the learning step, such as loss values.

Return type:

Dict

abstract load_state_dict(state_dict: Dict)[source]

Loads the policy state from a dictionary.

Parameters:

state_dict (Dict) – the state of the policy.

abstract state_dict() Dict[source]

Returns the state of the policy as a dictionary.

Returns:

the state of the policy.

Return type:

Dict

abstract train()[source]

Sets the policy to training mode.