fix racing condition

This commit is contained in:
Feng Wang 2024-12-22 19:51:58 +08:00
parent 463216d866
commit d13a6bfb11

View File

@ -491,6 +491,7 @@ class MIoTLan:
_profile_models: dict[str, dict] _profile_models: dict[str, dict]
_init_lock: asyncio.Lock
_init_done: bool _init_done: bool
# The following should be called from the main loop # The following should be called from the main loop
@ -547,6 +548,7 @@ class MIoTLan:
self._lan_state_sub_map = {} self._lan_state_sub_map = {}
self._lan_ctrl_vote_map = {} self._lan_ctrl_vote_map = {}
self._init_lock = asyncio.Lock()
self._init_done = False self._init_done = False
if ( if (
@ -571,44 +573,46 @@ class MIoTLan:
return self._init_done return self._init_done
async def init_async(self) -> None: async def init_async(self) -> None:
if self._init_done: # Avoid race condition
_LOGGER.info('miot lan already init') async with self._init_lock:
return if self._init_done:
if len(self._net_ifs) == 0: _LOGGER.info('miot lan already init')
_LOGGER.info('no net_ifs') return
return if len(self._net_ifs) == 0:
if not any(self._lan_ctrl_vote_map.values()): _LOGGER.info('no net_ifs')
_LOGGER.info('no vote for lan ctrl') return
return if not any(self._lan_ctrl_vote_map.values()):
if len(self._mips_service.get_services()) > 0: _LOGGER.info('no vote for lan ctrl')
_LOGGER.info('central hub gateway service exist') return
return if len(self._mips_service.get_services()) > 0:
for if_name in list(self._network.network_info.keys()): _LOGGER.info('central hub gateway service exist')
self._available_net_ifs.add(if_name) return
if len(self._available_net_ifs) == 0: for if_name in list(self._network.network_info.keys()):
_LOGGER.info('no available net_ifs') self._available_net_ifs.add(if_name)
return if len(self._available_net_ifs) == 0:
if self._net_ifs.isdisjoint(self._available_net_ifs): _LOGGER.info('no available net_ifs')
_LOGGER.info('no valid net_ifs') return
return if self._net_ifs.isdisjoint(self._available_net_ifs):
try: _LOGGER.info('no valid net_ifs')
self._profile_models = await self._main_loop.run_in_executor( return
None, load_yaml_file, try:
gen_absolute_path(self.PROFILE_MODELS_FILE)) self._profile_models = await self._main_loop.run_in_executor(
except Exception as err: # pylint: disable=broad-exception-caught None, load_yaml_file,
_LOGGER.error('load profile models error, %s', err) gen_absolute_path(self.PROFILE_MODELS_FILE))
self._profile_models = {} except Exception as err: # pylint: disable=broad-exception-caught
self._internal_loop = asyncio.new_event_loop() _LOGGER.error('load profile models error, %s', err)
# All tasks meant for the internal loop should happen in this thread self._profile_models = {}
self._thread = threading.Thread(target=self.__internal_loop_thread) self._internal_loop = asyncio.new_event_loop()
self._thread.name = 'miot_lan' # All tasks meant for the internal loop should happen in this thread
self._thread.daemon = True self._thread = threading.Thread(target=self.__internal_loop_thread)
self._thread.start() self._thread.name = 'miot_lan'
self._init_done = True self._thread.daemon = True
for handler in list(self._lan_state_sub_map.values()): self._thread.start()
self._main_loop.create_task(handler(True)) self._init_done = True
_LOGGER.info( for handler in list(self._lan_state_sub_map.values()):
'miot lan init, %s ,%s', self._net_ifs, self._available_net_ifs) self._main_loop.create_task(handler(True))
_LOGGER.info(
'miot lan init, %s ,%s', self._net_ifs, self._available_net_ifs)
def __internal_loop_thread(self) -> None: def __internal_loop_thread(self) -> None:
_LOGGER.info('miot lan thread start') _LOGGER.info('miot lan thread start')
@ -1347,7 +1351,7 @@ class MIoTLan:
scan_time = self.__get_next_scan_time() scan_time = self.__get_next_scan_time()
self._scan_timer = self._internal_loop.call_later( self._scan_timer = self._internal_loop.call_later(
scan_time, self.__scan_devices) scan_time, self.__scan_devices)
_LOGGER.debug('next scan time: %sms', scan_time) _LOGGER.debug('next scan time: %ss', scan_time)
def __get_next_scan_time(self) -> float: def __get_next_scan_time(self) -> float:
if not self._last_scan_interval: if not self._last_scan_interval: