| skipped 45 lines |
46 | 46 | | QueriesDraft = Iterable[QueryDraft] |
47 | 47 | | |
48 | 48 | | |
| 49 | + | def create_task_func(): |
| 50 | + | if sys.version_info.minor > 6: |
| 51 | + | create_asyncio_task = asyncio.create_task |
| 52 | + | else: |
| 53 | + | loop = asyncio.get_event_loop() |
| 54 | + | create_asyncio_task = loop.create_task |
| 55 | + | return create_asyncio_task |
| 56 | + | |
49 | 57 | | class AsyncExecutor: |
50 | 58 | | def __init__(self, *args, **kwargs): |
51 | 59 | | self.logger = kwargs['logger'] |
| skipped 57 lines |
109 | 117 | | self.workers_count = kwargs.get('in_parallel', 10) |
110 | 118 | | self.progress_func = kwargs.get('progress_func', tqdm.tqdm) |
111 | 119 | | self.queue = asyncio.Queue(self.workers_count) |
| 120 | + | self.timeout = kwargs.get('timeout') |
112 | 121 | | |
113 | 122 | | async def worker(self): |
114 | 123 | | while True: |
115 | | - | f, args, kwargs = await self.queue.get() |
116 | | - | result = await f(*args, **kwargs) |
| 124 | + | try: |
| 125 | + | f, args, kwargs = self.queue.get_nowait() |
| 126 | + | except asyncio.QueueEmpty: |
| 127 | + | return |
| 128 | + | |
| 129 | + | query_future = f(*args, **kwargs) |
| 130 | + | query_task = create_task_func()(query_future) |
| 131 | + | try: |
| 132 | + | result = await asyncio.wait_for(query_task, timeout=self.timeout) |
| 133 | + | except asyncio.TimeoutError: |
| 134 | + | result = None |
| 135 | + | |
117 | 136 | | self.results.append(result) |
118 | 137 | | self.progress.update(1) |
119 | 138 | | self.queue.task_done() |
120 | 139 | | |
121 | | - | async def _run(self, tasks: QueriesDraft): |
| 140 | + | async def _run(self, queries: QueriesDraft): |
122 | 141 | | self.results = [] |
123 | 142 | | |
124 | | - | if sys.version_info.minor > 6: |
125 | | - | create_task = asyncio.create_task |
126 | | - | else: |
127 | | - | loop = asyncio.get_event_loop() |
128 | | - | create_task = loop.create_task |
| 143 | + | queries_list = list(queries) |
| 144 | + | |
| 145 | + | min_workers = min(len(queries_list), self.workers_count) |
| 146 | + | |
| 147 | + | workers = [create_task_func()(self.worker()) |
| 148 | + | for _ in range(min_workers)] |
129 | 149 | | |
130 | | - | workers = [create_task(self.worker()) |
131 | | - | for _ in range(self.workers_count)] |
132 | | - | task_list = list(tasks) |
133 | | - | self.progress = self.progress_func(total=len(task_list)) |
134 | | - | for t in task_list: |
| 150 | + | self.progress = self.progress_func(total=len(queries_list)) |
| 151 | + | for t in queries_list: |
135 | 152 | | await self.queue.put(t) |
136 | 153 | | await self.queue.join() |
137 | 154 | | for w in workers: |
| skipped 587 lines |