/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef DOM_TENSOR_H_ #define DOM_TENSOR_H_ #include "js/TypeDecls.h" #include "mozilla/ErrorResult.h" #include "mozilla/dom/BindingDeclarations.h" #include "mozilla/dom/ONNXBinding.h" #include "mozilla/dom/onnxruntime_c_api.h" #include "nsCycleCollectionParticipant.h" #include "nsIGlobalObject.h" #include "nsWrapperCache.h" namespace mozilla::dom { class Promise; class Tensor final : public nsISupports, public nsWrapperCache { public: NS_DECL_CYCLE_COLLECTING_ISUPPORTS NS_DECL_CYCLE_COLLECTION_WRAPPERCACHE_CLASS(Tensor) public: // Used when created from js using a regular js array, containing numbers. Tensor(const GlobalObject& global, const nsACString& type, const nsTArray& data, const Sequence& dims); // Used when created from JS, e.g. input tensor, with a type array (it can be // of any type) Tensor(const GlobalObject& global, const nsACString& type, const ArrayBufferView& data, const Sequence& dims); // Used when created from C++, e.g. output tensor Tensor(const GlobalObject& aGlobal, ONNXTensorElementDataType aType, nsTArray aData, nsTArray aDims); static already_AddRefed Constructor( const GlobalObject& global, const nsACString& type, const ArrayBufferViewOrAnySequence& data, const Sequence& dims, ErrorResult& aRv); protected: ~Tensor() = default; public: nsIGlobalObject* GetParentObject() const { return mGlobal; }; JSObject* WrapObject(JSContext* aCx, JS::Handle aGivenProto) override; void GetDims(nsTArray& aRetVal); void SetDims(const nsTArray& aVal); void GetType(nsCString& aRetVal) const; void GetData(JSContext* cx, JS::MutableHandle aRetVal) const; TensorDataLocation Location() const; already_AddRefed GetData(const Optional& releaseData); void Dispose(); uint8_t* Data() { return mData.Elements(); } size_t Size() { return mData.Length(); } int32_t* Dims() { return mDims.Elements(); } size_t DimsSize() { return mDims.Length(); } ONNXTensorElementDataType Type() const; nsCString TypeString() const; nsLiteralCString ONNXTypeToString(ONNXTensorElementDataType aType) const; nsCString ToString() const; static ONNXTensorElementDataType StringToONNXDataType( const nsACString& aString); static size_t DataTypeSize(ONNXTensorElementDataType aType); private: nsCOMPtr mGlobal; nsCString mType; nsTArray mData; nsTArray mDims; }; } // namespace mozilla::dom #endif // DOM_TENSOR_H_