Expand source code
class WatchServer:
def __init__(
self,
worker_runner: Callable[[proto.CliArgs], Any],
main_file: pathlib.Path,
cli_args: proto.CliArgs,
loop: asyncio.AbstractEventLoop,
) -> None:
self._mp_pch, cli_args.mp_cch = socket.socketpair()
self._cli_args = cli_args
self._worker_runner = worker_runner
self._main_file = main_file
self._loop = loop
self._recv_jobs_fut = asyncio.Future[None]()
self._worker_reloading = False
async def run(self) -> None:
watch_paths = _find_watchable_paths(self._main_file)
for pth in watch_paths:
logger.log(DEV_LEVEL, f"Watching {pth}")
self._pch = await utils.aio.duplex_unix._AsyncDuplex.open(self._mp_pch)
read_ipc_task = self._loop.create_task(self._read_ipc_task())
try:
await watchfiles.arun_process(
*watch_paths,
target=self._worker_runner,
args=(self._cli_args,),
watch_filter=watchfiles.filters.PythonFilter(),
callback=self._on_reload,
)
finally:
await utils.aio.gracefully_cancel(read_ipc_task)
await self._pch.aclose()
async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None:
if self._worker_reloading:
return
self._worker_reloading = True
try:
await channel.asend_message(self._pch, proto.ActiveJobsRequest())
self._recv_jobs_fut = asyncio.Future()
with contextlib.suppress(asyncio.TimeoutError):
# wait max 1.5s to get the active jobs
await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5)
finally:
self._cli_args.reload_count += 1
@utils.log_exceptions(logger=logger)
async def _read_ipc_task(self) -> None:
active_jobs = []
while True:
msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES)
if isinstance(msg, proto.ActiveJobsResponse):
if msg.reload_count != self._cli_args.reload_count:
continue
active_jobs = msg.jobs
with contextlib.suppress(asyncio.InvalidStateError):
self._recv_jobs_fut.set_result(None)
if isinstance(msg, proto.ReloadJobsRequest):
await channel.asend_message(
self._pch, proto.ReloadJobsResponse(jobs=active_jobs)
)
if isinstance(msg, proto.Reloaded):
self._worker_reloading = False