Jit/bitvector.cpp (279 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) #include "Jit/bitvector.h" #include <algorithm> #include <cstdint> namespace jit { namespace util { BitVector::~BitVector() { if (!IsShortVector()) { delete bits_.bit_vec; } } BitVector::BitVector(size_t nb) : num_bits_(nb) { if (IsShortVector()) { bits_.bits = 0; } else { size_t size = num_bits_ / PTR_WIDTH + (num_bits_ % PTR_WIDTH == 0 ? 0 : 1); bits_.bit_vec = new std::vector<uint64_t>(size, 0); } } BitVector& BitVector::operator=(const BitVector& bv) { if (this == &bv) { return *this; } bool lhs_short = IsShortVector(); bool rhs_short = bv.IsShortVector(); num_bits_ = bv.num_bits_; if (lhs_short && rhs_short) { bits_.bits = bv.bits_.bits; } else if (lhs_short && !rhs_short) { bits_.bit_vec = new std::vector<uint64_t>(*bv.bits_.bit_vec); } else if (!lhs_short && rhs_short) { delete bits_.bit_vec; bits_.bits = bv.bits_.bits; } else { // if (!lhs_short && !rhs_short) *bits_.bit_vec = *bv.bits_.bit_vec; } return *this; } BitVector& BitVector::operator=(BitVector&& bv) { if (this == &bv) { return *this; } if (!IsShortVector()) { delete bits_.bit_vec; } num_bits_ = bv.num_bits_; bits_ = bv.bits_; bv.num_bits_ = 0; return *this; } bool BitVector::operator==(const BitVector& rhs) const { JIT_CHECK(num_bits_ == rhs.num_bits_, "LHS and RHS are of different widths."); if (IsShortVector()) { return bits_.bits == rhs.bits_.bits; } return std::equal( bits_.bit_vec->begin(), bits_.bit_vec->end(), rhs.bits_.bit_vec->begin()); } template <typename Op> BitVector BitVector::BinaryOp(const BitVector& rhs, const Op& op) const { JIT_CHECK(num_bits_ == rhs.num_bits_, "LHS and RHS are of different widths."); if (IsShortVector()) { return BitVector(num_bits_, op(bits_.bits, rhs.bits_.bits)); } BitVector bv; bv.num_bits_ = num_bits_; bv.bits_.bit_vec = new std::vector<uint64_t>; bv.bits_.bit_vec->reserve(bits_.bit_vec->size()); std::transform( bits_.bit_vec->begin(), bits_.bit_vec->end(), rhs.bits_.bit_vec->begin(), std::back_inserter(*bv.bits_.bit_vec), [op](uint64_t a, uint64_t b) -> uint64_t { return op(a, b); }); return bv; } BitVector BitVector::operator&(const BitVector& rhs) const { return BinaryOp( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a & b; }); } BitVector BitVector::operator|(const BitVector& rhs) const { return BinaryOp( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a | b; }); } BitVector BitVector::operator-(const BitVector& rhs) const { return BinaryOp( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a & ~b; }); } template <typename Op> BitVector& BitVector::BinaryOpAssign(const BitVector& rhs, const Op& op) { JIT_CHECK(num_bits_ == rhs.num_bits_, "LHS and RHS are of different widths."); if (IsShortVector()) { bits_.bits = op(bits_.bits, rhs.bits_.bits); } else { std::transform( bits_.bit_vec->begin(), bits_.bit_vec->end(), rhs.bits_.bit_vec->begin(), bits_.bit_vec->begin(), [op](uint64_t a, uint64_t b) -> uint64_t { return op(a, b); }); } return *this; } BitVector& BitVector::operator&=(const BitVector& rhs) { return BinaryOpAssign( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a & b; }); } BitVector& BitVector::operator|=(const BitVector& rhs) { return BinaryOpAssign( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a | b; }); } BitVector& BitVector::operator-=(const BitVector& rhs) { return BinaryOpAssign( rhs, [](uint64_t a, uint64_t b) -> uint64_t { return a & ~b; }); } void BitVector::ResetAll() { if (IsShortVector()) { bits_.bits = 0; } else { for (auto& v : *(bits_.bit_vec)) { v = 0; } } } void BitVector::fill(bool v) { if (!v) { return ResetAll(); } if (IsShortVector()) { if (num_bits_ == PTR_WIDTH) { bits_.bits = -1; } else { bits_.bits = (uintptr_t{1} << num_bits_) - 1; } } else { auto& vec = *bits_.bit_vec; for (size_t i = 0; i < vec.size() - 1; ++i) { vec[i] = -1; } auto remainder = num_bits_ % PTR_WIDTH; if (remainder == 0) { vec.back() = -1; } else { vec.back() = (uintptr_t{1} << remainder) - 1; } } } void BitVector::SetBit(size_t bit, bool v) { JIT_CHECK(bit < num_bits_, "bit is too large."); if (IsShortVector()) { auto b = uintptr_t(1) << bit; bits_.bits = v ? (bits_.bits | b) : (bits_.bits & ~b); } else { size_t index = bit / PTR_WIDTH; size_t offset = bit % PTR_WIDTH; auto& val = bits_.bit_vec->at(index); uintptr_t b = uintptr_t(1) << offset; val = v ? (val | b) : (val & ~b); } } size_t BitVector::AddBits(size_t i) { auto new_num_bits = num_bits_ + i; SetBitWidth(new_num_bits); return new_num_bits; } void BitVector::SetBitWidth(size_t size) { if (num_bits_ == size) { return; } bool old_short = IsShortVector(); auto new_num_bits = size; num_bits_ = new_num_bits; bool new_short = IsShortVector(); if (old_short && !new_short) { size_t size = num_bits_ / PTR_WIDTH + (num_bits_ % PTR_WIDTH == 0 ? 0 : 1); auto old_bits = bits_.bits; bits_.bit_vec = new std::vector<uint64_t>(size); bits_.bit_vec->at(0) = old_bits; } else if (!old_short && !new_short) { size_t size = num_bits_ / PTR_WIDTH + (num_bits_ % PTR_WIDTH == 0 ? 0 : 1); bits_.bit_vec->resize(size); } else if (!old_short && new_short) { auto low_bits = bits_.bit_vec->at(0); delete bits_.bit_vec; bits_.bits = low_bits; } // need to clear the unused upper bits // could use BZHI instruction, but this function is not frequently called, // so it is okay. auto high_mask = (uint64_t(1) << (num_bits_ % PTR_WIDTH)) - 1; if (new_short) { bits_.bits &= high_mask; } else { auto& chunk = *bits_.bit_vec->rbegin(); chunk &= high_mask; } } bool BitVector::GetBit(size_t bit) const { JIT_CHECK(bit < num_bits_, "bit is out of range."); if (IsShortVector()) { auto b = uintptr_t(1) << bit; return bits_.bits & b; } size_t index = bit / PTR_WIDTH; size_t offset = bit % PTR_WIDTH; return bits_.bit_vec->at(index) & (uintptr_t(1) << offset); } void BitVector::forEachSetBit(std::function<void(size_t)> per_bit_func) const { auto forEachBitInChunk = [&](uint64_t chunk, size_t base) { while (chunk) { int bit = __builtin_ctzl(chunk); chunk ^= chunk & -chunk; per_bit_func(bit + base); } }; if (IsShortVector()) { forEachBitInChunk(bits_.bits, 0); } else { size_t chunk_base = 0; for (uint64_t chunk : *bits_.bit_vec) { forEachBitInChunk(chunk, chunk_base); chunk_base += PTR_WIDTH; } } } uint64_t BitVector::GetBitChunk(size_t chunk) const { if (IsShortVector()) { JIT_CHECK(chunk == 0, "chunk is out of range."); return bits_.bits; } JIT_CHECK(chunk < bits_.bit_vec->size(), "chunk is out of range."); return bits_.bit_vec->at(chunk); } void BitVector::SetBitChunk(size_t chunk, uint64_t bits) { auto num_chunks = (num_bits_ + PTR_WIDTH - 1) / PTR_WIDTH; JIT_CHECK(chunk < num_chunks, "chunk is out of range"); if (chunk == num_chunks - 1) { auto remainder = num_bits_ % PTR_WIDTH; if (remainder != 0) { auto mask = ~((uint64_t{1} << remainder) - 1); JIT_CHECK((mask & bits) == 0, "invalid bit chunk"); } } if (IsShortVector()) { bits_.bits = bits; return; } (*bits_.bit_vec)[chunk] = bits; } size_t BitVector::GetPopCount() const { if (IsShortVector()) { return __builtin_popcountll(bits_.bits); } size_t count = 0; for (auto& b : *bits_.bit_vec) { count += __builtin_popcountll(b); } return count; } bool BitVector::IsEmpty() const { if (IsShortVector()) { return bits_.bits == 0; } for (auto& b : *bits_.bit_vec) { if (b != 0) { return false; } } return true; } std::ostream& operator<<(std::ostream& os, const BitVector& bv) { os << '['; for (std::size_t i = 0, n = bv.GetNumBits(); i < n; ++i) { if (i > 0 && (i % 8) == 0) { os << ';'; } os << (bv.GetBit(i) ? '1' : '0'); } os << ']'; return os; } } // namespace util } // namespace jit