diff --git a/custom_components/xiaomi_home/light.py b/custom_components/xiaomi_home/light.py index 26ed208..d953cc8 100644 --- a/custom_components/xiaomi_home/light.py +++ b/custom_components/xiaomi_home/light.py @@ -45,9 +45,10 @@ off Xiaomi or its affiliates' products. Light entities for Xiaomi Home. """ + from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any, Optional, List, Dict from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant @@ -59,34 +60,31 @@ from homeassistant.components.light import ( ATTR_EFFECT, LightEntity, LightEntityFeature, - ColorMode -) -from homeassistant.util.color import ( - value_to_brightness, - brightness_to_value + ColorMode, ) +from homeassistant.util.color import value_to_brightness, brightness_to_value from .miot.miot_spec import MIoTSpecProperty -from .miot.miot_device import MIoTDevice, MIoTEntityData, MIoTServiceEntity +from .miot.miot_device import MIoTDevice, MIoTEntityData, MIoTServiceEntity from .miot.const import DOMAIN _LOGGER = logging.getLogger(__name__) async def async_setup_entry( - hass: HomeAssistant, - config_entry: ConfigEntry, - async_add_entities: AddEntitiesCallback, + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, ) -> None: """Set up a config entry.""" - device_list: list[MIoTDevice] = hass.data[DOMAIN]['devices'][ - config_entry.entry_id] + device_list: list[MIoTDevice] = hass.data[DOMAIN]["devices"][config_entry.entry_id] new_entities = [] for miot_device in device_list: - for data in miot_device.entity_list.get('light', []): + for data in miot_device.entity_list.get("light", []): new_entities.append( - Light(miot_device=miot_device, entity_data=data)) + Light(miot_device=miot_device, entity_data=data, hass=hass) + ) if new_entities: async_add_entities(new_entities) @@ -94,6 +92,7 @@ async def async_setup_entry( class Light(MIoTServiceEntity, LightEntity): """Light entities for Xiaomi Home.""" + # pylint: disable=unused-argument _VALUE_RANGE_MODE_COUNT_MAX = 30 _prop_on: Optional[MIoTSpecProperty] @@ -106,15 +105,17 @@ class Light(MIoTServiceEntity, LightEntity): _mode_map: Optional[dict[Any, Any]] def __init__( - self, miot_device: MIoTDevice, entity_data: MIoTEntityData + self, miot_device: MIoTDevice, entity_data: MIoTEntityData, hass: HomeAssistant ) -> None: """Initialize the Light.""" - super().__init__(miot_device=miot_device, entity_data=entity_data) + super().__init__(miot_device=miot_device, entity_data=entity_data) + self.hass = hass self._attr_color_mode = None self._attr_supported_color_modes = set() self._attr_supported_features = LightEntityFeature(0) - if miot_device.did.startswith('group.'): - self._attr_icon = 'mdi:lightbulb-group' + self.miot_device = miot_device + if miot_device.did.startswith("group."): + self._attr_icon = "mdi:lightbulb-group" self._prop_on = None self._prop_brightness = None @@ -127,33 +128,32 @@ class Light(MIoTServiceEntity, LightEntity): # properties for prop in entity_data.props: # on - if prop.name == 'on': + if prop.name == "on": self._prop_on = prop # brightness - if prop.name == 'brightness': + if prop.name == "brightness": if prop.value_range: self._brightness_scale = ( - prop.value_range.min_, prop.value_range.max_) + prop.value_range.min_, + prop.value_range.max_, + ) self._prop_brightness = prop - elif ( - self._mode_map is None - and prop.value_list - ): + elif self._mode_map is None and prop.value_list: # For value-list brightness self._mode_map = prop.value_list.to_map() self._attr_effect_list = list(self._mode_map.values()) self._attr_supported_features |= LightEntityFeature.EFFECT self._prop_mode = prop else: - _LOGGER.info( - 'invalid brightness format, %s', self.entity_id) + _LOGGER.info("invalid brightness format, %s", self.entity_id) continue # color-temperature - if prop.name == 'color-temperature': + if prop.name == "color-temperature": if not prop.value_range: _LOGGER.info( - 'invalid color-temperature value_range format, %s', - self.entity_id) + "invalid color-temperature value_range format, %s", + self.entity_id, + ) continue self._attr_min_color_temp_kelvin = prop.value_range.min_ self._attr_max_color_temp_kelvin = prop.value_range.max_ @@ -161,40 +161,44 @@ class Light(MIoTServiceEntity, LightEntity): self._attr_color_mode = ColorMode.COLOR_TEMP self._prop_color_temp = prop # color - if prop.name == 'color': + if prop.name == "color": self._attr_supported_color_modes.add(ColorMode.RGB) self._attr_color_mode = ColorMode.RGB self._prop_color = prop # mode - if prop.name == 'mode': + if prop.name == "mode": mode_list = None if prop.value_list: mode_list = prop.value_list.to_map() elif prop.value_range: mode_list = {} if ( - int(( - prop.value_range.max_ - - prop.value_range.min_ - ) / prop.value_range.step) + int( + (prop.value_range.max_ - prop.value_range.min_) + / prop.value_range.step + ) > self._VALUE_RANGE_MODE_COUNT_MAX ): _LOGGER.error( - 'too many mode values, %s, %s, %s', - self.entity_id, prop.name, prop.value_range) + "too many mode values, %s, %s, %s", + self.entity_id, + prop.name, + prop.value_range, + ) else: for value in range( - prop.value_range.min_, - prop.value_range.max_, - prop.value_range.step): - mode_list[value] = f'mode {value}' + prop.value_range.min_, + prop.value_range.max_, + prop.value_range.step, + ): + mode_list[value] = f"mode {value}" if mode_list: self._mode_map = mode_list self._attr_effect_list = list(self._mode_map.values()) self._attr_supported_features |= LightEntityFeature.EFFECT self._prop_mode = prop else: - _LOGGER.info('invalid mode format, %s', self.entity_id) + _LOGGER.info("invalid mode format, %s", self.entity_id) continue if not self._attr_supported_color_modes: @@ -242,8 +246,8 @@ class Light(MIoTServiceEntity, LightEntity): def effect(self) -> Optional[str]: """Return the current mode.""" return self.get_map_value( - map_=self._mode_map, - key=self.get_prop_value(prop=self._prop_mode)) + map_=self._mode_map, key=self.get_prop_value(prop=self._prop_mode) + ) async def async_turn_on(self, **kwargs) -> None: """Turn the light on. @@ -252,42 +256,135 @@ class Light(MIoTServiceEntity, LightEntity): """ # on # Dirty logic for lumi.gateway.mgl03 indicator light - if self._prop_on: - value_on = True if self._prop_on.format_ == bool else 1 - await self.set_property_async( - prop=self._prop_on, value=value_on) - # brightness - if ATTR_BRIGHTNESS in kwargs: - brightness = brightness_to_value( - self._brightness_scale, kwargs[ATTR_BRIGHTNESS]) - await self.set_property_async( - prop=self._prop_brightness, value=brightness, - write_ha_state=False) - # color-temperature - if ATTR_COLOR_TEMP_KELVIN in kwargs: - await self.set_property_async( - prop=self._prop_color_temp, - value=kwargs[ATTR_COLOR_TEMP_KELVIN], - write_ha_state=False) - self._attr_color_mode = ColorMode.COLOR_TEMP - # rgb color - if ATTR_RGB_COLOR in kwargs: - r = kwargs[ATTR_RGB_COLOR][0] - g = kwargs[ATTR_RGB_COLOR][1] - b = kwargs[ATTR_RGB_COLOR][2] - rgb = (r << 16) | (g << 8) | b - await self.set_property_async( - prop=self._prop_color, value=rgb, - write_ha_state=False) - self._attr_color_mode = ColorMode.RGB - # mode - if ATTR_EFFECT in kwargs: - await self.set_property_async( - prop=self._prop_mode, - value=self.get_map_key( - map_=self._mode_map, value=kwargs[ATTR_EFFECT]), - write_ha_state=False) - self.async_write_ha_state() + # Determine whether the device sends the light-on properties in batches or one by one + select_entity_id = f"select.{self.miot_device.gen_device_entity_id(DOMAIN).split('.')[-1]}_command_send_mode" + command_send_mode = self.hass.states.get(select_entity_id) + if command_send_mode and command_send_mode.state == "Send Together": + set_properties_list: List[Dict[str, Any]] = [] + if self._prop_on: + value_on = True if self._prop_on.format_ == bool else 1 # noqa: E721 + set_properties_list.append({"prop": self._prop_on, "value": value_on}) + # brightness + if ATTR_BRIGHTNESS in kwargs: + brightness = brightness_to_value( + self._brightness_scale, kwargs[ATTR_BRIGHTNESS] + ) + set_properties_list.append( + {"prop": self._prop_brightness, "value": brightness} + ) + # color-temperature + if ATTR_COLOR_TEMP_KELVIN in kwargs: + set_properties_list.append( + { + "prop": self._prop_color_temp, + "value": kwargs[ATTR_COLOR_TEMP_KELVIN], + } + ) + self._attr_color_mode = ColorMode.COLOR_TEMP + # rgb color + if ATTR_RGB_COLOR in kwargs: + r = kwargs[ATTR_RGB_COLOR][0] + g = kwargs[ATTR_RGB_COLOR][1] + b = kwargs[ATTR_RGB_COLOR][2] + rgb = (r << 16) | (g << 8) | b + set_properties_list.append({"prop": self._prop_color, "value": rgb}) + self._attr_color_mode = ColorMode.RGB + # mode + if ATTR_EFFECT in kwargs: + set_properties_list.append( + { + "prop": self._prop_mode, + "value": self.get_map_key( + map_=self._mode_map, value=kwargs[ATTR_EFFECT] + ), + } + ) + await self.set_properties_async(set_properties_list) + self.async_write_ha_state() + elif command_send_mode and command_send_mode.state == "Send Turn On First": + set_properties_list: List[Dict[str, Any]] = [] + if self._prop_on: + value_on = True if self._prop_on.format_ == bool else 1 # noqa: E721 + set_properties_list.append({"prop": self._prop_on, "value": value_on}) + await self.set_property_async(prop=self._prop_on, value=value_on) + # brightness + if ATTR_BRIGHTNESS in kwargs: + brightness = brightness_to_value( + self._brightness_scale, kwargs[ATTR_BRIGHTNESS] + ) + set_properties_list.append( + {"prop": self._prop_brightness, "value": brightness} + ) + # color-temperature + if ATTR_COLOR_TEMP_KELVIN in kwargs: + set_properties_list.append( + { + "prop": self._prop_color_temp, + "value": kwargs[ATTR_COLOR_TEMP_KELVIN], + } + ) + self._attr_color_mode = ColorMode.COLOR_TEMP + # rgb color + if ATTR_RGB_COLOR in kwargs: + r = kwargs[ATTR_RGB_COLOR][0] + g = kwargs[ATTR_RGB_COLOR][1] + b = kwargs[ATTR_RGB_COLOR][2] + rgb = (r << 16) | (g << 8) | b + set_properties_list.append({"prop": self._prop_color, "value": rgb}) + self._attr_color_mode = ColorMode.RGB + # mode + if ATTR_EFFECT in kwargs: + set_properties_list.append( + { + "prop": self._prop_mode, + "value": self.get_map_key( + map_=self._mode_map, value=kwargs[ATTR_EFFECT] + ), + } + ) + await self.set_properties_async(set_properties_list) + self.async_write_ha_state() + + else: + if self._prop_on: + value_on = True if self._prop_on.format_ == bool else 1 # noqa: E721 + await self.set_property_async(prop=self._prop_on, value=value_on) + # brightness + if ATTR_BRIGHTNESS in kwargs: + brightness = brightness_to_value( + self._brightness_scale, kwargs[ATTR_BRIGHTNESS] + ) + await self.set_property_async( + prop=self._prop_brightness, value=brightness, write_ha_state=False + ) + # color-temperature + if ATTR_COLOR_TEMP_KELVIN in kwargs: + await self.set_property_async( + prop=self._prop_color_temp, + value=kwargs[ATTR_COLOR_TEMP_KELVIN], + write_ha_state=False, + ) + self._attr_color_mode = ColorMode.COLOR_TEMP + # rgb color + if ATTR_RGB_COLOR in kwargs: + r = kwargs[ATTR_RGB_COLOR][0] + g = kwargs[ATTR_RGB_COLOR][1] + b = kwargs[ATTR_RGB_COLOR][2] + rgb = (r << 16) | (g << 8) | b + await self.set_property_async( + prop=self._prop_color, value=rgb, write_ha_state=False + ) + self._attr_color_mode = ColorMode.RGB + # mode + if ATTR_EFFECT in kwargs: + await self.set_property_async( + prop=self._prop_mode, + value=self.get_map_key( + map_=self._mode_map, value=kwargs[ATTR_EFFECT] + ), + write_ha_state=False, + ) + self.async_write_ha_state() async def async_turn_off(self, **kwargs) -> None: """Turn the light off.""" diff --git a/custom_components/xiaomi_home/miot/miot_client.py b/custom_components/xiaomi_home/miot/miot_client.py index 58f506d..7864ee9 100644 --- a/custom_components/xiaomi_home/miot/miot_client.py +++ b/custom_components/xiaomi_home/miot/miot_client.py @@ -45,8 +45,9 @@ off Xiaomi or its affiliates' products. MIoT client instance. """ + from copy import deepcopy -from typing import Any, Callable, Optional, final +from typing import Any, Callable, Dict, List, Optional, final import asyncio import json import logging @@ -61,14 +62,23 @@ from homeassistant.components import zeroconf # pylint: disable=relative-beyond-top-level from .common import MIoTMatcher, slugify_did from .const import ( - DEFAULT_CTRL_MODE, DEFAULT_INTEGRATION_LANGUAGE, DEFAULT_NICK_NAME, DOMAIN, - MIHOME_CERT_EXPIRE_MARGIN, NETWORK_REFRESH_INTERVAL, - OAUTH2_CLIENT_ID, SUPPORT_CENTRAL_GATEWAY_CTRL) + DEFAULT_CTRL_MODE, + DEFAULT_INTEGRATION_LANGUAGE, + DEFAULT_NICK_NAME, + DOMAIN, + MIHOME_CERT_EXPIRE_MARGIN, + NETWORK_REFRESH_INTERVAL, + OAUTH2_CLIENT_ID, + SUPPORT_CENTRAL_GATEWAY_CTRL, +) from .miot_cloud import MIoTHttpClient, MIoTOauthClient from .miot_error import MIoTClientError, MIoTErrorCode from .miot_mips import ( - MIoTDeviceState, MipsCloudClient, MipsDeviceState, - MipsLocalClient) + MIoTDeviceState, + MipsCloudClient, + MipsDeviceState, + MipsLocalClient, +) from .miot_lan import MIoTLan from .miot_network import MIoTNetwork from .miot_storage import MIoTCert, MIoTStorage @@ -81,30 +91,33 @@ _LOGGER = logging.getLogger(__name__) @dataclass class MIoTClientSub: """MIoT client subscription.""" + topic: Optional[str] handler: Callable[[dict, Any], None] handler_ctx: Any = None def __str__(self) -> str: - return f'{self.topic}, {id(self.handler)}, {id(self.handler_ctx)}' + return f"{self.topic}, {id(self.handler)}, {id(self.handler_ctx)}" class CtrlMode(Enum): """MIoT client control mode.""" + AUTO = 0 CLOUD = auto() @staticmethod - def load(mode: str) -> 'CtrlMode': - if mode == 'auto': + def load(mode: str) -> "CtrlMode": + if mode == "auto": return CtrlMode.AUTO - if mode == 'cloud': + if mode == "cloud": return CtrlMode.CLOUD - raise MIoTClientError(f'unknown ctrl mode, {mode}') + raise MIoTClientError(f"unknown ctrl mode, {mode}") class MIoTClient: """MIoT client instance.""" + # pylint: disable=unused-argument # pylint: disable=broad-exception-caught # pylint: disable=inconsistent-quotes @@ -175,33 +188,33 @@ class MIoTClient: _display_binary_bool: bool def __init__( - self, - entry_id: str, - entry_data: dict, - network: MIoTNetwork, - storage: MIoTStorage, - mips_service: MipsService, - miot_lan: MIoTLan, - loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self, + entry_id: str, + entry_data: dict, + network: MIoTNetwork, + storage: MIoTStorage, + mips_service: MipsService, + miot_lan: MIoTLan, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: # MUST run in a running event loop self._main_loop = loop or asyncio.get_running_loop() # Check params if not isinstance(entry_data, dict): - raise MIoTClientError('invalid entry data') - if 'uid' not in entry_data or 'cloud_server' not in entry_data: - raise MIoTClientError('invalid entry data content') + raise MIoTClientError("invalid entry data") + if "uid" not in entry_data or "cloud_server" not in entry_data: + raise MIoTClientError("invalid entry data content") if not isinstance(network, MIoTNetwork): - raise MIoTClientError('invalid miot network') + raise MIoTClientError("invalid miot network") if not isinstance(storage, MIoTStorage): - raise MIoTClientError('invalid miot storage') + raise MIoTClientError("invalid miot storage") if not isinstance(mips_service, MipsService): - raise MIoTClientError('invalid mips service') + raise MIoTClientError("invalid mips service") self._entry_id = entry_id self._entry_data = entry_data - self._uid = entry_data['uid'] - self._cloud_server = entry_data['cloud_server'] - self._ctrl_mode = CtrlMode.load( - entry_data.get('ctrl_mode', DEFAULT_CTRL_MODE)) + self._uid = entry_data["uid"] + self._cloud_server = entry_data["cloud_server"] + self._ctrl_mode = CtrlMode.load(entry_data.get("ctrl_mode", DEFAULT_CTRL_MODE)) self._network = network self._storage = storage self._mips_service = mips_service @@ -238,125 +251,139 @@ class MIoTClient: self._show_devices_changed_notify_timer = None self._display_devs_notify = entry_data.get( - 'display_devices_changed_notify', ['add', 'del', 'offline']) + "display_devices_changed_notify", ["add", "del", "offline"] + ) self._display_notify_content_hash = None - self._display_binary_text = 'text' in entry_data.get( - 'display_binary_mode', ['text']) - self._display_binary_bool = 'bool' in entry_data.get( - 'display_binary_mode', ['text']) + self._display_binary_text = "text" in entry_data.get( + "display_binary_mode", ["text"] + ) + self._display_binary_bool = "bool" in entry_data.get( + "display_binary_mode", ["text"] + ) async def init_async(self) -> None: # Load user config and check self._user_config = await self._storage.load_user_config_async( - uid=self._uid, cloud_server=self._cloud_server) + uid=self._uid, cloud_server=self._cloud_server + ) if not self._user_config: # Integration need to be add again - raise MIoTClientError('load_user_config_async error') - _LOGGER.debug('user config, %s', json.dumps(self._user_config)) + raise MIoTClientError("load_user_config_async error") + _LOGGER.debug("user config, %s", json.dumps(self._user_config)) # MIoT i18n client self._i18n = MIoTI18n( lang=self._entry_data.get( - 'integration_language', DEFAULT_INTEGRATION_LANGUAGE), - loop=self._main_loop) + "integration_language", DEFAULT_INTEGRATION_LANGUAGE + ), + loop=self._main_loop, + ) await self._i18n.init_async() # Load cache device list await self.__load_cache_device_async() # MIoT oauth client instance self._oauth = MIoTOauthClient( client_id=OAUTH2_CLIENT_ID, - redirect_url=self._entry_data['oauth_redirect_url'], + redirect_url=self._entry_data["oauth_redirect_url"], cloud_server=self._cloud_server, uuid=self._entry_data["uuid"], - loop=self._main_loop) + loop=self._main_loop, + ) # MIoT http client instance self._http = MIoTHttpClient( cloud_server=self._cloud_server, client_id=OAUTH2_CLIENT_ID, - access_token=self._user_config['auth_info']['access_token'], - loop=self._main_loop) + access_token=self._user_config["auth_info"]["access_token"], + loop=self._main_loop, + ) # MIoT cert client self._cert = MIoTCert( - storage=self._storage, - uid=self._uid, - cloud_server=self.cloud_server) + storage=self._storage, uid=self._uid, cloud_server=self.cloud_server + ) # MIoT cloud mips client self._mips_cloud = MipsCloudClient( - uuid=self._entry_data['uuid'], + uuid=self._entry_data["uuid"], cloud_server=self._cloud_server, app_id=OAUTH2_CLIENT_ID, - token=self._user_config['auth_info']['access_token'], - loop=self._main_loop) + token=self._user_config["auth_info"]["access_token"], + loop=self._main_loop, + ) self._mips_cloud.enable_logger(logger=_LOGGER) self._mips_cloud.sub_mips_state( - key=f'{self._uid}-{self._cloud_server}', - handler=self.__on_mips_cloud_state_changed) + key=f"{self._uid}-{self._cloud_server}", + handler=self.__on_mips_cloud_state_changed, + ) # Subscribe network status self._network.sub_network_status( - key=f'{self._uid}-{self._cloud_server}', - handler=self.__on_network_status_changed) - await self.__on_network_status_changed( - status=self._network.network_status) + key=f"{self._uid}-{self._cloud_server}", + handler=self.__on_network_status_changed, + ) + await self.__on_network_status_changed(status=self._network.network_status) # Create multi mips local client instance according to the # number of hub gateways if self._ctrl_mode == CtrlMode.AUTO: # Central hub gateway ctrl if self._cloud_server in SUPPORT_CENTRAL_GATEWAY_CTRL: - for home_id, info in self._entry_data['home_selected'].items(): + for home_id, info in self._entry_data["home_selected"].items(): # Create local mips service changed listener self._mips_service.sub_service_change( - key=f'{self._uid}-{self._cloud_server}', - group_id=info['group_id'], - handler=self.__on_mips_service_state_change) + key=f"{self._uid}-{self._cloud_server}", + group_id=info["group_id"], + handler=self.__on_mips_service_state_change, + ) service_data = self._mips_service.get_services( - group_id=info['group_id']).get(info['group_id'], None) + group_id=info["group_id"] + ).get(info["group_id"], None) if not service_data: - _LOGGER.info( - 'central mips service not scanned, %s', home_id) + _LOGGER.info("central mips service not scanned, %s", home_id) continue _LOGGER.info( - 'central mips service scanned, %s, %s', - home_id, service_data) + "central mips service scanned, %s, %s", home_id, service_data + ) mips = MipsLocalClient( - did=self._entry_data['virtual_did'], - group_id=info['group_id'], - host=service_data['addresses'][0], + did=self._entry_data["virtual_did"], + group_id=info["group_id"], + host=service_data["addresses"][0], ca_file=self._cert.ca_file, cert_file=self._cert.cert_file, key_file=self._cert.key_file, - port=service_data['port'], - home_name=info['home_name'], - loop=self._main_loop) - self._mips_local[info['group_id']] = mips + port=service_data["port"], + home_name=info["home_name"], + loop=self._main_loop, + ) + self._mips_local[info["group_id"]] = mips mips.enable_logger(logger=_LOGGER) mips.on_dev_list_changed = self.__on_gw_device_list_changed mips.sub_mips_state( - key=info['group_id'], - handler=self.__on_mips_local_state_changed) + key=info["group_id"], handler=self.__on_mips_local_state_changed + ) mips.connect() # Lan ctrl await self._miot_lan.vote_for_lan_ctrl_async( - key=f'{self._uid}-{self._cloud_server}', vote=True) + key=f"{self._uid}-{self._cloud_server}", vote=True + ) self._miot_lan.sub_lan_state( - key=f'{self._uid}-{self._cloud_server}', - handler=self.__on_miot_lan_state_change) + key=f"{self._uid}-{self._cloud_server}", + handler=self.__on_miot_lan_state_change, + ) if self._miot_lan.init_done: await self.__on_miot_lan_state_change(True) else: - self._miot_lan.unsub_lan_state( - key=f'{self._uid}-{self._cloud_server}') + self._miot_lan.unsub_lan_state(key=f"{self._uid}-{self._cloud_server}") if self._miot_lan.init_done: self._miot_lan.unsub_device_state( - key=f'{self._uid}-{self._cloud_server}') + key=f"{self._uid}-{self._cloud_server}" + ) self._miot_lan.delete_devices( - devices=list(self._device_list_cache.keys())) + devices=list(self._device_list_cache.keys()) + ) await self._miot_lan.vote_for_lan_ctrl_async( - key=f'{self._uid}-{self._cloud_server}', vote=False) + key=f"{self._uid}-{self._cloud_server}", vote=False + ) - _LOGGER.info('init_async, %s, %s', self._uid, self._cloud_server) + _LOGGER.info("init_async, %s, %s", self._uid, self._cloud_server) async def deinit_async(self) -> None: - self._network.unsub_network_status( - key=f'{self._uid}-{self._cloud_server}') + self._network.unsub_network_status(key=f"{self._uid}-{self._cloud_server}") # Cancel refresh props if self._refresh_props_timer: self._refresh_props_timer.cancel() @@ -364,8 +391,7 @@ class MIoTClient: self._refresh_props_list.clear() self._refresh_props_retry_count = 0 # Cloud mips - self._mips_cloud.unsub_mips_state( - key=f'{self._uid}-{self._cloud_server}') + self._mips_cloud.unsub_mips_state(key=f"{self._uid}-{self._cloud_server}") self._mips_cloud.deinit() # Cancel refresh cloud devices if self._refresh_cloud_devices_timer: @@ -375,25 +401,27 @@ class MIoTClient: # Central hub gateway mips if self._cloud_server in SUPPORT_CENTRAL_GATEWAY_CTRL: self._mips_service.unsub_service_change( - key=f'{self._uid}-{self._cloud_server}') + key=f"{self._uid}-{self._cloud_server}" + ) for mips in self._mips_local.values(): mips.on_dev_list_changed = None mips.unsub_mips_state(key=mips.group_id) mips.deinit() if self._mips_local_state_changed_timers: - for timer_item in ( - self._mips_local_state_changed_timers.values()): + for timer_item in self._mips_local_state_changed_timers.values(): timer_item.cancel() self._mips_local_state_changed_timers.clear() - self._miot_lan.unsub_lan_state( - key=f'{self._uid}-{self._cloud_server}') + self._miot_lan.unsub_lan_state(key=f"{self._uid}-{self._cloud_server}") if self._miot_lan.init_done: self._miot_lan.unsub_device_state( - key=f'{self._uid}-{self._cloud_server}') + key=f"{self._uid}-{self._cloud_server}" + ) self._miot_lan.delete_devices( - devices=list(self._device_list_cache.keys())) + devices=list(self._device_list_cache.keys()) + ) await self._miot_lan.vote_for_lan_ctrl_async( - key=f'{self._uid}-{self._cloud_server}', vote=False) + key=f"{self._uid}-{self._cloud_server}", vote=False + ) # Cancel refresh auth info if self._refresh_token_timer: self._refresh_token_timer.cancel() @@ -408,18 +436,13 @@ class MIoTClient: await self._oauth.deinit_async() await self._http.deinit_async() # Remove notify - self._persistence_notify( - self.__gen_notify_key('dev_list_changed'), None, None) - self.__show_client_error_notify( - message=None, notify_key='oauth_info') - self.__show_client_error_notify( - message=None, notify_key='user_cert') - self.__show_client_error_notify( - message=None, notify_key='device_cache') - self.__show_client_error_notify( - message=None, notify_key='device_cloud') + self._persistence_notify(self.__gen_notify_key("dev_list_changed"), None, None) + self.__show_client_error_notify(message=None, notify_key="oauth_info") + self.__show_client_error_notify(message=None, notify_key="user_cert") + self.__show_client_error_notify(message=None, notify_key="device_cache") + self.__show_client_error_notify(message=None, notify_key="device_cloud") - _LOGGER.info('deinit_async, %s', self._uid) + _LOGGER.info("deinit_async, %s", self._uid) @property def main_loop(self) -> asyncio.AbstractEventLoop: @@ -459,7 +482,7 @@ class MIoTClient: @property def area_name_rule(self) -> Optional[str]: - return self._entry_data.get('area_name_rule', None) + return self._entry_data.get("area_name_rule", None) @property def cloud_server(self) -> str: @@ -467,12 +490,11 @@ class MIoTClient: @property def action_debug(self) -> bool: - return self._entry_data.get('action_debug', False) + return self._entry_data.get("action_debug", False) @property def hide_non_standard_entities(self) -> bool: - return self._entry_data.get( - 'hide_non_standard_entities', False) + return self._entry_data.get("hide_non_standard_entities", False) @property def display_devices_changed_notify(self) -> list[str]: @@ -495,7 +517,8 @@ class MIoTClient: self.__request_show_devices_changed_notify() else: self._persistence_notify( - self.__gen_notify_key('dev_list_changed'), None, None) + self.__gen_notify_key("dev_list_changed"), None, None + ) @property def device_list(self) -> dict: @@ -515,57 +538,67 @@ class MIoTClient: # Load auth info auth_info: Optional[dict] = None user_config: dict = await self._storage.load_user_config_async( - uid=self._uid, cloud_server=self._cloud_server, - keys=['auth_info']) + uid=self._uid, cloud_server=self._cloud_server, keys=["auth_info"] + ) if ( not user_config - or (auth_info := user_config.get('auth_info', None)) is None + or (auth_info := user_config.get("auth_info", None)) is None ): - raise MIoTClientError('load_user_config_async error') + raise MIoTClientError("load_user_config_async error") if ( - 'expires_ts' not in auth_info - or 'access_token' not in auth_info - or 'refresh_token' not in auth_info + "expires_ts" not in auth_info + or "access_token" not in auth_info + or "refresh_token" not in auth_info ): - raise MIoTClientError('invalid auth info') + raise MIoTClientError("invalid auth info") # Determine whether to update token - refresh_time = int(auth_info['expires_ts'] - time.time()) + refresh_time = int(auth_info["expires_ts"] - time.time()) if refresh_time <= 60: valid_auth_info = await self._oauth.refresh_access_token_async( - refresh_token=auth_info['refresh_token']) + refresh_token=auth_info["refresh_token"] + ) auth_info = valid_auth_info # Update http token self._http.update_http_header( - access_token=valid_auth_info['access_token']) + access_token=valid_auth_info["access_token"] + ) # Update mips cloud token self._mips_cloud.update_access_token( - access_token=valid_auth_info['access_token']) + access_token=valid_auth_info["access_token"] + ) # Update storage if not await self._storage.update_user_config_async( - uid=self._uid, cloud_server=self._cloud_server, - config={'auth_info': auth_info}): - raise MIoTClientError('update_user_config_async error') - _LOGGER.info( - 'refresh oauth info, get new access_token, %s', - auth_info) - refresh_time = int(auth_info['expires_ts'] - time.time()) + uid=self._uid, + cloud_server=self._cloud_server, + config={"auth_info": auth_info}, + ): + raise MIoTClientError("update_user_config_async error") + _LOGGER.info("refresh oauth info, get new access_token, %s", auth_info) + refresh_time = int(auth_info["expires_ts"] - time.time()) if refresh_time <= 0: - raise MIoTClientError('invalid expires time') - self.__show_client_error_notify(None, 'oauth_info') + raise MIoTClientError("invalid expires time") + self.__show_client_error_notify(None, "oauth_info") self.__request_refresh_auth_info(refresh_time) _LOGGER.debug( - 'refresh oauth info (%s, %s) after %ds', - self._uid, self._cloud_server, refresh_time) + "refresh oauth info (%s, %s) after %ds", + self._uid, + self._cloud_server, + refresh_time, + ) return True except Exception as err: self.__show_client_error_notify( - message=self._i18n.translate( - 'miot.client.invalid_oauth_info'), # type: ignore - notify_key='oauth_info') + message=self._i18n.translate("miot.client.invalid_oauth_info"), # type: ignore + notify_key="oauth_info", + ) _LOGGER.error( - 'refresh oauth info error (%s, %s), %s, %s', - self._uid, self._cloud_server, err, traceback.format_exc()) + "refresh oauth info error (%s, %s), %s, %s", + self._uid, + self._cloud_server, + err, + traceback.format_exc(), + ) return False async def refresh_user_cert_async(self) -> bool: @@ -573,305 +606,398 @@ class MIoTClient: if self._cloud_server not in SUPPORT_CENTRAL_GATEWAY_CTRL: return True if not await self._cert.verify_ca_cert_async(): - raise MIoTClientError('ca cert is not ready') + raise MIoTClientError("ca cert is not ready") refresh_time = ( - await self._cert.user_cert_remaining_time_async() - - MIHOME_CERT_EXPIRE_MARGIN) + await self._cert.user_cert_remaining_time_async() + - MIHOME_CERT_EXPIRE_MARGIN + ) if refresh_time <= 60: user_key = await self._cert.load_user_key_async() if not user_key: user_key = self._cert.gen_user_key() if not await self._cert.update_user_key_async(key=user_key): - raise MIoTClientError('update_user_key_async failed') + raise MIoTClientError("update_user_key_async failed") csr_str = self._cert.gen_user_csr( - user_key=user_key, did=self._entry_data['virtual_did']) + user_key=user_key, did=self._entry_data["virtual_did"] + ) crt_str = await self.miot_http.get_central_cert_async(csr_str) if not await self._cert.update_user_cert_async(cert=crt_str): - raise MIoTClientError('update user cert error') - _LOGGER.info('update_user_cert_async, %s', crt_str) + raise MIoTClientError("update user cert error") + _LOGGER.info("update_user_cert_async, %s", crt_str) # Create cert update task refresh_time = ( - await self._cert.user_cert_remaining_time_async() - - MIHOME_CERT_EXPIRE_MARGIN) + await self._cert.user_cert_remaining_time_async() + - MIHOME_CERT_EXPIRE_MARGIN + ) if refresh_time <= 0: - raise MIoTClientError('invalid refresh time') - self.__show_client_error_notify(None, 'user_cert') + raise MIoTClientError("invalid refresh time") + self.__show_client_error_notify(None, "user_cert") self.__request_refresh_user_cert(refresh_time) _LOGGER.debug( - 'refresh user cert (%s, %s) after %ds', - self._uid, self._cloud_server, refresh_time) + "refresh user cert (%s, %s) after %ds", + self._uid, + self._cloud_server, + refresh_time, + ) return True except MIoTClientError as error: self.__show_client_error_notify( - message=self._i18n.translate( - 'miot.client.invalid_cert_info'), # type: ignore - notify_key='user_cert') + message=self._i18n.translate("miot.client.invalid_cert_info"), # type: ignore + notify_key="user_cert", + ) _LOGGER.error( - 'refresh user cert error, %s, %s', - error, traceback.format_exc()) + "refresh user cert error, %s, %s", error, traceback.format_exc() + ) return False - async def set_prop_async( - self, did: str, siid: int, piid: int, value: Any - ) -> bool: + async def set_prop_async(self, did: str, siid: int, piid: int, value: Any) -> bool: if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') + raise MIoTClientError(f"did not exist, {did}") # Priority local control if self._ctrl_mode == CtrlMode.AUTO: # Gateway control device_gw = self._device_list_gateway.get(did, None) if ( - device_gw and device_gw.get('online', False) - and device_gw.get('specv2_access', False) - and 'group_id' in device_gw + device_gw + and device_gw.get("online", False) + and device_gw.get("specv2_access", False) + and "group_id" in device_gw ): - mips = self._mips_local.get(device_gw['group_id'], None) + mips = self._mips_local.get(device_gw["group_id"], None) if mips is None: - _LOGGER.error( - 'no gw route, %s, try control throw cloud', - device_gw) + _LOGGER.error("no gw route, %s, try control throw cloud", device_gw) else: result = await mips.set_prop_async( - did=did, siid=siid, piid=piid, value=value) + did=did, siid=siid, piid=piid, value=value + ) rc = (result or {}).get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + "code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value + ) if rc in [0, 1]: return True - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # Lan control device_lan = self._device_list_lan.get(did, None) - if device_lan and device_lan.get('online', False): + if device_lan and device_lan.get("online", False): result = await self._miot_lan.set_prop_async( - did=did, siid=siid, piid=piid, value=value) + did=did, siid=siid, piid=piid, value=value + ) _LOGGER.debug( - 'lan set prop, %s.%d.%d, %s -> %s', - did, siid, piid, value, result) + "lan set prop, %s.%d.%d, %s -> %s", did, siid, piid, value, result + ) rc = (result or {}).get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + "code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value + ) if rc in [0, 1]: return True - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # Cloud control device_cloud = self._device_list_cloud.get(did, None) - if device_cloud and device_cloud.get('online', False): + if device_cloud and device_cloud.get("online", False): result = await self._http.set_prop_async( - params=[ - {'did': did, 'siid': siid, 'piid': piid, 'value': value} - ]) + params=[{"did": did, "siid": siid, "piid": piid, "value": value}] + ) _LOGGER.debug( - 'set prop response, %s.%d.%d, %s, result, %s', - did, siid, piid, value, result) + "set prop response, %s.%d.%d, %s, result, %s", + did, + siid, + piid, + value, + result, + ) if result and len(result) == 1: - rc = result[0].get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + rc = result[0].get("code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) if rc in [0, 1]: return True if rc in [-704010000, -704042011]: # Device remove or offline - _LOGGER.error('device may be removed or offline, %s', did) + _LOGGER.error("device may be removed or offline, %s", did) self._main_loop.create_task( - await self.__refresh_cloud_device_with_dids_async( - dids=[did])) - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + await self.__refresh_cloud_device_with_dids_async(dids=[did]) + ) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # Show error message raise MIoTClientError( - f'{self._i18n.translate("miot.client.device_exec_error")}, ' - f'{self._i18n.translate("error.common.-10007")}') + f"{self._i18n.translate('miot.client.device_exec_error')}, " + f"{self._i18n.translate('error.common.-10007')}" + ) - def request_refresh_prop( - self, did: str, siid: int, piid: int - ) -> None: + async def set_props_async( + self, + props_list: List[Dict[str, Any]], + ) -> bool: + did_set = {prop["did"] for prop in props_list} + if len(did_set) > 1: + raise MIoTClientError(f"more than one did once, {did_set}") + did = did_set.pop() if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') - key: str = f'{did}|{siid}|{piid}' + raise MIoTClientError(f"did not exist, {did}") + # Priority local control + if self._ctrl_mode == CtrlMode.AUTO: + # Gateway control + device_gw = self._device_list_gateway.get(did, None) + if ( + device_gw + and device_gw.get("online", False) + and device_gw.get("specv2_access", False) + and "group_id" in device_gw + ): + mips = self._mips_local.get(device_gw["group_id"], None) + if mips is None: + _LOGGER.error("no gw route, %s, try control throw cloud", device_gw) + else: + result = await mips.set_props_async(did=did, props_list=props_list) + rc = { + (r or {}).get( + "code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value + ) + for r in result + } + if all(t in [0, 1] for t in rc): + return True + else: + raise MIoTClientError( + self.__get_exec_error_with_rc(rc=(rc - {0, 1})[0]) + ) + # Lan control + device_lan = self._device_list_lan.get(did, None) + if device_lan and device_lan.get("online", False): + result = await self._miot_lan.set_props_async( + did=did, props_list=props_list + ) + _LOGGER.debug("lan set prop, %s -> %s", props_list, result) + rc = { + (r or {}).get("code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + for r in result + } + if all(t in [0, 1] for t in rc): + return True + else: + raise MIoTClientError( + self.__get_exec_error_with_rc(rc=(rc - {0, 1})[0]) + ) + # Cloud control + device_cloud = self._device_list_cloud.get(did, None) + if device_cloud and device_cloud.get("online", False): + result = await self._http.set_props_async(params=props_list) + _LOGGER.debug( + "set prop response, %s, result, %s", + props_list, + result, + ) + if result and len(result) == len(props_list): + rc = { + (r or {}).get("code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + for r in result + } + if all(t in [0, 1] for t in rc): + return True + + if any(t in [-704010000, -704042011] for t in rc): + # Device remove or offline + _LOGGER.error("device may be removed or offline, %s", did) + self._main_loop.create_task( + await self.__refresh_cloud_device_with_dids_async(dids=[did]) + ) + else: + raise MIoTClientError( + self.__get_exec_error_with_rc(rc=(rc - {0, 1})[0]) + ) + + # Show error message + raise MIoTClientError( + f"{self._i18n.translate('miot.client.device_exec_error')}, " + f"{self._i18n.translate('error.common.-10007')}" + ) + + def request_refresh_prop(self, did: str, siid: int, piid: int) -> None: + if did not in self._device_list_cache: + raise MIoTClientError(f"did not exist, {did}") + key: str = f"{did}|{siid}|{piid}" if key in self._refresh_props_list: return - self._refresh_props_list[key] = { - 'did': did, 'siid': siid, 'piid': piid} + self._refresh_props_list[key] = {"did": did, "siid": siid, "piid": piid} if self._refresh_props_timer: return self._refresh_props_timer = self._main_loop.call_later( - 0.2, lambda: self._main_loop.create_task( - self.__refresh_props_handler())) + 0.2, lambda: self._main_loop.create_task(self.__refresh_props_handler()) + ) async def get_prop_async(self, did: str, siid: int, piid: int) -> Any: if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') + raise MIoTClientError(f"did not exist, {did}") # NOTICE: Since there are too many request attributes and obtaining # them directly from the hub or device will cause device abnormalities, # so obtaining the cache from the cloud is the priority here. try: if self._network.network_status: - result = await self._http.get_prop_async( - did=did, siid=siid, piid=piid) + result = await self._http.get_prop_async(did=did, siid=siid, piid=piid) if result: return result except Exception as err: # pylint: disable=broad-exception-caught # Catch all exceptions _LOGGER.error( - 'client get prop from cloud error, %s, %s', - err, traceback.format_exc()) + "client get prop from cloud error, %s, %s", err, traceback.format_exc() + ) if self._ctrl_mode == CtrlMode.AUTO: # Central hub gateway device_gw = self._device_list_gateway.get(did, None) if ( - device_gw and device_gw.get('online', False) - and device_gw.get('specv2_access', False) - and 'group_id' in device_gw + device_gw + and device_gw.get("online", False) + and device_gw.get("specv2_access", False) + and "group_id" in device_gw ): - mips = self._mips_local.get(device_gw['group_id'], None) + mips = self._mips_local.get(device_gw["group_id"], None) if mips is None: - _LOGGER.error('no gw route, %s', device_gw) + _LOGGER.error("no gw route, %s", device_gw) else: - return await mips.get_prop_async( - did=did, siid=siid, piid=piid) + return await mips.get_prop_async(did=did, siid=siid, piid=piid) # Lan device_lan = self._device_list_lan.get(did, None) - if device_lan and device_lan.get('online', False): + if device_lan and device_lan.get("online", False): return await self._miot_lan.get_prop_async( - did=did, siid=siid, piid=piid) + did=did, siid=siid, piid=piid + ) # _LOGGER.error( # 'client get prop failed, no-link, %s.%d.%d', did, siid, piid) return None - async def action_async( - self, did: str, siid: int, aiid: int, in_list: list - ) -> list: + async def action_async(self, did: str, siid: int, aiid: int, in_list: list) -> list: if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') + raise MIoTClientError(f"did not exist, {did}") device_gw = self._device_list_gateway.get(did, None) # Priority local control if self._ctrl_mode == CtrlMode.AUTO: if ( - device_gw and device_gw.get('online', False) - and device_gw.get('specv2_access', False) - and 'group_id' in device_gw + device_gw + and device_gw.get("online", False) + and device_gw.get("specv2_access", False) + and "group_id" in device_gw ): - mips = self._mips_local.get( - device_gw['group_id'], None) + mips = self._mips_local.get(device_gw["group_id"], None) if mips is None: - _LOGGER.error('no gw route, %s', device_gw) + _LOGGER.error("no gw route, %s", device_gw) else: result = await mips.action_async( - did=did, siid=siid, aiid=aiid, in_list=in_list) + did=did, siid=siid, aiid=aiid, in_list=in_list + ) rc = (result or {}).get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + "code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value + ) if rc in [0, 1]: - return result.get('out', []) - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + return result.get("out", []) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # Lan control device_lan = self._device_list_lan.get(did, None) - if device_lan and device_lan.get('online', False): + if device_lan and device_lan.get("online", False): result = await self._miot_lan.action_async( - did=did, siid=siid, aiid=aiid, in_list=in_list) - _LOGGER.debug( - 'lan action, %s, %s, %s -> %s', did, siid, aiid, result) + did=did, siid=siid, aiid=aiid, in_list=in_list + ) + _LOGGER.debug("lan action, %s, %s, %s -> %s", did, siid, aiid, result) rc = (result or {}).get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + "code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value + ) if rc in [0, 1]: - return result.get('out', []) - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + return result.get("out", []) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # Cloud control device_cloud = self._device_list_cloud.get(did, None) - if device_cloud and device_cloud.get('online', False): + if device_cloud and device_cloud.get("online", False): result: dict = await self._http.action_async( - did=did, siid=siid, aiid=aiid, in_list=in_list) + did=did, siid=siid, aiid=aiid, in_list=in_list + ) if result: - rc = result.get( - 'code', MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) + rc = result.get("code", MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value) if rc in [0, 1]: - return result.get('out', []) + return result.get("out", []) if rc in [-704010000, -704042011]: # Device remove or offline - _LOGGER.error('device removed or offline, %s', did) + _LOGGER.error("device removed or offline, %s", did) self._main_loop.create_task( - await self.__refresh_cloud_device_with_dids_async( - dids=[did])) - raise MIoTClientError( - self.__get_exec_error_with_rc(rc=rc)) + await self.__refresh_cloud_device_with_dids_async(dids=[did]) + ) + raise MIoTClientError(self.__get_exec_error_with_rc(rc=rc)) # TODO: Show error message - _LOGGER.error( - 'client action failed, %s.%d.%d', did, siid, aiid) + _LOGGER.error("client action failed, %s.%d.%d", did, siid, aiid) return [] def sub_prop( - self, did: str, handler: Callable[[dict, Any], None], - siid: Optional[int] = None, piid: Optional[int] = None, - handler_ctx: Any = None + self, + did: str, + handler: Callable[[dict, Any], None], + siid: Optional[int] = None, + piid: Optional[int] = None, + handler_ctx: Any = None, ) -> bool: if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') + raise MIoTClientError(f"did not exist, {did}") - topic = ( - f'{did}/p/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + topic = f"{did}/p/{'#' if siid is None or piid is None else f'{siid}/{piid}'}" self._sub_tree[topic] = MIoTClientSub( - topic=topic, handler=handler, handler_ctx=handler_ctx) - _LOGGER.debug('client sub prop, %s', topic) + topic=topic, handler=handler, handler_ctx=handler_ctx + ) + _LOGGER.debug("client sub prop, %s", topic) return True def unsub_prop( self, did: str, siid: Optional[int] = None, piid: Optional[int] = None ) -> bool: - topic = ( - f'{did}/p/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + topic = f"{did}/p/{'#' if siid is None or piid is None else f'{siid}/{piid}'}" if self._sub_tree.get(topic=topic): del self._sub_tree[topic] - _LOGGER.debug('client unsub prop, %s', topic) + _LOGGER.debug("client unsub prop, %s", topic) return True def sub_event( - self, did: str, handler: Callable[[dict, Any], None], - siid: Optional[int] = None, eiid: Optional[int] = None, - handler_ctx: Any = None + self, + did: str, + handler: Callable[[dict, Any], None], + siid: Optional[int] = None, + eiid: Optional[int] = None, + handler_ctx: Any = None, ) -> bool: if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') - topic = ( - f'{did}/e/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + raise MIoTClientError(f"did not exist, {did}") + topic = f"{did}/e/{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" self._sub_tree[topic] = MIoTClientSub( - topic=topic, handler=handler, handler_ctx=handler_ctx) - _LOGGER.debug('client sub event, %s', topic) + topic=topic, handler=handler, handler_ctx=handler_ctx + ) + _LOGGER.debug("client sub event, %s", topic) return True def unsub_event( self, did: str, siid: Optional[int] = None, eiid: Optional[int] = None ) -> bool: - topic = ( - f'{did}/e/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + topic = f"{did}/e/{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" if self._sub_tree.get(topic=topic): del self._sub_tree[topic] - _LOGGER.debug('client unsub event, %s', topic) + _LOGGER.debug("client unsub event, %s", topic) return True def sub_device_state( - self, did: str, handler: Callable[[str, MIoTDeviceState, Any], None], - handler_ctx: Any = None + self, + did: str, + handler: Callable[[str, MIoTDeviceState, Any], None], + handler_ctx: Any = None, ) -> bool: """Call callback handler in main loop""" if did not in self._device_list_cache: - raise MIoTClientError(f'did not exist, {did}') + raise MIoTClientError(f"did not exist, {did}") self._sub_device_state[did] = MipsDeviceState( - did=did, handler=handler, handler_ctx=handler_ctx) - _LOGGER.debug('client sub device state, %s', did) + did=did, handler=handler, handler_ctx=handler_ctx + ) + _LOGGER.debug("client sub device state, %s", did) return True def unsub_device_state(self, did: str) -> bool: self._sub_device_state.pop(did, None) - _LOGGER.debug('client unsub device state, %s', did) + _LOGGER.debug("client unsub device state, %s", did) return True async def remove_device_async(self, did: str) -> None: @@ -883,9 +1009,10 @@ class MIoTClient: self.__unsub_from(sub_from, did) # Storage await self._storage.save_async( - domain='miot_devices', - name=f'{self._uid}_{self._cloud_server}', - data=self._device_list_cache) + domain="miot_devices", + name=f"{self._uid}_{self._cloud_server}", + data=self._device_list_cache, + ) # Update notify self.__request_show_devices_changed_notify() @@ -897,18 +1024,17 @@ class MIoTClient: break def __get_exec_error_with_rc(self, rc: int) -> str: - err_msg: str = self._i18n.translate( - key=f'error.common.{rc}') # type: ignore + err_msg: str = self._i18n.translate(key=f"error.common.{rc}") # type: ignore if not err_msg: - err_msg = f'{self._i18n.translate(key="error.common.-10000")}, ' - err_msg += f'code={rc}' + err_msg = f"{self._i18n.translate(key='error.common.-10000')}, " + err_msg += f"code={rc}" return ( - f'{self._i18n.translate(key="miot.client.device_exec_error")}, ' - + err_msg) + f"{self._i18n.translate(key='miot.client.device_exec_error')}, " + err_msg + ) @final def __gen_notify_key(self, name: str) -> str: - return f'{DOMAIN}-{self._uid}-{self._cloud_server}-{name}' + return f"{DOMAIN}-{self._uid}-{self._cloud_server}-{name}" @final def __request_refresh_auth_info(self, delay_sec: int) -> None: @@ -916,8 +1042,9 @@ class MIoTClient: self._refresh_token_timer.cancel() self._refresh_token_timer = None self._refresh_token_timer = self._main_loop.call_later( - delay_sec, lambda: self._main_loop.create_task( - self.refresh_oauth_info_async())) + delay_sec, + lambda: self._main_loop.create_task(self.refresh_oauth_info_async()), + ) @final def __request_refresh_user_cert(self, delay_sec: int) -> None: @@ -925,15 +1052,16 @@ class MIoTClient: self._refresh_cert_timer.cancel() self._refresh_cert_timer = None self._refresh_cert_timer = self._main_loop.call_later( - delay_sec, lambda: self._main_loop.create_task( - self.refresh_user_cert_async())) + delay_sec, + lambda: self._main_loop.create_task(self.refresh_user_cert_async()), + ) @final def __unsub_from(self, sub_from: str, did: str) -> None: mips: Any = None - if sub_from == 'cloud': + if sub_from == "cloud": mips = self._mips_cloud - elif sub_from == 'lan': + elif sub_from == "lan": mips = self._miot_lan elif sub_from in self._mips_local: mips = self._mips_local[sub_from] @@ -942,7 +1070,7 @@ class MIoTClient: mips.unsub_prop(did=did) mips.unsub_event(did=did) except RuntimeError as e: - if 'Event loop is closed' in str(e): + if "Event loop is closed" in str(e): # Ignore unsub exception when loop is closed pass else: @@ -951,9 +1079,9 @@ class MIoTClient: @final def __sub_from(self, sub_from: str, did: str) -> None: mips = None - if sub_from == 'cloud': + if sub_from == "cloud": mips = self._mips_cloud - elif sub_from == 'lan': + elif sub_from == "lan": mips = self._miot_lan elif sub_from in self._mips_local: mips = self._mips_local[sub_from] @@ -970,23 +1098,23 @@ class MIoTClient: if self._ctrl_mode == CtrlMode.AUTO: if ( did in self._device_list_gateway - and self._device_list_gateway[did].get('online', False) - and self._device_list_gateway[did].get('push_available', False) + and self._device_list_gateway[did].get("online", False) + and self._device_list_gateway[did].get("push_available", False) ): - from_new = self._device_list_gateway[did]['group_id'] + from_new = self._device_list_gateway[did]["group_id"] elif ( did in self._device_list_lan - and self._device_list_lan[did].get('online', False) - and self._device_list_lan[did].get('push_available', False) + and self._device_list_lan[did].get("online", False) + and self._device_list_lan[did].get("push_available", False) ): - from_new = 'lan' + from_new = "lan" if ( from_new is None and did in self._device_list_cloud - and self._device_list_cloud[did].get('online', False) + and self._device_list_cloud[did].get("online", False) ): - from_new = 'cloud' + from_new = "cloud" if from_new == from_old: # No need to update return @@ -996,12 +1124,11 @@ class MIoTClient: # Sub new self.__sub_from(from_new, did) self._sub_source_list[did] = from_new - _LOGGER.info( - 'device sub changed, %s, from %s to %s', did, from_old, from_new) + _LOGGER.info("device sub changed, %s, from %s to %s", did, from_old, from_new) @final async def __on_network_status_changed(self, status: bool) -> None: - _LOGGER.info('network status changed, %s', status) + _LOGGER.info("network status changed, %s", status) if status: # Check auth_info if await self.refresh_oauth_info_async(): @@ -1023,8 +1150,7 @@ class MIoTClient: async def __on_mips_service_state_change( self, group_id: str, state: MipsServiceState, data: dict ) -> None: - _LOGGER.info( - 'mips service state changed, %s, %s, %s', group_id, state, data) + _LOGGER.info("mips service state changed, %s, %s, %s", group_id, state, data) mips = self._mips_local.get(group_id, None) if mips: @@ -1033,81 +1159,78 @@ class MIoTClient: self._mips_local.pop(group_id, None) return if ( - mips.client_id == self._entry_data['virtual_did'] - and mips.host == data['addresses'][0] - and mips.port == data['port'] + mips.client_id == self._entry_data["virtual_did"] + and mips.host == data["addresses"][0] + and mips.port == data["port"] ): return mips.disconnect() self._mips_local.pop(group_id, None) - home_name: str = '' - for info in list(self._entry_data['home_selected'].values()): - if info.get('group_id', None) == group_id: - home_name = info.get('home_name', '') + home_name: str = "" + for info in list(self._entry_data["home_selected"].values()): + if info.get("group_id", None) == group_id: + home_name = info.get("home_name", "") mips = MipsLocalClient( - did=self._entry_data['virtual_did'], + did=self._entry_data["virtual_did"], group_id=group_id, - host=data['addresses'][0], + host=data["addresses"][0], ca_file=self._cert.ca_file, cert_file=self._cert.cert_file, key_file=self._cert.key_file, - port=data['port'], + port=data["port"], home_name=home_name, - loop=self._main_loop) + loop=self._main_loop, + ) self._mips_local[group_id] = mips mips.enable_logger(logger=_LOGGER) mips.on_dev_list_changed = self.__on_gw_device_list_changed - mips.sub_mips_state( - key=group_id, handler=self.__on_mips_local_state_changed) + mips.sub_mips_state(key=group_id, handler=self.__on_mips_local_state_changed) mips.connect() @final - async def __on_mips_cloud_state_changed( - self, key: str, state: bool - ) -> None: - _LOGGER.info('cloud mips state changed, %s, %s', key, state) + async def __on_mips_cloud_state_changed(self, key: str, state: bool) -> None: + _LOGGER.info("cloud mips state changed, %s, %s", key, state) if state: # Connect self.__request_refresh_cloud_devices(immediately=True) # Sub cloud device state for did in list(self._device_list_cache.keys()): self._mips_cloud.sub_device_state( - did=did, handler=self.__on_cloud_device_state_changed) + did=did, handler=self.__on_cloud_device_state_changed + ) else: # Disconnect for did, info in self._device_list_cloud.items(): - cloud_state_old: Optional[bool] = info.get('online', None) + cloud_state_old: Optional[bool] = info.get("online", None) if not cloud_state_old: # Cloud state is None or False, no need to update continue - info['online'] = False + info["online"] = False if did not in self._device_list_cache: continue self.__update_device_msg_sub(did=did) state_old: Optional[bool] = self._device_list_cache[did].get( - 'online', None) + "online", None + ) state_new: Optional[bool] = self.__check_device_state( False, - self._device_list_gateway.get( - did, {}).get('online', False), - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_gateway.get(did, {}).get("online", False), + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: continue - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler(did, MIoTDeviceState.OFFLINE, sub.handler_ctx) self.__request_show_devices_changed_notify() @final - async def __on_mips_local_state_changed( - self, group_id: str, state: bool - ) -> None: - _LOGGER.info('local mips state changed, %s, %s', group_id, state) + async def __on_mips_local_state_changed(self, group_id: str, state: bool) -> None: + _LOGGER.info("local mips state changed, %s, %s", group_id, state) mips = self._mips_local.get(group_id, None) if not mips: - _LOGGER.error( - 'local mips state changed, mips not exist, %s', group_id) + _LOGGER.error("local mips state changed, mips not exist, %s", group_id) return if state: # Connected @@ -1115,28 +1238,30 @@ class MIoTClient: else: # Disconnect for did, info in self._device_list_gateway.items(): - if info.get('group_id', None) != group_id: + if info.get("group_id", None) != group_id: # Not belong to this gateway continue - if not info.get('online', False): + if not info.get("online", False): # Device offline, no need to update continue # Update local device info - info['online'] = False - info['push_available'] = False + info["online"] = False + info["push_available"] = False if did not in self._device_list_cache: # Device not exist continue self.__update_device_msg_sub(did=did) - state_old: Optional[bool] = self._device_list_cache.get( - did, {}).get('online', None) + state_old: Optional[bool] = self._device_list_cache.get(did, {}).get( + "online", None + ) state_new: Optional[bool] = self.__check_device_state( - self._device_list_cloud.get(did, {}).get('online', None), + self._device_list_cloud.get(did, {}).get("online", None), False, - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: continue - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler(did, MIoTDeviceState.OFFLINE, sub.handler_ctx) @@ -1145,45 +1270,49 @@ class MIoTClient: @final async def __on_miot_lan_state_change(self, state: bool) -> None: _LOGGER.info( - 'miot lan state changed, %s, %s, %s', - self._uid, self._cloud_server, state) + "miot lan state changed, %s, %s, %s", self._uid, self._cloud_server, state + ) if state: # Update device self._miot_lan.sub_device_state( - key=f'{self._uid}-{self._cloud_server}', - handler=self.__on_lan_device_state_changed) - for did, info in ( - await self._miot_lan.get_dev_list_async()).items(): - await self.__on_lan_device_state_changed( - did=did, state=info, ctx=None) - _LOGGER.info('lan device list, %s', self._device_list_lan) - self._miot_lan.update_devices(devices={ - did: { - 'token': info['token'], - 'model': info['model'], - 'connect_type': info['connect_type']} - for did, info in self._device_list_cache.items() - if 'token' in info and 'connect_type' in info - and info['connect_type'] in [0, 8, 12, 23] - }) + key=f"{self._uid}-{self._cloud_server}", + handler=self.__on_lan_device_state_changed, + ) + for did, info in (await self._miot_lan.get_dev_list_async()).items(): + await self.__on_lan_device_state_changed(did=did, state=info, ctx=None) + _LOGGER.info("lan device list, %s", self._device_list_lan) + self._miot_lan.update_devices( + devices={ + did: { + "token": info["token"], + "model": info["model"], + "connect_type": info["connect_type"], + } + for did, info in self._device_list_cache.items() + if "token" in info + and "connect_type" in info + and info["connect_type"] in [0, 8, 12, 23] + } + ) else: for did, info in self._device_list_lan.items(): - if not info.get('online', False): + if not info.get("online", False): continue # Update local device info - info['online'] = False - info['push_available'] = False + info["online"] = False + info["push_available"] = False self.__update_device_msg_sub(did=did) - state_old: Optional[bool] = self._device_list_cache.get( - did, {}).get('online', None) + state_old: Optional[bool] = self._device_list_cache.get(did, {}).get( + "online", None + ) state_new: Optional[bool] = self.__check_device_state( - self._device_list_cloud.get(did, {}).get('online', None), - self._device_list_gateway.get( - did, {}).get('online', False), - False) + self._device_list_cloud.get(did, {}).get("online", None), + self._device_list_gateway.get(did, {}).get("online", False), + False, + ) if state_old == state_new: continue - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler(did, MIoTDeviceState.OFFLINE, sub.handler_ctx) @@ -1194,87 +1323,92 @@ class MIoTClient: def __on_cloud_device_state_changed( self, did: str, state: MIoTDeviceState, ctx: Any ) -> None: - _LOGGER.info('cloud device state changed, %s, %s', did, state) + _LOGGER.info("cloud device state changed, %s, %s", did, state) cloud_device = self._device_list_cloud.get(did, None) if not cloud_device: return cloud_state_new: bool = state == MIoTDeviceState.ONLINE - if cloud_device.get('online', False) == cloud_state_new: + if cloud_device.get("online", False) == cloud_state_new: return - cloud_device['online'] = cloud_state_new + cloud_device["online"] = cloud_state_new if did not in self._device_list_cache: return self.__update_device_msg_sub(did=did) - state_old: Optional[bool] = self._device_list_cache[did].get( - 'online', None) + state_old: Optional[bool] = self._device_list_cache[did].get("online", None) state_new: Optional[bool] = self.__check_device_state( cloud_state_new, - self._device_list_gateway.get(did, {}).get('online', False), - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_gateway.get(did, {}).get("online", False), + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: return - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler( - did, MIoTDeviceState.ONLINE if state_new - else MIoTDeviceState.OFFLINE, sub.handler_ctx) + did, + MIoTDeviceState.ONLINE if state_new else MIoTDeviceState.OFFLINE, + sub.handler_ctx, + ) self.__request_show_devices_changed_notify() @final async def __on_gw_device_list_changed( self, mips: MipsLocalClient, did_list: list[str] ) -> None: - _LOGGER.info( - 'gateway devices list changed, %s, %s', mips.group_id, did_list) - payload: dict = {'filter': {'did': did_list}} - gw_list = await mips.get_dev_list_async( - payload=json.dumps(payload)) + _LOGGER.info("gateway devices list changed, %s, %s", mips.group_id, did_list) + payload: dict = {"filter": {"did": did_list}} + gw_list = await mips.get_dev_list_async(payload=json.dumps(payload)) if gw_list is None: - _LOGGER.error('local mips get_dev_list_async failed, %s', did_list) + _LOGGER.error("local mips get_dev_list_async failed, %s", did_list) return await self.__update_devices_from_gw_async( - gw_list=gw_list, group_id=mips.group_id, filter_dids=[ - did for did in did_list - if self._device_list_gateway.get(did, {}).get( - 'group_id', None) == mips.group_id]) + gw_list=gw_list, + group_id=mips.group_id, + filter_dids=[ + did + for did in did_list + if self._device_list_gateway.get(did, {}).get("group_id", None) + == mips.group_id + ], + ) self.__request_show_devices_changed_notify() @final async def __on_lan_device_state_changed( self, did: str, state: dict, ctx: Any ) -> None: - _LOGGER.info('lan device state changed, %s, %s', did, state) - lan_state_new: bool = state.get('online', False) - lan_sub_new: bool = state.get('push_available', False) + _LOGGER.info("lan device state changed, %s, %s", did, state) + lan_state_new: bool = state.get("online", False) + lan_sub_new: bool = state.get("push_available", False) self._device_list_lan.setdefault(did, {}) - if ( - lan_state_new == self._device_list_lan[did].get('online', False) - and lan_sub_new == self._device_list_lan[did].get( - 'push_available', False) - ): + if lan_state_new == self._device_list_lan[did].get( + "online", False + ) and lan_sub_new == self._device_list_lan[did].get("push_available", False): return - self._device_list_lan[did]['online'] = lan_state_new - self._device_list_lan[did]['push_available'] = lan_sub_new + self._device_list_lan[did]["online"] = lan_state_new + self._device_list_lan[did]["push_available"] = lan_sub_new if did not in self._device_list_cache: return self.__update_device_msg_sub(did=did) - if lan_state_new == self._device_list_cache[did].get('online', False): + if lan_state_new == self._device_list_cache[did].get("online", False): return - state_old: Optional[bool] = self._device_list_cache[did].get( - 'online', None) + state_old: Optional[bool] = self._device_list_cache[did].get("online", None) state_new: Optional[bool] = self.__check_device_state( - self._device_list_cloud.get(did, {}).get('online', None), - self._device_list_gateway.get(did, {}).get('online', False), - lan_state_new) + self._device_list_cloud.get(did, {}).get("online", None), + self._device_list_gateway.get(did, {}).get("online", False), + lan_state_new, + ) if state_old == state_new: return - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler( - did, MIoTDeviceState.ONLINE if state_new - else MIoTDeviceState.OFFLINE, sub.handler_ctx) + did, + MIoTDeviceState.ONLINE if state_new else MIoTDeviceState.OFFLINE, + sub.handler_ctx, + ) self.__request_show_devices_changed_notify() @final @@ -1282,22 +1416,28 @@ class MIoTClient: """params MUST contain did, siid, piid, value""" # BLE device has no online/offline msg try: - subs: list[MIoTClientSub] = list(self._sub_tree.iter_match( - f'{params["did"]}/p/{params["siid"]}/{params["piid"]}')) + subs: list[MIoTClientSub] = list( + self._sub_tree.iter_match( + f"{params['did']}/p/{params['siid']}/{params['piid']}" + ) + ) for sub in subs: sub.handler(params, sub.handler_ctx) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('on prop msg error, %s, %s', params, err) + _LOGGER.error("on prop msg error, %s, %s", params, err) @final def __on_event_msg(self, params: dict, ctx: Any) -> None: try: - subs: list[MIoTClientSub] = list(self._sub_tree.iter_match( - f'{params["did"]}/e/{params["siid"]}/{params["eiid"]}')) + subs: list[MIoTClientSub] = list( + self._sub_tree.iter_match( + f"{params['did']}/e/{params['siid']}/{params['eiid']}" + ) + ) for sub in subs: sub.handler(params, sub.handler_ctx) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('on event msg error, %s, %s', params, err) + _LOGGER.error("on event msg error, %s, %s", params, err) @final def __check_device_state( @@ -1314,39 +1454,38 @@ class MIoTClient: async def __load_cache_device_async(self) -> None: """Load device list from cache.""" cache_list: Optional[dict[str, dict]] = await self._storage.load_async( - domain='miot_devices', name=f'{self._uid}_{self._cloud_server}', - type_=dict) # type: ignore + domain="miot_devices", name=f"{self._uid}_{self._cloud_server}", type_=dict + ) # type: ignore if not cache_list: self.__show_client_error_notify( - message=self._i18n.translate( - 'miot.client.invalid_device_cache'), # type: ignore - notify_key='device_cache') - raise MIoTClientError('load device list from cache error') + message=self._i18n.translate("miot.client.invalid_device_cache"), # type: ignore + notify_key="device_cache", + ) + raise MIoTClientError("load device list from cache error") else: - self.__show_client_error_notify( - message=None, notify_key='device_cache') + self.__show_client_error_notify(message=None, notify_key="device_cache") # Set default online status = False self._device_list_cache = {} for did, info in cache_list.items(): - if info.get('online', None): - self._device_list_cache[did] = { - **info, 'online': False} + if info.get("online", None): + self._device_list_cache[did] = {**info, "online": False} else: self._device_list_cache[did] = info self._device_list_cloud = deepcopy(self._device_list_cache) self._device_list_gateway = { did: { - 'did': did, - 'name': info.get('name', None), - 'group_id': info.get('group_id', None), - 'online': False, - 'push_available': False} - for did, info in self._device_list_cache.items()} + "did": did, + "name": info.get("name", None), + "group_id": info.get("group_id", None), + "online": False, + "push_available": False, + } + for did, info in self._device_list_cache.items() + } @final async def __update_devices_from_cloud_async( - self, cloud_list: dict[str, dict], - filter_dids: Optional[list[str]] = None + self, cloud_list: dict[str, dict], filter_dids: Optional[list[str]] = None ) -> None: """Update cloud devices. NOTICE: This function will operate the cloud_list @@ -1354,21 +1493,21 @@ class MIoTClient: for did, info in self._device_list_cache.items(): if filter_dids and did not in filter_dids: continue - state_old: Optional[bool] = info.get('online', None) - cloud_state_old: Optional[bool] = self._device_list_cloud.get( - did, {}).get('online', None) + state_old: Optional[bool] = info.get("online", None) + cloud_state_old: Optional[bool] = self._device_list_cloud.get(did, {}).get( + "online", None + ) cloud_state_new: Optional[bool] = None device_new = cloud_list.pop(did, None) if device_new: - cloud_state_new = device_new.get('online', None) + cloud_state_new = device_new.get("online", None) # Update cache device info - info.update( - {**device_new, 'online': state_old}) + info.update({**device_new, "online": state_old}) # Update cloud device self._device_list_cloud[did] = device_new else: # Device deleted - self._device_list_cloud[did]['online'] = None + self._device_list_cloud[did]["online"] = None if cloud_state_old == cloud_state_new: # Cloud online status no change continue @@ -1376,194 +1515,202 @@ class MIoTClient: self.__update_device_msg_sub(did=did) state_new: Optional[bool] = self.__check_device_state( cloud_state_new, - self._device_list_gateway.get(did, {}).get('online', False), - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_gateway.get(did, {}).get("online", False), + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: # Online status no change continue - info['online'] = state_new + info["online"] = state_new # Call device state changed callback sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler( - did, MIoTDeviceState.ONLINE if state_new - else MIoTDeviceState.OFFLINE, sub.handler_ctx) + did, + MIoTDeviceState.ONLINE if state_new else MIoTDeviceState.OFFLINE, + sub.handler_ctx, + ) # New devices self._device_list_cloud.update(cloud_list) # Update storage if not await self._storage.save_async( - domain='miot_devices', - name=f'{self._uid}_{self._cloud_server}', - data=self._device_list_cache + domain="miot_devices", + name=f"{self._uid}_{self._cloud_server}", + data=self._device_list_cache, ): - _LOGGER.error('save device list to cache failed') + _LOGGER.error("save device list to cache failed") @final async def __refresh_cloud_devices_async(self) -> None: - _LOGGER.debug( - 'refresh cloud devices, %s, %s', self._uid, self._cloud_server) + _LOGGER.debug("refresh cloud devices, %s, %s", self._uid, self._cloud_server) self._refresh_cloud_devices_timer = None result = await self._http.get_devices_async( - home_ids=list(self._entry_data.get('home_selected', {}).keys())) - if not result and 'devices' not in result: + home_ids=list(self._entry_data.get("home_selected", {}).keys()) + ) + if not result and "devices" not in result: self.__show_client_error_notify( - message=self._i18n.translate( - 'miot.client.device_cloud_error'), # type: ignore - notify_key='device_cloud') + message=self._i18n.translate("miot.client.device_cloud_error"), # type: ignore + notify_key="device_cloud", + ) return else: - self.__show_client_error_notify( - message=None, notify_key='device_cloud') - cloud_list: dict[str, dict] = result['devices'] + self.__show_client_error_notify(message=None, notify_key="device_cloud") + cloud_list: dict[str, dict] = result["devices"] await self.__update_devices_from_cloud_async(cloud_list=cloud_list) # Update lan device - if ( - self._ctrl_mode == CtrlMode.AUTO - and self._miot_lan.init_done - ): - self._miot_lan.update_devices(devices={ - did: { - 'token': info['token'], - 'model': info['model'], - 'connect_type': info['connect_type']} - for did, info in self._device_list_cache.items() - if 'token' in info and 'connect_type' in info - and info['connect_type'] in [0, 8, 12, 23] - }) + if self._ctrl_mode == CtrlMode.AUTO and self._miot_lan.init_done: + self._miot_lan.update_devices( + devices={ + did: { + "token": info["token"], + "model": info["model"], + "connect_type": info["connect_type"], + } + for did, info in self._device_list_cache.items() + if "token" in info + and "connect_type" in info + and info["connect_type"] in [0, 8, 12, 23] + } + ) self.__request_show_devices_changed_notify() @final - async def __refresh_cloud_device_with_dids_async( - self, dids: list[str] - ) -> None: - _LOGGER.debug('refresh cloud device with dids, %s', dids) + async def __refresh_cloud_device_with_dids_async(self, dids: list[str]) -> None: + _LOGGER.debug("refresh cloud device with dids, %s", dids) cloud_list = await self._http.get_devices_with_dids_async(dids=dids) if cloud_list is None: - _LOGGER.error('cloud http get_dev_list_async failed, %s', dids) + _LOGGER.error("cloud http get_dev_list_async failed, %s", dids) return await self.__update_devices_from_cloud_async( - cloud_list=cloud_list, filter_dids=dids) + cloud_list=cloud_list, filter_dids=dids + ) self.__request_show_devices_changed_notify() def __request_refresh_cloud_devices(self, immediately=False) -> None: _LOGGER.debug( - 'request refresh cloud devices, %s, %s', - self._uid, self._cloud_server) + "request refresh cloud devices, %s, %s", self._uid, self._cloud_server + ) if immediately: if self._refresh_cloud_devices_timer: self._refresh_cloud_devices_timer.cancel() self._refresh_cloud_devices_timer = self._main_loop.call_later( - 0, lambda: self._main_loop.create_task( - self.__refresh_cloud_devices_async())) + 0, + lambda: self._main_loop.create_task( + self.__refresh_cloud_devices_async() + ), + ) return if self._refresh_cloud_devices_timer: return self._refresh_cloud_devices_timer = self._main_loop.call_later( - 6, lambda: self._main_loop.create_task( - self.__refresh_cloud_devices_async())) + 6, lambda: self._main_loop.create_task(self.__refresh_cloud_devices_async()) + ) @final async def __update_devices_from_gw_async( - self, gw_list: dict[str, dict], + self, + gw_list: dict[str, dict], group_id: Optional[str] = None, - filter_dids: Optional[list[str]] = None + filter_dids: Optional[list[str]] = None, ) -> None: """Update cloud devices. NOTICE: This function will operate the gw_list""" - _LOGGER.debug('update gw devices, %s, %s', group_id, filter_dids) + _LOGGER.debug("update gw devices, %s, %s", group_id, filter_dids) if not gw_list and not filter_dids: return for did, info in self._device_list_cache.items(): if did not in filter_dids: continue device_old = self._device_list_gateway.get(did, None) - gw_state_old = device_old.get( - 'online', False) if device_old else False + gw_state_old = device_old.get("online", False) if device_old else False gw_state_new: bool = False device_new = gw_list.pop(did, None) if device_new: # Update gateway device info - self._device_list_gateway[did] = { - **device_new, 'group_id': group_id} - gw_state_new = device_new.get('online', False) + self._device_list_gateway[did] = {**device_new, "group_id": group_id} + gw_state_new = device_new.get("online", False) else: # Device offline if device_old: - device_old['online'] = False + device_old["online"] = False # Update cache group_id - info['group_id'] = group_id + info["group_id"] = group_id if gw_state_old == gw_state_new: continue self.__update_device_msg_sub(did=did) - state_old: Optional[bool] = info.get('online', None) + state_old: Optional[bool] = info.get("online", None) state_new: Optional[bool] = self.__check_device_state( - self._device_list_cloud.get(did, {}).get('online', None), + self._device_list_cloud.get(did, {}).get("online", None), gw_state_new, - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: continue - info['online'] = state_new + info["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler( - did, MIoTDeviceState.ONLINE if state_new - else MIoTDeviceState.OFFLINE, sub.handler_ctx) + did, + MIoTDeviceState.ONLINE if state_new else MIoTDeviceState.OFFLINE, + sub.handler_ctx, + ) # New devices or device home info changed for did, info in gw_list.items(): - self._device_list_gateway[did] = {**info, 'group_id': group_id} + self._device_list_gateway[did] = {**info, "group_id": group_id} if did not in self._device_list_cache: continue - group_id_old: str = self._device_list_cache[did].get( - 'group_id', None) - self._device_list_cache[did]['group_id'] = group_id - _LOGGER.info( - 'move device %s from %s to %s', did, group_id_old, group_id) + group_id_old: str = self._device_list_cache[did].get("group_id", None) + self._device_list_cache[did]["group_id"] = group_id + _LOGGER.info("move device %s from %s to %s", did, group_id_old, group_id) self.__update_device_msg_sub(did=did) - state_old: Optional[bool] = self._device_list_cache[did].get( - 'online', None) + state_old: Optional[bool] = self._device_list_cache[did].get("online", None) state_new: Optional[bool] = self.__check_device_state( - self._device_list_cloud.get(did, {}).get('online', None), - info.get('online', False), - self._device_list_lan.get(did, {}).get('online', False)) + self._device_list_cloud.get(did, {}).get("online", None), + info.get("online", False), + self._device_list_lan.get(did, {}).get("online", False), + ) if state_old == state_new: continue - self._device_list_cache[did]['online'] = state_new + self._device_list_cache[did]["online"] = state_new sub = self._sub_device_state.get(did, None) if sub and sub.handler: sub.handler( - did, MIoTDeviceState.ONLINE if state_new - else MIoTDeviceState.OFFLINE, sub.handler_ctx) + did, + MIoTDeviceState.ONLINE if state_new else MIoTDeviceState.OFFLINE, + sub.handler_ctx, + ) @final - async def __refresh_gw_devices_with_group_id_async( - self, group_id: str - ) -> None: + async def __refresh_gw_devices_with_group_id_async(self, group_id: str) -> None: """Refresh gateway devices by group_id""" - _LOGGER.debug( - 'refresh gw devices with group_id, %s', group_id) + _LOGGER.debug("refresh gw devices with group_id, %s", group_id) # Remove timer self._mips_local_state_changed_timers.pop(group_id, None) mips = self._mips_local.get(group_id, None) if not mips: - _LOGGER.error('mips not exist, %s', group_id) + _LOGGER.error("mips not exist, %s", group_id) return if not mips.mips_state: - _LOGGER.debug('local mips disconnect, skip refresh, %s', group_id) + _LOGGER.debug("local mips disconnect, skip refresh, %s", group_id) return gw_list: dict = await mips.get_dev_list_async() if gw_list is None: _LOGGER.error( - 'refresh gw devices with group_id failed, %s, %s', - self._uid, group_id) + "refresh gw devices with group_id failed, %s, %s", self._uid, group_id + ) # Retry until success - self.__request_refresh_gw_devices_by_group_id( - group_id=group_id) + self.__request_refresh_gw_devices_by_group_id(group_id=group_id) return await self.__update_devices_from_gw_async( - gw_list=gw_list, group_id=group_id, filter_dids=[ - did for did, info in self._device_list_gateway.items() - if info.get('group_id', None) == group_id]) + gw_list=gw_list, + group_id=group_id, + filter_dids=[ + did + for did, info in self._device_list_gateway.items() + if info.get("group_id", None) == group_id + ], + ) self.__request_show_devices_changed_notify() @final @@ -1571,24 +1718,27 @@ class MIoTClient: self, group_id: str, immediately: bool = False ) -> None: """Request refresh gateway devices by group_id""" - refresh_timer = self._mips_local_state_changed_timers.get( - group_id, None) + refresh_timer = self._mips_local_state_changed_timers.get(group_id, None) if immediately: if refresh_timer: self._mips_local_state_changed_timers.pop(group_id, None) refresh_timer.cancel() self._mips_local_state_changed_timers[group_id] = ( self._main_loop.call_later( - 0, lambda: self._main_loop.create_task( - self.__refresh_gw_devices_with_group_id_async( - group_id=group_id)))) + 0, + lambda: self._main_loop.create_task( + self.__refresh_gw_devices_with_group_id_async(group_id=group_id) + ), + ) + ) if refresh_timer: return - self._mips_local_state_changed_timers[group_id] = ( - self._main_loop.call_later( - 3, lambda: self._main_loop.create_task( - self.__refresh_gw_devices_with_group_id_async( - group_id=group_id)))) + self._mips_local_state_changed_timers[group_id] = self._main_loop.call_later( + 3, + lambda: self._main_loop.create_task( + self.__refresh_gw_devices_with_group_id_async(group_id=group_id) + ), + ) @final async def __refresh_props_from_cloud(self, patch_len: int = 150) -> bool: @@ -1606,31 +1756,32 @@ class MIoTClient: request_list[key] = value try: results = await self._http.get_props_async( - params=list(request_list.values())) + params=list(request_list.values()) + ) if not results: - raise MIoTClientError('get_props_async failed') + raise MIoTClientError("get_props_async failed") for result in results: if ( - 'did' not in result - or 'siid' not in result - or 'piid' not in result - or 'value' not in result + "did" not in result + or "siid" not in result + or "piid" not in result + or "value" not in result ): continue request_list.pop( - f'{result["did"]}|{result["siid"]}|{result["piid"]}', - None) + f"{result['did']}|{result['siid']}|{result['piid']}", None + ) self.__on_prop_msg(params=result, ctx=None) if request_list: _LOGGER.info( - 'refresh props failed, cloud, %s', - list(request_list.keys())) + "refresh props failed, cloud, %s", list(request_list.keys()) + ) request_list = None return True except Exception as err: # pylint:disable=broad-exception-caught _LOGGER.error( - 'refresh props error, cloud, %s, %s', - err, traceback.format_exc()) + "refresh props error, cloud, %s, %s", err, traceback.format_exc() + ) # Add failed request back to the list self._refresh_props_list.update(request_list) return False @@ -1642,7 +1793,7 @@ class MIoTClient: request_list = {} succeed_once = False for key in list(self._refresh_props_list.keys()): - did = key.split('|')[0] + did = key.split("|")[0] if did in request_list: # NOTICE: A device only requests once a cycle, continuous # acquisition of properties can cause device exceptions. @@ -1652,17 +1803,17 @@ class MIoTClient: if not device_gw: # Device not exist continue - mips_gw = self._mips_local.get(device_gw['group_id'], None) + mips_gw = self._mips_local.get(device_gw["group_id"], None) if not mips_gw: - _LOGGER.error('mips gateway not exist, %s', key) + _LOGGER.error("mips gateway not exist, %s", key) continue request_list[did] = { **params, - 'fut': mips_gw.get_prop_async( - did=did, siid=params['siid'], piid=params['piid'], - timeout_ms=6000)} - results = await asyncio.gather( - *[v['fut'] for v in request_list.values()]) + "fut": mips_gw.get_prop_async( + did=did, siid=params["siid"], piid=params["piid"], timeout_ms=6000 + ), + } + results = await asyncio.gather(*[v["fut"] for v in request_list.values()]) for (did, param), result in zip(request_list.items(), results): if result is None: # Don't use "not result", it will be skipped when result @@ -1670,16 +1821,17 @@ class MIoTClient: continue self.__on_prop_msg( params={ - 'did': did, - 'siid': param['siid'], - 'piid': param['piid'], - 'value': result}, - ctx=None) + "did": did, + "siid": param["siid"], + "piid": param["piid"], + "value": result, + }, + ctx=None, + ) succeed_once = True if succeed_once: return True - _LOGGER.info( - 'refresh props failed, gw, %s', list(request_list.keys())) + _LOGGER.info("refresh props failed, gw, %s", list(request_list.keys())) # Add failed request back to the list self._refresh_props_list.update(request_list) return False @@ -1691,7 +1843,7 @@ class MIoTClient: request_list = {} succeed_once = False for key in list(self._refresh_props_list.keys()): - did = key.split('|')[0] + did = key.split("|")[0] if did in request_list: # NOTICE: A device only requests once a cycle, continuous # acquisition of properties can cause device exceptions. @@ -1701,11 +1853,11 @@ class MIoTClient: continue request_list[did] = { **params, - 'fut': self._miot_lan.get_prop_async( - did=did, siid=params['siid'], piid=params['piid'], - timeout_ms=6000)} - results = await asyncio.gather( - *[v['fut'] for v in request_list.values()]) + "fut": self._miot_lan.get_prop_async( + did=did, siid=params["siid"], piid=params["piid"], timeout_ms=6000 + ), + } + results = await asyncio.gather(*[v["fut"] for v in request_list.values()]) for (did, param), result in zip(request_list.items(), results): if result is None: # Don't use "not result", it will be skipped when result @@ -1713,16 +1865,17 @@ class MIoTClient: continue self.__on_prop_msg( params={ - 'did': did, - 'siid': param['siid'], - 'piid': param['piid'], - 'value': result}, - ctx=None) + "did": did, + "siid": param["siid"], + "piid": param["piid"], + "value": result, + }, + ctx=None, + ) succeed_once = True if succeed_once: return True - _LOGGER.info( - 'refresh props failed, lan, %s', list(request_list.keys())) + _LOGGER.info("refresh props failed, lan, %s", list(request_list.keys())) # Add failed request back to the list self._refresh_props_list.update(request_list) return False @@ -1740,8 +1893,9 @@ class MIoTClient: self._refresh_props_retry_count = 0 if self._refresh_props_list: self._refresh_props_timer = self._main_loop.call_later( - 0.2, lambda: self._main_loop.create_task( - self.__refresh_props_handler())) + 0.2, + lambda: self._main_loop.create_task(self.__refresh_props_handler()), + ) else: self._refresh_props_timer = None return @@ -1753,37 +1907,38 @@ class MIoTClient: if self._refresh_props_timer: self._refresh_props_timer.cancel() self._refresh_props_timer = None - _LOGGER.info('refresh props failed, retry count exceed') + _LOGGER.info("refresh props failed, retry count exceed") return self._refresh_props_retry_count += 1 - _LOGGER.info( - 'refresh props failed, retry, %s', self._refresh_props_retry_count) + _LOGGER.info("refresh props failed, retry, %s", self._refresh_props_retry_count) self._refresh_props_timer = self._main_loop.call_later( - 3, lambda: self._main_loop.create_task( - self.__refresh_props_handler())) + 3, lambda: self._main_loop.create_task(self.__refresh_props_handler()) + ) @final def __show_client_error_notify( - self, message: Optional[str], notify_key: str = '' + self, message: Optional[str], notify_key: str = "" ) -> None: if message: - self._persistence_notify( - f'{DOMAIN}{self._uid}{self._cloud_server}{notify_key}error', + f"{DOMAIN}{self._uid}{self._cloud_server}{notify_key}error", + self._i18n.translate(key="miot.client.xiaomi_home_error_title"), # type: ignore self._i18n.translate( - key='miot.client.xiaomi_home_error_title'), # type: ignore - self._i18n.translate( - key='miot.client.xiaomi_home_error', + key="miot.client.xiaomi_home_error", replace={ - 'nick_name': self._entry_data.get( - 'nick_name', DEFAULT_NICK_NAME), - 'uid': self._uid, - 'cloud_server': self._cloud_server, - 'message': message})) # type: ignore + "nick_name": self._entry_data.get( + "nick_name", DEFAULT_NICK_NAME + ), + "uid": self._uid, + "cloud_server": self._cloud_server, + "message": message, + }, + ), + ) # type: ignore else: self._persistence_notify( - f'{DOMAIN}{self._uid}{self._cloud_server}{notify_key}error', - None, None) + f"{DOMAIN}{self._uid}{self._cloud_server}{notify_key}error", None, None + ) @final def __show_devices_changed_notify(self) -> None: @@ -1792,108 +1947,114 @@ class MIoTClient: if self._persistence_notify is None: return - message_add: str = '' + message_add: str = "" count_add: int = 0 - message_del: str = '' + message_del: str = "" count_del: int = 0 - message_offline: str = '' + message_offline: str = "" count_offline: int = 0 # New devices - if 'add' in self._display_devs_notify: + if "add" in self._display_devs_notify: for did, info in { - **self._device_list_gateway, **self._device_list_cloud + **self._device_list_gateway, + **self._device_list_cloud, }.items(): if did in self._device_list_cache: continue count_add += 1 message_add += ( - f'- {info.get("name", "unknown")} ({did}, ' - f'{info.get("model", "unknown")})\n') + f"- {info.get('name', 'unknown')} ({did}, " + f"{info.get('model', 'unknown')})\n" + ) # Get unavailable and offline devices home_name_del: Optional[str] = None home_name_offline: Optional[str] = None for did, info in self._device_list_cache.items(): - online: Optional[bool] = info.get('online', None) - home_name_new = info.get('home_name', 'unknown') + online: Optional[bool] = info.get("online", None) + home_name_new = info.get("home_name", "unknown") if online: # Skip online device continue - if 'del' in self._display_devs_notify and online is None: + if "del" in self._display_devs_notify and online is None: # Device not exist if home_name_del != home_name_new: - message_del += f'\n[{home_name_new}]\n' + message_del += f"\n[{home_name_new}]\n" home_name_del = home_name_new count_del += 1 message_del += ( - f'- {info.get("name", "unknown")} ({did}, ' - f'{info.get("room_name", "unknown")})\n') + f"- {info.get('name', 'unknown')} ({did}, " + f"{info.get('room_name', 'unknown')})\n" + ) continue - if 'offline' in self._display_devs_notify: + if "offline" in self._display_devs_notify: # Device offline if home_name_offline != home_name_new: - message_offline += f'\n[{home_name_new}]\n' + message_offline += f"\n[{home_name_new}]\n" home_name_offline = home_name_new count_offline += 1 message_offline += ( - f'- {info.get("name", "unknown")} ({did}, ' - f'{info.get("room_name", "unknown")})\n') + f"- {info.get('name', 'unknown')} ({did}, " + f"{info.get('room_name', 'unknown')})\n" + ) - message = '' - if 'add' in self._display_devs_notify and count_add: + message = "" + if "add" in self._display_devs_notify and count_add: message += self._i18n.translate( - key='miot.client.device_list_add', - replace={ - 'count': count_add, - 'message': message_add}) # type: ignore - if 'del' in self._display_devs_notify and count_del: + key="miot.client.device_list_add", + replace={"count": count_add, "message": message_add}, + ) # type: ignore + if "del" in self._display_devs_notify and count_del: message += self._i18n.translate( - key='miot.client.device_list_del', - replace={ - 'count': count_del, - 'message': message_del}) # type: ignore - if 'offline' in self._display_devs_notify and count_offline: + key="miot.client.device_list_del", + replace={"count": count_del, "message": message_del}, + ) # type: ignore + if "offline" in self._display_devs_notify and count_offline: message += self._i18n.translate( - key='miot.client.device_list_offline', - replace={ - 'count': count_offline, - 'message': message_offline}) # type: ignore - if message != '': + key="miot.client.device_list_offline", + replace={"count": count_offline, "message": message_offline}, + ) # type: ignore + if message != "": msg_hash = hash(message) if msg_hash == self._display_notify_content_hash: # Notify content no change, return - _LOGGER.debug( - 'device list changed notify content no change, return') + _LOGGER.debug("device list changed notify content no change, return") return network_status = self._i18n.translate( - key='miot.client.network_status_online' + key="miot.client.network_status_online" if self._network.network_status - else 'miot.client.network_status_offline') + else "miot.client.network_status_offline" + ) self._persistence_notify( - self.__gen_notify_key('dev_list_changed'), + self.__gen_notify_key("dev_list_changed"), + self._i18n.translate("miot.client.device_list_changed_title"), # type: ignore self._i18n.translate( - 'miot.client.device_list_changed_title'), # type: ignore - self._i18n.translate( - key='miot.client.device_list_changed', + key="miot.client.device_list_changed", replace={ - 'nick_name': self._entry_data.get( - 'nick_name', DEFAULT_NICK_NAME), - 'uid': self._uid, - 'cloud_server': self._cloud_server, - 'network_status': network_status, - 'message': message})) # type: ignore + "nick_name": self._entry_data.get( + "nick_name", DEFAULT_NICK_NAME + ), + "uid": self._uid, + "cloud_server": self._cloud_server, + "network_status": network_status, + "message": message, + }, + ), + ) # type: ignore self._display_notify_content_hash = msg_hash _LOGGER.debug( - 'show device list changed notify, add %s, del %s, offline %s', - count_add, count_del, count_offline) + "show device list changed notify, add %s, del %s, offline %s", + count_add, + count_del, + count_offline, + ) else: self._persistence_notify( - self.__gen_notify_key('dev_list_changed'), None, None) + self.__gen_notify_key("dev_list_changed"), None, None + ) @final - def __request_show_devices_changed_notify( - self, delay_sec: float = 6 - ) -> None: + def __request_show_devices_changed_notify(self, delay_sec: float = 6) -> None: if not self._display_devs_notify: return if not self._mips_cloud and not self._mips_local and not self._miot_lan: @@ -1901,71 +2062,74 @@ class MIoTClient: if self._show_devices_changed_notify_timer: self._show_devices_changed_notify_timer.cancel() self._show_devices_changed_notify_timer = self._main_loop.call_later( - delay_sec, self.__show_devices_changed_notify) + delay_sec, self.__show_devices_changed_notify + ) @staticmethod async def get_miot_instance_async( - hass: HomeAssistant, entry_id: str, entry_data: Optional[dict] = None, - persistent_notify: Optional[Callable[[str, str, str], None]] = None + hass: HomeAssistant, + entry_id: str, + entry_data: Optional[dict] = None, + persistent_notify: Optional[Callable[[str, str, str], None]] = None, ) -> MIoTClient: if entry_id is None: - raise MIoTClientError('invalid entry_id') - miot_client = hass.data[DOMAIN].get('miot_clients', {}).get(entry_id, None) + raise MIoTClientError("invalid entry_id") + miot_client = hass.data[DOMAIN].get("miot_clients", {}).get(entry_id, None) if miot_client: - _LOGGER.info('instance exist, %s', entry_id) + _LOGGER.info("instance exist, %s", entry_id) return miot_client # Create new instance if not entry_data: - raise MIoTClientError('entry data is None') + raise MIoTClientError("entry data is None") # Get running loop loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() if not loop: - raise MIoTClientError('loop is None') + raise MIoTClientError("loop is None") # MIoT storage - storage: Optional[MIoTStorage] = hass.data[DOMAIN].get( - 'miot_storage', None) + storage: Optional[MIoTStorage] = hass.data[DOMAIN].get("miot_storage", None) if not storage: - storage = MIoTStorage( - root_path=entry_data['storage_path'], loop=loop) - hass.data[DOMAIN]['miot_storage'] = storage - _LOGGER.info('create miot_storage instance') + storage = MIoTStorage(root_path=entry_data["storage_path"], loop=loop) + hass.data[DOMAIN]["miot_storage"] = storage + _LOGGER.info("create miot_storage instance") global_config: dict = await storage.load_user_config_async( - uid='global_config', cloud_server='all', - keys=['network_detect_addr', 'net_interfaces', 'enable_subscribe']) + uid="global_config", + cloud_server="all", + keys=["network_detect_addr", "net_interfaces", "enable_subscribe"], + ) # MIoT network - network_detect_addr: dict = global_config.get('network_detect_addr', {}) - network: Optional[MIoTNetwork] = hass.data[DOMAIN].get( - 'miot_network', None) + network_detect_addr: dict = global_config.get("network_detect_addr", {}) + network: Optional[MIoTNetwork] = hass.data[DOMAIN].get("miot_network", None) if not network: network = MIoTNetwork( - ip_addr_list=network_detect_addr.get('ip', []), - url_addr_list=network_detect_addr.get('url', []), + ip_addr_list=network_detect_addr.get("ip", []), + url_addr_list=network_detect_addr.get("url", []), refresh_interval=NETWORK_REFRESH_INTERVAL, - loop=loop) - hass.data[DOMAIN]['miot_network'] = network + loop=loop, + ) + hass.data[DOMAIN]["miot_network"] = network await network.init_async() - _LOGGER.info('create miot_network instance') + _LOGGER.info("create miot_network instance") # MIoT service - mips_service: Optional[MipsService] = hass.data[DOMAIN].get( - 'mips_service', None) + mips_service: Optional[MipsService] = hass.data[DOMAIN].get("mips_service", None) if not mips_service: aiozc = await zeroconf.async_get_async_instance(hass) mips_service = MipsService(aiozc=aiozc, loop=loop) - hass.data[DOMAIN]['mips_service'] = mips_service + hass.data[DOMAIN]["mips_service"] = mips_service await mips_service.init_async() - _LOGGER.info('create mips_service instance') + _LOGGER.info("create mips_service instance") # MIoT lan - miot_lan: Optional[MIoTLan] = hass.data[DOMAIN].get('miot_lan', None) + miot_lan: Optional[MIoTLan] = hass.data[DOMAIN].get("miot_lan", None) if not miot_lan: miot_lan = MIoTLan( - net_ifs=global_config.get('net_interfaces', []), + net_ifs=global_config.get("net_interfaces", []), network=network, mips_service=mips_service, - enable_subscribe=global_config.get('enable_subscribe', False), - loop=loop) - hass.data[DOMAIN]['miot_lan'] = miot_lan - _LOGGER.info('create miot_lan instance') + enable_subscribe=global_config.get("enable_subscribe", False), + loop=loop, + ) + hass.data[DOMAIN]["miot_lan"] = miot_lan + _LOGGER.info("create miot_lan instance") # MIoT client miot_client = MIoTClient( entry_id=entry_id, @@ -1974,10 +2138,10 @@ async def get_miot_instance_async( storage=storage, mips_service=mips_service, miot_lan=miot_lan, - loop=loop + loop=loop, ) miot_client.persistent_notify = persistent_notify - hass.data[DOMAIN]['miot_clients'].setdefault(entry_id, miot_client) - _LOGGER.info('new miot_client instance, %s, %s', entry_id, entry_data) + hass.data[DOMAIN]["miot_clients"].setdefault(entry_id, miot_client) + _LOGGER.info("new miot_client instance, %s, %s", entry_id, entry_data) await miot_client.init_async() return miot_client diff --git a/custom_components/xiaomi_home/miot/miot_cloud.py b/custom_components/xiaomi_home/miot/miot_cloud.py index 0b301e8..bbd3357 100644 --- a/custom_components/xiaomi_home/miot/miot_cloud.py +++ b/custom_components/xiaomi_home/miot/miot_cloud.py @@ -45,6 +45,7 @@ off Xiaomi or its affiliates' products. MIoT http client. """ + import asyncio import base64 import hashlib @@ -58,10 +59,7 @@ import aiohttp # pylint: disable=relative-beyond-top-level from .common import calc_group_id -from .const import ( - DEFAULT_OAUTH2_API_HOST, - MIHOME_HTTP_API_TIMEOUT, - OAUTH2_AUTH_URL) +from .const import DEFAULT_OAUTH2_API_HOST, MIHOME_HTTP_API_TIMEOUT, OAUTH2_AUTH_URL from .miot_error import MIoTErrorCode, MIoTHttpError, MIoTOauthError _LOGGER = logging.getLogger(__name__) @@ -71,6 +69,7 @@ TOKEN_EXPIRES_TS_RATIO = 0.7 class MIoTOauthClient: """oauth agent url, default: product env.""" + _main_loop: asyncio.AbstractEventLoop _session: aiohttp.ClientSession _oauth_host: str @@ -80,28 +79,31 @@ class MIoTOauthClient: _state: str def __init__( - self, client_id: str, redirect_url: str, cloud_server: str, - uuid: str, loop: Optional[asyncio.AbstractEventLoop] = None + self, + client_id: str, + redirect_url: str, + cloud_server: str, + uuid: str, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: self._main_loop = loop or asyncio.get_running_loop() - if client_id is None or client_id.strip() == '': - raise MIoTOauthError('invalid client_id') + if client_id is None or client_id.strip() == "": + raise MIoTOauthError("invalid client_id") if not redirect_url: - raise MIoTOauthError('invalid redirect_url') + raise MIoTOauthError("invalid redirect_url") if not cloud_server: - raise MIoTOauthError('invalid cloud_server') + raise MIoTOauthError("invalid cloud_server") if not uuid: - raise MIoTOauthError('invalid uuid') + raise MIoTOauthError("invalid uuid") self._client_id = int(client_id) self._redirect_url = redirect_url - if cloud_server == 'cn': + if cloud_server == "cn": self._oauth_host = DEFAULT_OAUTH2_API_HOST else: - self._oauth_host = f'{cloud_server}.{DEFAULT_OAUTH2_API_HOST}' - self._device_id = f'ha.{uuid}' - self._state = hashlib.sha1( - f'd={self._device_id}'.encode('utf-8')).hexdigest() + self._oauth_host = f"{cloud_server}.{DEFAULT_OAUTH2_API_HOST}" + self._device_id = f"ha.{uuid}" + self._state = hashlib.sha1(f"d={self._device_id}".encode("utf-8")).hexdigest() self._session = aiohttp.ClientSession(loop=self._main_loop) @property @@ -113,8 +115,8 @@ class MIoTOauthClient: await self._session.close() def set_redirect_url(self, redirect_url: str) -> None: - if not isinstance(redirect_url, str) or redirect_url.strip() == '': - raise MIoTOauthError('invalid redirect_url') + if not isinstance(redirect_url, str) or redirect_url.strip() == "": + raise MIoTOauthError("invalid redirect_url") self._redirect_url = redirect_url def gen_auth_url( @@ -141,52 +143,54 @@ class MIoTOauthClient: str: _description_ """ params: dict = { - 'redirect_uri': redirect_url or self._redirect_url, - 'client_id': self._client_id, - 'response_type': 'code', - 'device_id': self._device_id, - 'state': self._state + "redirect_uri": redirect_url or self._redirect_url, + "client_id": self._client_id, + "response_type": "code", + "device_id": self._device_id, + "state": self._state, } if state: - params['state'] = state + params["state"] = state if scope: - params['scope'] = ' '.join(scope).strip() - params['skip_confirm'] = skip_confirm + params["scope"] = " ".join(scope).strip() + params["skip_confirm"] = skip_confirm encoded_params = urlencode(params) - return f'{OAUTH2_AUTH_URL}?{encoded_params}' + return f"{OAUTH2_AUTH_URL}?{encoded_params}" async def __get_token_async(self, data) -> dict: http_res = await self._session.get( - url=f'https://{self._oauth_host}/app/v2/ha/oauth/get_token', - params={'data': json.dumps(data)}, - headers={'content-type': 'application/x-www-form-urlencoded'}, - timeout=MIHOME_HTTP_API_TIMEOUT + url=f"https://{self._oauth_host}/app/v2/ha/oauth/get_token", + params={"data": json.dumps(data)}, + headers={"content-type": "application/x-www-form-urlencoded"}, + timeout=MIHOME_HTTP_API_TIMEOUT, ) if http_res.status == 401: raise MIoTOauthError( - 'unauthorized(401)', MIoTErrorCode.CODE_OAUTH_UNAUTHORIZED) + "unauthorized(401)", MIoTErrorCode.CODE_OAUTH_UNAUTHORIZED + ) if http_res.status != 200: - raise MIoTOauthError( - f'invalid http status code, {http_res.status}') + raise MIoTOauthError(f"invalid http status code, {http_res.status}") res_str = await http_res.text() res_obj = json.loads(res_str) if ( not res_obj - or res_obj.get('code', None) != 0 - or 'result' not in res_obj + or res_obj.get("code", None) != 0 + or "result" not in res_obj or not all( - key in res_obj['result'] - for key in ['access_token', 'refresh_token', 'expires_in']) + key in res_obj["result"] + for key in ["access_token", "refresh_token", "expires_in"] + ) ): - raise MIoTOauthError(f'invalid http response, {res_str}') + raise MIoTOauthError(f"invalid http response, {res_str}") return { - **res_obj['result'], - 'expires_ts': int( - time.time() + - (res_obj['result'].get('expires_in', 0)*TOKEN_EXPIRES_TS_RATIO)) + **res_obj["result"], + "expires_ts": int( + time.time() + + (res_obj["result"].get("expires_in", 0) * TOKEN_EXPIRES_TS_RATIO) + ), } async def get_access_token_async(self, code: str) -> dict: @@ -199,14 +203,16 @@ class MIoTOauthClient: str: _description_ """ if not isinstance(code, str): - raise MIoTOauthError('invalid code') + raise MIoTOauthError("invalid code") - return await self.__get_token_async(data={ - 'client_id': self._client_id, - 'redirect_uri': self._redirect_url, - 'code': code, - 'device_id': self._device_id - }) + return await self.__get_token_async( + data={ + "client_id": self._client_id, + "redirect_uri": self._redirect_url, + "code": code, + "device_id": self._device_id, + } + ) async def refresh_access_token_async(self, refresh_token: str) -> dict: """get access token by refresh token. @@ -218,17 +224,20 @@ class MIoTOauthClient: str: _description_ """ if not isinstance(refresh_token, str): - raise MIoTOauthError('invalid refresh_token') + raise MIoTOauthError("invalid refresh_token") - return await self.__get_token_async(data={ - 'client_id': self._client_id, - 'redirect_uri': self._redirect_url, - 'refresh_token': refresh_token, - }) + return await self.__get_token_async( + data={ + "client_id": self._client_id, + "redirect_uri": self._redirect_url, + "refresh_token": refresh_token, + } + ) class MIoTHttpClient: """MIoT http client.""" + # pylint: disable=inconsistent-quotes GET_PROP_AGGREGATE_INTERVAL: float = 0.2 GET_PROP_MAX_REQ_COUNT = 150 @@ -243,14 +252,17 @@ class MIoTHttpClient: _get_prop_list: dict[str, dict] def __init__( - self, cloud_server: str, client_id: str, access_token: str, - loop: Optional[asyncio.AbstractEventLoop] = None + self, + cloud_server: str, + client_id: str, + access_token: str, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: self._main_loop = loop or asyncio.get_running_loop() self._host = DEFAULT_OAUTH2_API_HOST - self._base_url = '' - self._client_id = '' - self._access_token = '' + self._base_url = "" + self._client_id = "" + self._access_token = "" self._get_prop_timer = None self._get_prop_list = {} @@ -260,11 +272,11 @@ class MIoTHttpClient: or not isinstance(client_id, str) or not isinstance(access_token, str) ): - raise MIoTHttpError('invalid params') + raise MIoTHttpError("invalid params") self.update_http_header( - cloud_server=cloud_server, client_id=client_id, - access_token=access_token) + cloud_server=cloud_server, client_id=client_id, access_token=access_token + ) self._session = aiohttp.ClientSession(loop=self._main_loop) @@ -273,7 +285,7 @@ class MIoTHttpClient: self._get_prop_timer.cancel() self._get_prop_timer = None for item in self._get_prop_list.values(): - fut: Optional[asyncio.Future] = item.get('fut', None) + fut: Optional[asyncio.Future] = item.get("fut", None) if fut: fut.cancel() self._get_prop_list.clear() @@ -281,14 +293,15 @@ class MIoTHttpClient: await self._session.close() def update_http_header( - self, cloud_server: Optional[str] = None, + self, + cloud_server: Optional[str] = None, client_id: Optional[str] = None, - access_token: Optional[str] = None + access_token: Optional[str] = None, ) -> None: if isinstance(cloud_server, str): - if cloud_server != 'cn': - self._host = f'{cloud_server}.{DEFAULT_OAUTH2_API_HOST}' - self._base_url = f'https://{self._host}' + if cloud_server != "cn": + self._host = f"{cloud_server}.{DEFAULT_OAUTH2_API_HOST}" + self._base_url = f"https://{self._host}" if isinstance(client_id, str): self._client_id = client_id if isinstance(access_token, str): @@ -297,318 +310,309 @@ class MIoTHttpClient: @property def __api_request_headers(self) -> dict: return { - 'Host': self._host, - 'X-Client-BizId': 'haapi', - 'Content-Type': 'application/json', - 'Authorization': f'Bearer{self._access_token}', - 'X-Client-AppId': self._client_id, + "Host": self._host, + "X-Client-BizId": "haapi", + "Content-Type": "application/json", + "Authorization": f"Bearer{self._access_token}", + "X-Client-AppId": self._client_id, } # pylint: disable=unused-private-member async def __mihome_api_get_async( - self, url_path: str, params: dict, - timeout: int = MIHOME_HTTP_API_TIMEOUT + self, url_path: str, params: dict, timeout: int = MIHOME_HTTP_API_TIMEOUT ) -> dict: http_res = await self._session.get( - url=f'{self._base_url}{url_path}', + url=f"{self._base_url}{url_path}", params=params, headers=self.__api_request_headers, - timeout=timeout) + timeout=timeout, + ) if http_res.status == 401: raise MIoTHttpError( - 'mihome api get failed, unauthorized(401)', - MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN) + "mihome api get failed, unauthorized(401)", + MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN, + ) if http_res.status != 200: raise MIoTHttpError( - f'mihome api get failed, {http_res.status}, ' - f'{url_path}, {params}') + f"mihome api get failed, {http_res.status}, {url_path}, {params}" + ) res_str = await http_res.text() res_obj: dict = json.loads(res_str) - if res_obj.get('code', None) != 0: + if res_obj.get("code", None) != 0: raise MIoTHttpError( - f'invalid response code, {res_obj.get("code",None)}, ' - f'{res_obj.get("message","")}') + f"invalid response code, {res_obj.get('code', None)}, " + f"{res_obj.get('message', '')}" + ) _LOGGER.debug( - 'mihome api get, %s%s, %s -> %s', - self._base_url, url_path, params, res_obj) + "mihome api get, %s%s, %s -> %s", self._base_url, url_path, params, res_obj + ) return res_obj async def __mihome_api_post_async( - self, url_path: str, data: dict, - timeout: int = MIHOME_HTTP_API_TIMEOUT + self, url_path: str, data: dict, timeout: int = MIHOME_HTTP_API_TIMEOUT ) -> dict: http_res = await self._session.post( - url=f'{self._base_url}{url_path}', + url=f"{self._base_url}{url_path}", json=data, headers=self.__api_request_headers, - timeout=timeout) + timeout=timeout, + ) if http_res.status == 401: raise MIoTHttpError( - 'mihome api get failed, unauthorized(401)', - MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN) + "mihome api get failed, unauthorized(401)", + MIoTErrorCode.CODE_HTTP_INVALID_ACCESS_TOKEN, + ) if http_res.status != 200: raise MIoTHttpError( - f'mihome api post failed, {http_res.status}, ' - f'{url_path}, {data}') + f"mihome api post failed, {http_res.status}, {url_path}, {data}" + ) res_str = await http_res.text() res_obj: dict = json.loads(res_str) - if res_obj.get('code', None) != 0: + if res_obj.get("code", None) != 0: raise MIoTHttpError( - f'invalid response code, {res_obj.get("code",None)}, ' - f'{res_obj.get("message","")}') + f"invalid response code, {res_obj.get('code', None)}, " + f"{res_obj.get('message', '')}" + ) _LOGGER.debug( - 'mihome api post, %s%s, %s -> %s', - self._base_url, url_path, data, res_obj) + "mihome api post, %s%s, %s -> %s", self._base_url, url_path, data, res_obj + ) return res_obj async def get_user_info_async(self) -> dict: http_res = await self._session.get( - url='https://open.account.xiaomi.com/user/profile', - params={ - 'clientId': self._client_id, 'token': self._access_token}, - headers={'content-type': 'application/x-www-form-urlencoded'}, - timeout=MIHOME_HTTP_API_TIMEOUT + url="https://open.account.xiaomi.com/user/profile", + params={"clientId": self._client_id, "token": self._access_token}, + headers={"content-type": "application/x-www-form-urlencoded"}, + timeout=MIHOME_HTTP_API_TIMEOUT, ) res_str = await http_res.text() res_obj = json.loads(res_str) if ( not res_obj - or res_obj.get('code', None) != 0 - or 'data' not in res_obj - or 'miliaoNick' not in res_obj['data'] + or res_obj.get("code", None) != 0 + or "data" not in res_obj + or "miliaoNick" not in res_obj["data"] ): - raise MIoTOauthError(f'invalid http response, {http_res.text}') + raise MIoTOauthError(f"invalid http response, {http_res.text}") - return res_obj['data'] + return res_obj["data"] async def get_central_cert_async(self, csr: str) -> str: if not isinstance(csr, str): - raise MIoTHttpError('invalid params') + raise MIoTHttpError("invalid params") res_obj: dict = await self.__mihome_api_post_async( - url_path='/app/v2/ha/oauth/get_central_crt', - data={ - 'csr': str(base64.b64encode(csr.encode('utf-8')), 'utf-8') - } + url_path="/app/v2/ha/oauth/get_central_crt", + data={"csr": str(base64.b64encode(csr.encode("utf-8")), "utf-8")}, ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') - cert: str = res_obj['result'].get('cert', None) + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") + cert: str = res_obj["result"].get("cert", None) if not isinstance(cert, str): - raise MIoTHttpError('invalid cert') + raise MIoTHttpError("invalid cert") return cert - async def __get_dev_room_page_async( - self, max_id: Optional[str] = None - ) -> dict: + async def __get_dev_room_page_async(self, max_id: Optional[str] = None) -> dict: res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/homeroom/get_dev_room_page', + url_path="/app/v2/homeroom/get_dev_room_page", data={ - 'start_id': max_id, - 'limit': 150, + "start_id": max_id, + "limit": 150, }, ) - if 'result' not in res_obj and 'info' not in res_obj['result']: - raise MIoTHttpError('invalid response result') + if "result" not in res_obj and "info" not in res_obj["result"]: + raise MIoTHttpError("invalid response result") home_list: dict = {} - for home in res_obj['result']['info']: - if 'id' not in home: - _LOGGER.error( - 'get dev room page error, invalid home, %s', home) + for home in res_obj["result"]["info"]: + if "id" not in home: + _LOGGER.error("get dev room page error, invalid home, %s", home) continue - home_list[str(home['id'])] = {'dids': home.get( - 'dids', None) or [], 'room_info': {}} - for room in home.get('roomlist', []): - if 'id' not in room: - _LOGGER.error( - 'get dev room page error, invalid room, %s', room) + home_list[str(home["id"])] = { + "dids": home.get("dids", None) or [], + "room_info": {}, + } + for room in home.get("roomlist", []): + if "id" not in room: + _LOGGER.error("get dev room page error, invalid room, %s", room) continue - home_list[str(home['id'])]['room_info'][str(room['id'])] = { - 'dids': room.get('dids', None) or []} - if ( - res_obj['result'].get('has_more', False) - and isinstance(res_obj['result'].get('max_id', None), str) + home_list[str(home["id"])]["room_info"][str(room["id"])] = { + "dids": room.get("dids", None) or [] + } + if res_obj["result"].get("has_more", False) and isinstance( + res_obj["result"].get("max_id", None), str ): next_list = await self.__get_dev_room_page_async( - max_id=res_obj['result']['max_id']) + max_id=res_obj["result"]["max_id"] + ) for home_id, info in next_list.items(): - home_list.setdefault(home_id, {'dids': [], 'room_info': {}}) - home_list[home_id]['dids'].extend(info['dids']) - for room_id, info in info['room_info'].items(): - home_list[home_id]['room_info'].setdefault( - room_id, {'dids': []}) - home_list[home_id]['room_info'][room_id]['dids'].extend( - info['dids']) + home_list.setdefault(home_id, {"dids": [], "room_info": {}}) + home_list[home_id]["dids"].extend(info["dids"]) + for room_id, info in info["room_info"].items(): + home_list[home_id]["room_info"].setdefault(room_id, {"dids": []}) + home_list[home_id]["room_info"][room_id]["dids"].extend( + info["dids"] + ) return home_list async def get_separated_shared_devices_async(self) -> dict[str, dict]: separated_shared_devices: dict = {} device_list: dict[str, dict] = await self.__get_device_list_page_async( - dids=[], start_did=None) + dids=[], start_did=None + ) for did, value in device_list.items(): - if value['owner'] is not None and ('userid' in value['owner']) and ( - 'nickname' in value['owner'] + if ( + value["owner"] is not None + and ("userid" in value["owner"]) + and ("nickname" in value["owner"]) ): - separated_shared_devices.setdefault(did, value['owner']) + separated_shared_devices.setdefault(did, value["owner"]) return separated_shared_devices async def get_homeinfos_async(self) -> dict: res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/homeroom/gethome', + url_path="/app/v2/homeroom/gethome", data={ - 'limit': 150, - 'fetch_share': True, - 'fetch_share_dev': True, - 'plat_form': 0, - 'app_ver': 9, + "limit": 150, + "fetch_share": True, + "fetch_share_dev": True, + "plat_form": 0, + "app_ver": 9, }, ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") uid: Optional[str] = None home_infos: dict = {} - for device_source in ['homelist', 'share_home_list']: + for device_source in ["homelist", "share_home_list"]: home_infos.setdefault(device_source, {}) - for home in res_obj['result'].get(device_source, []): - if ( - 'id' not in home - or 'name' not in home - or 'roomlist' not in home - ): + for home in res_obj["result"].get(device_source, []): + if "id" not in home or "name" not in home or "roomlist" not in home: continue - if uid is None and device_source == 'homelist': - uid = str(home['uid']) - home_infos[device_source][home['id']] = { - 'home_id': home['id'], - 'home_name': home['name'], - 'city_id': home.get('city_id', None), - 'longitude': home.get('longitude', None), - 'latitude': home.get('latitude', None), - 'address': home.get('address', None), - 'dids': home.get('dids', []), - 'room_info': { - room['id']: { - 'room_id': room['id'], - 'room_name': room['name'], - 'dids': room.get('dids', []) + if uid is None and device_source == "homelist": + uid = str(home["uid"]) + home_infos[device_source][home["id"]] = { + "home_id": home["id"], + "home_name": home["name"], + "city_id": home.get("city_id", None), + "longitude": home.get("longitude", None), + "latitude": home.get("latitude", None), + "address": home.get("address", None), + "dids": home.get("dids", []), + "room_info": { + room["id"]: { + "room_id": room["id"], + "room_name": room["name"], + "dids": room.get("dids", []), } - for room in home.get('roomlist', []) - if 'id' in room + for room in home.get("roomlist", []) + if "id" in room }, - 'group_id': calc_group_id( - uid=home['uid'], home_id=home['id']), - 'uid': str(home['uid']) + "group_id": calc_group_id(uid=home["uid"], home_id=home["id"]), + "uid": str(home["uid"]), } - home_infos['uid'] = uid - if ( - res_obj['result'].get('has_more', False) - and isinstance(res_obj['result'].get('max_id', None), str) + home_infos["uid"] = uid + if res_obj["result"].get("has_more", False) and isinstance( + res_obj["result"].get("max_id", None), str ): more_list = await self.__get_dev_room_page_async( - max_id=res_obj['result']['max_id']) - for device_source in ['homelist', 'share_home_list']: + max_id=res_obj["result"]["max_id"] + ) + for device_source in ["homelist", "share_home_list"]: for home_id, info in more_list.items(): if home_id not in home_infos[device_source]: - _LOGGER.info('unknown home, %s, %s', home_id, info) + _LOGGER.info("unknown home, %s, %s", home_id, info) continue - home_infos[device_source][home_id]['dids'].extend( - info['dids']) - for room_id, info in info['room_info'].items(): - home_infos[device_source][home_id][ - 'room_info'].setdefault( - room_id, { - 'room_id': room_id, - 'room_name': '', - 'dids': []}) - home_infos[device_source][home_id]['room_info'][ - room_id]['dids'].extend(info['dids']) + home_infos[device_source][home_id]["dids"].extend(info["dids"]) + for room_id, info in info["room_info"].items(): + home_infos[device_source][home_id]["room_info"].setdefault( + room_id, {"room_id": room_id, "room_name": "", "dids": []} + ) + home_infos[device_source][home_id]["room_info"][room_id][ + "dids" + ].extend(info["dids"]) return { - 'uid': uid, - 'home_list': home_infos.get('homelist', {}), - 'share_home_list': home_infos.get('share_home_list', []) + "uid": uid, + "home_list": home_infos.get("homelist", {}), + "share_home_list": home_infos.get("share_home_list", []), } async def get_uid_async(self) -> str: - return (await self.get_homeinfos_async()).get('uid', None) + return (await self.get_homeinfos_async()).get("uid", None) async def __get_device_list_page_async( self, dids: list[str], start_did: Optional[str] = None ) -> dict[str, dict]: - req_data: dict = { - 'limit': 200, - 'get_split_device': True, - 'dids': dids - } + req_data: dict = {"limit": 200, "get_split_device": True, "dids": dids} if start_did: - req_data['start_did'] = start_did + req_data["start_did"] = start_did device_infos: dict = {} res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/home/device_list_page', - data=req_data + url_path="/app/v2/home/device_list_page", data=req_data ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') - res_obj = res_obj['result'] + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") + res_obj = res_obj["result"] - for device in res_obj.get('list', []) or []: - did = device.get('did', None) - name = device.get('name', None) - urn = device.get('spec_type', None) - model = device.get('model', None) + for device in res_obj.get("list", []) or []: + did = device.get("did", None) + name = device.get("name", None) + urn = device.get("spec_type", None) + model = device.get("model", None) if did is None or name is None: - _LOGGER.info( - 'invalid device, cloud, %s', device) + _LOGGER.info("invalid device, cloud, %s", device) continue if urn is None or model is None: - _LOGGER.info( - 'missing the urn|model field, cloud, %s', device) + _LOGGER.info("missing the urn|model field, cloud, %s", device) continue - if did.startswith('miwifi.'): + if did.startswith("miwifi."): # The miwifi.* routers defined SPEC functions, but none of them # were implemented. - _LOGGER.info('ignore miwifi.* device, cloud, %s', did) + _LOGGER.info("ignore miwifi.* device, cloud, %s", did) continue device_infos[did] = { - 'did': did, - 'uid': device.get('uid', None), - 'name': name, - 'urn': urn, - 'model': model, - 'connect_type': device.get('pid', -1), - 'token': device.get('token', None), - 'online': device.get('isOnline', False), - 'icon': device.get('icon', None), - 'parent_id': device.get('parent_id', None), - 'manufacturer': model.split('.')[0], + "did": did, + "uid": device.get("uid", None), + "name": name, + "urn": urn, + "model": model, + "connect_type": device.get("pid", -1), + "token": device.get("token", None), + "online": device.get("isOnline", False), + "icon": device.get("icon", None), + "parent_id": device.get("parent_id", None), + "manufacturer": model.split(".")[0], # 2: xiao-ai, 1: general speaker - 'voice_ctrl': device.get('voice_ctrl', 0), - 'rssi': device.get('rssi', None), - 'owner': device.get('owner', None), - 'pid': device.get('pid', None), - 'local_ip': device.get('local_ip', None), - 'ssid': device.get('ssid', None), - 'bssid': device.get('bssid', None), - 'order_time': device.get('orderTime', 0), - 'fw_version': device.get('extra', {}).get( - 'fw_version', 'unknown'), + "voice_ctrl": device.get("voice_ctrl", 0), + "rssi": device.get("rssi", None), + "owner": device.get("owner", None), + "pid": device.get("pid", None), + "local_ip": device.get("local_ip", None), + "ssid": device.get("ssid", None), + "bssid": device.get("bssid", None), + "order_time": device.get("orderTime", 0), + "fw_version": device.get("extra", {}).get("fw_version", "unknown"), } - if isinstance(device.get('extra', None), dict) and device['extra']: - device_infos[did]['fw_version'] = device['extra'].get( - 'fw_version', None) - device_infos[did]['mcu_version'] = device['extra'].get( - 'mcu_version', None) - device_infos[did]['platform'] = device['extra'].get( - 'platform', None) + if isinstance(device.get("extra", None), dict) and device["extra"]: + device_infos[did]["fw_version"] = device["extra"].get( + "fw_version", None + ) + device_infos[did]["mcu_version"] = device["extra"].get( + "mcu_version", None + ) + device_infos[did]["platform"] = device["extra"].get("platform", None) - next_start_did = res_obj.get('next_start_did', None) - if res_obj.get('has_more', False) and next_start_did: - device_infos.update(await self.__get_device_list_page_async( - dids=dids, start_did=next_start_did)) + next_start_did = res_obj.get("next_start_did", None) + if res_obj.get("has_more", False) and next_start_did: + device_infos.update( + await self.__get_device_list_page_async( + dids=dids, start_did=next_start_did + ) + ) return device_infos @@ -616,8 +620,11 @@ class MIoTHttpClient: self, dids: list[str] ) -> Optional[dict[str, dict]]: results: list[dict[str, dict]] = await asyncio.gather( - *[self.__get_device_list_page_async(dids=dids[index:index+150]) - for index in range(0, len(dids), 150)]) + *[ + self.__get_device_list_page_async(dids=dids[index : index + 150]) + for index in range(0, len(dids), 150) + ] + ) devices = {} for result in results: if result is None: @@ -631,87 +638,96 @@ class MIoTHttpClient: homeinfos = await self.get_homeinfos_async() homes: dict[str, dict[str, Any]] = {} devices: dict[str, dict] = {} - for device_type in ['home_list', 'share_home_list']: + for device_type in ["home_list", "share_home_list"]: homes.setdefault(device_type, {}) - for home_id, home_info in (homeinfos.get( - device_type, None) or {}).items(): + for home_id, home_info in (homeinfos.get(device_type, None) or {}).items(): if isinstance(home_ids, list) and home_id not in home_ids: continue - home_name: str = home_info['home_name'] - group_id: str = home_info['group_id'] + home_name: str = home_info["home_name"] + group_id: str = home_info["group_id"] homes[device_type].setdefault( - home_id, { - 'home_name': home_name, - 'uid': home_info['uid'], - 'group_id': group_id, - 'room_info': {} - }) - devices.update({did: { - 'home_id': home_id, - 'home_name': home_name, - 'room_id': home_id, - 'room_name': home_name, - 'group_id': group_id - } for did in home_info.get('dids', [])}) - for room_id, room_info in home_info.get('room_info').items(): - room_name: str = room_info.get('room_name', '') - homes[device_type][home_id]['room_info'][ - room_id] = room_name - devices.update({ + home_id, + { + "home_name": home_name, + "uid": home_info["uid"], + "group_id": group_id, + "room_info": {}, + }, + ) + devices.update( + { did: { - 'home_id': home_id, - 'home_name': home_name, - 'room_id': room_id, - 'room_name': room_name, - 'group_id': group_id - } for did in room_info.get('dids', [])}) - separated_shared_devices: dict = ( - await self.get_separated_shared_devices_async()) + "home_id": home_id, + "home_name": home_name, + "room_id": home_id, + "room_name": home_name, + "group_id": group_id, + } + for did in home_info.get("dids", []) + } + ) + for room_id, room_info in home_info.get("room_info").items(): + room_name: str = room_info.get("room_name", "") + homes[device_type][home_id]["room_info"][room_id] = room_name + devices.update( + { + did: { + "home_id": home_id, + "home_name": home_name, + "room_id": room_id, + "room_name": room_name, + "group_id": group_id, + } + for did in room_info.get("dids", []) + } + ) + separated_shared_devices: dict = await self.get_separated_shared_devices_async() if separated_shared_devices: - homes.setdefault('separated_shared_list', {}) + homes.setdefault("separated_shared_list", {}) for did, owner in separated_shared_devices.items(): - owner_id = str(owner['userid']) - homes['separated_shared_list'].setdefault(owner_id,{ - 'home_name': owner['nickname'], - 'uid': owner_id, - 'group_id': 'NotSupport', - 'room_info': {'shared_device': 'shared_device'} - }) - devices.update({did: { - 'home_id': owner_id, - 'home_name': owner['nickname'], - 'room_id': 'shared_device', - 'room_name': 'shared_device', - 'group_id': 'NotSupport' - }}) + owner_id = str(owner["userid"]) + homes["separated_shared_list"].setdefault( + owner_id, + { + "home_name": owner["nickname"], + "uid": owner_id, + "group_id": "NotSupport", + "room_info": {"shared_device": "shared_device"}, + }, + ) + devices.update( + { + did: { + "home_id": owner_id, + "home_name": owner["nickname"], + "room_id": "shared_device", + "room_name": "shared_device", + "group_id": "NotSupport", + } + } + ) dids = sorted(list(devices.keys())) results = await self.get_devices_with_dids_async(dids=dids) if results is None: - raise MIoTHttpError('get devices failed') + raise MIoTHttpError("get devices failed") for did in dids: if did not in results: devices.pop(did, None) - _LOGGER.info('get device info failed, %s', did) + _LOGGER.info("get device info failed, %s", did) continue devices[did].update(results[did]) # Whether sub devices - match_str = re.search(r'\.s\d+$', did) + match_str = re.search(r"\.s\d+$", did) if not match_str: continue device = devices.pop(did, None) - parent_did = did.replace(match_str.group(), '') + parent_did = did.replace(match_str.group(), "") if parent_did in devices: - devices[parent_did].setdefault('sub_devices', {}) - devices[parent_did]['sub_devices'][match_str.group()[ - 1:]] = device + devices[parent_did].setdefault("sub_devices", {}) + devices[parent_did]["sub_devices"][match_str.group()[1:]] = device else: - _LOGGER.error( - 'unknown sub devices, %s, %s', did, parent_did) - return { - 'uid': homeinfos['uid'], - 'homes': homes, - 'devices': devices - } + _LOGGER.error("unknown sub devices, %s, %s", did, parent_did) + return {"uid": homeinfos["uid"], "homes": homes, "devices": devices} async def get_props_async(self, params: list) -> list: """ @@ -719,71 +735,67 @@ class MIoTHttpClient: {"did": "xxxxxx", "siid": 2, "piid": 2}] """ res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/miotspec/prop/get', - data={ - 'datasource': 1, - 'params': params - }, + url_path="/app/v2/miotspec/prop/get", + data={"datasource": 1, "params": params}, ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') - return res_obj['result'] + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") + return res_obj["result"] async def __get_prop_async(self, did: str, siid: int, piid: int) -> Any: results = await self.get_props_async( - params=[{'did': did, 'siid': siid, 'piid': piid}]) + params=[{"did": did, "siid": siid, "piid": piid}] + ) if not results: return None result = results[0] - if 'value' not in result: + if "value" not in result: return None - return result['value'] + return result["value"] async def __get_prop_handler(self) -> bool: props_req: set[str] = set() props_buffer: list[dict] = [] for key, item in self._get_prop_list.items(): - if item.get('tag', False): + if item.get("tag", False): continue # NOTICE: max req prop if len(props_req) >= self.GET_PROP_MAX_REQ_COUNT: break - item['tag'] = True - props_buffer.append(item['param']) + item["tag"] = True + props_buffer.append(item["param"]) props_req.add(key) if not props_buffer: - _LOGGER.error('get prop error, empty request list') + _LOGGER.error("get prop error, empty request list") return False results = await self.get_props_async(props_buffer) for result in results: - if not all( - key in result for key in ['did', 'siid', 'piid', 'value']): + if not all(key in result for key in ["did", "siid", "piid", "value"]): continue - key = f'{result["did"]}.{result["siid"]}.{result["piid"]}' + key = f"{result['did']}.{result['siid']}.{result['piid']}" prop_obj = self._get_prop_list.pop(key, None) if prop_obj is None: - _LOGGER.info('get prop error, key not exists, %s', result) + _LOGGER.info("get prop error, key not exists, %s", result) continue - prop_obj['fut'].set_result(result['value']) + prop_obj["fut"].set_result(result["value"]) props_req.remove(key) for key in props_req: prop_obj = self._get_prop_list.pop(key, None) if prop_obj is None: continue - prop_obj['fut'].set_result(None) + prop_obj["fut"].set_result(None) if props_req: - _LOGGER.info( - 'get prop from cloud failed, %s', props_req) + _LOGGER.info("get prop from cloud failed, %s", props_req) if self._get_prop_list: self._get_prop_timer = self._main_loop.call_later( self.GET_PROP_AGGREGATE_INTERVAL, - lambda: self._main_loop.create_task( - self.__get_prop_handler())) + lambda: self._main_loop.create_task(self.__get_prop_handler()), + ) else: self._get_prop_timer = None return True @@ -793,20 +805,20 @@ class MIoTHttpClient: ) -> Any: if immediately: return await self.__get_prop_async(did, siid, piid) - key: str = f'{did}.{siid}.{piid}' + key: str = f"{did}.{siid}.{piid}" prop_obj = self._get_prop_list.get(key, None) if prop_obj: - return await prop_obj['fut'] + return await prop_obj["fut"] fut = self._main_loop.create_future() self._get_prop_list[key] = { - 'param': {'did': did, 'siid': siid, 'piid': piid}, - 'fut': fut + "param": {"did": did, "siid": siid, "piid": piid}, + "fut": fut, } if self._get_prop_timer is None: self._get_prop_timer = self._main_loop.call_later( self.GET_PROP_AGGREGATE_INTERVAL, - lambda: self._main_loop.create_task( - self.__get_prop_handler())) + lambda: self._main_loop.create_task(self.__get_prop_handler()), + ) return await fut @@ -815,16 +827,24 @@ class MIoTHttpClient: params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}] """ res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/miotspec/prop/set', - data={ - 'params': params - }, - timeout=15 + url_path="/app/v2/miotspec/prop/set", data={"params": params}, timeout=15 ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") - return res_obj['result'] + return res_obj["result"] + + async def set_props_async(self, params: list) -> list: + """ + params = [{"did": "xxxx", "siid": 2, "piid": 1, "value": False}] + """ + res_obj = await self.__mihome_api_post_async( + url_path="/app/v2/miotspec/prop/set", data={"params": params}, timeout=15 + ) + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") + + return res_obj["result"] async def action_async( self, did: str, siid: int, aiid: int, in_list: list[dict] @@ -834,17 +854,18 @@ class MIoTHttpClient: """ # NOTICE: Non-standard action param res_obj = await self.__mihome_api_post_async( - url_path='/app/v2/miotspec/action', + url_path="/app/v2/miotspec/action", data={ - 'params': { - 'did': did, - 'siid': siid, - 'aiid': aiid, - 'in': [item['value'] for item in in_list]} + "params": { + "did": did, + "siid": siid, + "aiid": aiid, + "in": [item["value"] for item in in_list], + } }, - timeout=15 + timeout=15, ) - if 'result' not in res_obj: - raise MIoTHttpError('invalid response result') + if "result" not in res_obj: + raise MIoTHttpError("invalid response result") - return res_obj['result'] + return res_obj["result"] diff --git a/custom_components/xiaomi_home/miot/miot_device.py b/custom_components/xiaomi_home/miot/miot_device.py index e3394c9..1d23d05 100644 --- a/custom_components/xiaomi_home/miot/miot_device.py +++ b/custom_components/xiaomi_home/miot/miot_device.py @@ -45,9 +45,10 @@ off Xiaomi or its affiliates' products. MIoT device instance. """ + import asyncio from abc import abstractmethod -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional import logging from homeassistant.helpers.entity import Entity @@ -73,7 +74,7 @@ from homeassistant.const import ( UnitOfPower, UnitOfVolume, UnitOfVolumeFlowRate, - UnitOfDataRate + UnitOfDataRate, ) from homeassistant.helpers.entity import DeviceInfo from homeassistant.components.switch import SwitchDeviceClass @@ -85,7 +86,7 @@ from .specs.specv2entity import ( SPEC_DEVICE_TRANS_MAP, SPEC_EVENT_TRANS_MAP, SPEC_PROP_TRANS_MAP, - SPEC_SERVICE_TRANS_MAP + SPEC_SERVICE_TRANS_MAP, ) from .common import slugify_name, slugify_did from .const import DOMAIN @@ -99,7 +100,7 @@ from .miot_spec import ( MIoTSpecProperty, MIoTSpecService, MIoTSpecValueList, - MIoTSpecValueRange + MIoTSpecValueRange, ) _LOGGER = logging.getLogger(__name__) @@ -107,6 +108,7 @@ _LOGGER = logging.getLogger(__name__) class MIoTEntityData: """MIoT Entity Data.""" + platform: str device_class: Any spec: MIoTSpecInstance | MIoTSpecService @@ -115,9 +117,7 @@ class MIoTEntityData: events: set[MIoTSpecEvent] actions: set[MIoTSpecAction] - def __init__( - self, platform: str, spec: MIoTSpecInstance | MIoTSpecService - ) -> None: + def __init__(self, platform: str, spec: MIoTSpecInstance | MIoTSpecService) -> None: self.platform = platform self.spec = spec self.device_class = None @@ -128,6 +128,7 @@ class MIoTEntityData: class MIoTDevice: """MIoT Device Instance.""" + # pylint: disable=unused-argument miot_client: MIoTClient spec_instance: MIoTSpecInstance @@ -150,8 +151,7 @@ class MIoTDevice: _suggested_area: Optional[str] _sub_id: int - _device_state_sub_list: dict[str, dict[ - str, Callable[[str, MIoTDeviceState], None]]] + _device_state_sub_list: dict[str, dict[str, Callable[[str, MIoTDeviceState], None]]] _value_sub_list: dict[str, dict[str, Callable[[dict, Any], None]]] _entity_list: dict[str, list[MIoTEntityData]] @@ -160,33 +160,33 @@ class MIoTDevice: _action_list: dict[str, list[MIoTSpecAction]] def __init__( - self, miot_client: MIoTClient, + self, + miot_client: MIoTClient, device_info: dict[str, Any], - spec_instance: MIoTSpecInstance + spec_instance: MIoTSpecInstance, ) -> None: self.miot_client = miot_client self.spec_instance = spec_instance - self._online = device_info.get('online', False) - self._did = device_info['did'] - self._name = device_info['name'] - self._model = device_info['model'] - self._model_strs = self._model.split('.') - self._manufacturer = device_info.get('manufacturer', None) - self._fw_version = device_info.get('fw_version', None) + self._online = device_info.get("online", False) + self._did = device_info["did"] + self._name = device_info["name"] + self._model = device_info["model"] + self._model_strs = self._model.split(".") + self._manufacturer = device_info.get("manufacturer", None) + self._fw_version = device_info.get("fw_version", None) - self._icon = device_info.get('icon', None) - self._home_id = device_info.get('home_id', None) - self._home_name = device_info.get('home_name', None) - self._room_id = device_info.get('room_id', None) - self._room_name = device_info.get('room_name', None) + self._icon = device_info.get("icon", None) + self._home_id = device_info.get("home_id", None) + self._home_name = device_info.get("home_name", None) + self._room_id = device_info.get("room_id", None) + self._room_name = device_info.get("room_name", None) match self.miot_client.area_name_rule: - case 'home_room': - self._suggested_area = ( - f'{self._home_name} {self._room_name}'.strip()) - case 'home': + case "home_room": + self._suggested_area = f"{self._home_name} {self._room_name}".strip() + case "home": self._suggested_area = self._home_name.strip() - case 'room': + case "room": self._suggested_area = self._room_name.strip() case _: self._suggested_area = None @@ -200,23 +200,23 @@ class MIoTDevice: self._action_list = {} # Sub devices name - sub_devices: dict[str, dict] = device_info.get('sub_devices', None) + sub_devices: dict[str, dict] = device_info.get("sub_devices", None) if isinstance(sub_devices, dict) and sub_devices: for service in spec_instance.services: - sub_info = sub_devices.get(f's{service.iid}', None) + sub_info = sub_devices.get(f"s{service.iid}", None) if sub_info is None: continue _LOGGER.debug( - 'miot device, update service sub info, %s, %s', - self.did, sub_info) + "miot device, update service sub info, %s, %s", self.did, sub_info + ) service.description_trans = sub_info.get( - 'name', service.description_trans) + "name", service.description_trans + ) # Sub device state - self.miot_client.sub_device_state( - self._did, self.__on_device_state_changed) + self.miot_client.sub_device_state(self._did, self.__on_device_state_changed) - _LOGGER.debug('miot device init %s', device_info) + _LOGGER.debug("miot device init %s", device_info) @property def online(self) -> bool: @@ -240,7 +240,8 @@ class MIoTDevice: async def action_async(self, siid: int, aiid: int, in_list: list) -> list: return await self.miot_client.action_async( - did=self._did, siid=siid, aiid=aiid, in_list=in_list) + did=self._did, siid=siid, aiid=aiid, in_list=in_list + ) def sub_device_state( self, key: str, handler: Callable[[str, MIoTDeviceState], None] @@ -262,7 +263,7 @@ class MIoTDevice: def sub_property( self, handler: Callable[[dict, Any], None], siid: int, piid: int ) -> int: - key: str = f'p.{siid}.{piid}' + key: str = f"p.{siid}.{piid}" def _on_prop_changed(params: dict, ctx: Any) -> None: for handler in self._value_sub_list[key].values(): @@ -274,11 +275,12 @@ class MIoTDevice: else: self._value_sub_list[key] = {str(sub_id): handler} self.miot_client.sub_prop( - did=self._did, handler=_on_prop_changed, siid=siid, piid=piid) + did=self._did, handler=_on_prop_changed, siid=siid, piid=piid + ) return sub_id def unsub_property(self, siid: int, piid: int, sub_id: int) -> None: - key: str = f'p.{siid}.{piid}' + key: str = f"p.{siid}.{piid}" sub_list = self._value_sub_list.get(key, None) if sub_list: @@ -290,7 +292,7 @@ class MIoTDevice: def sub_event( self, handler: Callable[[dict, Any], None], siid: int, eiid: int ) -> int: - key: str = f'e.{siid}.{eiid}' + key: str = f"e.{siid}.{eiid}" def _on_event_occurred(params: dict, ctx: Any) -> None: for handler in self._value_sub_list[key].values(): @@ -302,11 +304,12 @@ class MIoTDevice: else: self._value_sub_list[key] = {str(sub_id): handler} self.miot_client.sub_event( - did=self._did, handler=_on_event_occurred, siid=siid, eiid=eiid) + did=self._did, handler=_on_event_occurred, siid=siid, eiid=eiid + ) return sub_id def unsub_event(self, siid: int, eiid: int, sub_id: int) -> None: - key: str = f'e.{siid}.{eiid}' + key: str = f"e.{siid}.{eiid}" sub_list = self._value_sub_list.get(key, None) if sub_list: @@ -326,8 +329,9 @@ class MIoTDevice: manufacturer=self._manufacturer, suggested_area=self._suggested_area, configuration_url=( - f'https://home.mi.com/webapp/content/baike/product/index.html?' - f'model={self._model}') + f"https://home.mi.com/webapp/content/baike/product/index.html?" + f"model={self._model}" + ), ) @property @@ -337,43 +341,46 @@ class MIoTDevice: @property def did_tag(self) -> str: - return slugify_did( - cloud_server=self.miot_client.cloud_server, did=self._did) + return slugify_did(cloud_server=self.miot_client.cloud_server, did=self._did) def gen_device_entity_id(self, ha_domain: str) -> str: return ( - f'{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_' - f'{self._model_strs[-1][:20]}') + f"{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_" + f"{self._model_strs[-1][:20]}" + ) - def gen_service_entity_id(self, ha_domain: str, siid: int, - description: str) -> str: + def gen_service_entity_id(self, ha_domain: str, siid: int, description: str) -> str: return ( - f'{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_' - f'{self._model_strs[-1][:20]}_s_{siid}_{description}') + f"{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_" + f"{self._model_strs[-1][:20]}_s_{siid}_{description}" + ) def gen_prop_entity_id( self, ha_domain: str, spec_name: str, siid: int, piid: int ) -> str: return ( - f'{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_' - f'{self._model_strs[-1][:20]}_{slugify_name(spec_name)}' - f'_p_{siid}_{piid}') + f"{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_" + f"{self._model_strs[-1][:20]}_{slugify_name(spec_name)}" + f"_p_{siid}_{piid}" + ) def gen_event_entity_id( self, ha_domain: str, spec_name: str, siid: int, eiid: int ) -> str: return ( - f'{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_' - f'{self._model_strs[-1][:20]}_{slugify_name(spec_name)}' - f'_e_{siid}_{eiid}') + f"{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_" + f"{self._model_strs[-1][:20]}_{slugify_name(spec_name)}" + f"_e_{siid}_{eiid}" + ) def gen_action_entity_id( self, ha_domain: str, spec_name: str, siid: int, aiid: int ) -> str: return ( - f'{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_' - f'{self._model_strs[-1][:20]}_{slugify_name(spec_name)}' - f'_a_{siid}_{aiid}') + f"{ha_domain}.{self._model_strs[0][:9]}_{self.did_tag}_" + f"{self._model_strs[-1][:20]}_{slugify_name(spec_name)}" + f"_a_{siid}_{aiid}" + ) @property def name(self) -> str: @@ -417,17 +424,17 @@ class MIoTDevice: spec_name: str = spec_instance.name if isinstance(SPEC_DEVICE_TRANS_MAP[spec_name], str): spec_name = SPEC_DEVICE_TRANS_MAP[spec_name] - if 'required' not in SPEC_DEVICE_TRANS_MAP[spec_name]: + if "required" not in SPEC_DEVICE_TRANS_MAP[spec_name]: return None # 1. The device shall have all required services. - required_services = SPEC_DEVICE_TRANS_MAP[spec_name]['required'].keys() - if not { - service.name for service in spec_instance.services - }.issuperset(required_services): + required_services = SPEC_DEVICE_TRANS_MAP[spec_name]["required"].keys() + if not {service.name for service in spec_instance.services}.issuperset( + required_services + ): return None - optional_services = SPEC_DEVICE_TRANS_MAP[spec_name]['optional'].keys() + optional_services = SPEC_DEVICE_TRANS_MAP[spec_name]["optional"].keys() - platform = SPEC_DEVICE_TRANS_MAP[spec_name]['entity'] + platform = SPEC_DEVICE_TRANS_MAP[spec_name]["entity"] entity_data = MIoTEntityData(platform=platform, spec=spec_instance) for service in spec_instance.services: if service.platform: @@ -438,59 +445,75 @@ class MIoTDevice: optional_actions: set # 2. The service shall have all required properties, actions. if service.name in required_services: - required_properties = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'required'].get( - service.name, {} - ).get('required', {}).get('properties', {}) - optional_properties = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'required'].get( - service.name, {} - ).get('optional', {}).get('properties', set({})) - required_actions = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'required'].get( - service.name, {} - ).get('required', {}).get('actions', set({})) - optional_actions = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'required'].get( - service.name, {} - ).get('optional', {}).get('actions', set({})) + required_properties = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["required"] + .get(service.name, {}) + .get("required", {}) + .get("properties", {}) + ) + optional_properties = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["required"] + .get(service.name, {}) + .get("optional", {}) + .get("properties", set({})) + ) + required_actions = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["required"] + .get(service.name, {}) + .get("required", {}) + .get("actions", set({})) + ) + optional_actions = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["required"] + .get(service.name, {}) + .get("optional", {}) + .get("actions", set({})) + ) elif service.name in optional_services: - required_properties = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'optional'].get( - service.name, {} - ).get('required', {}).get('properties', {}) - optional_properties = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'optional'].get( - service.name, {} - ).get('optional', {}).get('properties', set({})) - required_actions = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'optional'].get( - service.name, {} - ).get('required', {}).get('actions', set({})) - optional_actions = SPEC_DEVICE_TRANS_MAP[spec_name][ - 'optional'].get( - service.name, {} - ).get('optional', {}).get('actions', set({})) + required_properties = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["optional"] + .get(service.name, {}) + .get("required", {}) + .get("properties", {}) + ) + optional_properties = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["optional"] + .get(service.name, {}) + .get("optional", {}) + .get("properties", set({})) + ) + required_actions = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["optional"] + .get(service.name, {}) + .get("required", {}) + .get("actions", set({})) + ) + optional_actions = ( + SPEC_DEVICE_TRANS_MAP[spec_name]["optional"] + .get(service.name, {}) + .get("optional", {}) + .get("actions", set({})) + ) else: continue - if not { - prop.name for prop in service.properties if prop.access - }.issuperset(set(required_properties.keys())): + if not {prop.name for prop in service.properties if prop.access}.issuperset( + set(required_properties.keys()) + ): return None - if not { - action.name for action in service.actions - }.issuperset(required_actions): + if not {action.name for action in service.actions}.issuperset( + required_actions + ): return None # 3. The required property shall have all required access mode. for prop in service.properties: if prop.name in required_properties: - if not set(prop.access).issuperset( - required_properties[prop.name]): + if not set(prop.access).issuperset(required_properties[prop.name]): return None # property for prop in service.properties: if prop.name in set.union( - set(required_properties.keys()), optional_properties): + set(required_properties.keys()), optional_properties + ): if prop.unit: prop.external_unit = self.unit_convert(prop.unit) # prop.icon = self.icon_convert(prop.unit) @@ -498,8 +521,7 @@ class MIoTDevice: entity_data.props.add(prop) # action for action in service.actions: - if action.name in set.union( - required_actions, optional_actions): + if action.name in set.union(required_actions, optional_actions): action.platform = platform entity_data.actions.add(action) # event @@ -510,38 +532,37 @@ class MIoTDevice: def parse_miot_service_entity( self, miot_service: MIoTSpecService ) -> Optional[MIoTEntityData]: - if ( - miot_service.platform - or miot_service.name not in SPEC_SERVICE_TRANS_MAP - ): + if miot_service.platform or miot_service.name not in SPEC_SERVICE_TRANS_MAP: return None service_name = miot_service.name if isinstance(SPEC_SERVICE_TRANS_MAP[service_name], str): service_name = SPEC_SERVICE_TRANS_MAP[service_name] - if 'required' not in SPEC_SERVICE_TRANS_MAP[service_name]: + if "required" not in SPEC_SERVICE_TRANS_MAP[service_name]: return None # Required properties, required access mode required_properties: dict = SPEC_SERVICE_TRANS_MAP[service_name][ - 'required'].get('properties', {}) + "required" + ].get("properties", {}) if not { prop.name for prop in miot_service.properties if prop.access }.issuperset(set(required_properties.keys())): return None for prop in miot_service.properties: if prop.name in required_properties: - if not set(prop.access).issuperset( - required_properties[prop.name]): + if not set(prop.access).issuperset(required_properties[prop.name]): return None # Required actions # Required events - platform = SPEC_SERVICE_TRANS_MAP[service_name]['entity'] + platform = SPEC_SERVICE_TRANS_MAP[service_name]["entity"] entity_data = MIoTEntityData(platform=platform, spec=miot_service) # Optional properties - optional_properties = SPEC_SERVICE_TRANS_MAP[service_name][ - 'optional'].get('properties', set({})) + optional_properties = SPEC_SERVICE_TRANS_MAP[service_name]["optional"].get( + "properties", set({}) + ) for prop in miot_service.properties: if prop.name in set.union( - set(required_properties.keys()), optional_properties): + set(required_properties.keys()), optional_properties + ): if prop.unit: prop.external_unit = self.unit_convert(prop.unit) # prop.icon = self.icon_convert(prop.unit) @@ -552,46 +573,50 @@ class MIoTDevice: miot_service.platform = platform # entity_category if entity_category := SPEC_SERVICE_TRANS_MAP[service_name].get( - 'entity_category', None): + "entity_category", None + ): miot_service.entity_category = entity_category return entity_data def parse_miot_property_entity(self, miot_prop: MIoTSpecProperty) -> bool: if ( miot_prop.platform - or miot_prop.name not in SPEC_PROP_TRANS_MAP['properties'] + or miot_prop.name not in SPEC_PROP_TRANS_MAP["properties"] ): return False prop_name = miot_prop.name - if isinstance(SPEC_PROP_TRANS_MAP['properties'][prop_name], str): - prop_name = SPEC_PROP_TRANS_MAP['properties'][prop_name] - platform = SPEC_PROP_TRANS_MAP['properties'][prop_name]['entity'] + if isinstance(SPEC_PROP_TRANS_MAP["properties"][prop_name], str): + prop_name = SPEC_PROP_TRANS_MAP["properties"][prop_name] + platform = SPEC_PROP_TRANS_MAP["properties"][prop_name]["entity"] # Check prop_access: set = set({}) if miot_prop.readable: - prop_access.add('read') + prop_access.add("read") if miot_prop.writable: - prop_access.add('write') - if prop_access != (SPEC_PROP_TRANS_MAP[ - 'entities'][platform]['access']): + prop_access.add("write") + if prop_access != (SPEC_PROP_TRANS_MAP["entities"][platform]["access"]): return False - if miot_prop.format_.__name__ not in SPEC_PROP_TRANS_MAP[ - 'entities'][platform]['format']: + if ( + miot_prop.format_.__name__ + not in SPEC_PROP_TRANS_MAP["entities"][platform]["format"] + ): return False - miot_prop.device_class = SPEC_PROP_TRANS_MAP['properties'][prop_name][ - 'device_class'] + miot_prop.device_class = SPEC_PROP_TRANS_MAP["properties"][prop_name][ + "device_class" + ] # Optional params - if 'state_class' in SPEC_PROP_TRANS_MAP['properties'][prop_name]: - miot_prop.state_class = SPEC_PROP_TRANS_MAP['properties'][ - prop_name]['state_class'] + if "state_class" in SPEC_PROP_TRANS_MAP["properties"][prop_name]: + miot_prop.state_class = SPEC_PROP_TRANS_MAP["properties"][prop_name][ + "state_class" + ] if ( not miot_prop.external_unit - and 'unit_of_measurement' in SPEC_PROP_TRANS_MAP['properties'][ - prop_name] + and "unit_of_measurement" in SPEC_PROP_TRANS_MAP["properties"][prop_name] ): # Priority: spec_modify.unit > unit_convert > specv2entity.unit - miot_prop.external_unit = SPEC_PROP_TRANS_MAP['properties'][ - prop_name]['unit_of_measurement'] + miot_prop.external_unit = SPEC_PROP_TRANS_MAP["properties"][prop_name][ + "unit_of_measurement" + ] # Priority: default.icon when device_class is set > spec_modify.icon # > icon_convert miot_prop.platform = platform @@ -600,14 +625,12 @@ class MIoTDevice: def spec_transform(self) -> None: """Parse service, property, event, action from device spec.""" # STEP 1: device conversion - device_entity = self.parse_miot_device_entity( - spec_instance=self.spec_instance) + device_entity = self.parse_miot_device_entity(spec_instance=self.spec_instance) if device_entity: self.append_entity(entity_data=device_entity) # STEP 2: service conversion for service in self.spec_instance.services: - service_entity = self.parse_miot_service_entity( - miot_service=service) + service_entity = self.parse_miot_service_entity(miot_service=service) if service_entity: self.append_entity(entity_data=service_entity) # STEP 3.1: property conversion @@ -624,28 +647,28 @@ class MIoTDevice: if not prop.platform: if prop.writable: if prop.format_ == str: - prop.platform = 'text' + prop.platform = "text" elif prop.format_ == bool: - prop.platform = 'switch' + prop.platform = "switch" prop.device_class = SwitchDeviceClass.SWITCH elif prop.value_list: - prop.platform = 'select' + prop.platform = "select" elif prop.value_range: - prop.platform = 'number' + prop.platform = "number" else: # Irregular property will not be transformed. continue elif prop.readable or prop.notifiable: if prop.format_ == bool: - prop.platform = 'binary_sensor' + prop.platform = "binary_sensor" else: - prop.platform = 'sensor' + prop.platform = "sensor" self.append_prop(prop=prop) # STEP 3.2: event conversion for event in service.events: if event.platform: continue - event.platform = 'event' + event.platform = "event" if event.name in SPEC_EVENT_TRANS_MAP: event.device_class = SPEC_EVENT_TRANS_MAP[event.name] self.append_event(event=event) @@ -656,9 +679,9 @@ class MIoTDevice: if action.name in SPEC_ACTION_TRANS_MAP: continue if action.in_: - action.platform = 'notify' + action.platform = "notify" else: - action.platform = 'button' + action.platform = "button" self.append_action(action=action) def unit_convert(self, spec_unit: str) -> Optional[str]: @@ -718,59 +741,59 @@ class MIoTDevice: } """ unit_map = { - 'percentage': PERCENTAGE, - 'weeks': UnitOfTime.WEEKS, - 'days': UnitOfTime.DAYS, - 'hour': UnitOfTime.HOURS, - 'hours': UnitOfTime.HOURS, - 'minutes': UnitOfTime.MINUTES, - 'seconds': UnitOfTime.SECONDS, - 'ms': UnitOfTime.MILLISECONDS, - 'μs': UnitOfTime.MICROSECONDS, - 'celsius': UnitOfTemperature.CELSIUS, - 'fahrenheit': UnitOfTemperature.FAHRENHEIT, - 'kelvin': UnitOfTemperature.KELVIN, - 'μg/m3': CONCENTRATION_MICROGRAMS_PER_CUBIC_METER, - 'mg/m3': CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER, - 'ppm': CONCENTRATION_PARTS_PER_MILLION, - 'ppb': CONCENTRATION_PARTS_PER_BILLION, - 'lux': LIGHT_LUX, - 'pascal': UnitOfPressure.PA, - 'kilopascal': UnitOfPressure.KPA, - 'mmHg': UnitOfPressure.MMHG, - 'bar': UnitOfPressure.BAR, - 'L': UnitOfVolume.LITERS, - 'liter': UnitOfVolume.LITERS, - 'mL': UnitOfVolume.MILLILITERS, - 'km/h': UnitOfSpeed.KILOMETERS_PER_HOUR, - 'm/s': UnitOfSpeed.METERS_PER_SECOND, - 'watt': UnitOfPower.WATT, - 'w': UnitOfPower.WATT, - 'W': UnitOfPower.WATT, - 'kWh': UnitOfEnergy.KILO_WATT_HOUR, - 'A': UnitOfElectricCurrent.AMPERE, - 'mA': UnitOfElectricCurrent.MILLIAMPERE, - 'V': UnitOfElectricPotential.VOLT, - 'mv': UnitOfElectricPotential.MILLIVOLT, - 'mV': UnitOfElectricPotential.MILLIVOLT, - 'cm': UnitOfLength.CENTIMETERS, - 'm': UnitOfLength.METERS, - 'meter': UnitOfLength.METERS, - 'km': UnitOfLength.KILOMETERS, - 'm3/h': UnitOfVolumeFlowRate.CUBIC_METERS_PER_HOUR, - 'gram': UnitOfMass.GRAMS, - 'kilogram': UnitOfMass.KILOGRAMS, - 'dB': SIGNAL_STRENGTH_DECIBELS, - 'arcdegrees': DEGREE, - 'arcdegress': DEGREE, - 'kB': UnitOfInformation.KILOBYTES, - 'MB': UnitOfInformation.MEGABYTES, - 'GB': UnitOfInformation.GIGABYTES, - 'TB': UnitOfInformation.TERABYTES, - 'B/s': UnitOfDataRate.BYTES_PER_SECOND, - 'KB/s': UnitOfDataRate.KILOBYTES_PER_SECOND, - 'MB/s': UnitOfDataRate.MEGABYTES_PER_SECOND, - 'GB/s': UnitOfDataRate.GIGABYTES_PER_SECOND + "percentage": PERCENTAGE, + "weeks": UnitOfTime.WEEKS, + "days": UnitOfTime.DAYS, + "hour": UnitOfTime.HOURS, + "hours": UnitOfTime.HOURS, + "minutes": UnitOfTime.MINUTES, + "seconds": UnitOfTime.SECONDS, + "ms": UnitOfTime.MILLISECONDS, + "μs": UnitOfTime.MICROSECONDS, + "celsius": UnitOfTemperature.CELSIUS, + "fahrenheit": UnitOfTemperature.FAHRENHEIT, + "kelvin": UnitOfTemperature.KELVIN, + "μg/m3": CONCENTRATION_MICROGRAMS_PER_CUBIC_METER, + "mg/m3": CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER, + "ppm": CONCENTRATION_PARTS_PER_MILLION, + "ppb": CONCENTRATION_PARTS_PER_BILLION, + "lux": LIGHT_LUX, + "pascal": UnitOfPressure.PA, + "kilopascal": UnitOfPressure.KPA, + "mmHg": UnitOfPressure.MMHG, + "bar": UnitOfPressure.BAR, + "L": UnitOfVolume.LITERS, + "liter": UnitOfVolume.LITERS, + "mL": UnitOfVolume.MILLILITERS, + "km/h": UnitOfSpeed.KILOMETERS_PER_HOUR, + "m/s": UnitOfSpeed.METERS_PER_SECOND, + "watt": UnitOfPower.WATT, + "w": UnitOfPower.WATT, + "W": UnitOfPower.WATT, + "kWh": UnitOfEnergy.KILO_WATT_HOUR, + "A": UnitOfElectricCurrent.AMPERE, + "mA": UnitOfElectricCurrent.MILLIAMPERE, + "V": UnitOfElectricPotential.VOLT, + "mv": UnitOfElectricPotential.MILLIVOLT, + "mV": UnitOfElectricPotential.MILLIVOLT, + "cm": UnitOfLength.CENTIMETERS, + "m": UnitOfLength.METERS, + "meter": UnitOfLength.METERS, + "km": UnitOfLength.KILOMETERS, + "m3/h": UnitOfVolumeFlowRate.CUBIC_METERS_PER_HOUR, + "gram": UnitOfMass.GRAMS, + "kilogram": UnitOfMass.KILOGRAMS, + "dB": SIGNAL_STRENGTH_DECIBELS, + "arcdegrees": DEGREE, + "arcdegress": DEGREE, + "kB": UnitOfInformation.KILOBYTES, + "MB": UnitOfInformation.MEGABYTES, + "GB": UnitOfInformation.GIGABYTES, + "TB": UnitOfInformation.TERABYTES, + "B/s": UnitOfDataRate.BYTES_PER_SECOND, + "KB/s": UnitOfDataRate.KILOBYTES_PER_SECOND, + "MB/s": UnitOfDataRate.MEGABYTES_PER_SECOND, + "GB/s": UnitOfDataRate.GIGABYTES_PER_SECOND, } # Handle UnitOfConductivity separately since @@ -778,69 +801,77 @@ class MIoTDevice: try: # pylint: disable=import-outside-toplevel from homeassistant.const import UnitOfConductivity # type: ignore - unit_map['μS/cm'] = UnitOfConductivity.MICROSIEMENS_PER_CM - unit_map['mWh'] = UnitOfEnergy.MILLIWATT_HOUR + + unit_map["μS/cm"] = UnitOfConductivity.MICROSIEMENS_PER_CM + unit_map["mWh"] = UnitOfEnergy.MILLIWATT_HOUR except Exception: # pylint: disable=broad-except - unit_map['μS/cm'] = 'μS/cm' - unit_map['mWh'] = 'mWh' + unit_map["μS/cm"] = "μS/cm" + unit_map["mWh"] = "mWh" return unit_map.get(spec_unit, None) def icon_convert(self, spec_unit: str) -> Optional[str]: - if spec_unit in {'percentage'}: - return 'mdi:percent' + if spec_unit in {"percentage"}: + return "mdi:percent" if spec_unit in { - 'weeks', 'days', 'hour', 'hours', 'minutes', 'seconds', 'ms', 'μs' + "weeks", + "days", + "hour", + "hours", + "minutes", + "seconds", + "ms", + "μs", }: - return 'mdi:clock' - if spec_unit in {'celsius'}: - return 'mdi:temperature-celsius' - if spec_unit in {'fahrenheit'}: - return 'mdi:temperature-fahrenheit' - if spec_unit in {'kelvin'}: - return 'mdi:temperature-kelvin' - if spec_unit in {'μg/m3', 'mg/m3', 'ppm', 'ppb'}: - return 'mdi:blur' - if spec_unit in {'lux'}: - return 'mdi:brightness-6' - if spec_unit in {'pascal', 'kilopascal', 'megapascal', 'mmHg', 'bar'}: - return 'mdi:gauge' - if spec_unit in {'watt', 'w', 'W'}: - return 'mdi:flash-triangle' - if spec_unit in {'L', 'mL'}: - return 'mdi:gas-cylinder' - if spec_unit in {'km/h', 'm/s'}: - return 'mdi:speedometer' - if spec_unit in {'kWh'}: - return 'mdi:transmission-tower' - if spec_unit in {'A', 'mA'}: - return 'mdi:current-ac' - if spec_unit in {'V', 'mv', 'mV'}: - return 'mdi:current-dc' - if spec_unit in {'cm', 'm', 'meter', 'km'}: - return 'mdi:ruler' - if spec_unit in {'rgb'}: - return 'mdi:palette' - if spec_unit in {'m3/h', 'L/s'}: - return 'mdi:pipe-leak' - if spec_unit in {'μS/cm'}: - return 'mdi:resistor-nodes' - if spec_unit in {'gram', 'kilogram'}: - return 'mdi:weight' - if spec_unit in {'dB'}: - return 'mdi:signal-distance-variant' - if spec_unit in {'times'}: - return 'mdi:counter' - if spec_unit in {'mmol/L'}: - return 'mdi:dots-hexagon' - if spec_unit in {'kB', 'MB', 'GB'}: - return 'mdi:network-pos' - if spec_unit in {'arcdegress', 'arcdegrees'}: - return 'mdi:angle-obtuse' - if spec_unit in {'B/s', 'KB/s', 'MB/s', 'GB/s'}: - return 'mdi:network' - if spec_unit in {'calorie', 'kCal'}: - return 'mdi:food' + return "mdi:clock" + if spec_unit in {"celsius"}: + return "mdi:temperature-celsius" + if spec_unit in {"fahrenheit"}: + return "mdi:temperature-fahrenheit" + if spec_unit in {"kelvin"}: + return "mdi:temperature-kelvin" + if spec_unit in {"μg/m3", "mg/m3", "ppm", "ppb"}: + return "mdi:blur" + if spec_unit in {"lux"}: + return "mdi:brightness-6" + if spec_unit in {"pascal", "kilopascal", "megapascal", "mmHg", "bar"}: + return "mdi:gauge" + if spec_unit in {"watt", "w", "W"}: + return "mdi:flash-triangle" + if spec_unit in {"L", "mL"}: + return "mdi:gas-cylinder" + if spec_unit in {"km/h", "m/s"}: + return "mdi:speedometer" + if spec_unit in {"kWh"}: + return "mdi:transmission-tower" + if spec_unit in {"A", "mA"}: + return "mdi:current-ac" + if spec_unit in {"V", "mv", "mV"}: + return "mdi:current-dc" + if spec_unit in {"cm", "m", "meter", "km"}: + return "mdi:ruler" + if spec_unit in {"rgb"}: + return "mdi:palette" + if spec_unit in {"m3/h", "L/s"}: + return "mdi:pipe-leak" + if spec_unit in {"μS/cm"}: + return "mdi:resistor-nodes" + if spec_unit in {"gram", "kilogram"}: + return "mdi:weight" + if spec_unit in {"dB"}: + return "mdi:signal-distance-variant" + if spec_unit in {"times"}: + return "mdi:counter" + if spec_unit in {"mmol/L"}: + return "mdi:dots-hexagon" + if spec_unit in {"kB", "MB", "GB"}: + return "mdi:network-pos" + if spec_unit in {"arcdegress", "arcdegrees"}: + return "mdi:angle-obtuse" + if spec_unit in {"B/s", "KB/s", "MB/s", "GB/s"}: + return "mdi:network" + if spec_unit in {"calorie", "kCal"}: + return "mdi:food" return None def __gen_sub_id(self) -> int: @@ -853,12 +884,12 @@ class MIoTDevice: self._online = state == MIoTDeviceState.ONLINE for key, sub_list in self._device_state_sub_list.items(): for handler in sub_list.values(): - self.miot_client.main_loop.call_soon_threadsafe( - handler, key, state) + self.miot_client.main_loop.call_soon_threadsafe(handler, key, state) class MIoTServiceEntity(Entity): """MIoT Service Entity.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes miot_device: MIoTDevice @@ -869,22 +900,14 @@ class MIoTServiceEntity(Entity): _state_sub_id: int _value_sub_ids: dict[str, int] - _event_occurred_handler: Optional[ - Callable[[MIoTSpecEvent, dict], None]] - _prop_changed_subs: dict[ - MIoTSpecProperty, Callable[[MIoTSpecProperty, Any], None]] + _event_occurred_handler: Optional[Callable[[MIoTSpecEvent, dict], None]] + _prop_changed_subs: dict[MIoTSpecProperty, Callable[[MIoTSpecProperty, Any], None]] _pending_write_ha_state_timer: Optional[asyncio.TimerHandle] - def __init__( - self, miot_device: MIoTDevice, entity_data: MIoTEntityData - ) -> None: - if ( - miot_device is None - or entity_data is None - or entity_data.spec is None - ): - raise MIoTDeviceError('init error, invalid params') + def __init__(self, miot_device: MIoTDevice, entity_data: MIoTEntityData) -> None: + if miot_device is None or entity_data is None or entity_data.spec is None: + raise MIoTDeviceError("init error, invalid params") self.miot_device = miot_device self.entity_data = entity_data self._main_loop = miot_device.miot_client.main_loop @@ -894,14 +917,17 @@ class MIoTServiceEntity(Entity): # Gen entity id if isinstance(self.entity_data.spec, MIoTSpecInstance): self.entity_id = miot_device.gen_device_entity_id(DOMAIN) - self._attr_name = f' {self.entity_data.spec.description_trans}' + self._attr_name = f" {self.entity_data.spec.description_trans}" elif isinstance(self.entity_data.spec, MIoTSpecService): self.entity_id = miot_device.gen_service_entity_id( - DOMAIN, siid=self.entity_data.spec.iid, - description=self.entity_data.spec.description) + DOMAIN, + siid=self.entity_data.spec.iid, + description=self.entity_data.spec.description, + ) self._attr_name = ( - f'{"* "if self.entity_data.spec.proprietary else " "}' - f'{self.entity_data.spec.description_trans}') + f"{'* ' if self.entity_data.spec.proprietary else ' '}" + f"{self.entity_data.spec.description_trans}" + ) self._attr_entity_category = entity_data.spec.entity_category # Set entity attr self._attr_unique_id = self.entity_id @@ -913,14 +939,15 @@ class MIoTServiceEntity(Entity): self._prop_changed_subs = {} self._pending_write_ha_state_timer = None _LOGGER.info( - 'new miot service entity, %s, %s, %s, %s', - self.miot_device.name, self._attr_name, self.entity_data.spec.name, - self.entity_id) + "new miot service entity, %s, %s, %s, %s", + self.miot_device.name, + self._attr_name, + self.entity_data.spec.name, + self.entity_id, + ) @property - def event_occurred_handler( - self - ) -> Optional[Callable[[MIoTSpecEvent, dict], None]]: + def event_occurred_handler(self) -> Optional[Callable[[MIoTSpecEvent, dict], None]]: return self._event_occurred_handler @event_occurred_handler.setter @@ -928,12 +955,10 @@ class MIoTServiceEntity(Entity): self._event_occurred_handler = func def sub_prop_changed( - self, prop: MIoTSpecProperty, - handler: Callable[[MIoTSpecProperty, Any], None] + self, prop: MIoTSpecProperty, handler: Callable[[MIoTSpecProperty, Any], None] ) -> None: if not prop or not handler: - _LOGGER.error( - 'sub_prop_changed error, invalid prop/handler') + _LOGGER.error("sub_prop_changed error, invalid prop/handler") return self._prop_changed_subs[prop] = handler @@ -945,25 +970,28 @@ class MIoTServiceEntity(Entity): return self.miot_device.device_info async def async_added_to_hass(self) -> None: - state_id = 's.0' + state_id = "s.0" if isinstance(self.entity_data.spec, MIoTSpecService): - state_id = f's.{self.entity_data.spec.iid}' + state_id = f"s.{self.entity_data.spec.iid}" self._state_sub_id = self.miot_device.sub_device_state( - key=state_id, handler=self.__on_device_state_changed) + key=state_id, handler=self.__on_device_state_changed + ) # Sub prop for prop in self.entity_data.props: if not prop.notifiable and not prop.readable: continue - key = f'p.{prop.service.iid}.{prop.iid}' + key = f"p.{prop.service.iid}.{prop.iid}" self._value_sub_ids[key] = self.miot_device.sub_property( handler=self.__on_properties_changed, - siid=prop.service.iid, piid=prop.iid) + siid=prop.service.iid, + piid=prop.iid, + ) # Sub event for event in self.entity_data.events: - key = f'e.{event.service.iid}.{event.iid}' + key = f"e.{event.service.iid}.{event.iid}" self._value_sub_ids[key] = self.miot_device.sub_event( - handler=self.__on_event_occurred, - siid=event.service.iid, eiid=event.iid) + handler=self.__on_event_occurred, siid=event.service.iid, eiid=event.iid + ) # Refresh value if self._attr_available: @@ -973,38 +1001,33 @@ class MIoTServiceEntity(Entity): if self._pending_write_ha_state_timer: self._pending_write_ha_state_timer.cancel() self._pending_write_ha_state_timer = None - state_id = 's.0' + state_id = "s.0" if isinstance(self.entity_data.spec, MIoTSpecService): - state_id = f's.{self.entity_data.spec.iid}' - self.miot_device.unsub_device_state( - key=state_id, sub_id=self._state_sub_id) + state_id = f"s.{self.entity_data.spec.iid}" + self.miot_device.unsub_device_state(key=state_id, sub_id=self._state_sub_id) # Unsub prop for prop in self.entity_data.props: if not prop.notifiable and not prop.readable: continue - sub_id = self._value_sub_ids.pop( - f'p.{prop.service.iid}.{prop.iid}', None) + sub_id = self._value_sub_ids.pop(f"p.{prop.service.iid}.{prop.iid}", None) if sub_id: self.miot_device.unsub_property( - siid=prop.service.iid, piid=prop.iid, sub_id=sub_id) + siid=prop.service.iid, piid=prop.iid, sub_id=sub_id + ) # Unsub event for event in self.entity_data.events: - sub_id = self._value_sub_ids.pop( - f'e.{event.service.iid}.{event.iid}', None) + sub_id = self._value_sub_ids.pop(f"e.{event.service.iid}.{event.iid}", None) if sub_id: self.miot_device.unsub_event( - siid=event.service.iid, eiid=event.iid, sub_id=sub_id) + siid=event.service.iid, eiid=event.iid, sub_id=sub_id + ) - def get_map_value( - self, map_: Optional[dict[int, Any]], key: int - ) -> Any: + def get_map_value(self, map_: Optional[dict[int, Any]], key: int) -> Any: if map_ is None: return None return map_.get(key, None) - def get_map_key( - self, map_: Optional[dict[int, Any]], value: Any - ) -> Optional[int]: + def get_map_key(self, map_: Optional[dict[int, Any]], value: Any) -> Optional[int]: if map_ is None: return None for key, value_ in map_.items(): @@ -1015,70 +1038,139 @@ class MIoTServiceEntity(Entity): def get_prop_value(self, prop: Optional[MIoTSpecProperty]) -> Any: if not prop: _LOGGER.error( - 'get_prop_value error, property is None, %s, %s', - self._attr_name, self.entity_id) + "get_prop_value error, property is None, %s, %s", + self._attr_name, + self.entity_id, + ) return None return self._prop_value_map.get(prop, None) - def set_prop_value( - self, prop: Optional[MIoTSpecProperty], value: Any - ) -> None: + def set_prop_value(self, prop: Optional[MIoTSpecProperty], value: Any) -> None: if not prop: _LOGGER.error( - 'set_prop_value error, property is None, %s, %s', - self._attr_name, self.entity_id) + "set_prop_value error, property is None, %s, %s", + self._attr_name, + self.entity_id, + ) return self._prop_value_map[prop] = value async def set_property_async( - self, prop: Optional[MIoTSpecProperty], value: Any, - update_value: bool = True, write_ha_state: bool = True + self, + prop: Optional[MIoTSpecProperty], + value: Any, + update_value: bool = True, + write_ha_state: bool = True, ) -> bool: if not prop: raise RuntimeError( - f'set property failed, property is None, ' - f'{self.entity_id}, {self.name}') + f"set property failed, property is None, {self.entity_id}, {self.name}" + ) value = prop.value_format(value) if prop not in self.entity_data.props: raise RuntimeError( - f'set property failed, unknown property, ' - f'{self.entity_id}, {self.name}, {prop.name}') + f"set property failed, unknown property, " + f"{self.entity_id}, {self.name}, {prop.name}" + ) if not prop.writable: raise RuntimeError( - f'set property failed, not writable, ' - f'{self.entity_id}, {self.name}, {prop.name}') + f"set property failed, not writable, " + f"{self.entity_id}, {self.name}, {prop.name}" + ) try: await self.miot_device.miot_client.set_prop_async( - did=self.miot_device.did, siid=prop.service.iid, - piid=prop.iid, value=value) + did=self.miot_device.did, + siid=prop.service.iid, + piid=prop.iid, + value=value, + ) except MIoTClientError as e: raise RuntimeError( - f'{e}, {self.entity_id}, {self.name}, {prop.name}') from e + f"{e}, {self.entity_id}, {self.name}, {prop.name}" + ) from e if update_value: self._prop_value_map[prop] = value if write_ha_state: self.async_write_ha_state() return True + async def set_properties_async( + self, + set_properties_list: List[Dict[str, Any]], + update_value: bool = True, + write_ha_state: bool = True, + ) -> bool: + for set_property in set_properties_list: + prop = set_property.get("prop") + value = set_property.get("value") + if not prop: + raise RuntimeError( + f"set property failed, property is None, " + f"{self.entity_id}, {self.name}" + ) + set_property["value"] = prop.value_format(value) + if prop not in self.entity_data.props: + raise RuntimeError( + f"set property failed, unknown property, " + f"{self.entity_id}, {self.name}, {prop.name}" + ) + if not prop.writable: + raise RuntimeError( + f"set property failed, not writable, " + f"{self.entity_id}, {self.name}, {prop.name}" + ) + try: + await self.miot_device.miot_client.set_props_async( + [ + { + "did": self.miot_device.did, + "siid": set_property["prop"].service.iid, + "piid": set_property["prop"].iid, + "value": set_property["value"], + } + for set_property in set_properties_list + ] + ) + except MIoTClientError as e: + raise RuntimeError( + f"{e}, {self.entity_id}, {self.name}, {'/'.join([set_property['prop'].name for set_property in set_properties_list])}" + ) from e + if update_value: + for set_property in set_properties_list: + self._prop_value_map[set_property["prop"]] = set_property["value"] + if write_ha_state: + self.async_write_ha_state() + return True + async def get_property_async(self, prop: MIoTSpecProperty) -> Any: if not prop: _LOGGER.error( - 'get property failed, property is None, %s, %s', - self.entity_id, self.name) + "get property failed, property is None, %s, %s", + self.entity_id, + self.name, + ) return None if prop not in self.entity_data.props: _LOGGER.error( - 'get property failed, unknown property, %s, %s, %s', - self.entity_id, self.name, prop.name) + "get property failed, unknown property, %s, %s, %s", + self.entity_id, + self.name, + prop.name, + ) return None if not prop.readable: _LOGGER.error( - 'get property failed, not readable, %s, %s, %s', - self.entity_id, self.name, prop.name) + "get property failed, not readable, %s, %s, %s", + self.entity_id, + self.name, + prop.name, + ) return None result = prop.value_format( await self.miot_device.miot_client.get_prop_async( - did=self.miot_device.did, siid=prop.service.iid, piid=prop.iid)) + did=self.miot_device.did, siid=prop.service.iid, piid=prop.iid + ) + ) if result != self._prop_value_map[prop]: self._prop_value_map[prop] = result self.async_write_ha_state() @@ -1089,25 +1181,27 @@ class MIoTServiceEntity(Entity): ) -> bool: if not action: raise RuntimeError( - f'action failed, action is None, {self.entity_id}, {self.name}') + f"action failed, action is None, {self.entity_id}, {self.name}" + ) try: await self.miot_device.miot_client.action_async( - did=self.miot_device.did, siid=action.service.iid, - aiid=action.iid, in_list=in_list or []) + did=self.miot_device.did, + siid=action.service.iid, + aiid=action.iid, + in_list=in_list or [], + ) except MIoTClientError as e: raise RuntimeError( - f'{e}, {self.entity_id}, {self.name}, {action.name}') from e + f"{e}, {self.entity_id}, {self.name}, {action.name}" + ) from e return True def __on_properties_changed(self, params: dict, ctx: Any) -> None: - _LOGGER.debug('properties changed, %s', params) + _LOGGER.debug("properties changed, %s", params) for prop in self.entity_data.props: - if ( - prop.iid != params['piid'] - or prop.service.iid != params['siid'] - ): + if prop.iid != params["piid"] or prop.service.iid != params["siid"]: continue - value: Any = prop.value_format(params['value']) + value: Any = prop.value_format(params["value"]) self._prop_value_map[prop] = value if prop in self._prop_changed_subs: self._prop_changed_subs[prop](prop, value) @@ -1116,27 +1210,22 @@ class MIoTServiceEntity(Entity): self.async_write_ha_state() def __on_event_occurred(self, params: dict, ctx: Any) -> None: - _LOGGER.debug('event occurred, %s', params) + _LOGGER.debug("event occurred, %s", params) if self._event_occurred_handler is None: return for event in self.entity_data.events: - if ( - event.iid != params['eiid'] - or event.service.iid != params['siid'] - ): + if event.iid != params["eiid"] or event.service.iid != params["siid"]: continue trans_arg = {} - for item in params['arguments']: + for item in params["arguments"]: for prop in event.argument: - if prop.iid == item['piid']: - trans_arg[prop.description_trans] = item['value'] + if prop.iid == item["piid"]: + trans_arg[prop.description_trans] = item["value"] break self._event_occurred_handler(event, trans_arg) break - def __on_device_state_changed( - self, key: str, state: MIoTDeviceState - ) -> None: + def __on_device_state_changed(self, key: str, state: MIoTDeviceState) -> None: state_new = state == MIoTDeviceState.ONLINE if state_new == self._attr_available: return @@ -1151,11 +1240,13 @@ class MIoTServiceEntity(Entity): if not prop.readable: continue self.miot_device.miot_client.request_refresh_prop( - did=self.miot_device.did, siid=prop.service.iid, piid=prop.iid) + did=self.miot_device.did, siid=prop.service.iid, piid=prop.iid + ) if self._pending_write_ha_state_timer: self._pending_write_ha_state_timer.cancel() self._pending_write_ha_state_timer = self._main_loop.call_later( - 1, self.__write_ha_state_handler) + 1, self.__write_ha_state_handler + ) def __write_ha_state_handler(self) -> None: self._pending_write_ha_state_timer = None @@ -1164,6 +1255,7 @@ class MIoTServiceEntity(Entity): class MIoTPropertyEntity(Entity): """MIoT Property Entity.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes miot_device: MIoTDevice @@ -1182,7 +1274,7 @@ class MIoTPropertyEntity(Entity): def __init__(self, miot_device: MIoTDevice, spec: MIoTSpecProperty) -> None: if miot_device is None or spec is None or spec.service is None: - raise MIoTDeviceError('init error, invalid params') + raise MIoTDeviceError("init error, invalid params") self.miot_device = miot_device self.spec = spec self.service = spec.service @@ -1195,22 +1287,28 @@ class MIoTPropertyEntity(Entity): self._pending_write_ha_state_timer = None # Gen entity_id self.entity_id = self.miot_device.gen_prop_entity_id( - ha_domain=DOMAIN, spec_name=spec.name, - siid=spec.service.iid, piid=spec.iid) + ha_domain=DOMAIN, spec_name=spec.name, siid=spec.service.iid, piid=spec.iid + ) # Set entity attr self._attr_unique_id = self.entity_id self._attr_should_poll = False self._attr_has_entity_name = True self._attr_name = ( - f'{"* "if self.spec.proprietary else " "}' - f'{self.service.description_trans} {spec.description_trans}') + f"{'* ' if self.spec.proprietary else ' '}" + f"{self.service.description_trans} {spec.description_trans}" + ) self._attr_available = miot_device.online _LOGGER.info( - 'new miot property entity, %s, %s, %s, %s, %s, %s, %s', - self.miot_device.name, self._attr_name, spec.platform, - spec.device_class, self.entity_id, self._value_range, - self._value_list) + "new miot property entity, %s, %s, %s, %s, %s, %s, %s", + self.miot_device.name, + self._attr_name, + spec.platform, + spec.device_class, + self.entity_id, + self._value_range, + self._value_list, + ) @property def device_info(self) -> Optional[DeviceInfo]: @@ -1219,12 +1317,13 @@ class MIoTPropertyEntity(Entity): async def async_added_to_hass(self) -> None: # Sub device state changed self._state_sub_id = self.miot_device.sub_device_state( - key=f'{ self.service.iid}.{self.spec.iid}', - handler=self.__on_device_state_changed) + key=f"{self.service.iid}.{self.spec.iid}", + handler=self.__on_device_state_changed, + ) # Sub value changed self._value_sub_id = self.miot_device.sub_property( - handler=self.__on_value_changed, - siid=self.service.iid, piid=self.spec.iid) + handler=self.__on_value_changed, siid=self.service.iid, piid=self.spec.iid + ) # Refresh value if self._attr_available: self.__request_refresh_prop() @@ -1234,11 +1333,11 @@ class MIoTPropertyEntity(Entity): self._pending_write_ha_state_timer.cancel() self._pending_write_ha_state_timer = None self.miot_device.unsub_device_state( - key=f'{ self.service.iid}.{self.spec.iid}', - sub_id=self._state_sub_id) + key=f"{self.service.iid}.{self.spec.iid}", sub_id=self._state_sub_id + ) self.miot_device.unsub_property( - siid=self.service.iid, piid=self.spec.iid, - sub_id=self._value_sub_id) + siid=self.service.iid, piid=self.spec.iid, sub_id=self._value_sub_id + ) def get_vlist_description(self, value: Any) -> Optional[str]: if not self._value_list: @@ -1253,16 +1352,18 @@ class MIoTPropertyEntity(Entity): async def set_property_async(self, value: Any) -> bool: if not self.spec.writable: raise RuntimeError( - f'set property failed, not writable, ' - f'{self.entity_id}, {self.name}') + f"set property failed, not writable, {self.entity_id}, {self.name}" + ) value = self.spec.value_format(value) try: await self.miot_device.miot_client.set_prop_async( - did=self.miot_device.did, siid=self.spec.service.iid, - piid=self.spec.iid, value=value) + did=self.miot_device.did, + siid=self.spec.service.iid, + piid=self.spec.iid, + value=value, + ) except MIoTClientError as e: - raise RuntimeError( - f'{e}, {self.entity_id}, {self.name}') from e + raise RuntimeError(f"{e}, {self.entity_id}, {self.name}") from e self._value = value self.async_write_ha_state() return True @@ -1270,24 +1371,23 @@ class MIoTPropertyEntity(Entity): async def get_property_async(self) -> Any: if not self.spec.readable: _LOGGER.error( - 'get property failed, not readable, %s, %s', - self.entity_id, self.name) + "get property failed, not readable, %s, %s", self.entity_id, self.name + ) return None return self.spec.value_format( await self.miot_device.miot_client.get_prop_async( - did=self.miot_device.did, siid=self.spec.service.iid, - piid=self.spec.iid)) + did=self.miot_device.did, siid=self.spec.service.iid, piid=self.spec.iid + ) + ) def __on_value_changed(self, params: dict, ctx: Any) -> None: - _LOGGER.debug('property changed, %s', params) - self._value = self.spec.value_format(params['value']) + _LOGGER.debug("property changed, %s", params) + self._value = self.spec.value_format(params["value"]) self._value = self.spec.eval_expr(self._value) if not self._pending_write_ha_state_timer: self.async_write_ha_state() - def __on_device_state_changed( - self, key: str, state: MIoTDeviceState - ) -> None: + def __on_device_state_changed(self, key: str, state: MIoTDeviceState) -> None: self._attr_available = state == MIoTDeviceState.ONLINE if not self._attr_available: self.async_write_ha_state() @@ -1298,12 +1398,13 @@ class MIoTPropertyEntity(Entity): def __request_refresh_prop(self) -> None: if self.spec.readable: self.miot_device.miot_client.request_refresh_prop( - did=self.miot_device.did, siid=self.service.iid, - piid=self.spec.iid) + did=self.miot_device.did, siid=self.service.iid, piid=self.spec.iid + ) if self._pending_write_ha_state_timer: self._pending_write_ha_state_timer.cancel() self._pending_write_ha_state_timer = self._main_loop.call_later( - 1, self.__write_ha_state_handler) + 1, self.__write_ha_state_handler + ) def __write_ha_state_handler(self) -> None: self._pending_write_ha_state_timer = None @@ -1312,6 +1413,7 @@ class MIoTPropertyEntity(Entity): class MIoTEventEntity(Entity): """MIoT Event Entity.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes miot_device: MIoTDevice @@ -1326,22 +1428,23 @@ class MIoTEventEntity(Entity): def __init__(self, miot_device: MIoTDevice, spec: MIoTSpecEvent) -> None: if miot_device is None or spec is None or spec.service is None: - raise MIoTDeviceError('init error, invalid params') + raise MIoTDeviceError("init error, invalid params") self.miot_device = miot_device self.spec = spec self.service = spec.service self._main_loop = miot_device.miot_client.main_loop # Gen entity_id self.entity_id = self.miot_device.gen_event_entity_id( - ha_domain=DOMAIN, spec_name=spec.name, - siid=spec.service.iid, eiid=spec.iid) + ha_domain=DOMAIN, spec_name=spec.name, siid=spec.service.iid, eiid=spec.iid + ) # Set entity attr self._attr_unique_id = self.entity_id self._attr_should_poll = False self._attr_has_entity_name = True self._attr_name = ( - f'{"* "if self.spec.proprietary else " "}' - f'{self.service.description_trans} {spec.description_trans}') + f"{'* ' if self.spec.proprietary else ' '}" + f"{self.service.description_trans} {spec.description_trans}" + ) self._attr_available = miot_device.online self._attr_event_types = [spec.description_trans] @@ -1352,9 +1455,13 @@ class MIoTEventEntity(Entity): self._value_sub_id = 0 _LOGGER.info( - 'new miot event entity, %s, %s, %s, %s, %s', - self.miot_device.name, self._attr_name, spec.platform, - spec.device_class, self.entity_id) + "new miot event entity, %s, %s, %s, %s, %s", + self.miot_device.name, + self._attr_name, + spec.platform, + spec.device_class, + self.entity_id, + ) @property def device_info(self) -> Optional[DeviceInfo]: @@ -1363,20 +1470,21 @@ class MIoTEventEntity(Entity): async def async_added_to_hass(self) -> None: # Sub device state changed self._state_sub_id = self.miot_device.sub_device_state( - key=f'event.{ self.service.iid}.{self.spec.iid}', - handler=self.__on_device_state_changed) + key=f"event.{self.service.iid}.{self.spec.iid}", + handler=self.__on_device_state_changed, + ) # Sub value changed self._value_sub_id = self.miot_device.sub_event( - handler=self.__on_event_occurred, - siid=self.service.iid, eiid=self.spec.iid) + handler=self.__on_event_occurred, siid=self.service.iid, eiid=self.spec.iid + ) async def async_will_remove_from_hass(self) -> None: self.miot_device.unsub_device_state( - key=f'event.{ self.service.iid}.{self.spec.iid}', - sub_id=self._state_sub_id) + key=f"event.{self.service.iid}.{self.spec.iid}", sub_id=self._state_sub_id + ) self.miot_device.unsub_event( - siid=self.service.iid, eiid=self.spec.iid, - sub_id=self._value_sub_id) + siid=self.service.iid, eiid=self.spec.iid, sub_id=self._value_sub_id + ) @abstractmethod def on_event_occurred( @@ -1384,36 +1492,34 @@ class MIoTEventEntity(Entity): ) -> None: ... def __on_event_occurred(self, params: dict, ctx: Any) -> None: - _LOGGER.debug('event occurred, %s', params) + _LOGGER.debug("event occurred, %s", params) trans_arg = {} - for item in params['arguments']: + for item in params["arguments"]: try: - if 'value' not in item: + if "value" not in item: continue - if 'piid' in item: - trans_arg[self._arguments_map[item['piid']]] = item[ - 'value'] - elif ( - isinstance(item['value'], list) - and len(item['value']) == len(self.spec.argument) + if "piid" in item: + trans_arg[self._arguments_map[item["piid"]]] = item["value"] + elif isinstance(item["value"], list) and len(item["value"]) == len( + self.spec.argument ): # Dirty fix for cloud multi-arguments trans_arg = { - prop.description_trans: item['value'][index] + prop.description_trans: item["value"][index] for index, prop in enumerate(self.spec.argument) } break except KeyError as error: _LOGGER.debug( - 'on event msg, invalid args, %s, %s, %s', - self.entity_id, params, error) - self.on_event_occurred( - name=self.spec.description_trans, arguments=trans_arg) + "on event msg, invalid args, %s, %s, %s", + self.entity_id, + params, + error, + ) + self.on_event_occurred(name=self.spec.description_trans, arguments=trans_arg) self.async_write_ha_state() - def __on_device_state_changed( - self, key: str, state: MIoTDeviceState - ) -> None: + def __on_device_state_changed(self, key: str, state: MIoTDeviceState) -> None: state_new = state == MIoTDeviceState.ONLINE if state_new == self._attr_available: return @@ -1423,6 +1529,7 @@ class MIoTEventEntity(Entity): class MIoTActionEntity(Entity): """MIoT Action Entity.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes miot_device: MIoTDevice @@ -1436,7 +1543,7 @@ class MIoTActionEntity(Entity): def __init__(self, miot_device: MIoTDevice, spec: MIoTSpecAction) -> None: if miot_device is None or spec is None or spec.service is None: - raise MIoTDeviceError('init error, invalid params') + raise MIoTDeviceError("init error, invalid params") self.miot_device = miot_device self.spec = spec self.service = spec.service @@ -1444,21 +1551,26 @@ class MIoTActionEntity(Entity): self._state_sub_id = 0 # Gen entity_id self.entity_id = self.miot_device.gen_action_entity_id( - ha_domain=DOMAIN, spec_name=spec.name, - siid=spec.service.iid, aiid=spec.iid) + ha_domain=DOMAIN, spec_name=spec.name, siid=spec.service.iid, aiid=spec.iid + ) # Set entity attr self._attr_unique_id = self.entity_id self._attr_should_poll = False self._attr_has_entity_name = True self._attr_name = ( - f'{"* "if self.spec.proprietary else " "}' - f'{self.service.description_trans} {spec.description_trans}') + f"{'* ' if self.spec.proprietary else ' '}" + f"{self.service.description_trans} {spec.description_trans}" + ) self._attr_available = miot_device.online _LOGGER.debug( - 'new miot action entity, %s, %s, %s, %s, %s', - self.miot_device.name, self._attr_name, spec.platform, - spec.device_class, self.entity_id) + "new miot action entity, %s, %s, %s, %s, %s", + self.miot_device.name, + self._attr_name, + spec.platform, + spec.device_class, + self.entity_id, + ) @property def device_info(self) -> Optional[DeviceInfo]: @@ -1466,29 +1578,27 @@ class MIoTActionEntity(Entity): async def async_added_to_hass(self) -> None: self._state_sub_id = self.miot_device.sub_device_state( - key=f'a.{ self.service.iid}.{self.spec.iid}', - handler=self.__on_device_state_changed) + key=f"a.{self.service.iid}.{self.spec.iid}", + handler=self.__on_device_state_changed, + ) async def async_will_remove_from_hass(self) -> None: self.miot_device.unsub_device_state( - key=f'a.{ self.service.iid}.{self.spec.iid}', - sub_id=self._state_sub_id) + key=f"a.{self.service.iid}.{self.spec.iid}", sub_id=self._state_sub_id + ) - async def action_async( - self, in_list: Optional[list] = None - ) -> Optional[list]: + async def action_async(self, in_list: Optional[list] = None) -> Optional[list]: try: return await self.miot_device.miot_client.action_async( did=self.miot_device.did, siid=self.service.iid, aiid=self.spec.iid, - in_list=in_list or []) + in_list=in_list or [], + ) except MIoTClientError as e: - raise RuntimeError(f'{e}, {self.entity_id}, {self.name}') from e + raise RuntimeError(f"{e}, {self.entity_id}, {self.name}") from e - def __on_device_state_changed( - self, key: str, state: MIoTDeviceState - ) -> None: + def __on_device_state_changed(self, key: str, state: MIoTDeviceState) -> None: state_new = state == MIoTDeviceState.ONLINE if state_new == self._attr_available: return diff --git a/custom_components/xiaomi_home/miot/miot_lan.py b/custom_components/xiaomi_home/miot/miot_lan.py index 5c56a55..3c1d4c6 100644 --- a/custom_components/xiaomi_home/miot/miot_lan.py +++ b/custom_components/xiaomi_home/miot/miot_lan.py @@ -46,7 +46,6 @@ off Xiaomi or its affiliates' products. MIoT lan device control, only support MIoT SPEC-v2 WiFi devices. """ - import json import time import asyncio @@ -58,7 +57,7 @@ import secrets import socket import struct import threading -from typing import Any, Callable, Coroutine, Optional, final +from typing import Any, Callable, Coroutine, Dict, List, Optional, final from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding from cryptography.hazmat.backends import default_backend @@ -68,8 +67,7 @@ from cryptography.hazmat.primitives import hashes from .miot_error import MIoTError, MIoTLanError, MIoTErrorCode from .miot_network import InterfaceStatus, MIoTNetwork, NetworkInfo from .miot_mdns import MipsService, MipsServiceState -from .common import ( - randomize_float, load_yaml_file, gen_absolute_path, MIoTMatcher) +from .common import randomize_float, load_yaml_file, gen_absolute_path, MIoTMatcher _LOGGER = logging.getLogger(__name__) @@ -130,6 +128,7 @@ class _MIoTLanDeviceState(Enum): class _MIoTLanDevice: """MIoT lan device.""" + # pylint: disable=unused-argument OT_HEADER: int = 0x2131 OT_HEADER_LEN: int = 32 @@ -151,7 +150,7 @@ class _MIoTLanDevice: sub_ts: int supported_wildcard_sub: bool - _manager: 'MIoTLan' + _manager: "MIoTLan" _if_name: Optional[str] _sub_locked: bool _state: _MIoTLanDeviceState @@ -162,14 +161,10 @@ class _MIoTLanDevice: _ka_timer: Optional[asyncio.TimerHandle] _ka_internal: float -# All functions SHOULD be called from the internal loop + # All functions SHOULD be called from the internal loop def __init__( - self, - manager: 'MIoTLan', - did: str, - token: str, - ip: Optional[str] = None + self, manager: "MIoTLan", did: str, token: str, ip: Optional[str] = None ) -> None: self._manager: MIoTLan = manager self.did = did @@ -177,7 +172,8 @@ class _MIoTLanDevice: aes_key: bytes = self.__md5(self.token) aex_iv: bytes = self.__md5(aes_key + self.token) self.cipher = Cipher( - algorithms.AES128(aes_key), modes.CBC(aex_iv), default_backend()) + algorithms.AES128(aes_key), modes.CBC(aex_iv), default_backend() + ) self.ip = ip self.offset = 0 self.subscribed = False @@ -193,17 +189,18 @@ class _MIoTLanDevice: def ka_init_handler() -> None: self._ka_internal = self.KA_INTERVAL_MIN self.__update_keep_alive(state=_MIoTLanDeviceState.DEAD) + self._ka_timer = self._manager.internal_loop.call_later( randomize_float(self.CONSTRUCT_STATE_PENDING, 0.5), - ka_init_handler,) - _LOGGER.debug('miot lan device add, %s', self.did) + ka_init_handler, + ) + _LOGGER.debug("miot lan device add, %s", self.did) def keep_alive(self, ip: str, if_name: str) -> None: self.ip = ip if self._if_name != if_name: self._if_name = if_name - _LOGGER.info( - 'device if_name change, %s, %s', self._if_name, self.did) + _LOGGER.info("device if_name change, %s, %s", self._if_name, self.did) self.__update_keep_alive(state=_MIoTLanDeviceState.FRESH) @property @@ -216,8 +213,9 @@ class _MIoTLanDevice: return self._online = online self._manager.broadcast_device_state( - did=self.did, state={ - 'online': self._online, 'push_available': self.subscribed}) + did=self.did, + state={"online": self._online, "push_available": self.subscribed}, + ) @property def if_name(self) -> Optional[str]: @@ -226,37 +224,37 @@ class _MIoTLanDevice: def gen_packet( self, out_buffer: bytearray, clear_data: dict, did: str, offset: int ) -> int: - clear_bytes = json.dumps(clear_data, ensure_ascii=False).encode('utf-8') + clear_bytes = json.dumps(clear_data, ensure_ascii=False).encode("utf-8") padder = padding.PKCS7(algorithms.AES128.block_size).padder() padded_data = padder.update(clear_bytes) + padder.finalize() if len(padded_data) + self.OT_HEADER_LEN > len(out_buffer): - raise ValueError('rpc too long') + raise ValueError("rpc too long") encryptor = self.cipher.encryptor() encrypted_data = encryptor.update(padded_data) + encryptor.finalize() - data_len: int = len(encrypted_data)+self.OT_HEADER_LEN + data_len: int = len(encrypted_data) + self.OT_HEADER_LEN out_buffer[:32] = struct.pack( - '>HHQI16s', self.OT_HEADER, data_len, int(did), offset, - self.token) + ">HHQI16s", self.OT_HEADER, data_len, int(did), offset, self.token + ) out_buffer[32:data_len] = encrypted_data msg_md5: bytes = self.__md5(out_buffer[0:data_len]) out_buffer[16:32] = msg_md5 return data_len def decrypt_packet(self, encrypted_data: bytearray) -> dict: - data_len: int = struct.unpack('>H', encrypted_data[2:4])[0] + data_len: int = struct.unpack(">H", encrypted_data[2:4])[0] md5_orig: bytes = encrypted_data[16:32] encrypted_data[16:32] = self.token md5_calc: bytes = self.__md5(encrypted_data[0:data_len]) if md5_orig != md5_calc: - raise ValueError(f'invalid md5, {md5_orig}, {md5_calc}') + raise ValueError(f"invalid md5, {md5_orig}, {md5_calc}") decryptor = self.cipher.decryptor() - decrypted_padded_data = decryptor.update( - encrypted_data[32:data_len]) + decryptor.finalize() + decrypted_padded_data = ( + decryptor.update(encrypted_data[32:data_len]) + decryptor.finalize() + ) unpadder = padding.PKCS7(algorithms.AES128.block_size).unpadder() - decrypted_data = unpadder.update( - decrypted_padded_data) + unpadder.finalize() + decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() # Some device will add a redundant \0 at the end of JSON string - decrypted_data = decrypted_data.rstrip(b'\x00') + decrypted_data = decrypted_data.rstrip(b"\x00") return json.loads(decrypted_data) def subscribe(self) -> None: @@ -268,19 +266,20 @@ class _MIoTLanDevice: self._manager.send2device( did=self.did, msg={ - 'method': 'miIO.sub', - 'params': { - 'version': '2.0', - 'did': self._manager.virtual_did, - 'update_ts': sub_ts, - 'sub_method': '.' - } + "method": "miIO.sub", + "params": { + "version": "2.0", + "did": self._manager.virtual_did, + "update_ts": sub_ts, + "sub_method": ".", + }, }, handler=self.__subscribe_handler, handler_ctx=sub_ts, - timeout_ms=5000) + timeout_ms=5000, + ) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('subscribe device error, %s', err) + _LOGGER.error("subscribe device error, %s", err) self._sub_locked = False @@ -290,20 +289,22 @@ class _MIoTLanDevice: self._manager.send2device( did=self.did, msg={ - 'method': 'miIO.unsub', - 'params': { - 'version': '2.0', - 'did': self._manager.virtual_did, - 'update_ts': self.sub_ts or 0, - 'sub_method': '.' - } + "method": "miIO.unsub", + "params": { + "version": "2.0", + "did": self._manager.virtual_did, + "update_ts": self.sub_ts or 0, + "sub_method": ".", + }, }, handler=self.__unsubscribe_handler, - timeout_ms=5000) + timeout_ms=5000, + ) self.subscribed = False self._manager.broadcast_device_state( - did=self.did, state={ - 'online': self._online, 'push_available': self.subscribed}) + did=self.did, + state={"online": self._online, "push_available": self.subscribed}, + ) def on_delete(self) -> None: if self._ka_timer: @@ -312,53 +313,54 @@ class _MIoTLanDevice: if self._online_offline_timer: self._online_offline_timer.cancel() self._online_offline_timer = None - _LOGGER.debug('miot lan device delete, %s', self.did) + _LOGGER.debug("miot lan device delete, %s", self.did) def update_info(self, info: dict) -> None: if ( - 'token' in info - and len(info['token']) == 32 - and info['token'].upper() != self.token.hex().upper() + "token" in info + and len(info["token"]) == 32 + and info["token"].upper() != self.token.hex().upper() ): # Update token - self.token = bytes.fromhex(info['token']) + self.token = bytes.fromhex(info["token"]) aes_key: bytes = self.__md5(self.token) aex_iv: bytes = self.__md5(aes_key + self.token) self.cipher = Cipher( - algorithms.AES128(aes_key), - modes.CBC(aex_iv), default_backend()) - _LOGGER.debug('update token, %s', self.did) + algorithms.AES128(aes_key), modes.CBC(aex_iv), default_backend() + ) + _LOGGER.debug("update token, %s", self.did) def __subscribe_handler(self, msg: dict, sub_ts: int) -> None: if ( - 'result' not in msg - or 'code' not in msg['result'] - or msg['result']['code'] != 0 + "result" not in msg + or "code" not in msg["result"] + or msg["result"]["code"] != 0 ): - _LOGGER.error('subscribe device error, %s, %s', self.did, msg) + _LOGGER.error("subscribe device error, %s, %s", self.did, msg) return self.subscribed = True self.sub_ts = sub_ts self._manager.broadcast_device_state( - did=self.did, state={ - 'online': self._online, 'push_available': self.subscribed}) - _LOGGER.info('subscribe success, %s, %s', self._if_name, self.did) + did=self.did, + state={"online": self._online, "push_available": self.subscribed}, + ) + _LOGGER.info("subscribe success, %s, %s", self._if_name, self.did) def __unsubscribe_handler(self, msg: dict, ctx: Any) -> None: if ( - 'result' not in msg - or 'code' not in msg['result'] - or msg['result']['code'] != 0 + "result" not in msg + or "code" not in msg["result"] + or msg["result"]["code"] != 0 ): - _LOGGER.error('unsubscribe device error, %s, %s', self.did, msg) + _LOGGER.error("unsubscribe device error, %s, %s", self.did, msg) return - _LOGGER.info('unsubscribe success, %s, %s', self._if_name, self.did) + _LOGGER.info("unsubscribe success, %s, %s", self._if_name, self.did) def __update_keep_alive(self, state: _MIoTLanDeviceState) -> None: last_state: _MIoTLanDeviceState = self._state self._state = state if self._state != _MIoTLanDeviceState.FRESH: - _LOGGER.debug('device status, %s, %s', self.did, self._state) + _LOGGER.debug("device status, %s, %s", self.did, self._state) if self._ka_timer: self._ka_timer.cancel() self._ka_timer = None @@ -368,24 +370,27 @@ class _MIoTLanDevice: self._ka_internal = self.KA_INTERVAL_MIN self.__change_online(True) self._ka_timer = self._manager.internal_loop.call_later( - self.__get_next_ka_timeout(), self.__update_keep_alive, - _MIoTLanDeviceState.PING1) + self.__get_next_ka_timeout(), + self.__update_keep_alive, + _MIoTLanDeviceState.PING1, + ) case ( - _MIoTLanDeviceState.PING1 - | _MIoTLanDeviceState.PING2 - | _MIoTLanDeviceState.PING3 + _MIoTLanDeviceState.PING1 + | _MIoTLanDeviceState.PING2 + | _MIoTLanDeviceState.PING3 ): # Set the timer first to avoid Any early returns self._ka_timer = self._manager.internal_loop.call_later( - self.FAST_PING_INTERVAL, self.__update_keep_alive, - _MIoTLanDeviceState(state.value+1)) + self.FAST_PING_INTERVAL, + self.__update_keep_alive, + _MIoTLanDeviceState(state.value + 1), + ) # Fast ping if self._if_name is None: - _LOGGER.error( - 'if_name is Not set for device, %s', self.did) + _LOGGER.error("if_name is Not set for device, %s", self.did) return if self.ip is None: - _LOGGER.error('ip is Not set for device, %s', self.did) + _LOGGER.error("ip is Not set for device, %s", self.did) return self._manager.ping(if_name=self._if_name, target_ip=self.ip) case _MIoTLanDeviceState.DEAD: @@ -393,16 +398,16 @@ class _MIoTLanDevice: self._ka_internal = self.KA_INTERVAL_MIN self.__change_online(False) case _: - _LOGGER.error('invalid state, %s', state) + _LOGGER.error("invalid state, %s", state) def __get_next_ka_timeout(self) -> float: - self._ka_internal = min(self._ka_internal*2, self.KA_INTERVAL_MAX) + self._ka_internal = min(self._ka_internal * 2, self.KA_INTERVAL_MAX) return randomize_float(self._ka_internal, 0.1) def __change_online(self, online: bool) -> None: - _LOGGER.info('change online, %s, %s', self.did, online) + _LOGGER.info("change online, %s, %s", self.did, online) ts_now: int = int(time.time()) - self._online_offline_history.append({'ts': ts_now, 'online': online}) + self._online_offline_history.append({"ts": ts_now, "online": online}) if len(self._online_offline_history) > self.NETWORK_UNSTABLE_CNT_TH: self._online_offline_history.pop(0) if self._online_offline_timer: @@ -411,22 +416,19 @@ class _MIoTLanDevice: if not online: self.online = False else: - if ( - len(self._online_offline_history) < self.NETWORK_UNSTABLE_CNT_TH - or ( - ts_now - self._online_offline_history[0]['ts'] > - self.NETWORK_UNSTABLE_TIME_TH) + if len(self._online_offline_history) < self.NETWORK_UNSTABLE_CNT_TH or ( + ts_now - self._online_offline_history[0]["ts"] + > self.NETWORK_UNSTABLE_TIME_TH ): self.online = True else: - _LOGGER.info('unstable device detected, %s', self.did) - self._online_offline_timer = ( - self._manager.internal_loop.call_later( - self.NETWORK_UNSTABLE_RESUME_TH, - self.__online_resume_handler)) + _LOGGER.info("unstable device detected, %s", self.did) + self._online_offline_timer = self._manager.internal_loop.call_later( + self.NETWORK_UNSTABLE_RESUME_TH, self.__online_resume_handler + ) def __online_resume_handler(self) -> None: - _LOGGER.info('unstable resume threshold past, %s', self.did) + _LOGGER.info("unstable resume threshold past, %s", self.did) self.online = True def __md5(self, data: bytes) -> bytes: @@ -437,9 +439,10 @@ class _MIoTLanDevice: class MIoTLan: """MIoT lan device control.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes - OT_HEADER: bytes = b'\x21\x31' + OT_HEADER: bytes = b"\x21\x31" OT_PORT: int = 54321 OT_PROBE_LEN: int = 32 OT_MSG_LEN: int = 1400 @@ -448,7 +451,7 @@ class MIoTLan: OT_PROBE_INTERVAL_MIN: float = 5 OT_PROBE_INTERVAL_MAX: float = 45 - PROFILE_MODELS_FILE: str = 'lan/profile_models.yaml' + PROFILE_MODELS_FILE: str = "lan/profile_models.yaml" _main_loop: asyncio.AbstractEventLoop _net_ifs: set[str] @@ -483,7 +486,7 @@ class MIoTLan: _init_lock: asyncio.Lock _init_done: bool -# The following should be called from the main loop + # The following should be called from the main loop def __init__( self, @@ -492,32 +495,33 @@ class MIoTLan: mips_service: MipsService, enable_subscribe: bool = False, virtual_did: Optional[int] = None, - loop: Optional[asyncio.AbstractEventLoop] = None + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not network: - raise ValueError('network is required') + raise ValueError("network is required") if not mips_service: - raise ValueError('mips_service is required') + raise ValueError("mips_service is required") self._main_loop = loop or asyncio.get_event_loop() self._net_ifs = set(net_ifs) self._network = network self._network.sub_network_info( - key='miot_lan', - handler=self.__on_network_info_change_external_async) + key="miot_lan", handler=self.__on_network_info_change_external_async + ) self._mips_service = mips_service self._mips_service.sub_service_change( - key='miot_lan', group_id='*', - handler=self.__on_mips_service_change) + key="miot_lan", group_id="*", handler=self.__on_mips_service_change + ) self._enable_subscribe = enable_subscribe self._virtual_did = ( - str(virtual_did) if (virtual_did is not None) - else str(secrets.randbits(64))) + str(virtual_did) if (virtual_did is not None) else str(secrets.randbits(64)) + ) # Init socket probe message probe_bytes = bytearray(self.OT_PROBE_LEN) probe_bytes[:20] = ( - b'!1\x00\x20\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFFMDID') - probe_bytes[20:28] = struct.pack('>Q', int(self._virtual_did)) - probe_bytes[28:32] = b'\x00\x00\x00\x00' + b"!1\x00\x20\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xffMDID" + ) + probe_bytes[20:28] = struct.pack(">Q", int(self._virtual_did)) + probe_bytes[28:32] = b"\x00\x00\x00\x00" self._probe_msg = bytes(probe_bytes) self._read_buffer = bytearray(self.OT_MSG_LEN) self._write_buffer = bytearray(self.OT_MSG_LEN) @@ -528,7 +532,7 @@ class MIoTLan: self._local_port = None self._scan_timer = None self._last_scan_interval = None - self._msg_id_counter = int(random.random()*0x7FFFFFFF) + self._msg_id_counter = int(random.random() * 0x7FFFFFFF) self._pending_requests = {} self._device_msg_matcher = MIoTMatcher() self._device_state_sub_map = {} @@ -540,20 +544,17 @@ class MIoTLan: self._init_lock = asyncio.Lock() self._init_done = False - if ( - len(self._mips_service.get_services()) == 0 - and len(self._net_ifs) > 0 - ): - _LOGGER.info('no central hub gateway service, init miot lan') + if len(self._mips_service.get_services()) == 0 and len(self._net_ifs) > 0: + _LOGGER.info("no central hub gateway service, init miot lan") self._main_loop.call_later( - 0, lambda: self._main_loop.create_task( - self.init_async())) + 0, lambda: self._main_loop.create_task(self.init_async()) + ) def __assert_service_ready(self) -> None: if not self._init_done: raise MIoTLanError( - 'MIoT lan is not ready', - MIoTErrorCode.CODE_LAN_UNAVAILABLE) + "MIoT lan is not ready", MIoTErrorCode.CODE_LAN_UNAVAILABLE + ) @property def virtual_did(self) -> str: @@ -571,55 +572,57 @@ class MIoTLan: # Avoid race condition async with self._init_lock: if self._init_done: - _LOGGER.info('miot lan already init') + _LOGGER.info("miot lan already init") return if len(self._net_ifs) == 0: - _LOGGER.info('no net_ifs') + _LOGGER.info("no net_ifs") return if not any(self._lan_ctrl_vote_map.values()): - _LOGGER.info('no vote for lan ctrl') + _LOGGER.info("no vote for lan ctrl") return if len(self._mips_service.get_services()) > 0: - _LOGGER.info('central hub gateway service exist') + _LOGGER.info("central hub gateway service exist") return for if_name in list(self._network.network_info.keys()): self._available_net_ifs.add(if_name) if len(self._available_net_ifs) == 0: - _LOGGER.info('no available net_ifs') + _LOGGER.info("no available net_ifs") return if self._net_ifs.isdisjoint(self._available_net_ifs): - _LOGGER.info('no valid net_ifs') + _LOGGER.info("no valid net_ifs") return try: self._profile_models = await self._main_loop.run_in_executor( - None, load_yaml_file, - gen_absolute_path(self.PROFILE_MODELS_FILE)) + None, load_yaml_file, gen_absolute_path(self.PROFILE_MODELS_FILE) + ) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('load profile models error, %s', err) + _LOGGER.error("load profile models error, %s", err) self._profile_models = {} self._internal_loop = asyncio.new_event_loop() # All tasks meant for the internal loop should happen in this thread self._thread = threading.Thread(target=self.__internal_loop_thread) - self._thread.name = 'miot_lan' + self._thread.name = "miot_lan" self._thread.daemon = True self._thread.start() self._init_done = True for handler in list(self._lan_state_sub_map.values()): self._main_loop.create_task(handler(True)) _LOGGER.info( - 'miot lan init, %s ,%s', self._net_ifs, self._available_net_ifs) + "miot lan init, %s ,%s", self._net_ifs, self._available_net_ifs + ) def __internal_loop_thread(self) -> None: - _LOGGER.info('miot lan thread start') + _LOGGER.info("miot lan thread start") self.__init_socket() self._scan_timer = self._internal_loop.call_later( - int(3*random.random()), self.__scan_devices) + int(3 * random.random()), self.__scan_devices + ) self._internal_loop.run_forever() - _LOGGER.info('miot lan thread exit') + _LOGGER.info("miot lan thread exit") async def deinit_async(self) -> None: if not self._init_done: - _LOGGER.info('miot lan not init') + _LOGGER.info("miot lan not init") return self._init_done = False self._internal_loop.call_soon_threadsafe(self.__deinit) @@ -632,19 +635,19 @@ class MIoTLan: self._local_port = None self._scan_timer = None self._last_scan_interval = None - self._msg_id_counter = int(random.random()*0x7FFFFFFF) + self._msg_id_counter = int(random.random() * 0x7FFFFFFF) self._pending_requests = {} self._device_msg_matcher = MIoTMatcher() self._device_state_sub_map = {} self._reply_msg_buffer = {} for handler in list(self._lan_state_sub_map.values()): self._main_loop.create_task(handler(False)) - _LOGGER.info('miot lan deinit') + _LOGGER.info("miot lan deinit") async def update_net_ifs_async(self, net_ifs: list[str]) -> None: - _LOGGER.info('update net_ifs, %s', net_ifs) + _LOGGER.info("update net_ifs, %s", net_ifs) if not isinstance(net_ifs, list): - _LOGGER.error('invalid net_ifs, %s', net_ifs) + _LOGGER.error("invalid net_ifs, %s", net_ifs) return if len(net_ifs) == 0: # Deinit lan @@ -655,7 +658,7 @@ class MIoTLan: for if_name in list(self._network.network_info.keys()): available_net_ifs.add(if_name) if set(net_ifs).isdisjoint(available_net_ifs): - _LOGGER.error('no valid net_ifs, %s', net_ifs) + _LOGGER.error("no valid net_ifs, %s", net_ifs) await self.deinit_async() self._net_ifs = set(net_ifs) self._available_net_ifs = available_net_ifs @@ -664,12 +667,10 @@ class MIoTLan: self._net_ifs = set(net_ifs) await self.init_async() return - self._internal_loop.call_soon_threadsafe( - self.__update_net_ifs, - net_ifs) + self._internal_loop.call_soon_threadsafe(self.__update_net_ifs, net_ifs) async def vote_for_lan_ctrl_async(self, key: str, vote: bool) -> None: - _LOGGER.info('vote for lan ctrl, %s, %s', key, vote) + _LOGGER.info("vote for lan ctrl, %s, %s", key, vote) self._lan_ctrl_vote_map[key] = vote if not any(self._lan_ctrl_vote_map.values()): await self.deinit_async() @@ -677,33 +678,29 @@ class MIoTLan: await self.init_async() async def update_subscribe_option(self, enable_subscribe: bool) -> None: - _LOGGER.info('update subscribe option, %s', enable_subscribe) + _LOGGER.info("update subscribe option, %s", enable_subscribe) if not self._init_done: self._enable_subscribe = enable_subscribe return self._internal_loop.call_soon_threadsafe( - self.__update_subscribe_option, - {'enable_subscribe': enable_subscribe}) + self.__update_subscribe_option, {"enable_subscribe": enable_subscribe} + ) def update_devices(self, devices: dict[str, dict]) -> bool: - _LOGGER.info('update devices, %s', devices) + _LOGGER.info("update devices, %s", devices) if not self._init_done: return False - self._internal_loop.call_soon_threadsafe( - self.__update_devices, devices) + self._internal_loop.call_soon_threadsafe(self.__update_devices, devices) return True def delete_devices(self, devices: list[str]) -> bool: - _LOGGER.info('delete devices, %s', devices) + _LOGGER.info("delete devices, %s", devices) if not self._init_done: return False - self._internal_loop.call_soon_threadsafe( - self.__delete_devices, devices) + self._internal_loop.call_soon_threadsafe(self.__delete_devices, devices) return True - def sub_lan_state( - self, key: str, handler: Callable[[bool], Coroutine] - ) -> None: + def sub_lan_state(self, key: str, handler: Callable[[bool], Coroutine]) -> None: self._lan_state_sub_map[key] = handler def unsub_lan_state(self, key: str) -> None: @@ -711,15 +708,17 @@ class MIoTLan: @final def sub_device_state( - self, key: str, handler: Callable[[str, dict, Any], Coroutine], - handler_ctx: Any = None + self, + key: str, + handler: Callable[[str, dict, Any], Coroutine], + handler_ctx: Any = None, ) -> bool: if not self._init_done: return False self._internal_loop.call_soon_threadsafe( self.__sub_device_state, - _MIoTLanSubDeviceData( - key=key, handler=handler, handler_ctx=handler_ctx)) + _MIoTLanSubDeviceData(key=key, handler=handler, handler_ctx=handler_ctx), + ) return True @final @@ -727,7 +726,8 @@ class MIoTLan: if not self._init_done: return False self._internal_loop.call_soon_threadsafe( - self.__unsub_device_state, _MIoTLanUnsubDeviceData(key=key)) + self.__unsub_device_state, _MIoTLanUnsubDeviceData(key=key) + ) return True @final @@ -737,38 +737,33 @@ class MIoTLan: handler: Callable[[dict, Any], None], siid: Optional[int] = None, piid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: if not self._init_done: return False if not self._enable_subscribe: return False - key = ( - f'{did}/p/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + key = f"{did}/p/{'#' if siid is None or piid is None else f'{siid}/{piid}'}" self._internal_loop.call_soon_threadsafe( self.__sub_broadcast, _MIoTLanRegisterBroadcastData( - key=key, handler=handler, handler_ctx=handler_ctx)) + key=key, handler=handler, handler_ctx=handler_ctx + ), + ) return True @final def unsub_prop( - self, - did: str, - siid: Optional[int] = None, - piid: Optional[int] = None + self, did: str, siid: Optional[int] = None, piid: Optional[int] = None ) -> bool: if not self._init_done: return False if not self._enable_subscribe: return False - key = ( - f'{did}/p/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + key = f"{did}/p/{'#' if siid is None or piid is None else f'{siid}/{piid}'}" self._internal_loop.call_soon_threadsafe( - self.__unsub_broadcast, - _MIoTLanUnregisterBroadcastData(key=key)) + self.__unsub_broadcast, _MIoTLanUnregisterBroadcastData(key=key) + ) return True @final @@ -778,38 +773,33 @@ class MIoTLan: handler: Callable[[dict, Any], None], siid: Optional[int] = None, eiid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: if not self._init_done: return False if not self._enable_subscribe: return False - key = ( - f'{did}/e/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + key = f"{did}/e/{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" self._internal_loop.call_soon_threadsafe( self.__sub_broadcast, _MIoTLanRegisterBroadcastData( - key=key, handler=handler, handler_ctx=handler_ctx)) + key=key, handler=handler, handler_ctx=handler_ctx + ), + ) return True @final def unsub_event( - self, - did: str, - siid: Optional[int] = None, - eiid: Optional[int] = None + self, did: str, siid: Optional[int] = None, eiid: Optional[int] = None ) -> bool: if not self._init_done: return False if not self._enable_subscribe: return False - key = ( - f'{did}/e/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + key = f"{did}/e/{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" self._internal_loop.call_soon_threadsafe( - self.__unsub_broadcast, - _MIoTLanUnregisterBroadcastData(key=key)) + self.__unsub_broadcast, _MIoTLanUnregisterBroadcastData(key=key) + ) return True @final @@ -818,103 +808,131 @@ class MIoTLan: ) -> Any: self.__assert_service_ready() result_obj = await self.__call_api_async( - did=did, msg={ - 'method': 'get_properties', - 'params': [{'did': did, 'siid': siid, 'piid': piid}] - }, timeout_ms=timeout_ms) + did=did, + msg={ + "method": "get_properties", + "params": [{"did": did, "siid": siid, "piid": piid}], + }, + timeout_ms=timeout_ms, + ) if ( - result_obj and 'result' in result_obj - and len(result_obj['result']) == 1 - and 'did' in result_obj['result'][0] - and result_obj['result'][0]['did'] == did + result_obj + and "result" in result_obj + and len(result_obj["result"]) == 1 + and "did" in result_obj["result"][0] + and result_obj["result"][0]["did"] == did ): - return result_obj['result'][0].get('value', None) + return result_obj["result"][0].get("value", None) return None @final async def set_prop_async( - self, did: str, siid: int, piid: int, value: Any, - timeout_ms: int = 10000 + self, did: str, siid: int, piid: int, value: Any, timeout_ms: int = 10000 ) -> dict: self.__assert_service_ready() result_obj = await self.__call_api_async( - did=did, msg={ - 'method': 'set_properties', - 'params': [{ - 'did': did, 'siid': siid, 'piid': piid, 'value': value}] - }, timeout_ms=timeout_ms) + did=did, + msg={ + "method": "set_properties", + "params": [{"did": did, "siid": siid, "piid": piid, "value": value}], + }, + timeout_ms=timeout_ms, + ) if result_obj: if ( - 'result' in result_obj - and len(result_obj['result']) == 1 - and 'did' in result_obj['result'][0] - and result_obj['result'][0]['did'] == did - and 'code' in result_obj['result'][0] + "result" in result_obj + and len(result_obj["result"]) == 1 + and "did" in result_obj["result"][0] + and result_obj["result"][0]["did"] == did + and "code" in result_obj["result"][0] ): - return result_obj['result'][0] - if 'code' in result_obj: + return result_obj["result"][0] + if "code" in result_obj: return result_obj - raise MIoTError('Invalid result', MIoTErrorCode.CODE_INTERNAL_ERROR) + raise MIoTError("Invalid result", MIoTErrorCode.CODE_INTERNAL_ERROR) + + @final + async def set_props_async( + self, did: str, props_list: List[Dict[str, Any]], timeout_ms: int = 10000 + ) -> dict: + self.__assert_service_ready() + result_obj = await self.__call_api_async( + did=did, + msg={ + "method": "set_properties", + "params": props_list, + }, + timeout_ms=timeout_ms, + ) + if result_obj: + if ( + "result" in result_obj + and len(result_obj["result"]) == len(props_list) + and result_obj["result"][0].get("did") == did + and all("code" in item for item in result_obj["result"]) + ): + return result_obj["result"] + if "error" in result_obj: + return result_obj["error"] + return { + "code": MIoTErrorCode.CODE_INTERNAL_ERROR.value, + "message": "Invalid result", + } @final async def action_async( - self, did: str, siid: int, aiid: int, in_list: list, - timeout_ms: int = 10000 + self, did: str, siid: int, aiid: int, in_list: list, timeout_ms: int = 10000 ) -> dict: self.__assert_service_ready() result_obj = await self.__call_api_async( - did=did, msg={ - 'method': 'action', - 'params': { - 'did': did, 'siid': siid, 'aiid': aiid, 'in': in_list} - }, timeout_ms=timeout_ms) + did=did, + msg={ + "method": "action", + "params": {"did": did, "siid": siid, "aiid": aiid, "in": in_list}, + }, + timeout_ms=timeout_ms, + ) if result_obj: - if 'result' in result_obj and 'code' in result_obj['result']: - return result_obj['result'] - if 'code' in result_obj: + if "result" in result_obj and "code" in result_obj["result"]: + return result_obj["result"] + if "code" in result_obj: return result_obj - raise MIoTError('Invalid result', MIoTErrorCode.CODE_INTERNAL_ERROR) + raise MIoTError("Invalid result", MIoTErrorCode.CODE_INTERNAL_ERROR) @final - async def get_dev_list_async( - self, timeout_ms: int = 10000 - ) -> dict[str, dict]: + async def get_dev_list_async(self, timeout_ms: int = 10000) -> dict[str, dict]: if not self._init_done: return {} def get_device_list_handler(msg: dict, fut: asyncio.Future): - self._main_loop.call_soon_threadsafe( - fut.set_result, msg) + self._main_loop.call_soon_threadsafe(fut.set_result, msg) fut: asyncio.Future = self._main_loop.create_future() self._internal_loop.call_soon_threadsafe( self.__get_dev_list, _MIoTLanGetDevListData( - handler=get_device_list_handler, - handler_ctx=fut, - timeout_ms=timeout_ms)) + handler=get_device_list_handler, handler_ctx=fut, timeout_ms=timeout_ms + ), + ) return await fut async def __call_api_async( self, did: str, msg: dict, timeout_ms: int = 10000 ) -> dict: def call_api_handler(msg: dict, fut: asyncio.Future): - self._main_loop.call_soon_threadsafe( - fut.set_result, msg) + self._main_loop.call_soon_threadsafe(fut.set_result, msg) fut: asyncio.Future = self._main_loop.create_future() self._internal_loop.call_soon_threadsafe( - self.__call_api, did, msg, call_api_handler, fut, timeout_ms) + self.__call_api, did, msg, call_api_handler, fut, timeout_ms + ) return await fut async def __on_network_info_change_external_async( - self, - status: InterfaceStatus, - info: NetworkInfo + self, status: InterfaceStatus, info: NetworkInfo ) -> None: - _LOGGER.info( - 'on network info change, status: %s, info: %s', status, info) + _LOGGER.info("on network info change, status: %s, info: %s", status, info) available_net_ifs = set() for if_name in list(self._network.network_info.keys()): available_net_ifs.add(if_name) @@ -923,7 +941,7 @@ class MIoTLan: self._available_net_ifs = available_net_ifs return if self._net_ifs.isdisjoint(available_net_ifs): - _LOGGER.info('no valid net_ifs') + _LOGGER.info("no valid net_ifs") await self.deinit_async() self._available_net_ifs = available_net_ifs return @@ -933,62 +951,65 @@ class MIoTLan: return self._internal_loop.call_soon_threadsafe( self.__on_network_info_change, - _MIoTLanNetworkUpdateData(status=status, if_name=info.name)) + _MIoTLanNetworkUpdateData(status=status, if_name=info.name), + ) async def __on_mips_service_change( - self, group_id: str, state: MipsServiceState, data: dict + self, group_id: str, state: MipsServiceState, data: dict ) -> None: - _LOGGER.info( - 'on mips service change, %s, %s, %s', group_id, state, data) + _LOGGER.info("on mips service change, %s, %s, %s", group_id, state, data) if len(self._mips_service.get_services()) > 0: - _LOGGER.info('find central service, deinit miot lan') + _LOGGER.info("find central service, deinit miot lan") await self.deinit_async() else: - _LOGGER.info('no central service, init miot lan') + _LOGGER.info("no central service, init miot lan") await self.init_async() -# The following methods SHOULD ONLY be called in the internal loop + # The following methods SHOULD ONLY be called in the internal loop def ping(self, if_name: Optional[str], target_ip: str) -> None: if not target_ip: return self.__sendto( - if_name=if_name, data=self._probe_msg, address=target_ip, - port=self.OT_PORT) + if_name=if_name, data=self._probe_msg, address=target_ip, port=self.OT_PORT + ) def send2device( - self, did: str, + self, + did: str, msg: dict, handler: Optional[Callable[[dict, Any], None]] = None, handler_ctx: Any = None, - timeout_ms: Optional[int] = None + timeout_ms: Optional[int] = None, ) -> None: if timeout_ms and not handler: - raise ValueError('handler is required when timeout_ms is set') + raise ValueError("handler is required when timeout_ms is set") device: Optional[_MIoTLanDevice] = self._lan_devices.get(did) if not device: - raise ValueError('invalid device') + raise ValueError("invalid device") if not device.cipher: - raise ValueError('invalid device cipher') + raise ValueError("invalid device cipher") if not device.if_name: - raise ValueError('invalid device if_name') + raise ValueError("invalid device if_name") if not device.ip: - raise ValueError('invalid device ip') - in_msg = {'id': self.__gen_msg_id(), **msg} + raise ValueError("invalid device ip") + in_msg = {"id": self.__gen_msg_id(), **msg} msg_len = device.gen_packet( out_buffer=self._write_buffer, clear_data=in_msg, did=did, - offset=int(time.time())-device.offset) + offset=int(time.time()) - device.offset, + ) return self.__make_request( - msg_id=in_msg['id'], - msg=self._write_buffer[0: msg_len], + msg_id=in_msg["id"], + msg=self._write_buffer[0:msg_len], if_name=device.if_name, ip=device.ip, handler=handler, handler_ctx=handler_ctx, - timeout_ms=timeout_ms) + timeout_ms=timeout_ms, + ) def __make_request( self, @@ -998,25 +1019,24 @@ class MIoTLan: ip: str, handler: Optional[Callable[[dict, Any], None]], handler_ctx: Any = None, - timeout_ms: Optional[int] = None + timeout_ms: Optional[int] = None, ) -> None: def request_timeout_handler(req_data: _MIoTLanRequestData): self._pending_requests.pop(req_data.msg_id, None) if req_data and req_data.handler: - req_data.handler({ - 'code': MIoTErrorCode.CODE_TIMEOUT.value, - 'error': 'timeout'}, - req_data.handler_ctx) + req_data.handler( + {"code": MIoTErrorCode.CODE_TIMEOUT.value, "error": "timeout"}, + req_data.handler_ctx, + ) timer: Optional[asyncio.TimerHandle] = None request_data = _MIoTLanRequestData( - msg_id=msg_id, - handler=handler, - handler_ctx=handler_ctx, - timeout=timer) + msg_id=msg_id, handler=handler, handler_ctx=handler_ctx, timeout=timer + ) if timeout_ms: timer = self._internal_loop.call_later( - timeout_ms/1000, request_timeout_handler, request_data) + timeout_ms / 1000, request_timeout_handler, request_data + ) request_data.timeout = timer self._pending_requests[msg_id] = request_data self.__sendto(if_name=if_name, data=msg, address=ip, port=self.OT_PORT) @@ -1025,11 +1045,12 @@ class MIoTLan: for handler in self._device_state_sub_map.values(): self._main_loop.call_soon_threadsafe( self._main_loop.create_task, - handler.handler(did, state, handler.handler_ctx)) + handler.handler(did, state, handler.handler_ctx), + ) def __gen_msg_id(self) -> int: if not self._msg_id_counter: - self._msg_id_counter = int(random.random()*0x7FFFFFFF) + self._msg_id_counter = int(random.random() * 0x7FFFFFFF) self._msg_id_counter += 1 if self._msg_id_counter > 0x80000000: self._msg_id_counter = 1 @@ -1041,21 +1062,22 @@ class MIoTLan: msg: dict, handler: Callable, handler_ctx: Any, - timeout_ms: int = 10000 + timeout_ms: int = 10000, ) -> None: try: self.send2device( did=did, - msg={'from': 'ha.xiaomi_home', **msg}, + msg={"from": "ha.xiaomi_home", **msg}, handler=handler, handler_ctx=handler_ctx, - timeout_ms=timeout_ms) + timeout_ms=timeout_ms, + ) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('send2device error, %s', err) - handler({ - 'code': MIoTErrorCode.CODE_INTERNAL_ERROR.value, - 'error': str(err)}, - handler_ctx) + _LOGGER.error("send2device error, %s", err) + handler( + {"code": MIoTErrorCode.CODE_INTERNAL_ERROR.value, "error": str(err)}, + handler_ctx, + ) def __sub_device_state(self, data: _MIoTLanSubDeviceData) -> None: self._device_state_sub_map[data.key] = data @@ -1065,51 +1087,44 @@ class MIoTLan: def __sub_broadcast(self, data: _MIoTLanRegisterBroadcastData) -> None: self._device_msg_matcher[data.key] = data - _LOGGER.debug('lan register broadcast, %s', data.key) + _LOGGER.debug("lan register broadcast, %s", data.key) def __unsub_broadcast(self, data: _MIoTLanUnregisterBroadcastData) -> None: if self._device_msg_matcher.get(topic=data.key): del self._device_msg_matcher[data.key] - _LOGGER.debug('lan unregister broadcast, %s', data.key) + _LOGGER.debug("lan unregister broadcast, %s", data.key) def __get_dev_list(self, data: _MIoTLanGetDevListData) -> None: dev_list = { - device.did: { - 'online': device.online, - 'push_available': device.subscribed - } + device.did: {"online": device.online, "push_available": device.subscribed} for device in self._lan_devices.values() - if device.online} - data.handler( - dev_list, data.handler_ctx) + if device.online + } + data.handler(dev_list, data.handler_ctx) def __update_devices(self, devices: dict[str, dict]) -> None: for did, info in devices.items(): # did MUST be digit(UINT64) if not did.isdigit(): - _LOGGER.info('invalid did, %s', did) + _LOGGER.info("invalid did, %s", did) continue - if ( - 'model' not in info - or info['model'] in self._profile_models): + if "model" not in info or info["model"] in self._profile_models: # Do not support the local control of # Profile device for the time being _LOGGER.info( - 'model not support local ctrl, %s, %s', - did, info.get('model')) + "model not support local ctrl, %s, %s", did, info.get("model") + ) continue if did not in self._lan_devices: - if 'token' not in info: - _LOGGER.error( - 'token not found, %s, %s', did, info) + if "token" not in info: + _LOGGER.error("token not found, %s, %s", did, info) continue - if len(info['token']) != 32: - _LOGGER.error( - 'invalid device token, %s, %s', did, info) + if len(info["token"]) != 32: + _LOGGER.error("invalid device token, %s, %s", did, info) continue self._lan_devices[did] = _MIoTLanDevice( - manager=self, did=did, token=info['token'], - ip=info.get('ip', None)) + manager=self, did=did, token=info["token"], ip=info.get("ip", None) + ) else: self._lan_devices[did].update_info(info) @@ -1139,9 +1154,9 @@ class MIoTLan: self.__destroy_socket(if_name=if_name) def __update_subscribe_option(self, options: dict) -> None: - if 'enable_subscribe' in options: - if options['enable_subscribe'] != self._enable_subscribe: - self._enable_subscribe = options['enable_subscribe'] + if "enable_subscribe" in options: + if options["enable_subscribe"] != self._enable_subscribe: + self._enable_subscribe = options["enable_subscribe"] if not self._enable_subscribe: # Unsubscribe all for device in self._lan_devices.values(): @@ -1176,26 +1191,24 @@ class MIoTLan: def __create_socket(self, if_name: str) -> None: if if_name in self._broadcast_socks: - _LOGGER.info('socket already created, %s', if_name) + _LOGGER.info("socket already created, %s", if_name) return # Create socket try: - sock = socket.socket( - socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Set SO_BINDTODEVICE - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_BINDTODEVICE, if_name.encode()) - sock.bind(('', self._local_port or 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, if_name.encode()) + sock.bind(("", self._local_port or 0)) self._internal_loop.add_reader( - sock.fileno(), self.__socket_read_handler, (if_name, sock)) + sock.fileno(), self.__socket_read_handler, (if_name, sock) + ) self._broadcast_socks[if_name] = sock self._local_port = self._local_port or sock.getsockname()[1] - _LOGGER.info( - 'created socket, %s, %s', if_name, self._local_port) + _LOGGER.info("created socket, %s, %s", if_name, self._local_port) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('create socket error, %s, %s', if_name, err) + _LOGGER.error("create socket error, %s, %s", if_name, err) def __deinit_socket(self) -> None: for if_name in list(self._broadcast_socks.keys()): @@ -1208,23 +1221,25 @@ class MIoTLan: return self._internal_loop.remove_reader(sock.fileno()) sock.close() - _LOGGER.info('destroyed socket, %s', if_name) + _LOGGER.info("destroyed socket, %s", if_name) def __socket_read_handler(self, ctx: tuple[str, socket.socket]) -> None: try: data_len, addr = ctx[1].recvfrom_into( - self._read_buffer, self.OT_MSG_LEN, socket.MSG_DONTWAIT) + self._read_buffer, self.OT_MSG_LEN, socket.MSG_DONTWAIT + ) if data_len < 0: # Socket error - _LOGGER.error('socket read error, %s, %s', ctx[0], data_len) + _LOGGER.error("socket read error, %s, %s", ctx[0], data_len) return if addr[1] != self.OT_PORT: # Not ot msg return self.__raw_message_handler( - self._read_buffer[:data_len], data_len, addr[0], ctx[0]) + self._read_buffer[:data_len], data_len, addr[0], ctx[0] + ) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('socket read handler error, %s', err) + _LOGGER.error("socket read handler error, %s", err) def __raw_message_handler( self, data: bytearray, data_len: int, ip: str, if_name: str @@ -1232,11 +1247,11 @@ class MIoTLan: if data[:2] != self.OT_HEADER: return # Keep alive message - did: str = str(struct.unpack('>Q', data[4:12])[0]) + did: str = str(struct.unpack(">Q", data[4:12])[0]) device: Optional[_MIoTLanDevice] = self._lan_devices.get(did) if not device: return - timestamp: int = struct.unpack('>I', data[12:16])[0] + timestamp: int = struct.unpack(">I", data[12:16])[0] device.offset = int(time.time()) - timestamp # Keep alive if this is a probe if data_len == self.OT_PROBE_LEN or device.subscribed: @@ -1245,12 +1260,13 @@ class MIoTLan: if ( self._enable_subscribe and data_len == self.OT_PROBE_LEN - and data[16:20] == b'MSUB' - and data[24:27] == b'PUB' + and data[16:20] == b"MSUB" + and data[24:27] == b"PUB" ): device.supported_wildcard_sub = ( - int(data[28]) == self.OT_SUPPORT_WILDCARD_SUB) - sub_ts = struct.unpack('>I', data[20:24])[0] + int(data[28]) == self.OT_SUPPORT_WILDCARD_SUB + ) + sub_ts = struct.unpack(">I", data[20:24])[0] sub_type = int(data[27]) if ( device.supported_wildcard_sub @@ -1264,74 +1280,70 @@ class MIoTLan: try: decrypted_data = device.decrypt_packet(data) self.__message_handler(did, decrypted_data) - except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('decrypt packet error, %s, %s', did, err) + except Exception as err: # pylint: disable=broad-exception-caught + _LOGGER.error("decrypt packet error, %s, %s", did, err) return def __message_handler(self, did: str, msg: dict) -> None: - if 'id' not in msg: - _LOGGER.warning('invalid message, no id, %s, %s', did, msg) + if "id" not in msg: + _LOGGER.warning("invalid message, no id, %s, %s", did, msg) return # Reply - req: Optional[_MIoTLanRequestData] = ( - self._pending_requests.pop(msg['id'], None)) + req: Optional[_MIoTLanRequestData] = self._pending_requests.pop(msg["id"], None) if req: if req.timeout: req.timeout.cancel() req.timeout = None if req.handler is not None: - self._main_loop.call_soon_threadsafe( - req.handler, msg, req.handler_ctx) + self._main_loop.call_soon_threadsafe(req.handler, msg, req.handler_ctx) return # Handle up link message - if 'method' not in msg or 'params' not in msg: - _LOGGER.debug( - 'invalid message, no method or params, %s, %s', did, msg) + if "method" not in msg or "params" not in msg: + _LOGGER.debug("invalid message, no method or params, %s, %s", did, msg) return # Filter dup message - if self.__filter_dup_message(did, msg['id']): - self.send2device( - did=did, msg={'id': msg['id'], 'result': {'code': 0}}) + if self.__filter_dup_message(did, msg["id"]): + self.send2device(did=did, msg={"id": msg["id"], "result": {"code": 0}}) return - _LOGGER.debug('lan message, %s, %s', did, msg) - if msg['method'] == 'properties_changed': - for param in msg['params']: - if 'siid' not in param and 'piid' not in param: - _LOGGER.debug( - 'invalid message, no siid or piid, %s, %s', did, msg) + _LOGGER.debug("lan message, %s, %s", did, msg) + if msg["method"] == "properties_changed": + for param in msg["params"]: + if "siid" not in param and "piid" not in param: + _LOGGER.debug("invalid message, no siid or piid, %s, %s", did, msg) continue - key = f'{did}/p/{param["siid"]}/{param["piid"]}' + key = f"{did}/p/{param['siid']}/{param['piid']}" subs: list[_MIoTLanRegisterBroadcastData] = list( - self._device_msg_matcher.iter_match(key)) + self._device_msg_matcher.iter_match(key) + ) for sub in subs: self._main_loop.call_soon_threadsafe( - sub.handler, param, sub.handler_ctx) + sub.handler, param, sub.handler_ctx + ) elif ( - msg['method'] == 'event_occured' - and 'siid' in msg['params'] - and 'eiid' in msg['params'] + msg["method"] == "event_occured" + and "siid" in msg["params"] + and "eiid" in msg["params"] ): - key = f'{did}/e/{msg["params"]["siid"]}/{msg["params"]["eiid"]}' + key = f"{did}/e/{msg['params']['siid']}/{msg['params']['eiid']}" subs: list[_MIoTLanRegisterBroadcastData] = list( - self._device_msg_matcher.iter_match(key)) + self._device_msg_matcher.iter_match(key) + ) for sub in subs: self._main_loop.call_soon_threadsafe( - sub.handler, msg['params'], sub.handler_ctx) + sub.handler, msg["params"], sub.handler_ctx + ) else: - _LOGGER.debug( - 'invalid message, unknown method, %s, %s', did, msg) + _LOGGER.debug("invalid message, unknown method, %s, %s", did, msg) # Reply - self.send2device( - did=did, msg={'id': msg['id'], 'result': {'code': 0}}) + self.send2device(did=did, msg={"id": msg["id"], "result": {"code": 0}}) def __filter_dup_message(self, did: str, msg_id: int) -> bool: - filter_id = f'{did}.{msg_id}' + filter_id = f"{did}.{msg_id}" if filter_id in self._reply_msg_buffer: return True self._reply_msg_buffer[filter_id] = self._internal_loop.call_later( - 5, - lambda filter_id: self._reply_msg_buffer.pop(filter_id, None), - filter_id) + 5, lambda filter_id: self._reply_msg_buffer.pop(filter_id, None), filter_id + ) return False def __sendto( @@ -1340,13 +1352,13 @@ class MIoTLan: if if_name is None: # Broadcast for if_n, sock in self._broadcast_socks.items(): - _LOGGER.debug('send broadcast, %s', if_n) + _LOGGER.debug("send broadcast, %s", if_n) sock.sendto(data, socket.MSG_DONTWAIT, (address, port)) else: # Unicast sock = self._broadcast_socks.get(if_name, None) if not sock: - _LOGGER.error('invalid socket, %s', if_name) + _LOGGER.error("invalid socket, %s", if_name) return sock.sendto(data, socket.MSG_DONTWAIT, (address, port)) @@ -1356,19 +1368,21 @@ class MIoTLan: self._scan_timer = None try: # Scan devices - self.ping(if_name=None, target_ip='255.255.255.255') + self.ping(if_name=None, target_ip="255.255.255.255") except Exception as err: # pylint: disable=broad-exception-caught # Ignore any exceptions to avoid blocking the loop - _LOGGER.error('ping device error, %s', err) + _LOGGER.error("ping device error, %s", err) pass scan_time = self.__get_next_scan_time() self._scan_timer = self._internal_loop.call_later( - scan_time, self.__scan_devices) - _LOGGER.debug('next scan time: %ss', scan_time) + scan_time, self.__scan_devices + ) + _LOGGER.debug("next scan time: %ss", scan_time) def __get_next_scan_time(self) -> float: if not self._last_scan_interval: self._last_scan_interval = self.OT_PROBE_INTERVAL_MIN self._last_scan_interval = min( - self._last_scan_interval*2, self.OT_PROBE_INTERVAL_MAX) + self._last_scan_interval * 2, self.OT_PROBE_INTERVAL_MAX + ) return self._last_scan_interval diff --git a/custom_components/xiaomi_home/miot/miot_mips.py b/custom_components/xiaomi_home/miot/miot_mips.py index f1a4534..f641cc3 100644 --- a/custom_components/xiaomi_home/miot/miot_mips.py +++ b/custom_components/xiaomi_home/miot/miot_mips.py @@ -45,6 +45,7 @@ off Xiaomi or its affiliates' products. MIoT Pub/Sub client. """ + import asyncio import json import logging @@ -56,14 +57,15 @@ import threading from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto -from typing import Any, Callable, Optional, final, Coroutine +from typing import Any, Callable, Dict, List, Optional, final, Coroutine from paho.mqtt.client import ( MQTT_ERR_SUCCESS, MQTT_ERR_UNKNOWN, Client, MQTTv5, - MQTTMessage) + MQTTMessage, +) # pylint: disable=relative-beyond-top-level from .common import MIoTMatcher @@ -75,6 +77,7 @@ _LOGGER = logging.getLogger(__name__) class _MipsMsgTypeOptions(Enum): """MIoT Pub/Sub message type.""" + ID = 0 RET_TOPIC = auto() PAYLOAD = auto() @@ -84,38 +87,35 @@ class _MipsMsgTypeOptions(Enum): class _MipsMessage: """MIoT Pub/Sub message.""" + mid: int = 0 msg_from: Optional[str] = None ret_topic: Optional[str] = None payload: Optional[str] = None @staticmethod - def unpack(data: bytes) -> '_MipsMessage': + def unpack(data: bytes) -> "_MipsMessage": mips_msg = _MipsMessage() data_len = len(data) data_start = 0 data_end = 0 while data_start < data_len: - data_end = data_start+5 - unpack_len, unpack_type = struct.unpack( - ' bytes: if mid is None or payload is None: - raise MIoTMipsError('invalid mid or payload') - pack_msg: bytes = b'' + raise MIoTMipsError("invalid mid or payload") + pack_msg: bytes = b"" # mid - pack_msg += struct.pack(' str: - return f'{self.mid}, {self.msg_from}, {self.ret_topic}, {self.payload}' + return f"{self.mid}, {self.msg_from}, {self.ret_topic}, {self.payload}" @dataclass class _MipsRequest: """MIoT Pub/Sub request.""" + mid: int on_reply: Callable[[str, Any], None] on_reply_ctx: Any @@ -165,6 +175,7 @@ class _MipsRequest: @dataclass class _MipsBroadcast: """MIoT Pub/Sub broadcast.""" + topic: str """ param 1: msg topic @@ -175,12 +186,13 @@ class _MipsBroadcast: handler_ctx: Any def __str__(self) -> str: - return f'{self.topic}, {id(self.handler)}, {id(self.handler_ctx)}' + return f"{self.topic}, {id(self.handler)}, {id(self.handler_ctx)}" @dataclass class _MipsState: """MIoT Pub/Sub state.""" + key: str """ str: key @@ -191,6 +203,7 @@ class _MipsState: class MIoTDeviceState(Enum): """MIoT device state define.""" + DISABLE = 0 OFFLINE = auto() ONLINE = auto() @@ -199,6 +212,7 @@ class MIoTDeviceState(Enum): @dataclass class MipsDeviceState: """MIoT Pub/Sub device state.""" + did: Optional[str] = None """handler str: did @@ -211,6 +225,7 @@ class MipsDeviceState: class _MipsClient(ABC): """MIoT Pub/Sub client.""" + # pylint: disable=unused-argument MQTT_INTERVAL_S = 1 MIPS_QOS: int = 2 @@ -249,16 +264,16 @@ class _MipsClient(ABC): _mips_sub_pending_timer: Optional[asyncio.TimerHandle] def __init__( - self, - client_id: str, - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - ca_file: Optional[str] = None, - cert_file: Optional[str] = None, - key_file: Optional[str] = None, - loop: Optional[asyncio.AbstractEventLoop] = None + self, + client_id: str, + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + ca_file: Optional[str] = None, + cert_file: Optional[str] = None, + key_file: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: # MUST run with running loop self.main_loop = loop or asyncio.get_running_loop() @@ -323,8 +338,7 @@ class _MipsClient(ABC): self._internal_loop = asyncio.new_event_loop() self._mips_thread = threading.Thread(target=self.__mips_loop_thread) self._mips_thread.daemon = True - self._mips_thread.name = ( - self._client_id if thread_name is None else thread_name) + self._mips_thread.name = self._client_id if thread_name is None else thread_name self._mips_thread.start() async def connect_async(self) -> None: @@ -381,27 +395,24 @@ class _MipsClient(ABC): def update_mqtt_password(self, password: str) -> None: self._password = password if self._mqtt: - self._mqtt.username_pw_set( - username=self._username, password=self._password) + self._mqtt.username_pw_set(username=self._username, password=self._password) def log_debug(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.debug(f'{self._client_id}, '+msg, *args, **kwargs) + self._logger.debug(f"{self._client_id}, " + msg, *args, **kwargs) def log_info(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.info(f'{self._client_id}, '+msg, *args, **kwargs) + self._logger.info(f"{self._client_id}, " + msg, *args, **kwargs) def log_error(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.error(f'{self._client_id}, '+msg, *args, **kwargs) + self._logger.error(f"{self._client_id}, " + msg, *args, **kwargs) def enable_logger(self, logger: Optional[logging.Logger] = None) -> None: self._logger = logger - def enable_mqtt_logger( - self, logger: Optional[logging.Logger] = None - ) -> None: + def enable_mqtt_logger(self, logger: Optional[logging.Logger] = None) -> None: self._mqtt_logger = logger if self._mqtt: if logger: @@ -419,21 +430,21 @@ class _MipsClient(ABC): So use mutex instead of IPC. """ if isinstance(key, str) is False or handler is None: - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") state = _MipsState(key=key, handler=handler) with self._mips_state_sub_map_lock: self._mips_state_sub_map[key] = state - self.log_debug(f'mips register mips state, {key}') + self.log_debug(f"mips register mips state, {key}") return True @final def unsub_mips_state(self, key: str) -> bool: """Unsubscribe mips state.""" if isinstance(key, str) is False: - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") with self._mips_state_sub_map_lock: del self._mips_state_sub_map[key] - self.log_debug(f'mips unregister mips state, {key}') + self.log_debug(f"mips unregister mips state, {key}") return True @abstractmethod @@ -443,15 +454,12 @@ class _MipsClient(ABC): handler: Callable[[dict, Any], None], siid: Optional[int] = None, piid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: ... @abstractmethod def unsub_prop( - self, - did: str, - siid: Optional[int] = None, - piid: Optional[int] = None + self, did: str, siid: Optional[int] = None, piid: Optional[int] = None ) -> bool: ... @abstractmethod @@ -461,22 +469,17 @@ class _MipsClient(ABC): handler: Callable[[dict, Any], None], siid: Optional[int] = None, eiid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: ... @abstractmethod def unsub_event( - self, - did: str, - siid: Optional[int] = None, - eiid: Optional[int] = None + self, did: str, siid: Optional[int] = None, eiid: Optional[int] = None ) -> bool: ... @abstractmethod async def get_dev_list_async( - self, - payload: Optional[str] = None, - timeout_ms: int = 10000 + self, payload: Optional[str] = None, timeout_ms: int = 10000 ) -> dict[str, dict]: ... @abstractmethod @@ -486,14 +489,12 @@ class _MipsClient(ABC): @abstractmethod async def set_prop_async( - self, did: str, siid: int, piid: int, value: Any, - timeout_ms: int = 10000 + self, did: str, siid: int, piid: int, value: Any, timeout_ms: int = 10000 ) -> dict: ... @abstractmethod async def action_async( - self, did: str, siid: int, aiid: int, in_list: list, - timeout_ms: int = 10000 + self, did: str, siid: int, aiid: int, in_list: list, timeout_ms: int = 10000 ) -> dict: ... @abstractmethod @@ -518,10 +519,11 @@ class _MipsClient(ABC): self._mips_sub_pending_map[topic] = 0 if not self._mips_sub_pending_timer: self._mips_sub_pending_timer = self._internal_loop.call_later( - 0.01, self.__mips_sub_internal_pending_handler, topic) + 0.01, self.__mips_sub_internal_pending_handler, topic + ) except Exception as err: # pylint: disable=broad-exception-caught # Catch all exception - self.log_error(f'mips sub internal error, {topic}. {err}') + self.log_error(f"mips sub internal error, {topic}. {err}") @final def _mips_unsub_internal(self, topic: str) -> None: @@ -534,19 +536,20 @@ class _MipsClient(ABC): try: result, mid = self._mqtt.unsubscribe(topic=topic) if result == MQTT_ERR_SUCCESS: - self.log_debug( - f'mips unsub internal success, {result}, {mid}, {topic}') + self.log_debug(f"mips unsub internal success, {result}, {mid}, {topic}") return - self.log_error( - f'mips unsub internal error, {result}, {mid}, {topic}') + self.log_error(f"mips unsub internal error, {result}, {mid}, {topic}") except Exception as err: # pylint: disable=broad-exception-caught # Catch all exception - self.log_error(f'mips unsub internal error, {topic}, {err}') + self.log_error(f"mips unsub internal error, {topic}, {err}") @final def _mips_publish_internal( - self, topic: str, payload: str | bytes, - wait_for_publish: bool = False, timeout_ms: int = 10000 + self, + topic: str, + payload: str | bytes, + wait_for_publish: bool = False, + timeout_ms: int = 10000, ) -> bool: """mips publish message. NOTICE: Internal function, only mips threads are allowed to call @@ -556,20 +559,19 @@ class _MipsClient(ABC): if not self._mqtt or not self._mqtt.is_connected(): return False try: - handle = self._mqtt.publish( - topic=topic, payload=payload, qos=self.MIPS_QOS) + handle = self._mqtt.publish(topic=topic, payload=payload, qos=self.MIPS_QOS) # self.log_debug(f'_mips_publish_internal, {topic}, {payload}') if wait_for_publish is True: - handle.wait_for_publish(timeout_ms/1000.0) + handle.wait_for_publish(timeout_ms / 1000.0) return True except Exception as err: # pylint: disable=broad-exception-caught # Catch other exception - self.log_error(f'mips publish internal error, {err}') + self.log_error(f"mips publish internal error, {err}") return False def __thread_check(self) -> None: if threading.current_thread() is not self._mips_thread: - raise MIoTMipsError('illegal call') + raise MIoTMipsError("illegal call") def __mqtt_read_handler(self) -> None: self.__mqtt_loop_handler() @@ -582,7 +584,8 @@ class _MipsClient(ABC): self.__mqtt_loop_handler() if self._mqtt: self._mqtt_timer = self._internal_loop.call_later( - self.MQTT_INTERVAL_S, self.__mqtt_timer_handler) + self.MQTT_INTERVAL_S, self.__mqtt_timer_handler + ) def __mqtt_loop_handler(self) -> None: try: @@ -593,33 +596,28 @@ class _MipsClient(ABC): if self._mqtt: self._mqtt.loop_misc() if self._mqtt and self._mqtt.want_write(): - self._internal_loop.add_writer( - self._mqtt_fd, self.__mqtt_write_handler) + self._internal_loop.add_writer(self._mqtt_fd, self.__mqtt_write_handler) except Exception as err: # pylint: disable=broad-exception-caught # Catch all exception - self.log_error(f'__mqtt_loop_handler, {err}') + self.log_error(f"__mqtt_loop_handler, {err}") raise err def __mips_loop_thread(self) -> None: - self.log_info('mips_loop_thread start') + self.log_info("mips_loop_thread start") # mqtt init for API_VERSION2, # callback_api_version=CallbackAPIVersion.VERSION2, self._mqtt = Client(client_id=self._client_id, protocol=MQTTv5) self._mqtt.enable_logger(logger=self._mqtt_logger) # Set mqtt config if self._username: - self._mqtt.username_pw_set( - username=self._username, password=self._password) - if ( - self._ca_file - and self._cert_file - and self._key_file - ): + self._mqtt.username_pw_set(username=self._username, password=self._password) + if self._ca_file and self._cert_file and self._key_file: self._mqtt.tls_set( tls_version=ssl.PROTOCOL_TLS_CLIENT, ca_certs=self._ca_file, certfile=self._cert_file, - keyfile=self._key_file) + keyfile=self._key_file, + ) else: self._mqtt.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT) self._mqtt.tls_insecure_set(True) @@ -631,40 +629,38 @@ class _MipsClient(ABC): self.__mips_start_connect_tries() # Run event loop self._internal_loop.run_forever() - self.log_info('mips_loop_thread exit!') + self.log_info("mips_loop_thread exit!") def __on_connect(self, client, user_data, flags, rc, props) -> None: if not self._mqtt: - _LOGGER.error('__on_connect, but mqtt is None') + _LOGGER.error("__on_connect, but mqtt is None") return if not self._mqtt.is_connected(): return - self.log_info(f'mips connect, {flags}, {rc}, {props}') + self.log_info(f"mips connect, {flags}, {rc}, {props}") self._mqtt_state = True - self._internal_loop.call_soon( - self._on_mips_connect, rc, props) + self._internal_loop.call_soon(self._on_mips_connect, rc, props) with self._mips_state_sub_map_lock: for item in self._mips_state_sub_map.values(): if item.handler is None: continue self.main_loop.call_soon_threadsafe( - self.main_loop.create_task, - item.handler(item.key, True)) + self.main_loop.create_task, item.handler(item.key, True) + ) # Resolve future - self.main_loop.call_soon_threadsafe( - self._event_connect.set) - self.main_loop.call_soon_threadsafe( - self._event_disconnect.clear) + self.main_loop.call_soon_threadsafe(self._event_connect.set) + self.main_loop.call_soon_threadsafe(self._event_disconnect.clear) def __on_connect_failed(self, client: Client, user_data: Any) -> None: - self.log_error('mips connect failed') + self.log_error("mips connect failed") # Try to reconnect self.__mips_try_reconnect() - def __on_disconnect(self, client, user_data, rc, props) -> None: + def __on_disconnect(self, client, user_data, rc, props) -> None: if self._mqtt_state: (self.log_info if rc == 0 else self.log_error)( - f'mips disconnect, {rc}, {props}') + f"mips disconnect, {rc}, {props}" + ) self._mqtt_state = False if self._mqtt_timer: self._mqtt_timer.cancel() @@ -678,37 +674,28 @@ class _MipsClient(ABC): self._mips_sub_pending_timer.cancel() self._mips_sub_pending_timer = None self._mips_sub_pending_map = {} - self._internal_loop.call_soon( - self._on_mips_disconnect, rc, props) + self._internal_loop.call_soon(self._on_mips_disconnect, rc, props) # Call state sub handler with self._mips_state_sub_map_lock: for item in self._mips_state_sub_map.values(): if item.handler is None: continue self.main_loop.call_soon_threadsafe( - self.main_loop.create_task, - item.handler(item.key, False)) + self.main_loop.create_task, item.handler(item.key, False) + ) # Try to reconnect self.__mips_try_reconnect() # Set event - self.main_loop.call_soon_threadsafe( - self._event_disconnect.set) - self.main_loop.call_soon_threadsafe( - self._event_connect.clear) + self.main_loop.call_soon_threadsafe(self._event_disconnect.set) + self.main_loop.call_soon_threadsafe(self._event_connect.clear) - def __on_message( - self, - client: Client, - user_data: Any, - msg: MQTTMessage - ) -> None: + def __on_message(self, client: Client, user_data: Any, msg: MQTTMessage) -> None: self._on_mips_message(topic=msg.topic, payload=msg.payload) def __mips_sub_internal_pending_handler(self, ctx: Any) -> None: if not self._mqtt or not self._mqtt.is_connected(): - _LOGGER.error( - 'mips sub internal pending, but mqtt is None or disconnected') + _LOGGER.error("mips sub internal pending, but mqtt is None or disconnected") return subbed_count = 1 for topic in list(self._mips_sub_pending_map.keys()): @@ -717,28 +704,29 @@ class _MipsClient(ABC): count = self._mips_sub_pending_map[topic] if count > 3: self._mips_sub_pending_map.pop(topic) - self.log_error(f'retry mips sub internal error, {topic}') + self.log_error(f"retry mips sub internal error, {topic}") continue subbed_count += 1 result, mid = self._mqtt.subscribe(topic, qos=self.MIPS_QOS) if result == MQTT_ERR_SUCCESS: self._mips_sub_pending_map.pop(topic) - self.log_debug(f'mips sub internal success, {topic}') + self.log_debug(f"mips sub internal success, {topic}") continue - self._mips_sub_pending_map[topic] = count+1 + self._mips_sub_pending_map[topic] = count + 1 self.log_error( - f'retry mips sub internal, {count}, {topic}, {result}, {mid}') + f"retry mips sub internal, {count}, {topic}, {result}, {mid}" + ) if len(self._mips_sub_pending_map): self._mips_sub_pending_timer = self._internal_loop.call_later( - self.MIPS_SUB_INTERVAL, - self.__mips_sub_internal_pending_handler, None) + self.MIPS_SUB_INTERVAL, self.__mips_sub_internal_pending_handler, None + ) else: self._mips_sub_pending_timer = None def __mips_connect(self) -> None: if not self._mqtt: - _LOGGER.error('__mips_connect, but mqtt is None') + _LOGGER.error("__mips_connect, but mqtt is None") return result = MQTT_ERR_UNKNOWN if self._mips_reconnect_timer: @@ -754,30 +742,31 @@ class _MipsClient(ABC): self._internal_loop.remove_writer(self._mqtt_fd) self._mqtt_fd = -1 result = self._mqtt.connect( - host=self._host, port=self._port, - clean_start=True, keepalive=MIHOME_MQTT_KEEPALIVE) - self.log_info(f'__mips_connect success, {result}') + host=self._host, + port=self._port, + clean_start=True, + keepalive=MIHOME_MQTT_KEEPALIVE, + ) + self.log_info(f"__mips_connect success, {result}") except (TimeoutError, OSError) as error: - self.log_error('__mips_connect, connect error, %s', error) + self.log_error("__mips_connect, connect error, %s", error) if result == MQTT_ERR_SUCCESS: socket = self._mqtt.socket() if socket is None: - self.log_error( - '__mips_connect, connect success, but socket is None') + self.log_error("__mips_connect, connect success, but socket is None") self.__mips_try_reconnect() return self._mqtt_fd = socket.fileno() - self.log_debug(f'__mips_connect, _mqtt_fd, {self._mqtt_fd}') - self._internal_loop.add_reader( - self._mqtt_fd, self.__mqtt_read_handler) + self.log_debug(f"__mips_connect, _mqtt_fd, {self._mqtt_fd}") + self._internal_loop.add_reader(self._mqtt_fd, self.__mqtt_read_handler) if self._mqtt.want_write(): - self._internal_loop.add_writer( - self._mqtt_fd, self.__mqtt_write_handler) + self._internal_loop.add_writer(self._mqtt_fd, self.__mqtt_write_handler) self._mqtt_timer = self._internal_loop.call_later( - self.MQTT_INTERVAL_S, self.__mqtt_timer_handler) + self.MQTT_INTERVAL_S, self.__mqtt_timer_handler + ) else: - self.log_error(f'__mips_connect error result, {result}') + self.log_error(f"__mips_connect error result, {result}") self.__mips_try_reconnect() def __mips_try_reconnect(self, immediately: bool = False) -> None: @@ -789,10 +778,10 @@ class _MipsClient(ABC): interval: float = 0 if not immediately: interval = self.__get_next_reconnect_time() - self.log_error( - 'mips try reconnect after %ss', interval) + self.log_error("mips try reconnect after %ss", interval) self._mips_reconnect_timer = self._internal_loop.call_later( - interval, self.__mips_connect) + interval, self.__mips_connect + ) def __mips_start_connect_tries(self) -> None: self._mips_reconnect_tag = True @@ -825,26 +814,36 @@ class _MipsClient(ABC): self._mips_reconnect_interval = self.MIPS_RECONNECT_INTERVAL_MIN else: self._mips_reconnect_interval = min( - self._mips_reconnect_interval*2, - self.MIPS_RECONNECT_INTERVAL_MAX) + self._mips_reconnect_interval * 2, self.MIPS_RECONNECT_INTERVAL_MAX + ) return self._mips_reconnect_interval class MipsCloudClient(_MipsClient): """MIoT Pub/Sub Cloud Client.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes _msg_matcher: MIoTMatcher def __init__( - self, uuid: str, cloud_server: str, app_id: str, - token: str, port: int = 8883, - loop: Optional[asyncio.AbstractEventLoop] = None + self, + uuid: str, + cloud_server: str, + app_id: str, + token: str, + port: int = 8883, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: self._msg_matcher = MIoTMatcher() super().__init__( - client_id=f'ha.{uuid}', host=f'{cloud_server}-ha.mqtt.io.mi.com', - port=port, username=app_id, password=token, loop=loop) + client_id=f"ha.{uuid}", + host=f"{cloud_server}-ha.mqtt.io.mi.com", + port=port, + username=app_id, + password=token, + loop=loop, + ) @final def disconnect(self) -> None: @@ -853,7 +852,7 @@ class MipsCloudClient(_MipsClient): def update_access_token(self, access_token: str) -> bool: if not isinstance(access_token, str): - raise MIoTMipsError('invalid token') + raise MIoTMipsError("invalid token") self.update_mqtt_password(password=access_token) return True @@ -864,49 +863,48 @@ class MipsCloudClient(_MipsClient): handler: Callable[[dict, Any], None], siid: Optional[int] = None, piid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: if not isinstance(did, str) or handler is None: - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") topic: str = ( - f'device/{did}/up/properties_changed/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + f"device/{did}/up/properties_changed/" + f"{'#' if siid is None or piid is None else f'{siid}/{piid}'}" + ) def on_prop_msg(topic: str, payload: str, ctx: Any) -> None: try: msg: dict = json.loads(payload) except json.JSONDecodeError: - self.log_error( - f'on_prop_msg, invalid msg, {topic}, {payload}') + self.log_error(f"on_prop_msg, invalid msg, {topic}, {payload}") return if ( - not isinstance(msg.get('params', None), dict) - or 'siid' not in msg['params'] - or 'piid' not in msg['params'] - or 'value' not in msg['params'] + not isinstance(msg.get("params", None), dict) + or "siid" not in msg["params"] + or "piid" not in msg["params"] + or "value" not in msg["params"] ): - self.log_error( - f'on_prop_msg, invalid msg, {topic}, {payload}') + self.log_error(f"on_prop_msg, invalid msg, {topic}, {payload}") return if handler: - self.log_debug('on properties_changed, %s', payload) - handler(msg['params'], ctx) + self.log_debug("on properties_changed, %s", payload) + handler(msg["params"], ctx) + return self.__reg_broadcast_external( - topic=topic, handler=on_prop_msg, handler_ctx=handler_ctx) + topic=topic, handler=on_prop_msg, handler_ctx=handler_ctx + ) @final def unsub_prop( - self, - did: str, - siid: Optional[int] = None, - piid: Optional[int] = None + self, did: str, siid: Optional[int] = None, piid: Optional[int] = None ) -> bool: if not isinstance(did, str): - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") topic: str = ( - f'device/{did}/up/properties_changed/' - f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') + f"device/{did}/up/properties_changed/" + f"{'#' if siid is None or piid is None else f'{siid}/{piid}'}" + ) return self.__unreg_broadcast_external(topic=topic) @final @@ -916,136 +914,144 @@ class MipsCloudClient(_MipsClient): handler: Callable[[dict, Any], None], siid: Optional[int] = None, eiid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: if not isinstance(did, str) or handler is None: - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") # Spelling error: event_occured topic: str = ( - f'device/{did}/up/event_occured/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + f"device/{did}/up/event_occured/" + f"{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" + ) def on_event_msg(topic: str, payload: str, ctx: Any) -> None: try: msg: dict = json.loads(payload) except json.JSONDecodeError: - self.log_error( - f'on_event_msg, invalid msg, {topic}, {payload}') + self.log_error(f"on_event_msg, invalid msg, {topic}, {payload}") return if ( - not isinstance(msg.get('params', None), dict) - or 'siid' not in msg['params'] - or 'eiid' not in msg['params'] - or 'arguments' not in msg['params'] + not isinstance(msg.get("params", None), dict) + or "siid" not in msg["params"] + or "eiid" not in msg["params"] + or "arguments" not in msg["params"] ): - self.log_error( - f'on_event_msg, invalid msg, {topic}, {payload}') + self.log_error(f"on_event_msg, invalid msg, {topic}, {payload}") return if handler: - self.log_debug('on on_event_msg, %s', payload) - msg['params']['from'] = 'cloud' - handler(msg['params'], ctx) + self.log_debug("on on_event_msg, %s", payload) + msg["params"]["from"] = "cloud" + handler(msg["params"], ctx) + return self.__reg_broadcast_external( - topic=topic, handler=on_event_msg, handler_ctx=handler_ctx) + topic=topic, handler=on_event_msg, handler_ctx=handler_ctx + ) @final def unsub_event( - self, - did: str, - siid: Optional[int] = None, - eiid: Optional[int] = None + self, did: str, siid: Optional[int] = None, eiid: Optional[int] = None ) -> bool: if not isinstance(did, str): - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") # Spelling error: event_occured topic: str = ( - f'device/{did}/up/event_occured/' - f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') + f"device/{did}/up/event_occured/" + f"{'#' if siid is None or eiid is None else f'{siid}/{eiid}'}" + ) return self.__unreg_broadcast_external(topic=topic) @final def sub_device_state( - self, did: str, handler: Callable[[str, MIoTDeviceState, Any], None], - handler_ctx: Any = None + self, + did: str, + handler: Callable[[str, MIoTDeviceState, Any], None], + handler_ctx: Any = None, ) -> bool: """subscribe online state.""" if not isinstance(did, str) or handler is None: - raise MIoTMipsError('invalid params') - topic: str = f'device/{did}/state/#' + raise MIoTMipsError("invalid params") + topic: str = f"device/{did}/state/#" def on_state_msg(topic: str, payload: str, ctx: Any) -> None: msg: dict = json.loads(payload) # {"device_id":"xxxx","device_name":"米家智能插座3 ","event":"online", # "model": "cuco.plug.v3","timestamp":1709001070828,"uid":xxxx} - if msg is None or 'device_id' not in msg or 'event' not in msg: - self.log_error(f'on_state_msg, recv unknown msg, {payload}') + if msg is None or "device_id" not in msg or "event" not in msg: + self.log_error(f"on_state_msg, recv unknown msg, {payload}") return - if msg['device_id'] != did: - self.log_error( - f'on_state_msg, err msg, {did}!={msg["device_id"]}') + if msg["device_id"] != did: + self.log_error(f"on_state_msg, err msg, {did}!={msg['device_id']}") return if handler: - self.log_debug('cloud, device state changed, %s', payload) + self.log_debug("cloud, device state changed, %s", payload) handler( - did, MIoTDeviceState.ONLINE if msg['event'] == 'online' - else MIoTDeviceState.OFFLINE, ctx) + did, + MIoTDeviceState.ONLINE + if msg["event"] == "online" + else MIoTDeviceState.OFFLINE, + ctx, + ) + return self.__reg_broadcast_external( - topic=topic, handler=on_state_msg, handler_ctx=handler_ctx) + topic=topic, handler=on_state_msg, handler_ctx=handler_ctx + ) @final def unsub_device_state(self, did: str) -> bool: if not isinstance(did, str): - raise MIoTMipsError('invalid params') - topic: str = f'device/{did}/state/#' + raise MIoTMipsError("invalid params") + topic: str = f"device/{did}/state/#" return self.__unreg_broadcast_external(topic=topic) async def get_dev_list_async( self, payload: Optional[str] = None, timeout_ms: int = 10000 ) -> dict[str, dict]: - raise NotImplementedError('please call in http client') + raise NotImplementedError("please call in http client") async def get_prop_async( - self, did: str, siid: int, piid: int, timeout_ms: int = 10000 + self, did: str, siid: int, piid: int, timeout_ms: int = 10000 ) -> Any: - raise NotImplementedError('please call in http client') + raise NotImplementedError("please call in http client") async def set_prop_async( - self, did: str, siid: int, piid: int, value: Any, - timeout_ms: int = 10000 + self, did: str, siid: int, piid: int, value: Any, timeout_ms: int = 10000 ) -> dict: - raise NotImplementedError('please call in http client') + raise NotImplementedError("please call in http client") async def action_async( - self, did: str, siid: int, aiid: int, in_list: list, - timeout_ms: int = 10000 + self, did: str, siid: int, aiid: int, in_list: list, timeout_ms: int = 10000 ) -> dict: - raise NotImplementedError('please call in http client') + raise NotImplementedError("please call in http client") def __reg_broadcast_external( - self, topic: str, handler: Callable[[str, str, Any], None], - handler_ctx: Any = None + self, + topic: str, + handler: Callable[[str, str, Any], None], + handler_ctx: Any = None, ) -> bool: self._internal_loop.call_soon_threadsafe( - self.__reg_broadcast, topic, handler, handler_ctx) + self.__reg_broadcast, topic, handler, handler_ctx + ) return True def __unreg_broadcast_external(self, topic: str) -> bool: - self._internal_loop.call_soon_threadsafe( - self.__unreg_broadcast, topic) + self._internal_loop.call_soon_threadsafe(self.__unreg_broadcast, topic) return True def __reg_broadcast( - self, topic: str, handler: Callable[[str, str, Any], None], - handler_ctx: Any = None + self, + topic: str, + handler: Callable[[str, str, Any], None], + handler_ctx: Any = None, ) -> None: if not self._msg_matcher.get(topic=topic): sub_bc: _MipsBroadcast = _MipsBroadcast( - topic=topic, handler=handler, - handler_ctx=handler_ctx) + topic=topic, handler=handler, handler_ctx=handler_ctx + ) self._msg_matcher[topic] = sub_bc self._mips_sub_internal(topic=topic) else: - self.log_debug(f'mips cloud re-reg broadcast, {topic}') + self.log_debug(f"mips cloud re-reg broadcast, {topic}") def __unreg_broadcast(self, topic: str) -> None: if self._msg_matcher.get(topic=topic): @@ -1054,8 +1060,7 @@ class MipsCloudClient(_MipsClient): def _on_mips_connect(self, rc: int, props: dict) -> None: """sub topic.""" - for topic, _ in list( - self._msg_matcher.iter_all_nodes()): + for topic, _ in list(self._msg_matcher.iter_all_nodes()): self._mips_sub_internal(topic=topic) def _on_mips_disconnect(self, rc: int, props: dict) -> None: @@ -1067,23 +1072,24 @@ class MipsCloudClient(_MipsClient): NOTICE thread safe, this function will be called at the **mips** thread """ # broadcast - bc_list: list[_MipsBroadcast] = list( - self._msg_matcher.iter_match(topic)) + bc_list: list[_MipsBroadcast] = list(self._msg_matcher.iter_match(topic)) if not bc_list: return # The message from the cloud is not packed. - payload_str: str = payload.decode('utf-8') + payload_str: str = payload.decode("utf-8") # self.log_debug(f"on broadcast, {topic}, {payload}") for item in bc_list or []: if item.handler is None: continue # NOTICE: call threadsafe self.main_loop.call_soon_threadsafe( - item.handler, topic, payload_str, item.handler_ctx) + item.handler, topic, payload_str, item.handler_ctx + ) class MipsLocalClient(_MipsClient): """MIoT Pub/Sub Local Client.""" + # pylint: disable=unused-argument # pylint: disable=inconsistent-quotes MIPS_RECONNECT_INTERVAL_MIN: float = 6 @@ -1103,17 +1109,23 @@ class MipsLocalClient(_MipsClient): _on_dev_list_changed: Optional[Callable[[Any, list[str]], Coroutine]] def __init__( - self, did: str, host: str, group_id: str, - ca_file: str, cert_file: str, key_file: str, - port: int = 8883, home_name: str = '', - loop: Optional[asyncio.AbstractEventLoop] = None + self, + did: str, + host: str, + group_id: str, + ca_file: str, + cert_file: str, + key_file: str, + port: int = 8883, + home_name: str = "", + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: self._did = did self._group_id = group_id self._home_name = home_name self._mips_seed_id = random.randint(0, self.UINT32_MAX) - self._reply_topic = f'{did}/reply' - self._dev_list_change_topic = f'{did}/appMsg/devListChange' + self._reply_topic = f"{did}/reply" + self._dev_list_change_topic = f"{did}/appMsg/devListChange" self._request_map = {} self._msg_matcher = MIoTMatcher() self._get_prop_queue = {} @@ -1121,8 +1133,14 @@ class MipsLocalClient(_MipsClient): self._on_dev_list_changed = None super().__init__( - client_id=did, host=host, port=port, - ca_file=ca_file, cert_file=cert_file, key_file=key_file, loop=loop) + client_id=did, + host=host, + port=port, + ca_file=ca_file, + cert_file=cert_file, + key_file=key_file, + loop=loop, + ) @property def group_id(self) -> str: @@ -1130,15 +1148,15 @@ class MipsLocalClient(_MipsClient): def log_debug(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.debug(f'{self._home_name}, '+msg, *args, **kwargs) + self._logger.debug(f"{self._home_name}, " + msg, *args, **kwargs) def log_info(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.info(f'{self._home_name}, '+msg, *args, **kwargs) + self._logger.info(f"{self._home_name}, " + msg, *args, **kwargs) def log_error(self, msg, *args, **kwargs) -> None: if self._logger: - self._logger.error(f'{self._home_name}, '+msg, *args, **kwargs) + self._logger.error(f"{self._home_name}, " + msg, *args, **kwargs) @final def connect(self, thread_name: Optional[str] = None) -> None: @@ -1158,39 +1176,40 @@ class MipsLocalClient(_MipsClient): handler: Callable[[dict, Any], None], siid: Optional[int] = None, piid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: topic: str = ( - f'appMsg/notify/iot/{did}/property/' - f'{"#" if siid is None or piid is None else f"{siid}.{piid}"}') + f"appMsg/notify/iot/{did}/property/" + f"{'#' if siid is None or piid is None else f'{siid}.{piid}'}" + ) def on_prop_msg(topic: str, payload: str, ctx: Any): msg: dict = json.loads(payload) if ( msg is None - or 'did' not in msg - or 'siid' not in msg - or 'piid' not in msg - or 'value' not in msg + or "did" not in msg + or "siid" not in msg + or "piid" not in msg + or "value" not in msg ): # self.log_error(f'on_prop_msg, recv unknown msg, {payload}') return if handler: - self.log_debug('local, on properties_changed, %s', payload) + self.log_debug("local, on properties_changed, %s", payload) handler(msg, ctx) + return self.__reg_broadcast_external( - topic=topic, handler=on_prop_msg, handler_ctx=handler_ctx) + topic=topic, handler=on_prop_msg, handler_ctx=handler_ctx + ) @final def unsub_prop( - self, - did: str, - siid: Optional[int] = None, - piid: Optional[int] = None + self, did: str, siid: Optional[int] = None, piid: Optional[int] = None ) -> bool: topic: str = ( - f'appMsg/notify/iot/{did}/property/' - f'{"#" if siid is None or piid is None else f"{siid}.{piid}"}') + f"appMsg/notify/iot/{did}/property/" + f"{'#' if siid is None or piid is None else f'{siid}.{piid}'}" + ) return self.__unreg_broadcast_external(topic=topic) @final @@ -1200,42 +1219,43 @@ class MipsLocalClient(_MipsClient): handler: Callable[[dict, Any], None], siid: Optional[int] = None, eiid: Optional[int] = None, - handler_ctx: Any = None + handler_ctx: Any = None, ) -> bool: topic: str = ( - f'appMsg/notify/iot/{did}/event/' - f'{"#" if siid is None or eiid is None else f"{siid}.{eiid}"}') + f"appMsg/notify/iot/{did}/event/" + f"{'#' if siid is None or eiid is None else f'{siid}.{eiid}'}" + ) def on_event_msg(topic: str, payload: str, ctx: Any): msg: dict = json.loads(payload) if ( msg is None - or 'did' not in msg - or 'siid' not in msg - or 'eiid' not in msg + or "did" not in msg + or "siid" not in msg + or "eiid" not in msg # or 'arguments' not in msg ): - self.log_info('unknown event msg, %s', payload) + self.log_info("unknown event msg, %s", payload) return - if 'arguments' not in msg: - self.log_info('wrong event msg, %s', payload) - msg['arguments'] = [] + if "arguments" not in msg: + self.log_info("wrong event msg, %s", payload) + msg["arguments"] = [] if handler: - self.log_debug('local, on event_occurred, %s', payload) + self.log_debug("local, on event_occurred, %s", payload) handler(msg, ctx) + return self.__reg_broadcast_external( - topic=topic, handler=on_event_msg, handler_ctx=handler_ctx) + topic=topic, handler=on_event_msg, handler_ctx=handler_ctx + ) @final def unsub_event( - self, - did: str, - siid: Optional[int] = None, - eiid: Optional[int] = None + self, did: str, siid: Optional[int] = None, eiid: Optional[int] = None ) -> bool: topic: str = ( - f'appMsg/notify/iot/{did}/event/' - f'{"#" if siid is None or eiid is None else f"{siid}.{eiid}"}') + f"appMsg/notify/iot/{did}/event/" + f"{'#' if siid is None or eiid is None else f'{siid}.{eiid}'}" + ) return self.__unreg_broadcast_external(topic=topic) @final @@ -1244,20 +1264,17 @@ class MipsLocalClient(_MipsClient): ) -> Any: self._get_prop_queue.setdefault(did, []) fut: asyncio.Future = self.main_loop.create_future() - self._get_prop_queue[did].append({ - 'param': json.dumps({ - 'did': did, - 'siid': siid, - 'piid': piid - }), - 'fut': fut, - 'timeout_ms': timeout_ms - }) + self._get_prop_queue[did].append( + { + "param": json.dumps({"did": did, "siid": siid, "piid": piid}), + "fut": fut, + "timeout_ms": timeout_ms, + } + ) if self._get_prop_timer is None: self._get_prop_timer = self.main_loop.call_later( - 0.1, - self.main_loop.create_task, - self.__get_prop_timer_handle()) + 0.1, self.main_loop.create_task, self.__get_prop_timer_handle() + ) return await fut @final @@ -1265,153 +1282,166 @@ class MipsLocalClient(_MipsClient): self, did: str, siid: int, piid: int, timeout_ms: int = 10000 ) -> Any: result_obj = await self.__request_async( - topic='proxy/get', - payload=json.dumps({ - 'did': did, - 'siid': siid, - 'piid': piid - }), - timeout_ms=timeout_ms) - if not isinstance(result_obj, dict) or 'value' not in result_obj: + topic="proxy/get", + payload=json.dumps({"did": did, "siid": siid, "piid": piid}), + timeout_ms=timeout_ms, + ) + if not isinstance(result_obj, dict) or "value" not in result_obj: return None - return result_obj['value'] + return result_obj["value"] @final async def set_prop_async( - self, did: str, siid: int, piid: int, value: Any, - timeout_ms: int = 10000 + self, did: str, siid: int, piid: int, value: Any, timeout_ms: int = 10000 ) -> dict: payload_obj: dict = { - 'did': did, - 'rpc': { - 'id': self.__gen_mips_id, - 'method': 'set_properties', - 'params': [{ - 'did': did, - 'siid': siid, - 'piid': piid, - 'value': value - }] - } + "did": did, + "rpc": { + "id": self.__gen_mips_id, + "method": "set_properties", + "params": [{"did": did, "siid": siid, "piid": piid, "value": value}], + }, } result_obj = await self.__request_async( - topic='proxy/rpcReq', - payload=json.dumps(payload_obj), - timeout_ms=timeout_ms) + topic="proxy/rpcReq", payload=json.dumps(payload_obj), timeout_ms=timeout_ms + ) if result_obj: if ( - 'result' in result_obj - and len(result_obj['result']) == 1 - and 'did' in result_obj['result'][0] - and result_obj['result'][0]['did'] == did - and 'code' in result_obj['result'][0] + "result" in result_obj + and len(result_obj["result"]) == 1 + and "did" in result_obj["result"][0] + and result_obj["result"][0]["did"] == did + and "code" in result_obj["result"][0] ): - return result_obj['result'][0] - if 'error' in result_obj: - return result_obj['error'] + return result_obj["result"][0] + if "error" in result_obj: + return result_obj["error"] return { - 'code': MIoTErrorCode.CODE_INTERNAL_ERROR.value, - 'message': 'Invalid result'} + "code": MIoTErrorCode.CODE_INTERNAL_ERROR.value, + "message": "Invalid result", + } + + @final + async def set_props_async( + self, did: str, props_list: List[Dict[str, Any]], timeout_ms: int = 10000 + ) -> dict: + payload_obj: dict = { + "did": did, + "rpc": { + "id": self.__gen_mips_id, + "method": "set_properties", + "params": props_list, + }, + } + result_obj = await self.__request_async( + topic="proxy/rpcReq", payload=json.dumps(payload_obj), timeout_ms=timeout_ms + ) + if result_obj: + if ( + "result" in result_obj + and len(result_obj["result"]) == len(props_list) + and result_obj["result"][0].get("did") == did + and all("code" in item for item in result_obj["result"]) + ): + return result_obj["result"] + if "error" in result_obj: + return result_obj["error"] + return { + "code": MIoTErrorCode.CODE_INTERNAL_ERROR.value, + "message": "Invalid result", + } @final async def action_async( - self, did: str, siid: int, aiid: int, in_list: list, - timeout_ms: int = 10000 + self, did: str, siid: int, aiid: int, in_list: list, timeout_ms: int = 10000 ) -> dict: payload_obj: dict = { - 'did': did, - 'rpc': { - 'id': self.__gen_mips_id, - 'method': 'action', - 'params': { - 'did': did, - 'siid': siid, - 'aiid': aiid, - 'in': in_list - } - } + "did": did, + "rpc": { + "id": self.__gen_mips_id, + "method": "action", + "params": {"did": did, "siid": siid, "aiid": aiid, "in": in_list}, + }, } result_obj = await self.__request_async( - topic='proxy/rpcReq', payload=json.dumps(payload_obj), - timeout_ms=timeout_ms) + topic="proxy/rpcReq", payload=json.dumps(payload_obj), timeout_ms=timeout_ms + ) if result_obj: - if 'result' in result_obj and 'code' in result_obj['result']: - return result_obj['result'] - if 'error' in result_obj: - return result_obj['error'] + if "result" in result_obj and "code" in result_obj["result"]: + return result_obj["result"] + if "error" in result_obj: + return result_obj["error"] return { - 'code': MIoTErrorCode.CODE_INTERNAL_ERROR.value, - 'message': 'Invalid result'} + "code": MIoTErrorCode.CODE_INTERNAL_ERROR.value, + "message": "Invalid result", + } @final async def get_dev_list_async( self, payload: Optional[str] = None, timeout_ms: int = 10000 ) -> dict[str, dict]: result_obj = await self.__request_async( - topic='proxy/getDevList', payload=payload or '{}', - timeout_ms=timeout_ms) - if not result_obj or 'devList' not in result_obj: - raise MIoTMipsError('invalid result') + topic="proxy/getDevList", payload=payload or "{}", timeout_ms=timeout_ms + ) + if not result_obj or "devList" not in result_obj: + raise MIoTMipsError("invalid result") device_list = {} - for did, info in result_obj['devList'].items(): - name: str = info.get('name', None) - urn: str = info.get('urn', None) - model: str = info.get('model', None) + for did, info in result_obj["devList"].items(): + name: str = info.get("name", None) + urn: str = info.get("urn", None) + model: str = info.get("model", None) if name is None or urn is None or model is None: - self.log_error(f'invalid device info, {did}, {info}') + self.log_error(f"invalid device info, {did}, {info}") continue device_list[did] = { - 'did': did, - 'name': name, - 'urn': urn, - 'model': model, - 'online': info.get('online', False), - 'icon': info.get('icon', None), - 'fw_version': None, - 'home_id': '', - 'home_name': '', - 'room_id': info.get('roomId', ''), - 'room_name': info.get('roomName', ''), - 'specv2_access': info.get('specV2Access', False), - 'push_available': info.get('pushAvailable', False), - 'manufacturer': model.split('.')[0], + "did": did, + "name": name, + "urn": urn, + "model": model, + "online": info.get("online", False), + "icon": info.get("icon", None), + "fw_version": None, + "home_id": "", + "home_name": "", + "room_id": info.get("roomId", ""), + "room_name": info.get("roomName", ""), + "specv2_access": info.get("specV2Access", False), + "push_available": info.get("pushAvailable", False), + "manufacturer": model.split(".")[0], } return device_list @final - async def get_action_group_list_async( - self, timeout_ms: int = 10000 - ) -> list[str]: + async def get_action_group_list_async(self, timeout_ms: int = 10000) -> list[str]: result_obj = await self.__request_async( - topic='proxy/getMijiaActionGroupList', - payload='{}', - timeout_ms=timeout_ms) - if not result_obj or 'result' not in result_obj: - raise MIoTMipsError('invalid result') - return result_obj['result'] + topic="proxy/getMijiaActionGroupList", payload="{}", timeout_ms=timeout_ms + ) + if not result_obj or "result" not in result_obj: + raise MIoTMipsError("invalid result") + return result_obj["result"] @final async def exec_action_group_list_async( self, ag_id: str, timeout_ms: int = 10000 ) -> dict: result_obj = await self.__request_async( - topic='proxy/execMijiaActionGroup', + topic="proxy/execMijiaActionGroup", payload=f'{{"id":"{ag_id}"}}', - timeout_ms=timeout_ms) + timeout_ms=timeout_ms, + ) if result_obj: - if 'result' in result_obj: - return result_obj['result'] - if 'error' in result_obj: - return result_obj['error'] + if "result" in result_obj: + return result_obj["result"] + if "error" in result_obj: + return result_obj["error"] return { - 'code': MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value, - 'message': 'invalid result'} + "code": MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value, + "message": "invalid result", + } @final @property - def on_dev_list_changed( - self - ) -> Optional[Callable[[Any, list[str]], Coroutine]]: + def on_dev_list_changed(self) -> Optional[Callable[[Any, list[str]], Coroutine]]: return self._on_dev_list_changed @final @@ -1423,70 +1453,75 @@ class MipsLocalClient(_MipsClient): self._on_dev_list_changed = func def __request( - self, topic: str, payload: str, - on_reply: Callable[[str, Any], None], - on_reply_ctx: Any = None, timeout_ms: int = 10000 + self, + topic: str, + payload: str, + on_reply: Callable[[str, Any], None], + on_reply_ctx: Any = None, + timeout_ms: int = 10000, ) -> None: req = _MipsRequest( mid=self.__gen_mips_id, on_reply=on_reply, on_reply_ctx=on_reply_ctx, - timer=None) - pub_topic: str = f'master/{topic}' + timer=None, + ) + pub_topic: str = f"master/{topic}" result = self.__mips_publish( - topic=pub_topic, payload=payload, mid=req.mid, - ret_topic=self._reply_topic) + topic=pub_topic, payload=payload, mid=req.mid, ret_topic=self._reply_topic + ) self.log_debug( - f'mips local call api, {result}, {req.mid}, {pub_topic}, ' - f'{payload}') + f"mips local call api, {result}, {req.mid}, {pub_topic}, {payload}" + ) def on_request_timeout(req: _MipsRequest): self.log_error( - f'on mips request timeout, {req.mid}, {pub_topic}' - f', {payload}') + f"on mips request timeout, {req.mid}, {pub_topic}, {payload}" + ) self._request_map.pop(str(req.mid), None) req.on_reply( - '{"error":{"code":-10006, "message":"timeout"}}', - req.on_reply_ctx) + '{"error":{"code":-10006, "message":"timeout"}}', req.on_reply_ctx + ) + req.timer = self._internal_loop.call_later( - timeout_ms/1000, on_request_timeout, req) + timeout_ms / 1000, on_request_timeout, req + ) self._request_map[str(req.mid)] = req def __reg_broadcast( - self, topic: str, handler: Callable[[str, str, Any], None], - handler_ctx: Any + self, topic: str, handler: Callable[[str, str, Any], None], handler_ctx: Any ) -> None: - sub_topic: str = f'{self._did}/{topic}' + sub_topic: str = f"{self._did}/{topic}" if not self._msg_matcher.get(sub_topic): sub_bc: _MipsBroadcast = _MipsBroadcast( - topic=sub_topic, handler=handler, - handler_ctx=handler_ctx) + topic=sub_topic, handler=handler, handler_ctx=handler_ctx + ) self._msg_matcher[sub_topic] = sub_bc - self._mips_sub_internal(topic=f'master/{topic}') + self._mips_sub_internal(topic=f"master/{topic}") else: - self.log_debug(f'mips re-reg broadcast, {sub_topic}') + self.log_debug(f"mips re-reg broadcast, {sub_topic}") def __unreg_broadcast(self, topic) -> None: # Central hub gateway needs to add prefix - unsub_topic: str = f'{self._did}/{topic}' + unsub_topic: str = f"{self._did}/{topic}" if self._msg_matcher.get(unsub_topic): del self._msg_matcher[unsub_topic] self._mips_unsub_internal( - topic=re.sub(f'^{self._did}', 'master', unsub_topic)) + topic=re.sub(f"^{self._did}", "master", unsub_topic) + ) @final def _on_mips_connect(self, rc: int, props: dict) -> None: - self.log_debug('__on_mips_connect_handler') + self.log_debug("__on_mips_connect_handler") # Sub did/#, include reply topic - self._mips_sub_internal(f'{self._did}/#') + self._mips_sub_internal(f"{self._did}/#") # Sub device list change - self._mips_sub_internal('master/appMsg/devListChange') + self._mips_sub_internal("master/appMsg/devListChange") # Do not need to subscribe api topics, for they are covered by did/# # Sub api topic. # Sub broadcast topic for topic, _ in list(self._msg_matcher.iter_all_nodes()): - self._mips_sub_internal( - topic=re.sub(f'^{self._did}', 'master', topic)) + self._mips_sub_internal(topic=re.sub(f"^{self._did}", "master", topic)) @final def _on_mips_disconnect(self, rc: int, props: dict) -> None: @@ -1499,54 +1534,54 @@ class MipsLocalClient(_MipsClient): # f"mips local client, on_message, {topic} -> {mips_msg}") # Reply if topic == self._reply_topic: - self.log_debug(f'on request reply, {mips_msg}') - req: Optional[_MipsRequest] = self._request_map.pop( - str(mips_msg.mid), None) + self.log_debug(f"on request reply, {mips_msg}") + req: Optional[_MipsRequest] = self._request_map.pop(str(mips_msg.mid), None) if req: # Cancel timer if req.timer: req.timer.cancel() if req.on_reply: self.main_loop.call_soon_threadsafe( - req.on_reply, mips_msg.payload or '{}', - req.on_reply_ctx) + req.on_reply, mips_msg.payload or "{}", req.on_reply_ctx + ) return # Broadcast - bc_list: list[_MipsBroadcast] = list(self._msg_matcher.iter_match( - topic=topic)) + bc_list: list[_MipsBroadcast] = list(self._msg_matcher.iter_match(topic=topic)) if bc_list: - self.log_debug(f'on broadcast, {topic}, {mips_msg}') + self.log_debug(f"on broadcast, {topic}, {mips_msg}") for item in bc_list or []: if item.handler is None: continue self.main_loop.call_soon_threadsafe( - item.handler, topic[topic.find('/')+1:], - mips_msg.payload or '{}', item.handler_ctx) + item.handler, + topic[topic.find("/") + 1 :], + mips_msg.payload or "{}", + item.handler_ctx, + ) return # Device list change if topic == self._dev_list_change_topic: if mips_msg.payload is None: - self.log_error('devListChange msg is None') + self.log_error("devListChange msg is None") return payload_obj: dict = json.loads(mips_msg.payload) - dev_list = payload_obj.get('devList', None) + dev_list = payload_obj.get("devList", None) if not isinstance(dev_list, list) or not dev_list: - _LOGGER.error( - 'unknown devListChange msg, %s', mips_msg.payload) + _LOGGER.error("unknown devListChange msg, %s", mips_msg.payload) return if self._on_dev_list_changed: self.main_loop.call_soon_threadsafe( self.main_loop.create_task, - self._on_dev_list_changed(self, dev_list)) + self._on_dev_list_changed(self, dev_list), + ) return - self.log_debug( - f'mips local client, recv unknown msg, {topic} -> {mips_msg}') + self.log_debug(f"mips local client, recv unknown msg, {topic} -> {mips_msg}") @property def __gen_mips_id(self) -> int: mips_id: int = self._mips_seed_id - self._mips_seed_id = int((self._mips_seed_id+1) % self.UINT32_MAX) + self._mips_seed_id = int((self._mips_seed_id + 1) % self.UINT32_MAX) return mips_id def __mips_publish( @@ -1556,38 +1591,46 @@ class MipsLocalClient(_MipsClient): mid: Optional[int] = None, ret_topic: Optional[str] = None, wait_for_publish: bool = False, - timeout_ms: int = 10000 + timeout_ms: int = 10000, ) -> bool: mips_msg: bytes = _MipsMessage.pack( - mid=mid or self.__gen_mips_id, payload=payload, - msg_from='local', ret_topic=ret_topic) + mid=mid or self.__gen_mips_id, + payload=payload, + msg_from="local", + ret_topic=ret_topic, + ) return self._mips_publish_internal( - topic=topic.strip(), payload=mips_msg, - wait_for_publish=wait_for_publish, timeout_ms=timeout_ms) + topic=topic.strip(), + payload=mips_msg, + wait_for_publish=wait_for_publish, + timeout_ms=timeout_ms, + ) def __request_external( - self, topic: str, payload: str, - on_reply: Callable[[str, Any], None], - on_reply_ctx: Any = None, timeout_ms: int = 10000 + self, + topic: str, + payload: str, + on_reply: Callable[[str, Any], None], + on_reply_ctx: Any = None, + timeout_ms: int = 10000, ) -> bool: if topic is None or payload is None or on_reply is None: - raise MIoTMipsError('invalid params') + raise MIoTMipsError("invalid params") self._internal_loop.call_soon_threadsafe( - self.__request, topic, payload, on_reply, on_reply_ctx, timeout_ms) + self.__request, topic, payload, on_reply, on_reply_ctx, timeout_ms + ) return True def __reg_broadcast_external( - self, topic: str, handler: Callable[[str, str, Any], None], - handler_ctx: Any + self, topic: str, handler: Callable[[str, str, Any], None], handler_ctx: Any ) -> bool: self._internal_loop.call_soon_threadsafe( - self.__reg_broadcast, - topic, handler, handler_ctx) + self.__reg_broadcast, topic, handler, handler_ctx + ) return True def __unreg_broadcast_external(self, topic) -> bool: - self._internal_loop.call_soon_threadsafe( - self.__unreg_broadcast, topic) + self._internal_loop.call_soon_threadsafe(self.__unreg_broadcast, topic) return True @final @@ -1600,42 +1643,44 @@ class MipsLocalClient(_MipsClient): fut: asyncio.Future = ctx if fut: self.main_loop.call_soon_threadsafe(fut.set_result, payload) + if not self.__request_external( - topic=topic, - payload=payload, - on_reply=on_msg_reply, - on_reply_ctx=fut_handler, - timeout_ms=timeout_ms): + topic=topic, + payload=payload, + on_reply=on_msg_reply, + on_reply_ctx=fut_handler, + timeout_ms=timeout_ms, + ): # Request error - fut_handler.set_result('internal request error') + fut_handler.set_result("internal request error") result = await fut_handler try: return json.loads(result) except json.JSONDecodeError: return { - 'code': MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value, - 'message': f'Error: {result}'} + "code": MIoTErrorCode.CODE_MIPS_INVALID_RESULT.value, + "message": f"Error: {result}", + } async def __get_prop_timer_handle(self) -> None: for did in list(self._get_prop_queue.keys()): item = self._get_prop_queue[did].pop() - _LOGGER.debug('get prop, %s, %s', did, item) + _LOGGER.debug("get prop, %s, %s", did, item) result_obj = await self.__request_async( - topic='proxy/get', - payload=item['param'], - timeout_ms=item['timeout_ms']) - if result_obj is None or 'value' not in result_obj: - item['fut'].set_result(None) + topic="proxy/get", payload=item["param"], timeout_ms=item["timeout_ms"] + ) + if result_obj is None or "value" not in result_obj: + item["fut"].set_result(None) else: - item['fut'].set_result(result_obj['value']) + item["fut"].set_result(result_obj["value"]) if not self._get_prop_queue[did]: self._get_prop_queue.pop(did, None) if self._get_prop_queue: self._get_prop_timer = self.main_loop.call_later( - 0.1, lambda: self.main_loop.create_task( - self.__get_prop_timer_handle())) + 0.1, lambda: self.main_loop.create_task(self.__get_prop_timer_handle()) + ) else: self._get_prop_timer = None diff --git a/custom_components/xiaomi_home/select.py b/custom_components/xiaomi_home/select.py index 21b5e78..735c99a 100644 --- a/custom_components/xiaomi_home/select.py +++ b/custom_components/xiaomi_home/select.py @@ -45,18 +45,24 @@ off Xiaomi or its affiliates' products. Select entities for Xiaomi Home. """ + from __future__ import annotations +import logging from typing import Optional from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.components.select import SelectEntity +from homeassistant.helpers.entity import EntityCategory +from homeassistant.helpers.restore_state import RestoreEntity from .miot.const import DOMAIN from .miot.miot_device import MIoTDevice, MIoTPropertyEntity from .miot.miot_spec import MIoTSpecProperty +_LOGGER = logging.getLogger(__name__) + async def async_setup_entry( hass: HomeAssistant, @@ -64,17 +70,32 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up a config entry.""" - device_list: list[MIoTDevice] = hass.data[DOMAIN]['devices'][ - config_entry.entry_id] + device_list: list[MIoTDevice] = hass.data[DOMAIN]["devices"][config_entry.entry_id] new_entities = [] for miot_device in device_list: - for prop in miot_device.prop_list.get('select', []): + for prop in miot_device.prop_list.get("select", []): new_entities.append(Select(miot_device=miot_device, spec=prop)) if new_entities: async_add_entities(new_entities) + # create select for light + new_light_select_entities = [] + for miot_device in device_list: + if "device:light" in miot_device.spec_instance.urn: + if miot_device.entity_list.get("light", []): + device_id = list(miot_device.device_info.get("identifiers"))[0][1] + light_entity_id = miot_device.gen_device_entity_id(DOMAIN) + new_light_select_entities.append( + LightCommandSendMode( + hass=hass, light_entity_id=light_entity_id, device_id=device_id + ) + ) + + if new_light_select_entities: + async_add_entities(new_light_select_entities) + class Select(MIoTPropertyEntity, SelectEntity): """Select entities for Xiaomi Home.""" @@ -87,10 +108,46 @@ class Select(MIoTPropertyEntity, SelectEntity): async def async_select_option(self, option: str) -> None: """Change the selected option.""" - await self.set_property_async( - value=self.get_vlist_value(description=option)) + await self.set_property_async(value=self.get_vlist_value(description=option)) @property def current_option(self) -> Optional[str]: """Return the current selected option.""" return self.get_vlist_description(value=self._value) + + +class LightCommandSendMode(SelectEntity, RestoreEntity): + """To control whether to turn on the light, you need to send the light-on command first and + then send other color temperatures and brightness or send them all at the same time. + The default is to send one by one.""" + + def __init__(self, hass: HomeAssistant, light_entity_id: str, device_id: str): + super().__init__() + self.hass = hass + self._device_id = device_id + self._attr_name = "Command Send Mode" + self._attr_unique_id = f"{light_entity_id}_command_send_mode" + self._attr_options = ["Send One by One", "Send Turn On First", "Send Together"] + self._attr_device_info = {"identifiers": {(DOMAIN, device_id)}} + self._attr_current_option = self._attr_options[0] # 默认选项 + self._attr_entity_category = ( + EntityCategory.CONFIG + ) # **重点:告诉 HA 这是配置类实体** + + async def async_select_option(self, option: str): + """处理用户选择的选项。""" + if option in self._attr_options: + self._attr_current_option = option + self.async_write_ha_state() + + async def async_added_to_hass(self): + """在实体添加到 Home Assistant 时恢复上次的状态。""" + await super().async_added_to_hass() + if ( + last_state := await self.async_get_last_state() + ) and last_state.state in self._attr_options: + self._attr_current_option = last_state.state + + @property + def current_option(self): + return self._attr_current_option