/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; import { createEngine, FEATURES, } from "chrome://global/content/ml/EngineProcess.sys.mjs"; import { cosSim, KeywordExtractor, } from "chrome://global/content/ml/NLPUtils.sys.mjs"; import { computeCentroidFrom2DArray, computeRandScore, euclideanDistance, getAccuracyStats, kmeansPlusPlus, silhouetteCoefficients, } from "chrome://global/content/ml/ClusterAlgos.sys.mjs"; import { AIFeature } from "chrome://global/content/ml/AIFeature.sys.mjs"; const lazy = {}; ChromeUtils.defineESModuleGetters(lazy, { NLP: "resource://gre/modules/NLP.sys.mjs", MLEngineParent: "resource://gre/actors/MLEngineParent.sys.mjs", MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs", Progress: "chrome://global/content/ml/Utils.sys.mjs", MLUninstallService: "chrome://global/content/ml/Utils.sys.mjs", }); const LATEST_MODEL_REVISION = "latest"; // Methods for suggesting tabs that are similar to current tab export const SUGGEST_OTHER_TABS_METHODS = { KMEANS_WITH_ANCHOR: "KMEANS_WITH_ANCHOR", NEAREST_NEIGHBOR: "NEAREST_NEIGHBOR", LOGISTIC_REGRESSION: "LOGISTIC_REGRESSION", }; XPCOMUtils.defineLazyPreferenceGetter( lazy, "suggestOtherTabsMethod", "browser.tabs.groups.smart.suggestOtherTabsMethod" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "topicModelRevision", "browser.tabs.groups.smart.topicModelRevision" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "embeddingModelRevision", "browser.tabs.groups.smart.embeddingModelRevision" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "nearestNeighborThresholdInt", "browser.tabs.groups.smart.nearestNeighborThresholdInt" ); const EMBED_TEXT_KEY = "combined_text"; export const CLUSTER_METHODS = { KMEANS: "KMEANS", }; // Methods for finding similar items for an existing cluster export const ANCHOR_METHODS = { DRIFT: "DRIFT", // We let k-means clustering run, and find the cluster with the most anchor items FIXED: "FIXED", // We always group with the anchor items in the 0 cluster, and never let them be reassinged }; // Methods for finding ignoring other groups that were already grouped export const PREGROUPED_HANDLING_METHODS = { EXCLUDE: "EXCLUDE", // We let k-means clustering run, and find the cluster with the most anchor items IGNORE: "IGNORE", // We always group with the anchor items in the 0 cluster, and never let them be reassinged }; const EXPECTED_TOPIC_MODEL_OBJECTS = 6; const EXPECTED_EMBEDDING_MODEL_OBJECTS = 4; const MAX_NON_SUMMARIZED_SEARCH_LENGTH = 26; export const DIM_REDUCTION_METHODS = {}; const MISSING_ANCHOR_IN_CLUSTER_PENALTY = 0.2; const MAX_GROUPED_TABS = 3; const MAX_SUGGESTED_TABS = 10; // limit number of tabs to be processed so inference process doesn't crash const MAX_TABS_TO_PROCESS = 300; const DISSIMILAR_TAB_LABEL = "none"; const ADULT_TAB_LABEL = "adult content"; const LABELS_TO_EXCLUDE = [DISSIMILAR_TAB_LABEL, ADULT_TAB_LABEL]; const ML_TASK_FEATURE_EXTRACTION = "feature-extraction"; const ML_TASK_TEXT2TEXT = "text2text-generation"; const STG_FEATURE_ID = "smart-tab-grouping"; const STG_EMBEDDING_FEATURE_ID = "smart-tab-embedding"; const STG_TOPIC_FEATURE_ID = "smart-tab-topic"; const LABEL_REASONS = { DEFAULT: "DEFAULT", LOW_CONFIDENCE: "LOW_CONFIDENCE", EXCLUDE: "EXCLUDE", ERROR: "ERROR", }; export const SMART_TAB_GROUPING_CONFIG = { embedding: { dtype: "q8", timeoutMS: 2 * 60 * 1000, // 2 minutes taskName: ML_TASK_FEATURE_EXTRACTION, featureId: STG_EMBEDDING_FEATURE_ID, engineId: FEATURES[STG_EMBEDDING_FEATURE_ID].engineId, backend: "onnx-native", fallbackBackend: "onnx", }, topicGeneration: { dtype: "q8", timeoutMS: 2 * 60 * 1000, // 2 minutes taskName: ML_TASK_TEXT2TEXT, featureId: STG_TOPIC_FEATURE_ID, engineId: FEATURES[STG_TOPIC_FEATURE_ID].engineId, backend: "onnx-native", fallbackBackend: "onnx", }, dataConfig: { titleKey: "label", descriptionKey: "description", }, clustering: { dimReductionMethod: null, // Not completed. clusterImplementation: CLUSTER_METHODS.KMEANS, clusteringTriesPerK: 3, anchorMethod: ANCHOR_METHODS.FIXED, pregroupedHandlingMethod: PREGROUPED_HANDLING_METHODS.EXCLUDE, pregroupedSilhouetteBoost: 2, // Relative weight of the cluster's score and all other cluster's combined suggestOtherTabsMethod: SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR, }, }; // these parameters were generated by training a logistic regression // model on synthetic data. see https://github.com/mozilla/smart-tab-grouping // and https://github.com/mozilla/smart-tab-grouping/pull/12 for more info const LOGISTIC_REGRESSION_PARAMS = { // Logistic WITH group name // Features: s_gc, s_tt_max, s_dd in [0, 1] TITLE_WITH_GROUP_NAME: { GROUP_SIMILARITY_WEIGHT: 0.10249, TITLE_SIMILARITY_WEIGHT: 0.54897, DOMAIN_SIMILARITY_WEIGHT: 0.34854, INTERCEPT: -0.07397, THRESHOLD: 0.59, }, // Logistic WITHOUT group name // Features: s_tt_max, s_dd in [0, 1] TITLE_ONLY: { GROUP_SIMILARITY_WEIGHT: 0, // unused in this variant TITLE_SIMILARITY_WEIGHT: 0.92513, DOMAIN_SIMILARITY_WEIGHT: 0.07487, INTERCEPT: -2.58574, THRESHOLD: 0.123, }, }; const TAB_URLS_TO_EXCLUDE = [ "about:newtab", "about:home", "about:privatebrowsing", "chrome://browser/content/blanktab.html", "about:firefoxview", "about:opentabs", ]; const TITLE_DELIMETER_SET = new Set(["-", "|", "—"]); /** * For a given set of clusters represented by indices, returns the index of the cluster * that has the most anchor items inside it. * * An anhor item is an index that represents the index to a tab that is already grouped and in * the cluster we're interested in finding more items for. * * @param {number[][]} groupIndices - Array of clusters represented as arrays of indices. * @param {number[]} anchorItems - Array of anchor item indices. * @returns {{anchorClusterIndex: number, numAnchorItemsInCluster: number}} Index of best cluster and the number of anchor items. */ export function getBestAnchorClusterInfo(groupIndices, anchorItems) { const anchorItemSet = new Set(anchorItems); const numItemsList = groupIndices.map(g => g.reduce( (cur, itemIndex) => (anchorItemSet.has(itemIndex) ? cur + 1 : cur), 0 ) ); const anchorClusterIndex = numItemsList.indexOf(Math.max(...numItemsList)); const numAnchorItemsInCluster = numItemsList[anchorClusterIndex]; return { anchorClusterIndex, numAnchorItemsInCluster }; } /** * Check tab to see if it's a search page * * @param {object} tab * @returns {boolean} Returns true if the tab is a web search from the Firefox search UI and the user is still on the original page. * Changes in search query after search is made is supported. * Returns false if user started from a hompepage of a site rather than the New Tab / browser UI */ export function isSearchTab(tab) { const linkedBrowser = tab?.linkedBrowser; if (!linkedBrowser) { return false; } const searchURL = linkedBrowser.getAttribute("triggeringSearchEngineURL"); const curURL = linkedBrowser.currentURI?.spec; if (!searchURL) { return false; } const queryFieldsMarker = searchURL.indexOf("?"); if ( queryFieldsMarker > 0 && searchURL.substring(0, queryFieldsMarker) === curURL.substring(0, queryFieldsMarker) ) { return true; } return false; } export class SmartTabGroupingManager extends AIFeature { /** * Creates the SmartTabGroupingManager object. * * @param {object} config configuration options */ constructor(config) { super(); this.config = config || structuredClone(SMART_TAB_GROUPING_CONFIG); } /** * Returns id associated with Smart tab grouping * * @return {string} */ static get id() { return STG_FEATURE_ID; } /** * Re-enables prefs for stg */ static async enable() { Services.prefs.setBoolPref("browser.tabs.groups.smart.enabled", true); Services.prefs.setBoolPref("browser.tabs.groups.smart.userEnabled", true); Services.prefs.setBoolPref("browser.tabs.groups.smart.optin", true); } /** * Disables user prefs for smart tab grouping and deletes local models * * @return {Promise} */ static async block() { // disable prefs associated with stg // opt-in flow is kept as in unless we decide to disable and re-enable later // which would make the user have to go through the flow twice Services.prefs.setBoolPref("browser.tabs.groups.smart.enabled", false); Services.prefs.setBoolPref("browser.tabs.groups.smart.userEnabled", false); Services.prefs.setBoolPref("browser.tabs.groups.smart.optin", false); // delete models associated with stg await SmartTabGroupingManager.deleteSmartTabModels(); } /** * Checks if STG feature is enabled based on various prefs * that it depends on * * @return {boolean} */ static get isEnabled() { // note that both `browser.tabs.groups.smart.enabled` and // `browser.tabs.smart.userEnabled` disable the UI but not // `browser.tabs.groups.smart.optin` return ( Services.prefs.getBoolPref("browser.ml.enable") && Services.prefs.getBoolPref("browser.tabs.groups.smart.enabled") && Services.prefs.getBoolPref("browser.tabs.groups.smart.userEnabled") && Services.prefs.getBoolPref("browser.tabs.groups.smart.optin") ); } /** * Checks for other conditions for smart tab grouping to be turned on, * e.g. locale. * * @return {boolean} */ static get isAllowed() { return Services.locale.appLocaleAsBCP47.startsWith("en"); } /** * Resets smart tab grouping to its default state where UI is visible * and user opt-in is required */ static async makeAvailable() { // Set explicitly rather than clearing, so that a non-locked policy default // of "blocked" does not prevent the user from switching back to "available". Services.prefs.setBoolPref("browser.tabs.groups.smart.enabled", true); Services.prefs.setBoolPref("browser.tabs.groups.smart.userEnabled", true); Services.prefs.clearUserPref("browser.tabs.groups.smart.optin"); // remove local models await SmartTabGroupingManager.deleteSmartTabModels(); } /** * Checks if UI is hidden * * @return {boolean} */ static get isBlocked() { return ( !Services.prefs.getBoolPref("browser.tabs.groups.smart.enabled") || !Services.prefs.getBoolPref("browser.tabs.groups.smart.userEnabled") ); } /** * Checks if the feature is managed by enterprise policy. * * @return {boolean} */ static get isManagedByPolicy() { return Services.prefs.prefIsLocked("browser.tabs.groups.smart.userEnabled"); } /** * Deletes model artifacts associated with Smart Tab Grouping * * @return {Promise} */ static async deleteSmartTabModels() { const engineIds = [ FEATURES[STG_TOPIC_FEATURE_ID].engineId, FEATURES[STG_EMBEDDING_FEATURE_ID].engineId, ]; // Remove all ML Engine files associated with this feature. await lazy.MLUninstallService.uninstall({ engineIds, // Used only for attribution/telemetry; the specific value is not significant. actor: "SmartTabGrouping", }); } /** * * @param {MLEngine} engine the engine to check * @return {boolean} true if the engine has not been initialized or closed */ static isEngineClosed(engine) { return !engine || engine?.engineStatus === "closed"; } /** * Initializes the embedding engine by running a test request * This helps remove the init latency */ async initEmbeddingEngine() { if (!SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) { return; } try { this.embeddingEngine = await this._createMLEngine(this.config.embedding); const request = { args: ["Test"], options: { pooling: "mean", normalize: true }, }; this.embeddingEngine.run(request); } catch (e) {} } /** * Generates tabs to process with a limit. First MAX_GROUPED_TABS are tabs that are * present in the group of the anchor tab. The remaining "ungrouped" tabs fill the * slots up to MAX_TABS_TO_PROCESS * * @param {Array} tabsInGroup active tabs in anchor group we are adding tabs to * @param {Array} allTabs list of tabs from gbrowser, some of which may be grouped in other groups * @param {number} max_limit_to_process max number of tabs we want to process as part of the flow * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned. */ getTabsToProcess( tabsInGroup, allTabs, max_limit_to_process = MAX_TABS_TO_PROCESS ) { const seen = new Set(); let tabsToProcess = []; const shouldInclude = tab => { if (tab.pinned) { return false; } if (!tab?.linkedBrowser?.currentURI?.spec) { return false; } return true; }; // include tabs in the anchor group first for (const tab of tabsInGroup) { if (!shouldInclude(tab)) { continue; } if (!seen.has(tab)) { // make sure we have "seen" all the // tabs already in the current group seen.add(tab); tabsToProcess.push(tab); } } // when generating embeddings, we only look at the first MAX_GROUPED_TABS // so use that limit here tabsToProcess = tabsToProcess.slice(0, MAX_GROUPED_TABS); // fill remaining slots with ungrouped tabs from the window for (const tab of allTabs) { if (tabsToProcess.length >= max_limit_to_process) { break; } if (!shouldInclude(tab)) { continue; } if (!seen.has(tab)) { seen.add(tab); tabsToProcess.push(tab); } } return tabsToProcess; } /** * Generates suggested tabs for an existing or provisional group * * @param {object} group active group we are adding tabs to * @param {Array} tabs list of tabs from gbrowser, some of which may be grouped in other groups * @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned. */ async smartTabGroupingForGroup(group, tabs) { // Add tabs to suggested group const groupTabs = group.tabs; const allTabs = this.getTabsToProcess(groupTabs, tabs, MAX_TABS_TO_PROCESS); // first (1 up to MAX_GROUPED_TABS) are tabs in the group const groupIndices = []; for (let i = 0; i < MAX_GROUPED_TABS; i++) { if (groupTabs.includes(allTabs[i])) { groupIndices.push(i); } } // find tabs that are part of other groups const alreadyGroupedIndices = allTabs .map((t, i) => (t.group ? i : -1)) .filter(a => a >= 0); let suggestedTabs; switch (lazy.suggestOtherTabsMethod) { case SUGGEST_OTHER_TABS_METHODS.KMEANS_WITH_ANCHOR: suggestedTabs = await this.generateClusters( allTabs, null, null, null, groupIndices, alreadyGroupedIndices ).then(clusters => { if (!clusters) { return []; } const targetCluster = clusters.clusterRepresentations.find(c => groupTabs.some(g => c.tabs.includes(g)) ); if (targetCluster) { // Return only tabs not already grouped return targetCluster.tabs.filter(t => !t.group); } return []; }); break; case SUGGEST_OTHER_TABS_METHODS.LOGISTIC_REGRESSION: suggestedTabs = await this.findSimilarTabsLogisticRegression({ allTabs, groupedIndices: groupIndices, alreadyGroupedIndices, groupLabel: group?.label, }); break; case SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR: default: // find nearest neighbors to current group suggestedTabs = await this.findNearestNeighbors({ allTabs, groupedIndices: groupIndices, alreadyGroupedIndices, groupLabel: group?.label, }); } return suggestedTabs.slice(0, MAX_SUGGESTED_TABS); } /** * Get tabs that need to be included in suggestions * * @param {Array} allTabs all tabs that are part of the window * @param {Array} groupedIndices indices of tabs that are already part of the group * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups * @returns {Array} tabs indices to be considered for suggestions */ getTabsToSuggest(allTabs, groupedIndices, alreadyGroupedIndices) { // tabs to be excluded // indices of all tabs that should be excluded (with duplicates) const tabURLIndicesToExclude = allTabs .map((at, index) => (TAB_URLS_TO_EXCLUDE.includes(at.url) ? index : -1)) .filter(index => index !== -1); const excludedTabIndices = [ ...groupedIndices, ...alreadyGroupedIndices, ...tabURLIndicesToExclude, ]; // tabs to be included return allTabs .map((_, index) => index) .filter(i => !excludedTabIndices.includes(i)); } /** * Generates similar tabs a grouped list of tabs * * @param {Array} allTabs all tabs that are part of the window * @param {Array} groupedIndices indices of tabs that are already part of the group * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups * @param {string} groupLabel name of group if present * @param {number} threshold for nearest neighbor similarity * @returns a list of suggested tabs that are similar to the groupedIndices tabs */ async findNearestNeighbors({ allTabs, groupedIndices, alreadyGroupedIndices, groupLabel = "", thresholdMills = lazy.nearestNeighborThresholdInt, precomputedEmbeddings = [], depth = 0, }) { // get embeddings for all the tabs const tabData = await this._prepareTabData(allTabs); let embeddings = precomputedEmbeddings; if (precomputedEmbeddings.length === 0) { embeddings = await this._generateEmbeddings( tabData.map((td, index) => { let text = SmartTabGroupingManager.preprocessText(td[EMBED_TEXT_KEY]); // augment with group name if it's present if (groupLabel && groupedIndices.includes(index)) { text = `${groupLabel.slice(0, 100)}. ${text}`; } return text; }) ); } // tabs that need to be assigned after filtering const tabsToAssignIndices = this.getTabsToSuggest( tabData, groupedIndices, alreadyGroupedIndices ); let closestTabs = []; const similarTabsIndices = []; for (let i = 0; i < tabsToAssignIndices.length; i++) { let closestScore = null; for ( let j = 0; j < Math.min(groupedIndices.length, MAX_GROUPED_TABS); j++ ) { const cosineSim = cosSim( embeddings[tabsToAssignIndices[i]], embeddings[groupedIndices[j]] ); if (!closestScore || cosineSim > closestScore) { closestScore = cosineSim; } } // threshold could also be set via a nimbus experiment, in which case // it will be an int <= 1000 if (closestScore > thresholdMills / 1000) { closestTabs.push([allTabs[tabsToAssignIndices[i]], closestScore]); similarTabsIndices.push(tabsToAssignIndices[i]); } } closestTabs.sort((a, b) => b[1] - a[1]); closestTabs = closestTabs.map(t => t[0]); // recurse once if the initial call only had a single tab // and we found at least 1 similar tab - this improves recall if (groupedIndices.length === 1 && !!closestTabs.length && depth === 1) { const recurseSimilarTabs = await this.findNearestNeighbors({ allTabs, groupedIndices: similarTabsIndices, alreadyGroupedIndices: alreadyGroupedIndices.concat(groupedIndices), groupLabel, thresholdMills, precomputedEmbeddings: embeddings, depth: depth - 1, }); closestTabs = closestTabs.concat(recurseSimilarTabs); } return closestTabs; } /** * Calculates the average similarity between the anchor embeddings and the candidate embeddings * * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs */ getAverageSimilarity(anchorEmbeddings, candidateEmbeddings) { let averageSimilarities = []; for (let candidate_embedding of candidateEmbeddings) { let averageSimilarity = 0; for (let anchor_embedding of anchorEmbeddings) { averageSimilarity += cosSim(candidate_embedding, anchor_embedding); } averageSimilarities.push(averageSimilarity / anchorEmbeddings.length); } return averageSimilarities; } /** * Calculates the max similarity between the anchor embeddings and the candidate embeddings * (used for s_tt_max). * * @param {number[]} anchorEmbeddings title embeddings for the anchor tabs * @param {number[]} candidateEmbeddings title embeddings for the candidate tabs */ getMaxSimilarity(anchorEmbeddings, candidateEmbeddings) { let maxSimilarities = []; for (let candidate_embedding of candidateEmbeddings) { let maxSimilarity = -1; for (let anchor_embedding of anchorEmbeddings) { const sim = cosSim(candidate_embedding, anchor_embedding); if (sim > maxSimilarity) { maxSimilarity = sim; } } maxSimilarities.push(maxSimilarity); } return maxSimilarities; } /** * Extract base domain from a URL with error handling * * @param {string} url * @return {string} */ static getBaseDomain(url) { if (!url) { return ""; } let hostname; try { ({ hostname } = new URL(url)); } catch (_e) { // invalid URL return ""; } if (!hostname) { return ""; } try { // additionalParts = 1 → one label above the registrable domain // then remove 'www' // https://www.example.com -> www.example.com -> example.com // https://www.docs.google.com -> docs.google.com // https://localhost -> error return Services.eTLD .getBaseDomain(Services.io.newURI(url.toLowerCase()), 1) .replace(/^www\./, ""); } catch (_e) { // localhost, IPs, internal hosts, etc. // bucket by the hostname. return hostname.toLowerCase(); } } /** * For each candidate tab, compute s_dd = fraction of anchors whose base domain * matches the candidate's base domain. * * @param {Array} anchorTabsPrep output of _prepareTabData for anchor tabs * @param {Array} candidateTabsPrep output of _prepareTabData for candidate tabs * @return {number[]} array of s_dd values in [0, 1] */ getDomainMatchFractions(anchorTabsPrep, candidateTabsPrep) { const anchorDomains = anchorTabsPrep.map(t => SmartTabGroupingManager.getBaseDomain(t.url) ); const numAnchors = anchorDomains.length || 1; return candidateTabsPrep.map(tab => { const candDomain = SmartTabGroupingManager.getBaseDomain(tab.url); if (!candDomain) { return 0; } let same = 0; for (const ad of anchorDomains) { if (ad && ad === candDomain) { same++; } } return same / numAnchors; }); } /** * Calculates the sigmoid value of the input * * @param {number} z * @return {number} */ sigmoid(z) { return 1 / (1 + Math.exp(-z)); } /** * Calculates the probability using the linear combination of the parameters * * @param {number} groupSimilarity s_gc in [0,1] * @param {number} titleSimilarity s_tt_max in [0,1] * @param {number} domainSimilarity s_dd in [0,1] * @param {object} params the logistic regression weights assigned to each parameter * @return {number} */ calculateProbability( groupSimilarity, titleSimilarity, domainSimilarity, params ) { const wGroup = params.GROUP_SIMILARITY_WEIGHT || 0; const wTitle = params.TITLE_SIMILARITY_WEIGHT || 0; const wDomain = params.DOMAIN_SIMILARITY_WEIGHT || 0; const z = groupSimilarity * wGroup + titleSimilarity * wTitle + domainSimilarity * wDomain + params.INTERCEPT; return this.sigmoid(z); } /** * Calculates the probabilities given similarity lists (cosine) and domain fractions. * * @param {number[]|null} groupSimilaritiesCos cosine(group, candidate) in [-1,1] or null * @param {number[]} titleSimilaritiesCos max cosine(anchor, candidate) in [-1,1] * @param {number[]} domainSimilarities s_dd in [0,1] * @return {number[]} probabilities for each candidate tab */ calculateAllProbabilities( groupSimilaritiesCos, titleSimilaritiesCos, domainSimilarities ) { const hasGroupSimilarity = Array.isArray(groupSimilaritiesCos) && groupSimilaritiesCos.length; const useDomain = Array.isArray(domainSimilarities) && domainSimilarities.length; const probabilities = []; for (let i = 0; i < titleSimilaritiesCos.length; i++) { // groupTitleSim and titleSim are (cos + 1)/2 -> [0,1] const groupTitleSim = hasGroupSimilarity ? 0.5 * (groupSimilaritiesCos[i] + 1) : 0; const titleSim = 0.5 * (titleSimilaritiesCos[i] + 1); const domainSim = useDomain ? domainSimilarities[i] : 0; const params = hasGroupSimilarity ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY; probabilities.push( this.calculateProbability(groupTitleSim, titleSim, domainSim, params) ); } return probabilities; } /** * Generates similar tabs to a grouped list of tabs using a logistic regression "model" * * @param {Array} allTabs all tabs that are part of the window * @param {Array} groupedIndices indices of tabs that are already part of the group * @param {Array} alreadyGroupedIndices indices of tabs that are part of other groups * @param {string} groupLabel name of group if present */ async findSimilarTabsLogisticRegression({ allTabs, groupedIndices, alreadyGroupedIndices, groupLabel = "", }) { const tabData = await this._prepareTabData(allTabs); const candidateIndices = this.getTabsToSuggest( tabData, groupedIndices, alreadyGroupedIndices ); const candidateTabsData = candidateIndices.map(ci => allTabs[ci]); const candidateTabsPrep = await this._prepareTabData(candidateTabsData); const anchorTabsPrep = groupedIndices .map(gi => tabData[gi]) .slice(0, MAX_GROUPED_TABS); // generate embeddings for both anchor and candidate titles const titleEmbeddings = await this._generateEmbeddings( anchorTabsPrep .concat(candidateTabsPrep) .map(tab => SmartTabGroupingManager.preprocessText(tab[EMBED_TEXT_KEY])) ); let groupEmbedding; let groupSimilaritiesCos = null; if (groupLabel) { groupEmbedding = await this._generateEmbeddings([groupLabel]); // cosine(group, candidate_title) in [-1,1] groupSimilaritiesCos = this.getAverageSimilarity( groupEmbedding, titleEmbeddings.slice(anchorTabsPrep.length) ); } // s_tt_max: max cosine(anchor_title, candidate_title) in [-1,1] const titleSimilaritiesCos = this.getMaxSimilarity( titleEmbeddings.slice(0, anchorTabsPrep.length), titleEmbeddings.slice(anchorTabsPrep.length) ); // s_dd: fraction of anchors sharing the candidate's base domain const domainSimilarities = this.getDomainMatchFractions( anchorTabsPrep, candidateTabsPrep ); const candidateProbabilities = this.calculateAllProbabilities( groupSimilaritiesCos, titleSimilaritiesCos, domainSimilarities ); // get matching params depending on the group name availability const probabilityThreshold = groupEmbedding ? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME.THRESHOLD : LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY.THRESHOLD; return ( candidateTabsData // combine candidate tabs with corresponding probabilities .map((ct, index) => ({ ct, prob: candidateProbabilities[index], })) // only keep those that are within the probability threshold .filter(item => item.prob >= probabilityThreshold) // ensure the highest probability candidates come first in the list .sort((a, b) => b.prob - a.prob) // keep the tabs only .map(item => item.ct) ); } /** * This function will terminate a grouping or label generation in progress * It is currently not implemented. */ terminateProcess() { // TODO - teminate AI processes, This method will be // called when tab grouping panel is closed. } /** * Changes the clustering method. Must be one of supported methods. * * @param {string} method Name of method */ setClusteringMethod(method) { if (!(method in CLUSTER_METHODS)) { throw new Error(`Clustering method ${method} not supported`); } this.config.clustering.clusterImplementation = method; } /** * Set the technique for clustering when certain tabs are already assigned to groups * * @param {string} method which is one of ANCHOR_METHODS */ setAnchorMethod(method) { if (!(method in ANCHOR_METHODS)) { throw new Error(`Clustering anchor method ${method} not supported`); } this.config.clustering.anchorMethod = method; } setSilBoost(boost) { this.config.clustering.pregroupedSilhouetteBoost = boost; } /** * Sets method to reduce dimensionality of embeddings prior to clustering * * @param {string} method Name of method */ setDimensionReductionMethod(method) { if (method && !(method in DIM_REDUCTION_METHODS)) { throw new Error(`Dimension reduction method ${method} not supported`); } this.config.clustering.dimReductionMethod = method; } /** * Sets the field name of the title of a page to be used when clustering or generating embeddings * This is useful when clustering test data that is not a tab object * * @param {string} titleKey KEY FOR THE TITLE */ setDataTitleKey(titleKey) { this.config.dataConfig.titleKey = titleKey; } /** * Logs to the appropriate place for debugging. Console for now * * @param {string} msg Message to log * @param {boolean} useDescription Whether to add description to the final text */ log(_msg) {} /** * Prepares data to be used by the ml models * * @param {object[]} tabList list of tabs in the current window * @param {boolean} useDescription whether we should combined the title and description * @return {Promise<*[Object]>} * @private */ async _prepareTabData(tabList, useDescription = false) { const titleKey = this.config.dataConfig.titleKey; const descriptionKey = this.config.dataConfig.descriptionKey; const structuredData = []; for (let tab of tabList) { const description = useDescription && descriptionKey && tab[descriptionKey]; let textToEmbed; if (description) { textToEmbed = tab[titleKey] + " " + description; } else { textToEmbed = tab[titleKey] || "Unknown"; } structuredData.push({ [EMBED_TEXT_KEY]: textToEmbed, title: tab[titleKey], description, url: tab?.linkedBrowser?.currentURI?.spec, }); } return structuredData; } /** * Get updated config for the ml engine * * @param {object} initData * @param {string} featureId * @return {*} */ static getUpdatedInitData(initData, featureId) { // we're setting a specific modelRevision through about:config or Nimbus if ( featureId === SMART_TAB_GROUPING_CONFIG.topicGeneration.featureId && lazy.topicModelRevision !== LATEST_MODEL_REVISION ) { initData.modelRevision = lazy.topicModelRevision; } else if ( featureId === SMART_TAB_GROUPING_CONFIG.embedding.featureId && lazy.embeddingModelRevision !== LATEST_MODEL_REVISION ) { initData.modelRevision = lazy.embeddingModelRevision; } return initData; } /** * Creates an ML engine for a given config. * * @param {*} engineConfig * @param {function} progressCallback * @returns MLEngine */ async _createMLEngine(engineConfig, progressCallback) { const { featureId, engineId, dtype, taskName, timeoutMS, modelId, modelRevision, backend, fallbackBackend, } = engineConfig; let initData = { featureId, engineId, dtype, taskName, timeoutMS, modelId, modelRevision, backend, }; initData = SmartTabGroupingManager.getUpdatedInitData(initData, featureId); let engine; try { engine = await createEngine(initData, progressCallback); this.backend = backend; } catch (e) { engine = await createEngine( { ...initData, backend: fallbackBackend, }, progressCallback ); this.backend = fallbackBackend; } return engine; } /** * Generates embeddings from a list of tab data structures * * @param tabList List of tabs with label (title) and description keys * @returns {Promise<*[]>} List of embeddings (2d array) * @private */ async _generateEmbeddings(textToEmbedList) { const inputData = { inputArgs: textToEmbedList, runOptions: { pooling: "mean", normalize: true, }, }; if (SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) { this.embeddingEngine = await this._createMLEngine(this.config.embedding); } const request = { args: [inputData.inputArgs], options: inputData.runOptions, }; return await this.embeddingEngine.run(request); } /** * Clusters in desired methods * based on the config of the class * * @param tabList List of tabs as array * @param docEmbeddings Precomputed embeddings for the Tab as two dimensional array * @param k Desired number of clusters. Tries a range of sizes if 0. * @param {function} randomFunc Optional seeded random number generator for testing * @returns {SmartTabGroupingResult} * @private */ _clusterEmbeddings({ tabs, embeddings, k, randomFunc, anchorIndices, alreadyGroupedIndices = [], }) { let allItems; const freezeAnchorsInZeroCluster = anchorIndices && this.config.clustering.anchorMethod == ANCHOR_METHODS.FIXED; const dimReductionMethod = this.config.clustering.dimReductionMethod; switch (dimReductionMethod) { default: // Dimensionality reduction support is landing very soon. break; } k = k || 0; let startK = k; let endK = k + 1; if (!k) { startK = 2; // Find a reasonable max # of clusters endK = Math.min( Math.floor(Math.log(embeddings.length) * 2.0), embeddings.length ) + 1; } let bestResult; let bestResultSilScore = -100.0; let bestResultCenterCluster = 0; const clusteringMethod = this.config.clustering.clusterImplementation; const clusteringTriesPerK = this.config.clustering.clusteringTriesPerK; for (let curK = startK; curK < endK; curK++) { let bestItemsForK; let bestInertiaForK = 500000000000; for (let j = 0; j < clusteringTriesPerK; j++) { switch (clusteringMethod) { case CLUSTER_METHODS.KMEANS: allItems = kmeansPlusPlus({ data: embeddings, k: curK, maxIterations: 0, randomFunc, anchorIndices, preassignedIndices: this.config.clustering.pregroupedHandlingMethod === PREGROUPED_HANDLING_METHODS.EXCLUDE ? alreadyGroupedIndices : [], freezeAnchorsInZeroCluster, }); break; default: throw Error("Clustering implementation not supported"); } const tempResult = new SmartTabGroupingResult({ indices: allItems, embeddings, config: this.config, }); const inertia = tempResult.getCentroidInertia(); if (inertia < bestInertiaForK) { bestInertiaForK = inertia; bestItemsForK = tempResult; } } const silScores = silhouetteCoefficients( embeddings, bestItemsForK.indices ); if ( freezeAnchorsInZeroCluster && this.config.clustering.pregroupedSilhouetteBoost > 0 ) { // Boost silhouette score of target cluster when we are grouping around an existing cluster // pregroupedSilhouetteBoost indicates the relative weight of the cluster's score and all other cluster's combined silScores[0] *= this.config.clustering.pregroupedSilhouetteBoost; } let avgSil = silScores.reduce((p, c) => p + c, 0) / silScores.length; let curAnchorCluster = 0; if (anchorIndices && !freezeAnchorsInZeroCluster) { const { anchorClusterIndex, numAnchorItemsInCluster } = getBestAnchorClusterInfo(bestItemsForK.indices, anchorIndices); curAnchorCluster = anchorClusterIndex; const penalty = (MISSING_ANCHOR_IN_CLUSTER_PENALTY * (anchorIndices.length - numAnchorItemsInCluster)) / anchorIndices.length; avgSil -= penalty; } if (avgSil > bestResultSilScore) { bestResultSilScore = avgSil; bestResult = bestItemsForK.indices; bestResultCenterCluster = curAnchorCluster; } } const result = new SmartTabGroupingResult({ indices: bestResult, tabs, embeddings, config: this.config, }); if (anchorIndices) { result.setAnchorClusterIndex( freezeAnchorsInZeroCluster ? 0 : bestResultCenterCluster ); // In our k-means clustering implementation anchor cluster is always first if (!freezeAnchorsInZeroCluster) { result.adjustClusterForAnchors(anchorIndices); } } return result; } /** * Generate a label for tabs in a group created by the user * * @param tabs tabs that are currently in the group * @param otherTabs tabs in the window not part of the group * @return {Promise} */ async getPredictedLabelForGroup(tabs, otherTabs) { const clusters = this.createStaticCluster(tabs); const otherClusters = this.createStaticCluster(otherTabs); let predictedLabel; try { // function below modifies "clusters" object await this.generateGroupLabels(clusters, otherClusters); predictedLabel = clusters.clusterRepresentations[0].predictedTopicLabel; } catch (e) { this.labelReason = LABEL_REASONS.ERROR; predictedLabel = ""; } return predictedLabel; } /** * Generates clusters for a given list of tabs using precomputed embeddings or newly generated ones. * * @param {object[]} tabList - List of tab objects to be clustered. * @param {number[][]} [precomputedEmbeddings] - Precomputed embeddings for tab titles and descriptions. * @param {number} numClusters - Number of clusters to form. * @param {Function} randFunc - Random function used for clustering initialization. * @param {number[]} [anchorIndices=[]] - Indices of anchor tabs that should be prioritized in clustering. * @param {number[]} [alreadyGroupedIndices=[]] - Indices of tabs that are already assigned to groups. * @returns {SmartTabGroupingResult} - The best clustering result based on centroid inertia. */ async generateClusters( tabList, precomputedEmbeddings, numClusters, randFunc, anchorIndices = [], alreadyGroupedIndices = [] ) { numClusters = numClusters ?? 0; const structuredData = await this._prepareTabData(tabList); // embeddings for title and description if (precomputedEmbeddings) { this.docEmbeddings = precomputedEmbeddings; } else { this.docEmbeddings = await this._generateEmbeddings( structuredData.map(a => a[EMBED_TEXT_KEY]) ); } let bestResultCluster; let bestResultDistance = 50000000.0; const NUM_RUNS = 1; for (let i = 0; i < NUM_RUNS; i++) { const curResult = this._clusterEmbeddings({ tabs: tabList, embeddings: this.docEmbeddings, k: numClusters, randomFunc: randFunc, anchorIndices, alreadyGroupedIndices, }); const distance = curResult.getCentroidInertia(); if (distance < bestResultDistance) { bestResultDistance = distance; bestResultCluster = curResult; } } return bestResultCluster; } /** * Create static cluster from a list of tabs. A single tab is Ok. Returns null for 0 tabs * * @param tabs * @returns {SmartTabGroupingResult} groupingResult */ createStaticCluster(tabs) { if (!tabs) { return null; } return new SmartTabGroupingResult({ indices: [Array.from({ length: tabs.length }, (_, i) => i)], tabs, config: this.config, }); } /** * Utility function that loads all required engines for Smart Tab Grouping and any dependent models * * @param {(progress: { percentage: number }) => void} progressCallback callback function to call. * Callback passes a dict with percentage indicating best effort 0.0-100.0 progress in model download. */ async preloadAllModels(progressCallback) { let previousProgress = -1; const expectedObjects = EXPECTED_TOPIC_MODEL_OBJECTS + EXPECTED_EMBEDDING_MODEL_OBJECTS; // TODO - Find a way to get these fields. Add as a transformers js callback or within remotesettings const UPDATE_THRESHOLD_PERCENTAGE = 0.5; const ONE_MB = 1024 * 1024; const START_THRESHOLD_BYTES = ONE_MB * 0.2; const mutliProgressAggregator = new lazy.MultiProgressAggregator({ progressCallback: ({ progress, totalLoaded, metadata }) => { if (totalLoaded < START_THRESHOLD_BYTES) { progress = 0.0; } else { const numObjSeen = metadata.totalObjectsSeen || 0; if (numObjSeen > 0 && numObjSeen < expectedObjects) { // When starting to download we may still be getting configs and not have all the data progress *= numObjSeen / expectedObjects; } if (progress > 100) { progress = 100; } } if ( Math.abs(previousProgress - progress) > UPDATE_THRESHOLD_PERCENTAGE ) { // Update only once changes are above a threshold to avoid throttling the UI with events. progressCallback({ percentage: progress, }); previousProgress = progress; } }, watchedTypes: [ lazy.Progress.ProgressType.DOWNLOAD, lazy.Progress.ProgressType.LOAD_FROM_CACHE, ], }); const [topicEngine, embeddingEngine] = await Promise.all([ this._createMLEngine( this.config.topicGeneration, mutliProgressAggregator?.aggregateCallback.bind( mutliProgressAggregator ) || null ), this._createMLEngine( this.config.embedding, mutliProgressAggregator?.aggregateCallback.bind( mutliProgressAggregator ) || null ), ]); this.topicEngine = topicEngine; this.embeddingEngine = embeddingEngine; } /** * Generate model input from keywords and documents * * @param {string []} keywords * @param {string []} documents */ createModelInput(keywords, documents) { if (!keywords || keywords.length === 0) { return `Topic from keywords: titles: \n${documents.join(" \n")}`; } return `Topic from keywords: ${keywords.join(", ")}. titles: \n${documents.join(" \n")}`; } /** * One artifact of the LLM output is that sometimes words are duplicated * This function cuts the phrase when it sees the first duplicate word. * Handles simple singluar / plural duplicates (-s only). * * @param {string} phrase Input phrase * @returns {string} phrase cut before any duplicate word */ static cutAtDuplicateWords(phrase) { if (!phrase.length) { return phrase; } const wordsSet = new Set(); const wordList = phrase.split(" "); for (let i = 0; i < wordList.length; i++) { let baseWord = wordList[i].toLowerCase(); if (baseWord.length > 3) { if (baseWord.slice(-1) === "s") { baseWord = baseWord.slice(0, -1); } } if (wordsSet.has(baseWord)) { // We are seeing a baseWord word. Exit with just the words so far and don't // add any new words return wordList.slice(0, i).join(" "); } wordsSet.add(baseWord); } return phrase; // return original phrase } /** * Removes trailing domain-related text such as '... - Mail' or '... | News' * If there's not enough information remaining after, we keep the text as is * * @param {string} text tab title with potential domain information * @return {string} */ static preprocessText(text) { // Matches 'xyz - Domain' or 'xyz | Domain' // with a space before and after delimiter // or if there are multiple delimiters next to each other const delimiters = /(?<=\s)[|–-]+(?=\s)/; const splitText = text.split(delimiters); // ensure there's enough info without the last element const hasEnoughInfo = !!splitText.length && splitText.slice(0, -1).join(" ").length > 5; // domain related texts are usually shorter, this takes care of the most common cases const isPotentialDomainInfo = splitText.length > 1 && splitText[splitText.length - 1].length < 20; // If both conditions are met, remove the last chunk, filter out empty strings, // join on space, trim, and lowercase if (hasEnoughInfo && isPotentialDomainInfo) { return splitText .slice(0, -1) // everything except the last element .map(t => t.trim()) .filter(Boolean) // remove empty strings .join(" ") // join with spaces .trim(); // remove leading/trailing spaces } // Otherwise, just return the text return text; } /** * Postprocessing of raw output from Topic Model ML Engine * * @param {string | undefined} topic Raw topic phrase from topic model or undefined in case of an error */ processTopicModelResult(topic) { let basicResult = (topic || "").trim(); if (!basicResult) { this.labelReason = LABEL_REASONS.LOW_CONFIDENCE; } if (LABELS_TO_EXCLUDE.includes(basicResult.toLowerCase())) { this.labelReason = LABEL_REASONS.EXCLUDE; return ""; } return SmartTabGroupingManager.cutAtDuplicateWords(basicResult); } /** * Add titles to a cluster in a SmartTabGroupingResult using generative tehniques * Currently this function only works with a single target group, and a separate * item that represents all other ungrouped tabs. * * In the future this may be updated to more generally find labels for a set of clusters. * * @param {SmartTabGroupingResult} groupingResult The cluster we are generating the label for * @param {SmartTabGroupingResult} otherGroupingResult A 'made up' cluster representing all other tabs in the window */ async generateGroupLabels(groupingResult, otherGroupingResult = null) { // Special case for a search page const searchTopicSpecialCase = Services.prefs.getBoolPref( "browser.tabs.groups.smart.searchTopicEnabled", true ); if ( searchTopicSpecialCase && groupingResult.clusterRepresentations.length == 1 && groupingResult.clusterRepresentations[0].isSingleTabSearch ) { if (groupingResult.clusterRepresentations[0].setSingleTabSearchLabel()) { return; } } const { keywords, documents } = groupingResult.getRepresentativeDocsAndKeywords( otherGroupingResult ? otherGroupingResult.getRepresentativeDocuments() : [] ); const inputArgs = this.createModelInput( keywords ? keywords[0] : [], documents ); const requestInfo = { inputArgs, runOptions: { max_length: 6, }, }; if (SmartTabGroupingManager.isEngineClosed(this.topicEngine)) { this.topicEngine = await this._createMLEngine( this.config.topicGeneration ); } const request = { args: [requestInfo.inputArgs], options: requestInfo.runOptions, }; const genLabelResults = await this.topicEngine.run(request); genLabelResults.forEach((genResult, genResultIndex) => { groupingResult.clusterRepresentations[ genResultIndex ].predictedTopicLabel = this.processTopicModelResult( genResult.generated_text ); }); } getLabelReason() { return this.labelReason || LABEL_REASONS.DEFAULT; } /** * Generates glean metrics for ml smart tab label / topic. * This is currently called when the user saves or cancels the "suggest label" flow. * * @param {string} action "save" or "cancel" * @param {number} numTabsInGroup Number of tabs used to generate the label * @param {string} mlLabel ML generated label for the tab group * @param {string} userLabel User saved label for the tab group * @param {string} id The id of the group */ async handleLabelTelemetry({ action, numTabsInGroup, mlLabel, userLabel, id = "", }) { const { [ML_TASK_TEXT2TEXT]: topicEngineConfig } = await this.getEngineConfigs(); const labelReason = this.getLabelReason(); Glean.tabgroup.smartTabTopic.record({ action, tabs_in_group: numTabsInGroup, ml_label_length: (mlLabel || "").length, user_label_length: (userLabel || "").length, levenshtein_distance: lazy.NLP.levenshtein( userLabel || "", mlLabel || "" ), model_revision: topicEngineConfig.modelRevision || "", id, label_reason: labelReason, backend: this.backend || "onnx-native", }); this.labelReason = LABEL_REASONS.DEFAULT; } /** * Generates glean metrics for ml smart tab label / topic. * This is currently called when the user saves or cancels the "suggest other tabs" flow * * @param {string} action "save" or "cancel" * @param {number} numTabsInWindow Number of tabs in the current window * @param {number} numTabsInGroup Number of tabs in the current group * @param {number} numTabsSuggested Number of tabs suggested by the model * @param {number} numTabsApproved Number of tabs approved by the user * @param {number} numTabsRemoved Number of tabs removed by the user * @param {string} id The id of the group */ async handleSuggestTelemetry({ action, numTabsInWindow, numTabsInGroup, numTabsSuggested, numTabsApproved, numTabsRemoved, id = "", }) { const { [ML_TASK_FEATURE_EXTRACTION]: embeddingEngineConfig } = await this.getEngineConfigs(); Glean.tabgroup.smartTabSuggest.record({ action, tabs_in_window: numTabsInWindow, tabs_in_group: numTabsInGroup, tabs_suggested: numTabsSuggested, tabs_approved: numTabsApproved, tabs_removed: numTabsRemoved, model_revision: embeddingEngineConfig.modelRevision || "", id, backend: this.backend || "onnx-native", }); } /** * Gets config that engine was initialized with * * @return {Promise<{"[ML_TASK_TEXT2TEXT]", "[ML_TASK_FEATURE_EXTRACTION]"}>} */ async getEngineConfigs() { if (!this.topicEngineConfig) { this.topicEngineConfig = await lazy.MLEngineParent.getInferenceOptions( this.config.topicGeneration.featureId, this.config.topicGeneration.taskName ); } if (!this.embeddingEngineConfig) { this.embeddingEngineConfig = await lazy.MLEngineParent.getInferenceOptions( this.config.embedding.featureId, this.config.embedding.taskName ); } return { [ML_TASK_TEXT2TEXT]: this.topicEngineConfig, [ML_TASK_FEATURE_EXTRACTION]: this.embeddingEngineConfig, }; } } export class SmartTabGroupingResult { #anchorClusterIndex = -1; // Index of cluster that has original items we're building clustering around, when building around an existing item. /** * Creates a result from indices and complete tab and embedding lists. * This may create some extra data for management later * * @param indices indices of clusters (eg [[2,4], [1], [3]]_ * @param tabItems 1D array of tabs * @param embeddingItems Two dimensional array of embeddings * @param config Cluster config */ constructor({ indices = [], tabs, embeddings, config }) { this.embeddingItems = embeddings; this.config = config; this.indices = indices.filter(subArray => !!subArray.length); // Cleanup any empty clusters this.tabItems = tabs; this._buildClusterRepresentations(); } /** * Builds list of ClusterRepresentations */ _buildClusterRepresentations() { this.clusterRepresentations = this.indices.map(subClusterIndices => { const tabItemsMapped = this.tabItems && subClusterIndices.map(idx => this.tabItems[idx]); const embeddingItemsMapped = this.embeddingItems && subClusterIndices.map(idx => this.embeddingItems[idx]); return new ClusterRepresentation({ tabs: tabItemsMapped, embeddings: embeddingItemsMapped, config: this.config, }); }); } /** * Returns a list of documents for each cluster. Currently it is a list of documents picked * in no particular order. * * @return {[strings]} Title and description that represent the cluster. (If no docs are in the class, then titles are returned) */ getRepresentativeDocuments() { if (!this.documents) { this.documents = this.tabItems.map( t => t[this.config.dataConfig.titleKey] ); } // set a limit of 10 for now return this.documents.slice(0, 10); } /** * Returns the keywords and documents for the cluster, computing if needed * Does not return keywods if only one document is passed to the function. * * @param {string[]} otherDocuments other clusters that we'll compare against * @return keywords and documents that represent the cluster */ getRepresentativeDocsAndKeywords(otherDocuments = []) { this.documents = this.getRepresentativeDocuments(); if (!this.keywords) { const joinedDocs = this.documents.slice(0, 3).join(" "); const otherDocs = otherDocuments.join(" "); if (this.documents.length > 1) { const keywordExtractor = new KeywordExtractor(); this.keywords = keywordExtractor.fitTransform([joinedDocs, otherDocs]); } else { this.keywords = []; } } return { keywords: this.keywords, documents: this.documents }; } setAnchorClusterIndex(index) { this.#anchorClusterIndex = index; } /** * Get the cluster we originally are grouping around (finding additinoal item) * * @returns ClusterRepresentation */ getAnchorCluster() { if (this.#anchorClusterIndex === -1) { return null; } return this.clusterRepresentations[this.#anchorClusterIndex]; } /** * Given the indices that we were clustering around, make sure they are are all in the target grouping * Our generic k-means clustering might have them in separate groups */ adjustClusterForAnchors(anchorIndices) { if (!anchorIndices.length) { return; } const anchorSet = new Set(anchorIndices); for (let i = 0; i < this.indices.length; i++) { if (i === this.#anchorClusterIndex) { continue; } this.indices[i] = this.indices[i].filter(item => { if (anchorSet.has(item)) { this.indices[this.#anchorClusterIndex].push(item); return false; } return true; }); } this._buildClusterRepresentations(); } /** * Prints information about the cluster */ printClusters() { for (let cluster of this.clusterRepresentations) { cluster.print(); } } /** * Computes the inertia of the cluster which is the sum of square total distance. * * @returns {number} */ getCentroidInertia() { let runningTotalDistance = 0; this.clusterRepresentations.forEach(rep => { runningTotalDistance += rep.computeTotalSquaredCentroidDistance(); }); return runningTotalDistance; } /** * Converts a cluster representation to a flat list of tabs, with clusterID key in each * tab representing the id of the cluster it was part of. * * @returns {[object]} */ _flatMapItemsInClusters() { return this.clusterRepresentations.reduce((result, clusterRep) => { const annotatedTabs = clusterRep.tabs.map(a => { let c = {}; Object.assign(c, a); c.clusterID = clusterRep.clusterID; return c; }); return result.concat(annotatedTabs); }, []); } /** * Get rand score which describes the accuracy versus a user labeled * annotation on the dataset. Requires the dataset to be labeled. * * @param labelKey Key in the tabs that represent a unique label ID for the cluster. * @returns {number} The rand score. */ getRandScore(labelKey = "annotatedLabel") { const combinedItems = this._flatMapItemsInClusters(); return computeRandScore(combinedItems, "clusterID", labelKey); } /** * Get accuracy for a specific cluster * * @param labelKey Key in the tabs that represent a unique label ID for the cluster. * @param clusterValue is the cluster we are comparing * @returns {number} The rand score. */ getAccuracyStatsForCluster(labelKey = "annotatedLabel", clusterValue) { const combinedItems = this._flatMapItemsInClusters(); let keyClusterId = combinedItems.find( a => a[labelKey] === clusterValue ).clusterID; let truePositives = 0, trueNegatives = 0, falseNegatives = 0, falsePositives = 0; combinedItems.forEach(item => { const sameLabel = item[labelKey] === clusterValue; const sameCluster = item.clusterID === keyClusterId; if (sameLabel && sameCluster) { truePositives++; } if (!sameLabel && !sameCluster) { trueNegatives++; } if (sameLabel && !sameCluster) { falseNegatives++; } if (!sameLabel && sameCluster) { falsePositives++; } }); return getAccuracyStats({ truePositives, trueNegatives, falsePositives, falseNegatives, }); } } /** * Utility function to generate a random ID string * * @param len Length of the string * @returns {string} */ function genHexString(len) { const hex = "0123456789ABCDEF"; let output = ""; for (let i = 0; i < len; ++i) { output += hex.charAt(Math.floor(Math.random() * hex.length)); } return output; } class EmbeddingCluster { constructor({ tabs, embeddings, centroid }) { this.embeddings = embeddings; this.centroid = centroid || (embeddings && computeCentroidFrom2DArray(this.embeddings)); this.tabs = tabs; } /** * @returns total sum euclidan squared distance of each item from cluster's centroid */ computeTotalSquaredCentroidDistance() { let totalDistance = 0; if (this.embeddings.length === 0) { return 0; } this.embeddings.forEach(embedding => { totalDistance += euclideanDistance(this.centroid, embedding, true); }); return totalDistance; } /** * Returns number of items in the cluster * * @returns {int} */ numItems() { return this.tabs.length; } } /** * Represents a single cluster with additional saved metadata */ export class ClusterRepresentation extends EmbeddingCluster { constructor({ tabs, embeddings, centroid, config }) { super({ tabs, embeddings, centroid }); this.config = config; this.predictedTopicLabel = null; this.annotatedTopicLabel = null; this.userEditedTopicLabel = null; this.representativeText = null; this.keywords = null; this.documents = null; this.clusterID = genHexString(10); this.isSingleTabSearch = tabs?.length == 1 && isSearchTab(tabs[0]); } /** * For a single tab cluster with a search field, set the predicted topic * to be the title of the page * * @returns {boolean} True if we updated the cluster label successfully */ setSingleTabSearchLabel() { if (this.tabs.length !== 1) { return false; } const pageTitle = this.tabs[0][this.config.dataConfig.titleKey] || ""; for (let i = pageTitle.length - 1; i > 0; i--) { if (TITLE_DELIMETER_SET.has(pageTitle[i])) { const topicString = pageTitle.substring(0, i).trim(); if (topicString.length > MAX_NON_SUMMARIZED_SEARCH_LENGTH) { return false; } // Capitalize first character of each word. Regex returns first char of each word this.predictedTopicLabel = topicString.replace(/(^|\s)\S/g, t => t.toUpperCase() ); return true; } } return false; } /** * Returns the representative text for a cluster, computing it if needed */ getRepresentativeText() { if (!this.representativeText) { this.representativeText = this._generateRepresentativeText(); } return this.representativeText; } /** * Returns representative text for a cluster. * For this in initial implementation it simply returns title from a few tabs * * @returns {string} * @private */ _generateRepresentativeText() { let text = ""; const titleKey = this.config.dataConfig.titleKey; for (const tab of this.tabs.slice(0, 3)) { text += `\n${tab[titleKey]}`; } return text; } print() { // Add console log for debugging } }