gomoku_rl.collector module
- class gomoku_rl.collector.BlackPlayCollector(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Bases:
Collector
- __init__(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Initializes a collector for capturing game transitions where the black player is controlled by a trainable policy against a white player using a fixed policy.
- Parameters:
env (GomokuEnv) – The game environment where the collection takes place.
policy_black (_policy_t) – The trainable policy used for the black player.
policy_white (_policy_t) – The fixed policy for the white player, simulating a consistent opponent.
out_device – The device (e.g., CPU, GPU) where the collected data will be stored. Defaults to the environment’s device if not specified.
augment (bool, optional) – Whether to apply data augmentation to the collected transitions, enhancing the dataset’s diversity. Defaults to False.
- rollout(steps: int) tuple[TensorDict, dict] [source]
Executes a data collection session over a specified number of game steps, focusing on transitions involving the black player.
- Parameters:
steps (int) – The total number of steps to collect data for. This will be adjusted to ensure an even number of steps for symmetry in turn-taking.
- Returns:
A tuple containing the collected transitions for the black player and a dictionary with additional information such as the frames per second (fps) achieved during the collection.
- Return type:
tuple[TensorDict, dict]
- class gomoku_rl.collector.SelfPlayCollector(env: GomokuEnv, policy: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Bases:
Collector
- __init__(env: GomokuEnv, policy: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Initializes a collector for self-play data in a Gomoku environment.
This collector facilitates the collection of game transitions generated through self-play, where both players use the same policy.
- Parameters:
env (GomokuEnv) – The Gomoku game environment where self-play will be conducted.
policy (_policy_t) – The policy function to be used for both players during self-play.
out_device – The device on which collected data will be stored. If None, uses the device specified by the environment.
augment (bool, optional) – If True, applies data augmentation to the collected transitions. Defaults to False.
- rollout(steps: int) tuple[TensorDict, dict] [source]
Executes a rollout in the environment, collecting data for a specified number of steps.
- Parameters:
steps (int) – The number of steps to execute in the environment for this rollout.
- Returns:
- A tuple containing two elements:
A TensorDict holding the collected transitions from the rollout. Each transition includes the game state before the action, the action taken, and the resulting state.
A dictionary with additional information about the rollout.
- Return type:
tuple[TensorDict, dict]
- class gomoku_rl.collector.VersusPlayCollector(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Bases:
Collector
- __init__(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Initializes a collector for versus play data in a Gomoku environment, facilitating the collection of game transitions where two players, each using a distinct policy, compete against each other.
- Parameters:
env (GomokuEnv) – The Gomoku game environment where the two-player game will be conducted.
policy_black (_policy_t) – The policy function to be used for the black player.
policy_white (_policy_t) – The policy function to be used for the white player.
out_device – The device on which collected data will be stored. If None, uses the device specified by the environment.
augment (bool, optional) – If True, applies data augmentation to the collected transitions. Defaults to False.
- rollout(steps: int) tuple[TensorDict, TensorDict, dict] [source]
Executes a rollout in the environment, collecting data for a specified number of steps, alternating between the black and white policies.
- Parameters:
steps (int) – The number of steps to execute in the environment for this rollout. It is adjusted to be an even number to ensure an equal number of actions for both players.
- Returns:
- A tuple containing three elements:
A TensorDict of transitions collected for the black player, with each transition representing a game state before the black player’s action, the action taken, and the resulting state.
A TensorDict of transitions collected for the white player, structured similarly to the black player’s transitions. Note that for the first step, the white player does not take an action, so their collection starts from the second step.
A dictionary containing additional information about the rollout.
- Return type:
tuple
- class gomoku_rl.collector.WhitePlayCollector(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Bases:
Collector
- __init__(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], out_device=None, augment: bool = False)[source]
Initializes a collector focused on capturing game transitions from the perspective of the white player, who is controlled by a trainable policy, against a black player using a fixed policy.
- Parameters:
env (GomokuEnv) – The game environment where the collection takes place.
policy_black (_policy_t) – The fixed policy for the black player, providing a consistent challenge.
policy_white (_policy_t) – The trainable policy used for the white player.
out_device – The device for storing collected data, defaulting to the environment’s device if not specified.
augment (bool, optional) – Indicates whether to augment the collected transitions to enhance the dataset. Defaults to False.
- rollout(steps: int) tuple[TensorDict, dict] [source]
Performs a data collection session, focusing on the game transitions where the white player is active, over a specified number of steps.
- Parameters:
steps (int) – The number of steps for which data will be collected, adjusted to be even for fairness in gameplay.
- Returns:
A tuple containing the collected transitions for the white player and additional session information, such as collection performance (fps).
- Return type:
tuple[TensorDict, dict]
- gomoku_rl.collector.make_transition(tensordict_t_minus_1: TensorDict, tensordict_t: TensorDict, tensordict_t_plus_1: TensorDict) TensorDict [source]
Constructs a transition tensor dictionary for a two-player game by integrating the game state and actions from three consecutive time steps (t-1, t, and t+1).
- Parameters:
tensordict_t_minus_1 (TensorDict) – A tensor dictionary containing the game state and associated information at time t-1.
tensordict_t (TensorDict) – A tensor dictionary containing the game state and associated information at time t.
tensordict_t_plus_1 (TensorDict) – A tensor dictionary containing the game state and associated information at time t+1.
- Returns:
A new tensor dictionary representing the transition from time t-1 to t+1.
- Return type:
TensorDict
The function calculates rewards based on the win status at times t and t+1, and flags the transition as done if the game ends at either time t or t+1. The resulting tensor dictionary is structured to facilitate learning from this transition in reinforcement learning algorithms.
- gomoku_rl.collector.round(env: GomokuEnv, policy_black: Callable[[TensorDict], TensorDict], policy_white: Callable[[TensorDict], TensorDict], tensordict_t_minus_1: TensorDict, tensordict_t: TensorDict, return_black_transitions: bool = True, return_white_transitions: bool = True)[source]
Executes two sequential steps in the Gomoku environment, applying black and white policies alternately.
- Parameters:
env (GomokuEnv) – The Gomoku game environment instance.
policy_black (_policy_t) – The policy function for the black player, which determines the action based on the current game state.
policy_white (_policy_t) – The policy function for the white player, similar to policy_black but for the white player.
tensordict_t_minus_1 (TensorDict) – The game state tensor dictionary at time t-1, before the white player’s action.
tensordict_t (TensorDict) – The game state tensor dictionary at time t, before the black player’s action.
return_black_transitions (bool, optional) – If True, returns transition data for the black player. Defaults to True.
return_white_transitions (bool, optional) – If True, returns transition data for the white player. Defaults to True.
- Returns:
Contains transition data for the black player (if requested), transition data for the white player (if requested), the game state after the white player’s action (t+1), and the game state after the black player’s action (t+2).
- Return type:
tuple
Note
If the environment is reset at time t-1, the white player won’t make a move at time t. This is different from the black player’s behavior.
If the black player wins at time t and the environment is reset, the environment does not take a step at t+1. This affects the validity of the white player’s transition from t-1 to t+1, which is marked invalid where tensordict_t_minus_1[‘done’] is True.
- gomoku_rl.collector.self_play_step(env: GomokuEnv, policy: Callable[[TensorDict], TensorDict], tensordict_t_minus_1: TensorDict, tensordict_t: TensorDict)[source]
Executes a single step of self-play in a Gomoku environment using a specified policy.
- Parameters:
env (GomokuEnv) – The Gomoku game environment instance where the self-play step is executed.
policy (_policy_t) – The policy function to determine the next action based on the current game state.
tensordict_t_minus_1 (TensorDict) – The game state tensor dictionary at time t-1.
tensordict_t (TensorDict) – The game state tensor dictionary at time t.
- Returns:
Contains the transition information resulting from the action taken in this step, the game state tensor dictionary at time t (unchanged from the input), and the game state tensor dictionary at time t+1 after the action and potential reset.
The transition information includes the states at times t-1 and t, the action taken at time t, and the resulting state at time t+1.
The unchanged game state tensor dictionary at time t is returned to facilitate chaining of self-play steps or integration with other functions.
The updated game state tensor dictionary at time t+1 reflects the new state of the environment after applying the action and potentially resetting the environment if the game concluded in this step.
- Return type:
tuple