| skipped 2 lines |
3 | 3 | | import re |
4 | 4 | | import ssl |
5 | 5 | | import sys |
| 6 | + | import tqdm |
| 7 | + | import time |
| 8 | + | from typing import Callable, Any, Iterable, Tuple |
6 | 9 | | |
7 | 10 | | import aiohttp |
8 | 11 | | import tqdm.asyncio |
| skipped 27 lines |
36 | 39 | | } |
37 | 40 | | |
38 | 41 | | unsupported_characters = '#' |
| 42 | + | |
| 43 | + | QueryDraft = Tuple[Callable, Any, Any] |
| 44 | + | QueriesDraft = Iterable[QueryDraft] |
| 45 | + | |
| 46 | + | class AsyncExecutor: |
| 47 | + | def __init__(self, *args, **kwargs): |
| 48 | + | self.logger = kwargs['logger'] |
| 49 | + | |
| 50 | + | async def run(self, tasks: QueriesDraft): |
| 51 | + | start_time = time.time() |
| 52 | + | results = await self._run(tasks) |
| 53 | + | self.execution_time = time.time() - start_time |
| 54 | + | self.logger.debug(f'Spent time: {self.execution_time}') |
| 55 | + | return results |
| 56 | + | |
| 57 | + | async def _run(self, tasks: QueriesDraft): |
| 58 | + | await asyncio.sleep(0) |
| 59 | + | |
| 60 | + | |
| 61 | + | class AsyncioSimpleExecutor(AsyncExecutor): |
| 62 | + | def __init__(self, *args, **kwargs): |
| 63 | + | super().__init__(*args, **kwargs) |
| 64 | + | |
| 65 | + | async def _run(self, tasks: QueriesDraft): |
| 66 | + | futures = [f(*args, **kwargs) for f, args, kwargs in tasks] |
| 67 | + | return await asyncio.gather(*futures) |
| 68 | + | |
| 69 | + | |
| 70 | + | class AsyncioProgressbarExecutor(AsyncExecutor): |
| 71 | + | def __init__(self, *args, **kwargs): |
| 72 | + | super().__init__(*args, **kwargs) |
| 73 | + | |
| 74 | + | async def _run(self, tasks: QueriesDraft): |
| 75 | + | futures = [f(*args, **kwargs) for f, args, kwargs in tasks] |
| 76 | + | results = [] |
| 77 | + | for f in tqdm.asyncio.tqdm.as_completed(futures): |
| 78 | + | results.append(await f) |
| 79 | + | return results |
| 80 | + | |
| 81 | + | |
| 82 | + | class AsyncioProgressbarSemaphoreExecutor(AsyncExecutor): |
| 83 | + | def __init__(self, *args, **kwargs): |
| 84 | + | super().__init__(*args, **kwargs) |
| 85 | + | self.semaphore = asyncio.Semaphore(kwargs.get('in_parallel', 1)) |
| 86 | + | |
| 87 | + | async def _run(self, tasks: QueriesDraft): |
| 88 | + | async def _wrap_query(q: QueryDraft): |
| 89 | + | async with self.semaphore: |
| 90 | + | f, args, kwargs = q |
| 91 | + | return await f(*args, **kwargs) |
| 92 | + | |
| 93 | + | async def semaphore_gather(tasks: QueriesDraft): |
| 94 | + | coros = [_wrap_query(q) for q in tasks] |
| 95 | + | results = [] |
| 96 | + | for f in tqdm.asyncio.tqdm.as_completed(coros): |
| 97 | + | results.append(await f) |
| 98 | + | return results |
| 99 | + | |
| 100 | + | return await semaphore_gather(tasks) |
| 101 | + | |
| 102 | + | |
| 103 | + | class AsyncioProgressbarQueueExecutor(AsyncExecutor): |
| 104 | + | def __init__(self, *args, **kwargs): |
| 105 | + | super().__init__(*args, **kwargs) |
| 106 | + | self.workers_count = kwargs.get('in_parallel', 10) |
| 107 | + | self.progress_func = kwargs.get('progress_func', tqdm.tqdm) |
| 108 | + | self.queue = asyncio.Queue(self.workers_count) |
| 109 | + | |
| 110 | + | async def worker(self): |
| 111 | + | while True: |
| 112 | + | f, args, kwargs = await self.queue.get() |
| 113 | + | result = await f(*args, **kwargs) |
| 114 | + | self.results.append(result) |
| 115 | + | self.progress.update(1) |
| 116 | + | self.queue.task_done() |
| 117 | + | |
| 118 | + | async def _run(self, tasks: QueriesDraft): |
| 119 | + | self.results = [] |
| 120 | + | workers = [asyncio.create_task(self.worker()) |
| 121 | + | for _ in range(self.workers_count)] |
| 122 | + | task_list = list(tasks) |
| 123 | + | self.progress = self.progress_func(total=len(task_list)) |
| 124 | + | for t in task_list: |
| 125 | + | await self.queue.put(t) |
| 126 | + | await self.queue.join() |
| 127 | + | for w in workers: |
| 128 | + | w.cancel() |
| 129 | + | self.progress.close() |
| 130 | + | return self.results |
39 | 131 | | |
40 | 132 | | |
41 | 133 | | async def get_response(request_future, site_name, logger): |
| skipped 45 lines |
87 | 179 | | return html_text, status_code, error_text, expection_text |
88 | 180 | | |
89 | 181 | | |
90 | | - | async def update_site_dict_from_response(sitename, site_dict, results_info, semaphore, logger, query_notify): |
91 | | - | async with semaphore: |
92 | | - | site_obj = site_dict[sitename] |
93 | | - | future = site_obj.request_future |
94 | | - | if not future: |
95 | | - | # ignore: search by incompatible id type |
96 | | - | return |
| 182 | + | async def update_site_dict_from_response(sitename, site_dict, results_info, logger, query_notify): |
| 183 | + | site_obj = site_dict[sitename] |
| 184 | + | future = site_obj.request_future |
| 185 | + | if not future: |
| 186 | + | # ignore: search by incompatible id type |
| 187 | + | return |
97 | 188 | | |
98 | | - | response = await get_response(request_future=future, |
99 | | - | site_name=sitename, |
100 | | - | logger=logger) |
| 189 | + | response = await get_response(request_future=future, |
| 190 | + | site_name=sitename, |
| 191 | + | logger=logger) |
101 | 192 | | |
102 | | - | site_dict[sitename] = process_site_result(response, query_notify, logger, results_info, site_obj) |
| 193 | + | return sitename, process_site_result(response, query_notify, logger, results_info, site_obj) |
103 | 194 | | |
104 | 195 | | |
105 | 196 | | # TODO: move to separate class |
| skipped 348 lines |
454 | 545 | | # Add this site's results into final dictionary with all of the other results. |
455 | 546 | | results_total[site_name] = results_site |
456 | 547 | | |
457 | | - | # TODO: move into top-level function |
458 | | - | |
459 | | - | sem = asyncio.Semaphore(max_connections) |
460 | | - | |
461 | | - | tasks = [] |
| 548 | + | coroutines = [] |
462 | 549 | | for sitename, result_obj in results_total.items(): |
463 | | - | update_site_coro = update_site_dict_from_response(sitename, site_dict, result_obj, sem, logger, query_notify) |
464 | | - | future = asyncio.ensure_future(update_site_coro) |
465 | | - | tasks.append(future) |
| 550 | + | coroutines.append((update_site_dict_from_response, [sitename, site_dict, result_obj, logger, query_notify], {})) |
466 | 551 | | |
467 | 552 | | if no_progressbar: |
468 | | - | await asyncio.gather(*tasks) |
| 553 | + | executor = AsyncioSimpleExecutor(logger=logger) |
469 | 554 | | else: |
470 | | - | for f in tqdm.asyncio.tqdm.as_completed(tasks, timeout=timeout): |
471 | | - | try: |
472 | | - | await f |
473 | | - | except asyncio.exceptions.TimeoutError: |
474 | | - | # TODO: write timeout to results |
475 | | - | pass |
| 555 | + | executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=max_connections, timeout=timeout+0.5) |
| 556 | + | |
| 557 | + | results = await executor.run(coroutines) |
476 | 558 | | |
477 | 559 | | await session.close() |
478 | 560 | | |
479 | 561 | | # Notify caller that all queries are finished. |
480 | 562 | | query_notify.finish() |
481 | 563 | | |
482 | | - | return results_total |
| 564 | + | data = {} |
| 565 | + | for result in results: |
| 566 | + | # TODO: still can be empty |
| 567 | + | if result: |
| 568 | + | try: |
| 569 | + | data[result[0]] = result[1] |
| 570 | + | except Exception as e: |
| 571 | + | logger.error(e, exc_info=True) |
| 572 | + | logger.info(result) |
| 573 | + | |
| 574 | + | return data |
483 | 575 | | |
484 | 576 | | |
485 | 577 | | def timeout_check(value): |
| skipped 130 lines |