/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with * fmt::ptr(this\) file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "mozilla/dom/Tensor.h" #include "js/ArrayBuffer.h" #include "js/BigInt.h" #include "js/Value.h" #include "mozilla/Assertions.h" #include "mozilla/Logging.h" #include "mozilla/PodOperations.h" #include "mozilla/RefPtr.h" #include "mozilla/dom/BindingUtils.h" #include "mozilla/dom/ONNXBinding.h" #include "mozilla/dom/Promise.h" #include "mozilla/dom/ScriptSettings.h" #include "mozilla/dom/ToJSValue.h" #include "mozilla/dom/TypedArray.h" #include "nsContentUtils.h" #include "nsStringFwd.h" #include "nsTArray.h" extern mozilla::LazyLogModule gONNXLog; #define LOGD(fmt, ...) \ MOZ_LOG_FMT(gONNXLog, LogLevel::Debug, fmt, ##__VA_ARGS__) namespace mozilla::dom { NS_IMPL_CYCLE_COLLECTION_WRAPPERCACHE(Tensor, mGlobal) NS_IMPL_CYCLE_COLLECTING_ADDREF(Tensor) NS_IMPL_CYCLE_COLLECTING_RELEASE(Tensor) NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(Tensor) NS_WRAPPERCACHE_INTERFACE_MAP_ENTRY NS_INTERFACE_MAP_ENTRY(nsISupports) NS_INTERFACE_MAP_END Tensor::Tensor(const GlobalObject& aGlobal, const nsACString& aType, const ArrayBufferView& aData, const Sequence& aDims) : mType(aType) { LOGD("{}", __PRETTY_FUNCTION__); nsCOMPtr global = do_QueryInterface(aGlobal.GetAsSupports()); mGlobal = global; if (!aData.AppendDataTo(mData)) { size_t len = aData.ProcessFixedData( [&](const Span& aData) -> size_t { return aData.Length(); }); LOGD("{} OOM (size: {})", __PRETTY_FUNCTION__, len); } mDims.AppendElements(aDims); } Tensor::Tensor(const GlobalObject& aGlobal, const nsACString& aType, const nsTArray& aData, const Sequence& aDims) : mType(aType) { LOGD("{} type: {} len: {}", __PRETTY_FUNCTION__, aType, aData.Length()); nsCOMPtr global = do_QueryInterface(aGlobal.GetAsSupports()); mGlobal = std::move(global); // Cast to uint8_t. Type is held in mType mData.AppendElements(aData); mDims.AppendElements(aDims); } Tensor::Tensor(const GlobalObject& aGlobal, ONNXTensorElementDataType aType, nsTArray aData, nsTArray aDims) : mType(ONNXTypeToString(aType)) { LOGD("Output tensor: {} type: {} len: {}", __PRETTY_FUNCTION__, ONNXTypeToString(aType), aData.Length()); nsCOMPtr global = do_QueryInterface(aGlobal.GetAsSupports()); mGlobal = std::move(global); mData = std::move(aData); mDims.AppendElements(aDims); } already_AddRefed Tensor::Constructor( const GlobalObject& global, const nsACString& type, const ArrayBufferViewOrAnySequence& data, const Sequence& dims, ErrorResult& aRv) { if (data.IsAnySequence()) { #define CASE_BIGINT(onnx_type, c_type, conversionfn) \ case onnx_type: { \ nsTArray values; \ for (const JS::Value& element : data.GetAsAnySequence()) { \ JS::BigInt* bigint = element.toBigInt(); \ if (bigint) { \ values.AppendElement(conversionfn(bigint)); \ } else { \ aRv.ThrowTypeError("Inconsistent value in arg 2"); \ return nullptr; \ } \ } \ valuesAsBytes.AppendElements( \ reinterpret_cast(values.Elements()), \ values.Length() * sizeof(c_type)); \ break; \ } #define CASE(onnx_type, c_type, checkfn, conversionfn) \ case onnx_type: { \ nsTArray values; \ for (const auto& element : data.GetAsAnySequence()) { \ if (!element.checkfn()) { \ aRv.ThrowTypeError( \ "Inconsistency between type and value in second argument"); \ return nullptr; \ } \ if (std::numeric_limits::lowest() > element.conversionfn() || \ std::numeric_limits::max() < element.conversionfn()) { \ aRv.ThrowTypeError("Value out of range in arg 2"); \ return nullptr; \ } \ values.AppendElement(element.conversionfn()); \ } \ valuesAsBytes.AppendElements( \ reinterpret_cast(values.Elements()), \ values.Length() * sizeof(c_type)); \ break; \ } nsTArray valuesAsBytes; // Assume constant type, lock on the type of the first element. switch (StringToONNXDataType(type)) { CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, uint8_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, uint16_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, int16_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, int8_t, isNumber, toDouble) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, int16_t, isNumber, toDouble); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, isNumber, toDouble); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, uint32_t, isNumber, toDouble); CASE_BIGINT(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, ToBigInt64); CASE_BIGINT(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, uint64_t, ToBigUint64); case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { for (const auto& element : data.GetAsAnySequence()) { if (!element.isBoolean()) { aRv.ThrowTypeError( "Inconsistency between type and value in second argument"); return nullptr; } valuesAsBytes.AppendElement(element.toBoolean() ? 1 : 0); } break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: MOZ_CRASH("Not handled"); break; } auto rv = MakeRefPtr(global, type, valuesAsBytes, dims); LOGD("Tensor from sequence: {}", rv->ToString().get()); return rv.forget(); } auto rv = MakeRefPtr(global, type, data.GetAsArrayBufferView(), dims); LOGD("Tensor from TypedArray: {}", rv->ToString().get()); return rv.forget(); } // namespace mozilla::dom #undef CASE #undef CASE_BIGINT void Tensor::Dispose() { mData.Clear(); } void Tensor::SetDims(const nsTArray& aVal) { mDims.Clear(); mDims.AppendElements(aVal); } void Tensor::GetDims(nsTArray& aRetVal) { aRetVal.AppendElements(mDims); } void Tensor::GetType(nsCString& aRetVal) const { aRetVal.Assign(mType); } void Tensor::GetData(JSContext* aCx, JS::MutableHandle aRetVal) const { LOGD("{} {} type: {} size: {}", __PRETTY_FUNCTION__, fmt::ptr(this), mType.get(), mData.Length()); #define CASE(onnx_type, typed_array_type, c_type) \ case ONNX_TENSOR_ELEMENT_DATA_TYPE_##onnx_type: { \ nsTArray tmp((c_type*)mData.Elements(), \ mData.Length() / sizeof(c_type)); \ dom::TypedArrayCreator creator(std::move(tmp)); \ aRetVal.set(creator.Create(aCx)); \ break; \ } switch (Type()) { CASE(INT8, Int8Array, int8_t) CASE(UINT8, Uint8Array, uint8_t) CASE(INT16, Int16Array, int16_t) CASE(UINT16, Uint16Array, uint16_t) CASE(INT32, Int32Array, int32_t) CASE(UINT32, Uint32Array, uint32_t) CASE(INT64, BigInt64Array, int64_t) CASE(UINT64, BigUint64Array, uint64_t) CASE(BOOL, Uint8Array, uint8_t) CASE(DOUBLE, Float64Array, double) CASE(FLOAT, Float32Array, float) CASE(STRING, Uint8Array, uint8_t) // hmmm case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: MOZ_CRASH("Missing ONNX data type to js value"); break; } #undef CASE } // namespace mozilla::dom TensorDataLocation Tensor::Location() const { LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); return TensorDataLocation::Cpu; } already_AddRefed Tensor::GetData(const Optional& releaseData) { LOGD("{} {} type: {} size: {}", __PRETTY_FUNCTION__, fmt::ptr(this), mType.get(), mData.Length()); AutoJSContext ctx; RefPtr p = Promise::CreateInfallible(mGlobal); if (releaseData.WasPassed() && releaseData.Value()) { size_t lengthBytes = mData.Length(); UniquePtr tensorData( js_pod_arena_malloc(js::ArrayBufferContentsArena, lengthBytes)); PodCopy(tensorData.get(), mData.Elements(), lengthBytes); JS::Rooted data( ctx, JS::NewArrayBufferWithContents(ctx, lengthBytes, std::move(tensorData))); JS::Rooted value(ctx, JS::ObjectValue(*data)); p->MaybeResolve(value); mData.Clear(); } else { size_t lengthBytes = mData.Length(); UniquePtr tensorData( js_pod_arena_malloc(js::ArrayBufferContentsArena, lengthBytes)); PodCopy(tensorData.get(), mData.Elements(), lengthBytes); JS::Rooted data( ctx, JS::NewArrayBufferWithContents(ctx, lengthBytes, std::move(tensorData))); JS::Rooted value(ctx, JS::ObjectValue(*data)); p->MaybeResolve(value); } return p.forget(); } nsCString Tensor::TypeString() const { return ONNXTypeToString(Type()); } ONNXTensorElementDataType Tensor::StringToONNXDataType( const nsACString& aString) { #define CASE(string, suffix) \ do { \ if (aString.EqualsASCII(#string)) { \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##suffix; \ } \ } while (0); CASE(int4, INT4); CASE(uint4, UINT4); CASE(int8, INT8); CASE(uint8, UINT8); CASE(int16, INT16); CASE(uint16, UINT16); CASE(int32, INT32); CASE(uint32, UINT32); CASE(int64, INT64); CASE(uint64, UINT64); CASE(float16, FLOAT16); CASE(float32, FLOAT); CASE(float64, DOUBLE); CASE(bool, BOOL); MOZ_CRASH("Missing string to ONNX data type value"); #undef CASE } ONNXTensorElementDataType Tensor::Type() const { return StringToONNXDataType(mType); } nsLiteralCString Tensor::ONNXTypeToString( ONNXTensorElementDataType aType) const { switch (aType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: return "undefined"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: return "uint4"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: return "int4"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return "uint8"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return "int8"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return "uint16"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return "int16"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return "int32"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return "int64"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return "uint32"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return "uint64"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: return "string"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: return "bool"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return "float16"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return "bfloat16"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return "float32"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return "double"_ns; case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: MOZ_CRASH("Missing ONNX data type value to string"); break; } return ""_ns; } nsCString Tensor::ToString() const { nsCString rv; size_t count = mData.Length() / DataTypeSize(Type()); rv.AppendFmt("{} {} elements, {} bytes, {} dims", mType, count, mData.Length(), mDims.Length()); if (MOZ_LOG_TEST(gONNXLog, LogLevel::Verbose)) { rv.AppendFmt("Dims:\n"); rv.AppendFmt("{}\n", fmt::join(mDims, ",")); rv.AppendFmt("Values:\n"); #define CASE(onnx_type, c_type) \ case onnx_type: { \ rv.AppendFmt("{}\n", \ fmt::join(Span((c_type*)mData.Elements(), count), ",")); \ break; \ } switch (Type()) { CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, uint8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, uint16_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, int16_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, int16_t); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, uint32_t); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, uint64_t); case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: MOZ_CRASH("Not handled"); break; } #undef CASE } return rv; } size_t Tensor::DataTypeSize(ONNXTensorElementDataType aType) { #define CASE(onnx_type, c_type) \ do { \ case onnx_type: \ return sizeof(c_type); \ } while (0); switch (aType) { CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, uint8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, uint16_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, int16_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, int8_t) CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, int16_t); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, uint32_t); CASE(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, uint64_t); case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: MOZ_CRASH("Not handled"); break; } #undef CASE return 0; } JSObject* Tensor::WrapObject(JSContext* aCx, JS::Handle aGivenProto) { return Tensor_Binding::Wrap(aCx, this, aGivenProto); } } // namespace mozilla::dom