learning.replay_buffer
#
Module Contents#
- class learning.replay_buffer.ReplayBuffer(plant: jacta.planner.dynamics.simulator_plant.SimulatorPlant, params: jacta.planner.core.parameter_container.ParameterContainer)#
- Parameters:
plant (jacta.planner.dynamics.simulator_plant.SimulatorPlant) –
params (jacta.planner.core.parameter_container.ParameterContainer) –
- reset() None #
- Return type:
None
- reset_next_temporary_id() None #
- Return type:
None
- add_nodes(root_ids: torch.IntTensor, parent_ids: torch.IntTensor, states: torch.FloatTensor, start_actions: torch.FloatTensor, end_actions: torch.FloatTensor, relative_actions: torch.FloatTensor, temporary: bool = False, sub_goal_state: torch.FloatTensor = None) int #
- Parameters:
root_ids (torch.IntTensor) –
parent_ids (torch.IntTensor) –
states (torch.FloatTensor) –
start_actions (torch.FloatTensor) –
end_actions (torch.FloatTensor) –
relative_actions (torch.FloatTensor) –
temporary (bool) –
sub_goal_state (torch.FloatTensor) –
- Return type:
int
- sampling(batch_size: int, her_probability: float, reward_function: Callable) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor] #
Sample a batch at random with HER goal resampling from replay experience.
- Parameters:
batch_size (int) – Batch size.
her_probability (float) – Probability of resampling a goal with an achieved goal.
reward_function (Callable) – Reward function of the environment to recalculate the rewards.
- Returns:
A tuple of the sampled state, action, reward, next_state, goal batch.
- Raises:
Assertion error – Dimension check on states failed.
- Return type:
Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]