gomoku_rl.policy.dqn module

class gomoku_rl.policy.dqn.DQN(cfg: DictConfig, action_spec: DiscreteTensorSpec, observation_spec: TensorSpec, device: device | str | int | None = 'cuda')[source]

Bases: Policy

eval()[source]

Sets the policy to evaluation mode.

learn(data: TensorDict) Dict[source]

Updates the policy based on a batch of data.

Parameters:

data (TensorDict) – A batch of data typically including observations, actions, rewards, and next observations.

Returns:

A dictionary containing information about the learning step, such as loss values.

Return type:

Dict

load_state_dict(state_dict: Dict)[source]

Loads the policy state from a dictionary.

Parameters:

state_dict (Dict) – the state of the policy.

state_dict() Dict[source]

Returns the state of the policy as a dictionary.

Returns:

the state of the policy.

Return type:

Dict

train()[source]

Sets the policy to training mode.

gomoku_rl.policy.dqn.get_replay_buffer(buffer_size: int, batch_size: int, sampler: Sampler | None = None, device: device | str | int | None = None)[source]