parallel_executor.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import random
  2. import time
  3. from concurrent.futures import ThreadPoolExecutor, as_completed
  4. from typing import Any, Callable, List, Optional
  5. def parallel_exec(
  6. fn: Callable,
  7. list_of_kwargs: List[dict],
  8. max_workers: Optional[int] = None,
  9. jitter: float = 0.0,
  10. ) -> list:
  11. """
  12. Executes a given function `fn` in parallel, using multiple threads, on a list of argument tuples.
  13. The function limits the number of concurrent executions to `max_workers` and processes tasks in chunks,
  14. pausing between each chunk to avoid hitting rate limits or quotas.
  15. Args:
  16. - fn (Callable): The function to execute in parallel.
  17. - list_of_kwargs (list): A list of dicts, where each dict contains arguments for a single call to `fn`.
  18. - max_workers (int, optional): The maximum number of threads that can be used to execute the tasks
  19. concurrently.
  20. - jitter (float, optional): Wait for jitter * random.random() before submitting the next job.
  21. Returns:
  22. - A list containing the results of the function calls. The order of the results corresponds to the order
  23. the tasks were completed, which may not necessarily be the same as the order of `list_of_kwargs`.
  24. """
  25. results = []
  26. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  27. # Get the tasks for the current chunk
  28. futures = []
  29. for kwargs in list_of_kwargs:
  30. futures.append(executor.submit(fn, **kwargs))
  31. if jitter > 0.0:
  32. time.sleep(jitter * random.random())
  33. for future in as_completed(futures):
  34. results.append(future.result())
  35. return results
  36. # for debug
  37. def serial_exec(fn: Callable, list_of_kwargs: List[dict]) -> List[Any]:
  38. results = []
  39. for kwargs in list_of_kwargs:
  40. result = fn(**kwargs)
  41. results.append(result)
  42. return results