import threading

import queue


class Worker(threading.Thread):
    def __init__(self, queue, init_fn, work_fn, errors):
        super().__init__()
        self.daemon = True
        self.queue = queue
        self.init_fn = init_fn
        self.work_fn = work_fn
        self.errors = errors

    def run(self):
        if self.init_fn:
            init_data = self.init_fn()
        else:
            init_data = {}

        while True:
            try:
                task_data = self.queue.get(False)
            except queue.Empty:
                return

            if task_data is None:
                return

            args, task_kwargs = task_data
            kwargs = init_data.copy() if init_data is not None else {}
            kwargs.update(task_kwargs)
            try:
                self.work_fn(*args, **kwargs)
            except Exception as e:
                self.errors.append(e)
            finally:
                self.queue.task_done()


class ThreadExecutor:
    """Simple executor that runs a single function on multiple threads with
    a list of arguments.

    :param thread_count: Number of threads to use
    :param work_fn: Callable that does the actual work. This is called once per data item.
    :param init_fn: Optional function that's called once per thread. The return value can
                    be a dict of values to pass in to the work_fn."""

    def __init__(self, thread_count, work_fn, init_fn=None):
        self.thread_count = thread_count
        self.work_fn = work_fn
        self.init_fn = init_fn

    def run(self, data):
        """Run the executor with the given data. Returns a list of exceptions that
        occured

        :param data: List of (args, kwargs) to pass to the work_fn, where args is a
        tuple and kwargs is a dict."""
        work_queue = queue.Queue()
        for item in data:
            work_queue.put(item)

        errors = []

        workers = []
        for i in range(self.thread_count):
            workers.append(Worker(work_queue,
                                  self.init_fn,
                                  self.work_fn,
                                  errors))
            workers[-1].start()
        for item in workers:
            item.join()

        return errors
