gomoku_rl.utils.augment module

class gomoku_rl.utils.augment.AntiDiagonalFlip[source]

Bases: Transform

inverse_map_board(board: Tensor)[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor)[source]
map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.DiagonalFlip[source]

Bases: Transform

inverse_map_board(board: Tensor)[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor)[source]
map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.HorizontalFlip[source]

Bases: Transform

inverse_map_board(board: Tensor)[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor)[source]
map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.Identity[source]

Bases: Transform

inverse_map_board(board: Tensor) Tensor[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor) Tensor[source]
map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.Rotation(k: int = 1)[source]

Bases: Transform

inverse_map_board(board: Tensor)[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor)[source]
map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.Transform[source]

Bases: ABC

abstract inverse_map_board(board: Tensor) Tensor[source]
abstract inverse_map_index(index: Tensor, board_size: int) Tensor[source]
abstract map_board(board: Tensor) Tensor[source]
abstract map_index(index: Tensor, board_size: int) Tensor[source]
class gomoku_rl.utils.augment.VerticalFlip[source]

Bases: Transform

inverse_map_board(board: Tensor)[source]
inverse_map_index(index: Tensor, board_size: int) Tensor[source]
map_board(board: Tensor)[source]
map_index(index: Tensor, board_size: int) Tensor[source]
gomoku_rl.utils.augment.augment_transition(transition: TensorDict) TensorDict[source]
gomoku_rl.utils.augment.get_augmented_transition(transition: TensorDict, transform: Transform, inplace: bool = False) TensorDict[source]