/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- * vim: set ts=8 sts=2 et sw=2 tw=80: * * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ #include "intgemm/IntegerGemmIntrinsic.h" #include "mozilla/CheckedInt.h" #include "mozilla/IntegerPrintfMacros.h" #include "mozilla/TimeStamp.h" #include #include "fmt/format.h" #include "js/ErrorReport.h" #include "js/HeapAPI.h" #include "vm/ArrayBufferObject.h" #include "vm/GeckoProfiler.h" #include "vm/JSContext.h" #include "wasm/WasmBuiltins.h" #include "wasm/WasmInstance.h" #include "wasm/WasmLog.h" #if defined(USE_AVX512BW) # if defined(USE_AVX512VNNI) # if defined(USE_AVXVNNI) # define SUPPORTED_ARCHS \ xsimd::arch_list, xsimd::avx512bw, \ xsimd::avxvnni, xsimd::avx2, xsimd::ssse3, \ xsimd::sse2> # else # define SUPPORTED_ARCHS \ xsimd::arch_list, xsimd::avx512bw, \ xsimd::avx2, xsimd::ssse3, xsimd::sse2> # endif # elif defined(USE_AVXVNNI) # define SUPPORTED_ARCHS \ xsimd::arch_list # else # define SUPPORTED_ARCHS \ xsimd::arch_list # endif #elif defined(USE_AVXVNNI) # define SUPPORTED_ARCHS \ xsimd::arch_list #elif defined(USE_AVX2) # define SUPPORTED_ARCHS \ xsimd::arch_list #elif defined(USE_SSSE3) # define SUPPORTED_ARCHS xsimd::arch_list #elif defined(USE_SSE2) # define SUPPORTED_ARCHS xsimd::arch_list #elif defined(USE_NEON) and defined(XSIMD_WITH_NEON64) # if defined(USE_NEON_I8MM) # define SUPPORTED_ARCHS \ xsimd::arch_list, xsimd::neon64> # else # define SUPPORTED_ARCHS xsimd::arch_list # endif #else # error no supported architecture #endif // Dispatch *at runtime* based on run-time hardware and compile-time // architectures. // // FIXME: Ideally we would not run the dispatch code at each function call. #define GEMMOLOGY_DISPATCH(FUNC) \ xsimd::dispatch([](auto arch, auto... args) { \ return gemmology::Engine::FUNC(args...); \ }) template struct AutoProfilerMarker { AutoProfilerMarker(js::GeckoProfilerRuntime& profiler, const CharT* name) : profiler(profiler), name(name) { if (profiler.enabled()) { startTime = mozilla::TimeStamp::Now(); } } template AutoProfilerMarker(js::GeckoProfilerRuntime& profiler, const CharT* name, fmt::format_string aFormatStr, Args&&... aArgs) : profiler(profiler), name(name) { if (profiler.enabled()) { startTime = mozilla::TimeStamp::Now(); auto [out, size] = fmt::vformat_to_n( text, sizeof(text) - 1, aFormatStr, fmt::make_format_args>(aArgs...)); MOZ_ASSERT(size > sizeof(text) - 1, "Truncated marker, consider increasing the buffer"); *out = 0; } } ~AutoProfilerMarker() { if (profiler.enabled()) { profiler.markInterval(name, startTime, text, JS::ProfilingCategoryPair::JS); } } js::GeckoProfilerRuntime& profiler; const char* name; char text[TextLength]{}; mozilla::TimeStamp startTime; }; static constexpr uint32_t ARRAY_ALIGNMENT = 64; static constexpr uint32_t ROWS_A_MULTIPLIER = 1; static constexpr uint32_t COLUMNS_A_MULTIPLIER = 64; static constexpr uint32_t ROWS_B_MULTIPLIER = COLUMNS_A_MULTIPLIER; static constexpr uint32_t COLUMNS_B_MULTIPLIER = 8; static constexpr uint32_t SELECTED_COLUMNS_B_MULTIPLIER = 8; size_t GetWasmRawBufferLength(const uint8_t* memBase) { const js::WasmArrayRawBuffer* rawBuf = js::WasmArrayRawBuffer::fromDataPtr(memBase); return rawBuf->byteLength(); } bool CheckMatrixDimension(uint32_t size, uint32_t sizeMultiplier) { // A valid size is a positive integral multiple of Multiplier return !((size == 0) || (size % sizeMultiplier != 0)); } bool CheckMatrixBound(uint32_t input, uint64_t inputSize, size_t wasmBufferSize) { mozilla::CheckedUint64 inputUpperLimit(inputSize); inputUpperLimit += input; // Bound check fails if size overflows or it spans outside the wasm memory return !(!inputUpperLimit.isValid() || (inputUpperLimit.value() >= (uint64_t)wasmBufferSize)); } bool CheckMatrixBoundAndAlignment(uint32_t input, uint64_t inputSize, size_t wasmBufferSize) { // Alignment check: It is sufficient to check alignment for the offset rather // than for the actual pointer within wasm memory (as long as following assert // is satisfied) static_assert(js::gc::PageSize >= ARRAY_ALIGNMENT, "PageSize should be bigger than Alignment"); if (input % ARRAY_ALIGNMENT != 0) { return false; } // Check Bound return CheckMatrixBound(input, inputSize, wasmBufferSize); } int32_t js::intgemm::IntrI8PrepareB(wasm::Instance* instance, uint32_t inputMatrixB, float scale, float zeroPoint, uint32_t rowsB, uint32_t colsB, uint32_t outputMatrixB, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8PrepareB.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound and Alignment checks for matricies uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixB, sizeB, wasmBufferSize) || !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) for PrepareB const float* inputMatrixBPtr = reinterpret_cast(&memBase[inputMatrixB]); int8_t* outputMatrixBPtr = reinterpret_cast(&memBase[outputMatrixB]); AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), "integemm::PreparB", FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); GEMMOLOGY_DISPATCH(PrepareB) (inputMatrixBPtr, outputMatrixBPtr, scale, // Quant Mult rowsB, colsB); return 0; } int32_t js::intgemm::IntrI8PrepareBFromTransposed( wasm::Instance* instance, uint32_t inputMatrixBTransposed, float scale, float zeroPoint, uint32_t rowsB, uint32_t colsB, uint32_t outputMatrixB, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8PrepareBFromTransposed.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound checks for all matricies uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixBTransposed, sizeB, wasmBufferSize) || !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) for PrepareBTransposed const float* inputMatrixBTransposedPtr = reinterpret_cast(&memBase[inputMatrixBTransposed]); int8_t* outputMatrixBPtr = reinterpret_cast(&memBase[outputMatrixB]); AutoProfilerMarker marker( cx->runtime()->geckoProfiler(), "intgemm::PreparBTransposed", FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); GEMMOLOGY_DISPATCH(PrepareBTransposed) (inputMatrixBTransposedPtr, outputMatrixBPtr, scale, // Quant Mult rowsB, colsB); return 0; } int32_t js::intgemm::IntrI8PrepareBFromQuantizedTransposed( wasm::Instance* instance, uint32_t inputMatrixBQuantizedTransposed, uint32_t rowsB, uint32_t colsB, uint32_t outputMatrixB, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8PrepareBFromQuantizedTransposed.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound checks for all matricies uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixBQuantizedTransposed, sizeB, wasmBufferSize) || !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) const int8_t* inputMatrixBQuantizedTransposedPtr = reinterpret_cast( &memBase[inputMatrixBQuantizedTransposed]); int8_t* outputMatrixBPtr = reinterpret_cast(&memBase[outputMatrixB]); AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), "intgemm::PrepareBQuantizedTransposed", FMT_STRING("rowsB: {}, colsB: {}"), rowsB, colsB); GEMMOLOGY_DISPATCH(PrepareBQuantizedTransposed) (inputMatrixBQuantizedTransposedPtr, outputMatrixBPtr, rowsB, colsB); return 0; } int32_t js::intgemm::IntrI8PrepareA(wasm::Instance* instance, uint32_t inputMatrixA, float scale, float zeroPoint, uint32_t rowsA, uint32_t colsA, uint32_t outputMatrixA, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8PrepareA.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsA, ROWS_A_MULTIPLIER) || !CheckMatrixDimension(colsA, COLUMNS_A_MULTIPLIER)) { return -1; } // Memory Bound checks for all matricies uint64_t sizeA = (uint64_t)rowsA * (uint64_t)colsA; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixA, sizeA, wasmBufferSize) || !CheckMatrixBoundAndAlignment(outputMatrixA, sizeA, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) const float* inputMatrixAPtr = reinterpret_cast(&memBase[inputMatrixA]); uint8_t* outputMatrixAPtr = &memBase[outputMatrixA]; AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), "intgemm::PrepareA", FMT_STRING("rowsA: {}, colsA: {}"), rowsA, colsA); GEMMOLOGY_DISPATCH(Shift::PrepareA) (inputMatrixAPtr, outputMatrixAPtr, scale, rowsA, colsA); return 0; } int32_t js::intgemm::IntrI8PrepareBias( wasm::Instance* instance, uint32_t inputMatrixBPrepared, float scaleA, float zeroPointA, float scaleB, float zeroPointB, uint32_t rowsB, uint32_t colsB, uint32_t inputBias, uint32_t output, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8PrepareBias.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound checks for all matrices uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; uint64_t sizeBias = colsB; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, wasmBufferSize) || !CheckMatrixBound(output, sizeBias, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) const int8_t* inputMatrixBPreparedPtr = (const int8_t*)&memBase[inputMatrixBPrepared]; float* outputPtr = (float*)&memBase[output]; float unquantFactor = (-1) * ((127.0f / scaleA) * (127.0f / scaleB)) / (127.0f); if (inputBias) { if (!CheckMatrixBound(inputBias, sizeBias, wasmBufferSize)) { return -1; } const float* inputBiasPtr = reinterpret_cast(&memBase[inputBias]); AutoProfilerMarker marker( cx->runtime()->geckoProfiler(), "intgemm::PrepareBias w/ input bias", FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); GEMMOLOGY_DISPATCH(Shift::PrepareBias) (inputMatrixBPreparedPtr, rowsB, colsB, gemmology::callbacks::UnquantizeAndAddBiasAndWrite( unquantFactor, inputBiasPtr, outputPtr)); } else { AutoProfilerMarker marker( cx->runtime()->geckoProfiler(), "intgemm::PrepareBias", FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); GEMMOLOGY_DISPATCH(Shift::PrepareBias) (inputMatrixBPreparedPtr, rowsB, colsB, gemmology::callbacks::UnquantizeAndWrite(unquantFactor, outputPtr)); } return 0; } int32_t js::intgemm::IntrI8MultiplyAndAddBias( wasm::Instance* instance, uint32_t inputMatrixAPrepared, float scaleA, float zeroPointA, uint32_t inputMatrixBPrepared, float scaleB, float zeroPointB, uint32_t inputBiasPrepared, float unquantMultiplier, uint32_t rowsA, uint32_t width, uint32_t colsB, uint32_t output, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8MultiplyAndAddBias.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsA, ROWS_A_MULTIPLIER) || !CheckMatrixDimension(width, COLUMNS_A_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound checks for all matricies uint64_t sizeA = (uint64_t)rowsA * (uint64_t)width; uint64_t sizeB = (uint64_t)width * (uint64_t)colsB; uint64_t sizeBias = (uint64_t)colsB; uint64_t sizeOutput = (uint64_t)rowsA * (uint64_t)colsB; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixAPrepared, sizeA, wasmBufferSize) || !CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, wasmBufferSize) || !CheckMatrixBound(inputBiasPrepared, sizeBias, wasmBufferSize) || !CheckMatrixBound(output, sizeOutput, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) const uint8_t* inputMatrixAPreparedPtr = &memBase[inputMatrixAPrepared]; const int8_t* inputMatrixBPreparedPtr = reinterpret_cast(&memBase[inputMatrixBPrepared]); const float* inputBiasPreparedPtr = reinterpret_cast(&memBase[inputBiasPrepared]); float* outputPtr = reinterpret_cast(&memBase[output]); float unquantFactor = unquantMultiplier / (scaleA * scaleB); AutoProfilerMarker marker( cx->runtime()->geckoProfiler(), "intgemm::Shift::Multiply", FMT_STRING("rowsA: {}, width: {}, colsA: {}"), rowsA, width, colsB); GEMMOLOGY_DISPATCH(Shift::Multiply) (inputMatrixAPreparedPtr, inputMatrixBPreparedPtr, rowsA, width, colsB, gemmology::callbacks::UnquantizeAndAddBiasAndWrite( unquantFactor, inputBiasPreparedPtr, outputPtr)); return 0; } int32_t js::intgemm::IntrI8SelectColumnsOfB(wasm::Instance* instance, uint32_t inputMatrixBPrepared, uint32_t rowsB, uint32_t colsB, uint32_t colIndexList, uint32_t sizeColIndexList, uint32_t output, uint8_t* memBase) { MOZ_ASSERT(wasm::SASigIntrI8SelectColumnsOfB.failureMode == wasm::FailureMode::FailOnNegI32); JSContext* cx = instance->cx(); AutoUnsafeCallWithABI unsafe; // Size checks for matricies if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER) || !CheckMatrixDimension(sizeColIndexList, SELECTED_COLUMNS_B_MULTIPLIER)) { return -1; } // Memory Bound checks for all matricies uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; uint64_t sizeOutput = (uint64_t)rowsB * (uint64_t)sizeColIndexList; size_t wasmBufferSize = GetWasmRawBufferLength(memBase); if (!CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, wasmBufferSize) || !CheckMatrixBound(colIndexList, sizeColIndexList, wasmBufferSize) || !CheckMatrixBound(output, sizeOutput, wasmBufferSize)) { return -1; } // Actual call to the 3rd party library (intgemm) const int8_t* inputMatrixBPreparedPtr = reinterpret_cast(&memBase[inputMatrixBPrepared]); const uint32_t* colIndexListPtr = reinterpret_cast(&memBase[colIndexList]); int8_t* outputPtr = reinterpret_cast(&memBase[output]); AutoProfilerMarker marker( cx->runtime()->geckoProfiler(), "integemm::SelectColumnsB", FMT_STRING("rowsB: {} colsB: {} sizecolList: {}, sizeB: {}"), rowsB, colsB, sizeColIndexList, sizeB); GEMMOLOGY_DISPATCH(SelectColumnsB) (inputMatrixBPreparedPtr, outputPtr, rowsB, colIndexListPtr, colIndexListPtr + sizeColIndexList); return 0; } #undef GEMMOLOGY_DISPATCH #undef SUPPORTED_ARCHS