#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
"""Functional tools for creating and using iterators."""
# TODO(T42113424) Replace stubs with an actual implementation

import operator
from builtins import _number_check

from _builtins import (
    _int_check,
    _int_guard,
    _list_len,
    _list_new,
    _tuple_len,
    _Unbound,
    _unimplemented,
)


class accumulate:
    def __iter__(self):
        return self

    def __new__(cls, iterable, func=None, initial=None):
        result = object.__new__(cls)
        result._it = iter(iterable)
        result._func = operator.add if func is None else func
        result._initial = initial
        result._accumulated = None
        return result

    def __next__(self):
        initial = self._initial
        if initial is not None:
            self._accumulated = initial
            self._initial = None
            return initial

        result = self._accumulated

        if result is None:
            result = next(self._it)
            self._accumulated = result
            return result

        result = self._func(result, next(self._it))
        self._accumulated = result
        return result

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()


class chain:
    def __iter__(self):
        return self

    def __new__(cls, *iterables):
        result = object.__new__(cls)
        result._it = None
        result._iterables = iter(iterables)
        return result

    def __next__(self):
        while True:
            if self._it is None:
                try:
                    self._it = iter(next(self._iterables))
                except StopIteration:
                    raise
            try:
                result = next(self._it)
            except StopIteration:
                self._it = None
                continue
            return result

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    @classmethod
    def from_iterable(cls, iterable):
        result = object.__new__(cls)
        result._it = None
        result._iterables = iter(iterable)
        return result


class combinations:
    def __iter__(self):
        return self

    def __new__(cls, iterable, r):
        _int_guard(r)
        if r < 0:
            raise ValueError("r must be non-negative")

        result = object.__new__(cls)

        seq = tuple(iterable)
        n = _tuple_len(seq)

        if r > n:
            result._seq = None
            return result

        result._seq = seq
        result._indices = list(range(r))
        result._r = r
        result._index_delta = n - r
        return result

    def __next__(self):
        seq = self._seq
        if seq is None:
            raise StopIteration

        r = self._r
        indices = self._indices
        index_delta = self._index_delta

        # The result is the elements of the sequence at the current indices
        result = (*(seq[indices[i]] for i in range(r)),)

        # Scan indices right-to-left until finding one that is not at its
        # maximum (i + n - r).
        i = r - 1
        while i >= 0:
            if indices[i] < i + index_delta:
                # Increment the current index which we know is not at its
                # maximum.  Then move back to the right setting each index
                # to its lowest possible value (one higher than the index
                # to its left -- this maintains the sort order invariant).
                indices[i] += 1
                for j in range(i + 1, r):
                    indices[j] = indices[j - 1] + 1
                break
            i -= 1
        else:
            # The indices are all at their maximum values and we're done.
            self._seq = None

        return result

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    def __sizeof__(self):
        _unimplemented()


class combinations_with_replacement:
    def __iter__(self):
        return self

    def __new__(cls, iterable, r):
        _int_guard(r)
        if r < 0:
            raise ValueError("r must be non-negative")

        result = object.__new__(cls)

        seq = tuple(iterable)

        # We can't create combinations is if seq is empty and r > 0
        if not seq and r:
            result._seq = None
            return result

        result._seq = seq
        result._indices = _list_new(r, 0)
        result._r = r
        result._max_index = _tuple_len(seq) - 1
        return result

    def __next__(self):
        seq = self._seq
        if seq is None:
            raise StopIteration

        r = self._r
        indices = self._indices
        max_index = self._max_index

        # The result is the elements of the sequence at the current indices
        result = (*(seq[indices[i]] for i in range(r)),)

        # Scan indices right-to-left until finding one that is not at its
        # maximum (n - 1).
        i = r - 1
        while i >= 0:
            if indices[i] < max_index:
                # Increment the current index which we know is not at its
                # maximum.  Then set all to the right to the same value.
                index = indices[i] = indices[i] + 1
                for j in range(i, r):
                    indices[j] = index
                break
            i -= 1
        else:
            # The indices are all at their maximum values and we're done.
            self._seq = None

        return result

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    def __sizeof__(self):
        _unimplemented()


