runtime/memoryview-builtins.cpp (601 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
#include "memoryview-builtins.h"
#include "builtins.h"
#include "bytes-builtins.h"
#include "float-builtins.h"
#include "int-builtins.h"
#include "mmap-module.h"
#include "type-builtins.h"
namespace py {
static const BuiltinAttribute kMemoryViewAttributes[] = {
{ID(_memoryview__buffer), RawMemoryView::kBufferOffset,
AttributeFlags::kHidden},
{ID(format), RawMemoryView::kFormatOffset, AttributeFlags::kReadOnly},
{ID(_memoryview__length), RawMemoryView::kLengthOffset,
AttributeFlags::kHidden},
{ID(readonly), RawMemoryView::kReadOnlyOffset, AttributeFlags::kReadOnly},
{ID(obj), RawMemoryView::kObjectOffset, AttributeFlags::kReadOnly},
{ID(shape), RawMemoryView::kShapeOffset, AttributeFlags::kReadOnly},
{ID(_memoryview__start), RawMemoryView::kStartOffset,
AttributeFlags::kHidden},
{ID(strides), RawMemoryView::kStridesOffset, AttributeFlags::kReadOnly},
{ID(ndim), RawMemoryView::kNdimOffset, AttributeFlags::kReadOnly},
};
void initializeMemoryViewType(Thread* thread) {
addBuiltinType(thread, ID(memoryview), LayoutId::kMemoryView,
/*superclass_id=*/LayoutId::kObject, kMemoryViewAttributes,
MemoryView::kSize, /*basetype=*/false);
}
static char formatChar(const Str& format) {
if (format.length() == 2) {
if (format.byteAt(0) != '@') return -1;
return format.byteAt(1);
}
if (format.length() != 1) return -1;
return format.byteAt(0);
}
static word itemSize(char format) {
switch (format) {
case 'c':
case 'b':
case 'B':
return kByteSize;
case 'h':
case 'H':
return kShortSize;
case 'i':
case 'I':
return kIntSize;
case 'l':
case 'L':
return kLongSize;
case 'q':
case 'Q':
return kLongLongSize;
case 'n':
case 'N':
return kWordSize;
case 'f':
return kFloatSize;
case 'd':
return kDoubleSize;
case '?':
return kBoolSize;
case 'P':
return kPointerSize;
default:
return -1;
}
}
word memoryviewItemsize(Thread* thread, const MemoryView& view) {
HandleScope scope(thread);
Str format(&scope, view.format());
char format_c = formatChar(format);
DCHECK(format_c > 0, "invalid memoryview");
word item_size = itemSize(format_c);
DCHECK(item_size > 0, "invalid memoryview");
return item_size;
}
static RawObject raiseInvalidValueError(Thread* thread, char format) {
return thread->raiseWithFmt(LayoutId::kValueError,
"memoryview: invalid value for format '%c'",
format);
}
static RawObject raiseInvalidTypeError(Thread* thread, char format) {
return thread->raiseWithFmt(
LayoutId::kTypeError, "memoryview: invalid type for format '%c'", format);
}
static bool isIntFormat(char format) {
switch (format) {
case 'b':
FALLTHROUGH;
case 'h':
FALLTHROUGH;
case 'i':
FALLTHROUGH;
case 'l':
FALLTHROUGH;
case 'B':
FALLTHROUGH;
case 'H':
FALLTHROUGH;
case 'I':
FALLTHROUGH;
case 'L':
FALLTHROUGH;
case 'q':
FALLTHROUGH;
case 'Q':
FALLTHROUGH;
case 'n':
FALLTHROUGH;
case 'N':
FALLTHROUGH;
case 'P':
return true;
default:
return false;
}
}
static RawObject packObject(Thread* thread, uword address, char format,
word index, RawObject value) {
byte* dst = reinterpret_cast<byte*>(address + index);
if (isIntFormat(format)) {
if (!value.isInt()) return Unbound::object();
switch (format) {
case 'b': {
OptInt<char> opt_val = RawInt::cast(value).asInt<char>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'h': {
OptInt<short> opt_val = RawInt::cast(value).asInt<short>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'i': {
OptInt<int> opt_val = RawInt::cast(value).asInt<int>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'l': {
OptInt<long> opt_val = RawInt::cast(value).asInt<long>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'B': {
OptInt<unsigned char> opt_val =
RawInt::cast(value).asInt<unsigned char>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'H': {
OptInt<unsigned short> opt_val =
RawInt::cast(value).asInt<unsigned short>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'I': {
OptInt<unsigned int> opt_val =
RawInt::cast(value).asInt<unsigned int>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'L': {
OptInt<unsigned long> opt_val =
RawInt::cast(value).asInt<unsigned long>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'q': {
OptInt<long long> opt_val = RawInt::cast(value).asInt<long long>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'Q': {
OptInt<unsigned long long> opt_val =
RawInt::cast(value).asInt<unsigned long long>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'n': {
OptInt<ssize_t> opt_val = RawInt::cast(value).asInt<ssize_t>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'N': {
OptInt<size_t> opt_val = RawInt::cast(value).asInt<size_t>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
case 'P': {
OptInt<uintptr_t> opt_val = RawInt::cast(value).asInt<uintptr_t>();
if (opt_val.error != CastError::None) {
return raiseInvalidValueError(thread, format);
}
std::memcpy(dst, &opt_val.value, sizeof(opt_val.value));
break;
}
}
return NoneType::object();
}
switch (format) {
case 'f': {
if (!value.isFloat()) return Unbound::object();
float value_float = Float::cast(floatUnderlying(value)).value();
std::memcpy(dst, &value_float, sizeof(value_float));
return NoneType::object();
}
case 'd': {
if (!value.isFloat()) return Unbound::object();
double value_double = Float::cast(floatUnderlying(value)).value();
std::memcpy(dst, &value_double, sizeof(value_double));
return NoneType::object();
}
case 'c': {
if (!value.isBytes()) return raiseInvalidTypeError(thread, format);
RawBytes value_bytes = bytesUnderlying(value);
if (value_bytes.length() != 1) {
return raiseInvalidValueError(thread, format);
}
*dst = value_bytes.byteAt(0);
return NoneType::object();
}
case '?': {
if (!value.isBool()) return Unbound::object();
bool value_bool = Bool::cast(value).value();
std::memcpy(dst, &value_bool, sizeof(value_bool));
return NoneType::object();
}
default:
UNREACHABLE("invalid format");
}
return NoneType::object();
}
static RawObject unpackObject(Thread* thread, uword address, word length,
char format, word index) {
Runtime* runtime = thread->runtime();
DCHECK_INDEX(index, length - static_cast<word>(itemSize(format) - 1));
byte* src = reinterpret_cast<byte*>(address + index);
switch (format) {
case 'c':
return runtime->newBytes(1, Utils::readBytes<byte>(src));
case 'b':
return RawSmallInt::fromWord(Utils::readBytes<signed char>(src));
case 'B':
return RawSmallInt::fromWord(Utils::readBytes<unsigned char>(src));
case 'h':
return RawSmallInt::fromWord(Utils::readBytes<short>(src));
case 'H':
return RawSmallInt::fromWord(Utils::readBytes<unsigned short>(src));
case 'i':
return runtime->newInt(Utils::readBytes<int>(src));
case 'I':
return runtime->newInt(Utils::readBytes<unsigned int>(src));
case 'l':
return runtime->newInt(Utils::readBytes<long>(src));
case 'L':
return runtime->newIntFromUnsigned(Utils::readBytes<unsigned long>(src));
case 'q':
return runtime->newInt(Utils::readBytes<long long>(src));
case 'Q':
return runtime->newIntFromUnsigned(
Utils::readBytes<unsigned long long>(src));
case 'n':
return runtime->newInt(Utils::readBytes<ssize_t>(src));
case 'N':
return runtime->newIntFromUnsigned(Utils::readBytes<size_t>(src));
case 'P':
return runtime->newIntFromCPtr(Utils::readBytes<void*>(src));
case 'f':
return runtime->newFloat(Utils::readBytes<float>(src));
case 'd':
return runtime->newFloat(Utils::readBytes<double>(src));
case '?': {
return Bool::fromBool(Utils::readBytes<byte>(src) != 0);
}
default:
UNREACHABLE("invalid format");
}
}
// Helper function that returns the location within the memoryview buffer to
// find requested index
static word bufferIndex(const MemoryView& view, word index) {
word step = intUnderlying(Tuple::cast(view.strides()).at(0)).asWord();
if (step != 1) {
UNIMPLEMENTED("Stride != 1 is not yet supported");
}
DCHECK_INDEX(index, view.length());
return view.start() + index;
}
RawObject memoryviewGetitem(Thread* thread, const MemoryView& view,
word index) {
HandleScope scope(thread);
Object buffer(&scope, view.buffer());
Runtime* runtime = thread->runtime();
// TODO(T36619828) support str subclasses
Str format(&scope, view.format());
char format_c = formatChar(format);
// TODO(T58046846): Replace DCHECK(char > 0) checks
DCHECK(format_c > 0, "invalid memoryview");
word item_size = itemSize(format_c);
DCHECK(item_size > 0, "invalid memoryview");
word buffer_index = bufferIndex(view, index);
if (runtime->isInstanceOfBytes(*buffer)) {
// TODO(T38246066) support bytes subclasses
if (buffer.isLargeBytes()) {
LargeBytes bytes(&scope, *buffer);
return unpackObject(thread, bytes.address(), bytes.length(), format_c,
buffer_index);
}
CHECK(buffer.isSmallBytes(),
"memoryview.__getitem__ with non bytes/memory");
Bytes bytes(&scope, *buffer);
byte bytes_buffer[SmallBytes::kMaxLength];
bytes.copyTo(bytes_buffer, bytes.length());
return unpackObject(thread, reinterpret_cast<uword>(bytes_buffer),
bytes.length(), format_c, buffer_index);
}
CHECK(buffer.isPointer(), "memoryview.__getitem__ with non bytes/memory");
void* cptr = Pointer::cast(*buffer).cptr();
word ptr_length = Pointer::cast(*buffer).length();
return unpackObject(thread, reinterpret_cast<uword>(cptr), ptr_length,
format_c, buffer_index);
}
RawObject memoryviewGetslice(Thread* thread, const MemoryView& view, word start,
word stop, word step) {
if (step != 1) {
UNIMPLEMENTED("Stride != 1 is not yet supported");
}
HandleScope scope(thread);
Runtime* runtime = thread->runtime();
Str format(&scope, view.format());
char format_c = formatChar(format);
// TODO(T58046846): Replace DCHECK(char > 0) checks
DCHECK(format_c > 0, "invalid memoryview");
word item_size = itemSize(format_c);
DCHECK(item_size > 0, "invalid memoryview");
word slice_len = Slice::length(start, stop, step);
word slice_byte_size = slice_len * item_size;
Object buffer(&scope, view.buffer());
Object obj(&scope, view.object());
MemoryView result(
&scope, runtime->newMemoryView(
thread, obj, buffer, slice_byte_size,
view.readOnly() ? ReadOnly::ReadOnly : ReadOnly::ReadWrite));
result.setFormat(view.format());
result.setStart(view.start() + start * item_size);
return *result;
}
RawObject memoryviewSetitem(Thread* thread, const MemoryView& view, word index,
const Object& value) {
HandleScope scope(thread);
Object buffer(&scope, view.buffer());
Str format(&scope, view.format());
char fmt = formatChar(format);
// TODO(T58046846): Replace DCHECK(char > 0) checks
DCHECK(fmt > 0, "invalid memoryview");
word item_size = itemSize(fmt);
DCHECK(item_size > 0, "invalid memoryview");
word buffer_index = bufferIndex(view, index);
if (buffer.isMutableBytes()) {
return packObject(thread, LargeBytes::cast(*buffer).address(), fmt,
buffer_index, *value);
}
DCHECK(buffer.isPointer(), "memoryview.__setitem__ with non bytes/memory");
void* cptr = Pointer::cast(*buffer).cptr();
return packObject(thread, reinterpret_cast<uword>(cptr), fmt, buffer_index,
*value);
}
static RawObject raiseDifferentStructureError(Thread* thread) {
return thread->raiseWithFmt(
LayoutId::kValueError,
"memoryview assignment: lvalue and rvalue have different structures");
}
RawObject memoryviewSetslice(Thread* thread, const MemoryView& view, word start,
word stop, word step, word slice_len,
const Object& value_obj) {
HandleScope scope(thread);
Runtime* runtime = thread->runtime();
word stride = intUnderlying(Tuple::cast(view.strides()).at(0)).asWord();
if (view.start() != 0 || stride != 1) {
UNIMPLEMENTED("Set item with slicing called on a sliced memoryview");
}
Str format(&scope, view.format());
char fmt = formatChar(format);
// TODO(T58046846): Replace DCHECK(char > 0) checks
DCHECK(fmt > 0, "invalid memoryview");
Object buffer(&scope, view.buffer());
Bytes value_bytes(&scope, Bytes::empty());
if (runtime->isInstanceOfBytes(*value_obj)) {
value_bytes = *value_obj;
if (fmt != 'B' || value_bytes.length() != slice_len) {
return raiseDifferentStructureError(thread);
}
} else if (runtime->isInstanceOfBytearray(*value_obj)) {
Bytearray value_bytearray(&scope, *value_obj);
if (fmt != 'B' || value_bytearray.numItems() != slice_len) {
return raiseDifferentStructureError(thread);
}
value_bytes = value_bytearray.items();
} else if (value_obj.isMemoryView()) {
MemoryView value(&scope, *value_obj);
Str value_format(&scope, value.format());
char value_fmt = formatChar(value_format);
word item_size = itemSize(value_fmt);
DCHECK(item_size > 0, "invalid memoryview");
if (fmt != value_fmt || (value.length() / item_size) != slice_len) {
return raiseDifferentStructureError(thread);
}
byte small_bytes_buf[SmallBytes::kMaxLength];
uword value_address;
Object value_buffer(&scope, value.buffer());
if (value_buffer.isLargeBytes()) {
value_address = LargeBytes::cast(*value_buffer).address();
} else if (value_buffer.isInt()) {
value_address = Int::cast(*value_buffer).asInt<uword>().value;
} else {
DCHECK(value_buffer.isSmallBytes(),
"memoryview.__setitem__ with non bytes/memory");
Bytes bytes(&scope, *value_buffer);
bytes.copyTo(small_bytes_buf, value.length());
value_address = reinterpret_cast<uword>(small_bytes_buf);
}
uword address;
if (buffer.isMutableBytes()) {
address = LargeBytes::cast(*buffer).address();
} else {
DCHECK(buffer.isPointer(),
"memoryview.__setitem__ with non bytes/memory");
address = reinterpret_cast<uword>(Pointer::cast(*buffer).cptr());
}
if (step == 1 && item_size == 1) {
std::memcpy(reinterpret_cast<void*>(address + start),
reinterpret_cast<void*>(value_address), slice_len);
}
start *= item_size;
step *= item_size;
for (; start < stop; value_address += item_size, start += step) {
std::memcpy(reinterpret_cast<void*>(address + start),
reinterpret_cast<void*>(value_address), item_size);
}
return NoneType::object();
} else if (runtime->isByteslike(*value_obj)) {
UNIMPLEMENTED("unsupported bytes-like type");
} else {
return thread->raiseWithFmt(LayoutId::kTypeError,
"a bytes-like object is required, not '%T'",
&value_obj);
}
byte* address;
if (buffer.isMutableBytes()) {
address = reinterpret_cast<byte*>(LargeBytes::cast(*buffer).address());
} else {
DCHECK(buffer.isPointer(), "memoryview.__setitem__ with non bytes/memory");
address = static_cast<byte*>(Pointer::cast(*buffer).cptr());
}
if (step == 1) {
value_bytes.copyTo(address + start, slice_len);
return NoneType::object();
}
for (word i = 0; start < stop; i++, start += step) {
address[start] = value_bytes.byteAt(i);
}
return NoneType::object();
}
static word pow2_remainder(word dividend, word divisor) {
DCHECK(divisor > 0 && Utils::isPowerOfTwo(divisor), "must be power of two");
word mask = divisor - 1;
return dividend & mask;
}
RawObject METH(memoryview, cast)(Thread* thread, Arguments args) {
HandleScope scope(thread);
Object self_obj(&scope, args.get(0));
if (!self_obj.isMemoryView()) {
return thread->raiseRequiresType(self_obj, ID(memoryview));
}
MemoryView self(&scope, *self_obj);
Runtime* runtime = thread->runtime();
Object format_obj(&scope, args.get(1));
if (!runtime->isInstanceOfStr(*format_obj)) {
return thread->raiseWithFmt(LayoutId::kTypeError,
"format argument must be a string");
}
Str format(&scope, *format_obj);
char format_c = formatChar(format);
word item_size;
if (format_c < 0 || (item_size = itemSize(format_c)) < 0) {
return thread->raiseWithFmt(
LayoutId::kValueError,
"memoryview: destination must be a native single character format "
"prefixed with an optional '@'");
}
word length = self.length();
if (pow2_remainder(length, item_size) != 0) {
return thread->raiseWithFmt(
LayoutId::kValueError,
"memoryview: length is not a multiple of itemsize");
}
Object buffer(&scope, self.buffer());
Object obj(&scope, self.object());
MemoryView result(
&scope, runtime->newMemoryView(
thread, obj, buffer, length,
self.readOnly() ? ReadOnly::ReadOnly : ReadOnly::ReadWrite));
result.setFormat(*format);
return *result;
}
RawObject METH(memoryview, __len__)(Thread* thread, Arguments args) {
HandleScope scope(thread);
Object self_obj(&scope, args.get(0));
if (!self_obj.isMemoryView()) {
return thread->raiseRequiresType(self_obj, ID(memoryview));
}
MemoryView self(&scope, *self_obj);
// TODO(T36619828) support str subclasses
Str format(&scope, self.format());
char format_c = formatChar(format);
DCHECK(format_c > 0, "invalid format");
word item_size = itemSize(format_c);
DCHECK(item_size > 0, "invalid memoryview");
return SmallInt::fromWord(self.length() / item_size);
}
RawObject METH(memoryview, __new__)(Thread* thread, Arguments args) {
HandleScope scope(thread);
Runtime* runtime = thread->runtime();
if (args.get(0) != runtime->typeAt(LayoutId::kMemoryView)) {
return thread->raiseWithFmt(LayoutId::kTypeError,
"memoryview.__new__(X): X is not 'memoryview'");
}
Object object(&scope, args.get(1));
if (runtime->isInstanceOfBytes(*object)) {
Bytes bytes(&scope, bytesUnderlying(*object));
MemoryView result(
&scope, runtime->newMemoryView(thread, object, bytes, bytes.length(),
ReadOnly::ReadOnly));
return *result;
}
if (runtime->isInstanceOfBytearray(*object)) {
Bytearray bytearray(&scope, *object);
Bytes bytes(&scope, bytearray.items());
MemoryView result(&scope, runtime->newMemoryView(thread, object, bytes,
bytearray.numItems(),
ReadOnly::ReadWrite));
return *result;
}
if (object.isMemoryView()) {
MemoryView view(&scope, *object);
Object buffer(&scope, view.buffer());
Object view_obj(&scope, view.object());
MemoryView result(
&scope, runtime->newMemoryView(thread, view_obj, buffer, view.length(),
view.readOnly() ? ReadOnly::ReadOnly
: ReadOnly::ReadWrite));
result.setFormat(view.format());
return *result;
}
if (object.isMmap()) {
Mmap mmap_obj(&scope, *object);
Pointer pointer(&scope, mmap_obj.data());
MemoryView result(
&scope,
runtime->newMemoryViewFromCPtr(
thread, object, pointer.cptr(), pointer.length(),
mmap_obj.isWritable() ? ReadOnly::ReadWrite : ReadOnly::ReadOnly));
result.setFormat(SmallStr::fromCodePoint('B'));
return *result;
}
// Handle a buffer protocol object. Ideally we would skip an intermediate
// copy, but for now we make one copy of the data into a bytes.
// TODO(T85440357): Point directly to the buffer protocol object or
// bufferinfo from the memoryview.
Object bytes(&scope, newBytesFromBuffer(thread, object));
if (bytes.isError()) {
return *bytes;
}
return runtime->newMemoryView(
thread, object, bytes, Bytes::cast(*bytes).length(), ReadOnly::ReadOnly);
}
} // namespace py