/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ const { SmartTabGroupingManager, CLUSTER_METHODS, ANCHOR_METHODS, getBestAnchorClusterInfo, ClusterRepresentation, SMART_TAB_GROUPING_CONFIG, isSearchTab, } = ChromeUtils.importESModule( "moz-src:///browser/components/tabbrowser/SmartTabGrouping.sys.mjs" ); const { OPFS: SharedOPFS } = ChromeUtils.importESModule( "chrome://global/content/ml/OPFS.sys.mjs" ); const { TestIndexedDBCache: SharedIndexedDBCache } = ChromeUtils.importESModule( "chrome://global/content/ml/ModelHub.sys.mjs" ); const { SecurityOrchestrator: SharedSecurityOrchestrator, getSecurityOrchestrator: sharedGetSecurityOrchestrator, } = ChromeUtils.importESModule( "chrome://global/content/ml/security/SecurityOrchestrator.sys.mjs" ); const PREF_SECURITY_ENABLED = "browser.ml.security.enabled"; async function setupSecurity() { Services.prefs.setBoolPref(PREF_SECURITY_ENABLED, true); } async function teardownSecurity() { Services.prefs.clearUserPref(PREF_SECURITY_ENABLED); await SharedSecurityOrchestrator.resetForTesting(); } function createRandomBlob(blockSize = 8, count = 1) { const blocks = Array.from({ length: count }, () => Uint32Array.from( { length: blockSize / 4 }, () => Math.random() * 4294967296 ) ); return new Blob(blocks, { type: "application/octet-stream" }); } function createBlob(size = 8) { return createRandomBlob(size); } function stripLastUsed(data) { return data.map(({ lastUsed: _unusedLastUsed, ...rest }) => { return rest; }); } /** * Check whether an IndexedDB database exists without creating or deleting it. * * @param {string} name * @returns {Promise} true if it exists, false otherwise */ function indexedDBExists(name) { return new Promise((resolve, reject) => { const req = indexedDB.open(name); // no version -> avoids VersionError issues let sawUpgradeNeeded = false; req.onupgradeneeded = e => { // DB did not exist; abort so we don't create it. sawUpgradeNeeded = true; e.target.transaction.abort(); }; req.onsuccess = e => { e.target.result.close(); resolve(true); // existed }; req.onerror = () => { // If we aborted because it didn't exist, treat as "doesn't exist". if (sawUpgradeNeeded && req.error?.name === "AbortError") { resolve(false); return; } reject(req.error); }; }); } /** * Helper function to initialize the cache */ async function initializeCache() { const dbName = `modelFiles-${crypto.randomUUID()}`; await SharedIndexedDBCache.deleteDatabaseAndWait(dbName).catch(() => {}); await SharedOPFS.getDirectoryHandle(dbName, { create: true }); return await SharedIndexedDBCache.init({ dbName }); } /** * Helper function to delete the cache database */ async function deleteCache(cache) { await cache.dispose(); await SharedIndexedDBCache.deleteDatabaseAndWait(cache.dbName).catch( () => {} ); try { await SharedOPFS.remove(cache.dbName, { recursive: true }); } catch (e) { // can be empty } } /** * Checks if numbers are close up to decimalPoints decimal points * * @param {number} a * @param {number} b * @param {number} decimalPoints * @returns {boolean} True if numbers are similar */ function numberLooseEquals(a, b, decimalPoints = 2) { return a.toFixed(decimalPoints) === b.toFixed(decimalPoints); } /** * Compares two vectors up to decimalPoints decimal points * Returns true if all items the same up to decimalPoints threshold * * @param {number[]} a * @param {number[]} b * @param {number} decimalPoints * @returns {boolean} True if vectors are similar */ function vectorLooseEquals(a, b, decimalPoints = 2) { return a.every( (item, index) => item.toFixed(decimalPoints) === b[index].toFixed(decimalPoints) ); } /** * Extremely simple generator deterministic seeded list of numbers between * 0 and 1 for use of tests in place of a true random generator * * @param {number} seed * @returns {function(): number} */ function simpleNumberSequence(seed = 0) { const values = [ 0.42, 0.145, 0.5, 0.9234, 0.343, 0.1324, 0.8343, 0.534, 0.634, 0.3233, ]; let counter = Math.floor(seed) % values.length; return () => { counter = (counter + 1) % values.length; return values[counter]; }; } /** * Utility function to shuffle an array, using a random * * @param {object[]} array of items to shuffle * @param {Function} randFunc function that returns between 0 and 1 */ function shuffleArray(array, randFunc) { randFunc = randFunc ?? Math.random; for (let i = array.length - 1; i >= 0; i--) { const j = Math.floor(randFunc() * (i + 1)); [array[i], array[j]] = [array[j], array[i]]; } } /** * Returns dict that averages input values * * @param {object[]} itemArray List of dicts, each with values to average * @returns {object} Object with average of values passed in itemArray */ function averageStatsValues(itemArray) { const result = {}; if (itemArray.length === 0) { return result; } for (const key of Object.keys(itemArray[0])) { let total = 0.0; itemArray.forEach(a => (total += a[key])); result[key] = total / itemArray.length; } return result; } /** * Read tsv file from string * * @param {string} tsvString string to read from * @returns {object} Object with parsed tsv string */ function parseTsvStructured(tsvString) { const rows = tsvString.trim().split("\n"); const keys = rows[0].split("\t"); const arrayOfDicts = rows.slice(1).map(row => { const values = row.split("\t"); // Map keys to corresponding values const dict = {}; keys.forEach((key, index) => { dict[key] = values[index]; }); return dict; }); return arrayOfDicts; } /** * Read tsv string with embeddings * * @param {string} tsvString string with embeddings present * @returns {object} Object containing the embeddings */ function parseTsvEmbeddings(tsvString) { const rows = tsvString.trim().split("\n"); return rows.map(row => { return row.split("\t").map(value => parseFloat(value)); }); } /** * * @param {string} clusterMethod kmeans or kmeans with anchor * @param {string} umapMethod umap or dbscan * @param {object[]} tabs tabs to cluster * @param {object[]} embeddings precomputed embeddings for the tabs * @param {number} iterations number of iterations before stopping clustering * @param {number[]} preGroupedTabIndices indices of tabs that are present in the group * @param {string} anchorMethod fixed or drift anchor methods * @param {number} silBoost what value to multiply silhouette score * @returns {Promise<{object}>} average of metric results */ async function testAugmentGroup( clusterMethod, umapMethod, tabs, embeddings, iterations = 1, preGroupedTabIndices, anchorMethod = ANCHOR_METHODS.FIXED, silBoost = undefined ) { const groupManager = new SmartTabGroupingManager(); groupManager.setAnchorMethod(anchorMethod); if (silBoost !== undefined) { groupManager.setSilBoost(silBoost); } const randFunc = simpleNumberSequence(); groupManager.setDataTitleKey("title"); groupManager.setClusteringMethod(clusterMethod); groupManager.setDimensionReductionMethod(umapMethod); const allScores = []; for (let i = 0; i < iterations; i++) { const groupingResult = await groupManager.generateClusters( tabs, embeddings, 0, randFunc, preGroupedTabIndices ); const titleKey = "title"; const centralClusterTitles = new Set( groupingResult.getAnchorCluster().tabs.map(a => a[titleKey]) ); groupingResult.getAnchorCluster().print(); const anchorTitleSet = new Set( preGroupedTabIndices.map(a => tabs[a][titleKey]) ); Assert.equal( centralClusterTitles.intersection(anchorTitleSet).size, anchorTitleSet.size, `All anchor indices in target cluster` ); const scoreInfo = groupingResult.getAccuracyStatsForCluster( "smart_group_label", groupingResult.getAnchorCluster().tabs[0].smart_group_label ); allScores.push(scoreInfo); } return averageStatsValues(allScores); } /** * Runs clustering test with multiple anchor tabs * * @param {object[]} data tabs to run test on * @param {object []} precomputedEmbeddings embeddings for the tabs * @param {number[]} anchorGroupIndices indices of tabs already present in the group * @param {string} anchorMethod fixed or drift anchor method * @param {number} silBoost value with which to boost silhouette score * @returns {Promise<{}|null>} metric stats from running the clustering test */ async function runAnchorTabTest( data, precomputedEmbeddings = null, anchorGroupIndices, anchorMethod = ANCHOR_METHODS.FIXED, silBoost = undefined ) { const testParams = [[CLUSTER_METHODS.KMEANS]]; let scoreInfo; for (let testP of testParams) { scoreInfo = await testAugmentGroup( testP[0], testP[1], data, precomputedEmbeddings, 1, anchorGroupIndices, anchorMethod, silBoost ); } if (testParams.length === 1) { return scoreInfo; } return null; } /** * Fetches a local file from prefix and filename * * @param {string} host_prefix root data folder path * @param {string} filename name of file * @returns {Promise} */ function fetchFile(host_prefix, filename) { return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); // const url = `${HOST_PREFIX}${filename}`; const url = `${host_prefix}${filename}`; xhr.open("GET", url, true); xhr.onload = () => { if (xhr.status === 200) { resolve(xhr.responseText); } else { reject(new Error(`Failed to fetch data: ${xhr.statusText}`)); } }; xhr.onerror = () => reject(new Error(`Network error getting ${url}`)); xhr.send(); }); } /** * Creates a mock tab object with a mocked linkedBrowser, * simulating the tab data structure * * @param {object} options * @param {string|null} options.searchURL - The value to return from getAttribute("triggeringSearchEngineURL"). * @param {string} options.currentURL - The current URI of the tab's linked browser. * @param {string|null} options.title - Title of page * @returns {object} A mock tab object shaped like a real Firefox tab for testing. */ function createMockTab({ searchURL, currentURL, title }) { return { linkedBrowser: { getAttribute(name) { if (name === "triggeringSearchEngineURL") { return searchURL; } return null; }, currentURI: { spec: currentURL, }, }, label: title, }; }