class compress:
    def __iter__(self):
        return self

    def __new__(cls, data, selectors):
        result = object.__new__(cls)
        result._data = iter(data)
        result._selectors = iter(selectors)
        return result

    def __next__(self):
        data = self._data
        selectors = self._selectors

        while True:
            datum = next(data)
            selector = next(selectors)
            if selector:
                return datum

    def __reduce__(self):
        _unimplemented()


class count:
    def __iter__(self):
        return self

    def __new__(cls, start=0, step=1):
        if not _number_check(start):
            raise TypeError("a number is required")

        result = object.__new__(cls)
        result.count = start
        result.step = step
        return result

    def __next__(self):
        result = self.count
        self.count += self.step
        return result

    def __reduce__(self):
        _unimplemented()

    def __repr__(self):
        return f"count({self.count})"


class cycle:
    def __iter__(self):
        return self

    def __new__(cls, seq):
        result = object.__new__(cls)
        result._seq = iter(seq)
        result._saved = []
        result._first_pass = True
        return result

    def __next__(self):
        try:
            result = next(self._seq)
            if self._first_pass:
                self._saved.append(result)
            return result
        except StopIteration:
            self._first_pass = False
            self._seq = iter(self._saved)
            return next(self._seq)

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()


class dropwhile:
    def __iter__(self):
        return self

    def __new__(cls, predicate, iterable):
        result = object.__new__(cls)
        result._it = iter(iterable)
        result._func = predicate
        result._start = False
        return result

    def __next__(self):
        if self._start:
            return next(self._it)

        func = self._func

        while True:
            item = next(self._it)
            if not func(item):
                self._start = True
                return item

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()


class filterfalse:
    def __iter__(self):
        return self

    def __new__(cls, predicate, iterable):
        result = object.__new__(cls)
        result._it = iter(iterable)
        result._predicate = bool if predicate is None else predicate
        return result

    def __next__(self):
        while True:
            item = next(self._it)
            if not self._predicate(item):
                return item

    def __reduce__(self):
        _unimplemented()


# internal helper class for groupby
class _groupby_iterator:
    def __iter__(self):
        return self

    def __new__(cls, parent, cur):

        obj = object.__new__(cls)
        obj._parent = parent
        obj._currkey = cur
        return obj

    def __next__(self):
        parent = self._parent
        if parent._currkey == self._currkey:
            val = parent._currval
            try:
                parent._currval = next(parent._it)
                parent._currkey = (
                    parent._currval
                    if parent._keyfunc is None
                    else parent._keyfunc(parent._currval)
                )
            except StopIteration:
                parent._currkey = _Unbound
            return val
        raise StopIteration


class groupby:
    def __iter__(self):
        return self

    def __new__(cls, iterable, key=None):

        obj = object.__new__(cls)
        obj._it = iter(iterable)
        obj._tgtkey = obj._currkey = obj._currval = _Unbound
        obj._keyfunc = key
        return obj

    def __next__(self):
        # In middle of previous iterator
        while self._currkey == self._tgtkey:
            self._currval = next(self._it)
            self._currkey = (
                self._currval if self._keyfunc is None else self._keyfunc(self._currval)
            )
        if self._currkey is _Unbound:
            raise StopIteration
        # remember group of returned iterator
        self._tgtkey = self._currkey
        return self._currkey, _groupby_iterator(self, self._currkey)


