learning.replay_buffer#

Module Contents#

class learning.replay_buffer.ReplayBuffer(plant: jacta.planner.dynamics.simulator_plant.SimulatorPlant, params: jacta.planner.core.parameter_container.ParameterContainer)#
Parameters:
  • plant (jacta.planner.dynamics.simulator_plant.SimulatorPlant) –

  • params (jacta.planner.core.parameter_container.ParameterContainer) –

reset() None#
Return type:

None

reset_next_temporary_id() None#
Return type:

None

add_nodes(root_ids: torch.IntTensor, parent_ids: torch.IntTensor, states: torch.FloatTensor, start_actions: torch.FloatTensor, end_actions: torch.FloatTensor, relative_actions: torch.FloatTensor, temporary: bool = False, sub_goal_state: torch.FloatTensor = None) int#
Parameters:
  • root_ids (torch.IntTensor) –

  • parent_ids (torch.IntTensor) –

  • states (torch.FloatTensor) –

  • start_actions (torch.FloatTensor) –

  • end_actions (torch.FloatTensor) –

  • relative_actions (torch.FloatTensor) –

  • temporary (bool) –

  • sub_goal_state (torch.FloatTensor) –

Return type:

int

sampling(batch_size: int, her_probability: float, reward_function: Callable) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]#

Sample a batch at random with HER goal resampling from replay experience.

Parameters:
  • batch_size (int) – Batch size.

  • her_probability (float) – Probability of resampling a goal with an achieved goal.

  • reward_function (Callable) – Reward function of the environment to recalculate the rewards.

Returns:

A tuple of the sampled state, action, reward, next_state, goal batch.

Raises:

Assertion error – Dimension check on states failed.

Return type:

Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]