core/common/common_math.c (221 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include <stddef.h> #include <stdint.h> #include <string.h> #include "common_math.h" /** * Count the number of bits that are set (1) in a single byte. * * @param byte The byte to check. * @param stop Bit position to stop looking for set bits. If this is 8 or larger, the entire byte * will be checked. If this is negative, only contiguous bits will be counted. * * @return Number of bits that are set in the specified byte. */ static int common_math_count_set_bits (uint8_t byte, int stop) { int num_bits = 0; while ((byte != 0) && (stop != 0)) { if (byte & 0x01) { ++num_bits; } else if (stop < 0) { stop = 1; } byte >>= 1; stop--; } return num_bits; } /** * Get the number of bits that are set (1) in a single byte. Set bits do not need to be contiguous. * * @param byte Byte to check. * * @return Number of bits that are set in the specified byte. */ int common_math_get_num_bits_set (uint8_t byte) { return common_math_count_set_bits (byte, 8); } /** * Get the number of bits that are set (1) in a single byte before a specific bit position. Set * bits do not need to be contiguous. * * @param byte Byte to check. * @param index Bit position to stop looking for set bits. If this is 8 or larger, the entire byte * will be checked. * * @return Number of bits that are set in the specified byte. */ int common_math_get_num_bits_set_before_index (uint8_t byte, uint8_t index) { return common_math_count_set_bits (byte, index); } /** * Get the total number of bits that are set (1) in an array of bytes. Set bits do not need to be * contiguous. * * @param bytes Byte array to check. * @param length Length of the array. * * @return Total number of bits that are set for all bytes in the array or an error code. */ int common_math_get_num_bits_set_in_array (const uint8_t *bytes, size_t length) { size_t i; int bits = 0; if (bytes == NULL) { return COMMON_MATH_INVALID_ARGUMENT; } for (i = 0; i < length; i++) { bits += common_math_count_set_bits (bytes[i], 8); } return bits; } /** * Get the number of contiguous bits that are set (1) in a single byte, starting from bit 0. Any * bits after the first unset (0) bit will not be counted. * * @param byte Byte to check. * * @return Number of contiguous bits that are set in the specified byte. */ int common_math_get_num_contiguous_bits_set (uint8_t byte) { return common_math_count_set_bits (byte, -1); } /** * Get the total number of contiguous bits that are set (1) in an array of bytes, starting from bit * 0 in byte 0. Any bits after the first unset (0) bit will not be counted. * * @param bytes Byte array to check. * @param length Length of the array. * * @return Total number of contiguous bits that are set for all bytes in the array or an error code. */ int common_math_get_num_contiguous_bits_set_in_array (const uint8_t *bytes, size_t length) { size_t i = 0; int bits = 8; int total = 0; if (bytes == NULL) { return COMMON_MATH_INVALID_ARGUMENT; } while ((i < length) && (bits == 8)) { bits = common_math_count_set_bits (bytes[i++], -1); total += bits; } return total; } /** * Increments a byte array of arbitrary length by 1. * * @param buf Input array to be incremented. This will be treated as a big endian value. * @param len Length of the array. * @param allow_rollover Allows the array value to roll over to 0 when upper boundary is reached. * * @return 0 if the input array is incremented successfully or an error code. */ int common_math_increment_byte_array (uint8_t *buf, size_t length, bool allow_rollover) { size_t index = 0; if ((length == 0) || (buf == NULL)) { return COMMON_MATH_INVALID_ARGUMENT; } while ((index < (length - 1)) && (buf[index] == 0xff)) { buf[index++] = 0; } if ((index == (length - 1)) && (buf[index] == 0xff)) { if (allow_rollover) { buf[index] = 0; } else { memset (buf, 0xff, length); return COMMON_MATH_BOUNDARY_REACHED; } } else { buf[index]++; } return 0; } /** * Check a byte array to see if it contains all zeros. * * @param bytes The byte array to check. * @param length Length of the byte array. * * @return true if all bytes are zero, false otherwise. Empty or null arrays will return false. */ bool common_math_is_array_zero (const uint8_t *bytes, size_t length) { if ((bytes == NULL) || (length == 0)) { return false; } /* memcmp is fine here since the comparison is against a constant value and timing attacks are * not a concern. */ return ((bytes[0] == 0) && (memcmp (bytes, &bytes[1], length - 1) == 0)); } /** * Get the byte position and bit mask for a specific bit in an array of bytes. * * @param bytes The byte array. * @param length Length of the byte array. * @param bit The bit number in the array. * @param byte Output for the byte index in the array. * @param mask Output for the bit mask in the byte. * * @return 0 if the bit mask was successfully determined or an error code. */ static int common_math_get_bit_mask_in_array (const uint8_t *bytes, size_t length, size_t bit, size_t *byte, uint8_t *mask) { if (bytes == NULL) { return COMMON_MATH_INVALID_ARGUMENT; } *byte = bit / 8; *mask = 1U << (bit % 8); if (*byte >= length) { return COMMON_MATH_OUT_OF_RANGE; } return 0; } /** * Check a specific bit position in a byte array and determine if that bit is set. Bit number is * determined as bits 0-7 in byte 0, bits 8-15 in byte 1, bits 16-23 in byte 2, etc. * * @param bytes The byte array to check. * @param length Length of the byte array. * @param bit The bit number in the array to check. * * @return 1 if the bit is set, 0 if the bit is clear, or an error code. */ int common_math_is_bit_set_in_array (const uint8_t *bytes, size_t length, size_t bit) { size_t byte; uint8_t mask; int status; status = common_math_get_bit_mask_in_array (bytes, length, bit, &byte, &mask); if (status != 0) { return status; } return !!(bytes[byte] & mask); } /** * Set a bit at a specific bit position in a byte array. Bit number is determined as bits 0-7 in * byte 0, bits 8-15 in byte 1, bits 16-23 in byte 2, etc. * * @param bytes The byte array to update. * @param length Length of the byte array. * @param bit The bit number in the array to set. * * @return 0 if the bit was set or an error code. */ int common_math_set_bit_in_array (uint8_t *bytes, size_t length, size_t bit) { size_t byte; uint8_t mask; int status; status = common_math_get_bit_mask_in_array (bytes, length, bit, &byte, &mask); if (status != 0) { return status; } bytes[byte] |= mask; return 0; } /** * Clear a bit at a specific bit position in a byte array. Bit number is determined as bits 0-7 in * byte 0, bits 8-15 in byte 1, bits 16-23 in byte 2, etc. * * @param bytes The byte array to update. * @param length Length of the byte array. * @param bit The bit number in the array to clear. * * @return 0 if the bit was cleared or an error code. */ int common_math_clear_bit_in_array (uint8_t *bytes, size_t length, size_t bit) { size_t byte; uint8_t mask; int status; status = common_math_get_bit_mask_in_array (bytes, length, bit, &byte, &mask); if (status != 0) { return status; } bytes[byte] &= ~mask; return 0; } /** * Set the first bit in a byte array that is not already set. The result will be a contiguous * series of set bits that is one bit longer than it was before the call. * * @param bytes The byte array to update. * @param length Length of the byte array. * * @return 0 if the bit was set or an error code. */ int common_math_set_next_bit_in_array (uint8_t *bytes, size_t length) { int bits; if (bytes == NULL) { return COMMON_MATH_INVALID_ARGUMENT; } /* Use the count of contiguous bits as the bit position for the next bit to set. */ bits = common_math_get_num_contiguous_bits_set_in_array (bytes, length); return common_math_set_bit_in_array (bytes, length, bits); } /** * Set bits in a byte array until there are a specified count of contiguous bits. If the count * check is already satisfied, no bits are set. * * @param bytes The byte array to update. * @param length Length of the byte array. * @param even 1 to grow to an even count, 0 for an odd count. * * @return 0 if the bits were set or an error code. */ static int common_math_set_contiguous_bits_to_count (uint8_t *bytes, size_t length, int even) { int bits; if (bytes == NULL) { return COMMON_MATH_INVALID_ARGUMENT; } bits = common_math_get_num_contiguous_bits_set_in_array (bytes, length); while ((bits % 2) == even) { bits = common_math_set_bit_in_array (bytes, length, bits); if (bits == COMMON_MATH_OUT_OF_RANGE) { return bits; } bits = common_math_get_num_contiguous_bits_set_in_array (bytes, length); } return 0; } /** * Set bits in a byte array until there is an even number of contiguous bits that are set. If there * are already an even number of bits set, nothing will be done. * * @param bytes The byte array to update. * @param length Length of the byte array. * * @return 0 if the bits were set or an error code. */ int common_math_set_next_bit_in_array_even_count (uint8_t *bytes, size_t length) { return common_math_set_contiguous_bits_to_count (bytes, length, 1); } /** * Set bits in a byte array until there is an odd number of contiguous bits that are set. If there * are already an odd number of bits set, nothing will be done. * * @param bytes The byte array to update. * @param length Length of the byte array. * * @return 0 if the bits were set or an error code. */ int common_math_set_next_bit_in_array_odd_count (uint8_t *bytes, size_t length) { return common_math_set_contiguous_bits_to_count (bytes, length, 0); } /** * Shift all the bits in an array to the right, starting from the beginning of the array. Left-most * bits will be filled with zeros. * * @param bytes The byte array to shift. * @param length Length of the byte array. * @param shift_bits The number of bits to shift the array. */ void common_math_right_shift_array (uint8_t *bytes, size_t length, size_t shift_bits) { size_t i; size_t shift_bytes; if ((bytes == NULL) || (length == 0) || (shift_bits == 0)) { /* Nothing to do. */ return; } shift_bytes = shift_bits / 8; shift_bits %= 8; if (shift_bytes >= length) { /* The requested shift is larger then the array, so just clear the entire array. */ memset (bytes, 0, length); return; } /* Handle full bytes by moving the whole array to the right. */ memmove (&bytes[shift_bytes], bytes, length - shift_bytes); memset (bytes, 0, shift_bytes); /* Shift each byte, wrapping from the previous byte. */ for (i = (length - 1); i > shift_bytes; i--) { bytes[i] = (bytes[i - 1] << (8 - shift_bits)) | (bytes[i] >> shift_bits); } /* Shift the first byte. */ bytes[shift_bytes] >>= shift_bits; } /** * Shift all the bits in an array to the left, starting from the beginning of the array. Right-most * bits will be filled with zeros. * * @param bytes The byte array to shift. * @param length Length of the byte array. * @param shift_bits The number of bits to shift the array. */ void common_math_left_shift_array (uint8_t *bytes, size_t length, size_t shift_bits) { size_t i; size_t shift_bytes; if ((bytes == NULL) || (length == 0) || (shift_bits == 0)) { /* Nothing to do. */ return; } shift_bytes = shift_bits / 8; shift_bits %= 8; if (shift_bytes >= length) { /* The requested shift is larger then the array, so just clear the entire array. */ memset (bytes, 0, length); return; } /* Handle full bytes by moving the whole array to the left. */ length -= shift_bytes; memmove (bytes, &bytes[shift_bytes], length); memset (&bytes[length], 0, shift_bytes); /* Shift each byte, wrapping from the next byte. */ for (i = 0; i < (length - 1); i++) { bytes[i] = (bytes[i] << shift_bits) | (bytes[i + 1] >> (8 - shift_bits)); } /* Shift the last byte. */ bytes[length - 1] <<= shift_bits; } /** * Saturating increment for 8-bit unsigned integer. * * @param value The value to increment. * * @return The incremented value, or UINT8_MAX if the value is already at the maximum. */ uint8_t common_math_saturating_increment_u8 (uint8_t value) { return ((value == UINT8_MAX) ? UINT8_MAX : (value + 1)); } /** * Saturating increment for 16-bit unsigned integer. * * @param value The value to increment. * * @return The incremented value, or UINT16_MAX if the value is already at the maximum. */ uint16_t common_math_saturating_increment_u16 (uint16_t value) { return ((value == UINT16_MAX) ? UINT16_MAX : (value + 1)); } /** * Saturating increment for 32-bit unsigned integer. * * @param value The value to increment. * * @return The incremented value, or UINT32_MAX if the value is already at the maximum. */ uint32_t common_math_saturating_increment_u32 (uint32_t value) { return ((value == UINT32_MAX) ? UINT32_MAX : (value + 1)); }