class islice:
    def __new__(cls, seq, stop_or_start, stop=_Unbound, step=_Unbound):
        result = object.__new__(cls)
        result._it = iter(seq)
        result._count = 0
        if stop is _Unbound:
            start = 0
            stop = stop_or_start
            step = 1
        else:
            start = 0 if stop_or_start is None else stop_or_start
            if step is _Unbound or step is None:
                step = 1
            elif not _int_check(step) or step < 1:
                raise ValueError(
                    "Step for islice() must be a positive integer or None."
                )
        if stop is None:
            stop = -1
        elif not _int_check(stop) or stop == -1:
            raise ValueError(
                "Stop argument for islice() must be None or an "
                "integer: 0 <= x <= sys.maxsize."
            )
        if not _int_check(start) or start < 0 or stop < -1:
            raise ValueError(
                "Indices for islice() must be None or an integer: "
                "0 <= x <= sys.maxsize."
            )
        result._next = start
        result._stop = stop
        result._step = step
        return result

    def __iter__(self):
        return self

    def __next__(self):
        it = self._it
        if it is None:
            raise StopIteration
        count = self._count
        new_next = self._next
        while count < new_next:
            try:
                next(it)
            except Exception as exc:
                self._it = None
                raise exc
            count += 1
        stop = self._stop
        if count >= stop and stop != -1:
            self._it = None
            raise StopIteration
        try:
            item = next(it)
        except Exception as exc:
            self._it = None
            raise exc
        self._count = count + 1
        new_next += self._step
        if new_next > stop and stop != -1:
            new_next = stop
        self._next = new_next
        return item

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()


class permutations:
    def __iter__(self):
        return self

    def __new__(cls, iterable, r=None):
        seq = tuple(iterable)
        n = _tuple_len(seq)

        result = object.__new__(cls)

        if r is None:
            r = n
        elif r > n:
            result._seq = None
            return result

        result._seq = seq
        result._r = r
        result._indices = list(range(n))
        result._cycles = list(range(n, n - r, -1))
        return result

    def __next__(self):
        seq = self._seq
        if seq is None:
            raise StopIteration
        r = self._r
        indices = self._indices
        indices_len = _list_len(indices)
        result = (*(seq[indices[i]] for i in range(r)),)
        cycles = self._cycles
        i = r - 1
        while i >= 0:
            j = cycles[i] - 1
            if j > 0:
                cycles[i] = j
                indices[i], indices[-j] = indices[-j], indices[i]
                break
            cycles[i] = indices_len - i
            tmp = indices[i]
            k = i + 1
            while k < indices_len:
                indices[k - 1] = indices[k]
                k += 1
            indices[k - 1] = tmp
            i -= 1
        else:
            self._seq = None
        return result

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    def __sizeof__(self):
        _unimplemented()


class product:
    def __iter__(self):
        return self

    def __new__(cls, *iterables, repeat=1):
        if not _int_check(repeat):
            raise TypeError
        length = _tuple_len(iterables) if repeat else 0
        i = 0
        repeated = _list_new(length)
        result = object.__new__(cls)
        while i < length:
            item = tuple(iterables[i])
            if not item:
                result._iterables = None
                return result
            repeated[i] = item
            i += 1
        repeated *= repeat
        result._iterables = repeated
        result._digits = _list_new(length * repeat, 0)
        return result

    def __next__(self):
        iterables = self._iterables
        if iterables is None:
            raise StopIteration
        digits = self._digits
        length = _list_len(iterables)
        result = _list_new(length)
        i = length - 1
        carry = 1
        while i >= 0:
            j = digits[i]
            result[i] = iterables[i][j]
            j += carry
            if j < _tuple_len(iterables[i]):
                carry = 0
                digits[i] = j
            else:
                carry = 1
                digits[i] = 0
            i -= 1
        if carry:
            # counter overflowed, stop iteration
            self._iterables = None
        return tuple(result)

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    def __sizeof__(self):
        _unimplemented()


