source/FlattenIterator.h (111 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <iterator> #include <optional> #include <type_traits> namespace marianatrench { template <typename OuterIterator> struct FlattenDereference { using Reference = typename std::iterator_traits<OuterIterator>::reference; using InnerIterator = decltype(std::declval<Reference>().begin()); static InnerIterator begin(Reference reference) { return reference.begin(); } static InnerIterator end(Reference reference) { return reference.end(); } }; template <typename OuterIterator> struct FlattenConstDereference { using Reference = typename std::iterator_traits<OuterIterator>::reference; using InnerIterator = decltype(std::declval<Reference>().cbegin()); static InnerIterator begin(Reference reference) { return reference.cbegin(); } static InnerIterator end(Reference reference) { return reference.cend(); } }; /** * A flattening iterator that iterates on a container of containers. * * For instance, this can be used to treat a `std::vector<std::vector<T>>` as * a single list of `T`. */ template < typename OuterIterator, typename InnerIterator, typename Dereference = FlattenDereference<OuterIterator>> class FlattenIterator { public: using OuterReference = typename std::iterator_traits<OuterIterator>::reference; public: static_assert(std::is_same_v< decltype(Dereference::begin(std::declval<OuterReference>())), InnerIterator>); static_assert(std::is_same_v< decltype(Dereference::end(std::declval<OuterReference>())), InnerIterator>); public: // C++ iterator concept member types using iterator_category = std::forward_iterator_tag; using value_type = typename std::iterator_traits<InnerIterator>::value_type; using difference_type = typename std::iterator_traits<OuterIterator>::difference_type; using pointer = typename std::iterator_traits<InnerIterator>::pointer; using reference = typename std::iterator_traits<InnerIterator>::reference; public: explicit FlattenIterator(OuterIterator begin, OuterIterator end) : outer_(Range<OuterIterator>{std::move(begin), std::move(end)}), inner_(std::nullopt) { if (outer_.begin == outer_.end) { return; } inner_ = Range<InnerIterator>{ Dereference::begin(*outer_.begin), Dereference::end(*outer_.begin)}; advance_empty(); } FlattenIterator(const FlattenIterator&) = default; FlattenIterator(FlattenIterator&&) = default; FlattenIterator& operator=(const FlattenIterator&) = default; FlattenIterator& operator=(FlattenIterator&&) = default; ~FlattenIterator() = default; FlattenIterator& operator++() { ++inner_->begin; advance_empty(); return *this; } FlattenIterator operator++(int) { FlattenIterator result = *this; ++(*this); return result; } bool operator==(const FlattenIterator& other) const { return outer_.begin == other.outer_.begin && ((!inner_.has_value() && !other.inner_.has_value()) || (inner_.has_value() && other.inner_.has_value() && inner_->begin == other.inner_->begin)); } bool operator!=(const FlattenIterator& other) const { return !(*this == other); } reference operator*() { return *inner_->begin; } private: /* Advance the iterator until we find an element. */ void advance_empty() { while (inner_->begin == inner_->end) { ++outer_.begin; if (outer_.begin == outer_.end) { inner_ = std::nullopt; return; } else { inner_ = Range<InnerIterator>{ Dereference::begin(*outer_.begin), Dereference::end(*outer_.begin)}; } } } private: template <typename T> struct Range { T begin; T end; }; private: Range<OuterIterator> outer_; std::optional<Range<InnerIterator>> inner_; }; } // namespace marianatrench