Source code for codeair.state_machine

# Copyright (c) 2025 Firia Labs
""" State Machine - lightweight task scheduler and state machine framework."""
import time
from ansi_term import *

[docs]class State: """ Represents a state in the state machine. """ def __init__(self, state_machine, handler): """ Initializes a State instance. (Used by StateMachine - not usually used directly by end-user code) Args: state_machine (StateMachine): The state machine this state belongs to. handler (function): The function that handles state events. """ self.name = str(handler).split()[1] # CircuitPython doesn't have __name__ attribute self.state_machine = state_machine self.handler_function = handler self._entry_time = None
[docs] def enter(self): """Called when entering the state.""" self._entry_time = time.monotonic_ns() self.handler_function(self.state_machine, "enter", None, 0) # Call handler with "enter" event, st_time=0
[docs] def exit(self): """Called when exiting the state.""" elapsed_time = (time.monotonic_ns() - self._entry_time) / 1000000000 self.handler_function(self.state_machine, "exit", None, elapsed_time) # Call handler with "exit" event, st_time at exit
[docs] def handle_event(self, event_type, event_data): """ Handles events within the state. Args: event_type (str): The type of event. event_data (any): The data associated with the event. """ elapsed_time = (time.monotonic_ns() - self._entry_time) / 1000000000 self.handler_function(self.state_machine, event_type, event_data, elapsed_time) # Pass event info and st_time
[docs]class StateMachine: """ A simple state machine framework with task scheduling capabilities. """ def __init__(self): """Initializes a StateMachine instance.""" self.states = {} # {state_func: State} self.current_state = None self.initial_state_func = None self.running = False self.debug = False self.debug_log = print self._recursion_depth = 0 self.MAX_RECURSION_DEPTH = 5 self.task_list = [] self.tasks_for_removal = [] self.TASK_WARNING_THRESHOLD = 0.05 # seconds
[docs] def enable_debug(self, do_debug, log_func=print): """ Enables or disables debug logging. Args: do_debug (bool): If True, enables debug logging. log_func (function): The function to use for logging. Defaults to print. """ self.debug = do_debug self.debug_log = log_func
[docs] def add_task(self, callback, interval, args=None): """ Adds a task to be run periodically at a specified interval. Args: callback (function): The function to be called. interval (float): The interval in seconds between task executions. args (tuple, optional): Arguments to pass to the callback function. """ if not callable(callback): raise TypeError("add_task() expects a function object.") if args is not None and not hasattr(args, 'index'): args = (args,) next_run = time.monotonic_ns() + int(interval * 1000000000) self.task_list.append((next_run, interval, callback, args))
[docs] def remove_task(self, callback): """ Removes a task from the task list. Args: callback (function): The function to be removed. """ if self.debug and callback not in [task[2] for task in self.task_list]: self.debug_log("{ANSI_FG_YELLOW}Warning: task not found for removal.{ANSI_RESET}") self.tasks_for_removal.append(callback)
[docs] def sync_task(self, callback, ns_sync=None): """ Set task periodic interval to align with ns_sync as basis. Default to now (time.monotonic_ns()) Args: callback (function): The function to be synced. ns_sync (int, optional): The time in nanoseconds to sync to. Defaults to None, which uses the current time. """ if ns_sync is None: ns_sync = time.monotonic_ns() # Set task to run at ns_sync + interval for i, (next_run, interval, cb, args) in enumerate(self.task_list): if cb == callback: self.task_list[i] = (ns_sync + int(interval * 1000000000), interval, cb, args) return raise ValueError(f"Task {callback} not found in task list. Cannot sync.")
[docs] def schedule(self, callback, delay, args=None): """ Schedules a task to be run once after a delay. Args: callback (function): The function to be called. delay (float): The delay in seconds before the task is executed. args (tuple, optional): Arguments to pass to the callback function. """ if not callable(callback): raise TypeError("schedule() expects a function object.") def sked_callback(*args): callback(*args) self.remove_task(sked_callback) self.add_task(sked_callback, delay, args)
[docs] def run_tasks(self): """Runs all scheduled tasks that are due.""" if self.tasks_for_removal: self.task_list = [task for task in self.task_list if task[2] not in self.tasks_for_removal] self.tasks_for_removal.clear() now = time.monotonic_ns() for i, (next_run, interval, callback, args) in enumerate(self.task_list): if now >= next_run: if self.debug: try: t_call = time.monotonic_ns() callback(*args) if args else callback() # Run the scheduled task duration = (time.monotonic_ns() - t_call) / 1000000000 if duration > self.TASK_WARNING_THRESHOLD: self.debug_log(f"{ANSI_FG_YELLOW}Warning: task took too long: {duration}s ({callback}){ANSI_RESET}") except Exception: self.debug_log(f"{ANSI_FG_YELLOW}Error calling {callback} with args {args}{ANSI_RESET}") raise else: callback(*args) if args else callback() # Run the scheduled task self.task_list[i] = (now + int(interval * 1000000000), interval, callback, args) # Update next run time
[docs] def add_state(self, state_func): """ Adds a state to the state machine. state_func must be a callable that takes four arguments: .. code-block:: python def func(state_machine, event_type, event_data, st_time): pass * The state_func() will be called on every next_state transition, with the event_type set to "enter" or "exit". * It will also be called when event() is called, with event_type set to the event type passed to event(). * The st_time argument is the elapsed time since the state was entered, in seconds. Args: state_func (function): The function that handles the state. """ if not callable(state_func): raise TypeError("add_state() expects a function object.") state = State(self, state_func) if state_func in self.states: raise ValueError(f"State with name '{state.name}' already exists.") if not self.states: # Set initial state if first state added self.initial_state_func = state_func self.states[state_func] = state
[docs] def current(self): """Return currently active state function""" return self.current_state.handler_function if self.current_state else None
[docs] def start(self, first_state=None): """Starts or re-starts the state machine. Args: first_state (function, optional): The function that handles the initial state. Defaults to the first state added. """ if not self.states: raise Exception("StateMachine has no states defined.") if first_state: self.initial_state_func = first_state self.running = True self.current_state = None self.next_state(self.initial_state_func)
[docs] def next_state(self, next_state_func, sched_delay=None): """ Transitions to a new state. Args: next_state_func (function): The function that handles the next state. sched_delay(float): Optional schedule delay (seconds) before transition. Use 0 to schedule ASAP and break recursion. """ if not self.running: raise Exception("StateMachine not started. Call start() first.") if next_state_func not in self.states: raise ValueError(f"State not found.") if self.current_state and self.current_state.handler_function == next_state_func: return # Already in this state if sched_delay is not None: self.schedule(self.next_state, sched_delay, next_state_func) return if self.debug: self._recursion_depth += 1 if self._recursion_depth > self.MAX_RECURSION_DEPTH: raise RuntimeError("Recursive state transition detected. Check your enter/exit handlers.") if self.current_state: if self.debug: self.debug_log(f"{ANSI_FG_GREEN}Exiting state: {self.current_state.name}{ANSI_RESET}") self.current_state.exit() # Exit current state self.current_state = self.states[next_state_func] if self.debug: self.debug_log(f"{ANSI_FG_GREEN}Entering state: {self.current_state.name}{ANSI_RESET}") self.current_state.enter() # Enter new state if self.debug: self._recursion_depth -= 1
[docs] def event(self, event_type, event_data=None): """ Injects an event into the state machine, handled by the current state. Args: event_type (str): The type of event. event_data (any, optional): The data associated with the event. """ if self.current_state: self.current_state.handle_event(event_type, event_data)