class repeat:
    def __iter__(self):
        return self

    def __new__(cls, elem, times=None):
        result = object.__new__(cls)
        result._elem = elem
        if times is not None:
            _int_guard(times)
        result._times = times
        return result

    def __next__(self):
        if self._times is None:
            return self._elem
        if self._times > 0:
            self._times -= 1
            return self._elem
        raise StopIteration

    def __length_hint__(self):
        _unimplemented()

    def __reduce__(self):
        _unimplemented()

    def __repr__(self):
        _unimplemented()


class starmap:
    def __iter__(self):
        return self

    def __new__(cls, function, iterable):
        result = object.__new__(cls)
        result._it = iter(iterable)
        result._func = function
        return result

    def __next__(self):
        args = next(self._it)
        return self._func(*args)

    def __reduce__(self):
        _unimplemented()


def tee(iterable, n=2):
    _int_guard(n)
    if n < 0:
        raise ValueError("n must be >= 0")
    if n == 0:
        return ()

    it = iter(iterable)
    copyable = it if hasattr(it, "__copy__") else _tee.from_iterable(it)
    copyfunc = copyable.__copy__
    return tuple(copyable if i == 0 else copyfunc() for i in range(n))


# Internal cache for tee, a linked list where each link is a cached window to
# a section of the source iterator
class _tee_dataobject:
    # CPython sets this at 57 to align exactly with cache line size. We choose
    # 55 to align with cache lines in our system: Arrays <=255 elements have 1
    # word of header. The header and each data element is 8 bytes on a 64-bit
    # machine.  Cache lines are 64-bytes on all x86 machines though they tend to
    # be fetched in pairs, so any multiple of 8 minus 1 up to 255 is fine.
    _MAX_VALUES = 55

    def __init__(self, it):
        self._num_read = 0
        self._next_link = _Unbound
        self._it = it
        self._values = []

    def get_item(self, i):
        assert i < self.__class__._MAX_VALUES

        if i < self._num_read:
            return self._values[i]
        else:
            assert i == self._num_read
            value = next(self._it)
            self._num_read += 1
            # mutable tuple might be a nice future optimization here
            self._values.append(value)
            return value

    def next_link(self):
        if self._next_link is _Unbound:
            self._next_link = self.__class__(self._it)
        return self._next_link


class _tee:
    def __copy__(self):
        return self.__class__(self._data, self._index)

    def __init__(self, data, index):
        self._data = data
        self._index = index

    def __iter__(self):
        return self

    def __next__(self):
        if self._index >= _tee_dataobject._MAX_VALUES:
            self._data = self._data.next_link()
            self._index = 0

        value = self._data.get_item(self._index)
        self._index += 1
        return value

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()

    @classmethod
    def from_iterable(cls, iterable):
        it = iter(iterable)

        if isinstance(it, _tee):
            return it.__copy__()
        else:
            return cls(_tee_dataobject(it), 0)


class takewhile:
    def __iter__(self):
        return self

    def __new__(cls, predicate, iterable):
        result = object.__new__(cls)
        result._it = iter(iterable)
        result._func = predicate
        result._stop = False
        return result

    def __next__(self):
        if self._stop:
            raise StopIteration

        item = next(self._it)
        if self._func(item):
            return item

        self._stop = True
        raise StopIteration

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()


class zip_longest:
    def __iter__(self):
        return self

    def __new__(cls, *seqs, fillvalue=None):
        length = _tuple_len(seqs)
        result = object.__new__(cls)
        result._iters = [iter(seq) for seq in seqs]
        result._num_iters = length
        result._num_active = length
        result._fillvalue = fillvalue
        return result

    def __next__(self):
        iters = self._iters
        if not self._num_active:
            raise StopIteration
        fillvalue = self._fillvalue
        values = _list_new(self._num_iters, fillvalue)
        for i, it in enumerate(iters):
            try:
                values[i] = next(it)
            except StopIteration:
                self._num_active -= 1
                if not self._num_active:
                    raise
                self._iters[i] = repeat(fillvalue)
        return tuple(values)

    def __reduce__(self):
        _unimplemented()

    def __setstate__(self):
        _unimplemented()
