Source code for domovoy.plugins.hass

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import Any, Concatenate, ParamSpec

from domovoy_typing.services import HassSyntheticDomainsServiceCalls

from domovoy.applications.types import Interval
from domovoy.core.app_infra import AppStatus, AppWrapper
from domovoy.core.context import context_callback_id, context_logger
from domovoy.core.logging import get_logger
from domovoy.plugins import callbacks
from domovoy.plugins.hass.domains import (
    get_type_instance_for_entity_id,
)
from domovoy.plugins.hass.exceptions import HassUnknownEntityError
from domovoy.plugins.plugins import AppPlugin

from .core import EntityState, HassCore
from .exceptions import HassApiCommandError
from .types import EntityID, HassData, HassValue, PrimitiveHassValue

P = ParamSpec("P")

_missing_entities_logger = get_logger("missing_entitites")


[docs] @dataclass class ServiceDetails: has_response: bool
[docs] class HassPlugin(AppPlugin): __hass: HassCore _wrapper: AppWrapper __callbacks: callbacks.CallbacksPlugin __cached_service_definitions: dict[str, ServiceDetails] | None = None def __init__( self, name: str, wrapper: AppWrapper, hass_core: HassCore, ) -> None: super().__init__(name, wrapper) self.__hass = hass_core self.services = HassSyntheticDomainsServiceCalls(self)
[docs] def prepare(self) -> None: super().prepare() self.__callbacks = self._wrapper.get_pluginx(callbacks.CallbacksPlugin)
[docs] def get_full_state(self, entity_id: EntityID) -> EntityState: """Get the complete state object for a Home Assistant entity. Args: entity_id: The EntityID to retrieve state for. Can also accept string for backwards compatibility. Returns: EntityState object containing state, attributes, last_changed, and last_updated. Raises: HassUnknownEntityError: If the entity does not exist in the state cache. """ if isinstance(entity_id, str): entity_id = get_type_instance_for_entity_id(entity_id) entity_state = self.__hass.get_state(entity_id) if entity_state is None: raise HassUnknownEntityError(entity_id) return entity_state
[docs] def warn_if_entity_doesnt_exists(self, entity_id: EntityID | Sequence[EntityID] | None) -> None: """Log a warning if the specified entity or entities don't exist in the Home Assistant cache. Useful for debugging typos in entity IDs during development. Args: entity_id: Single EntityID, sequence of EntityIDs, or None. If None, no action is taken. """ if entity_id is None: return entity_id = wrap_entity_id_as_list(entity_id) for eid in entity_id: if not self.__hass.entity_exists_in_cache(eid): _missing_entities_logger.warning( "[{app_name}] '{entity_id}' doesn't exist in Hass.", entity_id=eid, app_name=self._wrapper.get_app_name_for_logs(), )
[docs] def get_entity_id_by_attribute( self, attribute: str, value: str | None, ) -> list[EntityID]: """Find all entities that have a specific attribute with the given value. Args: attribute: The name of the attribute to search for. value: The value to match. If None, returns entities that have the attribute regardless of value. Returns: List of EntityIDs that match the criteria. """ return self.__hass.get_entity_id_by_attribute(attribute, value)
[docs] def get_all_entities(self) -> list[EntityState]: """Get the complete state objects for all entities in the Home Assistant cache. Returns: List of EntityState objects for every entity known to Home Assistant. """ return self.__hass.get_all_entities()
[docs] def get_all_entity_ids(self) -> frozenset[EntityID]: """Get all entity IDs currently in the Home Assistant cache. Returns: Frozen set of EntityID objects representing all known entities. """ return self.__hass.get_all_entity_ids()
[docs] async def fire_event( self, event_type: str, event_data: HassData | None = None, ) -> None: """Fire a custom event on the Home Assistant event bus. Args: event_type: The type/name of the event to fire. event_data: Optional dictionary of data to include with the event. """ await self.__hass.fire_event(event_type, event_data)
[docs] async def get_service_definitions(self) -> HassData: """Retrieve service definitions from Home Assistant. Returns: Dictionary containing all available services and their schemas. """ return await self.__hass.get_service_definitions()
[docs] async def listen_trigger( self, trigger: HassData, callback: Callable[Concatenate[HassData, P], None | Awaitable[None]], oneshot: bool = False, # noqa: FBT001, FBT002 *callback_args: P.args, **callback_kwargs: P.kwargs, ) -> str: """Subscribe to a Home Assistant trigger configuration. Triggers are HA's automation system primitives (state, numeric_state, time, etc.). Args: trigger: Dictionary containing the trigger configuration (e.g., {"platform": "state", "entity_id": "..."}). callback: Function to call when trigger fires. Receives trigger variables as first argument. oneshot: If True, unsubscribe after the first trigger event. *callback_args: Additional positional arguments to pass to callback. **callback_kwargs: Additional keyword arguments to pass to callback. Returns: Subscription ID string that can be used to cancel the trigger subscription. """ context_logger.set(self._wrapper.logger) instrumented_callback = self._wrapper.instrument_app_callback(callback) @self._wrapper.handle_exception_and_logging(callback) async def listen_trigger_callback( subscription_id: int, trigger_vars: HassData, ) -> None: self._wrapper.logger.trace( "Calling Listen Trigger Callback: {cls_name}.{func_name} from callback_id: {subscription_id}", cls_name=callback.__self__.__class__.__name__, # type: ignore func_name=callback.__name__, subscription_id=subscription_id, ) if oneshot: await self.__hass.unsubscribe_trigger(subscription_id) await instrumented_callback( subscription_id, trigger_vars, *callback_args, **callback_kwargs, ) return str( await self.__hass.subscribe_trigger( listen_trigger_callback, trigger, ), )
[docs] async def call_service( self, service_name: str, *, return_response: bool = False, throw_on_error: bool = False, **kwargs: HassValue, ) -> HassData | None: """Call a Home Assistant service. Prefer using the typed service stubs (self.services.domain.service_name) over this method. Args: service_name: Service to call in format "domain.service" (e.g., "light.turn_on"). return_response: If True, wait for and return the service response data. throw_on_error: If True, raise exceptions on errors. If False, log errors without raising. **kwargs: Service data parameters. entity_id can be EntityID or list[EntityID]. Returns: Service response data if return_response is True and service returns data, otherwise None. Raises: HassApiCommandError: If throw_on_error is True and service call fails. """ service_name_segments = service_name.split(".") if len(service_name_segments) != 2: self._wrapper.logger.error( "Cannot call service {service_name}. Invalid service name", service_name=service_name, ) return None domain = service_name_segments[0] service = service_name_segments[1] entity_id: EntityID | list[EntityID] | None = None if "entity_id" in kwargs and ("domovoy_drop_target" not in kwargs or not kwargs["domovoy_drop_target"]): # We add the ignore because there is no easy way # to restrict the typing of kwargs until python 3.12 entity_id = kwargs["entity_id"] # type: ignore if ( entity_id is None or (isinstance(entity_id, list) and not all(isinstance(sub, EntityID) for sub in entity_id)) or (not isinstance(entity_id, EntityID) and not isinstance(entity_id, list)) ): self._wrapper.logger.error( "Cannot call service `{service_name}`. The `entity_id` key has an invalid type." " Only `EntityID` or `list[EntityID]` are allowed. If passing a list, make sure " "all the elements are EntityID. {entity_id}", service_name=service_name, entity_id=entity_id, ) return None kwargs.pop("entity_id") if "domovoy_drop_target" in kwargs: kwargs.pop("domovoy_drop_target") if "service_data_entity_id" in kwargs: val = get_type_instance_for_entity_id(str(kwargs["service_data_entity_id"])) kwargs.pop("service_data_entity_id") self.warn_if_entity_doesnt_exists(val if val else None) kwargs["entity_id"] = val self.warn_if_entity_doesnt_exists(entity_id) try: return await self.__hass.call_service( domain=domain, service=service, service_data=kwargs, entity_id=entity_id, return_response=return_response, ) except HassApiCommandError as e: if throw_on_error or self._wrapper.status == AppStatus.INITIALIZING: raise if e.message == "Service call requires responses but caller did not ask for responses": return await self.call_service( service_name, return_response=True, throw_on_error=throw_on_error, **kwargs, ) self._wrapper.logger.error( "There was an error when executing the command. " "Exception was not raised to app. Message: {exception_message}", exception_message=str(e), )
[docs] async def wait_for_state_to_be( self, entity_id: EntityID, states: str | list[str], duration: Interval | None = None, timeout: Interval | None = None, # noqa: ASYNC109 ) -> None: """Asynchronously wait for an entity to reach one of the specified states. This is an awaitable method that blocks until the condition is met or timeout occurs. Args: entity_id: The entity to monitor. states: Single state string or list of state strings to wait for. duration: If specified, entity must stay in the target state for this duration. timeout: Optional timeout interval. Raises asyncio.TimeoutError if exceeded. Raises: asyncio.TimeoutError: If timeout is specified and exceeded before condition is met. """ if timeout is None: await self.__wait_for_state_to_be_implementation( entity_id, states, duration, ) else: async with asyncio.timeout(timeout.total_seconds()): await self.__wait_for_state_to_be_implementation( entity_id, states, duration, )
def __wait_for_state_to_be_implementation( self, entity_id: EntityID, states: str | list[str], duration: Interval | None = None, ) -> asyncio.Future[None]: future = asyncio.get_event_loop().create_future() if isinstance(states, str): states = [states] async def state_callback( entity: EntityID, _attribute: str, _old: HassValue, new: HassValue, ) -> None: callback_id: str = context_callback_id.get() # type: ignore if new in states and not future.done(): entity_full_state = self.get_full_state(entity) if duration is not None and not entity_full_state.has_been_in_current_state_for_at_least( duration, ): await asyncio.sleep( (duration.to_timedelta() - entity_full_state.get_time_in_current_state()).total_seconds() + 0.5, ) if not self.get_full_state( entity, ).has_been_in_current_state_for_at_least(duration): return self.__callbacks.cancel_callback(callback_id) future.set_result(None) self.__callbacks.listen_state_extended(entity_id, state_callback, immediate=True) return future async def _get_cached_service_definitions(self, *, reset: bool = False) -> dict[str, ServiceDetails]: if self.__cached_service_definitions is None or reset is True: domains: dict[str, Any] = await self.get_service_definitions() service_definitions = {} for domain, services in sorted(domains.items()): for service, details in sorted(services.items()): service_definitions[f"{domain}.{service}"] = ServiceDetails(has_response="response" in details) self.__cached_service_definitions = service_definitions return self.__cached_service_definitions
[docs] async def send_raw_command(self, command_type: str, command_args: HassData) -> HassData | list[HassData]: """Send a raw WebSocket command directly to Home Assistant. This is a low-level method for advanced use cases not covered by other plugin methods. Args: command_type: The WebSocket command type to send. command_args: Dictionary of arguments for the command. Returns: The response from Home Assistant, either a single data dict or list of dicts. """ return await self.__hass.send_raw_command(command_type, command_args)
[docs] def get_typed_state[T](self, entity_id: EntityID[T]) -> T | None: """Get the state of an entity, cast to the entity's native type. Uses the EntityID's type information to parse and cast the state value. Args: entity_id: Typed EntityID that includes parsing logic for its domain. Returns: State value cast to the appropriate type, or None if parsing fails. """ full_state: EntityState = self.get_full_state(entity_id) return entity_id.parse_state_typed(full_state)
[docs] def get_state(self, entity_id: EntityID) -> PrimitiveHassValue: """Get the current state value of an entity. Returns only the state value, not the full EntityState object. Args: entity_id: The entity to get state for. Returns: The entity's current state as a primitive value (str, int, float, or bool). """ return self.get_full_state(entity_id).state
[docs] def wrap_entity_id_as_list(val: EntityID | Sequence[EntityID]) -> list[EntityID]: """Convert a single EntityID or sequence of EntityIDs into a list. Utility function to normalize entity ID arguments that can be either single or multiple values. Args: val: Single EntityID or sequence of EntityIDs. Returns: List containing the EntityID(s). """ if isinstance(val, Sequence): return list(val) return [val]