Module likelihood.models.environments
Classes
class ActionSpace (num_actions)-
Expand source code
class ActionSpace: def __init__(self, num_actions): self._num_actions = num_actions @property def n(self): return self._num_actionsInstance variables
prop n-
Expand source code
@property def n(self): return self._num_actions
class OptionCriticEnv (episodes: Dict[int, Dict[str, List]])-
Expand source code
class OptionCriticEnv: """ An environment for Option Critic reinforcement learning that processes a dataset of episodes. Attributes ---------- episodes : `Dict[str, Dict]` Dataset of episodes with state, action, selected_option, reward, next_state, and done information. observation_space : `np.ndarray` Initial observation space shape (from first episode's state) done : `bool` Whether the current episode has terminated num_options : `int` Number of distinct options available in the dataset actions_by_option : `defaultdict(set)` Maps selected options to sets of actions that were taken with them unique_actions_count : `List[int]` Count of unique actions per option index (used for action space definition) action_space : `ActionSpace` Custom action space defined by unique actions per option idx_episode : `int` Current episode index being processed current_state : `np.ndarray` Current state observation in the environment """ def __init__( self, episodes: Dict[int, Dict[str, List]], ): """ Initializes the OptionCriticEnv with a dataset of episodes. Parameters ---------- episodes : `Dict[int, Dict]` Dataset of episodes where keys are episode identifiers and values are episode data. Each episode must contain at least: - "state": List of state observations - "selected_option": List of selected options - "action": List of actions taken - "reward": List of rewards - "next_state": List of next states - "done": List of termination flags Raises ------ ValueError If required fields ("state" or "selected_option") are missing from episode data """ self.episodes = episodes required_keys = ["state", "action", "selected_option", "reward", "next_state", "done"] for episode_id, data in episodes.items(): if not all(k in data for k in required_keys): raise ValueError( f"Episode {episode_id} missing keys: {set(required_keys) - set(data.keys())}" ) self.observation_space = np.array(episodes[0]["state"][0]) self.done = False self.idx_episode = 0 self.current_state = None self.num_options = len(set(episodes[0]["selected_option"])) self.actions_by_option = defaultdict(set) # Build fast lookup for transitions self.state_action_option_to_transition: Dict[Tuple, Dict[str, Any]] = {} for episode_id, data in episodes.items(): states = data["state"] actions = data["action"] options = data["selected_option"] next_states = data["next_state"] rewards = data["reward"] dones = data["done"] for i in range(len(states)): state_key = tuple(states[i]) key = (state_key, options[i], actions[i]) self.state_action_option_to_transition[key] = { "next_state": next_states[i], "reward": rewards[i], "done": dones[i], } for i, selected in enumerate(options): self.actions_by_option[selected].add(actions[i]) self.unique_actions_count = [ len(self.actions_by_option.get(i, set())) for i in range(max(self.actions_by_option.keys()) + 1) ] self.action_space = ActionSpace(self.unique_actions_count) def reset(self) -> tuple[np.ndarray, dict]: """ Resets the environment to a random episode and returns the initial state. Returns ------- observation : `np.ndarray` Initial state observation info : `Dict` Empty dictionary (no additional information) """ episode_id = np.random.choice(list(self.episodes.keys())) self.idx_episode = episode_id self.current_state = self.episodes[episode_id]["state"][0] return self.current_state, {} def step(self, action: int, option: int) -> tuple[np.ndarray, float, bool, bool, dict]: """ Takes an action with a specific option and returns the next state, reward, and termination status. Parameters ---------- action : `int` Action index to execute option : `int` Selected option index Returns ------- next_state : `np.ndarray` State after taking the action reward : `float` Immediate reward for the transition done : `bool` Whether the episode has terminated (from episode data) terminated : `bool` Whether the action-option pair was found in the dataset info : `Dict` Empty dictionary (no additional information) """ key = (tuple(self.current_state), option, action) if key in self.state_action_option_to_transition: trans = self.state_action_option_to_transition[key] self.current_state = trans["next_state"] return trans["next_state"].copy(), trans["reward"], trans["done"], True, {} else: return self.current_state, 0.0, False, False, {}An environment for Option Critic reinforcement learning that processes a dataset of episodes.
Attributes
episodes:Dict[str, Dict]- Dataset of episodes with state, action, selected_option, reward, next_state, and done information.
observation_space:np.ndarray- Initial observation space shape (from first episode's state)
done:bool- Whether the current episode has terminated
num_options:int- Number of distinct options available in the dataset
actions_by_option:defaultdict(set)- Maps selected options to sets of actions that were taken with them
unique_actions_count:List[int]- Count of unique actions per option index (used for action space definition)
action_space:ActionSpace- Custom action space defined by unique actions per option
idx_episode:int- Current episode index being processed
current_state:np.ndarray- Current state observation in the environment
Initializes the OptionCriticEnv with a dataset of episodes.
Parameters
episodes:Dict[int, Dict]- Dataset of episodes where keys are episode identifiers and values are episode data. Each episode must contain at least: - "state": List of state observations - "selected_option": List of selected options - "action": List of actions taken - "reward": List of rewards - "next_state": List of next states - "done": List of termination flags
Raises
ValueError- If required fields ("state" or "selected_option") are missing from episode data
Methods
def reset(self) ‑> tuple[numpy.ndarray, dict]-
Expand source code
def reset(self) -> tuple[np.ndarray, dict]: """ Resets the environment to a random episode and returns the initial state. Returns ------- observation : `np.ndarray` Initial state observation info : `Dict` Empty dictionary (no additional information) """ episode_id = np.random.choice(list(self.episodes.keys())) self.idx_episode = episode_id self.current_state = self.episodes[episode_id]["state"][0] return self.current_state, {}Resets the environment to a random episode and returns the initial state.
Returns
observation:np.ndarray- Initial state observation
info:Dict- Empty dictionary (no additional information)
def step(self, action: int, option: int) ‑> tuple[numpy.ndarray, float, bool, bool, dict]-
Expand source code
def step(self, action: int, option: int) -> tuple[np.ndarray, float, bool, bool, dict]: """ Takes an action with a specific option and returns the next state, reward, and termination status. Parameters ---------- action : `int` Action index to execute option : `int` Selected option index Returns ------- next_state : `np.ndarray` State after taking the action reward : `float` Immediate reward for the transition done : `bool` Whether the episode has terminated (from episode data) terminated : `bool` Whether the action-option pair was found in the dataset info : `Dict` Empty dictionary (no additional information) """ key = (tuple(self.current_state), option, action) if key in self.state_action_option_to_transition: trans = self.state_action_option_to_transition[key] self.current_state = trans["next_state"] return trans["next_state"].copy(), trans["reward"], trans["done"], True, {} else: return self.current_state, 0.0, False, False, {}Takes an action with a specific option and returns the next state, reward, and termination status.
Parameters
action:int- Action index to execute
option:int- Selected option index
Returns
next_state:np.ndarray- State after taking the action
reward:float- Immediate reward for the transition
done:bool- Whether the episode has terminated (from episode data)
terminated:bool- Whether the action-option pair was found in the dataset
info:Dict- Empty dictionary (no additional information)