From 77d9b13b6db0afead521713204ffc4dced7ad0f2 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:17:13 +0100 Subject: [PATCH 01/13] amd-cache-optimizer Signed-off-by: Peter Jung --- .../sysfs-bus-platform-drivers-amd_x3d_vcache | 14 ++ MAINTAINERS | 8 + drivers/platform/x86/amd/Kconfig | 12 ++ drivers/platform/x86/amd/Makefile | 2 + drivers/platform/x86/amd/x3d_vcache.c | 193 ++++++++++++++++++ 5 files changed, 229 insertions(+) create mode 100644 Documentation/ABI/testing/sysfs-bus-platform-drivers-amd_x3d_vcache create mode 100644 drivers/platform/x86/amd/x3d_vcache.c diff --git a/Documentation/ABI/testing/sysfs-bus-platform-drivers-amd_x3d_vcache b/Documentation/ABI/testing/sysfs-bus-platform-drivers-amd_x3d_vcache new file mode 100644 index 000000000000..1aa6ed0c10d9 --- /dev/null +++ b/Documentation/ABI/testing/sysfs-bus-platform-drivers-amd_x3d_vcache @@ -0,0 +1,14 @@ +What: /sys/bus/platform/drivers/amd_x3d_vcache/AMDI0101\:00/amd_x3d_mode +Date: October 2024 +KernelVersion: 6.13 +Contact: Basavaraj Natikar +Description: (RW) AMD 3D V-Cache optimizer allows users to switch CPU core + rankings dynamically. + + This file switches between these two modes: + - "frequency" cores within the faster CCD are prioritized before + those in the slower CCD. + - "cache" cores within the larger L3 CCD are prioritized before + those in the smaller L3 CCD. + + Format: %s. diff --git a/MAINTAINERS b/MAINTAINERS index 21fdaa19229a..5dc7d5839fe9 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -965,6 +965,14 @@ Q: https://patchwork.kernel.org/project/linux-rdma/list/ F: drivers/infiniband/hw/efa/ F: include/uapi/rdma/efa-abi.h +AMD 3D V-CACHE PERFORMANCE OPTIMIZER DRIVER +M: Basavaraj Natikar +R: Mario Limonciello +L: platform-driver-x86@vger.kernel.org +S: Supported +F: Documentation/ABI/testing/sysfs-bus-platform-drivers-amd_x3d_vcache +F: drivers/platform/x86/amd/x3d_vcache.c + AMD ADDRESS TRANSLATION LIBRARY (ATL) M: Yazen Ghannam L: linux-edac@vger.kernel.org diff --git a/drivers/platform/x86/amd/Kconfig b/drivers/platform/x86/amd/Kconfig index f88682d36447..d73f691020d0 100644 --- a/drivers/platform/x86/amd/Kconfig +++ b/drivers/platform/x86/amd/Kconfig @@ -6,6 +6,18 @@ source "drivers/platform/x86/amd/pmf/Kconfig" source "drivers/platform/x86/amd/pmc/Kconfig" +config AMD_3D_VCACHE + tristate "AMD 3D V-Cache Performance Optimizer Driver" + depends on X86_64 && ACPI + help + The driver provides a sysfs interface, enabling the setting of a bias + that alters CPU core reordering. This bias prefers cores with higher + frequencies or larger L3 caches on processors supporting AMD 3D V-Cache + technology. + + If you choose to compile this driver as a module the module will be + called amd_3d_vcache. + config AMD_HSMP tristate "AMD HSMP Driver" depends on AMD_NB && X86_64 && ACPI diff --git a/drivers/platform/x86/amd/Makefile b/drivers/platform/x86/amd/Makefile index dcec0a46f8af..16e4cce02242 100644 --- a/drivers/platform/x86/amd/Makefile +++ b/drivers/platform/x86/amd/Makefile @@ -4,6 +4,8 @@ # AMD x86 Platform-Specific Drivers # +obj-$(CONFIG_AMD_3D_VCACHE) += amd_3d_vcache.o +amd_3d_vcache-objs := x3d_vcache.o obj-$(CONFIG_AMD_PMC) += pmc/ amd_hsmp-y := hsmp.o obj-$(CONFIG_AMD_HSMP) += amd_hsmp.o diff --git a/drivers/platform/x86/amd/x3d_vcache.c b/drivers/platform/x86/amd/x3d_vcache.c new file mode 100644 index 000000000000..679613d02b9a --- /dev/null +++ b/drivers/platform/x86/amd/x3d_vcache.c @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * AMD 3D V-Cache Performance Optimizer Driver + * + * Copyright (c) 2024, Advanced Micro Devices, Inc. + * All Rights Reserved. + * + * Authors: Basavaraj Natikar + * Perry Yuan + * Mario Limonciello + * + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include + +static char *x3d_mode = "frequency"; +module_param(x3d_mode, charp, 0444); +MODULE_PARM_DESC(x3d_mode, "Initial 3D-VCache mode; 'frequency' (default) or 'cache'"); + +#define DSM_REVISION_ID 0 +#define DSM_GET_FUNCS_SUPPORTED 0 +#define DSM_SET_X3D_MODE 1 + +static guid_t x3d_guid = GUID_INIT(0xdff8e55f, 0xbcfd, 0x46fb, 0xba, 0x0a, + 0xef, 0xd0, 0x45, 0x0f, 0x34, 0xee); + +enum amd_x3d_mode_type { + MODE_INDEX_FREQ, + MODE_INDEX_CACHE, +}; + +static const char * const amd_x3d_mode_strings[] = { + [MODE_INDEX_FREQ] = "frequency", + [MODE_INDEX_CACHE] = "cache", +}; + +struct amd_x3d_dev { + struct device *dev; + acpi_handle ahandle; + /* To protect x3d mode setting */ + struct mutex lock; + enum amd_x3d_mode_type curr_mode; +}; + +static int amd_x3d_mode_switch(struct amd_x3d_dev *data, int new_state) +{ + union acpi_object *out, argv; + + guard(mutex)(&data->lock); + argv.type = ACPI_TYPE_INTEGER; + argv.integer.value = new_state; + + out = acpi_evaluate_dsm(data->ahandle, &x3d_guid, DSM_REVISION_ID, DSM_SET_X3D_MODE, + &argv); + if (!out) { + dev_err(data->dev, "failed to evaluate _DSM\n"); + return -EINVAL; + } + + data->curr_mode = new_state; + + ACPI_FREE(out); + + return 0; +} + +static ssize_t amd_x3d_mode_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct amd_x3d_dev *data = dev_get_drvdata(dev); + int ret; + + ret = sysfs_match_string(amd_x3d_mode_strings, buf); + if (ret < 0) { + dev_err(dev, "no matching mode to set %s\n", buf); + return ret; + } + + ret = amd_x3d_mode_switch(data, ret); + + return ret ? ret : count; +} + +static ssize_t amd_x3d_mode_show(struct device *dev, struct device_attribute *attr, char *buf) +{ + struct amd_x3d_dev *data = dev_get_drvdata(dev); + + if (data->curr_mode > MODE_INDEX_CACHE || data->curr_mode < MODE_INDEX_FREQ) + return -EINVAL; + + return sysfs_emit(buf, "%s\n", amd_x3d_mode_strings[data->curr_mode]); +} +static DEVICE_ATTR_RW(amd_x3d_mode); + +static struct attribute *amd_x3d_attrs[] = { + &dev_attr_amd_x3d_mode.attr, + NULL +}; +ATTRIBUTE_GROUPS(amd_x3d); + +static int amd_x3d_supported(struct amd_x3d_dev *data) +{ + union acpi_object *out; + + out = acpi_evaluate_dsm(data->ahandle, &x3d_guid, DSM_REVISION_ID, + DSM_GET_FUNCS_SUPPORTED, NULL); + if (!out) { + dev_err(data->dev, "failed to evaluate _DSM\n"); + return -ENODEV; + } + + if (out->type != ACPI_TYPE_BUFFER) { + dev_err(data->dev, "invalid type %d\n", out->type); + ACPI_FREE(out); + return -EINVAL; + } + + ACPI_FREE(out); + return 0; +} + +static const struct acpi_device_id amd_x3d_acpi_ids[] = { + {"AMDI0101"}, + { }, +}; +MODULE_DEVICE_TABLE(acpi, amd_x3d_acpi_ids); + +static void amd_x3d_remove(void *context) +{ + struct amd_x3d_dev *data = context; + + mutex_destroy(&data->lock); +} + +static int amd_x3d_probe(struct platform_device *pdev) +{ + const struct acpi_device_id *id; + struct amd_x3d_dev *data; + acpi_handle handle; + int ret; + + handle = ACPI_HANDLE(&pdev->dev); + if (!handle) + return -ENODEV; + + id = acpi_match_device(amd_x3d_acpi_ids, &pdev->dev); + if (!id) + dev_err_probe(&pdev->dev, -ENODEV, "unable to match ACPI ID and data\n"); + + data = devm_kzalloc(&pdev->dev, sizeof(*data), GFP_KERNEL); + if (!data) + return -ENOMEM; + + data->dev = &pdev->dev; + data->ahandle = handle; + platform_set_drvdata(pdev, data); + + ret = amd_x3d_supported(data); + if (ret) + dev_err_probe(&pdev->dev, ret, "not supported on this platform\n"); + + ret = match_string(amd_x3d_mode_strings, ARRAY_SIZE(amd_x3d_mode_strings), x3d_mode); + if (ret < 0) + return dev_err_probe(&pdev->dev, -EINVAL, "invalid mode %s\n", x3d_mode); + + mutex_init(&data->lock); + + ret = amd_x3d_mode_switch(data, ret); + if (ret < 0) + return ret; + + return devm_add_action_or_reset(&pdev->dev, amd_x3d_remove, data); +} + +static struct platform_driver amd_3d_vcache_driver = { + .driver = { + .name = "amd_x3d_vcache", + .dev_groups = amd_x3d_groups, + .acpi_match_table = amd_x3d_acpi_ids, + }, + .probe = amd_x3d_probe, +}; +module_platform_driver(amd_3d_vcache_driver); + +MODULE_DESCRIPTION("AMD 3D V-Cache Performance Optimizer Driver"); +MODULE_LICENSE("GPL"); -- 2.47.0 From 54c4f598ee011b1f701bdc2a924e9930fbf10962 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:17:26 +0100 Subject: [PATCH 02/13] amd-pstate Signed-off-by: Peter Jung --- arch/x86/include/asm/cpufeatures.h | 3 +- arch/x86/include/asm/intel-family.h | 6 + arch/x86/include/asm/processor.h | 18 ++ arch/x86/include/asm/topology.h | 9 + arch/x86/kernel/acpi/cppc.c | 23 +++ arch/x86/kernel/cpu/debugfs.c | 1 + arch/x86/kernel/cpu/scattered.c | 3 +- arch/x86/kernel/cpu/topology_amd.c | 3 + arch/x86/kernel/cpu/topology_common.c | 34 ++++ arch/x86/kernel/smpboot.c | 5 +- arch/x86/mm/init.c | 23 ++- drivers/cpufreq/amd-pstate-ut.c | 6 +- drivers/cpufreq/amd-pstate.c | 231 ++++++++++------------- tools/arch/x86/include/asm/cpufeatures.h | 2 +- 14 files changed, 214 insertions(+), 153 deletions(-) diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h index 913fd3a7bac6..a7c93191b7c6 100644 --- a/arch/x86/include/asm/cpufeatures.h +++ b/arch/x86/include/asm/cpufeatures.h @@ -473,7 +473,8 @@ #define X86_FEATURE_BHI_CTRL (21*32+ 2) /* BHI_DIS_S HW control available */ #define X86_FEATURE_CLEAR_BHB_HW (21*32+ 3) /* BHI_DIS_S HW control enabled */ #define X86_FEATURE_CLEAR_BHB_LOOP_ON_VMEXIT (21*32+ 4) /* Clear branch history at vmexit using SW loop */ -#define X86_FEATURE_FAST_CPPC (21*32 + 5) /* AMD Fast CPPC */ +#define X86_FEATURE_AMD_FAST_CPPC (21*32 + 5) /* Fast CPPC */ +#define X86_FEATURE_AMD_HETEROGENEOUS_CORES (21*32 + 6) /* Heterogeneous Core Topology */ /* * BUG word(s) diff --git a/arch/x86/include/asm/intel-family.h b/arch/x86/include/asm/intel-family.h index 1a42f829667a..736764472048 100644 --- a/arch/x86/include/asm/intel-family.h +++ b/arch/x86/include/asm/intel-family.h @@ -183,4 +183,10 @@ /* Family 19 */ #define INTEL_PANTHERCOVE_X IFM(19, 0x01) /* Diamond Rapids */ +/* CPU core types */ +enum intel_cpu_type { + INTEL_CPU_TYPE_ATOM = 0x20, + INTEL_CPU_TYPE_CORE = 0x40, +}; + #endif /* _ASM_X86_INTEL_FAMILY_H */ diff --git a/arch/x86/include/asm/processor.h b/arch/x86/include/asm/processor.h index 4a686f0e5dbf..c0975815980c 100644 --- a/arch/x86/include/asm/processor.h +++ b/arch/x86/include/asm/processor.h @@ -105,6 +105,24 @@ struct cpuinfo_topology { // Cache level topology IDs u32 llc_id; u32 l2c_id; + + // Hardware defined CPU-type + union { + u32 cpu_type; + struct { + // CPUID.1A.EAX[23-0] + u32 intel_native_model_id :24; + // CPUID.1A.EAX[31-24] + u32 intel_type :8; + }; + struct { + // CPUID 0x80000026.EBX + u32 amd_num_processors :16, + amd_power_eff_ranking :8, + amd_native_model_id :4, + amd_type :4; + }; + }; }; struct cpuinfo_x86 { diff --git a/arch/x86/include/asm/topology.h b/arch/x86/include/asm/topology.h index 92f3664dd933..fd41103ad342 100644 --- a/arch/x86/include/asm/topology.h +++ b/arch/x86/include/asm/topology.h @@ -114,6 +114,12 @@ enum x86_topology_domains { TOPO_MAX_DOMAIN, }; +enum x86_topology_cpu_type { + TOPO_CPU_TYPE_PERFORMANCE, + TOPO_CPU_TYPE_EFFICIENCY, + TOPO_CPU_TYPE_UNKNOWN, +}; + struct x86_topology_system { unsigned int dom_shifts[TOPO_MAX_DOMAIN]; unsigned int dom_size[TOPO_MAX_DOMAIN]; @@ -149,6 +155,9 @@ extern unsigned int __max_threads_per_core; extern unsigned int __num_threads_per_package; extern unsigned int __num_cores_per_package; +const char *get_topology_cpu_type_name(struct cpuinfo_x86 *c); +enum x86_topology_cpu_type get_topology_cpu_type(struct cpuinfo_x86 *c); + static inline unsigned int topology_max_packages(void) { return __max_logical_packages; diff --git a/arch/x86/kernel/acpi/cppc.c b/arch/x86/kernel/acpi/cppc.c index aab9d0570841..d745dd586303 100644 --- a/arch/x86/kernel/acpi/cppc.c +++ b/arch/x86/kernel/acpi/cppc.c @@ -239,8 +239,10 @@ EXPORT_SYMBOL_GPL(amd_detect_prefcore); */ int amd_get_boost_ratio_numerator(unsigned int cpu, u64 *numerator) { + enum x86_topology_cpu_type core_type = get_topology_cpu_type(&cpu_data(cpu)); bool prefcore; int ret; + u32 tmp; ret = amd_detect_prefcore(&prefcore); if (ret) @@ -266,6 +268,27 @@ int amd_get_boost_ratio_numerator(unsigned int cpu, u64 *numerator) break; } } + + /* detect if running on heterogeneous design */ + if (cpu_feature_enabled(X86_FEATURE_AMD_HETEROGENEOUS_CORES)) { + switch (core_type) { + case TOPO_CPU_TYPE_UNKNOWN: + pr_warn("Undefined core type found for cpu %d\n", cpu); + break; + case TOPO_CPU_TYPE_PERFORMANCE: + /* use the max scale for performance cores */ + *numerator = CPPC_HIGHEST_PERF_PERFORMANCE; + return 0; + case TOPO_CPU_TYPE_EFFICIENCY: + /* use the highest perf value for efficiency cores */ + ret = amd_get_highest_perf(cpu, &tmp); + if (ret) + return ret; + *numerator = tmp; + return 0; + } + } + *numerator = CPPC_HIGHEST_PERF_PREFCORE; return 0; diff --git a/arch/x86/kernel/cpu/debugfs.c b/arch/x86/kernel/cpu/debugfs.c index 3baf3e435834..10719aba6276 100644 --- a/arch/x86/kernel/cpu/debugfs.c +++ b/arch/x86/kernel/cpu/debugfs.c @@ -22,6 +22,7 @@ static int cpu_debug_show(struct seq_file *m, void *p) seq_printf(m, "die_id: %u\n", c->topo.die_id); seq_printf(m, "cu_id: %u\n", c->topo.cu_id); seq_printf(m, "core_id: %u\n", c->topo.core_id); + seq_printf(m, "cpu_type: %s\n", get_topology_cpu_type_name(c)); seq_printf(m, "logical_pkg_id: %u\n", c->topo.logical_pkg_id); seq_printf(m, "logical_die_id: %u\n", c->topo.logical_die_id); seq_printf(m, "llc_id: %u\n", c->topo.llc_id); diff --git a/arch/x86/kernel/cpu/scattered.c b/arch/x86/kernel/cpu/scattered.c index c84c30188fdf..307a91741534 100644 --- a/arch/x86/kernel/cpu/scattered.c +++ b/arch/x86/kernel/cpu/scattered.c @@ -45,13 +45,14 @@ static const struct cpuid_bit cpuid_bits[] = { { X86_FEATURE_HW_PSTATE, CPUID_EDX, 7, 0x80000007, 0 }, { X86_FEATURE_CPB, CPUID_EDX, 9, 0x80000007, 0 }, { X86_FEATURE_PROC_FEEDBACK, CPUID_EDX, 11, 0x80000007, 0 }, - { X86_FEATURE_FAST_CPPC, CPUID_EDX, 15, 0x80000007, 0 }, + { X86_FEATURE_AMD_FAST_CPPC, CPUID_EDX, 15, 0x80000007, 0 }, { X86_FEATURE_MBA, CPUID_EBX, 6, 0x80000008, 0 }, { X86_FEATURE_SMBA, CPUID_EBX, 2, 0x80000020, 0 }, { X86_FEATURE_BMEC, CPUID_EBX, 3, 0x80000020, 0 }, { X86_FEATURE_PERFMON_V2, CPUID_EAX, 0, 0x80000022, 0 }, { X86_FEATURE_AMD_LBR_V2, CPUID_EAX, 1, 0x80000022, 0 }, { X86_FEATURE_AMD_LBR_PMC_FREEZE, CPUID_EAX, 2, 0x80000022, 0 }, + { X86_FEATURE_AMD_HETEROGENEOUS_CORES, CPUID_EAX, 30, 0x80000026, 0 }, { 0, 0, 0, 0, 0 } }; diff --git a/arch/x86/kernel/cpu/topology_amd.c b/arch/x86/kernel/cpu/topology_amd.c index 7d476fa697ca..03b3c9c3a45e 100644 --- a/arch/x86/kernel/cpu/topology_amd.c +++ b/arch/x86/kernel/cpu/topology_amd.c @@ -182,6 +182,9 @@ static void parse_topology_amd(struct topo_scan *tscan) if (cpu_feature_enabled(X86_FEATURE_TOPOEXT)) has_topoext = cpu_parse_topology_ext(tscan); + if (cpu_feature_enabled(X86_FEATURE_AMD_HETEROGENEOUS_CORES)) + tscan->c->topo.cpu_type = cpuid_ebx(0x80000026); + if (!has_topoext && !parse_8000_0008(tscan)) return; diff --git a/arch/x86/kernel/cpu/topology_common.c b/arch/x86/kernel/cpu/topology_common.c index 9a6069e7133c..8277c64f88db 100644 --- a/arch/x86/kernel/cpu/topology_common.c +++ b/arch/x86/kernel/cpu/topology_common.c @@ -3,6 +3,7 @@ #include +#include #include #include #include @@ -27,6 +28,36 @@ void topology_set_dom(struct topo_scan *tscan, enum x86_topology_domains dom, } } +enum x86_topology_cpu_type get_topology_cpu_type(struct cpuinfo_x86 *c) +{ + if (c->x86_vendor == X86_VENDOR_INTEL) { + switch (c->topo.intel_type) { + case INTEL_CPU_TYPE_ATOM: return TOPO_CPU_TYPE_EFFICIENCY; + case INTEL_CPU_TYPE_CORE: return TOPO_CPU_TYPE_PERFORMANCE; + } + } + if (c->x86_vendor == X86_VENDOR_AMD) { + switch (c->topo.amd_type) { + case 0: return TOPO_CPU_TYPE_PERFORMANCE; + case 1: return TOPO_CPU_TYPE_EFFICIENCY; + } + } + + return TOPO_CPU_TYPE_UNKNOWN; +} + +const char *get_topology_cpu_type_name(struct cpuinfo_x86 *c) +{ + switch (get_topology_cpu_type(c)) { + case TOPO_CPU_TYPE_PERFORMANCE: + return "performance"; + case TOPO_CPU_TYPE_EFFICIENCY: + return "efficiency"; + default: + return "unknown"; + } +} + static unsigned int __maybe_unused parse_num_cores_legacy(struct cpuinfo_x86 *c) { struct { @@ -87,6 +118,7 @@ static void parse_topology(struct topo_scan *tscan, bool early) .cu_id = 0xff, .llc_id = BAD_APICID, .l2c_id = BAD_APICID, + .cpu_type = TOPO_CPU_TYPE_UNKNOWN, }; struct cpuinfo_x86 *c = tscan->c; struct { @@ -132,6 +164,8 @@ static void parse_topology(struct topo_scan *tscan, bool early) case X86_VENDOR_INTEL: if (!IS_ENABLED(CONFIG_CPU_SUP_INTEL) || !cpu_parse_topology_ext(tscan)) parse_legacy(tscan); + if (c->cpuid_level >= 0x1a) + c->topo.cpu_type = cpuid_eax(0x1a); break; case X86_VENDOR_HYGON: if (IS_ENABLED(CONFIG_CPU_SUP_HYGON)) diff --git a/arch/x86/kernel/smpboot.c b/arch/x86/kernel/smpboot.c index 766f092dab80..b5a8f0891135 100644 --- a/arch/x86/kernel/smpboot.c +++ b/arch/x86/kernel/smpboot.c @@ -497,8 +497,9 @@ static int x86_cluster_flags(void) static int x86_die_flags(void) { - if (cpu_feature_enabled(X86_FEATURE_HYBRID_CPU)) - return x86_sched_itmt_flags(); + if (cpu_feature_enabled(X86_FEATURE_HYBRID_CPU) || + cpu_feature_enabled(X86_FEATURE_AMD_HETEROGENEOUS_CORES)) + return x86_sched_itmt_flags(); return 0; } diff --git a/arch/x86/mm/init.c b/arch/x86/mm/init.c index eb503f53c319..101725c149c4 100644 --- a/arch/x86/mm/init.c +++ b/arch/x86/mm/init.c @@ -263,28 +263,33 @@ static void __init probe_page_size_mask(void) } /* - * INVLPG may not properly flush Global entries - * on these CPUs when PCIDs are enabled. + * INVLPG may not properly flush Global entries on + * these CPUs. New microcode fixes the issue. */ static const struct x86_cpu_id invlpg_miss_ids[] = { - X86_MATCH_VFM(INTEL_ALDERLAKE, 0), - X86_MATCH_VFM(INTEL_ALDERLAKE_L, 0), - X86_MATCH_VFM(INTEL_ATOM_GRACEMONT, 0), - X86_MATCH_VFM(INTEL_RAPTORLAKE, 0), - X86_MATCH_VFM(INTEL_RAPTORLAKE_P, 0), - X86_MATCH_VFM(INTEL_RAPTORLAKE_S, 0), + X86_MATCH_VFM(INTEL_ALDERLAKE, 0x2e), + X86_MATCH_VFM(INTEL_ALDERLAKE_L, 0x42c), + X86_MATCH_VFM(INTEL_ATOM_GRACEMONT, 0x11), + X86_MATCH_VFM(INTEL_RAPTORLAKE, 0x118), + X86_MATCH_VFM(INTEL_RAPTORLAKE_P, 0x4117), + X86_MATCH_VFM(INTEL_RAPTORLAKE_S, 0x2e), {} }; static void setup_pcid(void) { + const struct x86_cpu_id *invlpg_miss_match; + if (!IS_ENABLED(CONFIG_X86_64)) return; if (!boot_cpu_has(X86_FEATURE_PCID)) return; - if (x86_match_cpu(invlpg_miss_ids)) { + invlpg_miss_match = x86_match_cpu(invlpg_miss_ids); + + if (invlpg_miss_match && + boot_cpu_data.microcode < invlpg_miss_match->driver_data) { pr_info("Incomplete global flushes, disabling PCID"); setup_clear_cpu_cap(X86_FEATURE_PCID); return; diff --git a/drivers/cpufreq/amd-pstate-ut.c b/drivers/cpufreq/amd-pstate-ut.c index f66701514d90..a261d7300951 100644 --- a/drivers/cpufreq/amd-pstate-ut.c +++ b/drivers/cpufreq/amd-pstate-ut.c @@ -227,10 +227,10 @@ static void amd_pstate_ut_check_freq(u32 index) goto skip_test; } - if (cpudata->min_freq != policy->min) { + if (cpudata->lowest_nonlinear_freq != policy->min) { amd_pstate_ut_cases[index].result = AMD_PSTATE_UT_RESULT_FAIL; - pr_err("%s cpu%d cpudata_min_freq=%d policy_min=%d, they should be equal!\n", - __func__, cpu, cpudata->min_freq, policy->min); + pr_err("%s cpu%d cpudata_lowest_nonlinear_freq=%d policy_min=%d, they should be equal!\n", + __func__, cpu, cpudata->lowest_nonlinear_freq, policy->min); goto skip_test; } diff --git a/drivers/cpufreq/amd-pstate.c b/drivers/cpufreq/amd-pstate.c index b63863f77c67..d7630bab2516 100644 --- a/drivers/cpufreq/amd-pstate.c +++ b/drivers/cpufreq/amd-pstate.c @@ -233,7 +233,7 @@ static int amd_pstate_get_energy_pref_index(struct amd_cpudata *cpudata) return index; } -static void pstate_update_perf(struct amd_cpudata *cpudata, u32 min_perf, +static void msr_update_perf(struct amd_cpudata *cpudata, u32 min_perf, u32 des_perf, u32 max_perf, bool fast_switch) { if (fast_switch) @@ -243,7 +243,7 @@ static void pstate_update_perf(struct amd_cpudata *cpudata, u32 min_perf, READ_ONCE(cpudata->cppc_req_cached)); } -DEFINE_STATIC_CALL(amd_pstate_update_perf, pstate_update_perf); +DEFINE_STATIC_CALL(amd_pstate_update_perf, msr_update_perf); static inline void amd_pstate_update_perf(struct amd_cpudata *cpudata, u32 min_perf, u32 des_perf, @@ -306,11 +306,17 @@ static int amd_pstate_set_energy_pref_index(struct amd_cpudata *cpudata, return ret; } -static inline int pstate_enable(bool enable) +static inline int msr_cppc_enable(bool enable) { int ret, cpu; unsigned long logical_proc_id_mask = 0; + /* + * MSR_AMD_CPPC_ENABLE is write-once, once set it cannot be cleared. + */ + if (!enable) + return 0; + if (enable == cppc_enabled) return 0; @@ -332,7 +338,7 @@ static inline int pstate_enable(bool enable) return 0; } -static int cppc_enable(bool enable) +static int shmem_cppc_enable(bool enable) { int cpu, ret = 0; struct cppc_perf_ctrls perf_ctrls; @@ -359,14 +365,14 @@ static int cppc_enable(bool enable) return ret; } -DEFINE_STATIC_CALL(amd_pstate_enable, pstate_enable); +DEFINE_STATIC_CALL(amd_pstate_cppc_enable, msr_cppc_enable); -static inline int amd_pstate_enable(bool enable) +static inline int amd_pstate_cppc_enable(bool enable) { - return static_call(amd_pstate_enable)(enable); + return static_call(amd_pstate_cppc_enable)(enable); } -static int pstate_init_perf(struct amd_cpudata *cpudata) +static int msr_init_perf(struct amd_cpudata *cpudata) { u64 cap1; @@ -385,7 +391,7 @@ static int pstate_init_perf(struct amd_cpudata *cpudata) return 0; } -static int cppc_init_perf(struct amd_cpudata *cpudata) +static int shmem_init_perf(struct amd_cpudata *cpudata) { struct cppc_perf_caps cppc_perf; @@ -420,14 +426,14 @@ static int cppc_init_perf(struct amd_cpudata *cpudata) return ret; } -DEFINE_STATIC_CALL(amd_pstate_init_perf, pstate_init_perf); +DEFINE_STATIC_CALL(amd_pstate_init_perf, msr_init_perf); static inline int amd_pstate_init_perf(struct amd_cpudata *cpudata) { return static_call(amd_pstate_init_perf)(cpudata); } -static void cppc_update_perf(struct amd_cpudata *cpudata, +static void shmem_update_perf(struct amd_cpudata *cpudata, u32 min_perf, u32 des_perf, u32 max_perf, bool fast_switch) { @@ -527,9 +533,28 @@ static void amd_pstate_update(struct amd_cpudata *cpudata, u32 min_perf, cpufreq_cpu_put(policy); } -static int amd_pstate_verify(struct cpufreq_policy_data *policy) +static int amd_pstate_verify(struct cpufreq_policy_data *policy_data) { - cpufreq_verify_within_cpu_limits(policy); + /* + * Initialize lower frequency limit (i.e.policy->min) with + * lowest_nonlinear_frequency which is the most energy efficient + * frequency. Override the initial value set by cpufreq core and + * amd-pstate qos_requests. + */ + if (policy_data->min == FREQ_QOS_MIN_DEFAULT_VALUE) { + struct cpufreq_policy *policy = cpufreq_cpu_get(policy_data->cpu); + struct amd_cpudata *cpudata; + + if (!policy) + return -EINVAL; + + cpudata = policy->driver_data; + policy_data->min = cpudata->lowest_nonlinear_freq; + cpufreq_cpu_put(policy); + } + + cpufreq_verify_within_cpu_limits(policy_data); + pr_debug("policy_max =%d, policy_min=%d\n", policy_data->max, policy_data->min); return 0; } @@ -665,34 +690,12 @@ static void amd_pstate_adjust_perf(unsigned int cpu, static int amd_pstate_cpu_boost_update(struct cpufreq_policy *policy, bool on) { struct amd_cpudata *cpudata = policy->driver_data; - struct cppc_perf_ctrls perf_ctrls; - u32 highest_perf, nominal_perf, nominal_freq, max_freq; + u32 nominal_freq, max_freq; int ret = 0; - highest_perf = READ_ONCE(cpudata->highest_perf); - nominal_perf = READ_ONCE(cpudata->nominal_perf); nominal_freq = READ_ONCE(cpudata->nominal_freq); max_freq = READ_ONCE(cpudata->max_freq); - if (boot_cpu_has(X86_FEATURE_CPPC)) { - u64 value = READ_ONCE(cpudata->cppc_req_cached); - - value &= ~GENMASK_ULL(7, 0); - value |= on ? highest_perf : nominal_perf; - WRITE_ONCE(cpudata->cppc_req_cached, value); - - wrmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, value); - } else { - perf_ctrls.max_perf = on ? highest_perf : nominal_perf; - ret = cppc_set_perf(cpudata->cpu, &perf_ctrls); - if (ret) { - cpufreq_cpu_release(policy); - pr_debug("Failed to set max perf on CPU:%d. ret:%d\n", - cpudata->cpu, ret); - return ret; - } - } - if (on) policy->cpuinfo.max_freq = max_freq; else if (policy->cpuinfo.max_freq > nominal_freq * 1000) @@ -847,7 +850,7 @@ static u32 amd_pstate_get_transition_delay_us(unsigned int cpu) transition_delay_ns = cppc_get_transition_latency(cpu); if (transition_delay_ns == CPUFREQ_ETERNAL) { - if (cpu_feature_enabled(X86_FEATURE_FAST_CPPC)) + if (cpu_feature_enabled(X86_FEATURE_AMD_FAST_CPPC)) return AMD_PSTATE_FAST_CPPC_TRANSITION_DELAY; else return AMD_PSTATE_TRANSITION_DELAY; @@ -1001,7 +1004,7 @@ static int amd_pstate_cpu_init(struct cpufreq_policy *policy) policy->fast_switch_possible = true; ret = freq_qos_add_request(&policy->constraints, &cpudata->req[0], - FREQ_QOS_MIN, policy->cpuinfo.min_freq); + FREQ_QOS_MIN, FREQ_QOS_MIN_DEFAULT_VALUE); if (ret < 0) { dev_err(dev, "Failed to add min-freq constraint (%d)\n", ret); goto free_cpudata1; @@ -1045,7 +1048,7 @@ static int amd_pstate_cpu_resume(struct cpufreq_policy *policy) { int ret; - ret = amd_pstate_enable(true); + ret = amd_pstate_cppc_enable(true); if (ret) pr_err("failed to enable amd-pstate during resume, return %d\n", ret); @@ -1056,7 +1059,7 @@ static int amd_pstate_cpu_suspend(struct cpufreq_policy *policy) { int ret; - ret = amd_pstate_enable(false); + ret = amd_pstate_cppc_enable(false); if (ret) pr_err("failed to disable amd-pstate during suspend, return %d\n", ret); @@ -1189,25 +1192,41 @@ static ssize_t show_energy_performance_preference( static void amd_pstate_driver_cleanup(void) { - amd_pstate_enable(false); + amd_pstate_cppc_enable(false); cppc_state = AMD_PSTATE_DISABLE; current_pstate_driver = NULL; } +static int amd_pstate_set_driver(int mode_idx) +{ + if (mode_idx >= AMD_PSTATE_DISABLE && mode_idx < AMD_PSTATE_MAX) { + cppc_state = mode_idx; + if (cppc_state == AMD_PSTATE_DISABLE) + pr_info("driver is explicitly disabled\n"); + + if (cppc_state == AMD_PSTATE_ACTIVE) + current_pstate_driver = &amd_pstate_epp_driver; + + if (cppc_state == AMD_PSTATE_PASSIVE || cppc_state == AMD_PSTATE_GUIDED) + current_pstate_driver = &amd_pstate_driver; + + return 0; + } + + return -EINVAL; +} + static int amd_pstate_register_driver(int mode) { int ret; - if (mode == AMD_PSTATE_PASSIVE || mode == AMD_PSTATE_GUIDED) - current_pstate_driver = &amd_pstate_driver; - else if (mode == AMD_PSTATE_ACTIVE) - current_pstate_driver = &amd_pstate_epp_driver; - else - return -EINVAL; + ret = amd_pstate_set_driver(mode); + if (ret) + return ret; cppc_state = mode; - ret = amd_pstate_enable(true); + ret = amd_pstate_cppc_enable(true); if (ret) { pr_err("failed to enable cppc during amd-pstate driver registration, return %d\n", ret); @@ -1485,6 +1504,8 @@ static int amd_pstate_epp_cpu_init(struct cpufreq_policy *policy) WRITE_ONCE(cpudata->cppc_cap1_cached, value); } + current_pstate_driver->adjust_perf = NULL; + return 0; free_cpudata1: @@ -1507,26 +1528,13 @@ static void amd_pstate_epp_cpu_exit(struct cpufreq_policy *policy) static int amd_pstate_epp_update_limit(struct cpufreq_policy *policy) { struct amd_cpudata *cpudata = policy->driver_data; - u32 max_perf, min_perf, min_limit_perf, max_limit_perf; + u32 max_perf, min_perf; u64 value; s16 epp; - if (cpudata->boost_supported && !policy->boost_enabled) - max_perf = READ_ONCE(cpudata->nominal_perf); - else - max_perf = READ_ONCE(cpudata->highest_perf); + max_perf = READ_ONCE(cpudata->highest_perf); min_perf = READ_ONCE(cpudata->lowest_perf); - max_limit_perf = div_u64(policy->max * max_perf, policy->cpuinfo.max_freq); - min_limit_perf = div_u64(policy->min * max_perf, policy->cpuinfo.max_freq); - - if (min_limit_perf < min_perf) - min_limit_perf = min_perf; - - if (max_limit_perf < min_limit_perf) - max_limit_perf = min_limit_perf; - - WRITE_ONCE(cpudata->max_limit_perf, max_limit_perf); - WRITE_ONCE(cpudata->min_limit_perf, min_limit_perf); + amd_pstate_update_min_max_limit(policy); max_perf = clamp_t(unsigned long, max_perf, cpudata->min_limit_perf, cpudata->max_limit_perf); @@ -1535,7 +1543,7 @@ static int amd_pstate_epp_update_limit(struct cpufreq_policy *policy) value = READ_ONCE(cpudata->cppc_req_cached); if (cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) - min_perf = max_perf; + min_perf = min(cpudata->nominal_perf, max_perf); /* Initial min/max values for CPPC Performance Controls Register */ value &= ~AMD_CPPC_MIN_PERF(~0L); @@ -1563,12 +1571,6 @@ static int amd_pstate_epp_update_limit(struct cpufreq_policy *policy) if (cpudata->policy == CPUFREQ_POLICY_PERFORMANCE) epp = 0; - /* Set initial EPP value */ - if (cpu_feature_enabled(X86_FEATURE_CPPC)) { - value &= ~GENMASK_ULL(31, 24); - value |= (u64)epp << 24; - } - WRITE_ONCE(cpudata->cppc_req_cached, value); return amd_pstate_set_epp(cpudata, epp); } @@ -1605,7 +1607,7 @@ static void amd_pstate_epp_reenable(struct amd_cpudata *cpudata) u64 value, max_perf; int ret; - ret = amd_pstate_enable(true); + ret = amd_pstate_cppc_enable(true); if (ret) pr_err("failed to enable amd pstate during resume, return %d\n", ret); @@ -1616,8 +1618,9 @@ static void amd_pstate_epp_reenable(struct amd_cpudata *cpudata) wrmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, value); } else { perf_ctrls.max_perf = max_perf; - perf_ctrls.energy_perf = AMD_CPPC_ENERGY_PERF_PREF(cpudata->epp_cached); cppc_set_perf(cpudata->cpu, &perf_ctrls); + perf_ctrls.energy_perf = AMD_CPPC_ENERGY_PERF_PREF(cpudata->epp_cached); + cppc_set_epp_perf(cpudata->cpu, &perf_ctrls, 1); } } @@ -1657,9 +1660,11 @@ static void amd_pstate_epp_offline(struct cpufreq_policy *policy) wrmsrl_on_cpu(cpudata->cpu, MSR_AMD_CPPC_REQ, value); } else { perf_ctrls.desired_perf = 0; + perf_ctrls.min_perf = min_perf; perf_ctrls.max_perf = min_perf; - perf_ctrls.energy_perf = AMD_CPPC_ENERGY_PERF_PREF(HWP_EPP_BALANCE_POWERSAVE); cppc_set_perf(cpudata->cpu, &perf_ctrls); + perf_ctrls.energy_perf = AMD_CPPC_ENERGY_PERF_PREF(HWP_EPP_BALANCE_POWERSAVE); + cppc_set_epp_perf(cpudata->cpu, &perf_ctrls, 1); } mutex_unlock(&amd_pstate_limits_lock); } @@ -1679,13 +1684,6 @@ static int amd_pstate_epp_cpu_offline(struct cpufreq_policy *policy) return 0; } -static int amd_pstate_epp_verify_policy(struct cpufreq_policy_data *policy) -{ - cpufreq_verify_within_cpu_limits(policy); - pr_debug("policy_max =%d, policy_min=%d\n", policy->max, policy->min); - return 0; -} - static int amd_pstate_epp_suspend(struct cpufreq_policy *policy) { struct amd_cpudata *cpudata = policy->driver_data; @@ -1699,7 +1697,7 @@ static int amd_pstate_epp_suspend(struct cpufreq_policy *policy) cpudata->suspended = true; /* disable CPPC in lowlevel firmware */ - ret = amd_pstate_enable(false); + ret = amd_pstate_cppc_enable(false); if (ret) pr_err("failed to suspend, return %d\n", ret); @@ -1741,7 +1739,7 @@ static struct cpufreq_driver amd_pstate_driver = { static struct cpufreq_driver amd_pstate_epp_driver = { .flags = CPUFREQ_CONST_LOOPS, - .verify = amd_pstate_epp_verify_policy, + .verify = amd_pstate_verify, .setpolicy = amd_pstate_epp_set_policy, .init = amd_pstate_epp_cpu_init, .exit = amd_pstate_epp_cpu_exit, @@ -1755,26 +1753,7 @@ static struct cpufreq_driver amd_pstate_epp_driver = { .attr = amd_pstate_epp_attr, }; -static int __init amd_pstate_set_driver(int mode_idx) -{ - if (mode_idx >= AMD_PSTATE_DISABLE && mode_idx < AMD_PSTATE_MAX) { - cppc_state = mode_idx; - if (cppc_state == AMD_PSTATE_DISABLE) - pr_info("driver is explicitly disabled\n"); - - if (cppc_state == AMD_PSTATE_ACTIVE) - current_pstate_driver = &amd_pstate_epp_driver; - - if (cppc_state == AMD_PSTATE_PASSIVE || cppc_state == AMD_PSTATE_GUIDED) - current_pstate_driver = &amd_pstate_driver; - - return 0; - } - - return -EINVAL; -} - -/** +/* * CPPC function is not supported for family ID 17H with model_ID ranging from 0x10 to 0x2F. * show the debug message that helps to check if the CPU has CPPC support for loading issue. */ @@ -1864,10 +1843,10 @@ static int __init amd_pstate_init(void) if (cppc_state == AMD_PSTATE_UNDEFINED) { /* Disable on the following configs by default: * 1. Undefined platforms - * 2. Server platforms + * 2. Server platforms with CPUs older than Family 0x1A. */ if (amd_pstate_acpi_pm_profile_undefined() || - amd_pstate_acpi_pm_profile_server()) { + (amd_pstate_acpi_pm_profile_server() && boot_cpu_data.x86 < 0x1A)) { pr_info("driver load is disabled, boot with specific mode to enable this\n"); return -ENODEV; } @@ -1875,50 +1854,31 @@ static int __init amd_pstate_init(void) cppc_state = CONFIG_X86_AMD_PSTATE_DEFAULT_MODE; } - switch (cppc_state) { - case AMD_PSTATE_DISABLE: + if (cppc_state == AMD_PSTATE_DISABLE) { pr_info("driver load is disabled, boot with specific mode to enable this\n"); return -ENODEV; - case AMD_PSTATE_PASSIVE: - case AMD_PSTATE_ACTIVE: - case AMD_PSTATE_GUIDED: - ret = amd_pstate_set_driver(cppc_state); - if (ret) - return ret; - break; - default: - return -EINVAL; } /* capability check */ if (cpu_feature_enabled(X86_FEATURE_CPPC)) { pr_debug("AMD CPPC MSR based functionality is supported\n"); - if (cppc_state != AMD_PSTATE_ACTIVE) - current_pstate_driver->adjust_perf = amd_pstate_adjust_perf; } else { pr_debug("AMD CPPC shared memory based functionality is supported\n"); - static_call_update(amd_pstate_enable, cppc_enable); - static_call_update(amd_pstate_init_perf, cppc_init_perf); - static_call_update(amd_pstate_update_perf, cppc_update_perf); + static_call_update(amd_pstate_cppc_enable, shmem_cppc_enable); + static_call_update(amd_pstate_init_perf, shmem_init_perf); + static_call_update(amd_pstate_update_perf, shmem_update_perf); } - if (amd_pstate_prefcore) { - ret = amd_detect_prefcore(&amd_pstate_prefcore); - if (ret) - return ret; - } - - /* enable amd pstate feature */ - ret = amd_pstate_enable(true); + ret = amd_pstate_register_driver(cppc_state); if (ret) { - pr_err("failed to enable driver mode(%d)\n", cppc_state); + pr_err("failed to register with return %d\n", ret); return ret; } - ret = cpufreq_register_driver(current_pstate_driver); - if (ret) { - pr_err("failed to register with return %d\n", ret); - goto disable_driver; + if (amd_pstate_prefcore) { + ret = amd_detect_prefcore(&amd_pstate_prefcore); + if (ret) + return ret; } dev_root = bus_get_dev_root(&cpu_subsys); @@ -1935,8 +1895,7 @@ static int __init amd_pstate_init(void) global_attr_free: cpufreq_unregister_driver(current_pstate_driver); -disable_driver: - amd_pstate_enable(false); + amd_pstate_cppc_enable(false); return ret; } device_initcall(amd_pstate_init); diff --git a/tools/arch/x86/include/asm/cpufeatures.h b/tools/arch/x86/include/asm/cpufeatures.h index dd4682857c12..23698d0f4bb4 100644 --- a/tools/arch/x86/include/asm/cpufeatures.h +++ b/tools/arch/x86/include/asm/cpufeatures.h @@ -472,7 +472,7 @@ #define X86_FEATURE_BHI_CTRL (21*32+ 2) /* BHI_DIS_S HW control available */ #define X86_FEATURE_CLEAR_BHB_HW (21*32+ 3) /* BHI_DIS_S HW control enabled */ #define X86_FEATURE_CLEAR_BHB_LOOP_ON_VMEXIT (21*32+ 4) /* Clear branch history at vmexit using SW loop */ -#define X86_FEATURE_FAST_CPPC (21*32 + 5) /* AMD Fast CPPC */ +#define X86_FEATURE_AMD_FAST_CPPC (21*32 + 5) /* AMD Fast CPPC */ /* * BUG word(s) -- 2.47.0 From 6c34d83a13cc89085c20da699633ac1f6b612596 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:17:40 +0100 Subject: [PATCH 03/13] autofdo Signed-off-by: Peter Jung --- Documentation/dev-tools/autofdo.rst | 168 ++++++++++++++++++++++++++ Documentation/dev-tools/index.rst | 2 + Documentation/dev-tools/propeller.rst | 162 +++++++++++++++++++++++++ MAINTAINERS | 14 +++ Makefile | 2 + arch/Kconfig | 39 ++++++ arch/sparc/kernel/vmlinux.lds.S | 5 + arch/x86/Kconfig | 2 + arch/x86/kernel/vmlinux.lds.S | 4 + include/asm-generic/vmlinux.lds.h | 49 ++++++-- scripts/Makefile.autofdo | 24 ++++ scripts/Makefile.lib | 20 +++ scripts/Makefile.propeller | 28 +++++ tools/objtool/check.c | 2 + tools/objtool/elf.c | 15 ++- 15 files changed, 520 insertions(+), 16 deletions(-) create mode 100644 Documentation/dev-tools/autofdo.rst create mode 100644 Documentation/dev-tools/propeller.rst create mode 100644 scripts/Makefile.autofdo create mode 100644 scripts/Makefile.propeller diff --git a/Documentation/dev-tools/autofdo.rst b/Documentation/dev-tools/autofdo.rst new file mode 100644 index 000000000000..1f0a451e9ccd --- /dev/null +++ b/Documentation/dev-tools/autofdo.rst @@ -0,0 +1,168 @@ +.. SPDX-License-Identifier: GPL-2.0 + +=================================== +Using AutoFDO with the Linux kernel +=================================== + +This enables AutoFDO build support for the kernel when using +the Clang compiler. AutoFDO (Auto-Feedback-Directed Optimization) +is a type of profile-guided optimization (PGO) used to enhance the +performance of binary executables. It gathers information about the +frequency of execution of various code paths within a binary using +hardware sampling. This data is then used to guide the compiler's +optimization decisions, resulting in a more efficient binary. AutoFDO +is a powerful optimization technique, and data indicates that it can +significantly improve kernel performance. It's especially beneficial +for workloads affected by front-end stalls. + +For AutoFDO builds, unlike non-FDO builds, the user must supply a +profile. Acquiring an AutoFDO profile can be done in several ways. +AutoFDO profiles are created by converting hardware sampling using +the "perf" tool. It is crucial that the workload used to create these +perf files is representative; they must exhibit runtime +characteristics similar to the workloads that are intended to be +optimized. Failure to do so will result in the compiler optimizing +for the wrong objective. + +The AutoFDO profile often encapsulates the program's behavior. If the +performance-critical codes are architecture-independent, the profile +can be applied across platforms to achieve performance gains. For +instance, using the profile generated on Intel architecture to build +a kernel for AMD architecture can also yield performance improvements. + +There are two methods for acquiring a representative profile: +(1) Sample real workloads using a production environment. +(2) Generate the profile using a representative load test. +When enabling the AutoFDO build configuration without providing an +AutoFDO profile, the compiler only modifies the dwarf information in +the kernel without impacting runtime performance. It's advisable to +use a kernel binary built with the same AutoFDO configuration to +collect the perf profile. While it's possible to use a kernel built +with different options, it may result in inferior performance. + +One can collect profiles using AutoFDO build for the previous kernel. +AutoFDO employs relative line numbers to match the profiles, offering +some tolerance for source changes. This mode is commonly used in a +production environment for profile collection. + +In a profile collection based on a load test, the AutoFDO collection +process consists of the following steps: + +#. Initial build: The kernel is built with AutoFDO options + without a profile. + +#. Profiling: The above kernel is then run with a representative + workload to gather execution frequency data. This data is + collected using hardware sampling, via perf. AutoFDO is most + effective on platforms supporting advanced PMU features like + LBR on Intel machines. + +#. AutoFDO profile generation: Perf output file is converted to + the AutoFDO profile via offline tools. + +The support requires a Clang compiler LLVM 17 or later. + +Preparation +=========== + +Configure the kernel with:: + + CONFIG_AUTOFDO_CLANG=y + +Customization +============= + +The default CONFIG_AUTOFDO_CLANG setting covers kernel space objects for +AutoFDO builds. One can, however, enable or disable AutoFDO build for +individual files and directories by adding a line similar to the following +to the respective kernel Makefile: + +- For enabling a single file (e.g. foo.o) :: + + AUTOFDO_PROFILE_foo.o := y + +- For enabling all files in one directory :: + + AUTOFDO_PROFILE := y + +- For disabling one file :: + + AUTOFDO_PROFILE_foo.o := n + +- For disabling all files in one directory :: + + AUTOFDO_PROFILE := n + +Workflow +======== + +Here is an example workflow for AutoFDO kernel: + +1) Build the kernel on the host machine with LLVM enabled, + for example, :: + + $ make menuconfig LLVM=1 + + Turn on AutoFDO build config:: + + CONFIG_AUTOFDO_CLANG=y + + With a configuration that with LLVM enabled, use the following command:: + + $ scripts/config -e AUTOFDO_CLANG + + After getting the config, build with :: + + $ make LLVM=1 + +2) Install the kernel on the test machine. + +3) Run the load tests. The '-c' option in perf specifies the sample + event period. We suggest using a suitable prime number, like 500009, + for this purpose. + + - For Intel platforms:: + + $ perf record -e BR_INST_RETIRED.NEAR_TAKEN:k -a -N -b -c -o -- + + - For AMD platforms: + + The supported systems are: Zen3 with BRS, or Zen4 with amd_lbr_v2. To check, + + For Zen3:: + + $ cat proc/cpuinfo | grep " brs" + + For Zen4:: + + $ cat proc/cpuinfo | grep amd_lbr_v2 + + The following command generated the perf data file:: + + $ perf record --pfm-events RETIRED_TAKEN_BRANCH_INSTRUCTIONS:k -a -N -b -c -o -- + +4) (Optional) Download the raw perf file to the host machine. + +5) To generate an AutoFDO profile, two offline tools are available: + create_llvm_prof and llvm_profgen. The create_llvm_prof tool is part + of the AutoFDO project and can be found on GitHub + (https://github.com/google/autofdo), version v0.30.1 or later. + The llvm_profgen tool is included in the LLVM compiler itself. It's + important to note that the version of llvm_profgen doesn't need to match + the version of Clang. It needs to be the LLVM 19 release of Clang + or later, or just from the LLVM trunk. :: + + $ llvm-profgen --kernel --binary= --perfdata= -o + + or :: + + $ create_llvm_prof --binary= --profile= --format=extbinary --out= + + Note that multiple AutoFDO profile files can be merged into one via:: + + $ llvm-profdata merge -o ... + +6) Rebuild the kernel using the AutoFDO profile file with the same config as step 1, + (Note CONFIG_AUTOFDO_CLANG needs to be enabled):: + + $ make LLVM=1 CLANG_AUTOFDO_PROFILE= diff --git a/Documentation/dev-tools/index.rst b/Documentation/dev-tools/index.rst index 53d4d124f9c5..3c0ac08b2709 100644 --- a/Documentation/dev-tools/index.rst +++ b/Documentation/dev-tools/index.rst @@ -34,6 +34,8 @@ Documentation/dev-tools/testing-overview.rst ktap checkuapi gpio-sloppy-logic-analyzer + autofdo + propeller .. only:: subproject and html diff --git a/Documentation/dev-tools/propeller.rst b/Documentation/dev-tools/propeller.rst new file mode 100644 index 000000000000..92195958e3db --- /dev/null +++ b/Documentation/dev-tools/propeller.rst @@ -0,0 +1,162 @@ +.. SPDX-License-Identifier: GPL-2.0 + +===================================== +Using Propeller with the Linux kernel +===================================== + +This enables Propeller build support for the kernel when using Clang +compiler. Propeller is a profile-guided optimization (PGO) method used +to optimize binary executables. Like AutoFDO, it utilizes hardware +sampling to gather information about the frequency of execution of +different code paths within a binary. Unlike AutoFDO, this information +is then used right before linking phase to optimize (among others) +block layout within and across functions. + +A few important notes about adopting Propeller optimization: + +#. Although it can be used as a standalone optimization step, it is + strongly recommended to apply Propeller on top of AutoFDO, + AutoFDO+ThinLTO or Instrument FDO. The rest of this document + assumes this paradigm. + +#. Propeller uses another round of profiling on top of + AutoFDO/AutoFDO+ThinLTO/iFDO. The whole build process involves + "build-afdo - train-afdo - build-propeller - train-propeller - + build-optimized". + +#. Propeller requires LLVM 19 release or later for Clang/Clang++ + and the linker(ld.lld). + +#. In addition to LLVM toolchain, Propeller requires a profiling + conversion tool: https://github.com/google/autofdo with a release + after v0.30.1: https://github.com/google/autofdo/releases/tag/v0.30.1. + +The Propeller optimization process involves the following steps: + +#. Initial building: Build the AutoFDO or AutoFDO+ThinLTO binary as + you would normally do, but with a set of compile-time / link-time + flags, so that a special metadata section is created within the + kernel binary. The special section is only intend to be used by the + profiling tool, it is not part of the runtime image, nor does it + change kernel run time text sections. + +#. Profiling: The above kernel is then run with a representative + workload to gather execution frequency data. This data is collected + using hardware sampling, via perf. Propeller is most effective on + platforms supporting advanced PMU features like LBR on Intel + machines. This step is the same as profiling the kernel for AutoFDO + (the exact perf parameters can be different). + +#. Propeller profile generation: Perf output file is converted to a + pair of Propeller profiles via an offline tool. + +#. Optimized build: Build the AutoFDO or AutoFDO+ThinLTO optimized + binary as you would normally do, but with a compile-time / + link-time flag to pick up the Propeller compile time and link time + profiles. This build step uses 3 profiles - the AutoFDO profile, + the Propeller compile-time profile and the Propeller link-time + profile. + +#. Deployment: The optimized kernel binary is deployed and used + in production environments, providing improved performance + and reduced latency. + +Preparation +=========== + +Configure the kernel with:: + + CONFIG_AUTOFDO_CLANG=y + CONFIG_PROPELLER_CLANG=y + +Customization +============= + +The default CONFIG_PROPELLER_CLANG setting covers kernel space objects +for Propeller builds. One can, however, enable or disable Propeller build +for individual files and directories by adding a line similar to the +following to the respective kernel Makefile: + +- For enabling a single file (e.g. foo.o):: + + PROPELLER_PROFILE_foo.o := y + +- For enabling all files in one directory:: + + PROPELLER_PROFILE := y + +- For disabling one file:: + + PROPELLER_PROFILE_foo.o := n + +- For disabling all files in one directory:: + + PROPELLER__PROFILE := n + + +Workflow +======== + +Here is an example workflow for building an AutoFDO+Propeller kernel: + +1) Assuming an AutoFDO profile is already collected following + instructions in the AutoFDO document, build the kernel on the host + machine, with AutoFDO and Propeller build configs :: + + CONFIG_AUTOFDO_CLANG=y + CONFIG_PROPELLER_CLANG=y + + and :: + + $ make LLVM=1 CLANG_AUTOFDO_PROFILE= + +2) Install the kernel on the test machine. + +3) Run the load tests. The '-c' option in perf specifies the sample + event period. We suggest using a suitable prime number, like 500009, + for this purpose. + + - For Intel platforms:: + + $ perf record -e BR_INST_RETIRED.NEAR_TAKEN:k -a -N -b -c -o -- + + - For AMD platforms:: + + $ perf record --pfm-event RETIRED_TAKEN_BRANCH_INSTRUCTIONS:k -a -N -b -c -o -- + + Note you can repeat the above steps to collect multiple s. + +4) (Optional) Download the raw perf file(s) to the host machine. + +5) Use the create_llvm_prof tool (https://github.com/google/autofdo) to + generate Propeller profile. :: + + $ create_llvm_prof --binary= --profile= + --format=propeller --propeller_output_module_name + --out=_cc_profile.txt + --propeller_symorder=_ld_profile.txt + + "" can be something like "/home/user/dir/any_string". + + This command generates a pair of Propeller profiles: + "_cc_profile.txt" and + "_ld_profile.txt". + + If there are more than 1 perf_file collected in the previous step, + you can create a temp list file "" with each line + containing one perf file name and run:: + + $ create_llvm_prof --binary= --profile=@ + --format=propeller --propeller_output_module_name + --out=_cc_profile.txt + --propeller_symorder=_ld_profile.txt + +6) Rebuild the kernel using the AutoFDO and Propeller + profiles. :: + + CONFIG_AUTOFDO_CLANG=y + CONFIG_PROPELLER_CLANG=y + + and :: + + $ make LLVM=1 CLANG_AUTOFDO_PROFILE= CLANG_PROPELLER_PROFILE_PREFIX= diff --git a/MAINTAINERS b/MAINTAINERS index 5dc7d5839fe9..3d4709c29704 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -3674,6 +3674,13 @@ F: kernel/audit* F: lib/*audit.c K: \baudit_[a-z_0-9]\+\b +AUTOFDO BUILD +M: Rong Xu +M: Han Shen +S: Supported +F: Documentation/dev-tools/autofdo.rst +F: scripts/Makefile.autofdo + AUXILIARY BUS DRIVER M: Greg Kroah-Hartman R: Dave Ertman @@ -18505,6 +18512,13 @@ S: Maintained F: include/linux/psi* F: kernel/sched/psi.c +PROPELLER BUILD +M: Rong Xu +M: Han Shen +S: Supported +F: Documentation/dev-tools/propeller.rst +F: scripts/Makefile.propeller + PRINTK M: Petr Mladek R: Steven Rostedt diff --git a/Makefile b/Makefile index 79192a3024bf..e619df4e09b8 100644 --- a/Makefile +++ b/Makefile @@ -1018,6 +1018,8 @@ include-$(CONFIG_KMSAN) += scripts/Makefile.kmsan include-$(CONFIG_UBSAN) += scripts/Makefile.ubsan include-$(CONFIG_KCOV) += scripts/Makefile.kcov include-$(CONFIG_RANDSTRUCT) += scripts/Makefile.randstruct +include-$(CONFIG_AUTOFDO_CLANG) += scripts/Makefile.autofdo +include-$(CONFIG_PROPELLER_CLANG) += scripts/Makefile.propeller include-$(CONFIG_GCC_PLUGINS) += scripts/Makefile.gcc-plugins include $(addprefix $(srctree)/, $(include-y)) diff --git a/arch/Kconfig b/arch/Kconfig index bd9f095d69fa..00551f340dbe 100644 --- a/arch/Kconfig +++ b/arch/Kconfig @@ -811,6 +811,45 @@ config LTO_CLANG_THIN If unsure, say Y. endchoice +config ARCH_SUPPORTS_AUTOFDO_CLANG + bool + +config AUTOFDO_CLANG + bool "Enable Clang's AutoFDO build (EXPERIMENTAL)" + depends on ARCH_SUPPORTS_AUTOFDO_CLANG + depends on CC_IS_CLANG && CLANG_VERSION >= 170000 + help + This option enables Clang’s AutoFDO build. When + an AutoFDO profile is specified in variable + CLANG_AUTOFDO_PROFILE during the build process, + Clang uses the profile to optimize the kernel. + + If no profile is specified, AutoFDO options are + still passed to Clang to facilitate the collection + of perf data for creating an AutoFDO profile in + subsequent builds. + + If unsure, say N. + +config ARCH_SUPPORTS_PROPELLER_CLANG + bool + +config PROPELLER_CLANG + bool "Enable Clang's Propeller build" + depends on ARCH_SUPPORTS_PROPELLER_CLANG + depends on CC_IS_CLANG && CLANG_VERSION >= 190000 + help + This option enables Clang’s Propeller build. When the Propeller + profiles is specified in variable CLANG_PROPELLER_PROFILE_PREFIX + during the build process, Clang uses the profiles to optimize + the kernel. + + If no profile is specified, Propeller options are still passed + to Clang to facilitate the collection of perf data for creating + the Propeller profiles in subsequent builds. + + If unsure, say N. + config ARCH_SUPPORTS_CFI_CLANG bool help diff --git a/arch/sparc/kernel/vmlinux.lds.S b/arch/sparc/kernel/vmlinux.lds.S index d317a843f7ea..f1b86eb30340 100644 --- a/arch/sparc/kernel/vmlinux.lds.S +++ b/arch/sparc/kernel/vmlinux.lds.S @@ -48,6 +48,11 @@ SECTIONS { _text = .; HEAD_TEXT + ALIGN_FUNCTION(); +#ifdef CONFIG_SPARC64 + /* Match text section symbols in head_64.S first */ + *head_64.o(.text) +#endif TEXT_TEXT SCHED_TEXT LOCK_TEXT diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 16354dfa6d96..89b8fc452a7c 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -126,6 +126,8 @@ config X86 select ARCH_SUPPORTS_LTO_CLANG select ARCH_SUPPORTS_LTO_CLANG_THIN select ARCH_SUPPORTS_RT + select ARCH_SUPPORTS_AUTOFDO_CLANG + select ARCH_SUPPORTS_PROPELLER_CLANG if X86_64 select ARCH_USE_BUILTIN_BSWAP select ARCH_USE_CMPXCHG_LOCKREF if X86_CMPXCHG64 select ARCH_USE_MEMTEST diff --git a/arch/x86/kernel/vmlinux.lds.S b/arch/x86/kernel/vmlinux.lds.S index b8c5741d2fb4..cf22081601ed 100644 --- a/arch/x86/kernel/vmlinux.lds.S +++ b/arch/x86/kernel/vmlinux.lds.S @@ -443,6 +443,10 @@ SECTIONS STABS_DEBUG DWARF_DEBUG +#ifdef CONFIG_PROPELLER_CLANG + .llvm_bb_addr_map : { *(.llvm_bb_addr_map) } +#endif + ELF_DETAILS DISCARDS diff --git a/include/asm-generic/vmlinux.lds.h b/include/asm-generic/vmlinux.lds.h index eeadbaeccf88..c995474e4c64 100644 --- a/include/asm-generic/vmlinux.lds.h +++ b/include/asm-generic/vmlinux.lds.h @@ -95,18 +95,25 @@ * With LTO_CLANG, the linker also splits sections by default, so we need * these macros to combine the sections during the final link. * + * With AUTOFDO_CLANG and PROPELLER_CLANG, by default, the linker splits + * text sections and regroups functions into subsections. + * * RODATA_MAIN is not used because existing code already defines .rodata.x * sections to be brought in with rodata. */ -#if defined(CONFIG_LD_DEAD_CODE_DATA_ELIMINATION) || defined(CONFIG_LTO_CLANG) +#if defined(CONFIG_LD_DEAD_CODE_DATA_ELIMINATION) || defined(CONFIG_LTO_CLANG) || \ +defined(CONFIG_AUTOFDO_CLANG) || defined(CONFIG_PROPELLER_CLANG) #define TEXT_MAIN .text .text.[0-9a-zA-Z_]* +#else +#define TEXT_MAIN .text +#endif +#if defined(CONFIG_LD_DEAD_CODE_DATA_ELIMINATION) || defined(CONFIG_LTO_CLANG) #define DATA_MAIN .data .data.[0-9a-zA-Z_]* .data..L* .data..compoundliteral* .data.$__unnamed_* .data.$L* #define SDATA_MAIN .sdata .sdata.[0-9a-zA-Z_]* #define RODATA_MAIN .rodata .rodata.[0-9a-zA-Z_]* .rodata..L* #define BSS_MAIN .bss .bss.[0-9a-zA-Z_]* .bss..L* .bss..compoundliteral* #define SBSS_MAIN .sbss .sbss.[0-9a-zA-Z_]* #else -#define TEXT_MAIN .text #define DATA_MAIN .data #define SDATA_MAIN .sdata #define RODATA_MAIN .rodata @@ -549,24 +556,44 @@ __cpuidle_text_end = .; \ __noinstr_text_end = .; +#define TEXT_SPLIT \ + __split_text_start = .; \ + *(.text.split .text.split.[0-9a-zA-Z_]*) \ + __split_text_end = .; + +#define TEXT_UNLIKELY \ + __unlikely_text_start = .; \ + *(.text.unlikely .text.unlikely.*) \ + __unlikely_text_end = .; + +#define TEXT_HOT \ + __hot_text_start = .; \ + *(.text.hot .text.hot.*) \ + __hot_text_end = .; + /* * .text section. Map to function alignment to avoid address changes * during second ld run in second ld pass when generating System.map * - * TEXT_MAIN here will match .text.fixup and .text.unlikely if dead - * code elimination is enabled, so these sections should be converted - * to use ".." first. + * TEXT_MAIN here will match symbols with a fixed pattern (for example, + * .text.hot or .text.unlikely) if dead code elimination or + * function-section is enabled. Match these symbols first before + * TEXT_MAIN to ensure they are grouped together. + * + * Also placing .text.hot section at the beginning of a page, this + * would help the TLB performance. */ #define TEXT_TEXT \ ALIGN_FUNCTION(); \ - *(.text.hot .text.hot.*) \ - *(TEXT_MAIN .text.fixup) \ - *(.text.unlikely .text.unlikely.*) \ + *(.text.asan.* .text.tsan.*) \ *(.text.unknown .text.unknown.*) \ + TEXT_SPLIT \ + TEXT_UNLIKELY \ + . = ALIGN(PAGE_SIZE); \ + TEXT_HOT \ + *(TEXT_MAIN .text.fixup) \ NOINSTR_TEXT \ - *(.ref.text) \ - *(.text.asan.* .text.tsan.*) - + *(.ref.text) /* sched.text is aling to function alignment to secure we have same * address even at second ld pass when generating System.map */ diff --git a/scripts/Makefile.autofdo b/scripts/Makefile.autofdo new file mode 100644 index 000000000000..1caf2457e585 --- /dev/null +++ b/scripts/Makefile.autofdo @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: GPL-2.0 + +# Enable available and selected Clang AutoFDO features. + +CFLAGS_AUTOFDO_CLANG := -fdebug-info-for-profiling -mllvm -enable-fs-discriminator=true -mllvm -improved-fs-discriminator=true + +ifndef CONFIG_DEBUG_INFO + CFLAGS_AUTOFDO_CLANG += -gmlt +endif + +ifdef CLANG_AUTOFDO_PROFILE + CFLAGS_AUTOFDO_CLANG += -fprofile-sample-use=$(CLANG_AUTOFDO_PROFILE) -ffunction-sections + CFLAGS_AUTOFDO_CLANG += -fsplit-machine-functions +endif + +ifdef CONFIG_LTO_CLANG_THIN + ifdef CLANG_AUTOFDO_PROFILE + KBUILD_LDFLAGS += --lto-sample-profile=$(CLANG_AUTOFDO_PROFILE) + endif + KBUILD_LDFLAGS += --mllvm=-enable-fs-discriminator=true --mllvm=-improved-fs-discriminator=true -plugin-opt=thinlto + KBUILD_LDFLAGS += -plugin-opt=-split-machine-functions +endif + +export CFLAGS_AUTOFDO_CLANG diff --git a/scripts/Makefile.lib b/scripts/Makefile.lib index 01a9f567d5af..e7859ad90224 100644 --- a/scripts/Makefile.lib +++ b/scripts/Makefile.lib @@ -191,6 +191,26 @@ _c_flags += $(if $(patsubst n%,, \ -D__KCSAN_INSTRUMENT_BARRIERS__) endif +# +# Enable AutoFDO build flags except some files or directories we don't want to +# enable (depends on variables AUTOFDO_PROFILE_obj.o and AUTOFDO_PROFILE). +# +ifeq ($(CONFIG_AUTOFDO_CLANG),y) +_c_flags += $(if $(patsubst n%,, \ + $(AUTOFDO_PROFILE_$(target-stem).o)$(AUTOFDO_PROFILE)$(is-kernel-object)), \ + $(CFLAGS_AUTOFDO_CLANG)) +endif + +# +# Enable Propeller build flags except some files or directories we don't want to +# enable (depends on variables AUTOFDO_PROPELLER_obj.o and PROPELLER_PROFILE). +# +ifdef CONFIG_PROPELLER_CLANG +_c_flags += $(if $(patsubst n%,, \ + $(AUTOFDO_PROFILE_$(target-stem).o)$(AUTOFDO_PROFILE)$(PROPELLER_PROFILE))$(is-kernel-object), \ + $(CFLAGS_PROPELLER_CLANG)) +endif + # $(src) for including checkin headers from generated source files # $(obj) for including generated headers from checkin source files ifeq ($(KBUILD_EXTMOD),) diff --git a/scripts/Makefile.propeller b/scripts/Makefile.propeller new file mode 100644 index 000000000000..344190717e47 --- /dev/null +++ b/scripts/Makefile.propeller @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-2.0 + +# Enable available and selected Clang Propeller features. +ifdef CLANG_PROPELLER_PROFILE_PREFIX + CFLAGS_PROPELLER_CLANG := -fbasic-block-sections=list=$(CLANG_PROPELLER_PROFILE_PREFIX)_cc_profile.txt -ffunction-sections + KBUILD_LDFLAGS += --symbol-ordering-file=$(CLANG_PROPELLER_PROFILE_PREFIX)_ld_profile.txt --no-warn-symbol-ordering +else + CFLAGS_PROPELLER_CLANG := -fbasic-block-sections=labels +endif + +# Propeller requires debug information to embed module names in the profiles. +# If CONFIG_DEBUG_INFO is not enabled, set -gmlt option. Skip this for AutoFDO, +# as the option should already be set. +ifndef CONFIG_DEBUG_INFO + ifndef CONFIG_AUTOFDO_CLANG + CFLAGS_PROPELLER_CLANG += -gmlt + endif +endif + +ifdef CONFIG_LTO_CLANG_THIN + ifdef CLANG_PROPELLER_PROFILE_PREFIX + KBUILD_LDFLAGS += --lto-basic-block-sections=$(CLANG_PROPELLER_PROFILE_PREFIX)_cc_profile.txt + else + KBUILD_LDFLAGS += --lto-basic-block-sections=labels + endif +endif + +export CFLAGS_PROPELLER_CLANG diff --git a/tools/objtool/check.c b/tools/objtool/check.c index 6604f5d038aa..05a0fb4a3d1a 100644 --- a/tools/objtool/check.c +++ b/tools/objtool/check.c @@ -4557,6 +4557,8 @@ static int validate_ibt(struct objtool_file *file) !strcmp(sec->name, "__jump_table") || !strcmp(sec->name, "__mcount_loc") || !strcmp(sec->name, ".kcfi_traps") || + !strcmp(sec->name, ".llvm.call-graph-profile") || + !strcmp(sec->name, ".llvm_bb_addr_map") || strstr(sec->name, "__patchable_function_entries")) continue; diff --git a/tools/objtool/elf.c b/tools/objtool/elf.c index 3d27983dc908..6f64d611faea 100644 --- a/tools/objtool/elf.c +++ b/tools/objtool/elf.c @@ -224,12 +224,17 @@ int find_symbol_hole_containing(const struct section *sec, unsigned long offset) if (n) return 0; /* not a hole */ - /* didn't find a symbol for which @offset is after it */ - if (!hole.sym) - return 0; /* not a hole */ + /* + * @offset >= sym->offset + sym->len, find symbol after it. + * When hole.sym is empty, use the first node to compute the hole. + * If there is no symbol in the section, the first node will be NULL, + * in which case, -1 is returned to skip the whole section. + */ + if (hole.sym) + n = rb_next(&hole.sym->node); + else + n = rb_first_cached(&sec->symbol_tree); - /* @offset >= sym->offset + sym->len, find symbol after it */ - n = rb_next(&hole.sym->node); if (!n) return -1; /* until end of address space */ -- 2.47.0 From 9f4066f41c5d80b408109ea740488da2cca89fcc Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:17:53 +0100 Subject: [PATCH 04/13] bbr3 Signed-off-by: Peter Jung --- include/linux/tcp.h | 4 +- include/net/inet_connection_sock.h | 4 +- include/net/tcp.h | 72 +- include/uapi/linux/inet_diag.h | 23 + include/uapi/linux/rtnetlink.h | 4 +- include/uapi/linux/tcp.h | 1 + net/ipv4/Kconfig | 21 +- net/ipv4/bpf_tcp_ca.c | 9 +- net/ipv4/tcp.c | 3 + net/ipv4/tcp_bbr.c | 2230 +++++++++++++++++++++------- net/ipv4/tcp_cong.c | 1 + net/ipv4/tcp_input.c | 40 +- net/ipv4/tcp_minisocks.c | 2 + net/ipv4/tcp_output.c | 48 +- net/ipv4/tcp_rate.c | 30 +- net/ipv4/tcp_timer.c | 1 + 16 files changed, 1940 insertions(+), 553 deletions(-) diff --git a/include/linux/tcp.h b/include/linux/tcp.h index 6a5e08b937b3..27aab715490e 100644 --- a/include/linux/tcp.h +++ b/include/linux/tcp.h @@ -369,7 +369,9 @@ struct tcp_sock { u8 compressed_ack; u8 dup_ack_counter:2, tlp_retrans:1, /* TLP is a retransmission */ - unused:5; + fast_ack_mode:2, /* which fast ack mode ? */ + tlp_orig_data_app_limited:1, /* app-limited before TLP rtx? */ + unused:2; u8 thin_lto : 1,/* Use linear timeouts for thin streams */ fastopen_connect:1, /* FASTOPEN_CONNECT sockopt */ fastopen_no_cookie:1, /* Allow send/recv SYN+data without a cookie */ diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h index c0deaafebfdc..d53f042d936e 100644 --- a/include/net/inet_connection_sock.h +++ b/include/net/inet_connection_sock.h @@ -137,8 +137,8 @@ struct inet_connection_sock { u32 icsk_probes_tstamp; u32 icsk_user_timeout; - u64 icsk_ca_priv[104 / sizeof(u64)]; -#define ICSK_CA_PRIV_SIZE sizeof_field(struct inet_connection_sock, icsk_ca_priv) +#define ICSK_CA_PRIV_SIZE (144) + u64 icsk_ca_priv[ICSK_CA_PRIV_SIZE / sizeof(u64)]; }; #define ICSK_TIME_RETRANS 1 /* Retransmit timer */ diff --git a/include/net/tcp.h b/include/net/tcp.h index d1948d357dad..7d99f0bec5f2 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -375,6 +375,8 @@ static inline void tcp_dec_quickack_mode(struct sock *sk) #define TCP_ECN_QUEUE_CWR 2 #define TCP_ECN_DEMAND_CWR 4 #define TCP_ECN_SEEN 8 +#define TCP_ECN_LOW 16 +#define TCP_ECN_ECT_PERMANENT 32 enum tcp_tw_status { TCP_TW_SUCCESS = 0, @@ -779,6 +781,15 @@ static inline void tcp_fast_path_check(struct sock *sk) u32 tcp_delack_max(const struct sock *sk); +static inline void tcp_set_ecn_low_from_dst(struct sock *sk, + const struct dst_entry *dst) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (dst_feature(dst, RTAX_FEATURE_ECN_LOW)) + tp->ecn_flags |= TCP_ECN_LOW; +} + /* Compute the actual rto_min value */ static inline u32 tcp_rto_min(const struct sock *sk) { @@ -884,6 +895,11 @@ static inline u32 tcp_stamp_us_delta(u64 t1, u64 t0) return max_t(s64, t1 - t0, 0); } +static inline u32 tcp_stamp32_us_delta(u32 t1, u32 t0) +{ + return max_t(s32, t1 - t0, 0); +} + /* provide the departure time in us unit */ static inline u64 tcp_skb_timestamp_us(const struct sk_buff *skb) { @@ -973,9 +989,14 @@ struct tcp_skb_cb { /* pkts S/ACKed so far upon tx of skb, incl retrans: */ __u32 delivered; /* start of send pipeline phase */ - u64 first_tx_mstamp; + u32 first_tx_mstamp; /* when we reached the "delivered" count */ - u64 delivered_mstamp; + u32 delivered_mstamp; +#define TCPCB_IN_FLIGHT_BITS 20 +#define TCPCB_IN_FLIGHT_MAX ((1U << TCPCB_IN_FLIGHT_BITS) - 1) + u32 in_flight:20, /* packets in flight at transmit */ + unused2:12; + u32 lost; /* packets lost so far upon tx of skb */ } tx; /* only used for outgoing skbs */ union { struct inet_skb_parm h4; @@ -1088,6 +1109,7 @@ enum tcp_ca_event { CA_EVENT_LOSS, /* loss timeout */ CA_EVENT_ECN_NO_CE, /* ECT set, but not CE marked */ CA_EVENT_ECN_IS_CE, /* received CE marked IP packet */ + CA_EVENT_TLP_RECOVERY, /* a lost segment was repaired by TLP probe */ }; /* Information about inbound ACK, passed to cong_ops->in_ack_event() */ @@ -1110,7 +1132,11 @@ enum tcp_ca_ack_event_flags { #define TCP_CONG_NON_RESTRICTED 0x1 /* Requires ECN/ECT set on all packets */ #define TCP_CONG_NEEDS_ECN 0x2 -#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | TCP_CONG_NEEDS_ECN) +/* Wants notification of CE events (CA_EVENT_ECN_IS_CE, CA_EVENT_ECN_NO_CE). */ +#define TCP_CONG_WANTS_CE_EVENTS 0x4 +#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | \ + TCP_CONG_NEEDS_ECN | \ + TCP_CONG_WANTS_CE_EVENTS) union tcp_cc_info; @@ -1130,10 +1156,13 @@ struct ack_sample { */ struct rate_sample { u64 prior_mstamp; /* starting timestamp for interval */ + u32 prior_lost; /* tp->lost at "prior_mstamp" */ u32 prior_delivered; /* tp->delivered at "prior_mstamp" */ u32 prior_delivered_ce;/* tp->delivered_ce at "prior_mstamp" */ + u32 tx_in_flight; /* packets in flight at starting timestamp */ + s32 lost; /* number of packets lost over interval */ s32 delivered; /* number of packets delivered over interval */ - s32 delivered_ce; /* number of packets delivered w/ CE marks*/ + s32 delivered_ce; /* packets delivered w/ CE mark over interval */ long interval_us; /* time for tp->delivered to incr "delivered" */ u32 snd_interval_us; /* snd interval for delivered packets */ u32 rcv_interval_us; /* rcv interval for delivered packets */ @@ -1144,7 +1173,9 @@ struct rate_sample { u32 last_end_seq; /* end_seq of most recently ACKed packet */ bool is_app_limited; /* is sample from packet with bubble in pipe? */ bool is_retrans; /* is sample from retransmission? */ + bool is_acking_tlp_retrans_seq; /* ACKed a TLP retransmit sequence? */ bool is_ack_delayed; /* is this (likely) a delayed ACK? */ + bool is_ece; /* did this ACK have ECN marked? */ }; struct tcp_congestion_ops { @@ -1168,8 +1199,11 @@ struct tcp_congestion_ops { /* hook for packet ack accounting (optional) */ void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample); - /* override sysctl_tcp_min_tso_segs */ - u32 (*min_tso_segs)(struct sock *sk); + /* pick target number of segments per TSO/GSO skb (optional): */ + u32 (*tso_segs)(struct sock *sk, unsigned int mss_now); + + /* react to a specific lost skb (optional) */ + void (*skb_marked_lost)(struct sock *sk, const struct sk_buff *skb); /* call when packets are delivered to update cwnd and pacing rate, * after all the ca_state processing. (optional) @@ -1235,6 +1269,14 @@ static inline char *tcp_ca_get_name_by_key(u32 key, char *buffer) } #endif +static inline bool tcp_ca_wants_ce_events(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + return icsk->icsk_ca_ops->flags & (TCP_CONG_NEEDS_ECN | + TCP_CONG_WANTS_CE_EVENTS); +} + static inline bool tcp_ca_needs_ecn(const struct sock *sk) { const struct inet_connection_sock *icsk = inet_csk(sk); @@ -1254,6 +1296,7 @@ static inline void tcp_ca_event(struct sock *sk, const enum tcp_ca_event event) void tcp_set_ca_state(struct sock *sk, const u8 ca_state); /* From tcp_rate.c */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, struct rate_sample *rs); @@ -1266,6 +1309,21 @@ static inline bool tcp_skb_sent_after(u64 t1, u64 t2, u32 seq1, u32 seq2) return t1 > t2 || (t1 == t2 && after(seq1, seq2)); } +/* If a retransmit failed due to local qdisc congestion or other local issues, + * then we may have called tcp_set_skb_tso_segs() to increase the number of + * segments in the skb without increasing the tx.in_flight. In all other cases, + * the tx.in_flight should be at least as big as the pcount of the sk_buff. We + * do not have the state to know whether a retransmit failed due to local qdisc + * congestion or other local issues, so to avoid spurious warnings we consider + * that any skb marked lost may have suffered that fate. + */ +static inline bool tcp_skb_tx_in_flight_is_suspicious(u32 skb_pcount, + u32 skb_sacked_flags, + u32 tx_in_flight) +{ + return (skb_pcount > tx_in_flight) && !(skb_sacked_flags & TCPCB_LOST); +} + /* These functions determine how the current flow behaves in respect of SACK * handling. SACK is negotiated with the peer, and therefore it can vary * between different flows. @@ -2417,7 +2475,7 @@ struct tcp_plb_state { u8 consec_cong_rounds:5, /* consecutive congested rounds */ unused:3; u32 pause_until; /* jiffies32 when PLB can resume rerouting */ -}; +} __attribute__ ((__packed__)); static inline void tcp_plb_init(const struct sock *sk, struct tcp_plb_state *plb) diff --git a/include/uapi/linux/inet_diag.h b/include/uapi/linux/inet_diag.h index 86bb2e8b17c9..9d9a3eb2ce9b 100644 --- a/include/uapi/linux/inet_diag.h +++ b/include/uapi/linux/inet_diag.h @@ -229,6 +229,29 @@ struct tcp_bbr_info { __u32 bbr_min_rtt; /* min-filtered RTT in uSec */ __u32 bbr_pacing_gain; /* pacing gain shifted left 8 bits */ __u32 bbr_cwnd_gain; /* cwnd gain shifted left 8 bits */ + __u32 bbr_bw_hi_lsb; /* lower 32 bits of bw_hi */ + __u32 bbr_bw_hi_msb; /* upper 32 bits of bw_hi */ + __u32 bbr_bw_lo_lsb; /* lower 32 bits of bw_lo */ + __u32 bbr_bw_lo_msb; /* upper 32 bits of bw_lo */ + __u8 bbr_mode; /* current bbr_mode in state machine */ + __u8 bbr_phase; /* current state machine phase */ + __u8 unused1; /* alignment padding; not used yet */ + __u8 bbr_version; /* BBR algorithm version */ + __u32 bbr_inflight_lo; /* lower short-term data volume bound */ + __u32 bbr_inflight_hi; /* higher long-term data volume bound */ + __u32 bbr_extra_acked; /* max excess packets ACKed in epoch */ +}; + +/* TCP BBR congestion control bbr_phase as reported in netlink/ss stats. */ +enum tcp_bbr_phase { + BBR_PHASE_INVALID = 0, + BBR_PHASE_STARTUP = 1, + BBR_PHASE_DRAIN = 2, + BBR_PHASE_PROBE_RTT = 3, + BBR_PHASE_PROBE_BW_UP = 4, + BBR_PHASE_PROBE_BW_DOWN = 5, + BBR_PHASE_PROBE_BW_CRUISE = 6, + BBR_PHASE_PROBE_BW_REFILL = 7, }; union tcp_cc_info { diff --git a/include/uapi/linux/rtnetlink.h b/include/uapi/linux/rtnetlink.h index 3b687d20c9ed..a7c30c243b54 100644 --- a/include/uapi/linux/rtnetlink.h +++ b/include/uapi/linux/rtnetlink.h @@ -507,12 +507,14 @@ enum { #define RTAX_FEATURE_TIMESTAMP (1 << 2) /* unused */ #define RTAX_FEATURE_ALLFRAG (1 << 3) /* unused */ #define RTAX_FEATURE_TCP_USEC_TS (1 << 4) +#define RTAX_FEATURE_ECN_LOW (1 << 5) #define RTAX_FEATURE_MASK (RTAX_FEATURE_ECN | \ RTAX_FEATURE_SACK | \ RTAX_FEATURE_TIMESTAMP | \ RTAX_FEATURE_ALLFRAG | \ - RTAX_FEATURE_TCP_USEC_TS) + RTAX_FEATURE_TCP_USEC_TS | \ + RTAX_FEATURE_ECN_LOW) struct rta_session { __u8 proto; diff --git a/include/uapi/linux/tcp.h b/include/uapi/linux/tcp.h index dbf896f3146c..4702cd2f1ffc 100644 --- a/include/uapi/linux/tcp.h +++ b/include/uapi/linux/tcp.h @@ -178,6 +178,7 @@ enum tcp_fastopen_client_fail { #define TCPI_OPT_ECN_SEEN 16 /* we received at least one packet with ECT */ #define TCPI_OPT_SYN_DATA 32 /* SYN-ACK acked data in SYN sent or rcvd */ #define TCPI_OPT_USEC_TS 64 /* usec timestamps */ +#define TCPI_OPT_ECN_LOW 128 /* Low-latency ECN configured at init */ /* * Sender's congestion state indicating normal or abnormal situations diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 6d2c97f8e9ef..ddc116ef22cb 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -669,15 +669,18 @@ config TCP_CONG_BBR default n help - BBR (Bottleneck Bandwidth and RTT) TCP congestion control aims to - maximize network utilization and minimize queues. It builds an explicit - model of the bottleneck delivery rate and path round-trip propagation - delay. It tolerates packet loss and delay unrelated to congestion. It - can operate over LAN, WAN, cellular, wifi, or cable modem links. It can - coexist with flows that use loss-based congestion control, and can - operate with shallow buffers, deep buffers, bufferbloat, policers, or - AQM schemes that do not provide a delay signal. It requires the fq - ("Fair Queue") pacing packet scheduler. + BBR (Bottleneck Bandwidth and RTT) TCP congestion control is a + model-based congestion control algorithm that aims to maximize + network utilization, keep queues and retransmit rates low, and to be + able to coexist with Reno/CUBIC in common scenarios. It builds an + explicit model of the network path. It tolerates a targeted degree + of random packet loss and delay. It can operate over LAN, WAN, + cellular, wifi, or cable modem links, and can use shallow-threshold + ECN signals. It can coexist to some degree with flows that use + loss-based congestion control, and can operate with shallow buffers, + deep buffers, bufferbloat, policers, or AQM schemes that do not + provide a delay signal. It requires pacing, using either TCP internal + pacing or the fq ("Fair Queue") pacing packet scheduler. choice prompt "Default TCP congestion control" diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 554804774628..2279e6e7bc9c 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -280,11 +280,15 @@ static void bpf_tcp_ca_pkts_acked(struct sock *sk, const struct ack_sample *samp { } -static u32 bpf_tcp_ca_min_tso_segs(struct sock *sk) +static u32 bpf_tcp_ca_tso_segs(struct sock *sk, unsigned int mss_now) { return 0; } +static void bpf_tcp_ca_skb_marked_lost(struct sock *sk, const struct sk_buff *skb) +{ +} + static void bpf_tcp_ca_cong_control(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) { @@ -315,7 +319,8 @@ static struct tcp_congestion_ops __bpf_ops_tcp_congestion_ops = { .cwnd_event = bpf_tcp_ca_cwnd_event, .in_ack_event = bpf_tcp_ca_in_ack_event, .pkts_acked = bpf_tcp_ca_pkts_acked, - .min_tso_segs = bpf_tcp_ca_min_tso_segs, + .tso_segs = bpf_tcp_ca_tso_segs, + .skb_marked_lost = bpf_tcp_ca_skb_marked_lost, .cong_control = bpf_tcp_ca_cong_control, .undo_cwnd = bpf_tcp_ca_undo_cwnd, .sndbuf_expand = bpf_tcp_ca_sndbuf_expand, diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 4f77bd862e95..fd3a5551eda7 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -3384,6 +3384,7 @@ int tcp_disconnect(struct sock *sk, int flags) tp->rx_opt.dsack = 0; tp->rx_opt.num_sacks = 0; tp->rcv_ooopack = 0; + tp->fast_ack_mode = 0; /* Clean up fastopen related fields */ @@ -4110,6 +4111,8 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_options |= TCPI_OPT_ECN; if (tp->ecn_flags & TCP_ECN_SEEN) info->tcpi_options |= TCPI_OPT_ECN_SEEN; + if (tp->ecn_flags & TCP_ECN_LOW) + info->tcpi_options |= TCPI_OPT_ECN_LOW; if (tp->syn_data_acked) info->tcpi_options |= TCPI_OPT_SYN_DATA; if (tp->tcp_usec_ts) diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index 760941e55153..a180fa648d5e 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -1,18 +1,19 @@ -/* Bottleneck Bandwidth and RTT (BBR) congestion control +/* BBR (Bottleneck Bandwidth and RTT) congestion control * - * BBR congestion control computes the sending rate based on the delivery - * rate (throughput) estimated from ACKs. In a nutshell: + * BBR is a model-based congestion control algorithm that aims for low queues, + * low loss, and (bounded) Reno/CUBIC coexistence. To maintain a model of the + * network path, it uses measurements of bandwidth and RTT, as well as (if they + * occur) packet loss and/or shallow-threshold ECN signals. Note that although + * it can use ECN or loss signals explicitly, it does not require either; it + * can bound its in-flight data based on its estimate of the BDP. * - * On each ACK, update our model of the network path: - * bottleneck_bandwidth = windowed_max(delivered / elapsed, 10 round trips) - * min_rtt = windowed_min(rtt, 10 seconds) - * pacing_rate = pacing_gain * bottleneck_bandwidth - * cwnd = max(cwnd_gain * bottleneck_bandwidth * min_rtt, 4) - * - * The core algorithm does not react directly to packet losses or delays, - * although BBR may adjust the size of next send per ACK when loss is - * observed, or adjust the sending rate if it estimates there is a - * traffic policer, in order to keep the drop rate reasonable. + * The model has both higher and lower bounds for the operating range: + * lo: bw_lo, inflight_lo: conservative short-term lower bound + * hi: bw_hi, inflight_hi: robust long-term upper bound + * The bandwidth-probing time scale is (a) extended dynamically based on + * estimated BDP to improve coexistence with Reno/CUBIC; (b) bounded by + * an interactive wall-clock time-scale to be more scalable and responsive + * than Reno and CUBIC. * * Here is a state transition diagram for BBR: * @@ -65,6 +66,13 @@ #include #include +#include +#include "tcp_dctcp.h" + +#define BBR_VERSION 3 + +#define bbr_param(sk,name) (bbr_ ## name) + /* Scale factor for rate in pkt/uSec unit to avoid truncation in bandwidth * estimation. The rate unit ~= (1500 bytes / 1 usec / 2^24) ~= 715 bps. * This handles bandwidths from 0.06pps (715bps) to 256Mpps (3Tbps) in a u32. @@ -85,36 +93,41 @@ enum bbr_mode { BBR_PROBE_RTT, /* cut inflight to min to probe min_rtt */ }; +/* How does the incoming ACK stream relate to our bandwidth probing? */ +enum bbr_ack_phase { + BBR_ACKS_INIT, /* not probing; not getting probe feedback */ + BBR_ACKS_REFILLING, /* sending at est. bw to fill pipe */ + BBR_ACKS_PROBE_STARTING, /* inflight rising to probe bw */ + BBR_ACKS_PROBE_FEEDBACK, /* getting feedback from bw probing */ + BBR_ACKS_PROBE_STOPPING, /* stopped probing; still getting feedback */ +}; + /* BBR congestion control block */ struct bbr { u32 min_rtt_us; /* min RTT in min_rtt_win_sec window */ u32 min_rtt_stamp; /* timestamp of min_rtt_us */ u32 probe_rtt_done_stamp; /* end time for BBR_PROBE_RTT mode */ - struct minmax bw; /* Max recent delivery rate in pkts/uS << 24 */ - u32 rtt_cnt; /* count of packet-timed rounds elapsed */ + u32 probe_rtt_min_us; /* min RTT in probe_rtt_win_ms win */ + u32 probe_rtt_min_stamp; /* timestamp of probe_rtt_min_us*/ u32 next_rtt_delivered; /* scb->tx.delivered at end of round */ u64 cycle_mstamp; /* time of this cycle phase start */ - u32 mode:3, /* current bbr_mode in state machine */ + u32 mode:2, /* current bbr_mode in state machine */ prev_ca_state:3, /* CA state on previous ACK */ - packet_conservation:1, /* use packet conservation? */ round_start:1, /* start of packet-timed tx->ack round? */ + ce_state:1, /* If most recent data has CE bit set */ + bw_probe_up_rounds:5, /* cwnd-limited rounds in PROBE_UP */ + try_fast_path:1, /* can we take fast path? */ idle_restart:1, /* restarting after idle? */ probe_rtt_round_done:1, /* a BBR_PROBE_RTT round at 4 pkts? */ - unused:13, - lt_is_sampling:1, /* taking long-term ("LT") samples now? */ - lt_rtt_cnt:7, /* round trips in long-term interval */ - lt_use_bw:1; /* use lt_bw as our bw estimate? */ - u32 lt_bw; /* LT est delivery rate in pkts/uS << 24 */ - u32 lt_last_delivered; /* LT intvl start: tp->delivered */ - u32 lt_last_stamp; /* LT intvl start: tp->delivered_mstamp */ - u32 lt_last_lost; /* LT intvl start: tp->lost */ + init_cwnd:7, /* initial cwnd */ + unused_1:10; u32 pacing_gain:10, /* current gain for setting pacing rate */ cwnd_gain:10, /* current gain for setting cwnd */ full_bw_reached:1, /* reached full bw in Startup? */ full_bw_cnt:2, /* number of rounds without large bw gains */ - cycle_idx:3, /* current index in pacing_gain cycle array */ + cycle_idx:2, /* current index in pacing_gain cycle array */ has_seen_rtt:1, /* have we seen an RTT sample yet? */ - unused_b:5; + unused_2:6; u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ u32 full_bw; /* recent bw, to estimate if pipe is full */ @@ -124,19 +137,67 @@ struct bbr { u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ extra_acked_win_idx:1, /* current index in extra_acked array */ - unused_c:6; + /* BBR v3 state: */ + full_bw_now:1, /* recently reached full bw plateau? */ + startup_ecn_rounds:2, /* consecutive hi ECN STARTUP rounds */ + loss_in_cycle:1, /* packet loss in this cycle? */ + ecn_in_cycle:1, /* ECN in this cycle? */ + unused_3:1; + u32 loss_round_delivered; /* scb->tx.delivered ending loss round */ + u32 undo_bw_lo; /* bw_lo before latest losses */ + u32 undo_inflight_lo; /* inflight_lo before latest losses */ + u32 undo_inflight_hi; /* inflight_hi before latest losses */ + u32 bw_latest; /* max delivered bw in last round trip */ + u32 bw_lo; /* lower bound on sending bandwidth */ + u32 bw_hi[2]; /* max recent measured bw sample */ + u32 inflight_latest; /* max delivered data in last round trip */ + u32 inflight_lo; /* lower bound of inflight data range */ + u32 inflight_hi; /* upper bound of inflight data range */ + u32 bw_probe_up_cnt; /* packets delivered per inflight_hi incr */ + u32 bw_probe_up_acks; /* packets (S)ACKed since inflight_hi incr */ + u32 probe_wait_us; /* PROBE_DOWN until next clock-driven probe */ + u32 prior_rcv_nxt; /* tp->rcv_nxt when CE state last changed */ + u32 ecn_eligible:1, /* sender can use ECN (RTT, handshake)? */ + ecn_alpha:9, /* EWMA delivered_ce/delivered; 0..256 */ + bw_probe_samples:1, /* rate samples reflect bw probing? */ + prev_probe_too_high:1, /* did last PROBE_UP go too high? */ + stopped_risky_probe:1, /* last PROBE_UP stopped due to risk? */ + rounds_since_probe:8, /* packet-timed rounds since probed bw */ + loss_round_start:1, /* loss_round_delivered round trip? */ + loss_in_round:1, /* loss marked in this round trip? */ + ecn_in_round:1, /* ECN marked in this round trip? */ + ack_phase:3, /* bbr_ack_phase: meaning of ACKs */ + loss_events_in_round:4,/* losses in STARTUP round */ + initialized:1; /* has bbr_init() been called? */ + u32 alpha_last_delivered; /* tp->delivered at alpha update */ + u32 alpha_last_delivered_ce; /* tp->delivered_ce at alpha update */ + + u8 unused_4; /* to preserve alignment */ + struct tcp_plb_state plb; }; -#define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ +struct bbr_context { + u32 sample_bw; +}; -/* Window length of bw filter (in rounds): */ -static const int bbr_bw_rtts = CYCLE_LEN + 2; /* Window length of min_rtt filter (in sec): */ static const u32 bbr_min_rtt_win_sec = 10; /* Minimum time (in ms) spent at bbr_cwnd_min_target in BBR_PROBE_RTT mode: */ static const u32 bbr_probe_rtt_mode_ms = 200; -/* Skip TSO below the following bandwidth (bits/sec): */ -static const int bbr_min_tso_rate = 1200000; +/* Window length of probe_rtt_min_us filter (in ms), and consequently the + * typical interval between PROBE_RTT mode entries. The default is 5000ms. + * Note that bbr_probe_rtt_win_ms must be <= bbr_min_rtt_win_sec * MSEC_PER_SEC + */ +static const u32 bbr_probe_rtt_win_ms = 5000; +/* Proportion of cwnd to estimated BDP in PROBE_RTT, in units of BBR_UNIT: */ +static const u32 bbr_probe_rtt_cwnd_gain = BBR_UNIT * 1 / 2; + +/* Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting + * in bigger TSO bursts. We cut the RTT-based allowance in half + * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance + * is below 1500 bytes after 6 * ~500 usec = 3ms. + */ +static const u32 bbr_tso_rtt_shift = 9; /* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. * In order to help drive the network toward lower queues and low latency while @@ -146,13 +207,15 @@ static const int bbr_min_tso_rate = 1200000; */ static const int bbr_pacing_margin_percent = 1; -/* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain +/* We use a startup_pacing_gain of 4*ln(2) because it's the smallest value * that will allow a smoothly increasing pacing rate that will double each RTT * and send the same number of packets per RTT that an un-paced, slow-starting * Reno or CUBIC flow would: */ -static const int bbr_high_gain = BBR_UNIT * 2885 / 1000 + 1; -/* The pacing gain of 1/high_gain in BBR_DRAIN is calculated to typically drain +static const int bbr_startup_pacing_gain = BBR_UNIT * 277 / 100 + 1; +/* The gain for deriving startup cwnd: */ +static const int bbr_startup_cwnd_gain = BBR_UNIT * 2; +/* The pacing gain in BBR_DRAIN is calculated to typically drain * the queue created in BBR_STARTUP in a single round: */ static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; @@ -160,13 +223,17 @@ static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; static const int bbr_cwnd_gain = BBR_UNIT * 2; /* The pacing_gain values for the PROBE_BW gain cycle, to discover/share bw: */ static const int bbr_pacing_gain[] = { - BBR_UNIT * 5 / 4, /* probe for more available bw */ - BBR_UNIT * 3 / 4, /* drain queue and/or yield bw to other flows */ - BBR_UNIT, BBR_UNIT, BBR_UNIT, /* cruise at 1.0*bw to utilize pipe, */ - BBR_UNIT, BBR_UNIT, BBR_UNIT /* without creating excess queue... */ + BBR_UNIT * 5 / 4, /* UP: probe for more available bw */ + BBR_UNIT * 91 / 100, /* DOWN: drain queue and/or yield bw */ + BBR_UNIT, /* CRUISE: try to use pipe w/ some headroom */ + BBR_UNIT, /* REFILL: refill pipe to estimated 100% */ +}; +enum bbr_pacing_gain_phase { + BBR_BW_PROBE_UP = 0, /* push up inflight to probe for bw/vol */ + BBR_BW_PROBE_DOWN = 1, /* drain excess inflight from the queue */ + BBR_BW_PROBE_CRUISE = 2, /* use pipe, w/ headroom in queue/pipe */ + BBR_BW_PROBE_REFILL = 3, /* v2: refill the pipe again to 100% */ }; -/* Randomize the starting gain cycling phase over N phases: */ -static const u32 bbr_cycle_rand = 7; /* Try to keep at least this many packets in flight, if things go smoothly. For * smooth functioning, a sliding window protocol ACKing every other packet @@ -174,24 +241,12 @@ static const u32 bbr_cycle_rand = 7; */ static const u32 bbr_cwnd_min_target = 4; -/* To estimate if BBR_STARTUP mode (i.e. high_gain) has filled pipe... */ +/* To estimate if BBR_STARTUP or BBR_BW_PROBE_UP has filled pipe... */ /* If bw has increased significantly (1.25x), there may be more bw available: */ static const u32 bbr_full_bw_thresh = BBR_UNIT * 5 / 4; /* But after 3 rounds w/o significant bw growth, estimate pipe is full: */ static const u32 bbr_full_bw_cnt = 3; -/* "long-term" ("LT") bandwidth estimator parameters... */ -/* The minimum number of rounds in an LT bw sampling interval: */ -static const u32 bbr_lt_intvl_min_rtts = 4; -/* If lost/delivered ratio > 20%, interval is "lossy" and we may be policed: */ -static const u32 bbr_lt_loss_thresh = 50; -/* If 2 intervals have a bw ratio <= 1/8, their bw is "consistent": */ -static const u32 bbr_lt_bw_ratio = BBR_UNIT / 8; -/* If 2 intervals have a bw diff <= 4 Kbit/sec their bw is "consistent": */ -static const u32 bbr_lt_bw_diff = 4000 / 8; -/* If we estimate we're policed, use lt_bw for this many round trips: */ -static const u32 bbr_lt_bw_max_rtts = 48; - /* Gain factor for adding extra_acked to target cwnd: */ static const int bbr_extra_acked_gain = BBR_UNIT; /* Window length of extra_acked window. */ @@ -201,8 +256,121 @@ static const u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; /* Time period for clamping cwnd increment due to ack aggregation */ static const u32 bbr_extra_acked_max_us = 100 * 1000; +/* Flags to control BBR ECN-related behavior... */ + +/* Ensure ACKs only ACK packets with consistent ECN CE status? */ +static const bool bbr_precise_ece_ack = true; + +/* Max RTT (in usec) at which to use sender-side ECN logic. + * Disabled when 0 (ECN allowed at any RTT). + */ +static const u32 bbr_ecn_max_rtt_us = 5000; + +/* On losses, scale down inflight and pacing rate by beta scaled by BBR_SCALE. + * No loss response when 0. + */ +static const u32 bbr_beta = BBR_UNIT * 30 / 100; + +/* Gain factor for ECN mark ratio samples, scaled by BBR_SCALE (1/16 = 6.25%) */ +static const u32 bbr_ecn_alpha_gain = BBR_UNIT * 1 / 16; + +/* The initial value for ecn_alpha; 1.0 allows a flow to respond quickly + * to congestion if the bottleneck is congested when the flow starts up. + */ +static const u32 bbr_ecn_alpha_init = BBR_UNIT; + +/* On ECN, cut inflight_lo to (1 - ecn_factor * ecn_alpha) scaled by BBR_SCALE. + * No ECN based bounding when 0. + */ +static const u32 bbr_ecn_factor = BBR_UNIT * 1 / 3; /* 1/3 = 33% */ + +/* Estimate bw probing has gone too far if CE ratio exceeds this threshold. + * Scaled by BBR_SCALE. Disabled when 0. + */ +static const u32 bbr_ecn_thresh = BBR_UNIT * 1 / 2; /* 1/2 = 50% */ + +/* If non-zero, if in a cycle with no losses but some ECN marks, after ECN + * clears then make the first round's increment to inflight_hi the following + * fraction of inflight_hi. + */ +static const u32 bbr_ecn_reprobe_gain = BBR_UNIT * 1 / 2; + +/* Estimate bw probing has gone too far if loss rate exceeds this level. */ +static const u32 bbr_loss_thresh = BBR_UNIT * 2 / 100; /* 2% loss */ + +/* Slow down for a packet loss recovered by TLP? */ +static const bool bbr_loss_probe_recovery = true; + +/* Exit STARTUP if number of loss marking events in a Recovery round is >= N, + * and loss rate is higher than bbr_loss_thresh. + * Disabled if 0. + */ +static const u32 bbr_full_loss_cnt = 6; + +/* Exit STARTUP if number of round trips with ECN mark rate above ecn_thresh + * meets this count. + */ +static const u32 bbr_full_ecn_cnt = 2; + +/* Fraction of unutilized headroom to try to leave in path upon high loss. */ +static const u32 bbr_inflight_headroom = BBR_UNIT * 15 / 100; + +/* How much do we increase cwnd_gain when probing for bandwidth in + * BBR_BW_PROBE_UP? This specifies the increment in units of + * BBR_UNIT/4. The default is 1, meaning 0.25. + * The min value is 0 (meaning 0.0); max is 3 (meaning 0.75). + */ +static const u32 bbr_bw_probe_cwnd_gain = 1; + +/* Max number of packet-timed rounds to wait before probing for bandwidth. If + * we want to tolerate 1% random loss per round, and not have this cut our + * inflight too much, we must probe for bw periodically on roughly this scale. + * If low, limits Reno/CUBIC coexistence; if high, limits loss tolerance. + * We aim to be fair with Reno/CUBIC up to a BDP of at least: + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + */ +static const u32 bbr_bw_probe_max_rounds = 63; + +/* Max amount of randomness to inject in round counting for Reno-coexistence. + */ +static const u32 bbr_bw_probe_rand_rounds = 2; + +/* Use BBR-native probe time scale starting at this many usec. + * We aim to be fair with Reno/CUBIC up to an inter-loss time epoch of at least: + * BDP*RTT = 25Mbps * .030sec /(1514bytes) * 0.030sec = 1.9 secs + */ +static const u32 bbr_bw_probe_base_us = 2 * USEC_PER_SEC; /* 2 secs */ + +/* Use BBR-native probes spread over this many usec: */ +static const u32 bbr_bw_probe_rand_us = 1 * USEC_PER_SEC; /* 1 secs */ + +/* Use fast path if app-limited, no loss/ECN, and target cwnd was reached? */ +static const bool bbr_fast_path = true; + +/* Use fast ack mode? */ +static const bool bbr_fast_ack_mode = true; + +static u32 bbr_max_bw(const struct sock *sk); +static u32 bbr_bw(const struct sock *sk); +static void bbr_exit_probe_rtt(struct sock *sk); +static void bbr_reset_congestion_signals(struct sock *sk); +static void bbr_run_loss_probe_recovery(struct sock *sk); + static void bbr_check_probe_rtt_done(struct sock *sk); +/* This connection can use ECN if both endpoints have signaled ECN support in + * the handshake and the per-route settings indicated this is a + * shallow-threshold ECN environment, meaning both: + * (a) ECN CE marks indicate low-latency/shallow-threshold congestion, and + * (b) TCP endpoints provide precise ACKs that only ACK data segments + * with consistent ECN CE status + */ +static bool bbr_can_use_ecn(const struct sock *sk) +{ + return (tcp_sk(sk)->ecn_flags & TCP_ECN_OK) && + (tcp_sk(sk)->ecn_flags & TCP_ECN_LOW); +} + /* Do we estimate that STARTUP filled the pipe? */ static bool bbr_full_bw_reached(const struct sock *sk) { @@ -214,17 +382,17 @@ static bool bbr_full_bw_reached(const struct sock *sk) /* Return the windowed max recent bandwidth sample, in pkts/uS << BW_SCALE. */ static u32 bbr_max_bw(const struct sock *sk) { - struct bbr *bbr = inet_csk_ca(sk); + const struct bbr *bbr = inet_csk_ca(sk); - return minmax_get(&bbr->bw); + return max(bbr->bw_hi[0], bbr->bw_hi[1]); } /* Return the estimated bandwidth of the path, in pkts/uS << BW_SCALE. */ static u32 bbr_bw(const struct sock *sk) { - struct bbr *bbr = inet_csk_ca(sk); + const struct bbr *bbr = inet_csk_ca(sk); - return bbr->lt_use_bw ? bbr->lt_bw : bbr_max_bw(sk); + return min(bbr_max_bw(sk), bbr->bw_lo); } /* Return maximum extra acked in past k-2k round trips, @@ -241,15 +409,23 @@ static u16 bbr_extra_acked(const struct sock *sk) * The order here is chosen carefully to avoid overflow of u64. This should * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. */ -static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain) +static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain, + int margin) { unsigned int mss = tcp_sk(sk)->mss_cache; rate *= mss; rate *= gain; rate >>= BBR_SCALE; - rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_margin_percent); - return rate >> BW_SCALE; + rate *= USEC_PER_SEC / 100 * (100 - margin); + rate >>= BW_SCALE; + rate = max(rate, 1ULL); + return rate; +} + +static u64 bbr_bw_bytes_per_sec(struct sock *sk, u64 rate) +{ + return bbr_rate_bytes_per_sec(sk, rate, BBR_UNIT, 0); } /* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ @@ -257,12 +433,13 @@ static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) { u64 rate = bw; - rate = bbr_rate_bytes_per_sec(sk, rate, gain); + rate = bbr_rate_bytes_per_sec(sk, rate, gain, + bbr_pacing_margin_percent); rate = min_t(u64, rate, READ_ONCE(sk->sk_max_pacing_rate)); return rate; } -/* Initialize pacing rate to: high_gain * init_cwnd / RTT. */ +/* Initialize pacing rate to: startup_pacing_gain * init_cwnd / RTT. */ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); @@ -279,7 +456,7 @@ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) bw = (u64)tcp_snd_cwnd(tp) * BW_UNIT; do_div(bw, rtt_us); WRITE_ONCE(sk->sk_pacing_rate, - bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain)); + bbr_bw_to_pacing_rate(sk, bw, bbr_param(sk, startup_pacing_gain))); } /* Pace using current bw estimate and a gain factor. */ @@ -295,26 +472,48 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) WRITE_ONCE(sk->sk_pacing_rate, rate); } -/* override sysctl_tcp_min_tso_segs */ -__bpf_kfunc static u32 bbr_min_tso_segs(struct sock *sk) +/* Return the number of segments BBR would like in a TSO/GSO skb, given a + * particular max gso size as a constraint. TODO: make this simpler and more + * consistent by switching bbr to just call tcp_tso_autosize(). + */ +static u32 bbr_tso_segs_generic(struct sock *sk, unsigned int mss_now, + u32 gso_max_size) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 segs, r; + u64 bytes; + + /* Budget a TSO/GSO burst size allowance based on bw (pacing_rate). */ + bytes = READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift); + + /* Budget a TSO/GSO burst size allowance based on min_rtt. For every + * K = 2^tso_rtt_shift microseconds of min_rtt, halve the burst. + * The min_rtt-based burst allowance is: 64 KBytes / 2^(min_rtt/K) + */ + if (bbr_param(sk, tso_rtt_shift)) { + r = bbr->min_rtt_us >> bbr_param(sk, tso_rtt_shift); + if (r < BITS_PER_TYPE(u32)) /* prevent undefined behavior */ + bytes += GSO_LEGACY_MAX_SIZE >> r; + } + + bytes = min_t(u32, bytes, gso_max_size - 1 - MAX_TCP_HEADER); + segs = max_t(u32, bytes / mss_now, + sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + return segs; +} + +/* Custom tcp_tso_autosize() for BBR, used at transmit time to cap skb size. */ +__bpf_kfunc static u32 bbr_tso_segs(struct sock *sk, unsigned int mss_now) { - return READ_ONCE(sk->sk_pacing_rate) < (bbr_min_tso_rate >> 3) ? 1 : 2; + return bbr_tso_segs_generic(sk, mss_now, sk->sk_gso_max_size); } +/* Like bbr_tso_segs(), using mss_cache, ignoring driver's sk_gso_max_size. */ static u32 bbr_tso_segs_goal(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); - u32 segs, bytes; - - /* Sort of tcp_tso_autosize() but ignoring - * driver provided sk_gso_max_size. - */ - bytes = min_t(unsigned long, - READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift), - GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); - segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); - return min(segs, 0x7FU); + return bbr_tso_segs_generic(sk, tp->mss_cache, GSO_LEGACY_MAX_SIZE); } /* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ @@ -334,7 +533,9 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - if (event == CA_EVENT_TX_START && tp->app_limited) { + if (event == CA_EVENT_TX_START) { + if (!tp->app_limited) + return; bbr->idle_restart = 1; bbr->ack_epoch_mstamp = tp->tcp_mstamp; bbr->ack_epoch_acked = 0; @@ -345,6 +546,16 @@ __bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) bbr_set_pacing_rate(sk, bbr_bw(sk), BBR_UNIT); else if (bbr->mode == BBR_PROBE_RTT) bbr_check_probe_rtt_done(sk); + } else if ((event == CA_EVENT_ECN_IS_CE || + event == CA_EVENT_ECN_NO_CE) && + bbr_can_use_ecn(sk) && + bbr_param(sk, precise_ece_ack)) { + u32 state = bbr->ce_state; + dctcp_ece_ack_update(sk, event, &bbr->prior_rcv_nxt, &state); + bbr->ce_state = state; + } else if (event == CA_EVENT_TLP_RECOVERY && + bbr_param(sk, loss_probe_recovery)) { + bbr_run_loss_probe_recovery(sk); } } @@ -367,10 +578,10 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) * default. This should only happen when the connection is not using TCP * timestamps and has retransmitted all of the SYN/SYNACK/data packets * ACKed so far. In this case, an RTO can cut cwnd to 1, in which - * case we need to slow-start up toward something safe: TCP_INIT_CWND. + * case we need to slow-start up toward something safe: initial cwnd. */ if (unlikely(bbr->min_rtt_us == ~0U)) /* no valid RTT samples yet? */ - return TCP_INIT_CWND; /* be safe: cap at default initial cwnd*/ + return bbr->init_cwnd; /* be safe: cap at initial cwnd */ w = (u64)bw * bbr->min_rtt_us; @@ -387,23 +598,23 @@ static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) * - one skb in sending host Qdisc, * - one skb in sending host TSO/GSO engine * - one skb being received by receiver host LRO/GRO/delayed-ACK engine - * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because - * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, + * Don't worry, at low rates this won't bloat cwnd because + * in such cases tso_segs_goal is small. The minimum cwnd is 4 packets, * which allows 2 outstanding 2-packet sequences, to try to keep pipe * full even with ACK-every-other-packet delayed ACKs. */ static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd) { struct bbr *bbr = inet_csk_ca(sk); + u32 tso_segs_goal; - /* Allow enough full-sized skbs in flight to utilize end systems. */ - cwnd += 3 * bbr_tso_segs_goal(sk); - - /* Reduce delayed ACKs by rounding up cwnd to the next even number. */ - cwnd = (cwnd + 1) & ~1U; + tso_segs_goal = 3 * bbr_tso_segs_goal(sk); + /* Allow enough full-sized skbs in flight to utilize end systems. */ + cwnd = max_t(u32, cwnd, tso_segs_goal); + cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); /* Ensure gain cycling gets inflight above BDP even for small BDPs. */ - if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == 0) + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) cwnd += 2; return cwnd; @@ -458,10 +669,10 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) { u32 max_aggr_cwnd, aggr_cwnd = 0; - if (bbr_extra_acked_gain && bbr_full_bw_reached(sk)) { + if (bbr_param(sk, extra_acked_gain)) { max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) / BW_UNIT; - aggr_cwnd = (bbr_extra_acked_gain * bbr_extra_acked(sk)) + aggr_cwnd = (bbr_param(sk, extra_acked_gain) * bbr_extra_acked(sk)) >> BBR_SCALE; aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); } @@ -469,66 +680,27 @@ static u32 bbr_ack_aggregation_cwnd(struct sock *sk) return aggr_cwnd; } -/* An optimization in BBR to reduce losses: On the first round of recovery, we - * follow the packet conservation principle: send P packets per P packets acked. - * After that, we slow-start and send at most 2*P packets per P packets acked. - * After recovery finishes, or upon undo, we restore the cwnd we had when - * recovery started (capped by the target cwnd based on estimated BDP). - * - * TODO(ycheng/ncardwell): implement a rate-based approach. - */ -static bool bbr_set_cwnd_to_recover_or_restore( - struct sock *sk, const struct rate_sample *rs, u32 acked, u32 *new_cwnd) +/* Returns the cwnd for PROBE_RTT mode. */ +static u32 bbr_probe_rtt_cwnd(struct sock *sk) { - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - u8 prev_state = bbr->prev_ca_state, state = inet_csk(sk)->icsk_ca_state; - u32 cwnd = tcp_snd_cwnd(tp); - - /* An ACK for P pkts should release at most 2*P packets. We do this - * in two steps. First, here we deduct the number of lost packets. - * Then, in bbr_set_cwnd() we slow start up toward the target cwnd. - */ - if (rs->losses > 0) - cwnd = max_t(s32, cwnd - rs->losses, 1); - - if (state == TCP_CA_Recovery && prev_state != TCP_CA_Recovery) { - /* Starting 1st round of Recovery, so do packet conservation. */ - bbr->packet_conservation = 1; - bbr->next_rtt_delivered = tp->delivered; /* start round now */ - /* Cut unused cwnd from app behavior, TSQ, or TSO deferral: */ - cwnd = tcp_packets_in_flight(tp) + acked; - } else if (prev_state >= TCP_CA_Recovery && state < TCP_CA_Recovery) { - /* Exiting loss recovery; restore cwnd saved before recovery. */ - cwnd = max(cwnd, bbr->prior_cwnd); - bbr->packet_conservation = 0; - } - bbr->prev_ca_state = state; - - if (bbr->packet_conservation) { - *new_cwnd = max(cwnd, tcp_packets_in_flight(tp) + acked); - return true; /* yes, using packet conservation */ - } - *new_cwnd = cwnd; - return false; + return max_t(u32, bbr_param(sk, cwnd_min_target), + bbr_bdp(sk, bbr_bw(sk), bbr_param(sk, probe_rtt_cwnd_gain))); } /* Slow-start up toward target cwnd (if bw estimate is growing, or packet loss * has drawn us down below target), or snap down to target if we're above it. */ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, - u32 acked, u32 bw, int gain) + u32 acked, u32 bw, int gain, u32 cwnd, + struct bbr_context *ctx) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u32 cwnd = tcp_snd_cwnd(tp), target_cwnd = 0; + u32 target_cwnd = 0; if (!acked) goto done; /* no packet fully ACKed; just apply caps */ - if (bbr_set_cwnd_to_recover_or_restore(sk, rs, acked, &cwnd)) - goto done; - target_cwnd = bbr_bdp(sk, bw, gain); /* Increment the cwnd to account for excess ACKed data that seems @@ -537,74 +709,26 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, target_cwnd += bbr_ack_aggregation_cwnd(sk); target_cwnd = bbr_quantization_budget(sk, target_cwnd); - /* If we're below target cwnd, slow start cwnd toward target cwnd. */ - if (bbr_full_bw_reached(sk)) /* only cut cwnd if we filled the pipe */ - cwnd = min(cwnd + acked, target_cwnd); - else if (cwnd < target_cwnd || tp->delivered < TCP_INIT_CWND) - cwnd = cwnd + acked; - cwnd = max(cwnd, bbr_cwnd_min_target); + /* Update cwnd and enable fast path if cwnd reaches target_cwnd. */ + bbr->try_fast_path = 0; + if (bbr_full_bw_reached(sk)) { /* only cut cwnd if we filled the pipe */ + cwnd += acked; + if (cwnd >= target_cwnd) { + cwnd = target_cwnd; + bbr->try_fast_path = 1; + } + } else if (cwnd < target_cwnd || cwnd < 2 * bbr->init_cwnd) { + cwnd += acked; + } else { + bbr->try_fast_path = 1; + } + cwnd = max_t(u32, cwnd, bbr_param(sk, cwnd_min_target)); done: - tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* apply global cap */ + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* global cap */ if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ - tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), bbr_cwnd_min_target)); -} - -/* End cycle phase if it's time and/or we hit the phase's in-flight target. */ -static bool bbr_is_next_cycle_phase(struct sock *sk, - const struct rate_sample *rs) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - bool is_full_length = - tcp_stamp_us_delta(tp->delivered_mstamp, bbr->cycle_mstamp) > - bbr->min_rtt_us; - u32 inflight, bw; - - /* The pacing_gain of 1.0 paces at the estimated bw to try to fully - * use the pipe without increasing the queue. - */ - if (bbr->pacing_gain == BBR_UNIT) - return is_full_length; /* just use wall clock time */ - - inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); - bw = bbr_max_bw(sk); - - /* A pacing_gain > 1.0 probes for bw by trying to raise inflight to at - * least pacing_gain*BDP; this may take more than min_rtt if min_rtt is - * small (e.g. on a LAN). We do not persist if packets are lost, since - * a path with small buffers may not hold that much. - */ - if (bbr->pacing_gain > BBR_UNIT) - return is_full_length && - (rs->losses || /* perhaps pacing_gain*BDP won't fit */ - inflight >= bbr_inflight(sk, bw, bbr->pacing_gain)); - - /* A pacing_gain < 1.0 tries to drain extra queue we added if bw - * probing didn't find more bw. If inflight falls to match BDP then we - * estimate queue is drained; persisting would underutilize the pipe. - */ - return is_full_length || - inflight <= bbr_inflight(sk, bw, BBR_UNIT); -} - -static void bbr_advance_cycle_phase(struct sock *sk) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - - bbr->cycle_idx = (bbr->cycle_idx + 1) & (CYCLE_LEN - 1); - bbr->cycle_mstamp = tp->delivered_mstamp; -} - -/* Gain cycling: cycle pacing gain to converge to fair share of available bw. */ -static void bbr_update_cycle_phase(struct sock *sk, - const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - - if (bbr->mode == BBR_PROBE_BW && bbr_is_next_cycle_phase(sk, rs)) - bbr_advance_cycle_phase(sk); + tcp_snd_cwnd_set(tp, min_t(u32, tcp_snd_cwnd(tp), + bbr_probe_rtt_cwnd(sk))); } static void bbr_reset_startup_mode(struct sock *sk) @@ -614,191 +738,49 @@ static void bbr_reset_startup_mode(struct sock *sk) bbr->mode = BBR_STARTUP; } -static void bbr_reset_probe_bw_mode(struct sock *sk) -{ - struct bbr *bbr = inet_csk_ca(sk); - - bbr->mode = BBR_PROBE_BW; - bbr->cycle_idx = CYCLE_LEN - 1 - get_random_u32_below(bbr_cycle_rand); - bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ -} - -static void bbr_reset_mode(struct sock *sk) -{ - if (!bbr_full_bw_reached(sk)) - bbr_reset_startup_mode(sk); - else - bbr_reset_probe_bw_mode(sk); -} - -/* Start a new long-term sampling interval. */ -static void bbr_reset_lt_bw_sampling_interval(struct sock *sk) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - - bbr->lt_last_stamp = div_u64(tp->delivered_mstamp, USEC_PER_MSEC); - bbr->lt_last_delivered = tp->delivered; - bbr->lt_last_lost = tp->lost; - bbr->lt_rtt_cnt = 0; -} - -/* Completely reset long-term bandwidth sampling. */ -static void bbr_reset_lt_bw_sampling(struct sock *sk) -{ - struct bbr *bbr = inet_csk_ca(sk); - - bbr->lt_bw = 0; - bbr->lt_use_bw = 0; - bbr->lt_is_sampling = false; - bbr_reset_lt_bw_sampling_interval(sk); -} - -/* Long-term bw sampling interval is done. Estimate whether we're policed. */ -static void bbr_lt_bw_interval_done(struct sock *sk, u32 bw) -{ - struct bbr *bbr = inet_csk_ca(sk); - u32 diff; - - if (bbr->lt_bw) { /* do we have bw from a previous interval? */ - /* Is new bw close to the lt_bw from the previous interval? */ - diff = abs(bw - bbr->lt_bw); - if ((diff * BBR_UNIT <= bbr_lt_bw_ratio * bbr->lt_bw) || - (bbr_rate_bytes_per_sec(sk, diff, BBR_UNIT) <= - bbr_lt_bw_diff)) { - /* All criteria are met; estimate we're policed. */ - bbr->lt_bw = (bw + bbr->lt_bw) >> 1; /* avg 2 intvls */ - bbr->lt_use_bw = 1; - bbr->pacing_gain = BBR_UNIT; /* try to avoid drops */ - bbr->lt_rtt_cnt = 0; - return; - } - } - bbr->lt_bw = bw; - bbr_reset_lt_bw_sampling_interval(sk); -} - -/* Token-bucket traffic policers are common (see "An Internet-Wide Analysis of - * Traffic Policing", SIGCOMM 2016). BBR detects token-bucket policers and - * explicitly models their policed rate, to reduce unnecessary losses. We - * estimate that we're policed if we see 2 consecutive sampling intervals with - * consistent throughput and high packet loss. If we think we're being policed, - * set lt_bw to the "long-term" average delivery rate from those 2 intervals. +/* See if we have reached next round trip. Upon start of the new round, + * returns packets delivered since previous round start plus this ACK. */ -static void bbr_lt_bw_sampling(struct sock *sk, const struct rate_sample *rs) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct bbr *bbr = inet_csk_ca(sk); - u32 lost, delivered; - u64 bw; - u32 t; - - if (bbr->lt_use_bw) { /* already using long-term rate, lt_bw? */ - if (bbr->mode == BBR_PROBE_BW && bbr->round_start && - ++bbr->lt_rtt_cnt >= bbr_lt_bw_max_rtts) { - bbr_reset_lt_bw_sampling(sk); /* stop using lt_bw */ - bbr_reset_probe_bw_mode(sk); /* restart gain cycling */ - } - return; - } - - /* Wait for the first loss before sampling, to let the policer exhaust - * its tokens and estimate the steady-state rate allowed by the policer. - * Starting samples earlier includes bursts that over-estimate the bw. - */ - if (!bbr->lt_is_sampling) { - if (!rs->losses) - return; - bbr_reset_lt_bw_sampling_interval(sk); - bbr->lt_is_sampling = true; - } - - /* To avoid underestimates, reset sampling if we run out of data. */ - if (rs->is_app_limited) { - bbr_reset_lt_bw_sampling(sk); - return; - } - - if (bbr->round_start) - bbr->lt_rtt_cnt++; /* count round trips in this interval */ - if (bbr->lt_rtt_cnt < bbr_lt_intvl_min_rtts) - return; /* sampling interval needs to be longer */ - if (bbr->lt_rtt_cnt > 4 * bbr_lt_intvl_min_rtts) { - bbr_reset_lt_bw_sampling(sk); /* interval is too long */ - return; - } - - /* End sampling interval when a packet is lost, so we estimate the - * policer tokens were exhausted. Stopping the sampling before the - * tokens are exhausted under-estimates the policed rate. - */ - if (!rs->losses) - return; - - /* Calculate packets lost and delivered in sampling interval. */ - lost = tp->lost - bbr->lt_last_lost; - delivered = tp->delivered - bbr->lt_last_delivered; - /* Is loss rate (lost/delivered) >= lt_loss_thresh? If not, wait. */ - if (!delivered || (lost << BBR_SCALE) < bbr_lt_loss_thresh * delivered) - return; - - /* Find average delivery rate in this sampling interval. */ - t = div_u64(tp->delivered_mstamp, USEC_PER_MSEC) - bbr->lt_last_stamp; - if ((s32)t < 1) - return; /* interval is less than one ms, so wait */ - /* Check if can multiply without overflow */ - if (t >= ~0U / USEC_PER_MSEC) { - bbr_reset_lt_bw_sampling(sk); /* interval too long; reset */ - return; - } - t *= USEC_PER_MSEC; - bw = (u64)delivered * BW_UNIT; - do_div(bw, t); - bbr_lt_bw_interval_done(sk, bw); -} - -/* Estimate the bandwidth based on how fast packets are delivered */ -static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) +static u32 bbr_update_round_start(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u64 bw; + u32 round_delivered = 0; bbr->round_start = 0; - if (rs->delivered < 0 || rs->interval_us <= 0) - return; /* Not a valid observation */ /* See if we've reached the next RTT */ - if (!before(rs->prior_delivered, bbr->next_rtt_delivered)) { + if (rs->interval_us > 0 && + !before(rs->prior_delivered, bbr->next_rtt_delivered)) { + round_delivered = tp->delivered - bbr->next_rtt_delivered; bbr->next_rtt_delivered = tp->delivered; - bbr->rtt_cnt++; bbr->round_start = 1; - bbr->packet_conservation = 0; } + return round_delivered; +} - bbr_lt_bw_sampling(sk, rs); +/* Calculate the bandwidth based on how fast packets are delivered */ +static void bbr_calculate_bw_sample(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + u64 bw = 0; /* Divide delivered by the interval to find a (lower bound) bottleneck * bandwidth sample. Delivered is in packets and interval_us in uS and * ratio will be <<1 for most connections. So delivered is first scaled. + * Round up to allow growth at low rates, even with integer division. */ - bw = div64_long((u64)rs->delivered * BW_UNIT, rs->interval_us); - - /* If this sample is application-limited, it is likely to have a very - * low delivered count that represents application behavior rather than - * the available network rate. Such a sample could drag down estimated - * bw, causing needless slow-down. Thus, to continue to send at the - * last measured network rate, we filter out app-limited samples unless - * they describe the path bw at least as well as our bw model. - * - * So the goal during app-limited phase is to proceed with the best - * network rate no matter how long. We automatically leave this - * phase when app writes faster than the network can deliver :) - */ - if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) { - /* Incorporate new sample into our max bw filter. */ - minmax_running_max(&bbr->bw, bbr_bw_rtts, bbr->rtt_cnt, bw); + if (rs->interval_us > 0) { + if (WARN_ONCE(rs->delivered < 0, + "negative delivered: %d interval_us: %ld\n", + rs->delivered, rs->interval_us)) + return; + + bw = DIV_ROUND_UP_ULL((u64)rs->delivered * BW_UNIT, rs->interval_us); } + + ctx->sample_bw = bw; } /* Estimates the windowed max degree of ack aggregation. @@ -812,7 +794,7 @@ static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) * * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). * Max filter is an approximate sliding window of 5-10 (packet timed) round - * trips. + * trips for non-startup phase, and 1-2 round trips for startup. */ static void bbr_update_ack_aggregation(struct sock *sk, const struct rate_sample *rs) @@ -820,15 +802,19 @@ static void bbr_update_ack_aggregation(struct sock *sk, u32 epoch_us, expected_acked, extra_acked; struct bbr *bbr = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); + u32 extra_acked_win_rtts_thresh = bbr_param(sk, extra_acked_win_rtts); - if (!bbr_extra_acked_gain || rs->acked_sacked <= 0 || + if (!bbr_param(sk, extra_acked_gain) || rs->acked_sacked <= 0 || rs->delivered < 0 || rs->interval_us <= 0) return; if (bbr->round_start) { bbr->extra_acked_win_rtts = min(0x1F, bbr->extra_acked_win_rtts + 1); - if (bbr->extra_acked_win_rtts >= bbr_extra_acked_win_rtts) { + if (!bbr_full_bw_reached(sk)) + extra_acked_win_rtts_thresh = 1; + if (bbr->extra_acked_win_rtts >= + extra_acked_win_rtts_thresh) { bbr->extra_acked_win_rtts = 0; bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? 0 : 1; @@ -862,49 +848,6 @@ static void bbr_update_ack_aggregation(struct sock *sk, bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; } -/* Estimate when the pipe is full, using the change in delivery rate: BBR - * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by - * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited - * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the - * higher rwin, 3: we get higher delivery rate samples. Or transient - * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar - * design goal, but uses delay and inter-ACK spacing instead of bandwidth. - */ -static void bbr_check_full_bw_reached(struct sock *sk, - const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - u32 bw_thresh; - - if (bbr_full_bw_reached(sk) || !bbr->round_start || rs->is_app_limited) - return; - - bw_thresh = (u64)bbr->full_bw * bbr_full_bw_thresh >> BBR_SCALE; - if (bbr_max_bw(sk) >= bw_thresh) { - bbr->full_bw = bbr_max_bw(sk); - bbr->full_bw_cnt = 0; - return; - } - ++bbr->full_bw_cnt; - bbr->full_bw_reached = bbr->full_bw_cnt >= bbr_full_bw_cnt; -} - -/* If pipe is probably full, drain the queue and then enter steady-state. */ -static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) -{ - struct bbr *bbr = inet_csk_ca(sk); - - if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { - bbr->mode = BBR_DRAIN; /* drain queue we created */ - tcp_sk(sk)->snd_ssthresh = - bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); - } /* fall through to check if in-flight is already small: */ - if (bbr->mode == BBR_DRAIN && - bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= - bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) - bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ -} - static void bbr_check_probe_rtt_done(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); @@ -914,9 +857,9 @@ static void bbr_check_probe_rtt_done(struct sock *sk) after(tcp_jiffies32, bbr->probe_rtt_done_stamp))) return; - bbr->min_rtt_stamp = tcp_jiffies32; /* wait a while until PROBE_RTT */ + bbr->probe_rtt_min_stamp = tcp_jiffies32; /* schedule next PROBE_RTT */ tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); - bbr_reset_mode(sk); + bbr_exit_probe_rtt(sk); } /* The goal of PROBE_RTT mode is to have BBR flows cooperatively and @@ -942,23 +885,35 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - bool filter_expired; + bool probe_rtt_expired, min_rtt_expired; + u32 expire; - /* Track min RTT seen in the min_rtt_win_sec filter window: */ - filter_expired = after(tcp_jiffies32, - bbr->min_rtt_stamp + bbr_min_rtt_win_sec * HZ); + /* Track min RTT in probe_rtt_win_ms to time next PROBE_RTT state. */ + expire = bbr->probe_rtt_min_stamp + + msecs_to_jiffies(bbr_param(sk, probe_rtt_win_ms)); + probe_rtt_expired = after(tcp_jiffies32, expire); if (rs->rtt_us >= 0 && - (rs->rtt_us < bbr->min_rtt_us || - (filter_expired && !rs->is_ack_delayed))) { - bbr->min_rtt_us = rs->rtt_us; - bbr->min_rtt_stamp = tcp_jiffies32; + (rs->rtt_us < bbr->probe_rtt_min_us || + (probe_rtt_expired && !rs->is_ack_delayed))) { + bbr->probe_rtt_min_us = rs->rtt_us; + bbr->probe_rtt_min_stamp = tcp_jiffies32; + } + /* Track min RTT seen in the min_rtt_win_sec filter window: */ + expire = bbr->min_rtt_stamp + bbr_param(sk, min_rtt_win_sec) * HZ; + min_rtt_expired = after(tcp_jiffies32, expire); + if (bbr->probe_rtt_min_us <= bbr->min_rtt_us || + min_rtt_expired) { + bbr->min_rtt_us = bbr->probe_rtt_min_us; + bbr->min_rtt_stamp = bbr->probe_rtt_min_stamp; } - if (bbr_probe_rtt_mode_ms > 0 && filter_expired && + if (bbr_param(sk, probe_rtt_mode_ms) > 0 && probe_rtt_expired && !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ bbr_save_cwnd(sk); /* note cwnd so we can restore it */ bbr->probe_rtt_done_stamp = 0; + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; } if (bbr->mode == BBR_PROBE_RTT) { @@ -967,9 +922,9 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; /* Maintain min packets in flight for max(200 ms, 1 round). */ if (!bbr->probe_rtt_done_stamp && - tcp_packets_in_flight(tp) <= bbr_cwnd_min_target) { + tcp_packets_in_flight(tp) <= bbr_probe_rtt_cwnd(sk)) { bbr->probe_rtt_done_stamp = tcp_jiffies32 + - msecs_to_jiffies(bbr_probe_rtt_mode_ms); + msecs_to_jiffies(bbr_param(sk, probe_rtt_mode_ms)); bbr->probe_rtt_round_done = 0; bbr->next_rtt_delivered = tp->delivered; } else if (bbr->probe_rtt_done_stamp) { @@ -990,18 +945,20 @@ static void bbr_update_gains(struct sock *sk) switch (bbr->mode) { case BBR_STARTUP: - bbr->pacing_gain = bbr_high_gain; - bbr->cwnd_gain = bbr_high_gain; + bbr->pacing_gain = bbr_param(sk, startup_pacing_gain); + bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); break; case BBR_DRAIN: - bbr->pacing_gain = bbr_drain_gain; /* slow, to drain */ - bbr->cwnd_gain = bbr_high_gain; /* keep cwnd */ + bbr->pacing_gain = bbr_param(sk, drain_gain); /* slow, to drain */ + bbr->cwnd_gain = bbr_param(sk, startup_cwnd_gain); /* keep cwnd */ break; case BBR_PROBE_BW: - bbr->pacing_gain = (bbr->lt_use_bw ? - BBR_UNIT : - bbr_pacing_gain[bbr->cycle_idx]); - bbr->cwnd_gain = bbr_cwnd_gain; + bbr->pacing_gain = bbr_pacing_gain[bbr->cycle_idx]; + bbr->cwnd_gain = bbr_param(sk, cwnd_gain); + if (bbr_param(sk, bw_probe_cwnd_gain) && + bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr->cwnd_gain += + BBR_UNIT * bbr_param(sk, bw_probe_cwnd_gain) / 4; break; case BBR_PROBE_RTT: bbr->pacing_gain = BBR_UNIT; @@ -1013,144 +970,1387 @@ static void bbr_update_gains(struct sock *sk) } } -static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) +__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) { - bbr_update_bw(sk, rs); - bbr_update_ack_aggregation(sk, rs); - bbr_update_cycle_phase(sk, rs); - bbr_check_full_bw_reached(sk, rs); - bbr_check_drain(sk, rs); - bbr_update_min_rtt(sk, rs); - bbr_update_gains(sk); + /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ + return 3; } -__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) +/* Incorporate a new bw sample into the current window of our max filter. */ +static void bbr_take_max_bw_sample(struct sock *sk, u32 bw) { struct bbr *bbr = inet_csk_ca(sk); - u32 bw; - - bbr_update_model(sk, rs); - bw = bbr_bw(sk); - bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); - bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain); + bbr->bw_hi[1] = max(bw, bbr->bw_hi[1]); } -__bpf_kfunc static void bbr_init(struct sock *sk) +/* Keep max of last 1-2 cycles. Each PROBE_BW cycle, flip filter window. */ +static void bbr_advance_max_bw_filter(struct sock *sk) { - struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - bbr->prior_cwnd = 0; - tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; - bbr->rtt_cnt = 0; - bbr->next_rtt_delivered = tp->delivered; - bbr->prev_ca_state = TCP_CA_Open; - bbr->packet_conservation = 0; - - bbr->probe_rtt_done_stamp = 0; - bbr->probe_rtt_round_done = 0; - bbr->min_rtt_us = tcp_min_rtt(tp); - bbr->min_rtt_stamp = tcp_jiffies32; - - minmax_reset(&bbr->bw, bbr->rtt_cnt, 0); /* init max bw to 0 */ + if (!bbr->bw_hi[1]) + return; /* no samples in this window; remember old window */ + bbr->bw_hi[0] = bbr->bw_hi[1]; + bbr->bw_hi[1] = 0; +} - bbr->has_seen_rtt = 0; - bbr_init_pacing_rate_from_rtt(sk); +/* Reset the estimator for reaching full bandwidth based on bw plateau. */ +static void bbr_reset_full_bw(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); - bbr->round_start = 0; - bbr->idle_restart = 0; - bbr->full_bw_reached = 0; bbr->full_bw = 0; bbr->full_bw_cnt = 0; - bbr->cycle_mstamp = 0; - bbr->cycle_idx = 0; - bbr_reset_lt_bw_sampling(sk); - bbr_reset_startup_mode(sk); + bbr->full_bw_now = 0; +} - bbr->ack_epoch_mstamp = tp->tcp_mstamp; - bbr->ack_epoch_acked = 0; - bbr->extra_acked_win_rtts = 0; - bbr->extra_acked_win_idx = 0; - bbr->extra_acked[0] = 0; - bbr->extra_acked[1] = 0; +/* How much do we want in flight? Our BDP, unless congestion cut cwnd. */ +static u32 bbr_target_inflight(struct sock *sk) +{ + u32 bdp = bbr_inflight(sk, bbr_bw(sk), BBR_UNIT); - cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); + return min(bdp, tcp_sk(sk)->snd_cwnd); } -__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) +static bool bbr_is_probing_bandwidth(struct sock *sk) { - /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ - return 3; + struct bbr *bbr = inet_csk_ca(sk); + + return (bbr->mode == BBR_STARTUP) || + (bbr->mode == BBR_PROBE_BW && + (bbr->cycle_idx == BBR_BW_PROBE_REFILL || + bbr->cycle_idx == BBR_BW_PROBE_UP)); +} + +/* Has the given amount of time elapsed since we marked the phase start? */ +static bool bbr_has_elapsed_in_phase(const struct sock *sk, u32 interval_us) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct bbr *bbr = inet_csk_ca(sk); + + return tcp_stamp_us_delta(tp->tcp_mstamp, + bbr->cycle_mstamp + interval_us) > 0; +} + +static void bbr_handle_queue_too_high_in_startup(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bdp; /* estimated BDP in packets, with quantization budget */ + + bbr->full_bw_reached = 1; + + bdp = bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + bbr->inflight_hi = max(bdp, bbr->inflight_latest); +} + +/* Exit STARTUP upon N consecutive rounds with ECN mark rate > ecn_thresh. */ +static void bbr_check_ecn_too_high_in_startup(struct sock *sk, u32 ce_ratio) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk) || !bbr->ecn_eligible || + !bbr_param(sk, full_ecn_cnt) || !bbr_param(sk, ecn_thresh)) + return; + + if (ce_ratio >= bbr_param(sk, ecn_thresh)) + bbr->startup_ecn_rounds++; + else + bbr->startup_ecn_rounds = 0; + + if (bbr->startup_ecn_rounds >= bbr_param(sk, full_ecn_cnt)) { + bbr_handle_queue_too_high_in_startup(sk); + return; + } +} + +/* Updates ecn_alpha and returns ce_ratio. -1 if not available. */ +static int bbr_update_ecn_alpha(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct bbr *bbr = inet_csk_ca(sk); + s32 delivered, delivered_ce; + u64 alpha, ce_ratio; + u32 gain; + bool want_ecn_alpha; + + /* See if we should use ECN sender logic for this connection. */ + if (!bbr->ecn_eligible && bbr_can_use_ecn(sk) && + bbr_param(sk, ecn_factor) && + (bbr->min_rtt_us <= bbr_ecn_max_rtt_us || + !bbr_ecn_max_rtt_us)) + bbr->ecn_eligible = 1; + + /* Skip updating alpha only if not ECN-eligible and PLB is disabled. */ + want_ecn_alpha = (bbr->ecn_eligible || + (bbr_can_use_ecn(sk) && + READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled))); + if (!want_ecn_alpha) + return -1; + + delivered = tp->delivered - bbr->alpha_last_delivered; + delivered_ce = tp->delivered_ce - bbr->alpha_last_delivered_ce; + + if (delivered == 0 || /* avoid divide by zero */ + WARN_ON_ONCE(delivered < 0 || delivered_ce < 0)) /* backwards? */ + return -1; + + BUILD_BUG_ON(BBR_SCALE != TCP_PLB_SCALE); + ce_ratio = (u64)delivered_ce << BBR_SCALE; + do_div(ce_ratio, delivered); + + gain = bbr_param(sk, ecn_alpha_gain); + alpha = ((BBR_UNIT - gain) * bbr->ecn_alpha) >> BBR_SCALE; + alpha += (gain * ce_ratio) >> BBR_SCALE; + bbr->ecn_alpha = min_t(u32, alpha, BBR_UNIT); + + bbr->alpha_last_delivered = tp->delivered; + bbr->alpha_last_delivered_ce = tp->delivered_ce; + + bbr_check_ecn_too_high_in_startup(sk, ce_ratio); + return (int)ce_ratio; } -/* In theory BBR does not need to undo the cwnd since it does not - * always reduce cwnd on losses (see bbr_main()). Keep it for now. +/* Protective Load Balancing (PLB). PLB rehashes outgoing data (to a new IPv6 + * flow label) if it encounters sustained congestion in the form of ECN marks. */ -__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) +static void bbr_plb(struct sock *sk, const struct rate_sample *rs, int ce_ratio) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->round_start && ce_ratio >= 0) + tcp_plb_update_state(sk, &bbr->plb, ce_ratio); + + tcp_plb_check_rehash(sk, &bbr->plb); +} + +/* Each round trip of BBR_BW_PROBE_UP, double volume of probing data. */ +static void bbr_raise_inflight_hi_slope(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 growth_this_round, cnt; + + /* Calculate "slope": packets S/Acked per inflight_hi increment. */ + growth_this_round = 1 << bbr->bw_probe_up_rounds; + bbr->bw_probe_up_rounds = min(bbr->bw_probe_up_rounds + 1, 30); + cnt = tcp_snd_cwnd(tp) / growth_this_round; + cnt = max(cnt, 1U); + bbr->bw_probe_up_cnt = cnt; +} + +/* In BBR_BW_PROBE_UP, not seeing high loss/ECN/queue, so raise inflight_hi. */ +static void bbr_probe_inflight_hi_upward(struct sock *sk, + const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 delta; + + if (!tp->is_cwnd_limited || tcp_snd_cwnd(tp) < bbr->inflight_hi) + return; /* not fully using inflight_hi, so don't grow it */ + + /* For each bw_probe_up_cnt packets ACKed, increase inflight_hi by 1. */ + bbr->bw_probe_up_acks += rs->acked_sacked; + if (bbr->bw_probe_up_acks >= bbr->bw_probe_up_cnt) { + delta = bbr->bw_probe_up_acks / bbr->bw_probe_up_cnt; + bbr->bw_probe_up_acks -= delta * bbr->bw_probe_up_cnt; + bbr->inflight_hi += delta; + bbr->try_fast_path = 0; /* Need to update cwnd */ + } + + if (bbr->round_start) + bbr_raise_inflight_hi_slope(sk); +} + +/* Does loss/ECN rate for this sample say inflight is "too high"? + * This is used by both the bbr_check_loss_too_high_in_startup() function, + * which can be used in either v1 or v2, and the PROBE_UP phase of v2, which + * uses it to notice when loss/ECN rates suggest inflight is too high. + */ +static bool bbr_is_inflight_too_high(const struct sock *sk, + const struct rate_sample *rs) +{ + const struct bbr *bbr = inet_csk_ca(sk); + u32 loss_thresh, ecn_thresh; + + if (rs->lost > 0 && rs->tx_in_flight) { + loss_thresh = (u64)rs->tx_in_flight * bbr_param(sk, loss_thresh) >> + BBR_SCALE; + if (rs->lost > loss_thresh) { + return true; + } + } + + if (rs->delivered_ce > 0 && rs->delivered > 0 && + bbr->ecn_eligible && bbr_param(sk, ecn_thresh)) { + ecn_thresh = (u64)rs->delivered * bbr_param(sk, ecn_thresh) >> + BBR_SCALE; + if (rs->delivered_ce > ecn_thresh) { + return true; + } + } + + return false; +} + +/* Calculate the tx_in_flight level that corresponded to excessive loss. + * We find "lost_prefix" segs of the skb where loss rate went too high, + * by solving for "lost_prefix" in the following equation: + * lost / inflight >= loss_thresh + * (lost_prev + lost_prefix) / (inflight_prev + lost_prefix) >= loss_thresh + * Then we take that equation, convert it to fixed point, and + * round up to the nearest packet. + */ +static u32 bbr_inflight_hi_from_lost_skb(const struct sock *sk, + const struct rate_sample *rs, + const struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u32 loss_thresh = bbr_param(sk, loss_thresh); + u32 pcount, divisor, inflight_hi; + s32 inflight_prev, lost_prev; + u64 loss_budget, lost_prefix; + + pcount = tcp_skb_pcount(skb); + + /* How much data was in flight before this skb? */ + inflight_prev = rs->tx_in_flight - pcount; + if (inflight_prev < 0) { + WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( + pcount, + TCP_SKB_CB(skb)->sacked, + rs->tx_in_flight), + "tx_in_flight: %u pcount: %u reneg: %u", + rs->tx_in_flight, pcount, tcp_sk(sk)->is_sack_reneg); + return ~0U; + } + + /* How much inflight data was marked lost before this skb? */ + lost_prev = rs->lost - pcount; + if (WARN_ONCE(lost_prev < 0, + "cwnd: %u ca: %d out: %u lost: %u pif: %u " + "tx_in_flight: %u tx.lost: %u tp->lost: %u rs->lost: %d " + "lost_prev: %d pcount: %d seq: %u end_seq: %u reneg: %u", + tcp_snd_cwnd(tp), inet_csk(sk)->icsk_ca_state, + tp->packets_out, tp->lost_out, tcp_packets_in_flight(tp), + rs->tx_in_flight, TCP_SKB_CB(skb)->tx.lost, tp->lost, + rs->lost, lost_prev, pcount, + TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, + tp->is_sack_reneg)) + return ~0U; + + /* At what prefix of this lost skb did losss rate exceed loss_thresh? */ + loss_budget = (u64)inflight_prev * loss_thresh + BBR_UNIT - 1; + loss_budget >>= BBR_SCALE; + if (lost_prev >= loss_budget) { + lost_prefix = 0; /* previous losses crossed loss_thresh */ + } else { + lost_prefix = loss_budget - lost_prev; + lost_prefix <<= BBR_SCALE; + divisor = BBR_UNIT - loss_thresh; + if (WARN_ON_ONCE(!divisor)) /* loss_thresh is 8 bits */ + return ~0U; + do_div(lost_prefix, divisor); + } + + inflight_hi = inflight_prev + lost_prefix; + return inflight_hi; +} + +/* If loss/ECN rates during probing indicated we may have overfilled a + * buffer, return an operating point that tries to leave unutilized headroom in + * the path for other flows, for fairness convergence and lower RTTs and loss. + */ +static u32 bbr_inflight_with_headroom(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 headroom, headroom_fraction; + + if (bbr->inflight_hi == ~0U) + return ~0U; + + headroom_fraction = bbr_param(sk, inflight_headroom); + headroom = ((u64)bbr->inflight_hi * headroom_fraction) >> BBR_SCALE; + headroom = max(headroom, 1U); + return max_t(s32, bbr->inflight_hi - headroom, + bbr_param(sk, cwnd_min_target)); +} + +/* Bound cwnd to a sensible level, based on our current probing state + * machine phase and model of a good inflight level (inflight_lo, inflight_hi). + */ +static void bbr_bound_cwnd_for_inflight_model(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 cap; + + /* tcp_rcv_synsent_state_process() currently calls tcp_ack() + * and thus cong_control() without first initializing us(!). + */ + if (!bbr->initialized) + return; + + cap = ~0U; + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx != BBR_BW_PROBE_CRUISE) { + /* Probe to see if more packets fit in the path. */ + cap = bbr->inflight_hi; + } else { + if (bbr->mode == BBR_PROBE_RTT || + (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_CRUISE)) + cap = bbr_inflight_with_headroom(sk); + } + /* Adapt to any loss/ECN since our last bw probe. */ + cap = min(cap, bbr->inflight_lo); + + cap = max_t(u32, cap, bbr_param(sk, cwnd_min_target)); + tcp_snd_cwnd_set(tp, min(cap, tcp_snd_cwnd(tp))); +} + +/* How should we multiplicatively cut bw or inflight limits based on ECN? */ +static u32 bbr_ecn_cut(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return BBR_UNIT - + ((bbr->ecn_alpha * bbr_param(sk, ecn_factor)) >> BBR_SCALE); +} + +/* Init lower bounds if have not inited yet. */ +static void bbr_init_lower_bounds(struct sock *sk, bool init_bw) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (init_bw && bbr->bw_lo == ~0U) + bbr->bw_lo = bbr_max_bw(sk); + if (bbr->inflight_lo == ~0U) + bbr->inflight_lo = tcp_snd_cwnd(tp); +} + +/* Reduce bw and inflight to (1 - beta). */ +static void bbr_loss_lower_bounds(struct sock *sk, u32 *bw, u32 *inflight) +{ + struct bbr* bbr = inet_csk_ca(sk); + u32 loss_cut = BBR_UNIT - bbr_param(sk, beta); + + *bw = max_t(u32, bbr->bw_latest, + (u64)bbr->bw_lo * loss_cut >> BBR_SCALE); + *inflight = max_t(u32, bbr->inflight_latest, + (u64)bbr->inflight_lo * loss_cut >> BBR_SCALE); +} + +/* Reduce inflight to (1 - alpha*ecn_factor). */ +static void bbr_ecn_lower_bounds(struct sock *sk, u32 *inflight) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 ecn_cut = bbr_ecn_cut(sk); + + *inflight = (u64)bbr->inflight_lo * ecn_cut >> BBR_SCALE; +} + +/* Estimate a short-term lower bound on the capacity available now, based + * on measurements of the current delivery process and recent history. When we + * are seeing loss/ECN at times when we are not probing bw, then conservatively + * move toward flow balance by multiplicatively cutting our short-term + * estimated safe rate and volume of data (bw_lo and inflight_lo). We use a + * multiplicative decrease in order to converge to a lower capacity in time + * logarithmic in the magnitude of the decrease. + * + * However, we do not cut our short-term estimates lower than the current rate + * and volume of delivered data from this round trip, since from the current + * delivery process we can estimate the measured capacity available now. + * + * Anything faster than that approach would knowingly risk high loss, which can + * cause low bw for Reno/CUBIC and high loss recovery latency for + * request/response flows using any congestion control. + */ +static void bbr_adapt_lower_bounds(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 ecn_inflight_lo = ~0U; + + /* We only use lower-bound estimates when not probing bw. + * When probing we need to push inflight higher to probe bw. + */ + if (bbr_is_probing_bandwidth(sk)) + return; + + /* ECN response. */ + if (bbr->ecn_in_round && bbr_param(sk, ecn_factor)) { + bbr_init_lower_bounds(sk, false); + bbr_ecn_lower_bounds(sk, &ecn_inflight_lo); + } + + /* Loss response. */ + if (bbr->loss_in_round) { + bbr_init_lower_bounds(sk, true); + bbr_loss_lower_bounds(sk, &bbr->bw_lo, &bbr->inflight_lo); + } + + /* Adjust to the lower of the levels implied by loss/ECN. */ + bbr->inflight_lo = min(bbr->inflight_lo, ecn_inflight_lo); + bbr->bw_lo = max(1U, bbr->bw_lo); +} + +/* Reset any short-term lower-bound adaptation to congestion, so that we can + * push our inflight up. + */ +static void bbr_reset_lower_bounds(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->bw_lo = ~0U; + bbr->inflight_lo = ~0U; +} + +/* After bw probing (STARTUP/PROBE_UP), reset signals before entering a state + * machine phase where we adapt our lower bound based on congestion signals. + */ +static void bbr_reset_congestion_signals(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; + bbr->loss_in_cycle = 0; + bbr->ecn_in_cycle = 0; + bbr->bw_latest = 0; + bbr->inflight_latest = 0; +} + +static void bbr_exit_loss_recovery(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); + bbr->try_fast_path = 0; /* bound cwnd using latest model */ +} + +/* Update rate and volume of delivered data from latest round trip. */ +static void bbr_update_latest_delivery_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->loss_round_start = 0; + if (rs->interval_us <= 0 || !rs->acked_sacked) + return; /* Not a valid observation */ + + bbr->bw_latest = max_t(u32, bbr->bw_latest, ctx->sample_bw); + bbr->inflight_latest = max_t(u32, bbr->inflight_latest, rs->delivered); + + if (!before(rs->prior_delivered, bbr->loss_round_delivered)) { + bbr->loss_round_delivered = tp->delivered; + bbr->loss_round_start = 1; /* mark start of new round trip */ + } +} + +/* Once per round, reset filter for latest rate and volume of delivered data. */ +static void bbr_advance_latest_delivery_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* If ACK matches a TLP retransmit, persist the filter. If we detect + * that a TLP retransmit plugged a tail loss, we'll want to remember + * how much data the path delivered before the tail loss. + */ + if (bbr->loss_round_start && !rs->is_acking_tlp_retrans_seq) { + bbr->bw_latest = ctx->sample_bw; + bbr->inflight_latest = rs->delivered; + } +} + +/* Update (most of) our congestion signals: track the recent rate and volume of + * delivered data, presence of loss, and EWMA degree of ECN marking. + */ +static void bbr_update_congestion_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) { struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + + if (rs->interval_us <= 0 || !rs->acked_sacked) + return; /* Not a valid observation */ + bw = ctx->sample_bw; - bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ + if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) + bbr_take_max_bw_sample(sk, bw); + + bbr->loss_in_round |= (rs->losses > 0); + + if (!bbr->loss_round_start) + return; /* skip the per-round-trip updates */ + /* Now do per-round-trip updates. */ + bbr_adapt_lower_bounds(sk, rs); + + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; +} + +/* Bandwidth probing can cause loss. To help coexistence with loss-based + * congestion control we spread out our probing in a Reno-conscious way. Due to + * the shape of the Reno sawtooth, the time required between loss epochs for an + * idealized Reno flow is a number of round trips that is the BDP of that + * flow. We count packet-timed round trips directly, since measured RTT can + * vary widely, and Reno is driven by packet-timed round trips. + */ +static bool bbr_is_reno_coexistence_probe_time(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 rounds; + + /* Random loss can shave some small percentage off of our inflight + * in each round. To survive this, flows need robust periodic probes. + */ + rounds = min_t(u32, bbr_param(sk, bw_probe_max_rounds), bbr_target_inflight(sk)); + return bbr->rounds_since_probe >= rounds; +} + +/* How long do we want to wait before probing for bandwidth (and risking + * loss)? We randomize the wait, for better mixing and fairness convergence. + * + * We bound the Reno-coexistence inter-bw-probe time to be 62-63 round trips. + * This is calculated to allow fairness with a 25Mbps, 30ms Reno flow, + * (eg 4K video to a broadband user): + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + * + * We bound the BBR-native inter-bw-probe wall clock time to be: + * (a) higher than 2 sec: to try to avoid causing loss for a long enough time + * to allow Reno at 30ms to get 4K video bw, the inter-bw-probe time must + * be at least: 25Mbps * .030sec / (1514bytes) * 0.030sec = 1.9secs + * (b) lower than 3 sec: to ensure flows can start probing in a reasonable + * amount of time to discover unutilized bw on human-scale interactive + * time-scales (e.g. perhaps traffic from a web page download that we + * were competing with is now complete). + */ +static void bbr_pick_probe_wait(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Decide the random round-trip bound for wait until probe: */ + bbr->rounds_since_probe = + get_random_u32_below(bbr_param(sk, bw_probe_rand_rounds)); + /* Decide the random wall clock bound for wait until probe: */ + bbr->probe_wait_us = bbr_param(sk, bw_probe_base_us) + + get_random_u32_below(bbr_param(sk, bw_probe_rand_us)); +} + +static void bbr_set_cycle_idx(struct sock *sk, int cycle_idx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->cycle_idx = cycle_idx; + /* New phase, so need to update cwnd and pacing rate. */ + bbr->try_fast_path = 0; +} + +/* Send at estimated bw to fill the pipe, but not queue. We need this phase + * before PROBE_UP, because as soon as we send faster than the available bw + * we will start building a queue, and if the buffer is shallow we can cause + * loss. If we do not fill the pipe before we cause this loss, our bw_hi and + * inflight_hi estimates will underestimate. + */ +static void bbr_start_bw_probe_refill(struct sock *sk, u32 bw_probe_up_rounds) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_lower_bounds(sk); + bbr->bw_probe_up_rounds = bw_probe_up_rounds; + bbr->bw_probe_up_acks = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_REFILLING; + bbr->next_rtt_delivered = tp->delivered; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_REFILL); +} + +/* Now probe max deliverable data rate and volume. */ +static void bbr_start_bw_probe_up(struct sock *sk, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->ack_phase = BBR_ACKS_PROBE_STARTING; + bbr->next_rtt_delivered = tp->delivered; + bbr->cycle_mstamp = tp->tcp_mstamp; + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_UP); + bbr_raise_inflight_hi_slope(sk); +} + +/* Start a new PROBE_BW probing cycle of some wall clock length. Pick a wall + * clock time at which to probe beyond an inflight that we think to be + * safe. This will knowingly risk packet loss, so we want to do this rarely, to + * keep packet loss rates low. Also start a round-trip counter, to probe faster + * if we estimate a Reno flow at our BDP would probe faster. + */ +static void bbr_start_bw_probe_down(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_congestion_signals(sk); + bbr->bw_probe_up_cnt = ~0U; /* not growing inflight_hi any more */ + bbr_pick_probe_wait(sk); + bbr->cycle_mstamp = tp->tcp_mstamp; /* start wall clock */ + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; + bbr_set_cycle_idx(sk, BBR_BW_PROBE_DOWN); +} + +/* Cruise: maintain what we estimate to be a neutral, conservative + * operating point, without attempting to probe up for bandwidth or down for + * RTT, and only reducing inflight in response to loss/ECN signals. + */ +static void bbr_start_bw_probe_cruise(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->inflight_lo != ~0U) + bbr->inflight_lo = min(bbr->inflight_lo, bbr->inflight_hi); + + bbr_set_cycle_idx(sk, BBR_BW_PROBE_CRUISE); +} + +/* Loss and/or ECN rate is too high while probing. + * Adapt (once per bw probe) by cutting inflight_hi and then restarting cycle. + */ +static void bbr_handle_inflight_too_high(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + const u32 beta = bbr_param(sk, beta); + + bbr->prev_probe_too_high = 1; + bbr->bw_probe_samples = 0; /* only react once per probe */ + /* If we are app-limited then we are not robustly + * probing the max volume of inflight data we think + * might be safe (analogous to how app-limited bw + * samples are not known to be robustly probing bw). + */ + if (!rs->is_app_limited) { + bbr->inflight_hi = max_t(u32, rs->tx_in_flight, + (u64)bbr_target_inflight(sk) * + (BBR_UNIT - beta) >> BBR_SCALE); + } + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr_start_bw_probe_down(sk); +} + +/* If we're seeing bw and loss samples reflecting our bw probing, adapt + * using the signals we see. If loss or ECN mark rate gets too high, then adapt + * inflight_hi downward. If we're able to push inflight higher without such + * signals, push higher: adapt inflight_hi upward. + */ +static bool bbr_adapt_upper_bounds(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Track when we'll see bw/loss samples resulting from our bw probes. */ + if (bbr->ack_phase == BBR_ACKS_PROBE_STARTING && bbr->round_start) + bbr->ack_phase = BBR_ACKS_PROBE_FEEDBACK; + if (bbr->ack_phase == BBR_ACKS_PROBE_STOPPING && bbr->round_start) { + /* End of samples from bw probing phase. */ + bbr->bw_probe_samples = 0; + bbr->ack_phase = BBR_ACKS_INIT; + /* At this point in the cycle, our current bw sample is also + * our best recent chance at finding the highest available bw + * for this flow. So now is the best time to forget the bw + * samples from the previous cycle, by advancing the window. + */ + if (bbr->mode == BBR_PROBE_BW && !rs->is_app_limited) + bbr_advance_max_bw_filter(sk); + /* If we had an inflight_hi, then probed and pushed inflight all + * the way up to hit that inflight_hi without seeing any + * high loss/ECN in all the resulting ACKs from that probing, + * then probe up again, this time letting inflight persist at + * inflight_hi for a round trip, then accelerating beyond. + */ + if (bbr->mode == BBR_PROBE_BW && + bbr->stopped_risky_probe && !bbr->prev_probe_too_high) { + bbr_start_bw_probe_refill(sk, 0); + return true; /* yes, decided state transition */ + } + } + if (bbr_is_inflight_too_high(sk, rs)) { + if (bbr->bw_probe_samples) /* sample is from bw probing? */ + bbr_handle_inflight_too_high(sk, rs); + } else { + /* Loss/ECN rate is declared safe. Adjust upper bound upward. */ + + if (bbr->inflight_hi == ~0U) + return false; /* no excess queue signals yet */ + + /* To be resilient to random loss, we must raise bw/inflight_hi + * if we observe in any phase that a higher level is safe. + */ + if (rs->tx_in_flight > bbr->inflight_hi) { + bbr->inflight_hi = rs->tx_in_flight; + } + + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr_probe_inflight_hi_upward(sk, rs); + } + + return false; +} + +/* Check if it's time to probe for bandwidth now, and if so, kick it off. */ +static bool bbr_check_time_to_probe_bw(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 n; + + /* If we seem to be at an operating point where we are not seeing loss + * but we are seeing ECN marks, then when the ECN marks cease we reprobe + * quickly (in case cross-traffic has ceased and freed up bw). + */ + if (bbr_param(sk, ecn_reprobe_gain) && bbr->ecn_eligible && + bbr->ecn_in_cycle && !bbr->loss_in_cycle && + inet_csk(sk)->icsk_ca_state == TCP_CA_Open) { + /* Calculate n so that when bbr_raise_inflight_hi_slope() + * computes growth_this_round as 2^n it will be roughly the + * desired volume of data (inflight_hi*ecn_reprobe_gain). + */ + n = ilog2((((u64)bbr->inflight_hi * + bbr_param(sk, ecn_reprobe_gain)) >> BBR_SCALE)); + bbr_start_bw_probe_refill(sk, n); + return true; + } + + if (bbr_has_elapsed_in_phase(sk, bbr->probe_wait_us) || + bbr_is_reno_coexistence_probe_time(sk)) { + bbr_start_bw_probe_refill(sk, 0); + return true; + } + return false; +} + +/* Is it time to transition from PROBE_DOWN to PROBE_CRUISE? */ +static bool bbr_check_time_to_cruise(struct sock *sk, u32 inflight, u32 bw) +{ + /* Always need to pull inflight down to leave headroom in queue. */ + if (inflight > bbr_inflight_with_headroom(sk)) + return false; + + return inflight <= bbr_inflight(sk, bw, BBR_UNIT); +} + +/* PROBE_BW state machine: cruise, refill, probe for bw, or drain? */ +static void bbr_update_cycle_phase(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + bool is_bw_probe_done = false; + u32 inflight, bw; + + if (!bbr_full_bw_reached(sk)) + return; + + /* In DRAIN, PROBE_BW, or PROBE_RTT, adjust upper bounds. */ + if (bbr_adapt_upper_bounds(sk, rs, ctx)) + return; /* already decided state transition */ + + if (bbr->mode != BBR_PROBE_BW) + return; + + inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); + bw = bbr_max_bw(sk); + + switch (bbr->cycle_idx) { + /* First we spend most of our time cruising with a pacing_gain of 1.0, + * which paces at the estimated bw, to try to fully use the pipe + * without building queue. If we encounter loss/ECN marks, we adapt + * by slowing down. + */ + case BBR_BW_PROBE_CRUISE: + if (bbr_check_time_to_probe_bw(sk, rs)) + return; /* already decided state transition */ + break; + + /* After cruising, when it's time to probe, we first "refill": we send + * at the estimated bw to fill the pipe, before probing higher and + * knowingly risking overflowing the bottleneck buffer (causing loss). + */ + case BBR_BW_PROBE_REFILL: + if (bbr->round_start) { + /* After one full round trip of sending in REFILL, we + * start to see bw samples reflecting our REFILL, which + * may be putting too much data in flight. + */ + bbr->bw_probe_samples = 1; + bbr_start_bw_probe_up(sk, ctx); + } + break; + + /* After we refill the pipe, we probe by using a pacing_gain > 1.0, to + * probe for bw. If we have not seen loss/ECN, we try to raise inflight + * to at least pacing_gain*BDP; note that this may take more than + * min_rtt if min_rtt is small (e.g. on a LAN). + * + * We terminate PROBE_UP bandwidth probing upon any of the following: + * + * (1) We've pushed inflight up to hit the inflight_hi target set in the + * most recent previous bw probe phase. Thus we want to start + * draining the queue immediately because it's very likely the most + * recently sent packets will fill the queue and cause drops. + * (2) If inflight_hi has not limited bandwidth growth recently, and + * yet delivered bandwidth has not increased much recently + * (bbr->full_bw_now). + * (3) Loss filter says loss rate is "too high". + * (4) ECN filter says ECN mark rate is "too high". + * + * (1) (2) checked here, (3) (4) checked in bbr_is_inflight_too_high() + */ + case BBR_BW_PROBE_UP: + if (bbr->prev_probe_too_high && + inflight >= bbr->inflight_hi) { + bbr->stopped_risky_probe = 1; + is_bw_probe_done = true; + } else { + if (tp->is_cwnd_limited && + tcp_snd_cwnd(tp) >= bbr->inflight_hi) { + /* inflight_hi is limiting bw growth */ + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + } else if (bbr->full_bw_now) { + /* Plateau in estimated bw. Pipe looks full. */ + is_bw_probe_done = true; + } + } + if (is_bw_probe_done) { + bbr->prev_probe_too_high = 0; /* no loss/ECN (yet) */ + bbr_start_bw_probe_down(sk); /* restart w/ down */ + } + break; + + /* After probing in PROBE_UP, we have usually accumulated some data in + * the bottleneck buffer (if bw probing didn't find more bw). We next + * enter PROBE_DOWN to try to drain any excess data from the queue. To + * do this, we use a pacing_gain < 1.0. We hold this pacing gain until + * our inflight is less then that target cruising point, which is the + * minimum of (a) the amount needed to leave headroom, and (b) the + * estimated BDP. Once inflight falls to match the target, we estimate + * the queue is drained; persisting would underutilize the pipe. + */ + case BBR_BW_PROBE_DOWN: + if (bbr_check_time_to_probe_bw(sk, rs)) + return; /* already decided state transition */ + if (bbr_check_time_to_cruise(sk, inflight, bw)) + bbr_start_bw_probe_cruise(sk); + break; + + default: + WARN_ONCE(1, "BBR invalid cycle index %u\n", bbr->cycle_idx); + } +} + +/* Exiting PROBE_RTT, so return to bandwidth probing in STARTUP or PROBE_BW. */ +static void bbr_exit_probe_rtt(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_lower_bounds(sk); + if (bbr_full_bw_reached(sk)) { + bbr->mode = BBR_PROBE_BW; + /* Raising inflight after PROBE_RTT may cause loss, so reset + * the PROBE_BW clock and schedule the next bandwidth probe for + * a friendly and randomized future point in time. + */ + bbr_start_bw_probe_down(sk); + /* Since we are exiting PROBE_RTT, we know inflight is + * below our estimated BDP, so it is reasonable to cruise. + */ + bbr_start_bw_probe_cruise(sk); + } else { + bbr->mode = BBR_STARTUP; + } +} + +/* Exit STARTUP based on loss rate > 1% and loss gaps in round >= N. Wait until + * the end of the round in recovery to get a good estimate of how many packets + * have been lost, and how many we need to drain with a low pacing rate. + */ +static void bbr_check_loss_too_high_in_startup(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk)) + return; + + /* For STARTUP exit, check the loss rate at the end of each round trip + * of Recovery episodes in STARTUP. We check the loss rate at the end + * of the round trip to filter out noisy/low loss and have a better + * sense of inflight (extent of loss), so we can drain more accurately. + */ + if (rs->losses && bbr->loss_events_in_round < 0xf) + bbr->loss_events_in_round++; /* update saturating counter */ + if (bbr_param(sk, full_loss_cnt) && bbr->loss_round_start && + inet_csk(sk)->icsk_ca_state == TCP_CA_Recovery && + bbr->loss_events_in_round >= bbr_param(sk, full_loss_cnt) && + bbr_is_inflight_too_high(sk, rs)) { + bbr_handle_queue_too_high_in_startup(sk); + return; + } + if (bbr->loss_round_start) + bbr->loss_events_in_round = 0; +} + +/* Estimate when the pipe is full, using the change in delivery rate: BBR + * estimates bw probing filled the pipe if the estimated bw hasn't changed by + * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited + * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the + * higher rwin, 3: we get higher delivery rate samples. Or transient + * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar + * design goal, but uses delay and inter-ACK spacing instead of bandwidth. + */ +static void bbr_check_full_bw_reached(struct sock *sk, + const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bw_thresh, full_cnt, thresh; + + if (bbr->full_bw_now || rs->is_app_limited) + return; + + thresh = bbr_param(sk, full_bw_thresh); + full_cnt = bbr_param(sk, full_bw_cnt); + bw_thresh = (u64)bbr->full_bw * thresh >> BBR_SCALE; + if (ctx->sample_bw >= bw_thresh) { + bbr_reset_full_bw(sk); + bbr->full_bw = ctx->sample_bw; + return; + } + if (!bbr->round_start) + return; + ++bbr->full_bw_cnt; + bbr->full_bw_now = bbr->full_bw_cnt >= full_cnt; + bbr->full_bw_reached |= bbr->full_bw_now; +} + +/* If pipe is probably full, drain the queue and then enter steady-state. */ +static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { + bbr->mode = BBR_DRAIN; /* drain queue we created */ + /* Set ssthresh to export purely for monitoring, to signal + * completion of initial STARTUP by setting to a non- + * TCP_INFINITE_SSTHRESH value (ssthresh is not used by BBR). + */ + tcp_sk(sk)->snd_ssthresh = + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + bbr_reset_congestion_signals(sk); + } /* fall through to check if in-flight is already small: */ + if (bbr->mode == BBR_DRAIN && + bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) { + bbr->mode = BBR_PROBE_BW; + bbr_start_bw_probe_down(sk); + } +} + +static void bbr_update_model(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + bbr_update_congestion_signals(sk, rs, ctx); + bbr_update_ack_aggregation(sk, rs); + bbr_check_loss_too_high_in_startup(sk, rs); + bbr_check_full_bw_reached(sk, rs, ctx); + bbr_check_drain(sk, rs, ctx); + bbr_update_cycle_phase(sk, rs, ctx); + bbr_update_min_rtt(sk, rs); +} + +/* Fast path for app-limited case. + * + * On each ack, we execute bbr state machine, which primarily consists of: + * 1) update model based on new rate sample, and + * 2) update control based on updated model or state change. + * + * There are certain workload/scenarios, e.g. app-limited case, where + * either we can skip updating model or we can skip update of both model + * as well as control. This provides signifcant softirq cpu savings for + * processing incoming acks. + * + * In case of app-limited, if there is no congestion (loss/ecn) and + * if observed bw sample is less than current estimated bw, then we can + * skip some of the computation in bbr state processing: + * + * - if there is no rtt/mode/phase change: In this case, since all the + * parameters of the network model are constant, we can skip model + * as well control update. + * + * - else we can skip rest of the model update. But we still need to + * update the control to account for the new rtt/mode/phase. + * + * Returns whether we can take fast path or not. + */ +static bool bbr_run_fast_path(struct sock *sk, bool *update_model, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 prev_min_rtt_us, prev_mode; + + if (bbr_param(sk, fast_path) && bbr->try_fast_path && + rs->is_app_limited && ctx->sample_bw < bbr_max_bw(sk) && + !bbr->loss_in_round && !bbr->ecn_in_round ) { + prev_mode = bbr->mode; + prev_min_rtt_us = bbr->min_rtt_us; + bbr_check_drain(sk, rs, ctx); + bbr_update_cycle_phase(sk, rs, ctx); + bbr_update_min_rtt(sk, rs); + + if (bbr->mode == prev_mode && + bbr->min_rtt_us == prev_min_rtt_us && + bbr->try_fast_path) { + return true; + } + + /* Skip model update, but control still needs to be updated */ + *update_model = false; + } + return false; +} + +__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct bbr_context ctx = { 0 }; + bool update_model = true; + u32 bw, round_delivered; + int ce_ratio = -1; + + round_delivered = bbr_update_round_start(sk, rs, &ctx); + if (bbr->round_start) { + bbr->rounds_since_probe = + min_t(s32, bbr->rounds_since_probe + 1, 0xFF); + ce_ratio = bbr_update_ecn_alpha(sk); + } + bbr_plb(sk, rs, ce_ratio); + + bbr->ecn_in_round |= (bbr->ecn_eligible && rs->is_ece); + bbr_calculate_bw_sample(sk, rs, &ctx); + bbr_update_latest_delivery_signals(sk, rs, &ctx); + + if (bbr_run_fast_path(sk, &update_model, rs, &ctx)) + goto out; + + if (update_model) + bbr_update_model(sk, rs, &ctx); + + bbr_update_gains(sk); + bw = bbr_bw(sk); + bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); + bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain, + tcp_snd_cwnd(tp), &ctx); + bbr_bound_cwnd_for_inflight_model(sk); + +out: + bbr_advance_latest_delivery_signals(sk, rs, &ctx); + bbr->prev_ca_state = inet_csk(sk)->icsk_ca_state; + bbr->loss_in_cycle |= rs->lost > 0; + bbr->ecn_in_cycle |= rs->delivered_ce > 0; +} + +__bpf_kfunc static void bbr_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->initialized = 1; + + bbr->init_cwnd = min(0x7FU, tcp_snd_cwnd(tp)); + bbr->prior_cwnd = tp->prior_cwnd; + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + bbr->next_rtt_delivered = tp->delivered; + bbr->prev_ca_state = TCP_CA_Open; + + bbr->probe_rtt_done_stamp = 0; + bbr->probe_rtt_round_done = 0; + bbr->probe_rtt_min_us = tcp_min_rtt(tp); + bbr->probe_rtt_min_stamp = tcp_jiffies32; + bbr->min_rtt_us = tcp_min_rtt(tp); + bbr->min_rtt_stamp = tcp_jiffies32; + + bbr->has_seen_rtt = 0; + bbr_init_pacing_rate_from_rtt(sk); + + bbr->round_start = 0; + bbr->idle_restart = 0; + bbr->full_bw_reached = 0; + bbr->full_bw = 0; bbr->full_bw_cnt = 0; - bbr_reset_lt_bw_sampling(sk); - return tcp_snd_cwnd(tcp_sk(sk)); + bbr->cycle_mstamp = 0; + bbr->cycle_idx = 0; + + bbr_reset_startup_mode(sk); + + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = 0; + bbr->extra_acked[0] = 0; + bbr->extra_acked[1] = 0; + + bbr->ce_state = 0; + bbr->prior_rcv_nxt = tp->rcv_nxt; + bbr->try_fast_path = 0; + + cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); + + /* Start sampling ECN mark rate after first full flight is ACKed: */ + bbr->loss_round_delivered = tp->delivered + 1; + bbr->loss_round_start = 0; + bbr->undo_bw_lo = 0; + bbr->undo_inflight_lo = 0; + bbr->undo_inflight_hi = 0; + bbr->loss_events_in_round = 0; + bbr->startup_ecn_rounds = 0; + bbr_reset_congestion_signals(sk); + bbr->bw_lo = ~0U; + bbr->bw_hi[0] = 0; + bbr->bw_hi[1] = 0; + bbr->inflight_lo = ~0U; + bbr->inflight_hi = ~0U; + bbr_reset_full_bw(sk); + bbr->bw_probe_up_cnt = ~0U; + bbr->bw_probe_up_acks = 0; + bbr->bw_probe_up_rounds = 0; + bbr->probe_wait_us = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_INIT; + bbr->rounds_since_probe = 0; + bbr->bw_probe_samples = 0; + bbr->prev_probe_too_high = 0; + bbr->ecn_eligible = 0; + bbr->ecn_alpha = bbr_param(sk, ecn_alpha_init); + bbr->alpha_last_delivered = 0; + bbr->alpha_last_delivered_ce = 0; + bbr->plb.pause_until = 0; + + tp->fast_ack_mode = bbr_fast_ack_mode ? 1 : 0; + + if (bbr_can_use_ecn(sk)) + tp->ecn_flags |= TCP_ECN_ECT_PERMANENT; +} + +/* BBR marks the current round trip as a loss round. */ +static void bbr_note_loss(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + /* Capture "current" data over the full round trip of loss, to + * have a better chance of observing the full capacity of the path. + */ + if (!bbr->loss_in_round) /* first loss in this round trip? */ + bbr->loss_round_delivered = tp->delivered; /* set round trip */ + bbr->loss_in_round = 1; + bbr->loss_in_cycle = 1; } -/* Entering loss recovery, so save cwnd for when we exit or undo recovery. */ +/* Core TCP stack informs us that the given skb was just marked lost. */ +__bpf_kfunc static void bbr_skb_marked_lost(struct sock *sk, + const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + struct rate_sample rs = {}; + + bbr_note_loss(sk); + + if (!bbr->bw_probe_samples) + return; /* not an skb sent while probing for bandwidth */ + if (unlikely(!scb->tx.delivered_mstamp)) + return; /* skb was SACKed, reneged, marked lost; ignore it */ + /* We are probing for bandwidth. Construct a rate sample that + * estimates what happened in the flight leading up to this lost skb, + * then see if the loss rate went too high, and if so at which packet. + */ + rs.tx_in_flight = scb->tx.in_flight; + rs.lost = tp->lost - scb->tx.lost; + rs.is_app_limited = scb->tx.is_app_limited; + if (bbr_is_inflight_too_high(sk, &rs)) { + rs.tx_in_flight = bbr_inflight_hi_from_lost_skb(sk, &rs, skb); + bbr_handle_inflight_too_high(sk, &rs); + } +} + +static void bbr_run_loss_probe_recovery(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct rate_sample rs = {0}; + + bbr_note_loss(sk); + + if (!bbr->bw_probe_samples) + return; /* not sent while probing for bandwidth */ + /* We are probing for bandwidth. Construct a rate sample that + * estimates what happened in the flight leading up to this + * loss, then see if the loss rate went too high. + */ + rs.lost = 1; /* TLP probe repaired loss of a single segment */ + rs.tx_in_flight = bbr->inflight_latest + rs.lost; + rs.is_app_limited = tp->tlp_orig_data_app_limited; + if (bbr_is_inflight_too_high(sk, &rs)) + bbr_handle_inflight_too_high(sk, &rs); +} + +/* Revert short-term model if current loss recovery event was spurious. */ +__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr_reset_full_bw(sk); /* spurious slow-down; reset full bw detector */ + bbr->loss_in_round = 0; + + /* Revert to cwnd and other state saved before loss episode. */ + bbr->bw_lo = max(bbr->bw_lo, bbr->undo_bw_lo); + bbr->inflight_lo = max(bbr->inflight_lo, bbr->undo_inflight_lo); + bbr->inflight_hi = max(bbr->inflight_hi, bbr->undo_inflight_hi); + bbr->try_fast_path = 0; /* take slow path to set proper cwnd, pacing */ + return bbr->prior_cwnd; +} + +/* Entering loss recovery, so save state for when we undo recovery. */ __bpf_kfunc static u32 bbr_ssthresh(struct sock *sk) { + struct bbr *bbr = inet_csk_ca(sk); + bbr_save_cwnd(sk); + /* For undo, save state that adapts based on loss signal. */ + bbr->undo_bw_lo = bbr->bw_lo; + bbr->undo_inflight_lo = bbr->inflight_lo; + bbr->undo_inflight_hi = bbr->inflight_hi; return tcp_sk(sk)->snd_ssthresh; } +static enum tcp_bbr_phase bbr_get_phase(struct bbr *bbr) +{ + switch (bbr->mode) { + case BBR_STARTUP: + return BBR_PHASE_STARTUP; + case BBR_DRAIN: + return BBR_PHASE_DRAIN; + case BBR_PROBE_BW: + break; + case BBR_PROBE_RTT: + return BBR_PHASE_PROBE_RTT; + default: + return BBR_PHASE_INVALID; + } + switch (bbr->cycle_idx) { + case BBR_BW_PROBE_UP: + return BBR_PHASE_PROBE_BW_UP; + case BBR_BW_PROBE_DOWN: + return BBR_PHASE_PROBE_BW_DOWN; + case BBR_BW_PROBE_CRUISE: + return BBR_PHASE_PROBE_BW_CRUISE; + case BBR_BW_PROBE_REFILL: + return BBR_PHASE_PROBE_BW_REFILL; + default: + return BBR_PHASE_INVALID; + } +} + static size_t bbr_get_info(struct sock *sk, u32 ext, int *attr, - union tcp_cc_info *info) + union tcp_cc_info *info) { if (ext & (1 << (INET_DIAG_BBRINFO - 1)) || ext & (1 << (INET_DIAG_VEGASINFO - 1))) { - struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u64 bw = bbr_bw(sk); - - bw = bw * tp->mss_cache * USEC_PER_SEC >> BW_SCALE; - memset(&info->bbr, 0, sizeof(info->bbr)); - info->bbr.bbr_bw_lo = (u32)bw; - info->bbr.bbr_bw_hi = (u32)(bw >> 32); - info->bbr.bbr_min_rtt = bbr->min_rtt_us; - info->bbr.bbr_pacing_gain = bbr->pacing_gain; - info->bbr.bbr_cwnd_gain = bbr->cwnd_gain; + u64 bw = bbr_bw_bytes_per_sec(sk, bbr_bw(sk)); + u64 bw_hi = bbr_bw_bytes_per_sec(sk, bbr_max_bw(sk)); + u64 bw_lo = bbr->bw_lo == ~0U ? + ~0ULL : bbr_bw_bytes_per_sec(sk, bbr->bw_lo); + struct tcp_bbr_info *bbr_info = &info->bbr; + + memset(bbr_info, 0, sizeof(*bbr_info)); + bbr_info->bbr_bw_lo = (u32)bw; + bbr_info->bbr_bw_hi = (u32)(bw >> 32); + bbr_info->bbr_min_rtt = bbr->min_rtt_us; + bbr_info->bbr_pacing_gain = bbr->pacing_gain; + bbr_info->bbr_cwnd_gain = bbr->cwnd_gain; + bbr_info->bbr_bw_hi_lsb = (u32)bw_hi; + bbr_info->bbr_bw_hi_msb = (u32)(bw_hi >> 32); + bbr_info->bbr_bw_lo_lsb = (u32)bw_lo; + bbr_info->bbr_bw_lo_msb = (u32)(bw_lo >> 32); + bbr_info->bbr_mode = bbr->mode; + bbr_info->bbr_phase = (__u8)bbr_get_phase(bbr); + bbr_info->bbr_version = (__u8)BBR_VERSION; + bbr_info->bbr_inflight_lo = bbr->inflight_lo; + bbr_info->bbr_inflight_hi = bbr->inflight_hi; + bbr_info->bbr_extra_acked = bbr_extra_acked(sk); *attr = INET_DIAG_BBRINFO; - return sizeof(info->bbr); + return sizeof(*bbr_info); } return 0; } __bpf_kfunc static void bbr_set_state(struct sock *sk, u8 new_state) { + struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); if (new_state == TCP_CA_Loss) { - struct rate_sample rs = { .losses = 1 }; bbr->prev_ca_state = TCP_CA_Loss; - bbr->full_bw = 0; - bbr->round_start = 1; /* treat RTO like end of a round */ - bbr_lt_bw_sampling(sk, &rs); + tcp_plb_update_state_upon_rto(sk, &bbr->plb); + /* The tcp_write_timeout() call to sk_rethink_txhash() likely + * repathed this flow, so re-learn the min network RTT on the + * new path: + */ + bbr_reset_full_bw(sk); + if (!bbr_is_probing_bandwidth(sk) && bbr->inflight_lo == ~0U) { + /* bbr_adapt_lower_bounds() needs cwnd before + * we suffered an RTO, to update inflight_lo: + */ + bbr->inflight_lo = + max(tcp_snd_cwnd(tp), bbr->prior_cwnd); + } + } else if (bbr->prev_ca_state == TCP_CA_Loss && + new_state != TCP_CA_Loss) { + bbr_exit_loss_recovery(sk); } } + static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { - .flags = TCP_CONG_NON_RESTRICTED, + .flags = TCP_CONG_NON_RESTRICTED | TCP_CONG_WANTS_CE_EVENTS, .name = "bbr", .owner = THIS_MODULE, .init = bbr_init, .cong_control = bbr_main, .sndbuf_expand = bbr_sndbuf_expand, + .skb_marked_lost = bbr_skb_marked_lost, .undo_cwnd = bbr_undo_cwnd, .cwnd_event = bbr_cwnd_event, .ssthresh = bbr_ssthresh, - .min_tso_segs = bbr_min_tso_segs, + .tso_segs = bbr_tso_segs, .get_info = bbr_get_info, .set_state = bbr_set_state, }; @@ -1159,10 +2359,11 @@ BTF_KFUNCS_START(tcp_bbr_check_kfunc_ids) BTF_ID_FLAGS(func, bbr_init) BTF_ID_FLAGS(func, bbr_main) BTF_ID_FLAGS(func, bbr_sndbuf_expand) +BTF_ID_FLAGS(func, bbr_skb_marked_lost) BTF_ID_FLAGS(func, bbr_undo_cwnd) BTF_ID_FLAGS(func, bbr_cwnd_event) BTF_ID_FLAGS(func, bbr_ssthresh) -BTF_ID_FLAGS(func, bbr_min_tso_segs) +BTF_ID_FLAGS(func, bbr_tso_segs) BTF_ID_FLAGS(func, bbr_set_state) BTF_KFUNCS_END(tcp_bbr_check_kfunc_ids) @@ -1195,5 +2396,12 @@ MODULE_AUTHOR("Van Jacobson "); MODULE_AUTHOR("Neal Cardwell "); MODULE_AUTHOR("Yuchung Cheng "); MODULE_AUTHOR("Soheil Hassas Yeganeh "); +MODULE_AUTHOR("Priyaranjan Jha "); +MODULE_AUTHOR("Yousuk Seung "); +MODULE_AUTHOR("Kevin Yang "); +MODULE_AUTHOR("Arjun Roy "); +MODULE_AUTHOR("David Morley "); + MODULE_LICENSE("Dual BSD/GPL"); MODULE_DESCRIPTION("TCP BBR (Bottleneck Bandwidth and RTT)"); +MODULE_VERSION(__stringify(BBR_VERSION)); diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c index 0306d257fa64..28f581c0dab7 100644 --- a/net/ipv4/tcp_cong.c +++ b/net/ipv4/tcp_cong.c @@ -237,6 +237,7 @@ void tcp_init_congestion_control(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); tcp_sk(sk)->prior_ssthresh = 0; + tcp_sk(sk)->fast_ack_mode = 0; if (icsk->icsk_ca_ops->init) icsk->icsk_ca_ops->init(sk); if (tcp_ca_needs_ecn(sk)) diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 2d844e1f867f..efb92e47a632 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -370,7 +370,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tcp_enter_quickack_mode(sk, 2); break; case INET_ECN_CE: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { @@ -381,7 +381,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tp->ecn_flags |= TCP_ECN_SEEN; break; default: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); tp->ecn_flags |= TCP_ECN_SEEN; break; @@ -1120,7 +1120,12 @@ static void tcp_verify_retransmit_hint(struct tcp_sock *tp, struct sk_buff *skb) */ static void tcp_notify_skb_loss_event(struct tcp_sock *tp, const struct sk_buff *skb) { + struct sock *sk = (struct sock *)tp; + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + tp->lost += tcp_skb_pcount(skb); + if (ca_ops->skb_marked_lost) + ca_ops->skb_marked_lost(sk, skb); } void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) @@ -1501,6 +1506,17 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount); tcp_skb_pcount_add(skb, -pcount); + /* Adjust tx.in_flight as pcount is shifted from skb to prev. */ + if (WARN_ONCE(TCP_SKB_CB(skb)->tx.in_flight < pcount, + "prev in_flight: %u skb in_flight: %u pcount: %u", + TCP_SKB_CB(prev)->tx.in_flight, + TCP_SKB_CB(skb)->tx.in_flight, + pcount)) + TCP_SKB_CB(skb)->tx.in_flight = 0; + else + TCP_SKB_CB(skb)->tx.in_flight -= pcount; + TCP_SKB_CB(prev)->tx.in_flight += pcount; + /* When we're adding to gso_segs == 1, gso_size will be zero, * in theory this shouldn't be necessary but as long as DSACK * code can come after this skb later on it's better to keep @@ -3826,7 +3842,8 @@ static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) /* This routine deals with acks during a TLP episode and ends an episode by * resetting tlp_high_seq. Ref: TLP algorithm in draft-ietf-tcpm-rack */ -static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) +static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag, + struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); @@ -3843,6 +3860,7 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) /* ACK advances: there was a loss, so reduce cwnd. Reset * tlp_high_seq in tcp_init_cwnd_reduction() */ + tcp_ca_event(sk, CA_EVENT_TLP_RECOVERY); tcp_init_cwnd_reduction(sk); tcp_set_ca_state(sk, TCP_CA_CWR); tcp_end_cwnd_reduction(sk); @@ -3853,6 +3871,11 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) FLAG_NOT_DUP | FLAG_DATA_SACKED))) { /* Pure dupack: original and TLP probe arrived; no loss */ tp->tlp_high_seq = 0; + } else { + /* This ACK matches a TLP retransmit. We cannot yet tell if + * this ACK is for the original or the TLP retransmit. + */ + rs->is_acking_tlp_retrans_seq = 1; } } @@ -3961,6 +3984,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) prior_fack = tcp_is_sack(tp) ? tcp_highest_sack_seq(tp) : tp->snd_una; rs.prior_in_flight = tcp_packets_in_flight(tp); + tcp_rate_check_app_limited(sk); /* ts_recent update must be made after we are sure that the packet * is in window. @@ -4035,7 +4059,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_rack_update_reo_wnd(sk, &rs); if (tp->tlp_high_seq) - tcp_process_tlp_ack(sk, ack, flag); + tcp_process_tlp_ack(sk, ack, flag, &rs); if (tcp_ack_is_dubious(sk, flag)) { if (!(flag & (FLAG_SND_UNA_ADVANCED | @@ -4059,6 +4083,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) delivered = tcp_newly_delivered(sk, delivered, flag); lost = tp->lost - lost; /* freshly marked lost */ rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); + rs.is_ece = !!(flag & FLAG_ECE); tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); tcp_cong_control(sk, ack, delivered, flag, sack_state.rate); tcp_xmit_recovery(sk, rexmit); @@ -4078,7 +4103,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_ack_probe(sk); if (tp->tlp_high_seq) - tcp_process_tlp_ack(sk, ack, flag); + tcp_process_tlp_ack(sk, ack, flag, &rs); return 1; old_ack: @@ -5752,13 +5777,14 @@ static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) /* More than one full frame received... */ if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && + (tp->fast_ack_mode == 1 || /* ... and right edge of window advances far enough. * (tcp_recvmsg() will send ACK otherwise). * If application uses SO_RCVLOWAT, we want send ack now if * we have not received enough bytes to satisfy the condition. */ - (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || - __tcp_select_window(sk) >= tp->rcv_wnd)) || + (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || + __tcp_select_window(sk) >= tp->rcv_wnd))) || /* We ACK each frame or... */ tcp_in_quickack_mode(sk) || /* Protocol state mandates a one-time immediate ACK */ diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index bb1fe1ba867a..050a80769de6 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -462,6 +462,8 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) u32 ca_key = dst_metric(dst, RTAX_CC_ALGO); bool ca_got_dst = false; + tcp_set_ecn_low_from_dst(sk, dst); + if (ca_key != TCP_CA_UNSPEC) { const struct tcp_congestion_ops *ca; diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 68804fd01daf..afdb62febe42 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -336,10 +336,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk); bool use_ecn = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn) == 1 || tcp_ca_needs_ecn(sk) || bpf_needs_ecn; + const struct dst_entry *dst = __sk_dst_get(sk); if (!use_ecn) { - const struct dst_entry *dst = __sk_dst_get(sk); - if (dst && dst_feature(dst, RTAX_FEATURE_ECN)) use_ecn = true; } @@ -351,6 +350,9 @@ static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) tp->ecn_flags = TCP_ECN_OK; if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn) INET_ECN_xmit(sk); + + if (dst) + tcp_set_ecn_low_from_dst(sk, dst); } } @@ -388,7 +390,8 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, th->cwr = 1; skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; } - } else if (!tcp_ca_needs_ecn(sk)) { + } else if (!(tp->ecn_flags & TCP_ECN_ECT_PERMANENT) && + !tcp_ca_needs_ecn(sk)) { /* ACK or retransmitted segment: clear ECT|CE */ INET_ECN_dontxmit(sk); } @@ -1601,7 +1604,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, { struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *buff; - int old_factor; + int old_factor, inflight_prev; long limit; int nlen; u8 flags; @@ -1676,6 +1679,30 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, if (diff) tcp_adjust_pcount(sk, skb, diff); + + inflight_prev = TCP_SKB_CB(skb)->tx.in_flight - old_factor; + if (inflight_prev < 0) { + WARN_ONCE(tcp_skb_tx_in_flight_is_suspicious( + old_factor, + TCP_SKB_CB(skb)->sacked, + TCP_SKB_CB(skb)->tx.in_flight), + "inconsistent: tx.in_flight: %u " + "old_factor: %d mss: %u sacked: %u " + "1st pcount: %d 2nd pcount: %d " + "1st len: %u 2nd len: %u ", + TCP_SKB_CB(skb)->tx.in_flight, old_factor, + mss_now, TCP_SKB_CB(skb)->sacked, + tcp_skb_pcount(skb), tcp_skb_pcount(buff), + skb->len, buff->len); + inflight_prev = 0; + } + /* Set 1st tx.in_flight as if 1st were sent by itself: */ + TCP_SKB_CB(skb)->tx.in_flight = inflight_prev + + tcp_skb_pcount(skb); + /* Set 2nd tx.in_flight with new 1st and 2nd pcounts: */ + TCP_SKB_CB(buff)->tx.in_flight = inflight_prev + + tcp_skb_pcount(skb) + + tcp_skb_pcount(buff); } /* Link BUFF into the send queue. */ @@ -2033,13 +2060,12 @@ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) { const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; - u32 min_tso, tso_segs; - - min_tso = ca_ops->min_tso_segs ? - ca_ops->min_tso_segs(sk) : - READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + u32 tso_segs; - tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); + tso_segs = ca_ops->tso_segs ? + ca_ops->tso_segs(sk, mss_now) : + tcp_tso_autosize(sk, mss_now, + sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); return min_t(u32, tso_segs, sk->sk_gso_max_segs); } @@ -2765,6 +2791,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, skb_set_delivery_time(skb, tp->tcp_wstamp_ns, SKB_CLOCK_MONOTONIC); list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); tcp_init_tso_segs(skb, mss_now); + tcp_set_tx_in_flight(sk, skb); goto repair; /* Skip network transmission */ } @@ -2979,6 +3006,7 @@ void tcp_send_loss_probe(struct sock *sk) if (WARN_ON(!skb || !tcp_skb_pcount(skb))) goto rearm_timer; + tp->tlp_orig_data_app_limited = TCP_SKB_CB(skb)->tx.is_app_limited; if (__tcp_retransmit_skb(sk, skb, 1)) goto rearm_timer; diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c index a8f6d9d06f2e..8737f2134648 100644 --- a/net/ipv4/tcp_rate.c +++ b/net/ipv4/tcp_rate.c @@ -34,6 +34,24 @@ * ready to send in the write queue. */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 in_flight; + + /* Check, sanitize, and record packets in flight after skb was sent. */ + in_flight = tcp_packets_in_flight(tp) + tcp_skb_pcount(skb); + if (WARN_ONCE(in_flight > TCPCB_IN_FLIGHT_MAX, + "insane in_flight %u cc %s mss %u " + "cwnd %u pif %u %u %u %u\n", + in_flight, inet_csk(sk)->icsk_ca_ops->name, + tp->mss_cache, tp->snd_cwnd, + tp->packets_out, tp->retrans_out, + tp->sacked_out, tp->lost_out)) + in_flight = TCPCB_IN_FLIGHT_MAX; + TCP_SKB_CB(skb)->tx.in_flight = in_flight; +} + /* Snapshot the current delivery information in the skb, to generate * a rate sample later when the skb is (s)acked in tcp_rate_skb_delivered(). */ @@ -66,7 +84,9 @@ void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->tx.delivered_mstamp = tp->delivered_mstamp; TCP_SKB_CB(skb)->tx.delivered = tp->delivered; TCP_SKB_CB(skb)->tx.delivered_ce = tp->delivered_ce; + TCP_SKB_CB(skb)->tx.lost = tp->lost; TCP_SKB_CB(skb)->tx.is_app_limited = tp->app_limited ? 1 : 0; + tcp_set_tx_in_flight(sk, skb); } /* When an skb is sacked or acked, we fill in the rate sample with the (prior) @@ -91,18 +111,21 @@ void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, if (!rs->prior_delivered || tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, scb->end_seq, rs->last_end_seq)) { + rs->prior_lost = scb->tx.lost; rs->prior_delivered_ce = scb->tx.delivered_ce; rs->prior_delivered = scb->tx.delivered; rs->prior_mstamp = scb->tx.delivered_mstamp; rs->is_app_limited = scb->tx.is_app_limited; rs->is_retrans = scb->sacked & TCPCB_RETRANS; + rs->tx_in_flight = scb->tx.in_flight; rs->last_end_seq = scb->end_seq; /* Record send time of most recently ACKed packet: */ tp->first_tx_mstamp = tx_tstamp; /* Find the duration of the "send phase" of this window: */ - rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, - scb->tx.first_tx_mstamp); + rs->interval_us = tcp_stamp32_us_delta( + tp->first_tx_mstamp, + scb->tx.first_tx_mstamp); } /* Mark off the skb delivered once it's sacked to avoid being @@ -144,6 +167,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, return; } rs->delivered = tp->delivered - rs->prior_delivered; + rs->lost = tp->lost - rs->prior_lost; rs->delivered_ce = tp->delivered_ce - rs->prior_delivered_ce; /* delivered_ce occupies less than 32 bits in the skb control block */ @@ -155,7 +179,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, * longer phase. */ snd_us = rs->interval_us; /* send phase */ - ack_us = tcp_stamp_us_delta(tp->tcp_mstamp, + ack_us = tcp_stamp32_us_delta(tp->tcp_mstamp, rs->prior_mstamp); /* ack phase */ rs->interval_us = max(snd_us, ack_us); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index 79064580c8c0..697270ce1ea6 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -690,6 +690,7 @@ void tcp_write_timer_handler(struct sock *sk) return; } + tcp_rate_check_app_limited(sk); tcp_mstamp_refresh(tcp_sk(sk)); event = icsk->icsk_pending; -- 2.47.0 From d87383343350575ce203091b2001bde085b12fc9 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:18:05 +0100 Subject: [PATCH 05/13] cachy Signed-off-by: Peter Jung --- .../admin-guide/kernel-parameters.txt | 12 + Makefile | 8 + arch/x86/Kconfig.cpu | 359 +- arch/x86/Makefile | 87 +- arch/x86/include/asm/pci.h | 6 + arch/x86/include/asm/vermagic.h | 70 + arch/x86/pci/common.c | 7 +- drivers/Makefile | 13 +- drivers/ata/ahci.c | 23 +- drivers/cpufreq/Kconfig.x86 | 2 - drivers/cpufreq/intel_pstate.c | 2 + drivers/gpu/drm/amd/amdgpu/amdgpu.h | 1 + drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 10 + drivers/gpu/drm/amd/display/Kconfig | 6 + .../gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c | 2 +- .../amd/display/amdgpu_dm/amdgpu_dm_color.c | 2 +- .../amd/display/amdgpu_dm/amdgpu_dm_crtc.c | 6 +- .../amd/display/amdgpu_dm/amdgpu_dm_plane.c | 6 +- drivers/gpu/drm/amd/pm/amdgpu_pm.c | 3 + drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c | 14 +- drivers/input/evdev.c | 19 +- drivers/md/dm-crypt.c | 5 + drivers/media/v4l2-core/Kconfig | 5 + drivers/media/v4l2-core/Makefile | 2 + drivers/media/v4l2-core/v4l2loopback.c | 3184 +++++++++++++++++ drivers/media/v4l2-core/v4l2loopback.h | 98 + .../media/v4l2-core/v4l2loopback_formats.h | 445 +++ drivers/pci/controller/Makefile | 6 + drivers/pci/controller/intel-nvme-remap.c | 462 +++ drivers/pci/quirks.c | 101 + include/linux/pagemap.h | 2 +- include/linux/user_namespace.h | 4 + include/linux/wait.h | 2 + init/Kconfig | 26 + kernel/Kconfig.hz | 24 + kernel/fork.c | 14 + kernel/locking/rwsem.c | 4 +- kernel/sched/fair.c | 13 + kernel/sched/sched.h | 2 +- kernel/sched/wait.c | 24 + kernel/sysctl.c | 12 + kernel/user_namespace.c | 7 + mm/Kconfig | 2 +- mm/compaction.c | 4 + mm/page-writeback.c | 8 + mm/page_alloc.c | 4 + mm/swap.c | 5 + mm/vmpressure.c | 4 + mm/vmscan.c | 8 + net/ipv4/inet_connection_sock.c | 2 +- 50 files changed, 5073 insertions(+), 64 deletions(-) create mode 100644 drivers/media/v4l2-core/v4l2loopback.c create mode 100644 drivers/media/v4l2-core/v4l2loopback.h create mode 100644 drivers/media/v4l2-core/v4l2loopback_formats.h create mode 100644 drivers/pci/controller/intel-nvme-remap.c diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt index 1666576acc0e..5b0b02e6988a 100644 --- a/Documentation/admin-guide/kernel-parameters.txt +++ b/Documentation/admin-guide/kernel-parameters.txt @@ -2248,6 +2248,9 @@ disable Do not enable intel_pstate as the default scaling driver for the supported processors + enable + Enable intel_pstate in-case "disable" was passed + previously in the kernel boot parameters active Use intel_pstate driver to bypass the scaling governors layer of cpufreq and provides it own @@ -4473,6 +4476,15 @@ nomsi [MSI] If the PCI_MSI kernel config parameter is enabled, this kernel boot option can be used to disable the use of MSI interrupts system-wide. + pcie_acs_override = + [PCIE] Override missing PCIe ACS support for: + downstream + All downstream ports - full ACS capabilities + multfunction + All multifunction devices - multifunction ACS subset + id:nnnn:nnnn + Specfic device - full ACS capabilities + Specified as vid:did (vendor/device ID) in hex noioapicquirk [APIC] Disable all boot interrupt quirks. Safety option to keep boot IRQs enabled. This should never be necessary. diff --git a/Makefile b/Makefile index e619df4e09b8..7223a0d87413 100644 --- a/Makefile +++ b/Makefile @@ -801,11 +801,19 @@ KBUILD_CFLAGS += -fno-delete-null-pointer-checks ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE KBUILD_CFLAGS += -O2 KBUILD_RUSTFLAGS += -Copt-level=2 +else ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3 +KBUILD_CFLAGS += -O3 +KBUILD_RUSTFLAGS += -Copt-level=3 else ifdef CONFIG_CC_OPTIMIZE_FOR_SIZE KBUILD_CFLAGS += -Os KBUILD_RUSTFLAGS += -Copt-level=s endif +# Perform swing modulo scheduling immediately before the first scheduling pass. +# This pass looks at innermost loops and reorders their instructions by +# overlapping different iterations. +KBUILD_CFLAGS += $(call cc-option,-fmodulo-sched -fmodulo-sched-allow-regmoves -fivopts -fmodulo-sched) + # Always set `debug-assertions` and `overflow-checks` because their default # depends on `opt-level` and `debug-assertions`, respectively. KBUILD_RUSTFLAGS += -Cdebug-assertions=$(if $(CONFIG_RUST_DEBUG_ASSERTIONS),y,n) diff --git a/arch/x86/Kconfig.cpu b/arch/x86/Kconfig.cpu index 2a7279d80460..f5849153b385 100644 --- a/arch/x86/Kconfig.cpu +++ b/arch/x86/Kconfig.cpu @@ -155,9 +155,8 @@ config MPENTIUM4 -Paxville -Dempsey - config MK6 - bool "K6/K6-II/K6-III" + bool "AMD K6/K6-II/K6-III" depends on X86_32 help Select this for an AMD K6-family processor. Enables use of @@ -165,7 +164,7 @@ config MK6 flags to GCC. config MK7 - bool "Athlon/Duron/K7" + bool "AMD Athlon/Duron/K7" depends on X86_32 help Select this for an AMD Athlon K7-family processor. Enables use of @@ -173,12 +172,114 @@ config MK7 flags to GCC. config MK8 - bool "Opteron/Athlon64/Hammer/K8" + bool "AMD Opteron/Athlon64/Hammer/K8" help Select this for an AMD Opteron or Athlon64 Hammer-family processor. Enables use of some extended instructions, and passes appropriate optimization flags to GCC. +config MK8SSE3 + bool "AMD Opteron/Athlon64/Hammer/K8 with SSE3" + help + Select this for improved AMD Opteron or Athlon64 Hammer-family processors. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MK10 + bool "AMD 61xx/7x50/PhenomX3/X4/II/K10" + help + Select this for an AMD 61xx Eight-Core Magny-Cours, Athlon X2 7x50, + Phenom X3/X4/II, Athlon II X2/X3/X4, or Turion II-family processor. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MBARCELONA + bool "AMD Barcelona" + help + Select this for AMD Family 10h Barcelona processors. + + Enables -march=barcelona + +config MBOBCAT + bool "AMD Bobcat" + help + Select this for AMD Family 14h Bobcat processors. + + Enables -march=btver1 + +config MJAGUAR + bool "AMD Jaguar" + help + Select this for AMD Family 16h Jaguar processors. + + Enables -march=btver2 + +config MBULLDOZER + bool "AMD Bulldozer" + help + Select this for AMD Family 15h Bulldozer processors. + + Enables -march=bdver1 + +config MPILEDRIVER + bool "AMD Piledriver" + help + Select this for AMD Family 15h Piledriver processors. + + Enables -march=bdver2 + +config MSTEAMROLLER + bool "AMD Steamroller" + help + Select this for AMD Family 15h Steamroller processors. + + Enables -march=bdver3 + +config MEXCAVATOR + bool "AMD Excavator" + help + Select this for AMD Family 15h Excavator processors. + + Enables -march=bdver4 + +config MZEN + bool "AMD Zen" + help + Select this for AMD Family 17h Zen processors. + + Enables -march=znver1 + +config MZEN2 + bool "AMD Zen 2" + help + Select this for AMD Family 17h Zen 2 processors. + + Enables -march=znver2 + +config MZEN3 + bool "AMD Zen 3" + depends on (CC_IS_GCC && GCC_VERSION >= 100300) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + Select this for AMD Family 19h Zen 3 processors. + + Enables -march=znver3 + +config MZEN4 + bool "AMD Zen 4" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 160000) + help + Select this for AMD Family 19h Zen 4 processors. + + Enables -march=znver4 + +config MZEN5 + bool "AMD Zen 5" + depends on (CC_IS_GCC && GCC_VERSION > 140000) || (CC_IS_CLANG && CLANG_VERSION >= 191000) + help + Select this for AMD Family 19h Zen 5 processors. + + Enables -march=znver5 + config MCRUSOE bool "Crusoe" depends on X86_32 @@ -269,8 +370,17 @@ config MPSC using the cpu family field in /proc/cpuinfo. Family 15 is an older Xeon, Family 6 a newer one. +config MATOM + bool "Intel Atom" + help + + Select this for the Intel Atom platform. Intel Atom CPUs have an + in-order pipelining architecture and thus can benefit from + accordingly optimized code. Use a recent GCC with specific Atom + support in order to fully benefit from selecting this option. + config MCORE2 - bool "Core 2/newer Xeon" + bool "Intel Core 2" help Select this for Intel Core 2 and newer Core 2 Xeons (Xeon 51xx and @@ -278,14 +388,191 @@ config MCORE2 family in /proc/cpuinfo. Newer ones have 6 and older ones 15 (not a typo) -config MATOM - bool "Intel Atom" + Enables -march=core2 + +config MNEHALEM + bool "Intel Nehalem" help - Select this for the Intel Atom platform. Intel Atom CPUs have an - in-order pipelining architecture and thus can benefit from - accordingly optimized code. Use a recent GCC with specific Atom - support in order to fully benefit from selecting this option. + Select this for 1st Gen Core processors in the Nehalem family. + + Enables -march=nehalem + +config MWESTMERE + bool "Intel Westmere" + help + + Select this for the Intel Westmere formerly Nehalem-C family. + + Enables -march=westmere + +config MSILVERMONT + bool "Intel Silvermont" + help + + Select this for the Intel Silvermont platform. + + Enables -march=silvermont + +config MGOLDMONT + bool "Intel Goldmont" + help + + Select this for the Intel Goldmont platform including Apollo Lake and Denverton. + + Enables -march=goldmont + +config MGOLDMONTPLUS + bool "Intel Goldmont Plus" + help + + Select this for the Intel Goldmont Plus platform including Gemini Lake. + + Enables -march=goldmont-plus + +config MSANDYBRIDGE + bool "Intel Sandy Bridge" + help + + Select this for 2nd Gen Core processors in the Sandy Bridge family. + + Enables -march=sandybridge + +config MIVYBRIDGE + bool "Intel Ivy Bridge" + help + + Select this for 3rd Gen Core processors in the Ivy Bridge family. + + Enables -march=ivybridge + +config MHASWELL + bool "Intel Haswell" + help + + Select this for 4th Gen Core processors in the Haswell family. + + Enables -march=haswell + +config MBROADWELL + bool "Intel Broadwell" + help + + Select this for 5th Gen Core processors in the Broadwell family. + + Enables -march=broadwell + +config MSKYLAKE + bool "Intel Skylake" + help + + Select this for 6th Gen Core processors in the Skylake family. + + Enables -march=skylake + +config MSKYLAKEX + bool "Intel Skylake X" + help + + Select this for 6th Gen Core processors in the Skylake X family. + + Enables -march=skylake-avx512 + +config MCANNONLAKE + bool "Intel Cannon Lake" + help + + Select this for 8th Gen Core processors + + Enables -march=cannonlake + +config MICELAKE + bool "Intel Ice Lake" + help + + Select this for 10th Gen Core processors in the Ice Lake family. + + Enables -march=icelake-client + +config MCASCADELAKE + bool "Intel Cascade Lake" + help + + Select this for Xeon processors in the Cascade Lake family. + + Enables -march=cascadelake + +config MCOOPERLAKE + bool "Intel Cooper Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + help + + Select this for Xeon processors in the Cooper Lake family. + + Enables -march=cooperlake + +config MTIGERLAKE + bool "Intel Tiger Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + help + + Select this for third-generation 10 nm process processors in the Tiger Lake family. + + Enables -march=tigerlake + +config MSAPPHIRERAPIDS + bool "Intel Sapphire Rapids" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for fourth-generation 10 nm process processors in the Sapphire Rapids family. + + Enables -march=sapphirerapids + +config MROCKETLAKE + bool "Intel Rocket Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for eleventh-generation processors in the Rocket Lake family. + + Enables -march=rocketlake + +config MALDERLAKE + bool "Intel Alder Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + + Select this for twelfth-generation processors in the Alder Lake family. + + Enables -march=alderlake + +config MRAPTORLAKE + bool "Intel Raptor Lake" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for thirteenth-generation processors in the Raptor Lake family. + + Enables -march=raptorlake + +config MMETEORLAKE + bool "Intel Meteor Lake" + depends on (CC_IS_GCC && GCC_VERSION >= 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for fourteenth-generation processors in the Meteor Lake family. + + Enables -march=meteorlake + +config MEMERALDRAPIDS + bool "Intel Emerald Rapids" + depends on (CC_IS_GCC && GCC_VERSION > 130000) || (CC_IS_CLANG && CLANG_VERSION >= 150500) + help + + Select this for fifth-generation 10 nm process processors in the Emerald Rapids family. + + Enables -march=emeraldrapids config GENERIC_CPU bool "Generic-x86-64" @@ -294,6 +581,26 @@ config GENERIC_CPU Generic x86-64 CPU. Run equally well on all x86-64 CPUs. +config MNATIVE_INTEL + bool "Intel-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for AMD CPUs. Intel Only! + + Enables -march=native + +config MNATIVE_AMD + bool "AMD-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for Intel CPUs. AMD Only! + + Enables -march=native + endchoice config X86_GENERIC @@ -308,6 +615,30 @@ config X86_GENERIC This is really intended for distributors who need more generic optimizations. +config X86_64_VERSION + int "x86-64 compiler ISA level" + range 1 4 + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + depends on X86_64 && GENERIC_CPU + help + Specify a specific x86-64 compiler ISA level. + + There are three x86-64 ISA levels that work on top of + the x86-64 baseline, namely: x86-64-v2, x86-64-v3, and x86-64-v4. + + x86-64-v2 brings support for vector instructions up to Streaming SIMD + Extensions 4.2 (SSE4.2) and Supplemental Streaming SIMD Extensions 3 + (SSSE3), the POPCNT instruction, and CMPXCHG16B. + + x86-64-v3 adds vector instructions up to AVX2, MOVBE, and additional + bit-manipulation instructions. + + x86-64-v4 is not included since the kernel does not use AVX512 instructions + + You can find the best version for your CPU by running one of the following: + /lib/ld-linux-x86-64.so.2 --help | grep supported + /lib64/ld-linux-x86-64.so.2 --help | grep supported + # # Define implied options from the CPU selection here config X86_INTERNODE_CACHE_SHIFT @@ -318,7 +649,7 @@ config X86_INTERNODE_CACHE_SHIFT config X86_L1_CACHE_SHIFT int default "7" if MPENTIUM4 || MPSC - default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU + default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD default "4" if MELAN || M486SX || M486 || MGEODEGX1 default "5" if MWINCHIP3D || MWINCHIPC6 || MCRUSOE || MEFFICEON || MCYRIXIII || MK6 || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || M586 || MVIAC3_2 || MGEODE_LX @@ -336,11 +667,11 @@ config X86_ALIGNMENT_16 config X86_INTEL_USERCOPY def_bool y - depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 + depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL config X86_USE_PPRO_CHECKSUM def_bool y - depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM + depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MZEN4 || MZEN5 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MRAPTORLAKE || MMETEORLAKE || MEMERALDRAPIDS || MNATIVE_INTEL || MNATIVE_AMD # # P6_NOPs are a relatively minor optimization that require a family >= diff --git a/arch/x86/Makefile b/arch/x86/Makefile index cd75e78a06c1..396d1db12bca 100644 --- a/arch/x86/Makefile +++ b/arch/x86/Makefile @@ -181,15 +181,96 @@ else cflags-$(CONFIG_MK8) += -march=k8 cflags-$(CONFIG_MPSC) += -march=nocona cflags-$(CONFIG_MCORE2) += -march=core2 - cflags-$(CONFIG_MATOM) += -march=atom - cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic + cflags-$(CONFIG_MATOM) += -march=bonnell + ifeq ($(CONFIG_X86_64_VERSION),1) + cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic + rustflags-$(CONFIG_GENERIC_CPU) += -Ztune-cpu=generic + else + cflags-$(CONFIG_GENERIC_CPU) += -march=x86-64-v$(CONFIG_X86_64_VERSION) + rustflags-$(CONFIG_GENERIC_CPU) += -Ctarget-cpu=x86-64-v$(CONFIG_X86_64_VERSION) + endif + cflags-$(CONFIG_MK8SSE3) += -march=k8-sse3 + cflags-$(CONFIG_MK10) += -march=amdfam10 + cflags-$(CONFIG_MBARCELONA) += -march=barcelona + cflags-$(CONFIG_MBOBCAT) += -march=btver1 + cflags-$(CONFIG_MJAGUAR) += -march=btver2 + cflags-$(CONFIG_MBULLDOZER) += -march=bdver1 + cflags-$(CONFIG_MPILEDRIVER) += -march=bdver2 -mno-tbm + cflags-$(CONFIG_MSTEAMROLLER) += -march=bdver3 -mno-tbm + cflags-$(CONFIG_MEXCAVATOR) += -march=bdver4 -mno-tbm + cflags-$(CONFIG_MZEN) += -march=znver1 + cflags-$(CONFIG_MZEN2) += -march=znver2 + cflags-$(CONFIG_MZEN3) += -march=znver3 + cflags-$(CONFIG_MZEN4) += -march=znver4 + cflags-$(CONFIG_MZEN5) += -march=znver5 + cflags-$(CONFIG_MNATIVE_INTEL) += -march=native + cflags-$(CONFIG_MNATIVE_AMD) += -march=native -mno-tbm + cflags-$(CONFIG_MNEHALEM) += -march=nehalem + cflags-$(CONFIG_MWESTMERE) += -march=westmere + cflags-$(CONFIG_MSILVERMONT) += -march=silvermont + cflags-$(CONFIG_MGOLDMONT) += -march=goldmont + cflags-$(CONFIG_MGOLDMONTPLUS) += -march=goldmont-plus + cflags-$(CONFIG_MSANDYBRIDGE) += -march=sandybridge + cflags-$(CONFIG_MIVYBRIDGE) += -march=ivybridge + cflags-$(CONFIG_MHASWELL) += -march=haswell + cflags-$(CONFIG_MBROADWELL) += -march=broadwell + cflags-$(CONFIG_MSKYLAKE) += -march=skylake + cflags-$(CONFIG_MSKYLAKEX) += -march=skylake-avx512 + cflags-$(CONFIG_MCANNONLAKE) += -march=cannonlake + cflags-$(CONFIG_MICELAKE) += -march=icelake-client + cflags-$(CONFIG_MCASCADELAKE) += -march=cascadelake + cflags-$(CONFIG_MCOOPERLAKE) += -march=cooperlake + cflags-$(CONFIG_MTIGERLAKE) += -march=tigerlake + cflags-$(CONFIG_MSAPPHIRERAPIDS) += -march=sapphirerapids + cflags-$(CONFIG_MROCKETLAKE) += -march=rocketlake + cflags-$(CONFIG_MALDERLAKE) += -march=alderlake + cflags-$(CONFIG_MRAPTORLAKE) += -march=raptorlake + cflags-$(CONFIG_MMETEORLAKE) += -march=meteorlake + cflags-$(CONFIG_MEMERALDRAPIDS) += -march=emeraldrapids KBUILD_CFLAGS += $(cflags-y) rustflags-$(CONFIG_MK8) += -Ctarget-cpu=k8 rustflags-$(CONFIG_MPSC) += -Ctarget-cpu=nocona rustflags-$(CONFIG_MCORE2) += -Ctarget-cpu=core2 rustflags-$(CONFIG_MATOM) += -Ctarget-cpu=atom - rustflags-$(CONFIG_GENERIC_CPU) += -Ztune-cpu=generic + rustflags-$(CONFIG_MK8SSE3) += -Ctarget-cpu=k8-sse3 + rustflags-$(CONFIG_MK10) += -Ctarget-cpu=amdfam10 + rustflags-$(CONFIG_MBARCELONA) += -Ctarget-cpu=barcelona + rustflags-$(CONFIG_MBOBCAT) += -Ctarget-cpu=btver1 + rustflags-$(CONFIG_MJAGUAR) += -Ctarget-cpu=btver2 + rustflags-$(CONFIG_MBULLDOZER) += -Ctarget-cpu=bdver1 + rustflags-$(CONFIG_MPILEDRIVER) += -Ctarget-cpu=bdver2 + rustflags-$(CONFIG_MSTEAMROLLER) += -Ctarget-cpu=bdver3 + rustflags-$(CONFIG_MEXCAVATOR) += -Ctarget-cpu=bdver4 + rustflags-$(CONFIG_MZEN) += -Ctarget-cpu=znver1 + rustflags-$(CONFIG_MZEN2) += -Ctarget-cpu=znver2 + rustflags-$(CONFIG_MZEN3) += -Ctarget-cpu=znver3 + rustflags-$(CONFIG_MZEN4) += -Ctarget-cpu=znver4 + rustflags-$(CONFIG_MZEN5) += -Ctarget-cpu=znver5 + rustflags-$(CONFIG_MNATIVE_INTEL) += -Ctarget-cpu=native + rustflags-$(CONFIG_MNATIVE_AMD) += -Ctarget-cpu=native + rustflags-$(CONFIG_MNEHALEM) += -Ctarget-cpu=nehalem + rustflags-$(CONFIG_MWESTMERE) += -Ctarget-cpu=westmere + rustflags-$(CONFIG_MSILVERMONT) += -Ctarget-cpu=silvermont + rustflags-$(CONFIG_MGOLDMONT) += -Ctarget-cpu=goldmont + rustflags-$(CONFIG_MGOLDMONTPLUS) += -Ctarget-cpu=goldmont-plus + rustflags-$(CONFIG_MSANDYBRIDGE) += -Ctarget-cpu=sandybridge + rustflags-$(CONFIG_MIVYBRIDGE) += -Ctarget-cpu=ivybridge + rustflags-$(CONFIG_MHASWELL) += -Ctarget-cpu=haswell + rustflags-$(CONFIG_MBROADWELL) += -Ctarget-cpu=broadwell + rustflags-$(CONFIG_MSKYLAKE) += -Ctarget-cpu=skylake + rustflags-$(CONFIG_MSKYLAKEX) += -Ctarget-cpu=skylake-avx512 + rustflags-$(CONFIG_MCANNONLAKE) += -Ctarget-cpu=cannonlake + rustflags-$(CONFIG_MICELAKE) += -Ctarget-cpu=icelake-client + rustflags-$(CONFIG_MCASCADELAKE) += -Ctarget-cpu=cascadelake + rustflags-$(CONFIG_MCOOPERLAKE) += -Ctarget-cpu=cooperlake + rustflags-$(CONFIG_MTIGERLAKE) += -Ctarget-cpu=tigerlake + rustflags-$(CONFIG_MSAPPHIRERAPIDS) += -Ctarget-cpu=sapphirerapids + rustflags-$(CONFIG_MROCKETLAKE) += -Ctarget-cpu=rocketlake + rustflags-$(CONFIG_MALDERLAKE) += -Ctarget-cpu=alderlake + rustflags-$(CONFIG_MRAPTORLAKE) += -Ctarget-cpu=raptorlake + rustflags-$(CONFIG_MMETEORLAKE) += -Ctarget-cpu=meteorlake + rustflags-$(CONFIG_MEMERALDRAPIDS) += -Ctarget-cpu=emeraldrapids KBUILD_RUSTFLAGS += $(rustflags-y) KBUILD_CFLAGS += -mno-red-zone diff --git a/arch/x86/include/asm/pci.h b/arch/x86/include/asm/pci.h index b3ab80a03365..5e883b397ff3 100644 --- a/arch/x86/include/asm/pci.h +++ b/arch/x86/include/asm/pci.h @@ -26,6 +26,7 @@ struct pci_sysdata { #if IS_ENABLED(CONFIG_VMD) struct pci_dev *vmd_dev; /* VMD Device if in Intel VMD domain */ #endif + struct pci_dev *nvme_remap_dev; /* AHCI Device if NVME remapped bus */ }; extern int pci_routeirq; @@ -69,6 +70,11 @@ static inline bool is_vmd(struct pci_bus *bus) #define is_vmd(bus) false #endif /* CONFIG_VMD */ +static inline bool is_nvme_remap(struct pci_bus *bus) +{ + return to_pci_sysdata(bus)->nvme_remap_dev != NULL; +} + /* Can be used to override the logic in pci_scan_bus for skipping already-configured bus numbers - to be used for buggy BIOSes or architectures with incomplete PCI setup by the loader */ diff --git a/arch/x86/include/asm/vermagic.h b/arch/x86/include/asm/vermagic.h index 75884d2cdec3..f4e29563473d 100644 --- a/arch/x86/include/asm/vermagic.h +++ b/arch/x86/include/asm/vermagic.h @@ -17,6 +17,54 @@ #define MODULE_PROC_FAMILY "586MMX " #elif defined CONFIG_MCORE2 #define MODULE_PROC_FAMILY "CORE2 " +#elif defined CONFIG_MNATIVE_INTEL +#define MODULE_PROC_FAMILY "NATIVE_INTEL " +#elif defined CONFIG_MNATIVE_AMD +#define MODULE_PROC_FAMILY "NATIVE_AMD " +#elif defined CONFIG_MNEHALEM +#define MODULE_PROC_FAMILY "NEHALEM " +#elif defined CONFIG_MWESTMERE +#define MODULE_PROC_FAMILY "WESTMERE " +#elif defined CONFIG_MSILVERMONT +#define MODULE_PROC_FAMILY "SILVERMONT " +#elif defined CONFIG_MGOLDMONT +#define MODULE_PROC_FAMILY "GOLDMONT " +#elif defined CONFIG_MGOLDMONTPLUS +#define MODULE_PROC_FAMILY "GOLDMONTPLUS " +#elif defined CONFIG_MSANDYBRIDGE +#define MODULE_PROC_FAMILY "SANDYBRIDGE " +#elif defined CONFIG_MIVYBRIDGE +#define MODULE_PROC_FAMILY "IVYBRIDGE " +#elif defined CONFIG_MHASWELL +#define MODULE_PROC_FAMILY "HASWELL " +#elif defined CONFIG_MBROADWELL +#define MODULE_PROC_FAMILY "BROADWELL " +#elif defined CONFIG_MSKYLAKE +#define MODULE_PROC_FAMILY "SKYLAKE " +#elif defined CONFIG_MSKYLAKEX +#define MODULE_PROC_FAMILY "SKYLAKEX " +#elif defined CONFIG_MCANNONLAKE +#define MODULE_PROC_FAMILY "CANNONLAKE " +#elif defined CONFIG_MICELAKE +#define MODULE_PROC_FAMILY "ICELAKE " +#elif defined CONFIG_MCASCADELAKE +#define MODULE_PROC_FAMILY "CASCADELAKE " +#elif defined CONFIG_MCOOPERLAKE +#define MODULE_PROC_FAMILY "COOPERLAKE " +#elif defined CONFIG_MTIGERLAKE +#define MODULE_PROC_FAMILY "TIGERLAKE " +#elif defined CONFIG_MSAPPHIRERAPIDS +#define MODULE_PROC_FAMILY "SAPPHIRERAPIDS " +#elif defined CONFIG_ROCKETLAKE +#define MODULE_PROC_FAMILY "ROCKETLAKE " +#elif defined CONFIG_MALDERLAKE +#define MODULE_PROC_FAMILY "ALDERLAKE " +#elif defined CONFIG_MRAPTORLAKE +#define MODULE_PROC_FAMILY "RAPTORLAKE " +#elif defined CONFIG_MMETEORLAKE +#define MODULE_PROC_FAMILY "METEORLAKE " +#elif defined CONFIG_MEMERALDRAPIDS +#define MODULE_PROC_FAMILY "EMERALDRAPIDS " #elif defined CONFIG_MATOM #define MODULE_PROC_FAMILY "ATOM " #elif defined CONFIG_M686 @@ -35,6 +83,28 @@ #define MODULE_PROC_FAMILY "K7 " #elif defined CONFIG_MK8 #define MODULE_PROC_FAMILY "K8 " +#elif defined CONFIG_MK8SSE3 +#define MODULE_PROC_FAMILY "K8SSE3 " +#elif defined CONFIG_MK10 +#define MODULE_PROC_FAMILY "K10 " +#elif defined CONFIG_MBARCELONA +#define MODULE_PROC_FAMILY "BARCELONA " +#elif defined CONFIG_MBOBCAT +#define MODULE_PROC_FAMILY "BOBCAT " +#elif defined CONFIG_MBULLDOZER +#define MODULE_PROC_FAMILY "BULLDOZER " +#elif defined CONFIG_MPILEDRIVER +#define MODULE_PROC_FAMILY "PILEDRIVER " +#elif defined CONFIG_MSTEAMROLLER +#define MODULE_PROC_FAMILY "STEAMROLLER " +#elif defined CONFIG_MJAGUAR +#define MODULE_PROC_FAMILY "JAGUAR " +#elif defined CONFIG_MEXCAVATOR +#define MODULE_PROC_FAMILY "EXCAVATOR " +#elif defined CONFIG_MZEN +#define MODULE_PROC_FAMILY "ZEN " +#elif defined CONFIG_MZEN2 +#define MODULE_PROC_FAMILY "ZEN2 " #elif defined CONFIG_MELAN #define MODULE_PROC_FAMILY "ELAN " #elif defined CONFIG_MCRUSOE diff --git a/arch/x86/pci/common.c b/arch/x86/pci/common.c index ddb798603201..7c20387d8202 100644 --- a/arch/x86/pci/common.c +++ b/arch/x86/pci/common.c @@ -723,12 +723,15 @@ int pci_ext_cfg_avail(void) return 0; } -#if IS_ENABLED(CONFIG_VMD) struct pci_dev *pci_real_dma_dev(struct pci_dev *dev) { +#if IS_ENABLED(CONFIG_VMD) if (is_vmd(dev->bus)) return to_pci_sysdata(dev->bus)->vmd_dev; +#endif + + if (is_nvme_remap(dev->bus)) + return to_pci_sysdata(dev->bus)->nvme_remap_dev; return dev; } -#endif diff --git a/drivers/Makefile b/drivers/Makefile index 45d1c3e630f7..4f5ab2429a7f 100644 --- a/drivers/Makefile +++ b/drivers/Makefile @@ -64,14 +64,8 @@ obj-y += char/ # iommu/ comes before gpu as gpu are using iommu controllers obj-y += iommu/ -# gpu/ comes after char for AGP vs DRM startup and after iommu -obj-y += gpu/ - obj-$(CONFIG_CONNECTOR) += connector/ -# i810fb depends on char/agp/ -obj-$(CONFIG_FB_I810) += video/fbdev/i810/ - obj-$(CONFIG_PARPORT) += parport/ obj-y += base/ block/ misc/ mfd/ nfc/ obj-$(CONFIG_LIBNVDIMM) += nvdimm/ @@ -83,6 +77,13 @@ obj-y += macintosh/ obj-y += scsi/ obj-y += nvme/ obj-$(CONFIG_ATA) += ata/ + +# gpu/ comes after char for AGP vs DRM startup and after iommu +obj-y += gpu/ + +# i810fb depends on char/agp/ +obj-$(CONFIG_FB_I810) += video/fbdev/i810/ + obj-$(CONFIG_TARGET_CORE) += target/ obj-$(CONFIG_MTD) += mtd/ obj-$(CONFIG_SPI) += spi/ diff --git a/drivers/ata/ahci.c b/drivers/ata/ahci.c index 45f63b09828a..d8bcb8b7544f 100644 --- a/drivers/ata/ahci.c +++ b/drivers/ata/ahci.c @@ -1618,7 +1618,7 @@ static irqreturn_t ahci_thunderx_irq_handler(int irq, void *dev_instance) } #endif -static void ahci_remap_check(struct pci_dev *pdev, int bar, +static int ahci_remap_check(struct pci_dev *pdev, int bar, struct ahci_host_priv *hpriv) { int i; @@ -1631,7 +1631,7 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, pci_resource_len(pdev, bar) < SZ_512K || bar != AHCI_PCI_BAR_STANDARD || !(readl(hpriv->mmio + AHCI_VSCAP) & 1)) - return; + return 0; cap = readq(hpriv->mmio + AHCI_REMAP_CAP); for (i = 0; i < AHCI_MAX_REMAP; i++) { @@ -1646,18 +1646,11 @@ static void ahci_remap_check(struct pci_dev *pdev, int bar, } if (!hpriv->remapped_nvme) - return; - - dev_warn(&pdev->dev, "Found %u remapped NVMe devices.\n", - hpriv->remapped_nvme); - dev_warn(&pdev->dev, - "Switch your BIOS from RAID to AHCI mode to use them.\n"); + return 0; - /* - * Don't rely on the msi-x capability in the remap case, - * share the legacy interrupt across ahci and remapped devices. - */ - hpriv->flags |= AHCI_HFLAG_NO_MSI; + /* Abort probe, allowing intel-nvme-remap to step in when available */ + dev_info(&pdev->dev, "Device will be handled by intel-nvme-remap.\n"); + return -ENODEV; } static int ahci_get_irq_vector(struct ata_host *host, int port) @@ -1896,7 +1889,9 @@ static int ahci_init_one(struct pci_dev *pdev, const struct pci_device_id *ent) hpriv->mmio = pcim_iomap_table(pdev)[ahci_pci_bar]; /* detect remapped nvme devices */ - ahci_remap_check(pdev, ahci_pci_bar, hpriv); + rc = ahci_remap_check(pdev, ahci_pci_bar, hpriv); + if (rc) + return rc; sysfs_add_file_to_group(&pdev->dev.kobj, &dev_attr_remapped_nvme.attr, diff --git a/drivers/cpufreq/Kconfig.x86 b/drivers/cpufreq/Kconfig.x86 index 97c2d4f15d76..5a3af44d785a 100644 --- a/drivers/cpufreq/Kconfig.x86 +++ b/drivers/cpufreq/Kconfig.x86 @@ -9,7 +9,6 @@ config X86_INTEL_PSTATE select ACPI_PROCESSOR if ACPI select ACPI_CPPC_LIB if X86_64 && ACPI && SCHED_MC_PRIO select CPU_FREQ_GOV_PERFORMANCE - select CPU_FREQ_GOV_SCHEDUTIL if SMP help This driver provides a P state for Intel core processors. The driver implements an internal governor and will become @@ -39,7 +38,6 @@ config X86_AMD_PSTATE depends on X86 && ACPI select ACPI_PROCESSOR select ACPI_CPPC_LIB if X86_64 - select CPU_FREQ_GOV_SCHEDUTIL if SMP help This driver adds a CPUFreq driver which utilizes a fine grain processor performance frequency control range instead of legacy diff --git a/drivers/cpufreq/intel_pstate.c b/drivers/cpufreq/intel_pstate.c index cd2ac1ba53d2..ac3647df1431 100644 --- a/drivers/cpufreq/intel_pstate.c +++ b/drivers/cpufreq/intel_pstate.c @@ -3820,6 +3820,8 @@ static int __init intel_pstate_setup(char *str) if (!strcmp(str, "disable")) no_load = 1; + else if (!strcmp(str, "enable")) + no_load = 0; else if (!strcmp(str, "active")) default_driver = &intel_pstate; else if (!strcmp(str, "passive")) diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu.h b/drivers/gpu/drm/amd/amdgpu/amdgpu.h index 9b1e0ede05a4..7617963901fa 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu.h +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu.h @@ -164,6 +164,7 @@ struct amdgpu_watchdog_timer { */ extern int amdgpu_modeset; extern unsigned int amdgpu_vram_limit; +extern int amdgpu_ignore_min_pcap; extern int amdgpu_vis_vram_limit; extern int amdgpu_gart_size; extern int amdgpu_gtt_size; diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c index 81d9877c8735..852e6f315576 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c @@ -136,6 +136,7 @@ enum AMDGPU_DEBUG_MASK { }; unsigned int amdgpu_vram_limit = UINT_MAX; +int amdgpu_ignore_min_pcap = 0; /* do not ignore by default */ int amdgpu_vis_vram_limit; int amdgpu_gart_size = -1; /* auto */ int amdgpu_gtt_size = -1; /* auto */ @@ -259,6 +260,15 @@ struct amdgpu_watchdog_timer amdgpu_watchdog_timer = { .period = 0x0, /* default to 0x0 (timeout disable) */ }; +/** + * DOC: ignore_min_pcap (int) + * Ignore the minimum power cap. + * Useful on graphics cards where the minimum power cap is very high. + * The default is 0 (Do not ignore). + */ +MODULE_PARM_DESC(ignore_min_pcap, "Ignore the minimum power cap"); +module_param_named(ignore_min_pcap, amdgpu_ignore_min_pcap, int, 0600); + /** * DOC: vramlimit (int) * Restrict the total amount of VRAM in MiB for testing. The default is 0 (Use full VRAM). diff --git a/drivers/gpu/drm/amd/display/Kconfig b/drivers/gpu/drm/amd/display/Kconfig index df17e79c45c7..e454488c1a31 100644 --- a/drivers/gpu/drm/amd/display/Kconfig +++ b/drivers/gpu/drm/amd/display/Kconfig @@ -53,4 +53,10 @@ config DRM_AMD_SECURE_DISPLAY This option enables the calculation of crc of specific region via debugfs. Cooperate with specific DMCU FW. +config AMD_PRIVATE_COLOR + bool "Enable KMS color management by AMD for AMD" + default n + help + This option extends the KMS color management API with AMD driver-specific properties to enhance the color management support on AMD Steam Deck. + endmenu diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c index 07e9ce99694f..cf966e8f61fa 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm.c @@ -4473,7 +4473,7 @@ static int amdgpu_dm_mode_config_init(struct amdgpu_device *adev) return r; } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR if (amdgpu_dm_create_color_properties(adev)) { dc_state_release(state->context); kfree(state); diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c index ebabfe3a512f..4d3ebcaacca1 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_color.c @@ -97,7 +97,7 @@ static inline struct fixed31_32 amdgpu_dm_fixpt_from_s3132(__u64 x) return val; } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR /* Pre-defined Transfer Functions (TF) * * AMD driver supports pre-defined mathematical functions for transferring diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c index a2cf2c066a76..285f5a045ca5 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_crtc.c @@ -474,7 +474,7 @@ static int amdgpu_dm_crtc_late_register(struct drm_crtc *crtc) } #endif -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR /** * dm_crtc_additional_color_mgmt - enable additional color properties * @crtc: DRM CRTC @@ -556,7 +556,7 @@ static const struct drm_crtc_funcs amdgpu_dm_crtc_funcs = { #if defined(CONFIG_DEBUG_FS) .late_register = amdgpu_dm_crtc_late_register, #endif -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR .atomic_set_property = amdgpu_dm_atomic_crtc_set_property, .atomic_get_property = amdgpu_dm_atomic_crtc_get_property, #endif @@ -735,7 +735,7 @@ int amdgpu_dm_crtc_init(struct amdgpu_display_manager *dm, drm_mode_crtc_set_gamma_size(&acrtc->base, MAX_COLOR_LEGACY_LUT_ENTRIES); -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR dm_crtc_additional_color_mgmt(&acrtc->base); #endif return 0; diff --git a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c index 495e3cd70426..704a48209657 100644 --- a/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c +++ b/drivers/gpu/drm/amd/display/amdgpu_dm/amdgpu_dm_plane.c @@ -1573,7 +1573,7 @@ static void amdgpu_dm_plane_drm_plane_destroy_state(struct drm_plane *plane, drm_atomic_helper_plane_destroy_state(plane, state); } -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR static void dm_atomic_plane_attach_color_mgmt_properties(struct amdgpu_display_manager *dm, struct drm_plane *plane) @@ -1764,7 +1764,7 @@ static const struct drm_plane_funcs dm_plane_funcs = { .atomic_duplicate_state = amdgpu_dm_plane_drm_plane_duplicate_state, .atomic_destroy_state = amdgpu_dm_plane_drm_plane_destroy_state, .format_mod_supported = amdgpu_dm_plane_format_mod_supported, -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR .atomic_set_property = dm_atomic_plane_set_property, .atomic_get_property = dm_atomic_plane_get_property, #endif @@ -1857,7 +1857,7 @@ int amdgpu_dm_plane_init(struct amdgpu_display_manager *dm, drm_plane_helper_add(plane, &dm_plane_helper_funcs); -#ifdef AMD_PRIVATE_COLOR +#ifdef CONFIG_AMD_PRIVATE_COLOR dm_atomic_plane_attach_color_mgmt_properties(dm, plane); #endif /* Create (reset) the plane state */ diff --git a/drivers/gpu/drm/amd/pm/amdgpu_pm.c b/drivers/gpu/drm/amd/pm/amdgpu_pm.c index d5d6ab484e5a..dccba7bcdf97 100644 --- a/drivers/gpu/drm/amd/pm/amdgpu_pm.c +++ b/drivers/gpu/drm/amd/pm/amdgpu_pm.c @@ -3272,6 +3272,9 @@ static ssize_t amdgpu_hwmon_show_power_cap_min(struct device *dev, struct device_attribute *attr, char *buf) { + if (amdgpu_ignore_min_pcap) + return sysfs_emit(buf, "%i\n", 0); + return amdgpu_hwmon_show_power_cap_generic(dev, attr, buf, PP_PWR_LIMIT_MIN); } diff --git a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c index ee1bcfaae3e3..3388604f222b 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c +++ b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c @@ -2785,7 +2785,10 @@ int smu_get_power_limit(void *handle, *limit = smu->max_power_limit; break; case SMU_PPT_LIMIT_MIN: - *limit = smu->min_power_limit; + if (amdgpu_ignore_min_pcap) + *limit = 0; + else + *limit = smu->min_power_limit; break; default: return -EINVAL; @@ -2809,7 +2812,14 @@ static int smu_set_power_limit(void *handle, uint32_t limit) if (smu->ppt_funcs->set_power_limit) return smu->ppt_funcs->set_power_limit(smu, limit_type, limit); - if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { + if (amdgpu_ignore_min_pcap) { + if ((limit > smu->max_power_limit)) { + dev_err(smu->adev->dev, + "New power limit (%d) is over the max allowed %d\n", + limit, smu->max_power_limit); + return -EINVAL; + } + } else if ((limit > smu->max_power_limit) || (limit < smu->min_power_limit)) { dev_err(smu->adev->dev, "New power limit (%d) is out of range [%d,%d]\n", limit, smu->min_power_limit, smu->max_power_limit); diff --git a/drivers/input/evdev.c b/drivers/input/evdev.c index b5cbb57ee5f6..a0f7fa1518c6 100644 --- a/drivers/input/evdev.c +++ b/drivers/input/evdev.c @@ -46,6 +46,7 @@ struct evdev_client { struct fasync_struct *fasync; struct evdev *evdev; struct list_head node; + struct rcu_head rcu; enum input_clock_type clk_type; bool revoked; unsigned long *evmasks[EV_CNT]; @@ -368,13 +369,22 @@ static void evdev_attach_client(struct evdev *evdev, spin_unlock(&evdev->client_lock); } +static void evdev_reclaim_client(struct rcu_head *rp) +{ + struct evdev_client *client = container_of(rp, struct evdev_client, rcu); + unsigned int i; + for (i = 0; i < EV_CNT; ++i) + bitmap_free(client->evmasks[i]); + kvfree(client); +} + static void evdev_detach_client(struct evdev *evdev, struct evdev_client *client) { spin_lock(&evdev->client_lock); list_del_rcu(&client->node); spin_unlock(&evdev->client_lock); - synchronize_rcu(); + call_rcu(&client->rcu, evdev_reclaim_client); } static int evdev_open_device(struct evdev *evdev) @@ -427,7 +437,6 @@ static int evdev_release(struct inode *inode, struct file *file) { struct evdev_client *client = file->private_data; struct evdev *evdev = client->evdev; - unsigned int i; mutex_lock(&evdev->mutex); @@ -439,11 +448,6 @@ static int evdev_release(struct inode *inode, struct file *file) evdev_detach_client(evdev, client); - for (i = 0; i < EV_CNT; ++i) - bitmap_free(client->evmasks[i]); - - kvfree(client); - evdev_close_device(evdev); return 0; @@ -486,7 +490,6 @@ static int evdev_open(struct inode *inode, struct file *file) err_free_client: evdev_detach_client(evdev, client); - kvfree(client); return error; } diff --git a/drivers/md/dm-crypt.c b/drivers/md/dm-crypt.c index 1ae2c71bb383..784829ada178 100644 --- a/drivers/md/dm-crypt.c +++ b/drivers/md/dm-crypt.c @@ -3315,6 +3315,11 @@ static int crypt_ctr(struct dm_target *ti, unsigned int argc, char **argv) goto bad; } +#ifdef CONFIG_CACHY + set_bit(DM_CRYPT_NO_READ_WORKQUEUE, &cc->flags); + set_bit(DM_CRYPT_NO_WRITE_WORKQUEUE, &cc->flags); +#endif + ret = crypt_ctr_cipher(ti, argv[0], argv[1]); if (ret < 0) goto bad; diff --git a/drivers/media/v4l2-core/Kconfig b/drivers/media/v4l2-core/Kconfig index 331b8e535e5b..80dabeebf580 100644 --- a/drivers/media/v4l2-core/Kconfig +++ b/drivers/media/v4l2-core/Kconfig @@ -40,6 +40,11 @@ config VIDEO_TUNER config V4L2_JPEG_HELPER tristate +config V4L2_LOOPBACK + tristate "V4L2 loopback device" + help + V4L2 loopback device + # Used by drivers that need v4l2-h264.ko config V4L2_H264 tristate diff --git a/drivers/media/v4l2-core/Makefile b/drivers/media/v4l2-core/Makefile index 2177b9d63a8f..c179507cedc4 100644 --- a/drivers/media/v4l2-core/Makefile +++ b/drivers/media/v4l2-core/Makefile @@ -33,5 +33,7 @@ obj-$(CONFIG_V4L2_JPEG_HELPER) += v4l2-jpeg.o obj-$(CONFIG_V4L2_MEM2MEM_DEV) += v4l2-mem2mem.o obj-$(CONFIG_V4L2_VP9) += v4l2-vp9.o +obj-$(CONFIG_V4L2_LOOPBACK) += v4l2loopback.o + obj-$(CONFIG_VIDEO_TUNER) += tuner.o obj-$(CONFIG_VIDEO_DEV) += v4l2-dv-timings.o videodev.o diff --git a/drivers/media/v4l2-core/v4l2loopback.c b/drivers/media/v4l2-core/v4l2loopback.c new file mode 100644 index 000000000000..25cb1beb26e5 --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback.c @@ -0,0 +1,3184 @@ +/* -*- c-file-style: "linux" -*- */ +/* + * v4l2loopback.c -- video4linux2 loopback driver + * + * Copyright (C) 2005-2009 Vasily Levin (vasaka@gmail.com) + * Copyright (C) 2010-2023 IOhannes m zmoelnig (zmoelnig@iem.at) + * Copyright (C) 2011 Stefan Diewald (stefan.diewald@mytum.de) + * Copyright (C) 2012 Anton Novikov (random.plant@gmail.com) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "v4l2loopback.h" + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 0, 0) +#error This module is not supported on kernels before 4.0.0. +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) +#define strscpy strlcpy +#endif + +#if defined(timer_setup) && defined(from_timer) +#define HAVE_TIMER_SETUP +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 7, 0) +#define VFL_TYPE_VIDEO VFL_TYPE_GRABBER +#endif + +#define V4L2LOOPBACK_VERSION_CODE \ + KERNEL_VERSION(V4L2LOOPBACK_VERSION_MAJOR, V4L2LOOPBACK_VERSION_MINOR, \ + V4L2LOOPBACK_VERSION_BUGFIX) + +MODULE_DESCRIPTION("V4L2 loopback video device"); +MODULE_AUTHOR("Vasily Levin, " + "IOhannes m zmoelnig ," + "Stefan Diewald," + "Anton Novikov" + "et al."); +#ifdef SNAPSHOT_VERSION +MODULE_VERSION(__stringify(SNAPSHOT_VERSION)); +#else +MODULE_VERSION("" __stringify(V4L2LOOPBACK_VERSION_MAJOR) "." __stringify( + V4L2LOOPBACK_VERSION_MINOR) "." __stringify(V4L2LOOPBACK_VERSION_BUGFIX)); +#endif +MODULE_LICENSE("GPL"); + +/* + * helpers + */ +#define dprintk(fmt, args...) \ + do { \ + if (debug > 0) { \ + printk(KERN_INFO "v4l2-loopback[" __stringify( \ + __LINE__) "], pid(%d): " fmt, \ + task_pid_nr(current), ##args); \ + } \ + } while (0) + +#define MARK() \ + do { \ + if (debug > 1) { \ + printk(KERN_INFO "%s:%d[%s], pid(%d)\n", __FILE__, \ + __LINE__, __func__, task_pid_nr(current)); \ + } \ + } while (0) + +#define dprintkrw(fmt, args...) \ + do { \ + if (debug > 2) { \ + printk(KERN_INFO "v4l2-loopback[" __stringify( \ + __LINE__) "], pid(%d): " fmt, \ + task_pid_nr(current), ##args); \ + } \ + } while (0) + +static inline void v4l2l_get_timestamp(struct v4l2_buffer *b) +{ + struct timespec64 ts; + ktime_get_ts64(&ts); + + b->timestamp.tv_sec = ts.tv_sec; + b->timestamp.tv_usec = (ts.tv_nsec / NSEC_PER_USEC); + b->flags |= V4L2_BUF_FLAG_TIMESTAMP_MONOTONIC; +} + +#if BITS_PER_LONG == 32 +#include /* do_div() for 64bit division */ +static inline int v4l2l_mod64(const s64 A, const u32 B) +{ + u64 a = (u64)A; + u32 b = B; + + if (A > 0) + return do_div(a, b); + a = -A; + return -do_div(a, b); +} +#else +static inline int v4l2l_mod64(const s64 A, const u32 B) +{ + return A % B; +} +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 16, 0) +typedef unsigned __poll_t; +#endif + +/* module constants + * can be overridden during he build process using something like + * make KCPPFLAGS="-DMAX_DEVICES=100" + */ + +/* maximum number of v4l2loopback devices that can be created */ +#ifndef MAX_DEVICES +#define MAX_DEVICES 8 +#endif + +/* whether the default is to announce capabilities exclusively or not */ +#ifndef V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS +#define V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS 0 +#endif + +/* when a producer is considered to have gone stale */ +#ifndef MAX_TIMEOUT +#define MAX_TIMEOUT (100 * 1000) /* in msecs */ +#endif + +/* max buffers that can be mapped, actually they + * are all mapped to max_buffers buffers */ +#ifndef MAX_BUFFERS +#define MAX_BUFFERS 32 +#endif + +/* module parameters */ +static int debug = 0; +module_param(debug, int, S_IRUGO | S_IWUSR); +MODULE_PARM_DESC(debug, "debugging level (higher values == more verbose)"); + +#define V4L2LOOPBACK_DEFAULT_MAX_BUFFERS 2 +static int max_buffers = V4L2LOOPBACK_DEFAULT_MAX_BUFFERS; +module_param(max_buffers, int, S_IRUGO); +MODULE_PARM_DESC(max_buffers, + "how many buffers should be allocated [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_MAX_BUFFERS) "]"); + +/* how many times a device can be opened + * the per-module default value can be overridden on a per-device basis using + * the /sys/devices interface + * + * note that max_openers should be at least 2 in order to get a working system: + * one opener for the producer and one opener for the consumer + * however, we leave that to the user + */ +#define V4L2LOOPBACK_DEFAULT_MAX_OPENERS 10 +static int max_openers = V4L2LOOPBACK_DEFAULT_MAX_OPENERS; +module_param(max_openers, int, S_IRUGO | S_IWUSR); +MODULE_PARM_DESC( + max_openers, + "how many users can open the loopback device [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_MAX_OPENERS) "]"); + +static int devices = -1; +module_param(devices, int, 0); +MODULE_PARM_DESC(devices, "how many devices should be created"); + +static int video_nr[MAX_DEVICES] = { [0 ...(MAX_DEVICES - 1)] = -1 }; +module_param_array(video_nr, int, NULL, 0444); +MODULE_PARM_DESC(video_nr, + "video device numbers (-1=auto, 0=/dev/video0, etc.)"); + +static char *card_label[MAX_DEVICES]; +module_param_array(card_label, charp, NULL, 0000); +MODULE_PARM_DESC(card_label, "card labels for each device"); + +static bool exclusive_caps[MAX_DEVICES] = { + [0 ...(MAX_DEVICES - 1)] = V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS +}; +module_param_array(exclusive_caps, bool, NULL, 0444); +/* FIXXME: wording */ +MODULE_PARM_DESC( + exclusive_caps, + "whether to announce OUTPUT/CAPTURE capabilities exclusively or not [DEFAULT: " __stringify( + V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS) "]"); + +/* format specifications */ +#define V4L2LOOPBACK_SIZE_MIN_WIDTH 2 +#define V4L2LOOPBACK_SIZE_MIN_HEIGHT 1 +#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH 8192 +#define V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT 8192 + +#define V4L2LOOPBACK_SIZE_DEFAULT_WIDTH 640 +#define V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT 480 + +static int max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; +module_param(max_width, int, S_IRUGO); +MODULE_PARM_DESC(max_width, + "maximum allowed frame width [DEFAULT: " __stringify( + V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH) "]"); +static int max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; +module_param(max_height, int, S_IRUGO); +MODULE_PARM_DESC(max_height, + "maximum allowed frame height [DEFAULT: " __stringify( + V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT) "]"); + +static DEFINE_IDR(v4l2loopback_index_idr); +static DEFINE_MUTEX(v4l2loopback_ctl_mutex); + +/* frame intervals */ +#define V4L2LOOPBACK_FPS_MIN 0 +#define V4L2LOOPBACK_FPS_MAX 1000 + +/* control IDs */ +#define V4L2LOOPBACK_CID_BASE (V4L2_CID_USER_BASE | 0xf000) +#define CID_KEEP_FORMAT (V4L2LOOPBACK_CID_BASE + 0) +#define CID_SUSTAIN_FRAMERATE (V4L2LOOPBACK_CID_BASE + 1) +#define CID_TIMEOUT (V4L2LOOPBACK_CID_BASE + 2) +#define CID_TIMEOUT_IMAGE_IO (V4L2LOOPBACK_CID_BASE + 3) + +static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl); +static const struct v4l2_ctrl_ops v4l2loopback_ctrl_ops = { + .s_ctrl = v4l2loopback_s_ctrl, +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_keepformat = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_KEEP_FORMAT, + .name = "keep_format", + .type = V4L2_CTRL_TYPE_BOOLEAN, + .min = 0, + .max = 1, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_sustainframerate = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_SUSTAIN_FRAMERATE, + .name = "sustain_framerate", + .type = V4L2_CTRL_TYPE_BOOLEAN, + .min = 0, + .max = 1, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeout = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_TIMEOUT, + .name = "timeout", + .type = V4L2_CTRL_TYPE_INTEGER, + .min = 0, + .max = MAX_TIMEOUT, + .step = 1, + .def = 0, + // clang-format on +}; +static const struct v4l2_ctrl_config v4l2loopback_ctrl_timeoutimageio = { + // clang-format off + .ops = &v4l2loopback_ctrl_ops, + .id = CID_TIMEOUT_IMAGE_IO, + .name = "timeout_image_io", + .type = V4L2_CTRL_TYPE_BUTTON, + .min = 0, + .max = 1, + .step = 1, + .def = 0, + // clang-format on +}; + +/* module structures */ +struct v4l2loopback_private { + int device_nr; +}; + +/* TODO(vasaka) use typenames which are common to kernel, but first find out if + * it is needed */ +/* struct keeping state and settings of loopback device */ + +struct v4l2l_buffer { + struct v4l2_buffer buffer; + struct list_head list_head; + int use_count; +}; + +struct v4l2_loopback_device { + struct v4l2_device v4l2_dev; + struct v4l2_ctrl_handler ctrl_handler; + struct video_device *vdev; + /* pixel and stream format */ + struct v4l2_pix_format pix_format; + bool pix_format_has_valid_sizeimage; + struct v4l2_captureparm capture_param; + unsigned long frame_jiffies; + + /* ctrls */ + int keep_format; /* CID_KEEP_FORMAT; stay ready_for_capture even when all + openers close() the device */ + int sustain_framerate; /* CID_SUSTAIN_FRAMERATE; duplicate frames to maintain + (close to) nominal framerate */ + + /* buffers stuff */ + u8 *image; /* pointer to actual buffers data */ + unsigned long int imagesize; /* size of buffers data */ + int buffers_number; /* should not be big, 4 is a good choice */ + struct v4l2l_buffer buffers[MAX_BUFFERS]; /* inner driver buffers */ + int used_buffers; /* number of the actually used buffers */ + int max_openers; /* how many times can this device be opened */ + + s64 write_position; /* number of last written frame + 1 */ + struct list_head outbufs_list; /* buffers in output DQBUF order */ + int bufpos2index + [MAX_BUFFERS]; /* mapping of (read/write_position % used_buffers) + * to inner buffer index */ + long buffer_size; + + /* sustain_framerate stuff */ + struct timer_list sustain_timer; + unsigned int reread_count; + + /* timeout stuff */ + unsigned long timeout_jiffies; /* CID_TIMEOUT; 0 means disabled */ + int timeout_image_io; /* CID_TIMEOUT_IMAGE_IO; next opener will + * read/write to timeout_image */ + u8 *timeout_image; /* copy of it will be captured when timeout passes */ + struct v4l2l_buffer timeout_image_buffer; + struct timer_list timeout_timer; + int timeout_happened; + + /* sync stuff */ + atomic_t open_count; + + int ready_for_capture; /* set to the number of writers that opened the + * device and negotiated format. */ + int ready_for_output; /* set to true when no writer is currently attached + * this differs slightly from !ready_for_capture, + * e.g. when using fallback images */ + int active_readers; /* increase if any reader starts streaming */ + int announce_all_caps; /* set to false, if device caps (OUTPUT/CAPTURE) + * should only be announced if the resp. "ready" + * flag is set; default=TRUE */ + + int min_width, max_width; + int min_height, max_height; + + char card_label[32]; + + wait_queue_head_t read_event; + spinlock_t lock, list_lock; +}; + +/* types of opener shows what opener wants to do with loopback */ +enum opener_type { + // clang-format off + UNNEGOTIATED = 0, + READER = 1, + WRITER = 2, + // clang-format on +}; + +/* struct keeping state and type of opener */ +struct v4l2_loopback_opener { + enum opener_type type; + s64 read_position; /* number of last processed frame + 1 or + * write_position - 1 if reader went out of sync */ + unsigned int reread_count; + struct v4l2_buffer *buffers; + int buffers_number; /* should not be big, 4 is a good choice */ + int timeout_image_io; + + struct v4l2_fh fh; +}; + +#define fh_to_opener(ptr) container_of((ptr), struct v4l2_loopback_opener, fh) + +/* this is heavily inspired by the bttv driver found in the linux kernel */ +struct v4l2l_format { + char *name; + int fourcc; /* video4linux 2 */ + int depth; /* bit/pixel */ + int flags; +}; +/* set the v4l2l_format.flags to PLANAR for non-packed formats */ +#define FORMAT_FLAGS_PLANAR 0x01 +#define FORMAT_FLAGS_COMPRESSED 0x02 + +#include "v4l2loopback_formats.h" + +#ifndef V4L2_TYPE_IS_CAPTURE +#define V4L2_TYPE_IS_CAPTURE(type) \ + ((type) == V4L2_BUF_TYPE_VIDEO_CAPTURE || \ + (type) == V4L2_BUF_TYPE_VIDEO_CAPTURE_MPLANE) +#endif /* V4L2_TYPE_IS_CAPTURE */ +#ifndef V4L2_TYPE_IS_OUTPUT +#define V4L2_TYPE_IS_OUTPUT(type) \ + ((type) == V4L2_BUF_TYPE_VIDEO_OUTPUT || \ + (type) == V4L2_BUF_TYPE_VIDEO_OUTPUT_MPLANE) +#endif /* V4L2_TYPE_IS_OUTPUT */ + +/* whether the format can be changed */ +/* the format is fixated if we + - have writers (ready_for_capture>0) + - and/or have readers (active_readers>0) +*/ +#define V4L2LOOPBACK_IS_FIXED_FMT(device) \ + (device->ready_for_capture > 0 || device->active_readers > 0 || \ + device->keep_format) + +static const unsigned int FORMATS = ARRAY_SIZE(formats); + +static char *fourcc2str(unsigned int fourcc, char buf[4]) +{ + buf[0] = (fourcc >> 0) & 0xFF; + buf[1] = (fourcc >> 8) & 0xFF; + buf[2] = (fourcc >> 16) & 0xFF; + buf[3] = (fourcc >> 24) & 0xFF; + + return buf; +} + +static const struct v4l2l_format *format_by_fourcc(int fourcc) +{ + unsigned int i; + + for (i = 0; i < FORMATS; i++) { + if (formats[i].fourcc == fourcc) + return formats + i; + } + + dprintk("unsupported format '%c%c%c%c'\n", (fourcc >> 0) & 0xFF, + (fourcc >> 8) & 0xFF, (fourcc >> 16) & 0xFF, + (fourcc >> 24) & 0xFF); + return NULL; +} + +static void pix_format_set_size(struct v4l2_pix_format *f, + const struct v4l2l_format *fmt, + unsigned int width, unsigned int height) +{ + f->width = width; + f->height = height; + + if (fmt->flags & FORMAT_FLAGS_PLANAR) { + f->bytesperline = width; /* Y plane */ + f->sizeimage = (width * height * fmt->depth) >> 3; + } else if (fmt->flags & FORMAT_FLAGS_COMPRESSED) { + /* doesn't make sense for compressed formats */ + f->bytesperline = 0; + f->sizeimage = (width * height * fmt->depth) >> 3; + } else { + f->bytesperline = (width * fmt->depth) >> 3; + f->sizeimage = height * f->bytesperline; + } +} + +static int v4l2l_fill_format(struct v4l2_format *fmt, int capture, + const u32 minwidth, const u32 maxwidth, + const u32 minheight, const u32 maxheight) +{ + u32 width = fmt->fmt.pix.width, height = fmt->fmt.pix.height; + u32 pixelformat = fmt->fmt.pix.pixelformat; + struct v4l2_format fmt0 = *fmt; + u32 bytesperline = 0, sizeimage = 0; + if (!width) + width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; + if (!height) + height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; + if (width < minwidth) + width = minwidth; + if (width > maxwidth) + width = maxwidth; + if (height < minheight) + height = minheight; + if (height > maxheight) + height = maxheight; + + /* sets: width,height,pixelformat,bytesperline,sizeimage */ + if (!(V4L2_TYPE_IS_MULTIPLANAR(fmt0.type))) { + fmt0.fmt.pix.bytesperline = 0; + fmt0.fmt.pix.sizeimage = 0; + } + + if (0) { + ; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) + } else if (!v4l2_fill_pixfmt(&fmt0.fmt.pix, pixelformat, width, + height)) { + ; + } else if (!v4l2_fill_pixfmt_mp(&fmt0.fmt.pix_mp, pixelformat, width, + height)) { + ; +#endif + } else { + const struct v4l2l_format *format = + format_by_fourcc(pixelformat); + if (!format) + return -EINVAL; + pix_format_set_size(&fmt0.fmt.pix, format, width, height); + fmt0.fmt.pix.pixelformat = format->fourcc; + } + + if (V4L2_TYPE_IS_MULTIPLANAR(fmt0.type)) { + *fmt = fmt0; + + if ((fmt->fmt.pix_mp.colorspace == V4L2_COLORSPACE_DEFAULT) || + (fmt->fmt.pix_mp.colorspace > V4L2_COLORSPACE_DCI_P3)) + fmt->fmt.pix_mp.colorspace = V4L2_COLORSPACE_SRGB; + if (V4L2_FIELD_ANY == fmt->fmt.pix_mp.field) + fmt->fmt.pix_mp.field = V4L2_FIELD_NONE; + if (capture) + fmt->type = V4L2_BUF_TYPE_VIDEO_CAPTURE_MPLANE; + else + fmt->type = V4L2_BUF_TYPE_VIDEO_OUTPUT_MPLANE; + } else { + bytesperline = fmt->fmt.pix.bytesperline; + sizeimage = fmt->fmt.pix.sizeimage; + + *fmt = fmt0; + + if (!fmt->fmt.pix.bytesperline) + fmt->fmt.pix.bytesperline = bytesperline; + if (!fmt->fmt.pix.sizeimage) + fmt->fmt.pix.sizeimage = sizeimage; + + if ((fmt->fmt.pix.colorspace == V4L2_COLORSPACE_DEFAULT) || + (fmt->fmt.pix.colorspace > V4L2_COLORSPACE_DCI_P3)) + fmt->fmt.pix.colorspace = V4L2_COLORSPACE_SRGB; + if (V4L2_FIELD_ANY == fmt->fmt.pix.field) + fmt->fmt.pix.field = V4L2_FIELD_NONE; + if (capture) + fmt->type = V4L2_BUF_TYPE_VIDEO_CAPTURE; + else + fmt->type = V4L2_BUF_TYPE_VIDEO_OUTPUT; + } + + return 0; +} + +/* Checks if v4l2l_fill_format() has set a valid, fixed sizeimage val. */ +static bool v4l2l_pix_format_has_valid_sizeimage(struct v4l2_format *fmt) +{ +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 2, 0) + const struct v4l2_format_info *info; + + info = v4l2_format_info(fmt->fmt.pix.pixelformat); + if (info && info->mem_planes == 1) + return true; +#endif + + return false; +} + +static int pix_format_eq(const struct v4l2_pix_format *ref, + const struct v4l2_pix_format *tgt, int strict) +{ + /* check if the two formats are equivalent. + * ANY fields are handled gracefully + */ +#define _pix_format_eq0(x) \ + if (ref->x != tgt->x) \ + result = 0 +#define _pix_format_eq1(x, def) \ + do { \ + if ((def != tgt->x) && (ref->x != tgt->x)) { \ + printk(KERN_INFO #x " failed"); \ + result = 0; \ + } \ + } while (0) + int result = 1; + _pix_format_eq0(width); + _pix_format_eq0(height); + _pix_format_eq0(pixelformat); + if (!strict) + return result; + _pix_format_eq1(field, V4L2_FIELD_ANY); + _pix_format_eq0(bytesperline); + _pix_format_eq0(sizeimage); + _pix_format_eq1(colorspace, V4L2_COLORSPACE_DEFAULT); + return result; +} + +static struct v4l2_loopback_device *v4l2loopback_getdevice(struct file *f); +static int inner_try_setfmt(struct file *file, struct v4l2_format *fmt) +{ + int capture = V4L2_TYPE_IS_CAPTURE(fmt->type); + struct v4l2_loopback_device *dev; + int needschange = 0; + char buf[5]; + buf[4] = 0; + + dev = v4l2loopback_getdevice(file); + + needschange = !(pix_format_eq(&dev->pix_format, &fmt->fmt.pix, 0)); + if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { + fmt->fmt.pix = dev->pix_format; + if (needschange) { + if (dev->active_readers > 0 && capture) { + /* cannot call fmt_cap while there are readers */ + return -EBUSY; + } + if (dev->ready_for_capture > 0 && !capture) { + /* cannot call fmt_out while there are writers */ + return -EBUSY; + } + } + } + if (v4l2l_fill_format(fmt, capture, dev->min_width, dev->max_width, + dev->min_height, dev->max_height) != 0) { + return -EINVAL; + } + + if (1) { + char buf[5]; + buf[4] = 0; + dprintk("capFOURCC=%s\n", + fourcc2str(dev->pix_format.pixelformat, buf)); + } + return 0; +} + +static int set_timeperframe(struct v4l2_loopback_device *dev, + struct v4l2_fract *tpf) +{ + if ((tpf->denominator < 1) || (tpf->numerator < 1)) { + return -EINVAL; + } + dev->capture_param.timeperframe = *tpf; + dev->frame_jiffies = max(1UL, msecs_to_jiffies(1000) * tpf->numerator / + tpf->denominator); + return 0; +} + +static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd); + +/* device attributes */ +/* available via sysfs: /sys/devices/virtual/video4linux/video* */ + +static ssize_t attr_show_format(struct device *cd, + struct device_attribute *attr, char *buf) +{ + /* gets the current format as "FOURCC:WxH@f/s", e.g. "YUYV:320x240@1000/30" */ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + const struct v4l2_fract *tpf; + char buf4cc[5], buf_fps[32]; + + if (!dev || !V4L2LOOPBACK_IS_FIXED_FMT(dev)) + return 0; + tpf = &dev->capture_param.timeperframe; + + fourcc2str(dev->pix_format.pixelformat, buf4cc); + buf4cc[4] = 0; + if (tpf->numerator == 1) + snprintf(buf_fps, sizeof(buf_fps), "%d", tpf->denominator); + else + snprintf(buf_fps, sizeof(buf_fps), "%d/%d", tpf->denominator, + tpf->numerator); + return sprintf(buf, "%4s:%dx%d@%s\n", buf4cc, dev->pix_format.width, + dev->pix_format.height, buf_fps); +} + +static ssize_t attr_store_format(struct device *cd, + struct device_attribute *attr, const char *buf, + size_t len) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + int fps_num = 0, fps_den = 1; + + if (!dev) + return -ENODEV; + + /* only fps changing is supported */ + if (sscanf(buf, "@%d/%d", &fps_num, &fps_den) > 0) { + struct v4l2_fract f = { .numerator = fps_den, + .denominator = fps_num }; + int err = 0; + if ((err = set_timeperframe(dev, &f)) < 0) + return err; + return len; + } + return -EINVAL; +} + +static DEVICE_ATTR(format, S_IRUGO | S_IWUSR, attr_show_format, + attr_store_format); + +static ssize_t attr_show_buffers(struct device *cd, + struct device_attribute *attr, char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + return sprintf(buf, "%d\n", dev->used_buffers); +} + +static DEVICE_ATTR(buffers, S_IRUGO, attr_show_buffers, NULL); + +static ssize_t attr_show_maxopeners(struct device *cd, + struct device_attribute *attr, char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + return sprintf(buf, "%d\n", dev->max_openers); +} + +static ssize_t attr_store_maxopeners(struct device *cd, + struct device_attribute *attr, + const char *buf, size_t len) +{ + struct v4l2_loopback_device *dev = NULL; + unsigned long curr = 0; + + if (kstrtoul(buf, 0, &curr)) + return -EINVAL; + + dev = v4l2loopback_cd2dev(cd); + if (!dev) + return -ENODEV; + + if (dev->max_openers == curr) + return len; + + if (curr > __INT_MAX__ || dev->open_count.counter > curr) { + /* request to limit to less openers as are currently attached to us */ + return -EINVAL; + } + + dev->max_openers = (int)curr; + + return len; +} + +static DEVICE_ATTR(max_openers, S_IRUGO | S_IWUSR, attr_show_maxopeners, + attr_store_maxopeners); + +static ssize_t attr_show_state(struct device *cd, struct device_attribute *attr, + char *buf) +{ + struct v4l2_loopback_device *dev = v4l2loopback_cd2dev(cd); + + if (!dev) + return -ENODEV; + + if (dev->ready_for_capture) + return sprintf(buf, "capture\n"); + if (dev->ready_for_output) + return sprintf(buf, "output\n"); + + return -EAGAIN; +} + +static DEVICE_ATTR(state, S_IRUGO, attr_show_state, NULL); + +static void v4l2loopback_remove_sysfs(struct video_device *vdev) +{ +#define V4L2_SYSFS_DESTROY(x) device_remove_file(&vdev->dev, &dev_attr_##x) + + if (vdev) { + V4L2_SYSFS_DESTROY(format); + V4L2_SYSFS_DESTROY(buffers); + V4L2_SYSFS_DESTROY(max_openers); + V4L2_SYSFS_DESTROY(state); + /* ... */ + } +} + +static void v4l2loopback_create_sysfs(struct video_device *vdev) +{ + int res = 0; + +#define V4L2_SYSFS_CREATE(x) \ + res = device_create_file(&vdev->dev, &dev_attr_##x); \ + if (res < 0) \ + break + if (!vdev) + return; + do { + V4L2_SYSFS_CREATE(format); + V4L2_SYSFS_CREATE(buffers); + V4L2_SYSFS_CREATE(max_openers); + V4L2_SYSFS_CREATE(state); + /* ... */ + } while (0); + + if (res >= 0) + return; + dev_err(&vdev->dev, "%s error: %d\n", __func__, res); +} + +/* Event APIs */ + +#define V4L2LOOPBACK_EVENT_BASE (V4L2_EVENT_PRIVATE_START) +#define V4L2LOOPBACK_EVENT_OFFSET 0x08E00000 +#define V4L2_EVENT_PRI_CLIENT_USAGE \ + (V4L2LOOPBACK_EVENT_BASE + V4L2LOOPBACK_EVENT_OFFSET + 1) + +struct v4l2_event_client_usage { + __u32 count; +}; + +/* global module data */ +/* find a device based on it's device-number (e.g. '3' for /dev/video3) */ +struct v4l2loopback_lookup_cb_data { + int device_nr; + struct v4l2_loopback_device *device; +}; +static int v4l2loopback_lookup_cb(int id, void *ptr, void *data) +{ + struct v4l2_loopback_device *device = ptr; + struct v4l2loopback_lookup_cb_data *cbdata = data; + if (cbdata && device && device->vdev) { + if (device->vdev->num == cbdata->device_nr) { + cbdata->device = device; + cbdata->device_nr = id; + return 1; + } + } + return 0; +} +static int v4l2loopback_lookup(int device_nr, + struct v4l2_loopback_device **device) +{ + struct v4l2loopback_lookup_cb_data data = { + .device_nr = device_nr, + .device = NULL, + }; + int err = idr_for_each(&v4l2loopback_index_idr, &v4l2loopback_lookup_cb, + &data); + if (1 == err) { + if (device) + *device = data.device; + return data.device_nr; + } + return -ENODEV; +} +static struct v4l2_loopback_device *v4l2loopback_cd2dev(struct device *cd) +{ + struct video_device *loopdev = to_video_device(cd); + struct v4l2loopback_private *ptr = + (struct v4l2loopback_private *)video_get_drvdata(loopdev); + int nr = ptr->device_nr; + + return idr_find(&v4l2loopback_index_idr, nr); +} + +static struct v4l2_loopback_device *v4l2loopback_getdevice(struct file *f) +{ + struct v4l2loopback_private *ptr = video_drvdata(f); + int nr = ptr->device_nr; + + return idr_find(&v4l2loopback_index_idr, nr); +} + +/* forward declarations */ +static void client_usage_queue_event(struct video_device *vdev); +static void init_buffers(struct v4l2_loopback_device *dev); +static int allocate_buffers(struct v4l2_loopback_device *dev); +static void free_buffers(struct v4l2_loopback_device *dev); +static void try_free_buffers(struct v4l2_loopback_device *dev); +static int allocate_timeout_image(struct v4l2_loopback_device *dev); +static void check_timers(struct v4l2_loopback_device *dev); +static const struct v4l2_file_operations v4l2_loopback_fops; +static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops; + +/* Queue helpers */ +/* next functions sets buffer flags and adjusts counters accordingly */ +static inline void set_done(struct v4l2l_buffer *buffer) +{ + buffer->buffer.flags &= ~V4L2_BUF_FLAG_QUEUED; + buffer->buffer.flags |= V4L2_BUF_FLAG_DONE; +} + +static inline void set_queued(struct v4l2l_buffer *buffer) +{ + buffer->buffer.flags &= ~V4L2_BUF_FLAG_DONE; + buffer->buffer.flags |= V4L2_BUF_FLAG_QUEUED; +} + +static inline void unset_flags(struct v4l2l_buffer *buffer) +{ + buffer->buffer.flags &= ~V4L2_BUF_FLAG_QUEUED; + buffer->buffer.flags &= ~V4L2_BUF_FLAG_DONE; +} + +/* V4L2 ioctl caps and params calls */ +/* returns device capabilities + * called on VIDIOC_QUERYCAP + */ +static int vidioc_querycap(struct file *file, void *priv, + struct v4l2_capability *cap) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + int device_nr = + ((struct v4l2loopback_private *)video_get_drvdata(dev->vdev)) + ->device_nr; + __u32 capabilities = V4L2_CAP_STREAMING | V4L2_CAP_READWRITE; + + strscpy(cap->driver, "v4l2 loopback", sizeof(cap->driver)); + snprintf(cap->card, sizeof(cap->card), "%s", dev->card_label); + snprintf(cap->bus_info, sizeof(cap->bus_info), + "platform:v4l2loopback-%03d", device_nr); + + if (dev->announce_all_caps) { + capabilities |= V4L2_CAP_VIDEO_CAPTURE | V4L2_CAP_VIDEO_OUTPUT; + } else { + if (dev->ready_for_capture) { + capabilities |= V4L2_CAP_VIDEO_CAPTURE; + } + if (dev->ready_for_output) { + capabilities |= V4L2_CAP_VIDEO_OUTPUT; + } + } + +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) + dev->vdev->device_caps = +#endif /* >=linux-4.7.0 */ + cap->device_caps = cap->capabilities = capabilities; + + cap->capabilities |= V4L2_CAP_DEVICE_CAPS; + + memset(cap->reserved, 0, sizeof(cap->reserved)); + return 0; +} + +static int vidioc_enum_framesizes(struct file *file, void *fh, + struct v4l2_frmsizeenum *argp) +{ + struct v4l2_loopback_device *dev; + + /* there can be only one... */ + if (argp->index) + return -EINVAL; + + dev = v4l2loopback_getdevice(file); + if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { + /* format has already been negotiated + * cannot change during runtime + */ + if (argp->pixel_format != dev->pix_format.pixelformat) + return -EINVAL; + + argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; + + argp->discrete.width = dev->pix_format.width; + argp->discrete.height = dev->pix_format.height; + } else { + /* if the format has not been negotiated yet, we accept anything + */ + if (NULL == format_by_fourcc(argp->pixel_format)) + return -EINVAL; + + if (dev->min_width == dev->max_width && + dev->min_height == dev->max_height) { + argp->type = V4L2_FRMSIZE_TYPE_DISCRETE; + + argp->discrete.width = dev->min_width; + argp->discrete.height = dev->min_height; + } else { + argp->type = V4L2_FRMSIZE_TYPE_CONTINUOUS; + + argp->stepwise.min_width = dev->min_width; + argp->stepwise.min_height = dev->min_height; + + argp->stepwise.max_width = dev->max_width; + argp->stepwise.max_height = dev->max_height; + + argp->stepwise.step_width = 1; + argp->stepwise.step_height = 1; + } + } + return 0; +} + +/* returns frameinterval (fps) for the set resolution + * called on VIDIOC_ENUM_FRAMEINTERVALS + */ +static int vidioc_enum_frameintervals(struct file *file, void *fh, + struct v4l2_frmivalenum *argp) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + + /* there can be only one... */ + if (argp->index) + return -EINVAL; + + if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { + if (argp->width != dev->pix_format.width || + argp->height != dev->pix_format.height || + argp->pixel_format != dev->pix_format.pixelformat) + return -EINVAL; + + argp->type = V4L2_FRMIVAL_TYPE_DISCRETE; + argp->discrete = dev->capture_param.timeperframe; + } else { + if (argp->width < dev->min_width || + argp->width > dev->max_width || + argp->height < dev->min_height || + argp->height > dev->max_height || + NULL == format_by_fourcc(argp->pixel_format)) + return -EINVAL; + + argp->type = V4L2_FRMIVAL_TYPE_CONTINUOUS; + argp->stepwise.min.numerator = 1; + argp->stepwise.min.denominator = V4L2LOOPBACK_FPS_MAX; + argp->stepwise.max.numerator = 1; + argp->stepwise.max.denominator = V4L2LOOPBACK_FPS_MIN; + argp->stepwise.step.numerator = 1; + argp->stepwise.step.denominator = 1; + } + + return 0; +} + +/* ------------------ CAPTURE ----------------------- */ + +/* returns device formats + * called on VIDIOC_ENUM_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE + */ +static int vidioc_enum_fmt_cap(struct file *file, void *fh, + struct v4l2_fmtdesc *f) +{ + struct v4l2_loopback_device *dev; + const struct v4l2l_format *fmt; + MARK(); + + dev = v4l2loopback_getdevice(file); + + if (f->index) + return -EINVAL; + + if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { + /* format has been fixed, so only one single format is supported */ + const __u32 format = dev->pix_format.pixelformat; + + if ((fmt = format_by_fourcc(format))) { + snprintf(f->description, sizeof(f->description), "%s", + fmt->name); + } else { + snprintf(f->description, sizeof(f->description), + "[%c%c%c%c]", (format >> 0) & 0xFF, + (format >> 8) & 0xFF, (format >> 16) & 0xFF, + (format >> 24) & 0xFF); + } + + f->pixelformat = dev->pix_format.pixelformat; + } else { + return -EINVAL; + } + f->flags = 0; + MARK(); + return 0; +} + +/* returns current video format + * called on VIDIOC_G_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE + */ +static int vidioc_g_fmt_cap(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + struct v4l2_loopback_device *dev; + MARK(); + + dev = v4l2loopback_getdevice(file); + if (!dev->ready_for_capture && !dev->ready_for_output) + return -EINVAL; + + fmt->fmt.pix = dev->pix_format; + MARK(); + return 0; +} + +/* checks if it is OK to change to format fmt; + * actual check is done by inner_try_setfmt + * just checking that pixelformat is OK and set other parameters, app should + * obey this decision + * called on VIDIOC_TRY_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE + */ +static int vidioc_try_fmt_cap(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + int ret = 0; + if (!V4L2_TYPE_IS_CAPTURE(fmt->type)) + return -EINVAL; + ret = inner_try_setfmt(file, fmt); + if (-EBUSY == ret) + return 0; + return ret; +} + +/* sets new output format, if possible + * actually format is set by input and we even do not check it, just return + * current one, but it is possible to set subregions of input TODO(vasaka) + * called on VIDIOC_S_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_CAPTURE + */ +static int vidioc_s_fmt_cap(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + int ret; + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + if (!V4L2_TYPE_IS_CAPTURE(fmt->type)) + return -EINVAL; + ret = inner_try_setfmt(file, fmt); + if (!ret) { + dev->pix_format = fmt->fmt.pix; + } + return ret; +} + +/* ------------------ OUTPUT ----------------------- */ + +/* returns device formats; + * LATER: allow all formats + * called on VIDIOC_ENUM_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT + */ +static int vidioc_enum_fmt_out(struct file *file, void *fh, + struct v4l2_fmtdesc *f) +{ + struct v4l2_loopback_device *dev; + const struct v4l2l_format *fmt; + + dev = v4l2loopback_getdevice(file); + + if (V4L2LOOPBACK_IS_FIXED_FMT(dev)) { + /* format has been fixed, so only one single format is supported */ + const __u32 format = dev->pix_format.pixelformat; + + if (f->index) + return -EINVAL; + + if ((fmt = format_by_fourcc(format))) { + snprintf(f->description, sizeof(f->description), "%s", + fmt->name); + } else { + snprintf(f->description, sizeof(f->description), + "[%c%c%c%c]", (format >> 0) & 0xFF, + (format >> 8) & 0xFF, (format >> 16) & 0xFF, + (format >> 24) & 0xFF); + } + + f->pixelformat = dev->pix_format.pixelformat; + } else { + /* fill in a dummy format */ + /* coverity[unsigned_compare] */ + if (f->index < 0 || f->index >= FORMATS) + return -EINVAL; + + fmt = &formats[f->index]; + + f->pixelformat = fmt->fourcc; + snprintf(f->description, sizeof(f->description), "%s", + fmt->name); + } + f->flags = 0; + + return 0; +} + +/* returns current video format format fmt */ +/* NOTE: this is called from the producer + * so if format has not been negotiated yet, + * it should return ALL of available formats, + * called on VIDIOC_G_FMT, with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT + */ +static int vidioc_g_fmt_out(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + struct v4l2_loopback_device *dev; + MARK(); + + dev = v4l2loopback_getdevice(file); + + /* + * LATER: this should return the currently valid format + * gstreamer doesn't like it, if this returns -EINVAL, as it + * then concludes that there is _no_ valid format + * CHECK whether this assumption is wrong, + * or whether we have to always provide a valid format + */ + + fmt->fmt.pix = dev->pix_format; + return 0; +} + +/* checks if it is OK to change to format fmt; + * if format is negotiated do not change it + * called on VIDIOC_TRY_FMT with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT + */ +static int vidioc_try_fmt_out(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + int ret = 0; + if (!V4L2_TYPE_IS_OUTPUT(fmt->type)) + return -EINVAL; + ret = inner_try_setfmt(file, fmt); + if (-EBUSY == ret) + return 0; + return ret; +} + +/* sets new output format, if possible; + * allocate data here because we do not know if it will be streaming or + * read/write IO + * called on VIDIOC_S_FMT with v4l2_buf_type set to V4L2_BUF_TYPE_VIDEO_OUTPUT + */ +static int vidioc_s_fmt_out(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + struct v4l2_loopback_device *dev; + int ret; + char buf[5]; + buf[4] = 0; + if (!V4L2_TYPE_IS_OUTPUT(fmt->type)) + return -EINVAL; + dev = v4l2loopback_getdevice(file); + + ret = inner_try_setfmt(file, fmt); + if (!ret) { + dev->pix_format = fmt->fmt.pix; + dev->pix_format_has_valid_sizeimage = + v4l2l_pix_format_has_valid_sizeimage(fmt); + dprintk("s_fmt_out(%d) %d...%d\n", ret, dev->ready_for_capture, + dev->pix_format.sizeimage); + dprintk("outFOURCC=%s\n", + fourcc2str(dev->pix_format.pixelformat, buf)); + + if (!dev->ready_for_capture) { + dev->buffer_size = + PAGE_ALIGN(dev->pix_format.sizeimage); + // JMZ: TODO get rid of the next line + fmt->fmt.pix.sizeimage = dev->buffer_size; + ret = allocate_buffers(dev); + } + } + return ret; +} + +// #define V4L2L_OVERLAY +#ifdef V4L2L_OVERLAY +/* ------------------ OVERLAY ----------------------- */ +/* currently unsupported */ +/* GSTreamer's v4l2sink is buggy, as it requires the overlay to work + * while it should only require it, if overlay is requested + * once the gstreamer element is fixed, remove the overlay dummies + */ +#warning OVERLAY dummies +static int vidioc_g_fmt_overlay(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + return 0; +} + +static int vidioc_s_fmt_overlay(struct file *file, void *priv, + struct v4l2_format *fmt) +{ + return 0; +} +#endif /* V4L2L_OVERLAY */ + +/* ------------------ PARAMs ----------------------- */ + +/* get some data flow parameters, only capability, fps and readbuffers has + * effect on this driver + * called on VIDIOC_G_PARM + */ +static int vidioc_g_parm(struct file *file, void *priv, + struct v4l2_streamparm *parm) +{ + /* do not care about type of opener, hope these enums would always be + * compatible */ + struct v4l2_loopback_device *dev; + MARK(); + + dev = v4l2loopback_getdevice(file); + parm->parm.capture = dev->capture_param; + return 0; +} + +/* get some data flow parameters, only capability, fps and readbuffers has + * effect on this driver + * called on VIDIOC_S_PARM + */ +static int vidioc_s_parm(struct file *file, void *priv, + struct v4l2_streamparm *parm) +{ + struct v4l2_loopback_device *dev; + int err = 0; + MARK(); + + dev = v4l2loopback_getdevice(file); + dprintk("vidioc_s_parm called frate=%d/%d\n", + parm->parm.capture.timeperframe.numerator, + parm->parm.capture.timeperframe.denominator); + + switch (parm->type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if ((err = set_timeperframe( + dev, &parm->parm.capture.timeperframe)) < 0) + return err; + break; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if ((err = set_timeperframe( + dev, &parm->parm.capture.timeperframe)) < 0) + return err; + break; + default: + return -1; + } + + parm->parm.capture = dev->capture_param; + return 0; +} + +#ifdef V4L2LOOPBACK_WITH_STD +/* sets a tv standard, actually we do not need to handle this any special way + * added to support effecttv + * called on VIDIOC_S_STD + */ +static int vidioc_s_std(struct file *file, void *fh, v4l2_std_id *_std) +{ + v4l2_std_id req_std = 0, supported_std = 0; + const v4l2_std_id all_std = V4L2_STD_ALL, no_std = 0; + + if (_std) { + req_std = *_std; + *_std = all_std; + } + + /* we support everything in V4L2_STD_ALL, but not more... */ + supported_std = (all_std & req_std); + if (no_std == supported_std) + return -EINVAL; + + return 0; +} + +/* gets a fake video standard + * called on VIDIOC_G_STD + */ +static int vidioc_g_std(struct file *file, void *fh, v4l2_std_id *norm) +{ + if (norm) + *norm = V4L2_STD_ALL; + return 0; +} +/* gets a fake video standard + * called on VIDIOC_QUERYSTD + */ +static int vidioc_querystd(struct file *file, void *fh, v4l2_std_id *norm) +{ + if (norm) + *norm = V4L2_STD_ALL; + return 0; +} +#endif /* V4L2LOOPBACK_WITH_STD */ + +static int v4l2loopback_set_ctrl(struct v4l2_loopback_device *dev, u32 id, + s64 val) +{ + switch (id) { + case CID_KEEP_FORMAT: + if (val < 0 || val > 1) + return -EINVAL; + dev->keep_format = val; + try_free_buffers( + dev); /* will only free buffers if !keep_format */ + break; + case CID_SUSTAIN_FRAMERATE: + if (val < 0 || val > 1) + return -EINVAL; + spin_lock_bh(&dev->lock); + dev->sustain_framerate = val; + check_timers(dev); + spin_unlock_bh(&dev->lock); + break; + case CID_TIMEOUT: + if (val < 0 || val > MAX_TIMEOUT) + return -EINVAL; + spin_lock_bh(&dev->lock); + dev->timeout_jiffies = msecs_to_jiffies(val); + check_timers(dev); + spin_unlock_bh(&dev->lock); + allocate_timeout_image(dev); + break; + case CID_TIMEOUT_IMAGE_IO: + dev->timeout_image_io = 1; + break; + default: + return -EINVAL; + } + return 0; +} + +static int v4l2loopback_s_ctrl(struct v4l2_ctrl *ctrl) +{ + struct v4l2_loopback_device *dev = container_of( + ctrl->handler, struct v4l2_loopback_device, ctrl_handler); + return v4l2loopback_set_ctrl(dev, ctrl->id, ctrl->val); +} + +/* returns set of device outputs, in our case there is only one + * called on VIDIOC_ENUMOUTPUT + */ +static int vidioc_enum_output(struct file *file, void *fh, + struct v4l2_output *outp) +{ + __u32 index = outp->index; + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + MARK(); + + if (!dev->announce_all_caps && !dev->ready_for_output) + return -ENOTTY; + + if (0 != index) + return -EINVAL; + + /* clear all data (including the reserved fields) */ + memset(outp, 0, sizeof(*outp)); + + outp->index = index; + strscpy(outp->name, "loopback in", sizeof(outp->name)); + outp->type = V4L2_OUTPUT_TYPE_ANALOG; + outp->audioset = 0; + outp->modulator = 0; +#ifdef V4L2LOOPBACK_WITH_STD + outp->std = V4L2_STD_ALL; +#ifdef V4L2_OUT_CAP_STD + outp->capabilities |= V4L2_OUT_CAP_STD; +#endif /* V4L2_OUT_CAP_STD */ +#endif /* V4L2LOOPBACK_WITH_STD */ + + return 0; +} + +/* which output is currently active, + * called on VIDIOC_G_OUTPUT + */ +static int vidioc_g_output(struct file *file, void *fh, unsigned int *i) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + if (!dev->announce_all_caps && !dev->ready_for_output) + return -ENOTTY; + if (i) + *i = 0; + return 0; +} + +/* set output, can make sense if we have more than one video src, + * called on VIDIOC_S_OUTPUT + */ +static int vidioc_s_output(struct file *file, void *fh, unsigned int i) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + if (!dev->announce_all_caps && !dev->ready_for_output) + return -ENOTTY; + + if (i) + return -EINVAL; + + return 0; +} + +/* returns set of device inputs, in our case there is only one, + * but later I may add more + * called on VIDIOC_ENUMINPUT + */ +static int vidioc_enum_input(struct file *file, void *fh, + struct v4l2_input *inp) +{ + struct v4l2_loopback_device *dev; + __u32 index = inp->index; + MARK(); + + if (0 != index) + return -EINVAL; + + /* clear all data (including the reserved fields) */ + memset(inp, 0, sizeof(*inp)); + + inp->index = index; + strscpy(inp->name, "loopback", sizeof(inp->name)); + inp->type = V4L2_INPUT_TYPE_CAMERA; + inp->audioset = 0; + inp->tuner = 0; + inp->status = 0; + +#ifdef V4L2LOOPBACK_WITH_STD + inp->std = V4L2_STD_ALL; +#ifdef V4L2_IN_CAP_STD + inp->capabilities |= V4L2_IN_CAP_STD; +#endif +#endif /* V4L2LOOPBACK_WITH_STD */ + + dev = v4l2loopback_getdevice(file); + if (!dev->ready_for_capture) { + inp->status |= V4L2_IN_ST_NO_SIGNAL; + } + + return 0; +} + +/* which input is currently active, + * called on VIDIOC_G_INPUT + */ +static int vidioc_g_input(struct file *file, void *fh, unsigned int *i) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + if (!dev->announce_all_caps && !dev->ready_for_capture) + return -ENOTTY; + if (i) + *i = 0; + return 0; +} + +/* set input, can make sense if we have more than one video src, + * called on VIDIOC_S_INPUT + */ +static int vidioc_s_input(struct file *file, void *fh, unsigned int i) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + if (!dev->announce_all_caps && !dev->ready_for_capture) + return -ENOTTY; + if (i == 0) + return 0; + return -EINVAL; +} + +/* --------------- V4L2 ioctl buffer related calls ----------------- */ + +/* negotiate buffer type + * only mmap streaming supported + * called on VIDIOC_REQBUFS + */ +static int vidioc_reqbufs(struct file *file, void *fh, + struct v4l2_requestbuffers *b) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + int i; + MARK(); + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + + dprintk("reqbufs: %d\t%d=%d\n", b->memory, b->count, + dev->buffers_number); + + if (opener->timeout_image_io) { + dev->timeout_image_io = 0; + if (b->memory != V4L2_MEMORY_MMAP) + return -EINVAL; + b->count = 2; + return 0; + } + + if (V4L2_TYPE_IS_OUTPUT(b->type) && (!dev->ready_for_output)) { + return -EBUSY; + } + + init_buffers(dev); + switch (b->memory) { + case V4L2_MEMORY_MMAP: + /* do nothing here, buffers are always allocated */ + if (b->count < 1 || dev->buffers_number < 1) + return 0; + + if (b->count > dev->buffers_number) + b->count = dev->buffers_number; + + /* make sure that outbufs_list contains buffers from 0 to used_buffers-1 + * actually, it will have been already populated via v4l2_loopback_init() + * at this point */ + if (list_empty(&dev->outbufs_list)) { + for (i = 0; i < dev->used_buffers; ++i) + list_add_tail(&dev->buffers[i].list_head, + &dev->outbufs_list); + } + + /* also, if dev->used_buffers is going to be decreased, we should remove + * out-of-range buffers from outbufs_list, and fix bufpos2index mapping */ + if (b->count < dev->used_buffers) { + struct v4l2l_buffer *pos, *n; + + list_for_each_entry_safe(pos, n, &dev->outbufs_list, + list_head) { + if (pos->buffer.index >= b->count) + list_del(&pos->list_head); + } + + /* after we update dev->used_buffers, buffers in outbufs_list will + * correspond to dev->write_position + [0;b->count-1] range */ + i = v4l2l_mod64(dev->write_position, b->count); + list_for_each_entry(pos, &dev->outbufs_list, + list_head) { + dev->bufpos2index[i % b->count] = + pos->buffer.index; + ++i; + } + } + + opener->buffers_number = b->count; + if (opener->buffers_number < dev->used_buffers) + dev->used_buffers = opener->buffers_number; + return 0; + default: + return -EINVAL; + } +} + +/* returns buffer asked for; + * give app as many buffers as it wants, if it less than MAX, + * but map them in our inner buffers + * called on VIDIOC_QUERYBUF + */ +static int vidioc_querybuf(struct file *file, void *fh, struct v4l2_buffer *b) +{ + enum v4l2_buf_type type; + int index; + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + + MARK(); + + type = b->type; + index = b->index; + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + + if ((b->type != V4L2_BUF_TYPE_VIDEO_CAPTURE) && + (b->type != V4L2_BUF_TYPE_VIDEO_OUTPUT)) { + return -EINVAL; + } + if (b->index > max_buffers) + return -EINVAL; + + if (opener->timeout_image_io) + *b = dev->timeout_image_buffer.buffer; + else + *b = dev->buffers[b->index % dev->used_buffers].buffer; + + b->type = type; + b->index = index; + dprintkrw("buffer type: %d (of %d with size=%ld)\n", b->memory, + dev->buffers_number, dev->buffer_size); + + /* Hopefully fix 'DQBUF return bad index if queue bigger then 2 for capture' + https://github.com/umlaeute/v4l2loopback/issues/60 */ + b->flags &= ~V4L2_BUF_FLAG_DONE; + b->flags |= V4L2_BUF_FLAG_QUEUED; + + return 0; +} + +static void buffer_written(struct v4l2_loopback_device *dev, + struct v4l2l_buffer *buf) +{ + del_timer_sync(&dev->sustain_timer); + del_timer_sync(&dev->timeout_timer); + + spin_lock_bh(&dev->list_lock); + list_move_tail(&buf->list_head, &dev->outbufs_list); + spin_unlock_bh(&dev->list_lock); + + spin_lock_bh(&dev->lock); + dev->bufpos2index[v4l2l_mod64(dev->write_position, dev->used_buffers)] = + buf->buffer.index; + ++dev->write_position; + dev->reread_count = 0; + + check_timers(dev); + spin_unlock_bh(&dev->lock); +} + +/* put buffer to queue + * called on VIDIOC_QBUF + */ +static int vidioc_qbuf(struct file *file, void *fh, struct v4l2_buffer *buf) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + struct v4l2l_buffer *b; + int index; + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + + if (buf->index > max_buffers) + return -EINVAL; + if (opener->timeout_image_io) + return 0; + + index = buf->index % dev->used_buffers; + b = &dev->buffers[index]; + + switch (buf->type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + dprintkrw( + "qbuf(CAPTURE)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", + index, buf->index, buf, buf->type, buf->bytesused, + buf->length, buf->flags, buf->field, + (long long)buf->timestamp.tv_sec, + (long int)buf->timestamp.tv_usec, buf->sequence); + set_queued(b); + return 0; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + dprintkrw( + "qbuf(OUTPUT)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", + index, buf->index, buf, buf->type, buf->bytesused, + buf->length, buf->flags, buf->field, + (long long)buf->timestamp.tv_sec, + (long int)buf->timestamp.tv_usec, buf->sequence); + if ((!(b->buffer.flags & V4L2_BUF_FLAG_TIMESTAMP_COPY)) && + (buf->timestamp.tv_sec == 0 && buf->timestamp.tv_usec == 0)) + v4l2l_get_timestamp(&b->buffer); + else { + b->buffer.timestamp = buf->timestamp; + b->buffer.flags |= V4L2_BUF_FLAG_TIMESTAMP_COPY; + } + if (dev->pix_format_has_valid_sizeimage) { + if (buf->bytesused >= dev->pix_format.sizeimage) { + b->buffer.bytesused = dev->pix_format.sizeimage; + } else { +#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 5, 0) + dev_warn_ratelimited( + &dev->vdev->dev, +#else + dprintkrw( +#endif + "warning queued output buffer bytesused too small %d < %d\n", + buf->bytesused, + dev->pix_format.sizeimage); + b->buffer.bytesused = buf->bytesused; + } + } else { + b->buffer.bytesused = buf->bytesused; + } + + set_done(b); + buffer_written(dev, b); + + /* Hopefully fix 'DQBUF return bad index if queue bigger then 2 for capture' + https://github.com/umlaeute/v4l2loopback/issues/60 */ + buf->flags &= ~V4L2_BUF_FLAG_DONE; + buf->flags |= V4L2_BUF_FLAG_QUEUED; + + wake_up_all(&dev->read_event); + return 0; + default: + return -EINVAL; + } +} + +static int can_read(struct v4l2_loopback_device *dev, + struct v4l2_loopback_opener *opener) +{ + int ret; + + spin_lock_bh(&dev->lock); + check_timers(dev); + ret = dev->write_position > opener->read_position || + dev->reread_count > opener->reread_count || dev->timeout_happened; + spin_unlock_bh(&dev->lock); + return ret; +} + +static int get_capture_buffer(struct file *file) +{ + struct v4l2_loopback_device *dev = v4l2loopback_getdevice(file); + struct v4l2_loopback_opener *opener = fh_to_opener(file->private_data); + int pos, ret; + int timeout_happened; + + if ((file->f_flags & O_NONBLOCK) && + (dev->write_position <= opener->read_position && + dev->reread_count <= opener->reread_count && + !dev->timeout_happened)) + return -EAGAIN; + wait_event_interruptible(dev->read_event, can_read(dev, opener)); + + spin_lock_bh(&dev->lock); + if (dev->write_position == opener->read_position) { + if (dev->reread_count > opener->reread_count + 2) + opener->reread_count = dev->reread_count - 1; + ++opener->reread_count; + pos = v4l2l_mod64(opener->read_position + dev->used_buffers - 1, + dev->used_buffers); + } else { + opener->reread_count = 0; + if (dev->write_position > + opener->read_position + dev->used_buffers) + opener->read_position = dev->write_position - 1; + pos = v4l2l_mod64(opener->read_position, dev->used_buffers); + ++opener->read_position; + } + timeout_happened = dev->timeout_happened; + dev->timeout_happened = 0; + spin_unlock_bh(&dev->lock); + + ret = dev->bufpos2index[pos]; + if (timeout_happened) { + if (ret < 0) { + dprintk("trying to return not mapped buf[%d]\n", ret); + return -EFAULT; + } + /* although allocated on-demand, timeout_image is freed only + * in free_buffers(), so we don't need to worry about it being + * deallocated suddenly */ + memcpy(dev->image + dev->buffers[ret].buffer.m.offset, + dev->timeout_image, dev->buffer_size); + } + return ret; +} + +/* put buffer to dequeue + * called on VIDIOC_DQBUF + */ +static int vidioc_dqbuf(struct file *file, void *fh, struct v4l2_buffer *buf) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + int index; + struct v4l2l_buffer *b; + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + if (opener->timeout_image_io) { + *buf = dev->timeout_image_buffer.buffer; + return 0; + } + + switch (buf->type) { + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + index = get_capture_buffer(file); + if (index < 0) + return index; + dprintkrw("capture DQBUF pos: %lld index: %d\n", + (long long)(opener->read_position - 1), index); + if (!(dev->buffers[index].buffer.flags & + V4L2_BUF_FLAG_MAPPED)) { + dprintk("trying to return not mapped buf[%d]\n", index); + return -EINVAL; + } + unset_flags(&dev->buffers[index]); + *buf = dev->buffers[index].buffer; + dprintkrw( + "dqbuf(CAPTURE)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", + index, buf->index, buf, buf->type, buf->bytesused, + buf->length, buf->flags, buf->field, + (long long)buf->timestamp.tv_sec, + (long int)buf->timestamp.tv_usec, buf->sequence); + return 0; + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + spin_lock_bh(&dev->list_lock); + + b = list_entry(dev->outbufs_list.prev, struct v4l2l_buffer, + list_head); + list_move_tail(&b->list_head, &dev->outbufs_list); + + spin_unlock_bh(&dev->list_lock); + dprintkrw("output DQBUF index: %d\n", b->buffer.index); + unset_flags(b); + *buf = b->buffer; + buf->type = V4L2_BUF_TYPE_VIDEO_OUTPUT; + dprintkrw( + "dqbuf(OUTPUT)#%d: buffer#%d @ %p type=%d bytesused=%d length=%d flags=%x field=%d timestamp=%lld.%06ld sequence=%d\n", + index, buf->index, buf, buf->type, buf->bytesused, + buf->length, buf->flags, buf->field, + (long long)buf->timestamp.tv_sec, + (long int)buf->timestamp.tv_usec, buf->sequence); + return 0; + default: + return -EINVAL; + } +} + +/* ------------- STREAMING ------------------- */ + +/* start streaming + * called on VIDIOC_STREAMON + */ +static int vidioc_streamon(struct file *file, void *fh, enum v4l2_buf_type type) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + MARK(); + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + + switch (type) { + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (!dev->ready_for_capture) { + int ret = allocate_buffers(dev); + if (ret < 0) + return ret; + } + opener->type = WRITER; + dev->ready_for_output = 0; + dev->ready_for_capture++; + return 0; + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if (!dev->ready_for_capture) + return -EIO; + if (dev->active_readers > 0) + return -EBUSY; + opener->type = READER; + dev->active_readers++; + client_usage_queue_event(dev->vdev); + return 0; + default: + return -EINVAL; + } + return -EINVAL; +} + +/* stop streaming + * called on VIDIOC_STREAMOFF + */ +static int vidioc_streamoff(struct file *file, void *fh, + enum v4l2_buf_type type) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + + MARK(); + dprintk("%d\n", type); + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(fh); + switch (type) { + case V4L2_BUF_TYPE_VIDEO_OUTPUT: + if (dev->ready_for_capture > 0) + dev->ready_for_capture--; + return 0; + case V4L2_BUF_TYPE_VIDEO_CAPTURE: + if (opener->type == READER) { + opener->type = 0; + dev->active_readers--; + client_usage_queue_event(dev->vdev); + } + return 0; + default: + return -EINVAL; + } + return -EINVAL; +} + +#ifdef CONFIG_VIDEO_V4L1_COMPAT +static int vidiocgmbuf(struct file *file, void *fh, struct video_mbuf *p) +{ + struct v4l2_loopback_device *dev; + MARK(); + + dev = v4l2loopback_getdevice(file); + p->frames = dev->buffers_number; + p->offsets[0] = 0; + p->offsets[1] = 0; + p->size = dev->buffer_size; + return 0; +} +#endif + +static void client_usage_queue_event(struct video_device *vdev) +{ + struct v4l2_event ev; + struct v4l2_loopback_device *dev; + + dev = container_of(vdev->v4l2_dev, struct v4l2_loopback_device, + v4l2_dev); + + memset(&ev, 0, sizeof(ev)); + ev.type = V4L2_EVENT_PRI_CLIENT_USAGE; + ((struct v4l2_event_client_usage *)&ev.u)->count = dev->active_readers; + + v4l2_event_queue(vdev, &ev); +} + +static int client_usage_ops_add(struct v4l2_subscribed_event *sev, + unsigned elems) +{ + if (!(sev->flags & V4L2_EVENT_SUB_FL_SEND_INITIAL)) + return 0; + + client_usage_queue_event(sev->fh->vdev); + return 0; +} + +static void client_usage_ops_replace(struct v4l2_event *old, + const struct v4l2_event *new) +{ + *((struct v4l2_event_client_usage *)&old->u) = + *((struct v4l2_event_client_usage *)&new->u); +} + +static void client_usage_ops_merge(const struct v4l2_event *old, + struct v4l2_event *new) +{ + *((struct v4l2_event_client_usage *)&new->u) = + *((struct v4l2_event_client_usage *)&old->u); +} + +const struct v4l2_subscribed_event_ops client_usage_ops = { + .add = client_usage_ops_add, + .replace = client_usage_ops_replace, + .merge = client_usage_ops_merge, +}; + +static int vidioc_subscribe_event(struct v4l2_fh *fh, + const struct v4l2_event_subscription *sub) +{ + switch (sub->type) { + case V4L2_EVENT_CTRL: + return v4l2_ctrl_subscribe_event(fh, sub); + case V4L2_EVENT_PRI_CLIENT_USAGE: + return v4l2_event_subscribe(fh, sub, 0, &client_usage_ops); + } + + return -EINVAL; +} + +/* file operations */ +static void vm_open(struct vm_area_struct *vma) +{ + struct v4l2l_buffer *buf; + MARK(); + + buf = vma->vm_private_data; + buf->use_count++; + + buf->buffer.flags |= V4L2_BUF_FLAG_MAPPED; +} + +static void vm_close(struct vm_area_struct *vma) +{ + struct v4l2l_buffer *buf; + MARK(); + + buf = vma->vm_private_data; + buf->use_count--; + + if (buf->use_count <= 0) + buf->buffer.flags &= ~V4L2_BUF_FLAG_MAPPED; +} + +static struct vm_operations_struct vm_ops = { + .open = vm_open, + .close = vm_close, +}; + +static int v4l2_loopback_mmap(struct file *file, struct vm_area_struct *vma) +{ + u8 *addr; + unsigned long start; + unsigned long size; + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + struct v4l2l_buffer *buffer = NULL; + MARK(); + + start = (unsigned long)vma->vm_start; + size = (unsigned long)(vma->vm_end - vma->vm_start); + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(file->private_data); + + if (size > dev->buffer_size) { + dprintk("userspace tries to mmap too much, fail\n"); + return -EINVAL; + } + if (opener->timeout_image_io) { + /* we are going to map the timeout_image_buffer */ + if ((vma->vm_pgoff << PAGE_SHIFT) != + dev->buffer_size * MAX_BUFFERS) { + dprintk("invalid mmap offset for timeout_image_io mode\n"); + return -EINVAL; + } + } else if ((vma->vm_pgoff << PAGE_SHIFT) > + dev->buffer_size * (dev->buffers_number - 1)) { + dprintk("userspace tries to mmap too far, fail\n"); + return -EINVAL; + } + + /* FIXXXXXME: allocation should not happen here! */ + if (NULL == dev->image) + if (allocate_buffers(dev) < 0) + return -EINVAL; + + if (opener->timeout_image_io) { + buffer = &dev->timeout_image_buffer; + addr = dev->timeout_image; + } else { + int i; + for (i = 0; i < dev->buffers_number; ++i) { + buffer = &dev->buffers[i]; + if ((buffer->buffer.m.offset >> PAGE_SHIFT) == + vma->vm_pgoff) + break; + } + + if (i >= dev->buffers_number) + return -EINVAL; + + addr = dev->image + (vma->vm_pgoff << PAGE_SHIFT); + } + + while (size > 0) { + struct page *page; + + page = vmalloc_to_page(addr); + + if (vm_insert_page(vma, start, page) < 0) + return -EAGAIN; + + start += PAGE_SIZE; + addr += PAGE_SIZE; + size -= PAGE_SIZE; + } + + vma->vm_ops = &vm_ops; + vma->vm_private_data = buffer; + + vm_open(vma); + + MARK(); + return 0; +} + +static unsigned int v4l2_loopback_poll(struct file *file, + struct poll_table_struct *pts) +{ + struct v4l2_loopback_opener *opener; + struct v4l2_loopback_device *dev; + __poll_t req_events = poll_requested_events(pts); + int ret_mask = 0; + MARK(); + + opener = fh_to_opener(file->private_data); + dev = v4l2loopback_getdevice(file); + + if (req_events & POLLPRI) { + if (!v4l2_event_pending(&opener->fh)) + poll_wait(file, &opener->fh.wait, pts); + if (v4l2_event_pending(&opener->fh)) { + ret_mask |= POLLPRI; + if (!(req_events & DEFAULT_POLLMASK)) + return ret_mask; + } + } + + switch (opener->type) { + case WRITER: + ret_mask |= POLLOUT | POLLWRNORM; + break; + case READER: + if (!can_read(dev, opener)) { + if (ret_mask) + return ret_mask; + poll_wait(file, &dev->read_event, pts); + } + if (can_read(dev, opener)) + ret_mask |= POLLIN | POLLRDNORM; + if (v4l2_event_pending(&opener->fh)) + ret_mask |= POLLPRI; + break; + default: + break; + } + + MARK(); + return ret_mask; +} + +/* do not want to limit device opens, it can be as many readers as user want, + * writers are limited by means of setting writer field */ +static int v4l2_loopback_open(struct file *file) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_opener *opener; + MARK(); + dev = v4l2loopback_getdevice(file); + if (dev->open_count.counter >= dev->max_openers) + return -EBUSY; + /* kfree on close */ + opener = kzalloc(sizeof(*opener), GFP_KERNEL); + if (opener == NULL) + return -ENOMEM; + + atomic_inc(&dev->open_count); + + opener->timeout_image_io = dev->timeout_image_io; + if (opener->timeout_image_io) { + int r = allocate_timeout_image(dev); + + if (r < 0) { + dprintk("timeout image allocation failed\n"); + + atomic_dec(&dev->open_count); + + kfree(opener); + return r; + } + } + + v4l2_fh_init(&opener->fh, video_devdata(file)); + file->private_data = &opener->fh; + + v4l2_fh_add(&opener->fh); + dprintk("opened dev:%p with image:%p\n", dev, dev ? dev->image : NULL); + MARK(); + return 0; +} + +static int v4l2_loopback_close(struct file *file) +{ + struct v4l2_loopback_opener *opener; + struct v4l2_loopback_device *dev; + int is_writer = 0, is_reader = 0; + MARK(); + + opener = fh_to_opener(file->private_data); + dev = v4l2loopback_getdevice(file); + + if (WRITER == opener->type) + is_writer = 1; + if (READER == opener->type) + is_reader = 1; + + atomic_dec(&dev->open_count); + if (dev->open_count.counter == 0) { + del_timer_sync(&dev->sustain_timer); + del_timer_sync(&dev->timeout_timer); + } + try_free_buffers(dev); + + v4l2_fh_del(&opener->fh); + v4l2_fh_exit(&opener->fh); + + kfree(opener); + if (is_writer) + dev->ready_for_output = 1; + if (is_reader) { + dev->active_readers--; + client_usage_queue_event(dev->vdev); + } + MARK(); + return 0; +} + +static ssize_t v4l2_loopback_read(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + int read_index; + struct v4l2_loopback_device *dev; + struct v4l2_buffer *b; + MARK(); + + dev = v4l2loopback_getdevice(file); + + read_index = get_capture_buffer(file); + if (read_index < 0) + return read_index; + if (count > dev->buffer_size) + count = dev->buffer_size; + b = &dev->buffers[read_index].buffer; + if (count > b->bytesused) + count = b->bytesused; + if (copy_to_user((void *)buf, (void *)(dev->image + b->m.offset), + count)) { + printk(KERN_ERR + "v4l2-loopback: failed copy_to_user() in read buf\n"); + return -EFAULT; + } + dprintkrw("leave v4l2_loopback_read()\n"); + return count; +} + +static ssize_t v4l2_loopback_write(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct v4l2_loopback_opener *opener; + struct v4l2_loopback_device *dev; + int write_index; + struct v4l2_buffer *b; + int err = 0; + + MARK(); + + dev = v4l2loopback_getdevice(file); + opener = fh_to_opener(file->private_data); + + if (UNNEGOTIATED == opener->type) { + spin_lock(&dev->lock); + + if (dev->ready_for_output) { + err = vidioc_streamon(file, file->private_data, + V4L2_BUF_TYPE_VIDEO_OUTPUT); + } + + spin_unlock(&dev->lock); + + if (err < 0) + return err; + } + + if (WRITER != opener->type) + return -EINVAL; + + if (!dev->ready_for_capture) { + int ret = allocate_buffers(dev); + if (ret < 0) + return ret; + dev->ready_for_capture = 1; + } + dprintkrw("v4l2_loopback_write() trying to write %zu bytes\n", count); + if (count > dev->buffer_size) + count = dev->buffer_size; + + write_index = v4l2l_mod64(dev->write_position, dev->used_buffers); + b = &dev->buffers[write_index].buffer; + + if (copy_from_user((void *)(dev->image + b->m.offset), (void *)buf, + count)) { + printk(KERN_ERR + "v4l2-loopback: failed copy_from_user() in write buf, could not write %zu\n", + count); + return -EFAULT; + } + v4l2l_get_timestamp(b); + b->bytesused = count; + b->sequence = dev->write_position; + buffer_written(dev, &dev->buffers[write_index]); + wake_up_all(&dev->read_event); + dprintkrw("leave v4l2_loopback_write()\n"); + return count; +} + +/* init functions */ +/* frees buffers, if already allocated */ +static void free_buffers(struct v4l2_loopback_device *dev) +{ + MARK(); + dprintk("freeing image@%p for dev:%p\n", dev ? dev->image : NULL, dev); + if (!dev) + return; + if (dev->image) { + vfree(dev->image); + dev->image = NULL; + } + if (dev->timeout_image) { + vfree(dev->timeout_image); + dev->timeout_image = NULL; + } + dev->imagesize = 0; +} +/* frees buffers, if they are no longer needed */ +static void try_free_buffers(struct v4l2_loopback_device *dev) +{ + MARK(); + if (0 == dev->open_count.counter && !dev->keep_format) { + free_buffers(dev); + dev->ready_for_capture = 0; + dev->buffer_size = 0; + dev->write_position = 0; + } +} +/* allocates buffers, if buffer_size is set */ +static int allocate_buffers(struct v4l2_loopback_device *dev) +{ + int err; + + MARK(); + /* vfree on close file operation in case no open handles left */ + + if (dev->buffer_size < 1 || dev->buffers_number < 1) + return -EINVAL; + + if ((__LONG_MAX__ / dev->buffer_size) < dev->buffers_number) + return -ENOSPC; + + if (dev->image) { + dprintk("allocating buffers again: %ld %ld\n", + dev->buffer_size * dev->buffers_number, dev->imagesize); + /* FIXME: prevent double allocation more intelligently! */ + if (dev->buffer_size * dev->buffers_number == dev->imagesize) + return 0; + + /* check whether the total number of readers/writers is <=1 */ + if ((dev->ready_for_capture + dev->active_readers) <= 1) + free_buffers(dev); + else + return -EINVAL; + } + + dev->imagesize = (unsigned long)dev->buffer_size * + (unsigned long)dev->buffers_number; + + dprintk("allocating %ld = %ldx%d\n", dev->imagesize, dev->buffer_size, + dev->buffers_number); + err = -ENOMEM; + + if (dev->timeout_jiffies > 0) { + err = allocate_timeout_image(dev); + if (err < 0) + goto error; + } + + dev->image = vmalloc(dev->imagesize); + if (dev->image == NULL) + goto error; + + dprintk("vmallocated %ld bytes\n", dev->imagesize); + MARK(); + + init_buffers(dev); + return 0; + +error: + free_buffers(dev); + return err; +} + +/* init inner buffers, they are capture mode and flags are set as + * for capture mod buffers */ +static void init_buffers(struct v4l2_loopback_device *dev) +{ + int i; + int buffer_size; + int bytesused; + MARK(); + + buffer_size = dev->buffer_size; + bytesused = dev->pix_format.sizeimage; + for (i = 0; i < dev->buffers_number; ++i) { + struct v4l2_buffer *b = &dev->buffers[i].buffer; + b->index = i; + b->bytesused = bytesused; + b->length = buffer_size; + b->field = V4L2_FIELD_NONE; + b->flags = 0; + b->m.offset = i * buffer_size; + b->memory = V4L2_MEMORY_MMAP; + b->sequence = 0; + b->timestamp.tv_sec = 0; + b->timestamp.tv_usec = 0; + b->type = V4L2_BUF_TYPE_VIDEO_CAPTURE; + + v4l2l_get_timestamp(b); + } + dev->timeout_image_buffer = dev->buffers[0]; + dev->timeout_image_buffer.buffer.m.offset = MAX_BUFFERS * buffer_size; + MARK(); +} + +static int allocate_timeout_image(struct v4l2_loopback_device *dev) +{ + MARK(); + if (dev->buffer_size <= 0) { + dev->timeout_image_io = 0; + return -EINVAL; + } + + if (dev->timeout_image == NULL) { + dev->timeout_image = vzalloc(dev->buffer_size); + if (dev->timeout_image == NULL) { + dev->timeout_image_io = 0; + return -ENOMEM; + } + } + return 0; +} + +/* fills and register video device */ +static void init_vdev(struct video_device *vdev, int nr) +{ + MARK(); + +#ifdef V4L2LOOPBACK_WITH_STD + vdev->tvnorms = V4L2_STD_ALL; +#endif /* V4L2LOOPBACK_WITH_STD */ + + vdev->vfl_type = VFL_TYPE_VIDEO; + vdev->fops = &v4l2_loopback_fops; + vdev->ioctl_ops = &v4l2_loopback_ioctl_ops; + vdev->release = &video_device_release; + vdev->minor = -1; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 7, 0) + vdev->device_caps = V4L2_CAP_DEVICE_CAPS | V4L2_CAP_VIDEO_CAPTURE | + V4L2_CAP_VIDEO_OUTPUT | V4L2_CAP_READWRITE | + V4L2_CAP_STREAMING; +#endif + + if (debug > 1) + vdev->dev_debug = V4L2_DEV_DEBUG_IOCTL | + V4L2_DEV_DEBUG_IOCTL_ARG; + + vdev->vfl_dir = VFL_DIR_M2M; + + MARK(); +} + +/* init default capture parameters, only fps may be changed in future */ +static void init_capture_param(struct v4l2_captureparm *capture_param) +{ + MARK(); + capture_param->capability = 0; + capture_param->capturemode = 0; + capture_param->extendedmode = 0; + capture_param->readbuffers = max_buffers; + capture_param->timeperframe.numerator = 1; + capture_param->timeperframe.denominator = 30; +} + +static void check_timers(struct v4l2_loopback_device *dev) +{ + if (!dev->ready_for_capture) + return; + + if (dev->timeout_jiffies > 0 && !timer_pending(&dev->timeout_timer)) + mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); + if (dev->sustain_framerate && !timer_pending(&dev->sustain_timer)) + mod_timer(&dev->sustain_timer, + jiffies + dev->frame_jiffies * 3 / 2); +} +#ifdef HAVE_TIMER_SETUP +static void sustain_timer_clb(struct timer_list *t) +{ + struct v4l2_loopback_device *dev = from_timer(dev, t, sustain_timer); +#else +static void sustain_timer_clb(unsigned long nr) +{ + struct v4l2_loopback_device *dev = + idr_find(&v4l2loopback_index_idr, nr); +#endif + spin_lock(&dev->lock); + if (dev->sustain_framerate) { + dev->reread_count++; + dprintkrw("reread: %lld %d\n", (long long)dev->write_position, + dev->reread_count); + if (dev->reread_count == 1) + mod_timer(&dev->sustain_timer, + jiffies + max(1UL, dev->frame_jiffies / 2)); + else + mod_timer(&dev->sustain_timer, + jiffies + dev->frame_jiffies); + wake_up_all(&dev->read_event); + } + spin_unlock(&dev->lock); +} +#ifdef HAVE_TIMER_SETUP +static void timeout_timer_clb(struct timer_list *t) +{ + struct v4l2_loopback_device *dev = from_timer(dev, t, timeout_timer); +#else +static void timeout_timer_clb(unsigned long nr) +{ + struct v4l2_loopback_device *dev = + idr_find(&v4l2loopback_index_idr, nr); +#endif + spin_lock(&dev->lock); + if (dev->timeout_jiffies > 0) { + dev->timeout_happened = 1; + mod_timer(&dev->timeout_timer, jiffies + dev->timeout_jiffies); + wake_up_all(&dev->read_event); + } + spin_unlock(&dev->lock); +} + +/* init loopback main structure */ +#define DEFAULT_FROM_CONF(confmember, default_condition, default_value) \ + ((conf) ? \ + ((conf->confmember default_condition) ? (default_value) : \ + (conf->confmember)) : \ + default_value) + +static int v4l2_loopback_add(struct v4l2_loopback_config *conf, int *ret_nr) +{ + struct v4l2_loopback_device *dev; + struct v4l2_ctrl_handler *hdl; + struct v4l2loopback_private *vdev_priv = NULL; + + int err = -ENOMEM; + + u32 _width = V4L2LOOPBACK_SIZE_DEFAULT_WIDTH; + u32 _height = V4L2LOOPBACK_SIZE_DEFAULT_HEIGHT; + + u32 _min_width = DEFAULT_FROM_CONF(min_width, + < V4L2LOOPBACK_SIZE_MIN_WIDTH, + V4L2LOOPBACK_SIZE_MIN_WIDTH); + u32 _min_height = DEFAULT_FROM_CONF(min_height, + < V4L2LOOPBACK_SIZE_MIN_HEIGHT, + V4L2LOOPBACK_SIZE_MIN_HEIGHT); + u32 _max_width = DEFAULT_FROM_CONF(max_width, < _min_width, max_width); + u32 _max_height = + DEFAULT_FROM_CONF(max_height, < _min_height, max_height); + bool _announce_all_caps = (conf && conf->announce_all_caps >= 0) ? + (conf->announce_all_caps) : + V4L2LOOPBACK_DEFAULT_EXCLUSIVECAPS; + int _max_buffers = DEFAULT_FROM_CONF(max_buffers, <= 0, max_buffers); + int _max_openers = DEFAULT_FROM_CONF(max_openers, <= 0, max_openers); + + int nr = -1; + + _announce_all_caps = (!!_announce_all_caps); + + if (conf) { + const int output_nr = conf->output_nr; +#ifdef SPLIT_DEVICES + const int capture_nr = conf->capture_nr; +#else + const int capture_nr = output_nr; +#endif + if (capture_nr >= 0 && output_nr == capture_nr) { + nr = output_nr; + } else if (capture_nr < 0 && output_nr < 0) { + nr = -1; + } else if (capture_nr < 0) { + nr = output_nr; + } else if (output_nr < 0) { + nr = capture_nr; + } else { + printk(KERN_ERR + "split OUTPUT and CAPTURE devices not yet supported."); + printk(KERN_INFO + "both devices must have the same number (%d != %d).", + output_nr, capture_nr); + return -EINVAL; + } + } + + if (idr_find(&v4l2loopback_index_idr, nr)) + return -EEXIST; + + dprintk("creating v4l2loopback-device #%d\n", nr); + dev = kzalloc(sizeof(*dev), GFP_KERNEL); + if (!dev) + return -ENOMEM; + + /* allocate id, if @id >= 0, we're requesting that specific id */ + if (nr >= 0) { + err = idr_alloc(&v4l2loopback_index_idr, dev, nr, nr + 1, + GFP_KERNEL); + if (err == -ENOSPC) + err = -EEXIST; + } else { + err = idr_alloc(&v4l2loopback_index_idr, dev, 0, 0, GFP_KERNEL); + } + if (err < 0) + goto out_free_dev; + nr = err; + err = -ENOMEM; + + if (conf && conf->card_label[0]) { + snprintf(dev->card_label, sizeof(dev->card_label), "%s", + conf->card_label); + } else { + snprintf(dev->card_label, sizeof(dev->card_label), + "Dummy video device (0x%04X)", nr); + } + snprintf(dev->v4l2_dev.name, sizeof(dev->v4l2_dev.name), + "v4l2loopback-%03d", nr); + + err = v4l2_device_register(NULL, &dev->v4l2_dev); + if (err) + goto out_free_idr; + MARK(); + + dev->vdev = video_device_alloc(); + if (dev->vdev == NULL) { + err = -ENOMEM; + goto out_unregister; + } + + vdev_priv = kzalloc(sizeof(struct v4l2loopback_private), GFP_KERNEL); + if (vdev_priv == NULL) { + err = -ENOMEM; + goto out_unregister; + } + + video_set_drvdata(dev->vdev, vdev_priv); + if (video_get_drvdata(dev->vdev) == NULL) { + err = -ENOMEM; + goto out_unregister; + } + + MARK(); + snprintf(dev->vdev->name, sizeof(dev->vdev->name), "%s", + dev->card_label); + + vdev_priv->device_nr = nr; + + init_vdev(dev->vdev, nr); + dev->vdev->v4l2_dev = &dev->v4l2_dev; + init_capture_param(&dev->capture_param); + err = set_timeperframe(dev, &dev->capture_param.timeperframe); + if (err) + goto out_unregister; + dev->keep_format = 0; + dev->sustain_framerate = 0; + + dev->announce_all_caps = _announce_all_caps; + dev->min_width = _min_width; + dev->min_height = _min_height; + dev->max_width = _max_width; + dev->max_height = _max_height; + dev->max_openers = _max_openers; + dev->buffers_number = dev->used_buffers = _max_buffers; + + dev->write_position = 0; + + MARK(); + spin_lock_init(&dev->lock); + spin_lock_init(&dev->list_lock); + INIT_LIST_HEAD(&dev->outbufs_list); + if (list_empty(&dev->outbufs_list)) { + int i; + + for (i = 0; i < dev->used_buffers; ++i) + list_add_tail(&dev->buffers[i].list_head, + &dev->outbufs_list); + } + memset(dev->bufpos2index, 0, sizeof(dev->bufpos2index)); + atomic_set(&dev->open_count, 0); + dev->ready_for_capture = 0; + dev->ready_for_output = 1; + + dev->buffer_size = 0; + dev->image = NULL; + dev->imagesize = 0; +#ifdef HAVE_TIMER_SETUP + timer_setup(&dev->sustain_timer, sustain_timer_clb, 0); + timer_setup(&dev->timeout_timer, timeout_timer_clb, 0); +#else + setup_timer(&dev->sustain_timer, sustain_timer_clb, nr); + setup_timer(&dev->timeout_timer, timeout_timer_clb, nr); +#endif + dev->reread_count = 0; + dev->timeout_jiffies = 0; + dev->timeout_image = NULL; + dev->timeout_happened = 0; + + hdl = &dev->ctrl_handler; + err = v4l2_ctrl_handler_init(hdl, 4); + if (err) + goto out_unregister; + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_keepformat, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_sustainframerate, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeout, NULL); + v4l2_ctrl_new_custom(hdl, &v4l2loopback_ctrl_timeoutimageio, NULL); + if (hdl->error) { + err = hdl->error; + goto out_free_handler; + } + dev->v4l2_dev.ctrl_handler = hdl; + + err = v4l2_ctrl_handler_setup(hdl); + if (err) + goto out_free_handler; + + /* FIXME set buffers to 0 */ + + /* Set initial format */ + if (_width < _min_width) + _width = _min_width; + if (_width > _max_width) + _width = _max_width; + if (_height < _min_height) + _height = _min_height; + if (_height > _max_height) + _height = _max_height; + + dev->pix_format.width = _width; + dev->pix_format.height = _height; + dev->pix_format.pixelformat = formats[0].fourcc; + dev->pix_format.colorspace = + V4L2_COLORSPACE_DEFAULT; /* do we need to set this ? */ + dev->pix_format.field = V4L2_FIELD_NONE; + + dev->buffer_size = PAGE_ALIGN(dev->pix_format.sizeimage); + dprintk("buffer_size = %ld (=%d)\n", dev->buffer_size, + dev->pix_format.sizeimage); + + if (dev->buffer_size && ((err = allocate_buffers(dev)) < 0)) + goto out_free_handler; + + init_waitqueue_head(&dev->read_event); + + /* register the device -> it creates /dev/video* */ + if (video_register_device(dev->vdev, VFL_TYPE_VIDEO, nr) < 0) { + printk(KERN_ERR + "v4l2loopback: failed video_register_device()\n"); + err = -EFAULT; + goto out_free_device; + } + v4l2loopback_create_sysfs(dev->vdev); + + MARK(); + if (ret_nr) + *ret_nr = dev->vdev->num; + return 0; + +out_free_device: + video_device_release(dev->vdev); +out_free_handler: + v4l2_ctrl_handler_free(&dev->ctrl_handler); +out_unregister: + video_set_drvdata(dev->vdev, NULL); + if (vdev_priv != NULL) + kfree(vdev_priv); + v4l2_device_unregister(&dev->v4l2_dev); +out_free_idr: + idr_remove(&v4l2loopback_index_idr, nr); +out_free_dev: + kfree(dev); + return err; +} + +static void v4l2_loopback_remove(struct v4l2_loopback_device *dev) +{ + free_buffers(dev); + v4l2loopback_remove_sysfs(dev->vdev); + kfree(video_get_drvdata(dev->vdev)); + video_unregister_device(dev->vdev); + v4l2_device_unregister(&dev->v4l2_dev); + v4l2_ctrl_handler_free(&dev->ctrl_handler); + kfree(dev); +} + +static long v4l2loopback_control_ioctl(struct file *file, unsigned int cmd, + unsigned long parm) +{ + struct v4l2_loopback_device *dev; + struct v4l2_loopback_config conf; + struct v4l2_loopback_config *confptr = &conf; + int device_nr, capture_nr, output_nr; + int ret; + + ret = mutex_lock_killable(&v4l2loopback_ctl_mutex); + if (ret) + return ret; + + ret = -EINVAL; + switch (cmd) { + default: + ret = -ENOSYS; + break; + /* add a v4l2loopback device (pair), based on the user-provided specs */ + case V4L2LOOPBACK_CTL_ADD: + if (parm) { + if ((ret = copy_from_user(&conf, (void *)parm, + sizeof(conf))) < 0) + break; + } else + confptr = NULL; + ret = v4l2_loopback_add(confptr, &device_nr); + if (ret >= 0) + ret = device_nr; + break; + /* remove a v4l2loopback device (both capture and output) */ + case V4L2LOOPBACK_CTL_REMOVE: + ret = v4l2loopback_lookup((int)parm, &dev); + if (ret >= 0 && dev) { + int nr = ret; + ret = -EBUSY; + if (dev->open_count.counter > 0) + break; + idr_remove(&v4l2loopback_index_idr, nr); + v4l2_loopback_remove(dev); + ret = 0; + }; + break; + /* get information for a loopback device. + * this is mostly about limits (which cannot be queried directly with VIDIOC_G_FMT and friends + */ + case V4L2LOOPBACK_CTL_QUERY: + if (!parm) + break; + if ((ret = copy_from_user(&conf, (void *)parm, sizeof(conf))) < + 0) + break; + capture_nr = output_nr = conf.output_nr; +#ifdef SPLIT_DEVICES + capture_nr = conf.capture_nr; +#endif + device_nr = (output_nr < 0) ? capture_nr : output_nr; + MARK(); + /* get the device from either capture_nr or output_nr (whatever is valid) */ + if ((ret = v4l2loopback_lookup(device_nr, &dev)) < 0) + break; + MARK(); + /* if we got the device from output_nr and there is a valid capture_nr, + * make sure that both refer to the same device (or bail out) + */ + if ((device_nr != capture_nr) && (capture_nr >= 0) && + ((ret = v4l2loopback_lookup(capture_nr, 0)) < 0)) + break; + MARK(); + /* if otoh, we got the device from capture_nr and there is a valid output_nr, + * make sure that both refer to the same device (or bail out) + */ + if ((device_nr != output_nr) && (output_nr >= 0) && + ((ret = v4l2loopback_lookup(output_nr, 0)) < 0)) + break; + MARK(); + + /* v4l2_loopback_config identified a single device, so fetch the data */ + snprintf(conf.card_label, sizeof(conf.card_label), "%s", + dev->card_label); + MARK(); + conf.output_nr = dev->vdev->num; +#ifdef SPLIT_DEVICES + conf.capture_nr = dev->vdev->num; +#endif + conf.min_width = dev->min_width; + conf.min_height = dev->min_height; + conf.max_width = dev->max_width; + conf.max_height = dev->max_height; + conf.announce_all_caps = dev->announce_all_caps; + conf.max_buffers = dev->buffers_number; + conf.max_openers = dev->max_openers; + conf.debug = debug; + MARK(); + if (copy_to_user((void *)parm, &conf, sizeof(conf))) { + ret = -EFAULT; + break; + } + MARK(); + ret = 0; + ; + break; + } + + MARK(); + mutex_unlock(&v4l2loopback_ctl_mutex); + MARK(); + return ret; +} + +/* LINUX KERNEL */ + +static const struct file_operations v4l2loopback_ctl_fops = { + // clang-format off + .owner = THIS_MODULE, + .open = nonseekable_open, + .unlocked_ioctl = v4l2loopback_control_ioctl, + .compat_ioctl = v4l2loopback_control_ioctl, + .llseek = noop_llseek, + // clang-format on +}; + +static struct miscdevice v4l2loopback_misc = { + // clang-format off + .minor = MISC_DYNAMIC_MINOR, + .name = "v4l2loopback", + .fops = &v4l2loopback_ctl_fops, + // clang-format on +}; + +static const struct v4l2_file_operations v4l2_loopback_fops = { + // clang-format off + .owner = THIS_MODULE, + .open = v4l2_loopback_open, + .release = v4l2_loopback_close, + .read = v4l2_loopback_read, + .write = v4l2_loopback_write, + .poll = v4l2_loopback_poll, + .mmap = v4l2_loopback_mmap, + .unlocked_ioctl = video_ioctl2, + // clang-format on +}; + +static const struct v4l2_ioctl_ops v4l2_loopback_ioctl_ops = { + // clang-format off + .vidioc_querycap = &vidioc_querycap, + .vidioc_enum_framesizes = &vidioc_enum_framesizes, + .vidioc_enum_frameintervals = &vidioc_enum_frameintervals, + + .vidioc_enum_output = &vidioc_enum_output, + .vidioc_g_output = &vidioc_g_output, + .vidioc_s_output = &vidioc_s_output, + + .vidioc_enum_input = &vidioc_enum_input, + .vidioc_g_input = &vidioc_g_input, + .vidioc_s_input = &vidioc_s_input, + + .vidioc_enum_fmt_vid_cap = &vidioc_enum_fmt_cap, + .vidioc_g_fmt_vid_cap = &vidioc_g_fmt_cap, + .vidioc_s_fmt_vid_cap = &vidioc_s_fmt_cap, + .vidioc_try_fmt_vid_cap = &vidioc_try_fmt_cap, + + .vidioc_enum_fmt_vid_out = &vidioc_enum_fmt_out, + .vidioc_s_fmt_vid_out = &vidioc_s_fmt_out, + .vidioc_g_fmt_vid_out = &vidioc_g_fmt_out, + .vidioc_try_fmt_vid_out = &vidioc_try_fmt_out, + +#ifdef V4L2L_OVERLAY + .vidioc_s_fmt_vid_overlay = &vidioc_s_fmt_overlay, + .vidioc_g_fmt_vid_overlay = &vidioc_g_fmt_overlay, +#endif + +#ifdef V4L2LOOPBACK_WITH_STD + .vidioc_s_std = &vidioc_s_std, + .vidioc_g_std = &vidioc_g_std, + .vidioc_querystd = &vidioc_querystd, +#endif /* V4L2LOOPBACK_WITH_STD */ + + .vidioc_g_parm = &vidioc_g_parm, + .vidioc_s_parm = &vidioc_s_parm, + + .vidioc_reqbufs = &vidioc_reqbufs, + .vidioc_querybuf = &vidioc_querybuf, + .vidioc_qbuf = &vidioc_qbuf, + .vidioc_dqbuf = &vidioc_dqbuf, + + .vidioc_streamon = &vidioc_streamon, + .vidioc_streamoff = &vidioc_streamoff, + +#ifdef CONFIG_VIDEO_V4L1_COMPAT + .vidiocgmbuf = &vidiocgmbuf, +#endif + + .vidioc_subscribe_event = &vidioc_subscribe_event, + .vidioc_unsubscribe_event = &v4l2_event_unsubscribe, + // clang-format on +}; + +static int free_device_cb(int id, void *ptr, void *data) +{ + struct v4l2_loopback_device *dev = ptr; + v4l2_loopback_remove(dev); + return 0; +} +static void free_devices(void) +{ + idr_for_each(&v4l2loopback_index_idr, &free_device_cb, NULL); + idr_destroy(&v4l2loopback_index_idr); +} + +static int __init v4l2loopback_init_module(void) +{ + const u32 min_width = V4L2LOOPBACK_SIZE_MIN_WIDTH; + const u32 min_height = V4L2LOOPBACK_SIZE_MIN_HEIGHT; + int err; + int i; + MARK(); + + err = misc_register(&v4l2loopback_misc); + if (err < 0) + return err; + + if (devices < 0) { + devices = 1; + + /* try guessing the devices from the "video_nr" parameter */ + for (i = MAX_DEVICES - 1; i >= 0; i--) { + if (video_nr[i] >= 0) { + devices = i + 1; + break; + } + } + } + + if (devices > MAX_DEVICES) { + devices = MAX_DEVICES; + printk(KERN_INFO + "v4l2loopback: number of initial devices is limited to: %d\n", + MAX_DEVICES); + } + + if (max_buffers > MAX_BUFFERS) { + max_buffers = MAX_BUFFERS; + printk(KERN_INFO + "v4l2loopback: number of buffers is limited to: %d\n", + MAX_BUFFERS); + } + + if (max_openers < 0) { + printk(KERN_INFO + "v4l2loopback: allowing %d openers rather than %d\n", + 2, max_openers); + max_openers = 2; + } + + if (max_width < min_width) { + max_width = V4L2LOOPBACK_SIZE_DEFAULT_MAX_WIDTH; + printk(KERN_INFO "v4l2loopback: using max_width %d\n", + max_width); + } + if (max_height < min_height) { + max_height = V4L2LOOPBACK_SIZE_DEFAULT_MAX_HEIGHT; + printk(KERN_INFO "v4l2loopback: using max_height %d\n", + max_height); + } + + for (i = 0; i < devices; i++) { + struct v4l2_loopback_config cfg = { + // clang-format off + .output_nr = video_nr[i], +#ifdef SPLIT_DEVICES + .capture_nr = video_nr[i], +#endif + .min_width = min_width, + .min_height = min_height, + .max_width = max_width, + .max_height = max_height, + .announce_all_caps = (!exclusive_caps[i]), + .max_buffers = max_buffers, + .max_openers = max_openers, + .debug = debug, + // clang-format on + }; + cfg.card_label[0] = 0; + if (card_label[i]) + snprintf(cfg.card_label, sizeof(cfg.card_label), "%s", + card_label[i]); + err = v4l2_loopback_add(&cfg, 0); + if (err) { + free_devices(); + goto error; + } + } + + dprintk("module installed\n"); + + printk(KERN_INFO "v4l2loopback driver version %d.%d.%d%s loaded\n", + // clang-format off + (V4L2LOOPBACK_VERSION_CODE >> 16) & 0xff, + (V4L2LOOPBACK_VERSION_CODE >> 8) & 0xff, + (V4L2LOOPBACK_VERSION_CODE ) & 0xff, +#ifdef SNAPSHOT_VERSION + " (" __stringify(SNAPSHOT_VERSION) ")" +#else + "" +#endif + ); + // clang-format on + + return 0; +error: + misc_deregister(&v4l2loopback_misc); + return err; +} + +static void v4l2loopback_cleanup_module(void) +{ + MARK(); + /* unregister the device -> it deletes /dev/video* */ + free_devices(); + /* and get rid of /dev/v4l2loopback */ + misc_deregister(&v4l2loopback_misc); + dprintk("module removed\n"); +} + +MODULE_ALIAS_MISCDEV(MISC_DYNAMIC_MINOR); + +module_init(v4l2loopback_init_module); +module_exit(v4l2loopback_cleanup_module); diff --git a/drivers/media/v4l2-core/v4l2loopback.h b/drivers/media/v4l2-core/v4l2loopback.h new file mode 100644 index 000000000000..1bc7e6b747a4 --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback.h @@ -0,0 +1,98 @@ +/* SPDX-License-Identifier: GPL-2.0+ WITH Linux-syscall-note */ +/* + * v4l2loopback.h + * + * Written by IOhannes m zmölnig, 7/1/20. + * + * Copyright 2020 by IOhannes m zmölnig. Redistribution of this file is + * permitted under the GNU General Public License. + */ +#ifndef _V4L2LOOPBACK_H +#define _V4L2LOOPBACK_H + +#define V4L2LOOPBACK_VERSION_MAJOR 0 +#define V4L2LOOPBACK_VERSION_MINOR 13 +#define V4L2LOOPBACK_VERSION_BUGFIX 1 + +/* /dev/v4l2loopback interface */ + +struct v4l2_loopback_config { + /** + * the device-number (/dev/video) + * V4L2LOOPBACK_CTL_ADD: + * setting this to a value<0, will allocate an available one + * if nr>=0 and the device already exists, the ioctl will EEXIST + * if output_nr and capture_nr are the same, only a single device will be created + * NOTE: currently split-devices (where output_nr and capture_nr differ) + * are not implemented yet. + * until then, requesting different device-IDs will result in EINVAL. + * + * V4L2LOOPBACK_CTL_QUERY: + * either both output_nr and capture_nr must refer to the same loopback, + * or one (and only one) of them must be -1 + * + */ + int output_nr; + int unused; /*capture_nr;*/ + + /** + * a nice name for your device + * if (*card_label)==0, an automatic name is assigned + */ + char card_label[32]; + + /** + * allowed frame size + * if too low, default values are used + */ + unsigned int min_width; + unsigned int max_width; + unsigned int min_height; + unsigned int max_height; + + /** + * number of buffers to allocate for the queue + * if set to <=0, default values are used + */ + int max_buffers; + + /** + * how many consumers are allowed to open this device concurrently + * if set to <=0, default values are used + */ + int max_openers; + + /** + * set the debugging level for this device + */ + int debug; + + /** + * whether to announce OUTPUT/CAPTURE capabilities exclusively + * for this device or not + * (!exclusive_caps) + * NOTE: this is going to be removed once separate output/capture + * devices are implemented + */ + int announce_all_caps; +}; + +/* a pointer to a (struct v4l2_loopback_config) that has all values you wish to impose on the + * to-be-created device set. + * if the ptr is NULL, a new device is created with default values at the driver's discretion. + * + * returns the device_nr of the OUTPUT device (which can be used with V4L2LOOPBACK_CTL_QUERY, + * to get more information on the device) + */ +#define V4L2LOOPBACK_CTL_ADD 0x4C80 + +/* a pointer to a (struct v4l2_loopback_config) that has output_nr and/or capture_nr set + * (the two values must either refer to video-devices associated with the same loopback device + * or exactly one of them must be <0 + */ +#define V4L2LOOPBACK_CTL_QUERY 0x4C82 + +/* the device-number (either CAPTURE or OUTPUT) associated with the loopback-device */ +#define V4L2LOOPBACK_CTL_REMOVE 0x4C81 + +#endif /* _V4L2LOOPBACK_H */ diff --git a/drivers/media/v4l2-core/v4l2loopback_formats.h b/drivers/media/v4l2-core/v4l2loopback_formats.h new file mode 100644 index 000000000000..d855a3796554 --- /dev/null +++ b/drivers/media/v4l2-core/v4l2loopback_formats.h @@ -0,0 +1,445 @@ +static const struct v4l2l_format formats[] = { +#ifndef V4L2_PIX_FMT_VP9 +#define V4L2_PIX_FMT_VP9 v4l2_fourcc('V', 'P', '9', '0') +#endif +#ifndef V4L2_PIX_FMT_HEVC +#define V4L2_PIX_FMT_HEVC v4l2_fourcc('H', 'E', 'V', 'C') +#endif + + /* here come the packed formats */ + { + .name = "32 bpp RGB, le", + .fourcc = V4L2_PIX_FMT_BGR32, + .depth = 32, + .flags = 0, + }, + { + .name = "32 bpp RGB, be", + .fourcc = V4L2_PIX_FMT_RGB32, + .depth = 32, + .flags = 0, + }, + { + .name = "24 bpp RGB, le", + .fourcc = V4L2_PIX_FMT_BGR24, + .depth = 24, + .flags = 0, + }, + { + .name = "24 bpp RGB, be", + .fourcc = V4L2_PIX_FMT_RGB24, + .depth = 24, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_ABGR32 + { + .name = "32 bpp RGBA, le", + .fourcc = V4L2_PIX_FMT_ABGR32, + .depth = 32, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_RGBA32 + { + .name = "32 bpp RGBA", + .fourcc = V4L2_PIX_FMT_RGBA32, + .depth = 32, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_RGB332 + { + .name = "8 bpp RGB-3-3-2", + .fourcc = V4L2_PIX_FMT_RGB332, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB332 */ +#ifdef V4L2_PIX_FMT_RGB444 + { + .name = "16 bpp RGB (xxxxrrrr ggggbbbb)", + .fourcc = V4L2_PIX_FMT_RGB444, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB444 */ +#ifdef V4L2_PIX_FMT_RGB555 + { + .name = "16 bpp RGB-5-5-5", + .fourcc = V4L2_PIX_FMT_RGB555, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB555 */ +#ifdef V4L2_PIX_FMT_RGB565 + { + .name = "16 bpp RGB-5-6-5", + .fourcc = V4L2_PIX_FMT_RGB565, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB565 */ +#ifdef V4L2_PIX_FMT_RGB555X + { + .name = "16 bpp RGB-5-5-5 BE", + .fourcc = V4L2_PIX_FMT_RGB555X, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB555X */ +#ifdef V4L2_PIX_FMT_RGB565X + { + .name = "16 bpp RGB-5-6-5 BE", + .fourcc = V4L2_PIX_FMT_RGB565X, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_RGB565X */ +#ifdef V4L2_PIX_FMT_BGR666 + { + .name = "18 bpp BGR-6-6-6", + .fourcc = V4L2_PIX_FMT_BGR666, + .depth = 18, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_BGR666 */ + { + .name = "4:2:2, packed, YUYV", + .fourcc = V4L2_PIX_FMT_YUYV, + .depth = 16, + .flags = 0, + }, + { + .name = "4:2:2, packed, UYVY", + .fourcc = V4L2_PIX_FMT_UYVY, + .depth = 16, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_YVYU + { + .name = "4:2:2, packed YVYU", + .fourcc = V4L2_PIX_FMT_YVYU, + .depth = 16, + .flags = 0, + }, +#endif +#ifdef V4L2_PIX_FMT_VYUY + { + .name = "4:2:2, packed VYUY", + .fourcc = V4L2_PIX_FMT_VYUY, + .depth = 16, + .flags = 0, + }, +#endif + { + .name = "4:2:2, packed YYUV", + .fourcc = V4L2_PIX_FMT_YYUV, + .depth = 16, + .flags = 0, + }, + { + .name = "YUV-8-8-8-8", + .fourcc = V4L2_PIX_FMT_YUV32, + .depth = 32, + .flags = 0, + }, + { + .name = "8 bpp, Greyscale", + .fourcc = V4L2_PIX_FMT_GREY, + .depth = 8, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_Y4 + { + .name = "4 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y4, + .depth = 4, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y4 */ +#ifdef V4L2_PIX_FMT_Y6 + { + .name = "6 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y6, + .depth = 6, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y6 */ +#ifdef V4L2_PIX_FMT_Y10 + { + .name = "10 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y10, + .depth = 10, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y10 */ +#ifdef V4L2_PIX_FMT_Y12 + { + .name = "12 bpp Greyscale", + .fourcc = V4L2_PIX_FMT_Y12, + .depth = 12, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_Y12 */ + { + .name = "16 bpp, Greyscale", + .fourcc = V4L2_PIX_FMT_Y16, + .depth = 16, + .flags = 0, + }, +#ifdef V4L2_PIX_FMT_YUV444 + { + .name = "16 bpp xxxxyyyy uuuuvvvv", + .fourcc = V4L2_PIX_FMT_YUV444, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV444 */ +#ifdef V4L2_PIX_FMT_YUV555 + { + .name = "16 bpp YUV-5-5-5", + .fourcc = V4L2_PIX_FMT_YUV555, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV555 */ +#ifdef V4L2_PIX_FMT_YUV565 + { + .name = "16 bpp YUV-5-6-5", + .fourcc = V4L2_PIX_FMT_YUV565, + .depth = 16, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_YUV565 */ + +/* bayer formats */ +#ifdef V4L2_PIX_FMT_SRGGB8 + { + .name = "Bayer RGGB 8bit", + .fourcc = V4L2_PIX_FMT_SRGGB8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SRGGB8 */ +#ifdef V4L2_PIX_FMT_SGRBG8 + { + .name = "Bayer GRBG 8bit", + .fourcc = V4L2_PIX_FMT_SGRBG8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SGRBG8 */ +#ifdef V4L2_PIX_FMT_SGBRG8 + { + .name = "Bayer GBRG 8bit", + .fourcc = V4L2_PIX_FMT_SGBRG8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SGBRG8 */ +#ifdef V4L2_PIX_FMT_SBGGR8 + { + .name = "Bayer BA81 8bit", + .fourcc = V4L2_PIX_FMT_SBGGR8, + .depth = 8, + .flags = 0, + }, +#endif /* V4L2_PIX_FMT_SBGGR8 */ + + /* here come the planar formats */ + { + .name = "4:1:0, planar, Y-Cr-Cb", + .fourcc = V4L2_PIX_FMT_YVU410, + .depth = 9, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:2:0, planar, Y-Cr-Cb", + .fourcc = V4L2_PIX_FMT_YVU420, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:1:0, planar, Y-Cb-Cr", + .fourcc = V4L2_PIX_FMT_YUV410, + .depth = 9, + .flags = FORMAT_FLAGS_PLANAR, + }, + { + .name = "4:2:0, planar, Y-Cb-Cr", + .fourcc = V4L2_PIX_FMT_YUV420, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#ifdef V4L2_PIX_FMT_YUV422P + { + .name = "16 bpp YVU422 planar", + .fourcc = V4L2_PIX_FMT_YUV422P, + .depth = 16, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_YUV422P */ +#ifdef V4L2_PIX_FMT_YUV411P + { + .name = "16 bpp YVU411 planar", + .fourcc = V4L2_PIX_FMT_YUV411P, + .depth = 16, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_YUV411P */ +#ifdef V4L2_PIX_FMT_Y41P + { + .name = "12 bpp YUV 4:1:1", + .fourcc = V4L2_PIX_FMT_Y41P, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_Y41P */ +#ifdef V4L2_PIX_FMT_NV12 + { + .name = "12 bpp Y/CbCr 4:2:0 ", + .fourcc = V4L2_PIX_FMT_NV12, + .depth = 12, + .flags = FORMAT_FLAGS_PLANAR, + }, +#endif /* V4L2_PIX_FMT_NV12 */ + +/* here come the compressed formats */ + +#ifdef V4L2_PIX_FMT_MJPEG + { + .name = "Motion-JPEG", + .fourcc = V4L2_PIX_FMT_MJPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MJPEG */ +#ifdef V4L2_PIX_FMT_JPEG + { + .name = "JFIF JPEG", + .fourcc = V4L2_PIX_FMT_JPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_JPEG */ +#ifdef V4L2_PIX_FMT_DV + { + .name = "DV1394", + .fourcc = V4L2_PIX_FMT_DV, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_DV */ +#ifdef V4L2_PIX_FMT_MPEG + { + .name = "MPEG-1/2/4 Multiplexed", + .fourcc = V4L2_PIX_FMT_MPEG, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG */ +#ifdef V4L2_PIX_FMT_H264 + { + .name = "H264 with start codes", + .fourcc = V4L2_PIX_FMT_H264, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264 */ +#ifdef V4L2_PIX_FMT_H264_NO_SC + { + .name = "H264 without start codes", + .fourcc = V4L2_PIX_FMT_H264_NO_SC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264_NO_SC */ +#ifdef V4L2_PIX_FMT_H264_MVC + { + .name = "H264 MVC", + .fourcc = V4L2_PIX_FMT_H264_MVC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H264_MVC */ +#ifdef V4L2_PIX_FMT_H263 + { + .name = "H263", + .fourcc = V4L2_PIX_FMT_H263, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_H263 */ +#ifdef V4L2_PIX_FMT_MPEG1 + { + .name = "MPEG-1 ES", + .fourcc = V4L2_PIX_FMT_MPEG1, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG1 */ +#ifdef V4L2_PIX_FMT_MPEG2 + { + .name = "MPEG-2 ES", + .fourcc = V4L2_PIX_FMT_MPEG2, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG2 */ +#ifdef V4L2_PIX_FMT_MPEG4 + { + .name = "MPEG-4 part 2 ES", + .fourcc = V4L2_PIX_FMT_MPEG4, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_MPEG4 */ +#ifdef V4L2_PIX_FMT_XVID + { + .name = "Xvid", + .fourcc = V4L2_PIX_FMT_XVID, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_XVID */ +#ifdef V4L2_PIX_FMT_VC1_ANNEX_G + { + .name = "SMPTE 421M Annex G compliant stream", + .fourcc = V4L2_PIX_FMT_VC1_ANNEX_G, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VC1_ANNEX_G */ +#ifdef V4L2_PIX_FMT_VC1_ANNEX_L + { + .name = "SMPTE 421M Annex L compliant stream", + .fourcc = V4L2_PIX_FMT_VC1_ANNEX_L, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VC1_ANNEX_L */ +#ifdef V4L2_PIX_FMT_VP8 + { + .name = "VP8", + .fourcc = V4L2_PIX_FMT_VP8, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VP8 */ +#ifdef V4L2_PIX_FMT_VP9 + { + .name = "VP9", + .fourcc = V4L2_PIX_FMT_VP9, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_VP9 */ +#ifdef V4L2_PIX_FMT_HEVC + { + .name = "HEVC", + .fourcc = V4L2_PIX_FMT_HEVC, + .depth = 32, + .flags = FORMAT_FLAGS_COMPRESSED, + }, +#endif /* V4L2_PIX_FMT_HEVC */ +}; diff --git a/drivers/pci/controller/Makefile b/drivers/pci/controller/Makefile index 038ccbd9e3ba..de5e4f5145af 100644 --- a/drivers/pci/controller/Makefile +++ b/drivers/pci/controller/Makefile @@ -1,4 +1,10 @@ # SPDX-License-Identifier: GPL-2.0 +ifdef CONFIG_X86_64 +ifdef CONFIG_SATA_AHCI +obj-y += intel-nvme-remap.o +endif +endif + obj-$(CONFIG_PCIE_CADENCE) += cadence/ obj-$(CONFIG_PCI_FTPCI100) += pci-ftpci100.o obj-$(CONFIG_PCI_IXP4XX) += pci-ixp4xx.o diff --git a/drivers/pci/controller/intel-nvme-remap.c b/drivers/pci/controller/intel-nvme-remap.c new file mode 100644 index 000000000000..e105e6f5cc91 --- /dev/null +++ b/drivers/pci/controller/intel-nvme-remap.c @@ -0,0 +1,462 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Intel remapped NVMe device support. + * + * Copyright (c) 2019 Endless Mobile, Inc. + * Author: Daniel Drake + * + * Some products ship by default with the SATA controller in "RAID" or + * "Intel RST Premium With Intel Optane System Acceleration" mode. Under this + * mode, which we refer to as "remapped NVMe" mode, any installed NVMe + * devices disappear from the PCI bus, and instead their I/O memory becomes + * available within the AHCI device BARs. + * + * This scheme is understood to be a way of avoiding usage of the standard + * Windows NVMe driver under that OS, instead mandating usage of Intel's + * driver instead, which has better power management, and presumably offers + * some RAID/disk-caching solutions too. + * + * Here in this driver, we support the remapped NVMe mode by claiming the + * AHCI device and creating a fake PCIe root port. On the new bus, the + * original AHCI device is exposed with only minor tweaks. Then, fake PCI + * devices corresponding to the remapped NVMe devices are created. The usual + * ahci and nvme drivers are then expected to bind to these devices and + * operate as normal. + * + * The PCI configuration space for the NVMe devices is completely + * unavailable, so we fake a minimal one and hope for the best. + * + * Interrupts are shared between the AHCI and NVMe devices. For simplicity, + * we only support the legacy interrupt here, although MSI support + * could potentially be added later. + */ + +#define MODULE_NAME "intel-nvme-remap" + +#include +#include +#include +#include +#include + +#define AHCI_PCI_BAR_STANDARD 5 + +struct nvme_remap_dev { + struct pci_dev *dev; /* AHCI device */ + struct pci_bus *bus; /* our fake PCI bus */ + struct pci_sysdata sysdata; + int irq_base; /* our fake interrupts */ + + /* + * When we detect an all-ones write to a BAR register, this flag + * is set, so that we return the BAR size on the next read (a + * standard PCI behaviour). + * This includes the assumption that an all-ones BAR write is + * immediately followed by a read of the same register. + */ + bool bar_sizing; + + /* + * Resources copied from the AHCI device, to be regarded as + * resources on our fake bus. + */ + struct resource ahci_resources[PCI_NUM_RESOURCES]; + + /* Resources corresponding to the NVMe devices. */ + struct resource remapped_dev_mem[AHCI_MAX_REMAP]; + + /* Number of remapped NVMe devices found. */ + int num_remapped_devices; +}; + +static inline struct nvme_remap_dev *nrdev_from_bus(struct pci_bus *bus) +{ + return container_of(bus->sysdata, struct nvme_remap_dev, sysdata); +} + + +/******** PCI configuration space **********/ + +/* + * Helper macros for tweaking returned contents of PCI configuration space. + * + * value contains len bytes of data read from reg. + * If fixup_reg is included in that range, fix up the contents of that + * register to fixed_value. + */ +#define NR_FIX8(fixup_reg, fixed_value) do { \ + if (reg <= fixup_reg && fixup_reg < reg + len) \ + ((u8 *) value)[fixup_reg - reg] = (u8) (fixed_value); \ + } while (0) + +#define NR_FIX16(fixup_reg, fixed_value) do { \ + NR_FIX8(fixup_reg, fixed_value); \ + NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ + } while (0) + +#define NR_FIX24(fixup_reg, fixed_value) do { \ + NR_FIX8(fixup_reg, fixed_value); \ + NR_FIX8(fixup_reg + 1, fixed_value >> 8); \ + NR_FIX8(fixup_reg + 2, fixed_value >> 16); \ + } while (0) + +#define NR_FIX32(fixup_reg, fixed_value) do { \ + NR_FIX16(fixup_reg, (u16) fixed_value); \ + NR_FIX16(fixup_reg + 2, fixed_value >> 16); \ + } while (0) + +/* + * Read PCI config space of the slot 0 (AHCI) device. + * We pass through the read request to the underlying device, but + * tweak the results in some cases. + */ +static int nvme_remap_pci_read_slot0(struct pci_bus *bus, int reg, + int len, u32 *value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct pci_bus *ahci_dev_bus = nrdev->dev->bus; + int ret; + + ret = ahci_dev_bus->ops->read(ahci_dev_bus, nrdev->dev->devfn, + reg, len, value); + if (ret) + return ret; + + /* + * Adjust the device class, to prevent this driver from attempting to + * additionally probe the device we're simulating here. + */ + NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_SATA_AHCI); + + /* + * Unset interrupt pin, otherwise ACPI tries to find routing + * info for our virtual IRQ, fails, and complains. + */ + NR_FIX8(PCI_INTERRUPT_PIN, 0); + + /* + * Truncate the AHCI BAR to not include the region that covers the + * hidden devices. This will cause the ahci driver to successfully + * probe th new device (instead of handing it over to this driver). + */ + if (nrdev->bar_sizing) { + NR_FIX32(PCI_BASE_ADDRESS_5, ~(SZ_16K - 1)); + nrdev->bar_sizing = false; + } + + return PCIBIOS_SUCCESSFUL; +} + +/* + * Read PCI config space of a remapped device. + * Since the original PCI config space is inaccessible, we provide a minimal, + * fake config space instead. + */ +static int nvme_remap_pci_read_remapped(struct pci_bus *bus, unsigned int port, + int reg, int len, u32 *value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct resource *remapped_mem; + + if (port > nrdev->num_remapped_devices) + return PCIBIOS_DEVICE_NOT_FOUND; + + *value = 0; + remapped_mem = &nrdev->remapped_dev_mem[port - 1]; + + /* Set a Vendor ID, otherwise Linux assumes no device is present */ + NR_FIX16(PCI_VENDOR_ID, PCI_VENDOR_ID_INTEL); + + /* Always appear on & bus mastering */ + NR_FIX16(PCI_COMMAND, PCI_COMMAND_MEMORY | PCI_COMMAND_MASTER); + + /* Set class so that nvme driver probes us */ + NR_FIX24(PCI_CLASS_PROG, PCI_CLASS_STORAGE_EXPRESS); + + if (nrdev->bar_sizing) { + NR_FIX32(PCI_BASE_ADDRESS_0, + ~(resource_size(remapped_mem) - 1)); + nrdev->bar_sizing = false; + } else { + resource_size_t mem_start = remapped_mem->start; + + mem_start |= PCI_BASE_ADDRESS_MEM_TYPE_64; + NR_FIX32(PCI_BASE_ADDRESS_0, mem_start); + mem_start >>= 32; + NR_FIX32(PCI_BASE_ADDRESS_1, mem_start); + } + + return PCIBIOS_SUCCESSFUL; +} + +/* Read PCI configuration space. */ +static int nvme_remap_pci_read(struct pci_bus *bus, unsigned int devfn, + int reg, int len, u32 *value) +{ + if (PCI_SLOT(devfn) == 0) + return nvme_remap_pci_read_slot0(bus, reg, len, value); + else + return nvme_remap_pci_read_remapped(bus, PCI_SLOT(devfn), + reg, len, value); +} + +/* + * Write PCI config space of the slot 0 (AHCI) device. + * Apart from the special case of BAR sizing, we disable all writes. + * Otherwise, the ahci driver could make changes (e.g. unset PCI bus master) + * that would affect the operation of the NVMe devices. + */ +static int nvme_remap_pci_write_slot0(struct pci_bus *bus, int reg, + int len, u32 value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + struct pci_bus *ahci_dev_bus = nrdev->dev->bus; + + if (reg >= PCI_BASE_ADDRESS_0 && reg <= PCI_BASE_ADDRESS_5) { + /* + * Writing all-ones to a BAR means that the size of the + * memory region is being checked. Flag this so that we can + * reply with an appropriate size on the next read. + */ + if (value == ~0) + nrdev->bar_sizing = true; + + return ahci_dev_bus->ops->write(ahci_dev_bus, + nrdev->dev->devfn, + reg, len, value); + } + + return PCIBIOS_SET_FAILED; +} + +/* + * Write PCI config space of a remapped device. + * Since the original PCI config space is inaccessible, we reject all + * writes, except for the special case of BAR probing. + */ +static int nvme_remap_pci_write_remapped(struct pci_bus *bus, + unsigned int port, + int reg, int len, u32 value) +{ + struct nvme_remap_dev *nrdev = nrdev_from_bus(bus); + + if (port > nrdev->num_remapped_devices) + return PCIBIOS_DEVICE_NOT_FOUND; + + /* + * Writing all-ones to a BAR means that the size of the memory + * region is being checked. Flag this so that we can reply with + * an appropriate size on the next read. + */ + if (value == ~0 && reg >= PCI_BASE_ADDRESS_0 + && reg <= PCI_BASE_ADDRESS_5) { + nrdev->bar_sizing = true; + return PCIBIOS_SUCCESSFUL; + } + + return PCIBIOS_SET_FAILED; +} + +/* Write PCI configuration space. */ +static int nvme_remap_pci_write(struct pci_bus *bus, unsigned int devfn, + int reg, int len, u32 value) +{ + if (PCI_SLOT(devfn) == 0) + return nvme_remap_pci_write_slot0(bus, reg, len, value); + else + return nvme_remap_pci_write_remapped(bus, PCI_SLOT(devfn), + reg, len, value); +} + +static struct pci_ops nvme_remap_pci_ops = { + .read = nvme_remap_pci_read, + .write = nvme_remap_pci_write, +}; + + +/******** Initialization & exit **********/ + +/* + * Find a PCI domain ID to use for our fake bus. + * Start at 0x10000 to not clash with ACPI _SEG domains (16 bits). + */ +static int find_free_domain(void) +{ + int domain = 0xffff; + struct pci_bus *bus = NULL; + + while ((bus = pci_find_next_bus(bus)) != NULL) + domain = max_t(int, domain, pci_domain_nr(bus)); + + return domain + 1; +} + +static int find_remapped_devices(struct nvme_remap_dev *nrdev, + struct list_head *resources) +{ + void __iomem *mmio; + int i, count = 0; + u32 cap; + + mmio = pcim_iomap(nrdev->dev, AHCI_PCI_BAR_STANDARD, + pci_resource_len(nrdev->dev, + AHCI_PCI_BAR_STANDARD)); + if (!mmio) + return -ENODEV; + + /* Check if this device might have remapped nvme devices. */ + if (pci_resource_len(nrdev->dev, AHCI_PCI_BAR_STANDARD) < SZ_512K || + !(readl(mmio + AHCI_VSCAP) & 1)) + return -ENODEV; + + cap = readq(mmio + AHCI_REMAP_CAP); + for (i = AHCI_MAX_REMAP-1; i >= 0; i--) { + struct resource *remapped_mem; + + if ((cap & (1 << i)) == 0) + continue; + if (readl(mmio + ahci_remap_dcc(i)) + != PCI_CLASS_STORAGE_EXPRESS) + continue; + + /* We've found a remapped device */ + remapped_mem = &nrdev->remapped_dev_mem[count++]; + remapped_mem->start = + pci_resource_start(nrdev->dev, AHCI_PCI_BAR_STANDARD) + + ahci_remap_base(i); + remapped_mem->end = remapped_mem->start + + AHCI_REMAP_N_SIZE - 1; + remapped_mem->flags = IORESOURCE_MEM | IORESOURCE_PCI_FIXED; + pci_add_resource(resources, remapped_mem); + } + + pcim_iounmap(nrdev->dev, mmio); + + if (count == 0) + return -ENODEV; + + nrdev->num_remapped_devices = count; + dev_info(&nrdev->dev->dev, "Found %d remapped NVMe devices\n", + nrdev->num_remapped_devices); + return 0; +} + +static void nvme_remap_remove_root_bus(void *data) +{ + struct pci_bus *bus = data; + + pci_stop_root_bus(bus); + pci_remove_root_bus(bus); +} + +static int nvme_remap_probe(struct pci_dev *dev, + const struct pci_device_id *id) +{ + struct nvme_remap_dev *nrdev; + LIST_HEAD(resources); + int i; + int ret; + struct pci_dev *child; + + nrdev = devm_kzalloc(&dev->dev, sizeof(*nrdev), GFP_KERNEL); + nrdev->sysdata.domain = find_free_domain(); + nrdev->sysdata.nvme_remap_dev = dev; + nrdev->dev = dev; + pci_set_drvdata(dev, nrdev); + + ret = pcim_enable_device(dev); + if (ret < 0) + return ret; + + pci_set_master(dev); + + ret = find_remapped_devices(nrdev, &resources); + if (ret) + return ret; + + /* Add resources from the original AHCI device */ + for (i = 0; i < PCI_NUM_RESOURCES; i++) { + struct resource *res = &dev->resource[i]; + + if (res->start) { + struct resource *nr_res = &nrdev->ahci_resources[i]; + + nr_res->start = res->start; + nr_res->end = res->end; + nr_res->flags = res->flags; + pci_add_resource(&resources, nr_res); + } + } + + /* Create virtual interrupts */ + nrdev->irq_base = devm_irq_alloc_descs(&dev->dev, -1, 0, + nrdev->num_remapped_devices + 1, + 0); + if (nrdev->irq_base < 0) + return nrdev->irq_base; + + /* Create and populate PCI bus */ + nrdev->bus = pci_create_root_bus(&dev->dev, 0, &nvme_remap_pci_ops, + &nrdev->sysdata, &resources); + if (!nrdev->bus) + return -ENODEV; + + if (devm_add_action_or_reset(&dev->dev, nvme_remap_remove_root_bus, + nrdev->bus)) + return -ENOMEM; + + /* We don't support sharing MSI interrupts between these devices */ + nrdev->bus->bus_flags |= PCI_BUS_FLAGS_NO_MSI; + + pci_scan_child_bus(nrdev->bus); + + list_for_each_entry(child, &nrdev->bus->devices, bus_list) { + /* + * Prevent PCI core from trying to move memory BARs around. + * The hidden NVMe devices are at fixed locations. + */ + for (i = 0; i < PCI_NUM_RESOURCES; i++) { + struct resource *res = &child->resource[i]; + + if (res->flags & IORESOURCE_MEM) + res->flags |= IORESOURCE_PCI_FIXED; + } + + /* Share the legacy IRQ between all devices */ + child->irq = dev->irq; + } + + pci_assign_unassigned_bus_resources(nrdev->bus); + pci_bus_add_devices(nrdev->bus); + + return 0; +} + +static const struct pci_device_id nvme_remap_ids[] = { + /* + * Match all Intel RAID controllers. + * + * There's overlap here with the set of devices detected by the ahci + * driver, but ahci will only successfully probe when there + * *aren't* any remapped NVMe devices, and this driver will only + * successfully probe when there *are* remapped NVMe devices that + * need handling. + */ + { + PCI_VDEVICE(INTEL, PCI_ANY_ID), + .class = PCI_CLASS_STORAGE_RAID << 8, + .class_mask = 0xffffff00, + }, + {0,} +}; +MODULE_DEVICE_TABLE(pci, nvme_remap_ids); + +static struct pci_driver nvme_remap_drv = { + .name = MODULE_NAME, + .id_table = nvme_remap_ids, + .probe = nvme_remap_probe, +}; +module_pci_driver(nvme_remap_drv); + +MODULE_AUTHOR("Daniel Drake "); +MODULE_LICENSE("GPL v2"); diff --git a/drivers/pci/quirks.c b/drivers/pci/quirks.c index dccb60c1d9cc..d9a8af789de8 100644 --- a/drivers/pci/quirks.c +++ b/drivers/pci/quirks.c @@ -3747,6 +3747,106 @@ static void quirk_no_bus_reset(struct pci_dev *dev) dev->dev_flags |= PCI_DEV_FLAGS_NO_BUS_RESET; } +static bool acs_on_downstream; +static bool acs_on_multifunction; + +#define NUM_ACS_IDS 16 +struct acs_on_id { + unsigned short vendor; + unsigned short device; +}; +static struct acs_on_id acs_on_ids[NUM_ACS_IDS]; +static u8 max_acs_id; + +static __init int pcie_acs_override_setup(char *p) +{ + if (!p) + return -EINVAL; + + while (*p) { + if (!strncmp(p, "downstream", 10)) + acs_on_downstream = true; + if (!strncmp(p, "multifunction", 13)) + acs_on_multifunction = true; + if (!strncmp(p, "id:", 3)) { + char opt[5]; + int ret; + long val; + + if (max_acs_id >= NUM_ACS_IDS - 1) { + pr_warn("Out of PCIe ACS override slots (%d)\n", + NUM_ACS_IDS); + goto next; + } + + p += 3; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].vendor = val; + + p += strcspn(p, ":"); + if (*p != ':') { + pr_warn("PCIe ACS invalid ID\n"); + goto next; + } + + p++; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].device = val; + max_acs_id++; + } +next: + p += strcspn(p, ","); + if (*p == ',') + p++; + } + + if (acs_on_downstream || acs_on_multifunction || max_acs_id) + pr_warn("Warning: PCIe ACS overrides enabled; This may allow non-IOMMU protected peer-to-peer DMA\n"); + + return 0; +} +early_param("pcie_acs_override", pcie_acs_override_setup); + +static int pcie_acs_overrides(struct pci_dev *dev, u16 acs_flags) +{ + int i; + + /* Never override ACS for legacy devices or devices with ACS caps */ + if (!pci_is_pcie(dev) || + pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ACS)) + return -ENOTTY; + + for (i = 0; i < max_acs_id; i++) + if (acs_on_ids[i].vendor == dev->vendor && + acs_on_ids[i].device == dev->device) + return 1; + + switch (pci_pcie_type(dev)) { + case PCI_EXP_TYPE_DOWNSTREAM: + case PCI_EXP_TYPE_ROOT_PORT: + if (acs_on_downstream) + return 1; + break; + case PCI_EXP_TYPE_ENDPOINT: + case PCI_EXP_TYPE_UPSTREAM: + case PCI_EXP_TYPE_LEG_END: + case PCI_EXP_TYPE_RC_END: + if (acs_on_multifunction && dev->multifunction) + return 1; + } + + return -ENOTTY; +} /* * Some NVIDIA GPU devices do not work with bus reset, SBR needs to be * prevented for those affected devices. @@ -5168,6 +5268,7 @@ static const struct pci_dev_acs_enabled { { PCI_VENDOR_ID_ZHAOXIN, PCI_ANY_ID, pci_quirk_zhaoxin_pcie_ports_acs }, /* Wangxun nics */ { PCI_VENDOR_ID_WANGXUN, PCI_ANY_ID, pci_quirk_wangxun_nic_acs }, + { PCI_ANY_ID, PCI_ANY_ID, pcie_acs_overrides }, { 0 } }; diff --git a/include/linux/pagemap.h b/include/linux/pagemap.h index 68a5f1ff3301..291873a34079 100644 --- a/include/linux/pagemap.h +++ b/include/linux/pagemap.h @@ -1362,7 +1362,7 @@ struct readahead_control { ._index = i, \ } -#define VM_READAHEAD_PAGES (SZ_128K / PAGE_SIZE) +#define VM_READAHEAD_PAGES (SZ_8M / PAGE_SIZE) void page_cache_ra_unbounded(struct readahead_control *, unsigned long nr_to_read, unsigned long lookahead_count); diff --git a/include/linux/user_namespace.h b/include/linux/user_namespace.h index 7183e5aca282..56573371a2f8 100644 --- a/include/linux/user_namespace.h +++ b/include/linux/user_namespace.h @@ -159,6 +159,8 @@ static inline void set_userns_rlimit_max(struct user_namespace *ns, #ifdef CONFIG_USER_NS +extern int unprivileged_userns_clone; + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { if (ns) @@ -192,6 +194,8 @@ extern bool current_in_userns(const struct user_namespace *target_ns); struct ns_common *ns_get_owner(struct ns_common *ns); #else +#define unprivileged_userns_clone 0 + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { return &init_user_ns; diff --git a/include/linux/wait.h b/include/linux/wait.h index 8aa3372f21a0..924778a426ce 100644 --- a/include/linux/wait.h +++ b/include/linux/wait.h @@ -163,6 +163,7 @@ static inline bool wq_has_sleeper(struct wait_queue_head *wq_head) extern void add_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void add_wait_queue_exclusive(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); +extern void add_wait_queue_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void add_wait_queue_priority(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); extern void remove_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); @@ -1191,6 +1192,7 @@ do { \ */ void prepare_to_wait(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); bool prepare_to_wait_exclusive(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); +void prepare_to_wait_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); long prepare_to_wait_event(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state); void finish_wait(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry); long wait_woken(struct wait_queue_entry *wq_entry, unsigned mode, long timeout); diff --git a/init/Kconfig b/init/Kconfig index c521e1421ad4..38dbd16da6a9 100644 --- a/init/Kconfig +++ b/init/Kconfig @@ -145,6 +145,10 @@ config THREAD_INFO_IN_TASK menu "General setup" +config CACHY + bool "Some kernel tweaks by CachyOS" + default y + config BROKEN bool @@ -1300,6 +1304,22 @@ config USER_NS If unsure, say N. +config USER_NS_UNPRIVILEGED + bool "Allow unprivileged users to create namespaces" + default y + depends on USER_NS + help + When disabled, unprivileged users will not be able to create + new namespaces. Allowing users to create their own namespaces + has been part of several recent local privilege escalation + exploits, so if you need user namespaces but are + paranoid^Wsecurity-conscious you want to disable this. + + This setting can be overridden at runtime via the + kernel.unprivileged_userns_clone sysctl. + + If unsure, say Y. + config PID_NS bool "PID Namespaces" default y @@ -1442,6 +1462,12 @@ config CC_OPTIMIZE_FOR_PERFORMANCE with the "-O2" compiler flag for best performance and most helpful compile-time warnings. +config CC_OPTIMIZE_FOR_PERFORMANCE_O3 + bool "Optimize more for performance (-O3)" + help + Choosing this option will pass "-O3" to your compiler to optimize + the kernel yet more for performance. + config CC_OPTIMIZE_FOR_SIZE bool "Optimize for size (-Os)" help diff --git a/kernel/Kconfig.hz b/kernel/Kconfig.hz index 38ef6d06888e..0f78364efd4f 100644 --- a/kernel/Kconfig.hz +++ b/kernel/Kconfig.hz @@ -40,6 +40,27 @@ choice on SMP and NUMA systems and exactly dividing by both PAL and NTSC frame rates for video and multimedia work. + config HZ_500 + bool "500 HZ" + help + 500 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_600 + bool "600 HZ" + help + 600 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_750 + bool "750 HZ" + help + 750 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + config HZ_1000 bool "1000 HZ" help @@ -53,6 +74,9 @@ config HZ default 100 if HZ_100 default 250 if HZ_250 default 300 if HZ_300 + default 500 if HZ_500 + default 600 if HZ_600 + default 750 if HZ_750 default 1000 if HZ_1000 config SCHED_HRTICK diff --git a/kernel/fork.c b/kernel/fork.c index 22f43721d031..8287afdd01d2 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -107,6 +107,10 @@ #include #include +#ifdef CONFIG_USER_NS +#include +#endif + #include #include #include @@ -2138,6 +2142,10 @@ __latent_entropy struct task_struct *copy_process( if ((clone_flags & (CLONE_NEWUSER|CLONE_FS)) == (CLONE_NEWUSER|CLONE_FS)) return ERR_PTR(-EINVAL); + if ((clone_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) + if (!capable(CAP_SYS_ADMIN)) + return ERR_PTR(-EPERM); + /* * Thread groups must share signals as well, and detached threads * can only be started up within the thread group. @@ -3291,6 +3299,12 @@ int ksys_unshare(unsigned long unshare_flags) if (unshare_flags & CLONE_NEWNS) unshare_flags |= CLONE_FS; + if ((unshare_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) { + err = -EPERM; + if (!capable(CAP_SYS_ADMIN)) + goto bad_unshare_out; + } + err = check_unshare_flags(unshare_flags); if (err) goto bad_unshare_out; diff --git a/kernel/locking/rwsem.c b/kernel/locking/rwsem.c index 2bbb6eca5144..125cdf85741c 100644 --- a/kernel/locking/rwsem.c +++ b/kernel/locking/rwsem.c @@ -747,6 +747,7 @@ rwsem_spin_on_owner(struct rw_semaphore *sem) struct task_struct *new, *owner; unsigned long flags, new_flags; enum owner_state state; + int i = 0; lockdep_assert_preemption_disabled(); @@ -783,7 +784,8 @@ rwsem_spin_on_owner(struct rw_semaphore *sem) break; } - cpu_relax(); + if (i++ > 1000) + cpu_relax(); } return state; diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c index 2d16c8545c71..54e7c4c3e2c5 100644 --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -73,10 +73,19 @@ unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_LOG; * * (default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) */ +#ifdef CONFIG_CACHY +unsigned int sysctl_sched_base_slice = 350000ULL; +static unsigned int normalized_sysctl_sched_base_slice = 350000ULL; +#else unsigned int sysctl_sched_base_slice = 750000ULL; static unsigned int normalized_sysctl_sched_base_slice = 750000ULL; +#endif +#ifdef CONFIG_CACHY +const_debug unsigned int sysctl_sched_migration_cost = 300000UL; +#else const_debug unsigned int sysctl_sched_migration_cost = 500000UL; +#endif static int __init setup_sched_thermal_decay_shift(char *str) { @@ -121,8 +130,12 @@ int __weak arch_asym_cpu_priority(int cpu) * * (default: 5 msec, units: microseconds) */ +#ifdef CONFIG_CACHY +static unsigned int sysctl_sched_cfs_bandwidth_slice = 3000UL; +#else static unsigned int sysctl_sched_cfs_bandwidth_slice = 5000UL; #endif +#endif #ifdef CONFIG_NUMA_BALANCING /* Restrict the NUMA promotion throughput (MB/s) for each target node. */ diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h index 6c54a57275cc..f610df2e0811 100644 --- a/kernel/sched/sched.h +++ b/kernel/sched/sched.h @@ -2815,7 +2815,7 @@ extern void deactivate_task(struct rq *rq, struct task_struct *p, int flags); extern void wakeup_preempt(struct rq *rq, struct task_struct *p, int flags); -#ifdef CONFIG_PREEMPT_RT +#if defined(CONFIG_PREEMPT_RT) || defined(CONFIG_CACHY) # define SCHED_NR_MIGRATE_BREAK 8 #else # define SCHED_NR_MIGRATE_BREAK 32 diff --git a/kernel/sched/wait.c b/kernel/sched/wait.c index 51e38f5f4701..c5cc616484ba 100644 --- a/kernel/sched/wait.c +++ b/kernel/sched/wait.c @@ -47,6 +47,17 @@ void add_wait_queue_priority(struct wait_queue_head *wq_head, struct wait_queue_ } EXPORT_SYMBOL_GPL(add_wait_queue_priority); +void add_wait_queue_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry) +{ + unsigned long flags; + + wq_entry->flags |= WQ_FLAG_EXCLUSIVE; + spin_lock_irqsave(&wq_head->lock, flags); + __add_wait_queue(wq_head, wq_entry); + spin_unlock_irqrestore(&wq_head->lock, flags); +} +EXPORT_SYMBOL(add_wait_queue_exclusive_lifo); + void remove_wait_queue(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry) { unsigned long flags; @@ -258,6 +269,19 @@ prepare_to_wait_exclusive(struct wait_queue_head *wq_head, struct wait_queue_ent } EXPORT_SYMBOL(prepare_to_wait_exclusive); +void prepare_to_wait_exclusive_lifo(struct wait_queue_head *wq_head, struct wait_queue_entry *wq_entry, int state) +{ + unsigned long flags; + + wq_entry->flags |= WQ_FLAG_EXCLUSIVE; + spin_lock_irqsave(&wq_head->lock, flags); + if (list_empty(&wq_entry->entry)) + __add_wait_queue(wq_head, wq_entry); + set_current_state(state); + spin_unlock_irqrestore(&wq_head->lock, flags); +} +EXPORT_SYMBOL(prepare_to_wait_exclusive_lifo); + void init_wait_entry(struct wait_queue_entry *wq_entry, int flags) { wq_entry->flags = flags; diff --git a/kernel/sysctl.c b/kernel/sysctl.c index 79e6cb1d5c48..676e89dc38c3 100644 --- a/kernel/sysctl.c +++ b/kernel/sysctl.c @@ -80,6 +80,9 @@ #ifdef CONFIG_RT_MUTEXES #include #endif +#ifdef CONFIG_USER_NS +#include +#endif /* shared constants to be used in various sysctls */ const int sysctl_vals[] = { 0, 1, 2, 3, 4, 100, 200, 1000, 3000, INT_MAX, 65535, -1 }; @@ -1618,6 +1621,15 @@ static struct ctl_table kern_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, +#ifdef CONFIG_USER_NS + { + .procname = "unprivileged_userns_clone", + .data = &unprivileged_userns_clone, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif #ifdef CONFIG_PROC_SYSCTL { .procname = "tainted", diff --git a/kernel/user_namespace.c b/kernel/user_namespace.c index aa0b2e47f2f2..d74d857b1696 100644 --- a/kernel/user_namespace.c +++ b/kernel/user_namespace.c @@ -22,6 +22,13 @@ #include #include +/* sysctl */ +#ifdef CONFIG_USER_NS_UNPRIVILEGED +int unprivileged_userns_clone = 1; +#else +int unprivileged_userns_clone; +#endif + static struct kmem_cache *user_ns_cachep __ro_after_init; static DEFINE_MUTEX(userns_state_mutex); diff --git a/mm/Kconfig b/mm/Kconfig index 33fa51d608dc..6bfea371341e 100644 --- a/mm/Kconfig +++ b/mm/Kconfig @@ -648,7 +648,7 @@ config COMPACTION config COMPACT_UNEVICTABLE_DEFAULT int depends on COMPACTION - default 0 if PREEMPT_RT + default 0 if PREEMPT_RT || CACHY default 1 # diff --git a/mm/compaction.c b/mm/compaction.c index a2b16b08cbbf..48d611e58ad3 100644 --- a/mm/compaction.c +++ b/mm/compaction.c @@ -1920,7 +1920,11 @@ static int sysctl_compact_unevictable_allowed __read_mostly = CONFIG_COMPACT_UNE * aggressively the kernel should compact memory in the * background. It takes values in the range [0, 100]. */ +#ifdef CONFIG_CACHY +static unsigned int __read_mostly sysctl_compaction_proactiveness; +#else static unsigned int __read_mostly sysctl_compaction_proactiveness = 20; +#endif static int sysctl_extfrag_threshold = 500; static int __read_mostly sysctl_compact_memory; diff --git a/mm/page-writeback.c b/mm/page-writeback.c index fcd4c1439cb9..e2f7d709e819 100644 --- a/mm/page-writeback.c +++ b/mm/page-writeback.c @@ -71,7 +71,11 @@ static long ratelimit_pages = 32; /* * Start background writeback (via writeback threads) at this percentage */ +#ifdef CONFIG_CACHY +static int dirty_background_ratio = 5; +#else static int dirty_background_ratio = 10; +#endif /* * dirty_background_bytes starts at 0 (disabled) so that it is a function of @@ -99,7 +103,11 @@ static unsigned long vm_dirty_bytes; /* * The interval between `kupdate'-style writebacks */ +#ifdef CONFIG_CACHY +unsigned int dirty_writeback_interval = 10 * 100; /* centiseconds */ +#else unsigned int dirty_writeback_interval = 5 * 100; /* centiseconds */ +#endif EXPORT_SYMBOL_GPL(dirty_writeback_interval); diff --git a/mm/page_alloc.c b/mm/page_alloc.c index c6c7bb3ea71b..d8ba1df0b5e1 100644 --- a/mm/page_alloc.c +++ b/mm/page_alloc.c @@ -271,7 +271,11 @@ const char * const migratetype_names[MIGRATE_TYPES] = { int min_free_kbytes = 1024; int user_min_free_kbytes = -1; +#ifdef CONFIG_CACHY +static int watermark_boost_factor __read_mostly; +#else static int watermark_boost_factor __read_mostly = 15000; +#endif static int watermark_scale_factor = 10; /* movable_zone is the "real" zone pages in ZONE_MOVABLE are taken from */ diff --git a/mm/swap.c b/mm/swap.c index b8e3259ea2c4..4e7d140d422c 100644 --- a/mm/swap.c +++ b/mm/swap.c @@ -1094,6 +1094,10 @@ void folio_batch_remove_exceptionals(struct folio_batch *fbatch) */ void __init swap_setup(void) { +#ifdef CONFIG_CACHY + /* Only swap-in pages requested, avoid readahead */ + page_cluster = 0; +#else unsigned long megs = totalram_pages() >> (20 - PAGE_SHIFT); /* Use a smaller cluster for small-memory machines */ @@ -1105,4 +1109,5 @@ void __init swap_setup(void) * Right now other parts of the system means that we * _really_ don't want to cluster much more */ +#endif } diff --git a/mm/vmpressure.c b/mm/vmpressure.c index bd5183dfd879..3a410f53a07c 100644 --- a/mm/vmpressure.c +++ b/mm/vmpressure.c @@ -43,7 +43,11 @@ static const unsigned long vmpressure_win = SWAP_CLUSTER_MAX * 16; * essence, they are percents: the higher the value, the more number * unsuccessful reclaims there were. */ +#ifdef CONFIG_CACHY +static const unsigned int vmpressure_level_med = 65; +#else static const unsigned int vmpressure_level_med = 60; +#endif static const unsigned int vmpressure_level_critical = 95; /* diff --git a/mm/vmscan.c b/mm/vmscan.c index 28ba2b06fc7d..99568ccfb0fd 100644 --- a/mm/vmscan.c +++ b/mm/vmscan.c @@ -200,7 +200,11 @@ struct scan_control { /* * From 0 .. MAX_SWAPPINESS. Higher means more swappy. */ +#ifdef CONFIG_CACHY +int vm_swappiness = 20; +#else int vm_swappiness = 60; +#endif #ifdef CONFIG_MEMCG @@ -3992,7 +3996,11 @@ static bool lruvec_is_reclaimable(struct lruvec *lruvec, struct scan_control *sc } /* to protect the working set of the last N jiffies */ +#ifdef CONFIG_CACHY +static unsigned long lru_gen_min_ttl __read_mostly = 1000; +#else static unsigned long lru_gen_min_ttl __read_mostly; +#endif static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc) { diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 2b698f8419fe..fd039c41d1c8 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -634,7 +634,7 @@ static int inet_csk_wait_for_connect(struct sock *sk, long timeo) * having to remove and re-insert us on the wait queue. */ for (;;) { - prepare_to_wait_exclusive(sk_sleep(sk), &wait, + prepare_to_wait_exclusive_lifo(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); release_sock(sk); if (reqsk_queue_empty(&icsk->icsk_accept_queue)) -- 2.47.0 From cb33f67ae0f185239bebf9bd3491e5c671c72df0 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:18:19 +0100 Subject: [PATCH 06/13] crypto Signed-off-by: Peter Jung --- arch/x86/crypto/Kconfig | 4 +- arch/x86/crypto/aegis128-aesni-asm.S | 532 ++++++++-------------- arch/x86/crypto/aegis128-aesni-glue.c | 145 +++--- arch/x86/crypto/crc32c-intel_glue.c | 2 +- arch/x86/crypto/crc32c-pcl-intel-asm_64.S | 354 +++++--------- 5 files changed, 387 insertions(+), 650 deletions(-) diff --git a/arch/x86/crypto/Kconfig b/arch/x86/crypto/Kconfig index 7b1bebed879d..3d2e38ba5240 100644 --- a/arch/x86/crypto/Kconfig +++ b/arch/x86/crypto/Kconfig @@ -363,7 +363,7 @@ config CRYPTO_CHACHA20_X86_64 - AVX-512VL (Advanced Vector Extensions-512VL) config CRYPTO_AEGIS128_AESNI_SSE2 - tristate "AEAD ciphers: AEGIS-128 (AES-NI/SSE2)" + tristate "AEAD ciphers: AEGIS-128 (AES-NI/SSE4.1)" depends on X86 && 64BIT select CRYPTO_AEAD select CRYPTO_SIMD @@ -372,7 +372,7 @@ config CRYPTO_AEGIS128_AESNI_SSE2 Architecture: x86_64 using: - AES-NI (AES New Instructions) - - SSE2 (Streaming SIMD Extensions 2) + - SSE4.1 (Streaming SIMD Extensions 4.1) config CRYPTO_NHPOLY1305_SSE2 tristate "Hash functions: NHPoly1305 (SSE2)" diff --git a/arch/x86/crypto/aegis128-aesni-asm.S b/arch/x86/crypto/aegis128-aesni-asm.S index ad7f4c891625..7294dc0ee7ba 100644 --- a/arch/x86/crypto/aegis128-aesni-asm.S +++ b/arch/x86/crypto/aegis128-aesni-asm.S @@ -1,14 +1,13 @@ /* SPDX-License-Identifier: GPL-2.0-only */ /* - * AES-NI + SSE2 implementation of AEGIS-128 + * AES-NI + SSE4.1 implementation of AEGIS-128 * * Copyright (c) 2017-2018 Ondrej Mosnacek * Copyright (C) 2017-2018 Red Hat, Inc. All rights reserved. + * Copyright 2024 Google LLC */ #include -#include -#include #define STATE0 %xmm0 #define STATE1 %xmm1 @@ -20,11 +19,6 @@ #define T0 %xmm6 #define T1 %xmm7 -#define STATEP %rdi -#define LEN %rsi -#define SRC %rdx -#define DST %rcx - .section .rodata.cst16.aegis128_const, "aM", @progbits, 32 .align 16 .Laegis128_const_0: @@ -34,11 +28,11 @@ .byte 0xdb, 0x3d, 0x18, 0x55, 0x6d, 0xc2, 0x2f, 0xf1 .byte 0x20, 0x11, 0x31, 0x42, 0x73, 0xb5, 0x28, 0xdd -.section .rodata.cst16.aegis128_counter, "aM", @progbits, 16 -.align 16 -.Laegis128_counter: - .byte 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 - .byte 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f +.section .rodata.cst32.zeropad_mask, "aM", @progbits, 32 +.align 32 +.Lzeropad_mask: + .octa 0xffffffffffffffffffffffffffffffff + .octa 0 .text @@ -61,140 +55,102 @@ .endm /* - * __load_partial: internal ABI - * input: - * LEN - bytes - * SRC - src - * output: - * MSG - message block - * changed: - * T0 - * %r8 - * %r9 + * Load 1 <= LEN (%ecx) <= 15 bytes from the pointer SRC into the xmm register + * MSG and zeroize any remaining bytes. Clobbers %rax, %rcx, and %r8. */ -SYM_FUNC_START_LOCAL(__load_partial) - xor %r9d, %r9d - pxor MSG, MSG - - mov LEN, %r8 - and $0x1, %r8 - jz .Lld_partial_1 - - mov LEN, %r8 - and $0x1E, %r8 - add SRC, %r8 - mov (%r8), %r9b - -.Lld_partial_1: - mov LEN, %r8 - and $0x2, %r8 - jz .Lld_partial_2 - - mov LEN, %r8 - and $0x1C, %r8 - add SRC, %r8 - shl $0x10, %r9 - mov (%r8), %r9w - -.Lld_partial_2: - mov LEN, %r8 - and $0x4, %r8 - jz .Lld_partial_4 - - mov LEN, %r8 - and $0x18, %r8 - add SRC, %r8 - shl $32, %r9 - mov (%r8), %r8d - xor %r8, %r9 - -.Lld_partial_4: - movq %r9, MSG - - mov LEN, %r8 - and $0x8, %r8 - jz .Lld_partial_8 - - mov LEN, %r8 - and $0x10, %r8 - add SRC, %r8 - pslldq $8, MSG - movq (%r8), T0 - pxor T0, MSG - -.Lld_partial_8: - RET -SYM_FUNC_END(__load_partial) +.macro load_partial + sub $8, %ecx /* LEN - 8 */ + jle .Lle8\@ + + /* Load 9 <= LEN <= 15 bytes: */ + movq (SRC), MSG /* Load first 8 bytes */ + mov (SRC, %rcx), %rax /* Load last 8 bytes */ + neg %ecx + shl $3, %ecx + shr %cl, %rax /* Discard overlapping bytes */ + pinsrq $1, %rax, MSG + jmp .Ldone\@ + +.Lle8\@: + add $4, %ecx /* LEN - 4 */ + jl .Llt4\@ + + /* Load 4 <= LEN <= 8 bytes: */ + mov (SRC), %eax /* Load first 4 bytes */ + mov (SRC, %rcx), %r8d /* Load last 4 bytes */ + jmp .Lcombine\@ + +.Llt4\@: + /* Load 1 <= LEN <= 3 bytes: */ + add $2, %ecx /* LEN - 2 */ + movzbl (SRC), %eax /* Load first byte */ + jl .Lmovq\@ + movzwl (SRC, %rcx), %r8d /* Load last 2 bytes */ +.Lcombine\@: + shl $3, %ecx + shl %cl, %r8 + or %r8, %rax /* Combine the two parts */ +.Lmovq\@: + movq %rax, MSG +.Ldone\@: +.endm /* - * __store_partial: internal ABI - * input: - * LEN - bytes - * DST - dst - * output: - * T0 - message block - * changed: - * %r8 - * %r9 - * %r10 + * Store 1 <= LEN (%ecx) <= 15 bytes from the xmm register \msg to the pointer + * DST. Clobbers %rax, %rcx, and %r8. */ -SYM_FUNC_START_LOCAL(__store_partial) - mov LEN, %r8 - mov DST, %r9 - - movq T0, %r10 - - cmp $8, %r8 - jl .Lst_partial_8 - - mov %r10, (%r9) - psrldq $8, T0 - movq T0, %r10 - - sub $8, %r8 - add $8, %r9 - -.Lst_partial_8: - cmp $4, %r8 - jl .Lst_partial_4 - - mov %r10d, (%r9) - shr $32, %r10 - - sub $4, %r8 - add $4, %r9 - -.Lst_partial_4: - cmp $2, %r8 - jl .Lst_partial_2 - - mov %r10w, (%r9) - shr $0x10, %r10 - - sub $2, %r8 - add $2, %r9 - -.Lst_partial_2: - cmp $1, %r8 - jl .Lst_partial_1 - - mov %r10b, (%r9) - -.Lst_partial_1: - RET -SYM_FUNC_END(__store_partial) +.macro store_partial msg + sub $8, %ecx /* LEN - 8 */ + jl .Llt8\@ + + /* Store 8 <= LEN <= 15 bytes: */ + pextrq $1, \msg, %rax + mov %ecx, %r8d + shl $3, %ecx + ror %cl, %rax + mov %rax, (DST, %r8) /* Store last LEN - 8 bytes */ + movq \msg, (DST) /* Store first 8 bytes */ + jmp .Ldone\@ + +.Llt8\@: + add $4, %ecx /* LEN - 4 */ + jl .Llt4\@ + + /* Store 4 <= LEN <= 7 bytes: */ + pextrd $1, \msg, %eax + mov %ecx, %r8d + shl $3, %ecx + ror %cl, %eax + mov %eax, (DST, %r8) /* Store last LEN - 4 bytes */ + movd \msg, (DST) /* Store first 4 bytes */ + jmp .Ldone\@ + +.Llt4\@: + /* Store 1 <= LEN <= 3 bytes: */ + pextrb $0, \msg, 0(DST) + cmp $-2, %ecx /* LEN - 4 == -2, i.e. LEN == 2? */ + jl .Ldone\@ + pextrb $1, \msg, 1(DST) + je .Ldone\@ + pextrb $2, \msg, 2(DST) +.Ldone\@: +.endm /* - * void crypto_aegis128_aesni_init(void *state, const void *key, const void *iv); + * void aegis128_aesni_init(struct aegis_state *state, + * const struct aegis_block *key, + * const u8 iv[AEGIS128_NONCE_SIZE]); */ -SYM_FUNC_START(crypto_aegis128_aesni_init) - FRAME_BEGIN +SYM_FUNC_START(aegis128_aesni_init) + .set STATEP, %rdi + .set KEYP, %rsi + .set IVP, %rdx /* load IV: */ - movdqu (%rdx), T1 + movdqu (IVP), T1 /* load key: */ - movdqa (%rsi), KEY + movdqa (KEYP), KEY pxor KEY, T1 movdqa T1, STATE0 movdqa KEY, STATE3 @@ -224,20 +180,22 @@ SYM_FUNC_START(crypto_aegis128_aesni_init) movdqu STATE2, 0x20(STATEP) movdqu STATE3, 0x30(STATEP) movdqu STATE4, 0x40(STATEP) - - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_init) +SYM_FUNC_END(aegis128_aesni_init) /* - * void crypto_aegis128_aesni_ad(void *state, unsigned int length, - * const void *data); + * void aegis128_aesni_ad(struct aegis_state *state, const u8 *data, + * unsigned int len); + * + * len must be a multiple of 16. */ -SYM_FUNC_START(crypto_aegis128_aesni_ad) - FRAME_BEGIN +SYM_FUNC_START(aegis128_aesni_ad) + .set STATEP, %rdi + .set SRC, %rsi + .set LEN, %edx - cmp $0x10, LEN - jb .Lad_out + test LEN, LEN + jz .Lad_out /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -246,89 +204,40 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu 0x30(STATEP), STATE3 movdqu 0x40(STATEP), STATE4 - mov SRC, %r8 - and $0xF, %r8 - jnz .Lad_u_loop - -.align 8 -.Lad_a_loop: - movdqa 0x00(SRC), MSG - aegis128_update - pxor MSG, STATE4 - sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_1 - - movdqa 0x10(SRC), MSG - aegis128_update - pxor MSG, STATE3 - sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_2 - - movdqa 0x20(SRC), MSG - aegis128_update - pxor MSG, STATE2 - sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_3 - - movdqa 0x30(SRC), MSG - aegis128_update - pxor MSG, STATE1 - sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_4 - - movdqa 0x40(SRC), MSG - aegis128_update - pxor MSG, STATE0 - sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_0 - - add $0x50, SRC - jmp .Lad_a_loop - .align 8 -.Lad_u_loop: +.Lad_loop: movdqu 0x00(SRC), MSG aegis128_update pxor MSG, STATE4 sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_1 + jz .Lad_out_1 movdqu 0x10(SRC), MSG aegis128_update pxor MSG, STATE3 sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_2 + jz .Lad_out_2 movdqu 0x20(SRC), MSG aegis128_update pxor MSG, STATE2 sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_3 + jz .Lad_out_3 movdqu 0x30(SRC), MSG aegis128_update pxor MSG, STATE1 sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_4 + jz .Lad_out_4 movdqu 0x40(SRC), MSG aegis128_update pxor MSG, STATE0 sub $0x10, LEN - cmp $0x10, LEN - jl .Lad_out_0 + jz .Lad_out_0 add $0x50, SRC - jmp .Lad_u_loop + jmp .Lad_loop /* store the state: */ .Lad_out_0: @@ -337,7 +246,6 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu STATE2, 0x20(STATEP) movdqu STATE3, 0x30(STATEP) movdqu STATE4, 0x40(STATEP) - FRAME_END RET .Lad_out_1: @@ -346,7 +254,6 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu STATE1, 0x20(STATEP) movdqu STATE2, 0x30(STATEP) movdqu STATE3, 0x40(STATEP) - FRAME_END RET .Lad_out_2: @@ -355,7 +262,6 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu STATE0, 0x20(STATEP) movdqu STATE1, 0x30(STATEP) movdqu STATE2, 0x40(STATEP) - FRAME_END RET .Lad_out_3: @@ -364,7 +270,6 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu STATE4, 0x20(STATEP) movdqu STATE0, 0x30(STATEP) movdqu STATE1, 0x40(STATEP) - FRAME_END RET .Lad_out_4: @@ -373,41 +278,38 @@ SYM_FUNC_START(crypto_aegis128_aesni_ad) movdqu STATE3, 0x20(STATEP) movdqu STATE4, 0x30(STATEP) movdqu STATE0, 0x40(STATEP) - FRAME_END - RET - .Lad_out: - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_ad) +SYM_FUNC_END(aegis128_aesni_ad) -.macro encrypt_block a s0 s1 s2 s3 s4 i - movdq\a (\i * 0x10)(SRC), MSG +.macro encrypt_block s0 s1 s2 s3 s4 i + movdqu (\i * 0x10)(SRC), MSG movdqa MSG, T0 pxor \s1, T0 pxor \s4, T0 movdqa \s2, T1 pand \s3, T1 pxor T1, T0 - movdq\a T0, (\i * 0x10)(DST) + movdqu T0, (\i * 0x10)(DST) aegis128_update pxor MSG, \s4 sub $0x10, LEN - cmp $0x10, LEN - jl .Lenc_out_\i + jz .Lenc_out_\i .endm /* - * void crypto_aegis128_aesni_enc(void *state, unsigned int length, - * const void *src, void *dst); + * void aegis128_aesni_enc(struct aegis_state *state, const u8 *src, u8 *dst, + * unsigned int len); + * + * len must be nonzero and a multiple of 16. */ -SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) - FRAME_BEGIN - - cmp $0x10, LEN - jb .Lenc_out +SYM_FUNC_START(aegis128_aesni_enc) + .set STATEP, %rdi + .set SRC, %rsi + .set DST, %rdx + .set LEN, %ecx /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -416,34 +318,17 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu 0x30(STATEP), STATE3 movdqu 0x40(STATEP), STATE4 - mov SRC, %r8 - or DST, %r8 - and $0xF, %r8 - jnz .Lenc_u_loop - .align 8 -.Lenc_a_loop: - encrypt_block a STATE0 STATE1 STATE2 STATE3 STATE4 0 - encrypt_block a STATE4 STATE0 STATE1 STATE2 STATE3 1 - encrypt_block a STATE3 STATE4 STATE0 STATE1 STATE2 2 - encrypt_block a STATE2 STATE3 STATE4 STATE0 STATE1 3 - encrypt_block a STATE1 STATE2 STATE3 STATE4 STATE0 4 +.Lenc_loop: + encrypt_block STATE0 STATE1 STATE2 STATE3 STATE4 0 + encrypt_block STATE4 STATE0 STATE1 STATE2 STATE3 1 + encrypt_block STATE3 STATE4 STATE0 STATE1 STATE2 2 + encrypt_block STATE2 STATE3 STATE4 STATE0 STATE1 3 + encrypt_block STATE1 STATE2 STATE3 STATE4 STATE0 4 add $0x50, SRC add $0x50, DST - jmp .Lenc_a_loop - -.align 8 -.Lenc_u_loop: - encrypt_block u STATE0 STATE1 STATE2 STATE3 STATE4 0 - encrypt_block u STATE4 STATE0 STATE1 STATE2 STATE3 1 - encrypt_block u STATE3 STATE4 STATE0 STATE1 STATE2 2 - encrypt_block u STATE2 STATE3 STATE4 STATE0 STATE1 3 - encrypt_block u STATE1 STATE2 STATE3 STATE4 STATE0 4 - - add $0x50, SRC - add $0x50, DST - jmp .Lenc_u_loop + jmp .Lenc_loop /* store the state: */ .Lenc_out_0: @@ -452,7 +337,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu STATE1, 0x20(STATEP) movdqu STATE2, 0x30(STATEP) movdqu STATE3, 0x40(STATEP) - FRAME_END RET .Lenc_out_1: @@ -461,7 +345,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu STATE0, 0x20(STATEP) movdqu STATE1, 0x30(STATEP) movdqu STATE2, 0x40(STATEP) - FRAME_END RET .Lenc_out_2: @@ -470,7 +353,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu STATE4, 0x20(STATEP) movdqu STATE0, 0x30(STATEP) movdqu STATE1, 0x40(STATEP) - FRAME_END RET .Lenc_out_3: @@ -479,7 +361,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu STATE3, 0x20(STATEP) movdqu STATE4, 0x30(STATEP) movdqu STATE0, 0x40(STATEP) - FRAME_END RET .Lenc_out_4: @@ -488,20 +369,19 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc) movdqu STATE2, 0x20(STATEP) movdqu STATE3, 0x30(STATEP) movdqu STATE4, 0x40(STATEP) - FRAME_END - RET - .Lenc_out: - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_enc) +SYM_FUNC_END(aegis128_aesni_enc) /* - * void crypto_aegis128_aesni_enc_tail(void *state, unsigned int length, - * const void *src, void *dst); + * void aegis128_aesni_enc_tail(struct aegis_state *state, const u8 *src, + * u8 *dst, unsigned int len); */ -SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc_tail) - FRAME_BEGIN +SYM_FUNC_START(aegis128_aesni_enc_tail) + .set STATEP, %rdi + .set SRC, %rsi + .set DST, %rdx + .set LEN, %ecx /* {load,store}_partial rely on this being %ecx */ /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -511,7 +391,8 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc_tail) movdqu 0x40(STATEP), STATE4 /* encrypt message: */ - call __load_partial + mov LEN, %r9d + load_partial movdqa MSG, T0 pxor STATE1, T0 @@ -520,7 +401,8 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc_tail) pand STATE3, T1 pxor T1, T0 - call __store_partial + mov %r9d, LEN + store_partial T0 aegis128_update pxor MSG, STATE4 @@ -531,37 +413,36 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_enc_tail) movdqu STATE1, 0x20(STATEP) movdqu STATE2, 0x30(STATEP) movdqu STATE3, 0x40(STATEP) - - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_enc_tail) +SYM_FUNC_END(aegis128_aesni_enc_tail) -.macro decrypt_block a s0 s1 s2 s3 s4 i - movdq\a (\i * 0x10)(SRC), MSG +.macro decrypt_block s0 s1 s2 s3 s4 i + movdqu (\i * 0x10)(SRC), MSG pxor \s1, MSG pxor \s4, MSG movdqa \s2, T1 pand \s3, T1 pxor T1, MSG - movdq\a MSG, (\i * 0x10)(DST) + movdqu MSG, (\i * 0x10)(DST) aegis128_update pxor MSG, \s4 sub $0x10, LEN - cmp $0x10, LEN - jl .Ldec_out_\i + jz .Ldec_out_\i .endm /* - * void crypto_aegis128_aesni_dec(void *state, unsigned int length, - * const void *src, void *dst); + * void aegis128_aesni_dec(struct aegis_state *state, const u8 *src, u8 *dst, + * unsigned int len); + * + * len must be nonzero and a multiple of 16. */ -SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) - FRAME_BEGIN - - cmp $0x10, LEN - jb .Ldec_out +SYM_FUNC_START(aegis128_aesni_dec) + .set STATEP, %rdi + .set SRC, %rsi + .set DST, %rdx + .set LEN, %ecx /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -570,34 +451,17 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu 0x30(STATEP), STATE3 movdqu 0x40(STATEP), STATE4 - mov SRC, %r8 - or DST, %r8 - and $0xF, %r8 - jnz .Ldec_u_loop - .align 8 -.Ldec_a_loop: - decrypt_block a STATE0 STATE1 STATE2 STATE3 STATE4 0 - decrypt_block a STATE4 STATE0 STATE1 STATE2 STATE3 1 - decrypt_block a STATE3 STATE4 STATE0 STATE1 STATE2 2 - decrypt_block a STATE2 STATE3 STATE4 STATE0 STATE1 3 - decrypt_block a STATE1 STATE2 STATE3 STATE4 STATE0 4 +.Ldec_loop: + decrypt_block STATE0 STATE1 STATE2 STATE3 STATE4 0 + decrypt_block STATE4 STATE0 STATE1 STATE2 STATE3 1 + decrypt_block STATE3 STATE4 STATE0 STATE1 STATE2 2 + decrypt_block STATE2 STATE3 STATE4 STATE0 STATE1 3 + decrypt_block STATE1 STATE2 STATE3 STATE4 STATE0 4 add $0x50, SRC add $0x50, DST - jmp .Ldec_a_loop - -.align 8 -.Ldec_u_loop: - decrypt_block u STATE0 STATE1 STATE2 STATE3 STATE4 0 - decrypt_block u STATE4 STATE0 STATE1 STATE2 STATE3 1 - decrypt_block u STATE3 STATE4 STATE0 STATE1 STATE2 2 - decrypt_block u STATE2 STATE3 STATE4 STATE0 STATE1 3 - decrypt_block u STATE1 STATE2 STATE3 STATE4 STATE0 4 - - add $0x50, SRC - add $0x50, DST - jmp .Ldec_u_loop + jmp .Ldec_loop /* store the state: */ .Ldec_out_0: @@ -606,7 +470,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu STATE1, 0x20(STATEP) movdqu STATE2, 0x30(STATEP) movdqu STATE3, 0x40(STATEP) - FRAME_END RET .Ldec_out_1: @@ -615,7 +478,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu STATE0, 0x20(STATEP) movdqu STATE1, 0x30(STATEP) movdqu STATE2, 0x40(STATEP) - FRAME_END RET .Ldec_out_2: @@ -624,7 +486,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu STATE4, 0x20(STATEP) movdqu STATE0, 0x30(STATEP) movdqu STATE1, 0x40(STATEP) - FRAME_END RET .Ldec_out_3: @@ -633,7 +494,6 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu STATE3, 0x20(STATEP) movdqu STATE4, 0x30(STATEP) movdqu STATE0, 0x40(STATEP) - FRAME_END RET .Ldec_out_4: @@ -642,20 +502,19 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec) movdqu STATE2, 0x20(STATEP) movdqu STATE3, 0x30(STATEP) movdqu STATE4, 0x40(STATEP) - FRAME_END - RET - .Ldec_out: - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_dec) +SYM_FUNC_END(aegis128_aesni_dec) /* - * void crypto_aegis128_aesni_dec_tail(void *state, unsigned int length, - * const void *src, void *dst); + * void aegis128_aesni_dec_tail(struct aegis_state *state, const u8 *src, + * u8 *dst, unsigned int len); */ -SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec_tail) - FRAME_BEGIN +SYM_FUNC_START(aegis128_aesni_dec_tail) + .set STATEP, %rdi + .set SRC, %rsi + .set DST, %rdx + .set LEN, %ecx /* {load,store}_partial rely on this being %ecx */ /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -665,7 +524,8 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec_tail) movdqu 0x40(STATEP), STATE4 /* decrypt message: */ - call __load_partial + mov LEN, %r9d + load_partial pxor STATE1, MSG pxor STATE4, MSG @@ -673,17 +533,13 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec_tail) pand STATE3, T1 pxor T1, MSG - movdqa MSG, T0 - call __store_partial + mov %r9d, LEN + store_partial MSG /* mask with byte count: */ - movq LEN, T0 - punpcklbw T0, T0 - punpcklbw T0, T0 - punpcklbw T0, T0 - punpcklbw T0, T0 - movdqa .Laegis128_counter(%rip), T1 - pcmpgtb T1, T0 + lea .Lzeropad_mask+16(%rip), %rax + sub %r9, %rax + movdqu (%rax), T0 pand T0, MSG aegis128_update @@ -695,17 +551,19 @@ SYM_TYPED_FUNC_START(crypto_aegis128_aesni_dec_tail) movdqu STATE1, 0x20(STATEP) movdqu STATE2, 0x30(STATEP) movdqu STATE3, 0x40(STATEP) - - FRAME_END RET -SYM_FUNC_END(crypto_aegis128_aesni_dec_tail) +SYM_FUNC_END(aegis128_aesni_dec_tail) /* - * void crypto_aegis128_aesni_final(void *state, void *tag_xor, - * u64 assoclen, u64 cryptlen); + * void aegis128_aesni_final(struct aegis_state *state, + * struct aegis_block *tag_xor, + * unsigned int assoclen, unsigned int cryptlen); */ -SYM_FUNC_START(crypto_aegis128_aesni_final) - FRAME_BEGIN +SYM_FUNC_START(aegis128_aesni_final) + .set STATEP, %rdi + .set TAG_XOR, %rsi + .set ASSOCLEN, %edx + .set CRYPTLEN, %ecx /* load the state: */ movdqu 0x00(STATEP), STATE0 @@ -715,10 +573,8 @@ SYM_FUNC_START(crypto_aegis128_aesni_final) movdqu 0x40(STATEP), STATE4 /* prepare length block: */ - movq %rdx, MSG - movq %rcx, T0 - pslldq $8, T0 - pxor T0, MSG + movd ASSOCLEN, MSG + pinsrd $2, CRYPTLEN, MSG psllq $3, MSG /* multiply by 8 (to get bit count) */ pxor STATE3, MSG @@ -733,7 +589,7 @@ SYM_FUNC_START(crypto_aegis128_aesni_final) aegis128_update; pxor MSG, STATE3 /* xor tag: */ - movdqu (%rsi), MSG + movdqu (TAG_XOR), MSG pxor STATE0, MSG pxor STATE1, MSG @@ -741,8 +597,6 @@ SYM_FUNC_START(crypto_aegis128_aesni_final) pxor STATE3, MSG pxor STATE4, MSG - movdqu MSG, (%rsi) - - FRAME_END + movdqu MSG, (TAG_XOR) RET -SYM_FUNC_END(crypto_aegis128_aesni_final) +SYM_FUNC_END(aegis128_aesni_final) diff --git a/arch/x86/crypto/aegis128-aesni-glue.c b/arch/x86/crypto/aegis128-aesni-glue.c index 4623189000d8..c19d8e3d96a3 100644 --- a/arch/x86/crypto/aegis128-aesni-glue.c +++ b/arch/x86/crypto/aegis128-aesni-glue.c @@ -1,7 +1,7 @@ // SPDX-License-Identifier: GPL-2.0-or-later /* * The AEGIS-128 Authenticated-Encryption Algorithm - * Glue for AES-NI + SSE2 implementation + * Glue for AES-NI + SSE4.1 implementation * * Copyright (c) 2017-2018 Ondrej Mosnacek * Copyright (C) 2017-2018 Red Hat, Inc. All rights reserved. @@ -23,27 +23,6 @@ #define AEGIS128_MIN_AUTH_SIZE 8 #define AEGIS128_MAX_AUTH_SIZE 16 -asmlinkage void crypto_aegis128_aesni_init(void *state, void *key, void *iv); - -asmlinkage void crypto_aegis128_aesni_ad( - void *state, unsigned int length, const void *data); - -asmlinkage void crypto_aegis128_aesni_enc( - void *state, unsigned int length, const void *src, void *dst); - -asmlinkage void crypto_aegis128_aesni_dec( - void *state, unsigned int length, const void *src, void *dst); - -asmlinkage void crypto_aegis128_aesni_enc_tail( - void *state, unsigned int length, const void *src, void *dst); - -asmlinkage void crypto_aegis128_aesni_dec_tail( - void *state, unsigned int length, const void *src, void *dst); - -asmlinkage void crypto_aegis128_aesni_final( - void *state, void *tag_xor, unsigned int cryptlen, - unsigned int assoclen); - struct aegis_block { u8 bytes[AEGIS128_BLOCK_SIZE] __aligned(AEGIS128_BLOCK_ALIGN); }; @@ -56,15 +35,31 @@ struct aegis_ctx { struct aegis_block key; }; -struct aegis_crypt_ops { - int (*skcipher_walk_init)(struct skcipher_walk *walk, - struct aead_request *req, bool atomic); +asmlinkage void aegis128_aesni_init(struct aegis_state *state, + const struct aegis_block *key, + const u8 iv[AEGIS128_NONCE_SIZE]); - void (*crypt_blocks)(void *state, unsigned int length, const void *src, - void *dst); - void (*crypt_tail)(void *state, unsigned int length, const void *src, - void *dst); -}; +asmlinkage void aegis128_aesni_ad(struct aegis_state *state, const u8 *data, + unsigned int len); + +asmlinkage void aegis128_aesni_enc(struct aegis_state *state, const u8 *src, + u8 *dst, unsigned int len); + +asmlinkage void aegis128_aesni_dec(struct aegis_state *state, const u8 *src, + u8 *dst, unsigned int len); + +asmlinkage void aegis128_aesni_enc_tail(struct aegis_state *state, + const u8 *src, u8 *dst, + unsigned int len); + +asmlinkage void aegis128_aesni_dec_tail(struct aegis_state *state, + const u8 *src, u8 *dst, + unsigned int len); + +asmlinkage void aegis128_aesni_final(struct aegis_state *state, + struct aegis_block *tag_xor, + unsigned int assoclen, + unsigned int cryptlen); static void crypto_aegis128_aesni_process_ad( struct aegis_state *state, struct scatterlist *sg_src, @@ -85,16 +80,15 @@ static void crypto_aegis128_aesni_process_ad( if (pos > 0) { unsigned int fill = AEGIS128_BLOCK_SIZE - pos; memcpy(buf.bytes + pos, src, fill); - crypto_aegis128_aesni_ad(state, - AEGIS128_BLOCK_SIZE, - buf.bytes); + aegis128_aesni_ad(state, buf.bytes, + AEGIS128_BLOCK_SIZE); pos = 0; left -= fill; src += fill; } - crypto_aegis128_aesni_ad(state, left, src); - + aegis128_aesni_ad(state, src, + left & ~(AEGIS128_BLOCK_SIZE - 1)); src += left & ~(AEGIS128_BLOCK_SIZE - 1); left &= AEGIS128_BLOCK_SIZE - 1; } @@ -110,24 +104,37 @@ static void crypto_aegis128_aesni_process_ad( if (pos > 0) { memset(buf.bytes + pos, 0, AEGIS128_BLOCK_SIZE - pos); - crypto_aegis128_aesni_ad(state, AEGIS128_BLOCK_SIZE, buf.bytes); + aegis128_aesni_ad(state, buf.bytes, AEGIS128_BLOCK_SIZE); } } -static void crypto_aegis128_aesni_process_crypt( - struct aegis_state *state, struct skcipher_walk *walk, - const struct aegis_crypt_ops *ops) +static __always_inline void +crypto_aegis128_aesni_process_crypt(struct aegis_state *state, + struct skcipher_walk *walk, bool enc) { while (walk->nbytes >= AEGIS128_BLOCK_SIZE) { - ops->crypt_blocks(state, - round_down(walk->nbytes, AEGIS128_BLOCK_SIZE), - walk->src.virt.addr, walk->dst.virt.addr); + if (enc) + aegis128_aesni_enc(state, walk->src.virt.addr, + walk->dst.virt.addr, + round_down(walk->nbytes, + AEGIS128_BLOCK_SIZE)); + else + aegis128_aesni_dec(state, walk->src.virt.addr, + walk->dst.virt.addr, + round_down(walk->nbytes, + AEGIS128_BLOCK_SIZE)); skcipher_walk_done(walk, walk->nbytes % AEGIS128_BLOCK_SIZE); } if (walk->nbytes) { - ops->crypt_tail(state, walk->nbytes, walk->src.virt.addr, - walk->dst.virt.addr); + if (enc) + aegis128_aesni_enc_tail(state, walk->src.virt.addr, + walk->dst.virt.addr, + walk->nbytes); + else + aegis128_aesni_dec_tail(state, walk->src.virt.addr, + walk->dst.virt.addr, + walk->nbytes); skcipher_walk_done(walk, 0); } } @@ -162,42 +169,39 @@ static int crypto_aegis128_aesni_setauthsize(struct crypto_aead *tfm, return 0; } -static void crypto_aegis128_aesni_crypt(struct aead_request *req, - struct aegis_block *tag_xor, - unsigned int cryptlen, - const struct aegis_crypt_ops *ops) +static __always_inline void +crypto_aegis128_aesni_crypt(struct aead_request *req, + struct aegis_block *tag_xor, + unsigned int cryptlen, bool enc) { struct crypto_aead *tfm = crypto_aead_reqtfm(req); struct aegis_ctx *ctx = crypto_aegis128_aesni_ctx(tfm); struct skcipher_walk walk; struct aegis_state state; - ops->skcipher_walk_init(&walk, req, true); + if (enc) + skcipher_walk_aead_encrypt(&walk, req, true); + else + skcipher_walk_aead_decrypt(&walk, req, true); kernel_fpu_begin(); - crypto_aegis128_aesni_init(&state, ctx->key.bytes, req->iv); + aegis128_aesni_init(&state, &ctx->key, req->iv); crypto_aegis128_aesni_process_ad(&state, req->src, req->assoclen); - crypto_aegis128_aesni_process_crypt(&state, &walk, ops); - crypto_aegis128_aesni_final(&state, tag_xor, req->assoclen, cryptlen); + crypto_aegis128_aesni_process_crypt(&state, &walk, enc); + aegis128_aesni_final(&state, tag_xor, req->assoclen, cryptlen); kernel_fpu_end(); } static int crypto_aegis128_aesni_encrypt(struct aead_request *req) { - static const struct aegis_crypt_ops OPS = { - .skcipher_walk_init = skcipher_walk_aead_encrypt, - .crypt_blocks = crypto_aegis128_aesni_enc, - .crypt_tail = crypto_aegis128_aesni_enc_tail, - }; - struct crypto_aead *tfm = crypto_aead_reqtfm(req); struct aegis_block tag = {}; unsigned int authsize = crypto_aead_authsize(tfm); unsigned int cryptlen = req->cryptlen; - crypto_aegis128_aesni_crypt(req, &tag, cryptlen, &OPS); + crypto_aegis128_aesni_crypt(req, &tag, cryptlen, true); scatterwalk_map_and_copy(tag.bytes, req->dst, req->assoclen + cryptlen, authsize, 1); @@ -208,12 +212,6 @@ static int crypto_aegis128_aesni_decrypt(struct aead_request *req) { static const struct aegis_block zeros = {}; - static const struct aegis_crypt_ops OPS = { - .skcipher_walk_init = skcipher_walk_aead_decrypt, - .crypt_blocks = crypto_aegis128_aesni_dec, - .crypt_tail = crypto_aegis128_aesni_dec_tail, - }; - struct crypto_aead *tfm = crypto_aead_reqtfm(req); struct aegis_block tag; unsigned int authsize = crypto_aead_authsize(tfm); @@ -222,27 +220,16 @@ static int crypto_aegis128_aesni_decrypt(struct aead_request *req) scatterwalk_map_and_copy(tag.bytes, req->src, req->assoclen + cryptlen, authsize, 0); - crypto_aegis128_aesni_crypt(req, &tag, cryptlen, &OPS); + crypto_aegis128_aesni_crypt(req, &tag, cryptlen, false); return crypto_memneq(tag.bytes, zeros.bytes, authsize) ? -EBADMSG : 0; } -static int crypto_aegis128_aesni_init_tfm(struct crypto_aead *aead) -{ - return 0; -} - -static void crypto_aegis128_aesni_exit_tfm(struct crypto_aead *aead) -{ -} - static struct aead_alg crypto_aegis128_aesni_alg = { .setkey = crypto_aegis128_aesni_setkey, .setauthsize = crypto_aegis128_aesni_setauthsize, .encrypt = crypto_aegis128_aesni_encrypt, .decrypt = crypto_aegis128_aesni_decrypt, - .init = crypto_aegis128_aesni_init_tfm, - .exit = crypto_aegis128_aesni_exit_tfm, .ivsize = AEGIS128_NONCE_SIZE, .maxauthsize = AEGIS128_MAX_AUTH_SIZE, @@ -267,7 +254,7 @@ static struct simd_aead_alg *simd_alg; static int __init crypto_aegis128_aesni_module_init(void) { - if (!boot_cpu_has(X86_FEATURE_XMM2) || + if (!boot_cpu_has(X86_FEATURE_XMM4_1) || !boot_cpu_has(X86_FEATURE_AES) || !cpu_has_xfeatures(XFEATURE_MASK_SSE, NULL)) return -ENODEV; @@ -286,6 +273,6 @@ module_exit(crypto_aegis128_aesni_module_exit); MODULE_LICENSE("GPL"); MODULE_AUTHOR("Ondrej Mosnacek "); -MODULE_DESCRIPTION("AEGIS-128 AEAD algorithm -- AESNI+SSE2 implementation"); +MODULE_DESCRIPTION("AEGIS-128 AEAD algorithm -- AESNI+SSE4.1 implementation"); MODULE_ALIAS_CRYPTO("aegis128"); MODULE_ALIAS_CRYPTO("aegis128-aesni"); diff --git a/arch/x86/crypto/crc32c-intel_glue.c b/arch/x86/crypto/crc32c-intel_glue.c index feccb5254c7e..52c5d47ef5a1 100644 --- a/arch/x86/crypto/crc32c-intel_glue.c +++ b/arch/x86/crypto/crc32c-intel_glue.c @@ -41,7 +41,7 @@ */ #define CRC32C_PCL_BREAKEVEN 512 -asmlinkage unsigned int crc_pcl(const u8 *buffer, int len, +asmlinkage unsigned int crc_pcl(const u8 *buffer, unsigned int len, unsigned int crc_init); #endif /* CONFIG_X86_64 */ diff --git a/arch/x86/crypto/crc32c-pcl-intel-asm_64.S b/arch/x86/crypto/crc32c-pcl-intel-asm_64.S index bbcff1fb78cb..752812bc4991 100644 --- a/arch/x86/crypto/crc32c-pcl-intel-asm_64.S +++ b/arch/x86/crypto/crc32c-pcl-intel-asm_64.S @@ -7,6 +7,7 @@ * http://www.intel.com/content/dam/www/public/us/en/documents/white-papers/fast-crc-computation-paper.pdf * * Copyright (C) 2012 Intel Corporation. + * Copyright 2024 Google LLC * * Authors: * Wajdi Feghali @@ -44,185 +45,129 @@ */ #include -#include ## ISCSI CRC 32 Implementation with crc32 and pclmulqdq Instruction -.macro LABEL prefix n -.L\prefix\n\(): -.endm - -.macro JMPTBL_ENTRY i -.quad .Lcrc_\i -.endm - -.macro JNC_LESS_THAN j - jnc .Lless_than_\j -.endm - -# Define threshold where buffers are considered "small" and routed to more -# efficient "by-1" code. This "by-1" code only handles up to 255 bytes, so -# SMALL_SIZE can be no larger than 255. - +# Define threshold below which buffers are considered "small" and routed to +# regular CRC code that does not interleave the CRC instructions. #define SMALL_SIZE 200 -.if (SMALL_SIZE > 255) -.error "SMALL_ SIZE must be < 256" -.endif - -# unsigned int crc_pcl(u8 *buffer, int len, unsigned int crc_init); +# unsigned int crc_pcl(const u8 *buffer, unsigned int len, unsigned int crc_init); .text SYM_FUNC_START(crc_pcl) -#define bufp rdi -#define bufp_dw %edi -#define bufp_w %di -#define bufp_b %dil -#define bufptmp %rcx -#define block_0 %rcx -#define block_1 %rdx -#define block_2 %r11 -#define len %rsi -#define len_dw %esi -#define len_w %si -#define len_b %sil -#define crc_init_arg %rdx -#define tmp %rbx -#define crc_init %r8 -#define crc_init_dw %r8d -#define crc1 %r9 -#define crc2 %r10 - - pushq %rbx - pushq %rdi - pushq %rsi - - ## Move crc_init for Linux to a different - mov crc_init_arg, crc_init +#define bufp %rdi +#define bufp_d %edi +#define len %esi +#define crc_init %edx +#define crc_init_q %rdx +#define n_misaligned %ecx /* overlaps chunk_bytes! */ +#define n_misaligned_q %rcx +#define chunk_bytes %ecx /* overlaps n_misaligned! */ +#define chunk_bytes_q %rcx +#define crc1 %r8 +#define crc2 %r9 + + cmp $SMALL_SIZE, len + jb .Lsmall ################################################################ ## 1) ALIGN: ################################################################ - - mov %bufp, bufptmp # rdi = *buf - neg %bufp - and $7, %bufp # calculate the unalignment amount of + mov bufp_d, n_misaligned + neg n_misaligned + and $7, n_misaligned # calculate the misalignment amount of # the address - je .Lproc_block # Skip if aligned - - ## If len is less than 8 and we're unaligned, we need to jump - ## to special code to avoid reading beyond the end of the buffer - cmp $8, len - jae .Ldo_align - # less_than_8 expects length in upper 3 bits of len_dw - # less_than_8_post_shl1 expects length = carryflag * 8 + len_dw[31:30] - shl $32-3+1, len_dw - jmp .Lless_than_8_post_shl1 + je .Laligned # Skip if aligned + # Process 1 <= n_misaligned <= 7 bytes individually in order to align + # the remaining data to an 8-byte boundary. .Ldo_align: - #### Calculate CRC of unaligned bytes of the buffer (if any) - movq (bufptmp), tmp # load a quadward from the buffer - add %bufp, bufptmp # align buffer pointer for quadword - # processing - sub %bufp, len # update buffer length + movq (bufp), %rax + add n_misaligned_q, bufp + sub n_misaligned, len .Lalign_loop: - crc32b %bl, crc_init_dw # compute crc32 of 1-byte - shr $8, tmp # get next byte - dec %bufp + crc32b %al, crc_init # compute crc32 of 1-byte + shr $8, %rax # get next byte + dec n_misaligned jne .Lalign_loop - -.Lproc_block: +.Laligned: ################################################################ - ## 2) PROCESS BLOCKS: + ## 2) PROCESS BLOCK: ################################################################ - ## compute num of bytes to be processed - movq len, tmp # save num bytes in tmp - - cmpq $128*24, len + cmp $128*24, len jae .Lfull_block -.Lcontinue_block: - cmpq $SMALL_SIZE, len - jb .Lsmall - - ## len < 128*24 - movq $2731, %rax # 2731 = ceil(2^16 / 24) - mul len_dw - shrq $16, %rax - - ## eax contains floor(bytes / 24) = num 24-byte chunks to do - - ## process rax 24-byte chunks (128 >= rax >= 0) - - ## compute end address of each block - ## block 0 (base addr + RAX * 8) - ## block 1 (base addr + RAX * 16) - ## block 2 (base addr + RAX * 24) - lea (bufptmp, %rax, 8), block_0 - lea (block_0, %rax, 8), block_1 - lea (block_1, %rax, 8), block_2 +.Lpartial_block: + # Compute floor(len / 24) to get num qwords to process from each lane. + imul $2731, len, %eax # 2731 = ceil(2^16 / 24) + shr $16, %eax + jmp .Lcrc_3lanes - xor crc1, crc1 - xor crc2, crc2 - - ## branch into array - leaq jump_table(%rip), %bufp - mov (%bufp,%rax,8), %bufp - JMP_NOSPEC bufp - - ################################################################ - ## 2a) PROCESS FULL BLOCKS: - ################################################################ .Lfull_block: - movl $128,%eax - lea 128*8*2(block_0), block_1 - lea 128*8*3(block_0), block_2 - add $128*8*1, block_0 - - xor crc1,crc1 - xor crc2,crc2 - - # Fall through into top of crc array (crc_128) + # Processing 128 qwords from each lane. + mov $128, %eax ################################################################ - ## 3) CRC Array: + ## 3) CRC each of three lanes: ################################################################ - i=128 -.rept 128-1 -.altmacro -LABEL crc_ %i -.noaltmacro - ENDBR - crc32q -i*8(block_0), crc_init - crc32q -i*8(block_1), crc1 - crc32q -i*8(block_2), crc2 - i=(i-1) -.endr - -.altmacro -LABEL crc_ %i -.noaltmacro - ENDBR - crc32q -i*8(block_0), crc_init - crc32q -i*8(block_1), crc1 -# SKIP crc32 -i*8(block_2), crc2 ; Don't do this one yet - - mov block_2, block_0 +.Lcrc_3lanes: + xor crc1,crc1 + xor crc2,crc2 + mov %eax, chunk_bytes + shl $3, chunk_bytes # num bytes to process from each lane + sub $5, %eax # 4 for 4x_loop, 1 for special last iter + jl .Lcrc_3lanes_4x_done + + # Unroll the loop by a factor of 4 to reduce the overhead of the loop + # bookkeeping instructions, which can compete with crc32q for the ALUs. +.Lcrc_3lanes_4x_loop: + crc32q (bufp), crc_init_q + crc32q (bufp,chunk_bytes_q), crc1 + crc32q (bufp,chunk_bytes_q,2), crc2 + crc32q 8(bufp), crc_init_q + crc32q 8(bufp,chunk_bytes_q), crc1 + crc32q 8(bufp,chunk_bytes_q,2), crc2 + crc32q 16(bufp), crc_init_q + crc32q 16(bufp,chunk_bytes_q), crc1 + crc32q 16(bufp,chunk_bytes_q,2), crc2 + crc32q 24(bufp), crc_init_q + crc32q 24(bufp,chunk_bytes_q), crc1 + crc32q 24(bufp,chunk_bytes_q,2), crc2 + add $32, bufp + sub $4, %eax + jge .Lcrc_3lanes_4x_loop + +.Lcrc_3lanes_4x_done: + add $4, %eax + jz .Lcrc_3lanes_last_qword + +.Lcrc_3lanes_1x_loop: + crc32q (bufp), crc_init_q + crc32q (bufp,chunk_bytes_q), crc1 + crc32q (bufp,chunk_bytes_q,2), crc2 + add $8, bufp + dec %eax + jnz .Lcrc_3lanes_1x_loop + +.Lcrc_3lanes_last_qword: + crc32q (bufp), crc_init_q + crc32q (bufp,chunk_bytes_q), crc1 +# SKIP crc32q (bufp,chunk_bytes_q,2), crc2 ; Don't do this one yet ################################################################ ## 4) Combine three results: ################################################################ - lea (K_table-8)(%rip), %bufp # first entry is for idx 1 - shlq $3, %rax # rax *= 8 - pmovzxdq (%bufp,%rax), %xmm0 # 2 consts: K1:K2 - leal (%eax,%eax,2), %eax # rax *= 3 (total *24) - subq %rax, tmp # tmp -= rax*24 + lea (K_table-8)(%rip), %rax # first entry is for idx 1 + pmovzxdq (%rax,chunk_bytes_q), %xmm0 # 2 consts: K1:K2 + lea (chunk_bytes,chunk_bytes,2), %eax # chunk_bytes * 3 + sub %eax, len # len -= chunk_bytes * 3 - movq crc_init, %xmm1 # CRC for block 1 + movq crc_init_q, %xmm1 # CRC for block 1 pclmulqdq $0x00, %xmm0, %xmm1 # Multiply by K2 movq crc1, %xmm2 # CRC for block 2 @@ -230,103 +175,54 @@ LABEL crc_ %i pxor %xmm2,%xmm1 movq %xmm1, %rax - xor -i*8(block_2), %rax - mov crc2, crc_init - crc32 %rax, crc_init + xor (bufp,chunk_bytes_q,2), %rax + mov crc2, crc_init_q + crc32 %rax, crc_init_q + lea 8(bufp,chunk_bytes_q,2), bufp ################################################################ - ## 5) Check for end: + ## 5) If more blocks remain, goto (2): ################################################################ -LABEL crc_ 0 - ENDBR - mov tmp, len - cmp $128*24, tmp - jae .Lfull_block - cmp $24, tmp - jae .Lcontinue_block - -.Lless_than_24: - shl $32-4, len_dw # less_than_16 expects length - # in upper 4 bits of len_dw - jnc .Lless_than_16 - crc32q (bufptmp), crc_init - crc32q 8(bufptmp), crc_init - jz .Ldo_return - add $16, bufptmp - # len is less than 8 if we got here - # less_than_8 expects length in upper 3 bits of len_dw - # less_than_8_post_shl1 expects length = carryflag * 8 + len_dw[31:30] - shl $2, len_dw - jmp .Lless_than_8_post_shl1 + cmp $128*24, len + jae .Lfull_block + cmp $SMALL_SIZE, len + jae .Lpartial_block ####################################################################### - ## 6) LESS THAN 256-bytes REMAIN AT THIS POINT (8-bits of len are full) + ## 6) Process any remainder without interleaving: ####################################################################### .Lsmall: - shl $32-8, len_dw # Prepare len_dw for less_than_256 - j=256 -.rept 5 # j = {256, 128, 64, 32, 16} -.altmacro -LABEL less_than_ %j # less_than_j: Length should be in - # upper lg(j) bits of len_dw - j=(j/2) - shl $1, len_dw # Get next MSB - JNC_LESS_THAN %j -.noaltmacro - i=0 -.rept (j/8) - crc32q i(bufptmp), crc_init # Compute crc32 of 8-byte data - i=i+8 -.endr - jz .Ldo_return # Return if remaining length is zero - add $j, bufptmp # Advance buf -.endr - -.Lless_than_8: # Length should be stored in - # upper 3 bits of len_dw - shl $1, len_dw -.Lless_than_8_post_shl1: - jnc .Lless_than_4 - crc32l (bufptmp), crc_init_dw # CRC of 4 bytes - jz .Ldo_return # return if remaining data is zero - add $4, bufptmp -.Lless_than_4: # Length should be stored in - # upper 2 bits of len_dw - shl $1, len_dw - jnc .Lless_than_2 - crc32w (bufptmp), crc_init_dw # CRC of 2 bytes - jz .Ldo_return # return if remaining data is zero - add $2, bufptmp -.Lless_than_2: # Length should be stored in the MSB - # of len_dw - shl $1, len_dw - jnc .Lless_than_1 - crc32b (bufptmp), crc_init_dw # CRC of 1 byte -.Lless_than_1: # Length should be zero -.Ldo_return: - movq crc_init, %rax - popq %rsi - popq %rdi - popq %rbx + test len, len + jz .Ldone + mov len, %eax + shr $3, %eax + jz .Ldo_dword +.Ldo_qwords: + crc32q (bufp), crc_init_q + add $8, bufp + dec %eax + jnz .Ldo_qwords +.Ldo_dword: + test $4, len + jz .Ldo_word + crc32l (bufp), crc_init + add $4, bufp +.Ldo_word: + test $2, len + jz .Ldo_byte + crc32w (bufp), crc_init + add $2, bufp +.Ldo_byte: + test $1, len + jz .Ldone + crc32b (bufp), crc_init +.Ldone: + mov crc_init, %eax RET SYM_FUNC_END(crc_pcl) .section .rodata, "a", @progbits - ################################################################ - ## jump table Table is 129 entries x 2 bytes each - ################################################################ -.align 4 -jump_table: - i=0 -.rept 129 -.altmacro -JMPTBL_ENTRY %i -.noaltmacro - i=i+1 -.endr - - ################################################################ ## PCLMULQDQ tables ## Table is 128 entries x 2 words (8 bytes) each -- 2.47.0 From 10b2f8b54a3363a982c4d021ed29a191bea5c0b3 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:20:20 +0100 Subject: [PATCH 07/13] fixes Signed-off-by: Peter Jung --- arch/Kconfig | 4 +- arch/x86/kernel/cpu/amd.c | 11 ++ arch/x86/mm/tlb.c | 22 ++-- drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 5 + drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c | 108 ++++++------------ drivers/gpu/drm/amd/pm/swsmu/inc/amdgpu_smu.h | 11 +- .../gpu/drm/amd/pm/swsmu/smu11/arcturus_ppt.c | 20 ++-- .../gpu/drm/amd/pm/swsmu/smu11/navi10_ppt.c | 20 ++-- .../amd/pm/swsmu/smu11/sienna_cichlid_ppt.c | 21 ++-- .../gpu/drm/amd/pm/swsmu/smu11/vangogh_ppt.c | 17 +-- .../gpu/drm/amd/pm/swsmu/smu12/renoir_ppt.c | 17 +-- .../drm/amd/pm/swsmu/smu13/smu_v13_0_0_ppt.c | 33 +++--- .../drm/amd/pm/swsmu/smu13/smu_v13_0_7_ppt.c | 21 ++-- .../drm/amd/pm/swsmu/smu14/smu_v14_0_2_ppt.c | 24 ++-- drivers/gpu/drm/amd/pm/swsmu/smu_cmn.c | 8 -- drivers/gpu/drm/amd/pm/swsmu/smu_cmn.h | 2 - drivers/gpu/drm/drm_edid.c | 47 +++++++- drivers/misc/lkdtm/bugs.c | 2 +- fs/ntfs3/attrib.c | 9 +- fs/ntfs3/bitmap.c | 62 +++------- fs/ntfs3/file.c | 34 +++--- fs/ntfs3/frecord.c | 104 +++-------------- fs/ntfs3/fsntfs.c | 2 +- fs/ntfs3/ntfs_fs.h | 3 +- fs/ntfs3/record.c | 16 ++- fs/ntfs3/run.c | 40 +++++-- include/linux/compiler_attributes.h | 13 --- include/linux/compiler_types.h | 19 +++ include/linux/mm_types.h | 1 + init/Kconfig | 8 ++ kernel/sched/core.c | 46 +++++--- kernel/sched/sched.h | 5 + lib/overflow_kunit.c | 2 +- scripts/package/PKGBUILD | 5 + 34 files changed, 376 insertions(+), 386 deletions(-) diff --git a/arch/Kconfig b/arch/Kconfig index 00551f340dbe..833b2344ce79 100644 --- a/arch/Kconfig +++ b/arch/Kconfig @@ -1128,7 +1128,7 @@ config ARCH_MMAP_RND_BITS int "Number of bits to use for ASLR of mmap base address" if EXPERT range ARCH_MMAP_RND_BITS_MIN ARCH_MMAP_RND_BITS_MAX default ARCH_MMAP_RND_BITS_DEFAULT if ARCH_MMAP_RND_BITS_DEFAULT - default ARCH_MMAP_RND_BITS_MIN + default ARCH_MMAP_RND_BITS_MAX depends on HAVE_ARCH_MMAP_RND_BITS help This value can be used to select the number of bits to use to @@ -1162,7 +1162,7 @@ config ARCH_MMAP_RND_COMPAT_BITS int "Number of bits to use for ASLR of mmap base address for compatible applications" if EXPERT range ARCH_MMAP_RND_COMPAT_BITS_MIN ARCH_MMAP_RND_COMPAT_BITS_MAX default ARCH_MMAP_RND_COMPAT_BITS_DEFAULT if ARCH_MMAP_RND_COMPAT_BITS_DEFAULT - default ARCH_MMAP_RND_COMPAT_BITS_MIN + default ARCH_MMAP_RND_COMPAT_BITS_MAX depends on HAVE_ARCH_MMAP_RND_COMPAT_BITS help This value can be used to select the number of bits to use to diff --git a/arch/x86/kernel/cpu/amd.c b/arch/x86/kernel/cpu/amd.c index fab5caec0b72..823f44f7bc94 100644 --- a/arch/x86/kernel/cpu/amd.c +++ b/arch/x86/kernel/cpu/amd.c @@ -924,6 +924,17 @@ static void init_amd_zen4(struct cpuinfo_x86 *c) { if (!cpu_has(c, X86_FEATURE_HYPERVISOR)) msr_set_bit(MSR_ZEN4_BP_CFG, MSR_ZEN4_BP_CFG_SHARED_BTB_FIX_BIT); + + /* + * These Zen4 SoCs advertise support for virtualized VMLOAD/VMSAVE + * in some BIOS versions but they can lead to random host reboots. + */ + switch (c->x86_model) { + case 0x18 ... 0x1f: + case 0x60 ... 0x7f: + clear_cpu_cap(c, X86_FEATURE_V_VMSAVE_VMLOAD); + break; + } } static void init_amd_zen5(struct cpuinfo_x86 *c) diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c index 86593d1b787d..1aac4fa90d3d 100644 --- a/arch/x86/mm/tlb.c +++ b/arch/x86/mm/tlb.c @@ -568,7 +568,7 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, * mm_cpumask. The TLB shootdown code can figure out from * cpu_tlbstate_shared.is_lazy whether or not to send an IPI. */ - if (WARN_ON_ONCE(prev != &init_mm && + if (IS_ENABLED(CONFIG_DEBUG_VM) && WARN_ON_ONCE(prev != &init_mm && !cpumask_test_cpu(cpu, mm_cpumask(next)))) cpumask_set_cpu(cpu, mm_cpumask(next)); @@ -606,18 +606,15 @@ void switch_mm_irqs_off(struct mm_struct *unused, struct mm_struct *next, cond_mitigation(tsk); /* - * Stop remote flushes for the previous mm. - * Skip kernel threads; we never send init_mm TLB flushing IPIs, - * but the bitmap manipulation can cause cache line contention. + * Leave this CPU in prev's mm_cpumask. Atomic writes to + * mm_cpumask can be expensive under contention. The CPU + * will be removed lazily at TLB flush time. */ - if (prev != &init_mm) { - VM_WARN_ON_ONCE(!cpumask_test_cpu(cpu, - mm_cpumask(prev))); - cpumask_clear_cpu(cpu, mm_cpumask(prev)); - } + VM_WARN_ON_ONCE(prev != &init_mm && !cpumask_test_cpu(cpu, + mm_cpumask(prev))); /* Start receiving IPIs and then read tlb_gen (and LAM below) */ - if (next != &init_mm) + if (next != &init_mm && !cpumask_test_cpu(cpu, mm_cpumask(next))) cpumask_set_cpu(cpu, mm_cpumask(next)); next_tlb_gen = atomic64_read(&next->context.tlb_gen); @@ -761,8 +758,11 @@ static void flush_tlb_func(void *info) count_vm_tlb_event(NR_TLB_REMOTE_FLUSH_RECEIVED); /* Can only happen on remote CPUs */ - if (f->mm && f->mm != loaded_mm) + if (f->mm && f->mm != loaded_mm) { + cpumask_clear_cpu(raw_smp_processor_id(), mm_cpumask(f->mm)); + trace_tlb_flush(TLB_REMOTE_WRONG_CPU, 0); return; + } } if (unlikely(loaded_mm == &init_mm)) diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c index 852e6f315576..f6a6fc6a4f5c 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c @@ -3078,6 +3078,11 @@ static int __init amdgpu_init(void) /* Ignore KFD init failures. Normal when CONFIG_HSA_AMD is not set. */ amdgpu_amdkfd_init(); + if (amdgpu_pp_feature_mask & PP_OVERDRIVE_MASK) { + add_taint(TAINT_CPU_OUT_OF_SPEC, LOCKDEP_STILL_OK); + pr_crit("Overdrive is enabled, please disable it before reporting any bugs.\n"); + } + /* let modprobe override vga console setting */ return pci_register_driver(&amdgpu_kms_pci_driver); diff --git a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c index 3388604f222b..daa870302cc3 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c +++ b/drivers/gpu/drm/amd/pm/swsmu/amdgpu_smu.c @@ -1257,42 +1257,18 @@ static int smu_sw_init(void *handle) INIT_WORK(&smu->interrupt_work, smu_interrupt_work_fn); atomic64_set(&smu->throttle_int_counter, 0); smu->watermarks_bitmap = 0; - smu->power_profile_mode = PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT; - smu->default_power_profile_mode = PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT; - smu->user_dpm_profile.user_workload_mask = 0; atomic_set(&smu->smu_power.power_gate.vcn_gated, 1); atomic_set(&smu->smu_power.power_gate.jpeg_gated, 1); atomic_set(&smu->smu_power.power_gate.vpe_gated, 1); atomic_set(&smu->smu_power.power_gate.umsch_mm_gated, 1); - smu->workload_priority[PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT] = 0; - smu->workload_priority[PP_SMC_POWER_PROFILE_FULLSCREEN3D] = 1; - smu->workload_priority[PP_SMC_POWER_PROFILE_POWERSAVING] = 2; - smu->workload_priority[PP_SMC_POWER_PROFILE_VIDEO] = 3; - smu->workload_priority[PP_SMC_POWER_PROFILE_VR] = 4; - smu->workload_priority[PP_SMC_POWER_PROFILE_COMPUTE] = 5; - smu->workload_priority[PP_SMC_POWER_PROFILE_CUSTOM] = 6; - if (smu->is_apu || - !smu_is_workload_profile_available(smu, PP_SMC_POWER_PROFILE_FULLSCREEN3D)) { - smu->driver_workload_mask = - 1 << smu->workload_priority[PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT]; - } else { - smu->driver_workload_mask = - 1 << smu->workload_priority[PP_SMC_POWER_PROFILE_FULLSCREEN3D]; - smu->default_power_profile_mode = PP_SMC_POWER_PROFILE_FULLSCREEN3D; - } - - smu->workload_mask = smu->driver_workload_mask | - smu->user_dpm_profile.user_workload_mask; - smu->workload_setting[0] = PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT; - smu->workload_setting[1] = PP_SMC_POWER_PROFILE_FULLSCREEN3D; - smu->workload_setting[2] = PP_SMC_POWER_PROFILE_POWERSAVING; - smu->workload_setting[3] = PP_SMC_POWER_PROFILE_VIDEO; - smu->workload_setting[4] = PP_SMC_POWER_PROFILE_VR; - smu->workload_setting[5] = PP_SMC_POWER_PROFILE_COMPUTE; - smu->workload_setting[6] = PP_SMC_POWER_PROFILE_CUSTOM; + !smu_is_workload_profile_available(smu, PP_SMC_POWER_PROFILE_FULLSCREEN3D)) + smu->power_profile_mode = PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT; + else + smu->power_profile_mode = PP_SMC_POWER_PROFILE_FULLSCREEN3D; + smu->display_config = &adev->pm.pm_display_cfg; smu->smu_dpm.dpm_level = AMD_DPM_FORCED_LEVEL_AUTO; @@ -2232,24 +2208,23 @@ static int smu_enable_umd_pstate(void *handle, } static int smu_bump_power_profile_mode(struct smu_context *smu, - long *param, - uint32_t param_size) + long *param, + uint32_t param_size, + bool enable) { int ret = 0; if (smu->ppt_funcs->set_power_profile_mode) - ret = smu->ppt_funcs->set_power_profile_mode(smu, param, param_size); + ret = smu->ppt_funcs->set_power_profile_mode(smu, param, param_size, enable); return ret; } static int smu_adjust_power_state_dynamic(struct smu_context *smu, enum amd_dpm_forced_level level, - bool skip_display_settings, - bool init) + bool skip_display_settings) { int ret = 0; - int index = 0; long workload[1]; struct smu_dpm_context *smu_dpm_ctx = &(smu->smu_dpm); @@ -2287,13 +2262,10 @@ static int smu_adjust_power_state_dynamic(struct smu_context *smu, } if (smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_MANUAL && - smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_PERF_DETERMINISM) { - index = fls(smu->workload_mask); - index = index > 0 && index <= WORKLOAD_POLICY_MAX ? index - 1 : 0; - workload[0] = smu->workload_setting[index]; + smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_PERF_DETERMINISM) { + workload[0] = smu->power_profile_mode; - if (init || smu->power_profile_mode != workload[0]) - smu_bump_power_profile_mode(smu, workload, 0); + smu_bump_power_profile_mode(smu, workload, 0, true); } return ret; @@ -2313,13 +2285,13 @@ static int smu_handle_task(struct smu_context *smu, ret = smu_pre_display_config_changed(smu); if (ret) return ret; - ret = smu_adjust_power_state_dynamic(smu, level, false, false); + ret = smu_adjust_power_state_dynamic(smu, level, false); break; case AMD_PP_TASK_COMPLETE_INIT: - ret = smu_adjust_power_state_dynamic(smu, level, true, true); + ret = smu_adjust_power_state_dynamic(smu, level, true); break; case AMD_PP_TASK_READJUST_POWER_STATE: - ret = smu_adjust_power_state_dynamic(smu, level, true, false); + ret = smu_adjust_power_state_dynamic(smu, level, true); break; default: break; @@ -2341,12 +2313,11 @@ static int smu_handle_dpm_task(void *handle, static int smu_switch_power_profile(void *handle, enum PP_SMC_POWER_PROFILE type, - bool en) + bool enable) { struct smu_context *smu = handle; struct smu_dpm_context *smu_dpm_ctx = &(smu->smu_dpm); long workload[1]; - uint32_t index; if (!smu->pm_enabled || !smu->adev->pm.dpm_enabled) return -EOPNOTSUPP; @@ -2354,24 +2325,15 @@ static int smu_switch_power_profile(void *handle, if (!(type < PP_SMC_POWER_PROFILE_CUSTOM)) return -EINVAL; - if (!en) { - smu->driver_workload_mask &= ~(1 << smu->workload_priority[type]); - index = fls(smu->workload_mask); - index = index > 0 && index <= WORKLOAD_POLICY_MAX ? index - 1 : 0; - workload[0] = smu->workload_setting[index]; - } else { - smu->driver_workload_mask |= (1 << smu->workload_priority[type]); - index = fls(smu->workload_mask); - index = index <= WORKLOAD_POLICY_MAX ? index - 1 : 0; - workload[0] = smu->workload_setting[index]; - } + /* don't disable the user's preference */ + if (!enable && type == smu->power_profile_mode) + return 0; - smu->workload_mask = smu->driver_workload_mask | - smu->user_dpm_profile.user_workload_mask; + workload[0] = type; if (smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_MANUAL && - smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_PERF_DETERMINISM) - smu_bump_power_profile_mode(smu, workload, 0); + smu_dpm_ctx->dpm_level != AMD_DPM_FORCED_LEVEL_PERF_DETERMINISM) + smu_bump_power_profile_mode(smu, workload, 0, enable); return 0; } @@ -3069,21 +3031,25 @@ static int smu_set_power_profile_mode(void *handle, uint32_t param_size) { struct smu_context *smu = handle; - int ret; + long workload[1]; + int ret = 0; if (!smu->pm_enabled || !smu->adev->pm.dpm_enabled || !smu->ppt_funcs->set_power_profile_mode) return -EOPNOTSUPP; - if (smu->user_dpm_profile.user_workload_mask & - (1 << smu->workload_priority[param[param_size]])) - return 0; - - smu->user_dpm_profile.user_workload_mask = - (1 << smu->workload_priority[param[param_size]]); - smu->workload_mask = smu->user_dpm_profile.user_workload_mask | - smu->driver_workload_mask; - ret = smu_bump_power_profile_mode(smu, param, param_size); + if (param[param_size] != smu->power_profile_mode) { + /* clear the old user preference */ + workload[0] = smu->power_profile_mode; + ret = smu_bump_power_profile_mode(smu, workload, 0, false); + if (ret) + return ret; + /* set the new user preference */ + ret = smu_bump_power_profile_mode(smu, param, param_size, true); + if (!ret) + /* store the user's preference */ + smu->power_profile_mode = param[param_size]; + } return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/inc/amdgpu_smu.h b/drivers/gpu/drm/amd/pm/swsmu/inc/amdgpu_smu.h index d60d9a12a47e..fc54b2c6ede8 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/inc/amdgpu_smu.h +++ b/drivers/gpu/drm/amd/pm/swsmu/inc/amdgpu_smu.h @@ -240,7 +240,6 @@ struct smu_user_dpm_profile { /* user clock state information */ uint32_t clk_mask[SMU_CLK_COUNT]; uint32_t clk_dependency; - uint32_t user_workload_mask; }; #define SMU_TABLE_INIT(tables, table_id, s, a, d) \ @@ -557,12 +556,10 @@ struct smu_context { uint32_t hard_min_uclk_req_from_dal; bool disable_uclk_switch; + /* backend specific workload mask */ uint32_t workload_mask; - uint32_t driver_workload_mask; - uint32_t workload_priority[WORKLOAD_POLICY_MAX]; - uint32_t workload_setting[WORKLOAD_POLICY_MAX]; + /* default/user workload preference */ uint32_t power_profile_mode; - uint32_t default_power_profile_mode; bool pm_enabled; bool is_apu; @@ -734,8 +731,10 @@ struct pptable_funcs { * create/set custom power profile modes. * &input: Power profile mode parameters. * &size: Size of &input. + * &enable: enable/disable the profile */ - int (*set_power_profile_mode)(struct smu_context *smu, long *input, uint32_t size); + int (*set_power_profile_mode)(struct smu_context *smu, long *input, + uint32_t size, bool enable); /** * @dpm_set_vcn_enable: Enable/disable VCN engine dynamic power diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu11/arcturus_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu11/arcturus_ppt.c index 31fe512028f4..ac7fbb815644 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu11/arcturus_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu11/arcturus_ppt.c @@ -1443,7 +1443,8 @@ static int arcturus_get_power_profile_mode(struct smu_context *smu, static int arcturus_set_power_profile_mode(struct smu_context *smu, long *input, - uint32_t size) + uint32_t size, + bool enable) { DpmActivityMonitorCoeffInt_t activity_monitor; int workload_type = 0; @@ -1455,8 +1456,9 @@ static int arcturus_set_power_profile_mode(struct smu_context *smu, return -EINVAL; } - if ((profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) && - (smu->smc_fw_version >= 0x360d00)) { + if (enable && + (profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) && + (smu->smc_fw_version >= 0x360d00)) { if (size != 10) return -EINVAL; @@ -1520,18 +1522,18 @@ static int arcturus_set_power_profile_mode(struct smu_context *smu, return -EINVAL; } + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, smu->workload_mask, NULL); - if (ret) { + if (ret) dev_err(smu->adev->dev, "Fail to set workload type %d\n", workload_type); - return ret; - } - - smu_cmn_assign_power_profile(smu); - return 0; + return ret; } static int arcturus_set_performance_level(struct smu_context *smu, diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu11/navi10_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu11/navi10_ppt.c index 12223f507977..656df9fce471 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu11/navi10_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu11/navi10_ppt.c @@ -2004,19 +2004,19 @@ static int navi10_get_power_profile_mode(struct smu_context *smu, char *buf) return size; } -static int navi10_set_power_profile_mode(struct smu_context *smu, long *input, uint32_t size) +static int navi10_set_power_profile_mode(struct smu_context *smu, long *input, + uint32_t size, bool enable) { DpmActivityMonitorCoeffInt_t activity_monitor; int workload_type, ret = 0; + uint32_t profile_mode = input[size]; - smu->power_profile_mode = input[size]; - - if (smu->power_profile_mode > PP_SMC_POWER_PROFILE_CUSTOM) { - dev_err(smu->adev->dev, "Invalid power profile mode %d\n", smu->power_profile_mode); + if (profile_mode > PP_SMC_POWER_PROFILE_CUSTOM) { + dev_err(smu->adev->dev, "Invalid power profile mode %d\n", profile_mode); return -EINVAL; } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { + if (enable && profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { if (size != 10) return -EINVAL; @@ -2078,16 +2078,18 @@ static int navi10_set_power_profile_mode(struct smu_context *smu, long *input, u /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ workload_type = smu_cmn_to_asic_specific_index(smu, CMN2ASIC_MAPPING_WORKLOAD, - smu->power_profile_mode); + profile_mode); if (workload_type < 0) return -EINVAL; + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, smu->workload_mask, NULL); if (ret) dev_err(smu->adev->dev, "[%s] Failed to set work load mask!", __func__); - else - smu_cmn_assign_power_profile(smu); return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu11/sienna_cichlid_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu11/sienna_cichlid_ppt.c index 3b7b2ec8319a..289cba0f741e 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu11/sienna_cichlid_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu11/sienna_cichlid_ppt.c @@ -1706,22 +1706,23 @@ static int sienna_cichlid_get_power_profile_mode(struct smu_context *smu, char * return size; } -static int sienna_cichlid_set_power_profile_mode(struct smu_context *smu, long *input, uint32_t size) +static int sienna_cichlid_set_power_profile_mode(struct smu_context *smu, + long *input, uint32_t size, + bool enable) { DpmActivityMonitorCoeffIntExternal_t activity_monitor_external; DpmActivityMonitorCoeffInt_t *activity_monitor = &(activity_monitor_external.DpmActivityMonitorCoeffInt); + uint32_t profile_mode = input[size]; int workload_type, ret = 0; - smu->power_profile_mode = input[size]; - - if (smu->power_profile_mode > PP_SMC_POWER_PROFILE_CUSTOM) { - dev_err(smu->adev->dev, "Invalid power profile mode %d\n", smu->power_profile_mode); + if (profile_mode > PP_SMC_POWER_PROFILE_CUSTOM) { + dev_err(smu->adev->dev, "Invalid power profile mode %d\n", profile_mode); return -EINVAL; } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { + if (enable && profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { if (size != 10) return -EINVAL; @@ -1783,16 +1784,18 @@ static int sienna_cichlid_set_power_profile_mode(struct smu_context *smu, long * /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ workload_type = smu_cmn_to_asic_specific_index(smu, CMN2ASIC_MAPPING_WORKLOAD, - smu->power_profile_mode); + profile_mode); if (workload_type < 0) return -EINVAL; + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, smu->workload_mask, NULL); if (ret) dev_err(smu->adev->dev, "[%s] Failed to set work load mask!", __func__); - else - smu_cmn_assign_power_profile(smu); return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu11/vangogh_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu11/vangogh_ppt.c index 952ee22cbc90..a123ae7809ec 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu11/vangogh_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu11/vangogh_ppt.c @@ -1054,7 +1054,8 @@ static int vangogh_get_power_profile_mode(struct smu_context *smu, return size; } -static int vangogh_set_power_profile_mode(struct smu_context *smu, long *input, uint32_t size) +static int vangogh_set_power_profile_mode(struct smu_context *smu, long *input, + uint32_t size, bool enable) { int workload_type, ret; uint32_t profile_mode = input[size]; @@ -1065,7 +1066,7 @@ static int vangogh_set_power_profile_mode(struct smu_context *smu, long *input, } if (profile_mode == PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT || - profile_mode == PP_SMC_POWER_PROFILE_POWERSAVING) + profile_mode == PP_SMC_POWER_PROFILE_POWERSAVING) return 0; /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ @@ -1078,18 +1079,18 @@ static int vangogh_set_power_profile_mode(struct smu_context *smu, long *input, return -EINVAL; } + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_ActiveProcessNotify, smu->workload_mask, NULL); - if (ret) { + if (ret) dev_err_once(smu->adev->dev, "Fail to set workload type %d\n", workload_type); - return ret; - } - - smu_cmn_assign_power_profile(smu); - return 0; + return ret; } static int vangogh_set_soft_freq_limited_range(struct smu_context *smu, diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu12/renoir_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu12/renoir_ppt.c index 62316a6707ef..25779abc5447 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu12/renoir_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu12/renoir_ppt.c @@ -862,7 +862,8 @@ static int renoir_force_clk_levels(struct smu_context *smu, return ret; } -static int renoir_set_power_profile_mode(struct smu_context *smu, long *input, uint32_t size) +static int renoir_set_power_profile_mode(struct smu_context *smu, long *input, + uint32_t size, bool enable) { int workload_type, ret; uint32_t profile_mode = input[size]; @@ -873,7 +874,7 @@ static int renoir_set_power_profile_mode(struct smu_context *smu, long *input, u } if (profile_mode == PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT || - profile_mode == PP_SMC_POWER_PROFILE_POWERSAVING) + profile_mode == PP_SMC_POWER_PROFILE_POWERSAVING) return 0; /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ @@ -889,17 +890,17 @@ static int renoir_set_power_profile_mode(struct smu_context *smu, long *input, u return -EINVAL; } + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_ActiveProcessNotify, smu->workload_mask, NULL); - if (ret) { + if (ret) dev_err_once(smu->adev->dev, "Fail to set workload type %d\n", workload_type); - return ret; - } - smu_cmn_assign_power_profile(smu); - - return 0; + return ret; } static int renoir_set_peak_clock_by_device(struct smu_context *smu) diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_0_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_0_ppt.c index 5dd7ceca64fe..6861267b68fb 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_0_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_0_ppt.c @@ -2479,22 +2479,22 @@ static int smu_v13_0_0_get_power_profile_mode(struct smu_context *smu, static int smu_v13_0_0_set_power_profile_mode(struct smu_context *smu, long *input, - uint32_t size) + uint32_t size, + bool enable) { DpmActivityMonitorCoeffIntExternal_t activity_monitor_external; DpmActivityMonitorCoeffInt_t *activity_monitor = &(activity_monitor_external.DpmActivityMonitorCoeffInt); + uint32_t profile_mode = input[size]; int workload_type, ret = 0; u32 workload_mask; - smu->power_profile_mode = input[size]; - - if (smu->power_profile_mode >= PP_SMC_POWER_PROFILE_COUNT) { - dev_err(smu->adev->dev, "Invalid power profile mode %d\n", smu->power_profile_mode); + if (profile_mode >= PP_SMC_POWER_PROFILE_COUNT) { + dev_err(smu->adev->dev, "Invalid power profile mode %d\n", profile_mode); return -EINVAL; } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { + if (enable && profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { if (size != 9) return -EINVAL; @@ -2547,13 +2547,18 @@ static int smu_v13_0_0_set_power_profile_mode(struct smu_context *smu, /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ workload_type = smu_cmn_to_asic_specific_index(smu, CMN2ASIC_MAPPING_WORKLOAD, - smu->power_profile_mode); + profile_mode); if (workload_type < 0) return -EINVAL; workload_mask = 1 << workload_type; + if (enable) + smu->workload_mask |= workload_mask; + else + smu->workload_mask &= ~workload_mask; + /* Add optimizations for SMU13.0.0/10. Reuse the power saving profile */ if ((amdgpu_ip_version(smu->adev, MP1_HWIP, 0) == IP_VERSION(13, 0, 0) && ((smu->adev->pm.fw_version == 0x004e6601) || @@ -2564,25 +2569,13 @@ static int smu_v13_0_0_set_power_profile_mode(struct smu_context *smu, CMN2ASIC_MAPPING_WORKLOAD, PP_SMC_POWER_PROFILE_POWERSAVING); if (workload_type >= 0) - workload_mask |= 1 << workload_type; + smu->workload_mask |= 1 << workload_type; } - smu->workload_mask |= workload_mask; ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, smu->workload_mask, NULL); - if (!ret) { - smu_cmn_assign_power_profile(smu); - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_POWERSAVING) { - workload_type = smu_cmn_to_asic_specific_index(smu, - CMN2ASIC_MAPPING_WORKLOAD, - PP_SMC_POWER_PROFILE_FULLSCREEN3D); - smu->power_profile_mode = smu->workload_mask & (1 << workload_type) - ? PP_SMC_POWER_PROFILE_FULLSCREEN3D - : PP_SMC_POWER_PROFILE_BOOTUP_DEFAULT; - } - } return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_7_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_7_ppt.c index 9d0b19419de0..bf1f8e63e228 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_7_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu13/smu_v13_0_7_ppt.c @@ -2434,22 +2434,23 @@ do { \ return result; } -static int smu_v13_0_7_set_power_profile_mode(struct smu_context *smu, long *input, uint32_t size) +static int smu_v13_0_7_set_power_profile_mode(struct smu_context *smu, + long *input, uint32_t size, + bool enable) { DpmActivityMonitorCoeffIntExternal_t activity_monitor_external; DpmActivityMonitorCoeffInt_t *activity_monitor = &(activity_monitor_external.DpmActivityMonitorCoeffInt); + uint32_t profile_mode = input[size]; int workload_type, ret = 0; - smu->power_profile_mode = input[size]; - - if (smu->power_profile_mode > PP_SMC_POWER_PROFILE_WINDOW3D) { - dev_err(smu->adev->dev, "Invalid power profile mode %d\n", smu->power_profile_mode); + if (profile_mode > PP_SMC_POWER_PROFILE_WINDOW3D) { + dev_err(smu->adev->dev, "Invalid power profile mode %d\n", profile_mode); return -EINVAL; } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { + if (enable && profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { if (size != 8) return -EINVAL; @@ -2496,17 +2497,19 @@ static int smu_v13_0_7_set_power_profile_mode(struct smu_context *smu, long *inp /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ workload_type = smu_cmn_to_asic_specific_index(smu, CMN2ASIC_MAPPING_WORKLOAD, - smu->power_profile_mode); + profile_mode); if (workload_type < 0) return -EINVAL; + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, smu->workload_mask, NULL); if (ret) dev_err(smu->adev->dev, "[%s] Failed to set work load mask!", __func__); - else - smu_cmn_assign_power_profile(smu); return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu14/smu_v14_0_2_ppt.c b/drivers/gpu/drm/amd/pm/swsmu/smu14/smu_v14_0_2_ppt.c index 1aa13d32ceb2..e9c75caaebd7 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu14/smu_v14_0_2_ppt.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu14/smu_v14_0_2_ppt.c @@ -1731,21 +1731,22 @@ static int smu_v14_0_2_get_power_profile_mode(struct smu_context *smu, static int smu_v14_0_2_set_power_profile_mode(struct smu_context *smu, long *input, - uint32_t size) + uint32_t size, + bool enable) { DpmActivityMonitorCoeffIntExternal_t activity_monitor_external; DpmActivityMonitorCoeffInt_t *activity_monitor = &(activity_monitor_external.DpmActivityMonitorCoeffInt); + uint32_t profile_mode = input[size]; int workload_type, ret = 0; uint32_t current_profile_mode = smu->power_profile_mode; - smu->power_profile_mode = input[size]; - if (smu->power_profile_mode >= PP_SMC_POWER_PROFILE_COUNT) { - dev_err(smu->adev->dev, "Invalid power profile mode %d\n", smu->power_profile_mode); + if (profile_mode >= PP_SMC_POWER_PROFILE_COUNT) { + dev_err(smu->adev->dev, "Invalid power profile mode %d\n", profile_mode); return -EINVAL; } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { + if (enable && profile_mode == PP_SMC_POWER_PROFILE_CUSTOM) { if (size != 9) return -EINVAL; @@ -1795,7 +1796,7 @@ static int smu_v14_0_2_set_power_profile_mode(struct smu_context *smu, } } - if (smu->power_profile_mode == PP_SMC_POWER_PROFILE_COMPUTE) + if (profile_mode == PP_SMC_POWER_PROFILE_COMPUTE) smu_v14_0_deep_sleep_control(smu, false); else if (current_profile_mode == PP_SMC_POWER_PROFILE_COMPUTE) smu_v14_0_deep_sleep_control(smu, true); @@ -1803,15 +1804,16 @@ static int smu_v14_0_2_set_power_profile_mode(struct smu_context *smu, /* conv PP_SMC_POWER_PROFILE* to WORKLOAD_PPLIB_*_BIT */ workload_type = smu_cmn_to_asic_specific_index(smu, CMN2ASIC_MAPPING_WORKLOAD, - smu->power_profile_mode); + profile_mode); if (workload_type < 0) return -EINVAL; + if (enable) + smu->workload_mask |= (1 << workload_type); + else + smu->workload_mask &= ~(1 << workload_type); ret = smu_cmn_send_smc_msg_with_param(smu, SMU_MSG_SetWorkloadMask, - smu->workload_mask, NULL); - - if (!ret) - smu_cmn_assign_power_profile(smu); + smu->workload_mask, NULL); return ret; } diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.c b/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.c index bdfc5e617333..91ad434bcdae 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.c +++ b/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.c @@ -1138,14 +1138,6 @@ int smu_cmn_set_mp1_state(struct smu_context *smu, return ret; } -void smu_cmn_assign_power_profile(struct smu_context *smu) -{ - uint32_t index; - index = fls(smu->workload_mask); - index = index > 0 && index <= WORKLOAD_POLICY_MAX ? index - 1 : 0; - smu->power_profile_mode = smu->workload_setting[index]; -} - bool smu_cmn_is_audio_func_enabled(struct amdgpu_device *adev) { struct pci_dev *p = NULL; diff --git a/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.h b/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.h index 8a801e389659..1de685defe85 100644 --- a/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.h +++ b/drivers/gpu/drm/amd/pm/swsmu/smu_cmn.h @@ -130,8 +130,6 @@ void smu_cmn_init_soft_gpu_metrics(void *table, uint8_t frev, uint8_t crev); int smu_cmn_set_mp1_state(struct smu_context *smu, enum pp_mp1_state mp1_state); -void smu_cmn_assign_power_profile(struct smu_context *smu); - /* * Helper function to make sysfs_emit_at() happy. Align buf to * the current page boundary and record the offset. diff --git a/drivers/gpu/drm/drm_edid.c b/drivers/gpu/drm/drm_edid.c index 855beafb76ff..ad78059ee954 100644 --- a/drivers/gpu/drm/drm_edid.c +++ b/drivers/gpu/drm/drm_edid.c @@ -94,6 +94,8 @@ static int oui(u8 first, u8 second, u8 third) #define EDID_QUIRK_NON_DESKTOP (1 << 12) /* Cap the DSC target bitrate to 15bpp */ #define EDID_QUIRK_CAP_DSC_15BPP (1 << 13) +/* Fix up a particular 5120x1440@240Hz timing */ +#define EDID_QUIRK_FIXUP_5120_1440_240 (1 << 14) #define MICROSOFT_IEEE_OUI 0xca125c @@ -182,6 +184,12 @@ static const struct edid_quirk { EDID_QUIRK('S', 'A', 'M', 596, EDID_QUIRK_PREFER_LARGE_60), EDID_QUIRK('S', 'A', 'M', 638, EDID_QUIRK_PREFER_LARGE_60), + /* Samsung C49G95T */ + EDID_QUIRK('S', 'A', 'M', 0x7053, EDID_QUIRK_FIXUP_5120_1440_240), + + /* Samsung S49AG95 */ + EDID_QUIRK('S', 'A', 'M', 0x71ac, EDID_QUIRK_FIXUP_5120_1440_240), + /* Sony PVM-2541A does up to 12 bpc, but only reports max 8 bpc */ EDID_QUIRK('S', 'N', 'Y', 0x2541, EDID_QUIRK_FORCE_12BPC), @@ -6753,7 +6761,37 @@ static void update_display_info(struct drm_connector *connector, drm_edid_to_eld(connector, drm_edid); } -static struct drm_display_mode *drm_mode_displayid_detailed(struct drm_device *dev, +static void drm_mode_displayid_detailed_edid_quirks(struct drm_connector *connector, + struct drm_display_mode *mode) +{ + unsigned int hsync_width; + unsigned int vsync_width; + + if (connector->display_info.quirks & EDID_QUIRK_FIXUP_5120_1440_240) { + if (mode->hdisplay == 5120 && mode->vdisplay == 1440 && + mode->clock == 1939490) { + hsync_width = mode->hsync_end - mode->hsync_start; + vsync_width = mode->vsync_end - mode->vsync_start; + + mode->clock = 2018490; + mode->hdisplay = 5120; + mode->hsync_start = 5120 + 8; + mode->hsync_end = 5120 + 8 + hsync_width; + mode->htotal = 5200; + + mode->vdisplay = 1440; + mode->vsync_start = 1440 + 165; + mode->vsync_end = 1440 + 165 + vsync_width; + mode->vtotal = 1619; + + drm_dbg_kms(connector->dev, + "[CONNECTOR:%d:%s] Samsung 240Hz mode quirk applied\n", + connector->base.id, connector->name); + } + } +} + +static struct drm_display_mode *drm_mode_displayid_detailed(struct drm_connector *connector, struct displayid_detailed_timings_1 *timings, bool type_7) { @@ -6772,7 +6810,7 @@ static struct drm_display_mode *drm_mode_displayid_detailed(struct drm_device *d bool hsync_positive = (timings->hsync[1] >> 7) & 0x1; bool vsync_positive = (timings->vsync[1] >> 7) & 0x1; - mode = drm_mode_create(dev); + mode = drm_mode_create(connector->dev); if (!mode) return NULL; @@ -6795,6 +6833,9 @@ static struct drm_display_mode *drm_mode_displayid_detailed(struct drm_device *d if (timings->flags & 0x80) mode->type |= DRM_MODE_TYPE_PREFERRED; + + drm_mode_displayid_detailed_edid_quirks(connector, mode); + drm_mode_set_name(mode); return mode; @@ -6817,7 +6858,7 @@ static int add_displayid_detailed_1_modes(struct drm_connector *connector, for (i = 0; i < num_timings; i++) { struct displayid_detailed_timings_1 *timings = &det->timings[i]; - newmode = drm_mode_displayid_detailed(connector->dev, timings, type_7); + newmode = drm_mode_displayid_detailed(connector, timings, type_7); if (!newmode) continue; diff --git a/drivers/misc/lkdtm/bugs.c b/drivers/misc/lkdtm/bugs.c index 62ba01525479..376047beea3d 100644 --- a/drivers/misc/lkdtm/bugs.c +++ b/drivers/misc/lkdtm/bugs.c @@ -445,7 +445,7 @@ static void lkdtm_FAM_BOUNDS(void) pr_err("FAIL: survived access of invalid flexible array member index!\n"); - if (!__has_attribute(__counted_by__)) + if (!IS_ENABLED(CONFIG_CC_HAS_COUNTED_BY)) pr_warn("This is expected since this %s was built with a compiler that does not support __counted_by\n", lkdtm_kernel_info); else if (IS_ENABLED(CONFIG_UBSAN_BOUNDS)) diff --git a/fs/ntfs3/attrib.c b/fs/ntfs3/attrib.c index 0763202d00c9..8d789b017fa9 100644 --- a/fs/ntfs3/attrib.c +++ b/fs/ntfs3/attrib.c @@ -977,7 +977,7 @@ int attr_data_get_block(struct ntfs_inode *ni, CLST vcn, CLST clen, CLST *lcn, /* Check for compressed frame. */ err = attr_is_frame_compressed(ni, attr_b, vcn >> NTFS_LZNT_CUNIT, - &hint); + &hint, run); if (err) goto out; @@ -1521,16 +1521,16 @@ int attr_wof_frame_info(struct ntfs_inode *ni, struct ATTRIB *attr, * attr_is_frame_compressed - Used to detect compressed frame. * * attr - base (primary) attribute segment. + * run - run to use, usually == &ni->file.run. * Only base segments contains valid 'attr->nres.c_unit' */ int attr_is_frame_compressed(struct ntfs_inode *ni, struct ATTRIB *attr, - CLST frame, CLST *clst_data) + CLST frame, CLST *clst_data, struct runs_tree *run) { int err; u32 clst_frame; CLST clen, lcn, vcn, alen, slen, vcn_next; size_t idx; - struct runs_tree *run; *clst_data = 0; @@ -1542,7 +1542,6 @@ int attr_is_frame_compressed(struct ntfs_inode *ni, struct ATTRIB *attr, clst_frame = 1u << attr->nres.c_unit; vcn = frame * clst_frame; - run = &ni->file.run; if (!run_lookup_entry(run, vcn, &lcn, &clen, &idx)) { err = attr_load_runs_vcn(ni, attr->type, attr_name(attr), @@ -1678,7 +1677,7 @@ int attr_allocate_frame(struct ntfs_inode *ni, CLST frame, size_t compr_size, if (err) goto out; - err = attr_is_frame_compressed(ni, attr_b, frame, &clst_data); + err = attr_is_frame_compressed(ni, attr_b, frame, &clst_data, run); if (err) goto out; diff --git a/fs/ntfs3/bitmap.c b/fs/ntfs3/bitmap.c index cf4fe21a5039..04107b950717 100644 --- a/fs/ntfs3/bitmap.c +++ b/fs/ntfs3/bitmap.c @@ -710,20 +710,17 @@ int wnd_set_free(struct wnd_bitmap *wnd, size_t bit, size_t bits) { int err = 0; struct super_block *sb = wnd->sb; - size_t bits0 = bits; u32 wbits = 8 * sb->s_blocksize; size_t iw = bit >> (sb->s_blocksize_bits + 3); u32 wbit = bit & (wbits - 1); struct buffer_head *bh; + u32 op; - while (iw < wnd->nwnd && bits) { - u32 tail, op; - + for (; iw < wnd->nwnd && bits; iw++, bit += op, bits -= op, wbit = 0) { if (iw + 1 == wnd->nwnd) wbits = wnd->bits_last; - tail = wbits - wbit; - op = min_t(u32, tail, bits); + op = min_t(u32, wbits - wbit, bits); bh = wnd_map(wnd, iw); if (IS_ERR(bh)) { @@ -736,20 +733,15 @@ int wnd_set_free(struct wnd_bitmap *wnd, size_t bit, size_t bits) ntfs_bitmap_clear_le(bh->b_data, wbit, op); wnd->free_bits[iw] += op; + wnd->total_zeroes += op; set_buffer_uptodate(bh); mark_buffer_dirty(bh); unlock_buffer(bh); put_bh(bh); - wnd->total_zeroes += op; - bits -= op; - wbit = 0; - iw += 1; + wnd_add_free_ext(wnd, bit, op, false); } - - wnd_add_free_ext(wnd, bit, bits0, false); - return err; } @@ -760,20 +752,17 @@ int wnd_set_used(struct wnd_bitmap *wnd, size_t bit, size_t bits) { int err = 0; struct super_block *sb = wnd->sb; - size_t bits0 = bits; size_t iw = bit >> (sb->s_blocksize_bits + 3); u32 wbits = 8 * sb->s_blocksize; u32 wbit = bit & (wbits - 1); struct buffer_head *bh; + u32 op; - while (iw < wnd->nwnd && bits) { - u32 tail, op; - + for (; iw < wnd->nwnd && bits; iw++, bit += op, bits -= op, wbit = 0) { if (unlikely(iw + 1 == wnd->nwnd)) wbits = wnd->bits_last; - tail = wbits - wbit; - op = min_t(u32, tail, bits); + op = min_t(u32, wbits - wbit, bits); bh = wnd_map(wnd, iw); if (IS_ERR(bh)) { @@ -785,21 +774,16 @@ int wnd_set_used(struct wnd_bitmap *wnd, size_t bit, size_t bits) ntfs_bitmap_set_le(bh->b_data, wbit, op); wnd->free_bits[iw] -= op; + wnd->total_zeroes -= op; set_buffer_uptodate(bh); mark_buffer_dirty(bh); unlock_buffer(bh); put_bh(bh); - wnd->total_zeroes -= op; - bits -= op; - wbit = 0; - iw += 1; + if (!RB_EMPTY_ROOT(&wnd->start_tree)) + wnd_remove_free_ext(wnd, bit, op); } - - if (!RB_EMPTY_ROOT(&wnd->start_tree)) - wnd_remove_free_ext(wnd, bit, bits0); - return err; } @@ -852,15 +836,13 @@ static bool wnd_is_free_hlp(struct wnd_bitmap *wnd, size_t bit, size_t bits) size_t iw = bit >> (sb->s_blocksize_bits + 3); u32 wbits = 8 * sb->s_blocksize; u32 wbit = bit & (wbits - 1); + u32 op; - while (iw < wnd->nwnd && bits) { - u32 tail, op; - + for (; iw < wnd->nwnd && bits; iw++, bits -= op, wbit = 0) { if (unlikely(iw + 1 == wnd->nwnd)) wbits = wnd->bits_last; - tail = wbits - wbit; - op = min_t(u32, tail, bits); + op = min_t(u32, wbits - wbit, bits); if (wbits != wnd->free_bits[iw]) { bool ret; @@ -875,10 +857,6 @@ static bool wnd_is_free_hlp(struct wnd_bitmap *wnd, size_t bit, size_t bits) if (!ret) return false; } - - bits -= op; - wbit = 0; - iw += 1; } return true; @@ -928,6 +906,7 @@ bool wnd_is_used(struct wnd_bitmap *wnd, size_t bit, size_t bits) size_t iw = bit >> (sb->s_blocksize_bits + 3); u32 wbits = 8 * sb->s_blocksize; u32 wbit = bit & (wbits - 1); + u32 op; size_t end; struct rb_node *n; struct e_node *e; @@ -945,14 +924,11 @@ bool wnd_is_used(struct wnd_bitmap *wnd, size_t bit, size_t bits) return false; use_wnd: - while (iw < wnd->nwnd && bits) { - u32 tail, op; - + for (; iw < wnd->nwnd && bits; iw++, bits -= op, wbit = 0) { if (unlikely(iw + 1 == wnd->nwnd)) wbits = wnd->bits_last; - tail = wbits - wbit; - op = min_t(u32, tail, bits); + op = min_t(u32, wbits - wbit, bits); if (wnd->free_bits[iw]) { bool ret; @@ -966,10 +942,6 @@ bool wnd_is_used(struct wnd_bitmap *wnd, size_t bit, size_t bits) if (!ret) goto out; } - - bits -= op; - wbit = 0; - iw += 1; } ret = true; diff --git a/fs/ntfs3/file.c b/fs/ntfs3/file.c index e370eaf9bfe2..3f96a11804c9 100644 --- a/fs/ntfs3/file.c +++ b/fs/ntfs3/file.c @@ -182,13 +182,15 @@ static int ntfs_extend_initialized_size(struct file *file, loff_t pos = valid; int err; + if (valid >= new_valid) + return 0; + if (is_resident(ni)) { ni->i_valid = new_valid; return 0; } WARN_ON(is_compressed(ni)); - WARN_ON(valid >= new_valid); for (;;) { u32 zerofrom, len; @@ -222,7 +224,7 @@ static int ntfs_extend_initialized_size(struct file *file, if (err) goto out; - folio_zero_range(folio, zerofrom, folio_size(folio)); + folio_zero_range(folio, zerofrom, folio_size(folio) - zerofrom); err = ntfs_write_end(file, mapping, pos, len, len, folio, NULL); if (err < 0) @@ -987,6 +989,7 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) u64 frame_vbo; pgoff_t index; bool frame_uptodate; + struct folio *folio; if (frame_size < PAGE_SIZE) { /* @@ -1041,8 +1044,9 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) if (err) { for (ip = 0; ip < pages_per_frame; ip++) { page = pages[ip]; - unlock_page(page); - put_page(page); + folio = page_folio(page); + folio_unlock(folio); + folio_put(folio); } goto out; } @@ -1052,9 +1056,10 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) off = offset_in_page(valid); for (; ip < pages_per_frame; ip++, off = 0) { page = pages[ip]; + folio = page_folio(page); zero_user_segment(page, off, PAGE_SIZE); flush_dcache_page(page); - SetPageUptodate(page); + folio_mark_uptodate(folio); } ni_lock(ni); @@ -1063,9 +1068,10 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) for (ip = 0; ip < pages_per_frame; ip++) { page = pages[ip]; - SetPageUptodate(page); - unlock_page(page); - put_page(page); + folio = page_folio(page); + folio_mark_uptodate(folio); + folio_unlock(folio); + folio_put(folio); } if (err) @@ -1107,8 +1113,9 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) for (ip = 0; ip < pages_per_frame; ip++) { page = pages[ip]; - unlock_page(page); - put_page(page); + folio = page_folio(page); + folio_unlock(folio); + folio_put(folio); } goto out; } @@ -1149,9 +1156,10 @@ static ssize_t ntfs_compress_write(struct kiocb *iocb, struct iov_iter *from) for (ip = 0; ip < pages_per_frame; ip++) { page = pages[ip]; ClearPageDirty(page); - SetPageUptodate(page); - unlock_page(page); - put_page(page); + folio = page_folio(page); + folio_mark_uptodate(folio); + folio_unlock(folio); + folio_put(folio); } if (err) diff --git a/fs/ntfs3/frecord.c b/fs/ntfs3/frecord.c index 41c7ffad2790..8b39d0ce5f28 100644 --- a/fs/ntfs3/frecord.c +++ b/fs/ntfs3/frecord.c @@ -1900,46 +1900,6 @@ enum REPARSE_SIGN ni_parse_reparse(struct ntfs_inode *ni, struct ATTRIB *attr, return REPARSE_LINK; } -/* - * fiemap_fill_next_extent_k - a copy of fiemap_fill_next_extent - * but it uses 'fe_k' instead of fieinfo->fi_extents_start - */ -static int fiemap_fill_next_extent_k(struct fiemap_extent_info *fieinfo, - struct fiemap_extent *fe_k, u64 logical, - u64 phys, u64 len, u32 flags) -{ - struct fiemap_extent extent; - - /* only count the extents */ - if (fieinfo->fi_extents_max == 0) { - fieinfo->fi_extents_mapped++; - return (flags & FIEMAP_EXTENT_LAST) ? 1 : 0; - } - - if (fieinfo->fi_extents_mapped >= fieinfo->fi_extents_max) - return 1; - - if (flags & FIEMAP_EXTENT_DELALLOC) - flags |= FIEMAP_EXTENT_UNKNOWN; - if (flags & FIEMAP_EXTENT_DATA_ENCRYPTED) - flags |= FIEMAP_EXTENT_ENCODED; - if (flags & (FIEMAP_EXTENT_DATA_TAIL | FIEMAP_EXTENT_DATA_INLINE)) - flags |= FIEMAP_EXTENT_NOT_ALIGNED; - - memset(&extent, 0, sizeof(extent)); - extent.fe_logical = logical; - extent.fe_physical = phys; - extent.fe_length = len; - extent.fe_flags = flags; - - memcpy(fe_k + fieinfo->fi_extents_mapped, &extent, sizeof(extent)); - - fieinfo->fi_extents_mapped++; - if (fieinfo->fi_extents_mapped == fieinfo->fi_extents_max) - return 1; - return (flags & FIEMAP_EXTENT_LAST) ? 1 : 0; -} - /* * ni_fiemap - Helper for file_fiemap(). * @@ -1950,11 +1910,9 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, __u64 vbo, __u64 len) { int err = 0; - struct fiemap_extent *fe_k = NULL; struct ntfs_sb_info *sbi = ni->mi.sbi; u8 cluster_bits = sbi->cluster_bits; - struct runs_tree *run; - struct rw_semaphore *run_lock; + struct runs_tree run; struct ATTRIB *attr; CLST vcn = vbo >> cluster_bits; CLST lcn, clen; @@ -1965,13 +1923,11 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, u32 flags; bool ok; + run_init(&run); if (S_ISDIR(ni->vfs_inode.i_mode)) { - run = &ni->dir.alloc_run; attr = ni_find_attr(ni, NULL, NULL, ATTR_ALLOC, I30_NAME, ARRAY_SIZE(I30_NAME), NULL, NULL); - run_lock = &ni->dir.run_lock; } else { - run = &ni->file.run; attr = ni_find_attr(ni, NULL, NULL, ATTR_DATA, NULL, 0, NULL, NULL); if (!attr) { @@ -1986,7 +1942,6 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, "fiemap is not supported for compressed file (cp -r)"); goto out; } - run_lock = &ni->file.run_lock; } if (!attr || !attr->non_res) { @@ -1998,51 +1953,32 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, goto out; } - /* - * To avoid lock problems replace pointer to user memory by pointer to kernel memory. - */ - fe_k = kmalloc_array(fieinfo->fi_extents_max, - sizeof(struct fiemap_extent), - GFP_NOFS | __GFP_ZERO); - if (!fe_k) { - err = -ENOMEM; - goto out; - } - end = vbo + len; alloc_size = le64_to_cpu(attr->nres.alloc_size); if (end > alloc_size) end = alloc_size; - down_read(run_lock); - while (vbo < end) { if (idx == -1) { - ok = run_lookup_entry(run, vcn, &lcn, &clen, &idx); + ok = run_lookup_entry(&run, vcn, &lcn, &clen, &idx); } else { CLST vcn_next = vcn; - ok = run_get_entry(run, ++idx, &vcn, &lcn, &clen) && + ok = run_get_entry(&run, ++idx, &vcn, &lcn, &clen) && vcn == vcn_next; if (!ok) vcn = vcn_next; } if (!ok) { - up_read(run_lock); - down_write(run_lock); - err = attr_load_runs_vcn(ni, attr->type, attr_name(attr), - attr->name_len, run, vcn); - - up_write(run_lock); - down_read(run_lock); + attr->name_len, &run, vcn); if (err) break; - ok = run_lookup_entry(run, vcn, &lcn, &clen, &idx); + ok = run_lookup_entry(&run, vcn, &lcn, &clen, &idx); if (!ok) { err = -EINVAL; @@ -2067,8 +2003,9 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, } else if (is_attr_compressed(attr)) { CLST clst_data; - err = attr_is_frame_compressed( - ni, attr, vcn >> attr->nres.c_unit, &clst_data); + err = attr_is_frame_compressed(ni, attr, + vcn >> attr->nres.c_unit, + &clst_data, &run); if (err) break; if (clst_data < NTFS_LZNT_CLUSTERS) @@ -2097,8 +2034,8 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, if (vbo + dlen >= end) flags |= FIEMAP_EXTENT_LAST; - err = fiemap_fill_next_extent_k(fieinfo, fe_k, vbo, lbo, - dlen, flags); + err = fiemap_fill_next_extent(fieinfo, vbo, lbo, dlen, + flags); if (err < 0) break; @@ -2119,8 +2056,7 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, if (vbo + bytes >= end) flags |= FIEMAP_EXTENT_LAST; - err = fiemap_fill_next_extent_k(fieinfo, fe_k, vbo, lbo, bytes, - flags); + err = fiemap_fill_next_extent(fieinfo, vbo, lbo, bytes, flags); if (err < 0) break; if (err == 1) { @@ -2131,19 +2067,8 @@ int ni_fiemap(struct ntfs_inode *ni, struct fiemap_extent_info *fieinfo, vbo += bytes; } - up_read(run_lock); - - /* - * Copy to user memory out of lock - */ - if (copy_to_user(fieinfo->fi_extents_start, fe_k, - fieinfo->fi_extents_max * - sizeof(struct fiemap_extent))) { - err = -EFAULT; - } - out: - kfree(fe_k); + run_close(&run); return err; } @@ -2672,7 +2597,8 @@ int ni_read_frame(struct ntfs_inode *ni, u64 frame_vbo, struct page **pages, down_write(&ni->file.run_lock); run_truncate_around(run, le64_to_cpu(attr->nres.svcn)); frame = frame_vbo >> (cluster_bits + NTFS_LZNT_CUNIT); - err = attr_is_frame_compressed(ni, attr, frame, &clst_data); + err = attr_is_frame_compressed(ni, attr, frame, &clst_data, + run); up_write(&ni->file.run_lock); if (err) goto out1; diff --git a/fs/ntfs3/fsntfs.c b/fs/ntfs3/fsntfs.c index 0fa636038b4e..03471bc9371c 100644 --- a/fs/ntfs3/fsntfs.c +++ b/fs/ntfs3/fsntfs.c @@ -2699,4 +2699,4 @@ int ntfs_set_label(struct ntfs_sb_info *sbi, u8 *label, int len) out: __putname(uni); return err; -} \ No newline at end of file +} diff --git a/fs/ntfs3/ntfs_fs.h b/fs/ntfs3/ntfs_fs.h index 26e1e1379c04..cd8e8374bb5a 100644 --- a/fs/ntfs3/ntfs_fs.h +++ b/fs/ntfs3/ntfs_fs.h @@ -446,7 +446,8 @@ int attr_wof_frame_info(struct ntfs_inode *ni, struct ATTRIB *attr, struct runs_tree *run, u64 frame, u64 frames, u8 frame_bits, u32 *ondisk_size, u64 *vbo_data); int attr_is_frame_compressed(struct ntfs_inode *ni, struct ATTRIB *attr, - CLST frame, CLST *clst_data); + CLST frame, CLST *clst_data, + struct runs_tree *run); int attr_allocate_frame(struct ntfs_inode *ni, CLST frame, size_t compr_size, u64 new_valid); int attr_collapse_range(struct ntfs_inode *ni, u64 vbo, u64 bytes); diff --git a/fs/ntfs3/record.c b/fs/ntfs3/record.c index f810f0419d25..61d53d39f3b9 100644 --- a/fs/ntfs3/record.c +++ b/fs/ntfs3/record.c @@ -212,7 +212,7 @@ struct ATTRIB *mi_enum_attr(struct mft_inode *mi, struct ATTRIB *attr) return NULL; if (off >= used || off < MFTRECORD_FIXUP_OFFSET_1 || - !IS_ALIGNED(off, 4)) { + !IS_ALIGNED(off, 8)) { return NULL; } @@ -236,8 +236,11 @@ struct ATTRIB *mi_enum_attr(struct mft_inode *mi, struct ATTRIB *attr) off += asize; } - /* Can we use the first field (attr->type). */ - /* NOTE: this code also checks attr->size availability. */ + /* + * Can we use the first fields: + * attr->type, + * attr->size + */ if (off + 8 > used) { static_assert(ALIGN(sizeof(enum ATTR_TYPE), 8) == 8); return NULL; @@ -259,10 +262,17 @@ struct ATTRIB *mi_enum_attr(struct mft_inode *mi, struct ATTRIB *attr) asize = le32_to_cpu(attr->size); + if (!IS_ALIGNED(asize, 8)) + return NULL; + /* Check overflow and boundary. */ if (off + asize < off || off + asize > used) return NULL; + /* Can we use the field attr->non_res. */ + if (off + 9 > used) + return NULL; + /* Check size of attribute. */ if (!attr->non_res) { /* Check resident fields. */ diff --git a/fs/ntfs3/run.c b/fs/ntfs3/run.c index 58e988cd8049..6e86d66197ef 100644 --- a/fs/ntfs3/run.c +++ b/fs/ntfs3/run.c @@ -1055,8 +1055,8 @@ int run_unpack_ex(struct runs_tree *run, struct ntfs_sb_info *sbi, CLST ino, { int ret, err; CLST next_vcn, lcn, len; - size_t index; - bool ok; + size_t index, done; + bool ok, zone; struct wnd_bitmap *wnd; ret = run_unpack(run, sbi, ino, svcn, evcn, vcn, run_buf, run_buf_size); @@ -1087,8 +1087,9 @@ int run_unpack_ex(struct runs_tree *run, struct ntfs_sb_info *sbi, CLST ino, continue; down_read_nested(&wnd->rw_lock, BITMAP_MUTEX_CLUSTERS); + zone = max(wnd->zone_bit, lcn) < min(wnd->zone_end, lcn + len); /* Check for free blocks. */ - ok = wnd_is_used(wnd, lcn, len); + ok = !zone && wnd_is_used(wnd, lcn, len); up_read(&wnd->rw_lock); if (ok) continue; @@ -1096,14 +1097,33 @@ int run_unpack_ex(struct runs_tree *run, struct ntfs_sb_info *sbi, CLST ino, /* Looks like volume is corrupted. */ ntfs_set_state(sbi, NTFS_DIRTY_ERROR); - if (down_write_trylock(&wnd->rw_lock)) { - /* Mark all zero bits as used in range [lcn, lcn+len). */ - size_t done; - err = wnd_set_used_safe(wnd, lcn, len, &done); - up_write(&wnd->rw_lock); - if (err) - return err; + if (!down_write_trylock(&wnd->rw_lock)) + continue; + + if (zone) { + /* + * Range [lcn, lcn + len) intersects with zone. + * To avoid complex with zone just turn it off. + */ + wnd_zone_set(wnd, 0, 0); + } + + /* Mark all zero bits as used in range [lcn, lcn+len). */ + err = wnd_set_used_safe(wnd, lcn, len, &done); + if (zone) { + /* Restore zone. Lock mft run. */ + struct rw_semaphore *lock = + is_mounted(sbi) ? &sbi->mft.ni->file.run_lock : + NULL; + if (lock) + down_read(lock); + ntfs_refresh_zone(sbi); + if (lock) + up_read(lock); } + up_write(&wnd->rw_lock); + if (err) + return err; } return ret; diff --git a/include/linux/compiler_attributes.h b/include/linux/compiler_attributes.h index 32284cd26d52..c16d4199bf92 100644 --- a/include/linux/compiler_attributes.h +++ b/include/linux/compiler_attributes.h @@ -94,19 +94,6 @@ # define __copy(symbol) #endif -/* - * Optional: only supported since gcc >= 15 - * Optional: only supported since clang >= 18 - * - * gcc: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=108896 - * clang: https://github.com/llvm/llvm-project/pull/76348 - */ -#if __has_attribute(__counted_by__) -# define __counted_by(member) __attribute__((__counted_by__(member))) -#else -# define __counted_by(member) -#endif - /* * Optional: not supported by gcc * Optional: only supported since clang >= 14.0 diff --git a/include/linux/compiler_types.h b/include/linux/compiler_types.h index 1a957ea2f4fe..639be0f30b45 100644 --- a/include/linux/compiler_types.h +++ b/include/linux/compiler_types.h @@ -323,6 +323,25 @@ struct ftrace_likely_data { #define __no_sanitize_or_inline __always_inline #endif +/* + * Optional: only supported since gcc >= 15 + * Optional: only supported since clang >= 18 + * + * gcc: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=108896 + * clang: https://github.com/llvm/llvm-project/pull/76348 + * + * __bdos on clang < 19.1.2 can erroneously return 0: + * https://github.com/llvm/llvm-project/pull/110497 + * + * __bdos on clang < 19.1.3 can be off by 4: + * https://github.com/llvm/llvm-project/pull/112636 + */ +#ifdef CONFIG_CC_HAS_COUNTED_BY +# define __counted_by(member) __attribute__((__counted_by__(member))) +#else +# define __counted_by(member) +#endif + /* * Apply __counted_by() when the Endianness matches to increase test coverage. */ diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 6e3bdf8e38bc..6b6f05404304 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -1335,6 +1335,7 @@ enum tlb_flush_reason { TLB_LOCAL_SHOOTDOWN, TLB_LOCAL_MM_SHOOTDOWN, TLB_REMOTE_SEND_IPI, + TLB_REMOTE_WRONG_CPU, NR_TLB_FLUSH_REASONS, }; diff --git a/init/Kconfig b/init/Kconfig index 38dbd16da6a9..504e8a7c4e2a 100644 --- a/init/Kconfig +++ b/init/Kconfig @@ -120,6 +120,14 @@ config CC_HAS_ASM_INLINE config CC_HAS_NO_PROFILE_FN_ATTR def_bool $(success,echo '__attribute__((no_profile_instrument_function)) int x();' | $(CC) -x c - -c -o /dev/null -Werror) +# clang needs to be at least 19.1.3 to avoid __bdos miscalculations +# https://github.com/llvm/llvm-project/pull/110497 +# https://github.com/llvm/llvm-project/pull/112636 +# TODO: when gcc 15 is released remove the build test and add gcc version check +config CC_HAS_COUNTED_BY + def_bool $(success,echo 'struct flex { int count; int array[] __attribute__((__counted_by__(count))); };' | $(CC) $(CLANG_FLAGS) -x c - -c -o /dev/null -Werror) + depends on !(CC_IS_CLANG && CLANG_VERSION < 190103) + config PAHOLE_VERSION int default $(shell,$(srctree)/scripts/pahole-version.sh $(PAHOLE)) diff --git a/kernel/sched/core.c b/kernel/sched/core.c index 719e0ed1e976..b35752fdbcc0 100644 --- a/kernel/sched/core.c +++ b/kernel/sched/core.c @@ -3734,28 +3734,38 @@ ttwu_do_activate(struct rq *rq, struct task_struct *p, int wake_flags, */ static int ttwu_runnable(struct task_struct *p, int wake_flags) { - struct rq_flags rf; - struct rq *rq; - int ret = 0; + CLASS(__task_rq_lock, rq_guard)(p); + struct rq *rq = rq_guard.rq; - rq = __task_rq_lock(p, &rf); - if (task_on_rq_queued(p)) { - update_rq_clock(rq); - if (p->se.sched_delayed) - enqueue_task(rq, p, ENQUEUE_NOCLOCK | ENQUEUE_DELAYED); - if (!task_on_cpu(rq, p)) { - /* - * When on_rq && !on_cpu the task is preempted, see if - * it should preempt the task that is current now. - */ - wakeup_preempt(rq, p, wake_flags); + if (!task_on_rq_queued(p)) + return 0; + + update_rq_clock(rq); + if (p->se.sched_delayed) { + int queue_flags = ENQUEUE_DELAYED | ENQUEUE_NOCLOCK; + + /* + * Since sched_delayed means we cannot be current anywhere, + * dequeue it here and have it fall through to the + * select_task_rq() case further along the ttwu() path. + */ + if (rq->nr_running > 1 && p->nr_cpus_allowed > 1) { + dequeue_task(rq, p, DEQUEUE_SLEEP | queue_flags); + return 0; } - ttwu_do_wakeup(p); - ret = 1; + + enqueue_task(rq, p, queue_flags); } - __task_rq_unlock(rq, &rf); + if (!task_on_cpu(rq, p)) { + /* + * When on_rq && !on_cpu the task is preempted, see if + * it should preempt the task that is current now. + */ + wakeup_preempt(rq, p, wake_flags); + } + ttwu_do_wakeup(p); - return ret; + return 1; } #ifdef CONFIG_SMP diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h index f610df2e0811..e7fbb1d0f316 100644 --- a/kernel/sched/sched.h +++ b/kernel/sched/sched.h @@ -1779,6 +1779,11 @@ task_rq_unlock(struct rq *rq, struct task_struct *p, struct rq_flags *rf) raw_spin_unlock_irqrestore(&p->pi_lock, rf->flags); } +DEFINE_LOCK_GUARD_1(__task_rq_lock, struct task_struct, + _T->rq = __task_rq_lock(_T->lock, &_T->rf), + __task_rq_unlock(_T->rq, &_T->rf), + struct rq *rq; struct rq_flags rf) + DEFINE_LOCK_GUARD_1(task_rq_lock, struct task_struct, _T->rq = task_rq_lock(_T->lock, &_T->rf), task_rq_unlock(_T->rq, _T->lock, &_T->rf), diff --git a/lib/overflow_kunit.c b/lib/overflow_kunit.c index 2abc78367dd1..5222c6393f11 100644 --- a/lib/overflow_kunit.c +++ b/lib/overflow_kunit.c @@ -1187,7 +1187,7 @@ static void DEFINE_FLEX_test(struct kunit *test) { /* Using _RAW_ on a __counted_by struct will initialize "counter" to zero */ DEFINE_RAW_FLEX(struct foo, two_but_zero, array, 2); -#if __has_attribute(__counted_by__) +#ifdef CONFIG_CC_HAS_COUNTED_BY int expected_raw_size = sizeof(struct foo); #else int expected_raw_size = sizeof(struct foo) + 2 * sizeof(s16); diff --git a/scripts/package/PKGBUILD b/scripts/package/PKGBUILD index f83493838cf9..4010899652b8 100644 --- a/scripts/package/PKGBUILD +++ b/scripts/package/PKGBUILD @@ -91,6 +91,11 @@ _package-headers() { "${srctree}/scripts/package/install-extmod-build" "${builddir}" fi + # required when DEBUG_INFO_BTF_MODULES is enabled + if [ -f tools/bpf/resolve_btfids/resolve_btfids ]; then + install -Dt "$builddir/tools/bpf/resolve_btfids" tools/bpf/resolve_btfids/resolve_btfids + fi + echo "Installing System.map and config..." mkdir -p "${builddir}" cp System.map "${builddir}/System.map" -- 2.47.0 From 9001aa3709fdcb60967ed205910b873f14eed07b Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:20:32 +0100 Subject: [PATCH 08/13] ksm Signed-off-by: Peter Jung --- arch/alpha/kernel/syscalls/syscall.tbl | 3 + arch/arm/tools/syscall.tbl | 3 + arch/m68k/kernel/syscalls/syscall.tbl | 3 + arch/microblaze/kernel/syscalls/syscall.tbl | 3 + arch/mips/kernel/syscalls/syscall_n32.tbl | 3 + arch/mips/kernel/syscalls/syscall_n64.tbl | 3 + arch/mips/kernel/syscalls/syscall_o32.tbl | 3 + arch/parisc/kernel/syscalls/syscall.tbl | 3 + arch/powerpc/kernel/syscalls/syscall.tbl | 3 + arch/s390/kernel/syscalls/syscall.tbl | 3 + arch/sh/kernel/syscalls/syscall.tbl | 3 + arch/sparc/kernel/syscalls/syscall.tbl | 3 + arch/x86/entry/syscalls/syscall_32.tbl | 3 + arch/x86/entry/syscalls/syscall_64.tbl | 3 + arch/xtensa/kernel/syscalls/syscall.tbl | 3 + include/linux/syscalls.h | 3 + include/uapi/asm-generic/unistd.h | 9 +- kernel/sys.c | 138 ++++++++++++++++++ kernel/sys_ni.c | 3 + scripts/syscall.tbl | 3 + .../arch/powerpc/entry/syscalls/syscall.tbl | 3 + .../perf/arch/s390/entry/syscalls/syscall.tbl | 3 + 22 files changed, 206 insertions(+), 1 deletion(-) diff --git a/arch/alpha/kernel/syscalls/syscall.tbl b/arch/alpha/kernel/syscalls/syscall.tbl index 74720667fe09..e6a11f3c0a2e 100644 --- a/arch/alpha/kernel/syscalls/syscall.tbl +++ b/arch/alpha/kernel/syscalls/syscall.tbl @@ -502,3 +502,6 @@ 570 common lsm_set_self_attr sys_lsm_set_self_attr 571 common lsm_list_modules sys_lsm_list_modules 572 common mseal sys_mseal +573 common process_ksm_enable sys_process_ksm_enable +574 common process_ksm_disable sys_process_ksm_disable +575 common process_ksm_status sys_process_ksm_status diff --git a/arch/arm/tools/syscall.tbl b/arch/arm/tools/syscall.tbl index 23c98203c40f..10a3099decbe 100644 --- a/arch/arm/tools/syscall.tbl +++ b/arch/arm/tools/syscall.tbl @@ -477,3 +477,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/m68k/kernel/syscalls/syscall.tbl b/arch/m68k/kernel/syscalls/syscall.tbl index 22a3cbd4c602..12d2c7594bf0 100644 --- a/arch/m68k/kernel/syscalls/syscall.tbl +++ b/arch/m68k/kernel/syscalls/syscall.tbl @@ -462,3 +462,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/microblaze/kernel/syscalls/syscall.tbl b/arch/microblaze/kernel/syscalls/syscall.tbl index 2b81a6bd78b2..e2a93c856eed 100644 --- a/arch/microblaze/kernel/syscalls/syscall.tbl +++ b/arch/microblaze/kernel/syscalls/syscall.tbl @@ -468,3 +468,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/mips/kernel/syscalls/syscall_n32.tbl b/arch/mips/kernel/syscalls/syscall_n32.tbl index 953f5b7dc723..b921fbf56fa6 100644 --- a/arch/mips/kernel/syscalls/syscall_n32.tbl +++ b/arch/mips/kernel/syscalls/syscall_n32.tbl @@ -401,3 +401,6 @@ 460 n32 lsm_set_self_attr sys_lsm_set_self_attr 461 n32 lsm_list_modules sys_lsm_list_modules 462 n32 mseal sys_mseal +463 n32 process_ksm_enable sys_process_ksm_enable +464 n32 process_ksm_disable sys_process_ksm_disable +465 n32 process_ksm_status sys_process_ksm_status diff --git a/arch/mips/kernel/syscalls/syscall_n64.tbl b/arch/mips/kernel/syscalls/syscall_n64.tbl index 1464c6be6eb3..8d7f9ddd66f4 100644 --- a/arch/mips/kernel/syscalls/syscall_n64.tbl +++ b/arch/mips/kernel/syscalls/syscall_n64.tbl @@ -377,3 +377,6 @@ 460 n64 lsm_set_self_attr sys_lsm_set_self_attr 461 n64 lsm_list_modules sys_lsm_list_modules 462 n64 mseal sys_mseal +463 n64 process_ksm_enable sys_process_ksm_enable +464 n64 process_ksm_disable sys_process_ksm_disable +465 n64 process_ksm_status sys_process_ksm_status diff --git a/arch/mips/kernel/syscalls/syscall_o32.tbl b/arch/mips/kernel/syscalls/syscall_o32.tbl index 2439a2491cff..9d6142739954 100644 --- a/arch/mips/kernel/syscalls/syscall_o32.tbl +++ b/arch/mips/kernel/syscalls/syscall_o32.tbl @@ -450,3 +450,6 @@ 460 o32 lsm_set_self_attr sys_lsm_set_self_attr 461 o32 lsm_list_modules sys_lsm_list_modules 462 o32 mseal sys_mseal +463 o32 process_ksm_enable sys_process_ksm_enable +464 o32 process_ksm_disable sys_process_ksm_disable +465 o32 process_ksm_status sys_process_ksm_status diff --git a/arch/parisc/kernel/syscalls/syscall.tbl b/arch/parisc/kernel/syscalls/syscall.tbl index 66dc406b12e4..9d46476fd908 100644 --- a/arch/parisc/kernel/syscalls/syscall.tbl +++ b/arch/parisc/kernel/syscalls/syscall.tbl @@ -461,3 +461,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/powerpc/kernel/syscalls/syscall.tbl b/arch/powerpc/kernel/syscalls/syscall.tbl index ebae8415dfbb..16f71bc2f6f0 100644 --- a/arch/powerpc/kernel/syscalls/syscall.tbl +++ b/arch/powerpc/kernel/syscalls/syscall.tbl @@ -553,3 +553,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/s390/kernel/syscalls/syscall.tbl b/arch/s390/kernel/syscalls/syscall.tbl index 01071182763e..7394bad8178e 100644 --- a/arch/s390/kernel/syscalls/syscall.tbl +++ b/arch/s390/kernel/syscalls/syscall.tbl @@ -465,3 +465,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status sys_process_ksm_status diff --git a/arch/sh/kernel/syscalls/syscall.tbl b/arch/sh/kernel/syscalls/syscall.tbl index c55fd7696d40..b9fc31221b87 100644 --- a/arch/sh/kernel/syscalls/syscall.tbl +++ b/arch/sh/kernel/syscalls/syscall.tbl @@ -466,3 +466,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/sparc/kernel/syscalls/syscall.tbl b/arch/sparc/kernel/syscalls/syscall.tbl index cfdfb3707c16..0d79fd772854 100644 --- a/arch/sparc/kernel/syscalls/syscall.tbl +++ b/arch/sparc/kernel/syscalls/syscall.tbl @@ -508,3 +508,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl index 534c74b14fab..c546a30575f1 100644 --- a/arch/x86/entry/syscalls/syscall_32.tbl +++ b/arch/x86/entry/syscalls/syscall_32.tbl @@ -468,3 +468,6 @@ 460 i386 lsm_set_self_attr sys_lsm_set_self_attr 461 i386 lsm_list_modules sys_lsm_list_modules 462 i386 mseal sys_mseal +463 i386 process_ksm_enable sys_process_ksm_enable +464 i386 process_ksm_disable sys_process_ksm_disable +465 i386 process_ksm_status sys_process_ksm_status diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl index 7093ee21c0d1..0fcd10ba8dfe 100644 --- a/arch/x86/entry/syscalls/syscall_64.tbl +++ b/arch/x86/entry/syscalls/syscall_64.tbl @@ -386,6 +386,9 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status # # Due to a historical design error, certain syscalls are numbered differently diff --git a/arch/xtensa/kernel/syscalls/syscall.tbl b/arch/xtensa/kernel/syscalls/syscall.tbl index 67083fc1b2f5..c1aecee4ad9b 100644 --- a/arch/xtensa/kernel/syscalls/syscall.tbl +++ b/arch/xtensa/kernel/syscalls/syscall.tbl @@ -433,3 +433,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h index 5758104921e6..cc9c4fac2412 100644 --- a/include/linux/syscalls.h +++ b/include/linux/syscalls.h @@ -818,6 +818,9 @@ asmlinkage long sys_madvise(unsigned long start, size_t len, int behavior); asmlinkage long sys_process_madvise(int pidfd, const struct iovec __user *vec, size_t vlen, int behavior, unsigned int flags); asmlinkage long sys_process_mrelease(int pidfd, unsigned int flags); +asmlinkage long sys_process_ksm_enable(int pidfd, unsigned int flags); +asmlinkage long sys_process_ksm_disable(int pidfd, unsigned int flags); +asmlinkage long sys_process_ksm_status(int pidfd, unsigned int flags); asmlinkage long sys_remap_file_pages(unsigned long start, unsigned long size, unsigned long prot, unsigned long pgoff, unsigned long flags); diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h index 5bf6148cac2b..613e559ad6e0 100644 --- a/include/uapi/asm-generic/unistd.h +++ b/include/uapi/asm-generic/unistd.h @@ -841,8 +841,15 @@ __SYSCALL(__NR_lsm_list_modules, sys_lsm_list_modules) #define __NR_mseal 462 __SYSCALL(__NR_mseal, sys_mseal) +#define __NR_process_ksm_enable 463 +__SYSCALL(__NR_process_ksm_enable, sys_process_ksm_enable) +#define __NR_process_ksm_disable 464 +__SYSCALL(__NR_process_ksm_disable, sys_process_ksm_disable) +#define __NR_process_ksm_status 465 +__SYSCALL(__NR_process_ksm_status, sys_process_ksm_status) + #undef __NR_syscalls -#define __NR_syscalls 463 +#define __NR_syscalls 466 /* * 32 bit systems traditionally used different diff --git a/kernel/sys.c b/kernel/sys.c index 4da31f28fda8..fcd3aeaddd05 100644 --- a/kernel/sys.c +++ b/kernel/sys.c @@ -2791,6 +2791,144 @@ SYSCALL_DEFINE5(prctl, int, option, unsigned long, arg2, unsigned long, arg3, return error; } +#ifdef CONFIG_KSM +enum pkc_action { + PKSM_ENABLE = 0, + PKSM_DISABLE, + PKSM_STATUS, +}; + +static long do_process_ksm_control(int pidfd, enum pkc_action action) +{ + long ret; + struct task_struct *task; + struct mm_struct *mm; + unsigned int f_flags; + + task = pidfd_get_task(pidfd, &f_flags); + if (IS_ERR(task)) { + ret = PTR_ERR(task); + goto out; + } + + /* Require PTRACE_MODE_READ to avoid leaking ASLR metadata. */ + mm = mm_access(task, PTRACE_MODE_READ_FSCREDS); + if (IS_ERR_OR_NULL(mm)) { + ret = IS_ERR(mm) ? PTR_ERR(mm) : -ESRCH; + goto release_task; + } + + /* Require CAP_SYS_NICE for influencing process performance. */ + if (!capable(CAP_SYS_NICE)) { + ret = -EPERM; + goto release_mm; + } + + if (mmap_write_lock_killable(mm)) { + ret = -EINTR; + goto release_mm; + } + + switch (action) { + case PKSM_ENABLE: + ret = ksm_enable_merge_any(mm); + break; + case PKSM_DISABLE: + ret = ksm_disable_merge_any(mm); + break; + case PKSM_STATUS: + ret = !!test_bit(MMF_VM_MERGE_ANY, &mm->flags); + break; + } + + mmap_write_unlock(mm); + +release_mm: + mmput(mm); +release_task: + put_task_struct(task); +out: + return ret; +} +#endif /* CONFIG_KSM */ + +SYSCALL_DEFINE2(process_ksm_enable, int, pidfd, unsigned int, flags) +{ +#ifdef CONFIG_KSM + if (flags != 0) + return -EINVAL; + + return do_process_ksm_control(pidfd, PKSM_ENABLE); +#else /* CONFIG_KSM */ + return -ENOSYS; +#endif /* CONFIG_KSM */ +} + +SYSCALL_DEFINE2(process_ksm_disable, int, pidfd, unsigned int, flags) +{ +#ifdef CONFIG_KSM + if (flags != 0) + return -EINVAL; + + return do_process_ksm_control(pidfd, PKSM_DISABLE); +#else /* CONFIG_KSM */ + return -ENOSYS; +#endif /* CONFIG_KSM */ +} + +SYSCALL_DEFINE2(process_ksm_status, int, pidfd, unsigned int, flags) +{ +#ifdef CONFIG_KSM + if (flags != 0) + return -EINVAL; + + return do_process_ksm_control(pidfd, PKSM_STATUS); +#else /* CONFIG_KSM */ + return -ENOSYS; +#endif /* CONFIG_KSM */ +} + +#ifdef CONFIG_KSM +static ssize_t process_ksm_enable_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return sprintf(buf, "%u\n", __NR_process_ksm_enable); +} +static struct kobj_attribute process_ksm_enable_attr = __ATTR_RO(process_ksm_enable); + +static ssize_t process_ksm_disable_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return sprintf(buf, "%u\n", __NR_process_ksm_disable); +} +static struct kobj_attribute process_ksm_disable_attr = __ATTR_RO(process_ksm_disable); + +static ssize_t process_ksm_status_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + return sprintf(buf, "%u\n", __NR_process_ksm_status); +} +static struct kobj_attribute process_ksm_status_attr = __ATTR_RO(process_ksm_status); + +static struct attribute *process_ksm_sysfs_attrs[] = { + &process_ksm_enable_attr.attr, + &process_ksm_disable_attr.attr, + &process_ksm_status_attr.attr, + NULL, +}; + +static const struct attribute_group process_ksm_sysfs_attr_group = { + .attrs = process_ksm_sysfs_attrs, + .name = "process_ksm", +}; + +static int __init process_ksm_sysfs_init(void) +{ + return sysfs_create_group(kernel_kobj, &process_ksm_sysfs_attr_group); +} +subsys_initcall(process_ksm_sysfs_init); +#endif /* CONFIG_KSM */ + SYSCALL_DEFINE3(getcpu, unsigned __user *, cpup, unsigned __user *, nodep, struct getcpu_cache __user *, unused) { diff --git a/kernel/sys_ni.c b/kernel/sys_ni.c index c00a86931f8c..d82213d68522 100644 --- a/kernel/sys_ni.c +++ b/kernel/sys_ni.c @@ -186,6 +186,9 @@ COND_SYSCALL(mincore); COND_SYSCALL(madvise); COND_SYSCALL(process_madvise); COND_SYSCALL(process_mrelease); +COND_SYSCALL(process_ksm_enable); +COND_SYSCALL(process_ksm_disable); +COND_SYSCALL(process_ksm_status); COND_SYSCALL(remap_file_pages); COND_SYSCALL(mbind); COND_SYSCALL(get_mempolicy); diff --git a/scripts/syscall.tbl b/scripts/syscall.tbl index 845e24eb372e..227d9cc12365 100644 --- a/scripts/syscall.tbl +++ b/scripts/syscall.tbl @@ -403,3 +403,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/tools/perf/arch/powerpc/entry/syscalls/syscall.tbl b/tools/perf/arch/powerpc/entry/syscalls/syscall.tbl index ebae8415dfbb..16f71bc2f6f0 100644 --- a/tools/perf/arch/powerpc/entry/syscalls/syscall.tbl +++ b/tools/perf/arch/powerpc/entry/syscalls/syscall.tbl @@ -553,3 +553,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status diff --git a/tools/perf/arch/s390/entry/syscalls/syscall.tbl b/tools/perf/arch/s390/entry/syscalls/syscall.tbl index 01071182763e..7394bad8178e 100644 --- a/tools/perf/arch/s390/entry/syscalls/syscall.tbl +++ b/tools/perf/arch/s390/entry/syscalls/syscall.tbl @@ -465,3 +465,6 @@ 460 common lsm_set_self_attr sys_lsm_set_self_attr sys_lsm_set_self_attr 461 common lsm_list_modules sys_lsm_list_modules sys_lsm_list_modules 462 common mseal sys_mseal sys_mseal +463 common process_ksm_enable sys_process_ksm_enable sys_process_ksm_enable +464 common process_ksm_disable sys_process_ksm_disable sys_process_ksm_disable +465 common process_ksm_status sys_process_ksm_status sys_process_ksm_status -- 2.47.0 From fe2e45be58ad904ff3ab40356c10d73e057c18fd Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:20:44 +0100 Subject: [PATCH 09/13] ntsync Signed-off-by: Peter Jung --- Documentation/userspace-api/index.rst | 1 + Documentation/userspace-api/ntsync.rst | 398 +++++ MAINTAINERS | 9 + drivers/misc/Kconfig | 1 - drivers/misc/ntsync.c | 989 +++++++++++- include/uapi/linux/ntsync.h | 39 + tools/testing/selftests/Makefile | 1 + .../selftests/drivers/ntsync/.gitignore | 1 + .../testing/selftests/drivers/ntsync/Makefile | 7 + tools/testing/selftests/drivers/ntsync/config | 1 + .../testing/selftests/drivers/ntsync/ntsync.c | 1407 +++++++++++++++++ 11 files changed, 2850 insertions(+), 4 deletions(-) create mode 100644 Documentation/userspace-api/ntsync.rst create mode 100644 tools/testing/selftests/drivers/ntsync/.gitignore create mode 100644 tools/testing/selftests/drivers/ntsync/Makefile create mode 100644 tools/testing/selftests/drivers/ntsync/config create mode 100644 tools/testing/selftests/drivers/ntsync/ntsync.c diff --git a/Documentation/userspace-api/index.rst b/Documentation/userspace-api/index.rst index 274cc7546efc..9c1b15cd89ab 100644 --- a/Documentation/userspace-api/index.rst +++ b/Documentation/userspace-api/index.rst @@ -63,6 +63,7 @@ Everything else vduse futex2 perf_ring_buffer + ntsync .. only:: subproject and html diff --git a/Documentation/userspace-api/ntsync.rst b/Documentation/userspace-api/ntsync.rst new file mode 100644 index 000000000000..767844637a7d --- /dev/null +++ b/Documentation/userspace-api/ntsync.rst @@ -0,0 +1,398 @@ +=================================== +NT synchronization primitive driver +=================================== + +This page documents the user-space API for the ntsync driver. + +ntsync is a support driver for emulation of NT synchronization +primitives by user-space NT emulators. It exists because implementation +in user-space, using existing tools, cannot match Windows performance +while offering accurate semantics. It is implemented entirely in +software, and does not drive any hardware device. + +This interface is meant as a compatibility tool only, and should not +be used for general synchronization. Instead use generic, versatile +interfaces such as futex(2) and poll(2). + +Synchronization primitives +========================== + +The ntsync driver exposes three types of synchronization primitives: +semaphores, mutexes, and events. + +A semaphore holds a single volatile 32-bit counter, and a static 32-bit +integer denoting the maximum value. It is considered signaled (that is, +can be acquired without contention, or will wake up a waiting thread) +when the counter is nonzero. The counter is decremented by one when a +wait is satisfied. Both the initial and maximum count are established +when the semaphore is created. + +A mutex holds a volatile 32-bit recursion count, and a volatile 32-bit +identifier denoting its owner. A mutex is considered signaled when its +owner is zero (indicating that it is not owned). The recursion count is +incremented when a wait is satisfied, and ownership is set to the given +identifier. + +A mutex also holds an internal flag denoting whether its previous owner +has died; such a mutex is said to be abandoned. Owner death is not +tracked automatically based on thread death, but rather must be +communicated using ``NTSYNC_IOC_MUTEX_KILL``. An abandoned mutex is +inherently considered unowned. + +Except for the "unowned" semantics of zero, the actual value of the +owner identifier is not interpreted by the ntsync driver at all. The +intended use is to store a thread identifier; however, the ntsync +driver does not actually validate that a calling thread provides +consistent or unique identifiers. + +An event is similar to a semaphore with a maximum count of one. It holds +a volatile boolean state denoting whether it is signaled or not. There +are two types of events, auto-reset and manual-reset. An auto-reset +event is designaled when a wait is satisfied; a manual-reset event is +not. The event type is specified when the event is created. + +Unless specified otherwise, all operations on an object are atomic and +totally ordered with respect to other operations on the same object. + +Objects are represented by files. When all file descriptors to an +object are closed, that object is deleted. + +Char device +=========== + +The ntsync driver creates a single char device /dev/ntsync. Each file +description opened on the device represents a unique instance intended +to back an individual NT virtual machine. Objects created by one ntsync +instance may only be used with other objects created by the same +instance. + +ioctl reference +=============== + +All operations on the device are done through ioctls. There are four +structures used in ioctl calls:: + + struct ntsync_sem_args { + __u32 sem; + __u32 count; + __u32 max; + }; + + struct ntsync_mutex_args { + __u32 mutex; + __u32 owner; + __u32 count; + }; + + struct ntsync_event_args { + __u32 event; + __u32 signaled; + __u32 manual; + }; + + struct ntsync_wait_args { + __u64 timeout; + __u64 objs; + __u32 count; + __u32 owner; + __u32 index; + __u32 alert; + __u32 flags; + __u32 pad; + }; + +Depending on the ioctl, members of the structure may be used as input, +output, or not at all. All ioctls return 0 on success. + +The ioctls on the device file are as follows: + +.. c:macro:: NTSYNC_IOC_CREATE_SEM + + Create a semaphore object. Takes a pointer to struct + :c:type:`ntsync_sem_args`, which is used as follows: + + .. list-table:: + + * - ``sem`` + - On output, contains a file descriptor to the created semaphore. + * - ``count`` + - Initial count of the semaphore. + * - ``max`` + - Maximum count of the semaphore. + + Fails with ``EINVAL`` if ``count`` is greater than ``max``. + +.. c:macro:: NTSYNC_IOC_CREATE_MUTEX + + Create a mutex object. Takes a pointer to struct + :c:type:`ntsync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - On output, contains a file descriptor to the created mutex. + * - ``count`` + - Initial recursion count of the mutex. + * - ``owner`` + - Initial owner of the mutex. + + If ``owner`` is nonzero and ``count`` is zero, or if ``owner`` is + zero and ``count`` is nonzero, the function fails with ``EINVAL``. + +.. c:macro:: NTSYNC_IOC_CREATE_EVENT + + Create an event object. Takes a pointer to struct + :c:type:`ntsync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - On output, contains a file descriptor to the created event. + * - ``signaled`` + - If nonzero, the event is initially signaled, otherwise + nonsignaled. + * - ``manual`` + - If nonzero, the event is a manual-reset event, otherwise + auto-reset. + +The ioctls on the individual objects are as follows: + +.. c:macro:: NTSYNC_IOC_SEM_POST + + Post to a semaphore object. Takes a pointer to a 32-bit integer, + which on input holds the count to be added to the semaphore, and on + output contains its previous count. + + If adding to the semaphore's current count would raise the latter + past the semaphore's maximum count, the ioctl fails with + ``EOVERFLOW`` and the semaphore is not affected. If raising the + semaphore's count causes it to become signaled, eligible threads + waiting on this semaphore will be woken and the semaphore's count + decremented appropriately. + +.. c:macro:: NTSYNC_IOC_MUTEX_UNLOCK + + Release a mutex object. Takes a pointer to struct + :c:type:`ntsync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - Ignored. + * - ``owner`` + - Specifies the owner trying to release this mutex. + * - ``count`` + - On output, contains the previous recursion count. + + If ``owner`` is zero, the ioctl fails with ``EINVAL``. If ``owner`` + is not the current owner of the mutex, the ioctl fails with + ``EPERM``. + + The mutex's count will be decremented by one. If decrementing the + mutex's count causes it to become zero, the mutex is marked as + unowned and signaled, and eligible threads waiting on it will be + woken as appropriate. + +.. c:macro:: NTSYNC_IOC_SET_EVENT + + Signal an event object. Takes a pointer to a 32-bit integer, which on + output contains the previous state of the event. + + Eligible threads will be woken, and auto-reset events will be + designaled appropriately. + +.. c:macro:: NTSYNC_IOC_RESET_EVENT + + Designal an event object. Takes a pointer to a 32-bit integer, which + on output contains the previous state of the event. + +.. c:macro:: NTSYNC_IOC_PULSE_EVENT + + Wake threads waiting on an event object while leaving it in an + unsignaled state. Takes a pointer to a 32-bit integer, which on + output contains the previous state of the event. + + A pulse operation can be thought of as a set followed by a reset, + performed as a single atomic operation. If two threads are waiting on + an auto-reset event which is pulsed, only one will be woken. If two + threads are waiting a manual-reset event which is pulsed, both will + be woken. However, in both cases, the event will be unsignaled + afterwards, and a simultaneous read operation will always report the + event as unsignaled. + +.. c:macro:: NTSYNC_IOC_READ_SEM + + Read the current state of a semaphore object. Takes a pointer to + struct :c:type:`ntsync_sem_args`, which is used as follows: + + .. list-table:: + + * - ``sem`` + - Ignored. + * - ``count`` + - On output, contains the current count of the semaphore. + * - ``max`` + - On output, contains the maximum count of the semaphore. + +.. c:macro:: NTSYNC_IOC_READ_MUTEX + + Read the current state of a mutex object. Takes a pointer to struct + :c:type:`ntsync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - Ignored. + * - ``owner`` + - On output, contains the current owner of the mutex, or zero + if the mutex is not currently owned. + * - ``count`` + - On output, contains the current recursion count of the mutex. + + If the mutex is marked as abandoned, the function fails with + ``EOWNERDEAD``. In this case, ``count`` and ``owner`` are set to + zero. + +.. c:macro:: NTSYNC_IOC_READ_EVENT + + Read the current state of an event object. Takes a pointer to struct + :c:type:`ntsync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - Ignored. + * - ``signaled`` + - On output, contains the current state of the event. + * - ``manual`` + - On output, contains 1 if the event is a manual-reset event, + and 0 otherwise. + +.. c:macro:: NTSYNC_IOC_KILL_OWNER + + Mark a mutex as unowned and abandoned if it is owned by the given + owner. Takes an input-only pointer to a 32-bit integer denoting the + owner. If the owner is zero, the ioctl fails with ``EINVAL``. If the + owner does not own the mutex, the function fails with ``EPERM``. + + Eligible threads waiting on the mutex will be woken as appropriate + (and such waits will fail with ``EOWNERDEAD``, as described below). + +.. c:macro:: NTSYNC_IOC_WAIT_ANY + + Poll on any of a list of objects, atomically acquiring at most one. + Takes a pointer to struct :c:type:`ntsync_wait_args`, which is + used as follows: + + .. list-table:: + + * - ``timeout`` + - Absolute timeout in nanoseconds. If ``NTSYNC_WAIT_REALTIME`` + is set, the timeout is measured against the REALTIME clock; + otherwise it is measured against the MONOTONIC clock. If the + timeout is equal to or earlier than the current time, the + function returns immediately without sleeping. If ``timeout`` + is U64_MAX, the function will sleep until an object is + signaled, and will not fail with ``ETIMEDOUT``. + * - ``objs`` + - Pointer to an array of ``count`` file descriptors + (specified as an integer so that the structure has the same + size regardless of architecture). If any object is + invalid, the function fails with ``EINVAL``. + * - ``count`` + - Number of objects specified in the ``objs`` array. + If greater than ``NTSYNC_MAX_WAIT_COUNT``, the function fails + with ``EINVAL``. + * - ``owner`` + - Mutex owner identifier. If any object in ``objs`` is a mutex, + the ioctl will attempt to acquire that mutex on behalf of + ``owner``. If ``owner`` is zero, the ioctl fails with + ``EINVAL``. + * - ``index`` + - On success, contains the index (into ``objs``) of the object + which was signaled. If ``alert`` was signaled instead, + this contains ``count``. + * - ``alert`` + - Optional event object file descriptor. If nonzero, this + specifies an "alert" event object which, if signaled, will + terminate the wait. If nonzero, the identifier must point to a + valid event. + * - ``flags`` + - Zero or more flags. Currently the only flag is + ``NTSYNC_WAIT_REALTIME``, which causes the timeout to be + measured against the REALTIME clock instead of MONOTONIC. + * - ``pad`` + - Unused, must be set to zero. + + This function attempts to acquire one of the given objects. If unable + to do so, it sleeps until an object becomes signaled, subsequently + acquiring it, or the timeout expires. In the latter case the ioctl + fails with ``ETIMEDOUT``. The function only acquires one object, even + if multiple objects are signaled. + + A semaphore is considered to be signaled if its count is nonzero, and + is acquired by decrementing its count by one. A mutex is considered + to be signaled if it is unowned or if its owner matches the ``owner`` + argument, and is acquired by incrementing its recursion count by one + and setting its owner to the ``owner`` argument. An auto-reset event + is acquired by designaling it; a manual-reset event is not affected + by acquisition. + + Acquisition is atomic and totally ordered with respect to other + operations on the same object. If two wait operations (with different + ``owner`` identifiers) are queued on the same mutex, only one is + signaled. If two wait operations are queued on the same semaphore, + and a value of one is posted to it, only one is signaled. + + If an abandoned mutex is acquired, the ioctl fails with + ``EOWNERDEAD``. Although this is a failure return, the function may + otherwise be considered successful. The mutex is marked as owned by + the given owner (with a recursion count of 1) and as no longer + abandoned, and ``index`` is still set to the index of the mutex. + + The ``alert`` argument is an "extra" event which can terminate the + wait, independently of all other objects. + + It is valid to pass the same object more than once, including by + passing the same event in the ``objs`` array and in ``alert``. If a + wakeup occurs due to that object being signaled, ``index`` is set to + the lowest index corresponding to that object. + + The function may fail with ``EINTR`` if a signal is received. + +.. c:macro:: NTSYNC_IOC_WAIT_ALL + + Poll on a list of objects, atomically acquiring all of them. Takes a + pointer to struct :c:type:`ntsync_wait_args`, which is used + identically to ``NTSYNC_IOC_WAIT_ANY``, except that ``index`` is + always filled with zero on success if not woken via alert. + + This function attempts to simultaneously acquire all of the given + objects. If unable to do so, it sleeps until all objects become + simultaneously signaled, subsequently acquiring them, or the timeout + expires. In the latter case the ioctl fails with ``ETIMEDOUT`` and no + objects are modified. + + Objects may become signaled and subsequently designaled (through + acquisition by other threads) while this thread is sleeping. Only + once all objects are simultaneously signaled does the ioctl acquire + them and return. The entire acquisition is atomic and totally ordered + with respect to other operations on any of the given objects. + + If an abandoned mutex is acquired, the ioctl fails with + ``EOWNERDEAD``. Similarly to ``NTSYNC_IOC_WAIT_ANY``, all objects are + nevertheless marked as acquired. Note that if multiple mutex objects + are specified, there is no way to know which were marked as + abandoned. + + As with "any" waits, the ``alert`` argument is an "extra" event which + can terminate the wait. Critically, however, an "all" wait will + succeed if all members in ``objs`` are signaled, *or* if ``alert`` is + signaled. In the latter case ``index`` will be set to ``count``. As + with "any" waits, if both conditions are filled, the former takes + priority, and objects in ``objs`` will be acquired. + + Unlike ``NTSYNC_IOC_WAIT_ANY``, it is not valid to pass the same + object more than once, nor is it valid to pass the same object in + ``objs`` and in ``alert``. If this is attempted, the function fails + with ``EINVAL``. diff --git a/MAINTAINERS b/MAINTAINERS index 3d4709c29704..3ca514d82269 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -16501,6 +16501,15 @@ T: git https://github.com/Paragon-Software-Group/linux-ntfs3.git F: Documentation/filesystems/ntfs3.rst F: fs/ntfs3/ +NTSYNC SYNCHRONIZATION PRIMITIVE DRIVER +M: Elizabeth Figura +L: wine-devel@winehq.org +S: Supported +F: Documentation/userspace-api/ntsync.rst +F: drivers/misc/ntsync.c +F: include/uapi/linux/ntsync.h +F: tools/testing/selftests/drivers/ntsync/ + NUBUS SUBSYSTEM M: Finn Thain L: linux-m68k@lists.linux-m68k.org diff --git a/drivers/misc/Kconfig b/drivers/misc/Kconfig index 3fe7e2a9bd29..6c8b999a5e08 100644 --- a/drivers/misc/Kconfig +++ b/drivers/misc/Kconfig @@ -517,7 +517,6 @@ config OPEN_DICE config NTSYNC tristate "NT synchronization primitive emulation" - depends on BROKEN help This module provides kernel support for emulation of Windows NT synchronization primitives. It is not a hardware driver. diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c index 4954553b7baa..3fac06270549 100644 --- a/drivers/misc/ntsync.c +++ b/drivers/misc/ntsync.c @@ -6,11 +6,17 @@ */ #include +#include #include #include +#include +#include #include #include +#include #include +#include +#include #include #include #include @@ -19,6 +25,8 @@ enum ntsync_type { NTSYNC_TYPE_SEM, + NTSYNC_TYPE_MUTEX, + NTSYNC_TYPE_EVENT, }; /* @@ -30,10 +38,13 @@ enum ntsync_type { * * Both rely on struct file for reference counting. Individual * ntsync_obj objects take a reference to the device when created. + * Wait operations take a reference to each object being waited on for + * the duration of the wait. */ struct ntsync_obj { spinlock_t lock; + int dev_locked; enum ntsync_type type; @@ -46,13 +57,335 @@ struct ntsync_obj { __u32 count; __u32 max; } sem; + struct { + __u32 count; + pid_t owner; + bool ownerdead; + } mutex; + struct { + bool manual; + bool signaled; + } event; } u; + + /* + * any_waiters is protected by the object lock, but all_waiters is + * protected by the device wait_all_lock. + */ + struct list_head any_waiters; + struct list_head all_waiters; + + /* + * Hint describing how many tasks are queued on this object in a + * wait-all operation. + * + * Any time we do a wake, we may need to wake "all" waiters as well as + * "any" waiters. In order to atomically wake "all" waiters, we must + * lock all of the objects, and that means grabbing the wait_all_lock + * below (and, due to lock ordering rules, before locking this object). + * However, wait-all is a rare operation, and grabbing the wait-all + * lock for every wake would create unnecessary contention. + * Therefore we first check whether all_hint is zero, and, if it is, + * we skip trying to wake "all" waiters. + * + * Since wait requests must originate from user-space threads, we're + * limited here by PID_MAX_LIMIT, so there's no risk of overflow. + */ + atomic_t all_hint; +}; + +struct ntsync_q_entry { + struct list_head node; + struct ntsync_q *q; + struct ntsync_obj *obj; + __u32 index; +}; + +struct ntsync_q { + struct task_struct *task; + __u32 owner; + + /* + * Protected via atomic_try_cmpxchg(). Only the thread that wins the + * compare-and-swap may actually change object states and wake this + * task. + */ + atomic_t signaled; + + bool all; + bool ownerdead; + __u32 count; + struct ntsync_q_entry entries[]; }; struct ntsync_device { + /* + * Wait-all operations must atomically grab all objects, and be totally + * ordered with respect to each other and wait-any operations. + * If one thread is trying to acquire several objects, another thread + * cannot touch the object at the same time. + * + * This device-wide lock is used to serialize wait-for-all + * operations, and operations on an object that is involved in a + * wait-for-all. + */ + struct mutex wait_all_lock; + struct file *file; }; +/* + * Single objects are locked using obj->lock. + * + * Multiple objects are 'locked' while holding dev->wait_all_lock. + * In this case however, individual objects are not locked by holding + * obj->lock, but by setting obj->dev_locked. + * + * This means that in order to lock a single object, the sequence is slightly + * more complicated than usual. Specifically it needs to check obj->dev_locked + * after acquiring obj->lock, if set, it needs to drop the lock and acquire + * dev->wait_all_lock in order to serialize against the multi-object operation. + */ + +static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) +{ + lockdep_assert_held(&dev->wait_all_lock); + lockdep_assert(obj->dev == dev); + spin_lock(&obj->lock); + /* + * By setting obj->dev_locked inside obj->lock, it is ensured that + * anyone holding obj->lock must see the value. + */ + obj->dev_locked = 1; + spin_unlock(&obj->lock); +} + +static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) +{ + lockdep_assert_held(&dev->wait_all_lock); + lockdep_assert(obj->dev == dev); + spin_lock(&obj->lock); + obj->dev_locked = 0; + spin_unlock(&obj->lock); +} + +static void obj_lock(struct ntsync_obj *obj) +{ + struct ntsync_device *dev = obj->dev; + + for (;;) { + spin_lock(&obj->lock); + if (likely(!obj->dev_locked)) + break; + + spin_unlock(&obj->lock); + mutex_lock(&dev->wait_all_lock); + spin_lock(&obj->lock); + /* + * obj->dev_locked should be set and released under the same + * wait_all_lock section, since we now own this lock, it should + * be clear. + */ + lockdep_assert(!obj->dev_locked); + spin_unlock(&obj->lock); + mutex_unlock(&dev->wait_all_lock); + } +} + +static void obj_unlock(struct ntsync_obj *obj) +{ + spin_unlock(&obj->lock); +} + +static bool ntsync_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj) +{ + bool all; + + obj_lock(obj); + all = atomic_read(&obj->all_hint); + if (unlikely(all)) { + obj_unlock(obj); + mutex_lock(&dev->wait_all_lock); + dev_lock_obj(dev, obj); + } + + return all; +} + +static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, bool all) +{ + if (all) { + dev_unlock_obj(dev, obj); + mutex_unlock(&dev->wait_all_lock); + } else { + obj_unlock(obj); + } +} + +#define ntsync_assert_held(obj) \ + lockdep_assert((lockdep_is_held(&(obj)->lock) != LOCK_STATE_NOT_HELD) || \ + ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \ + (obj)->dev_locked)) + +static bool is_signaled(struct ntsync_obj *obj, __u32 owner) +{ + ntsync_assert_held(obj); + + switch (obj->type) { + case NTSYNC_TYPE_SEM: + return !!obj->u.sem.count; + case NTSYNC_TYPE_MUTEX: + if (obj->u.mutex.owner && obj->u.mutex.owner != owner) + return false; + return obj->u.mutex.count < UINT_MAX; + case NTSYNC_TYPE_EVENT: + return obj->u.event.signaled; + } + + WARN(1, "bad object type %#x\n", obj->type); + return false; +} + +/* + * "locked_obj" is an optional pointer to an object which is already locked and + * should not be locked again. This is necessary so that changing an object's + * state and waking it can be a single atomic operation. + */ +static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q, + struct ntsync_obj *locked_obj) +{ + __u32 count = q->count; + bool can_wake = true; + int signaled = -1; + __u32 i; + + lockdep_assert_held(&dev->wait_all_lock); + if (locked_obj) + lockdep_assert(locked_obj->dev_locked); + + for (i = 0; i < count; i++) { + if (q->entries[i].obj != locked_obj) + dev_lock_obj(dev, q->entries[i].obj); + } + + for (i = 0; i < count; i++) { + if (!is_signaled(q->entries[i].obj, q->owner)) { + can_wake = false; + break; + } + } + + if (can_wake && atomic_try_cmpxchg(&q->signaled, &signaled, 0)) { + for (i = 0; i < count; i++) { + struct ntsync_obj *obj = q->entries[i].obj; + + switch (obj->type) { + case NTSYNC_TYPE_SEM: + obj->u.sem.count--; + break; + case NTSYNC_TYPE_MUTEX: + if (obj->u.mutex.ownerdead) + q->ownerdead = true; + obj->u.mutex.ownerdead = false; + obj->u.mutex.count++; + obj->u.mutex.owner = q->owner; + break; + case NTSYNC_TYPE_EVENT: + if (!obj->u.event.manual) + obj->u.event.signaled = false; + break; + } + } + wake_up_process(q->task); + } + + for (i = 0; i < count; i++) { + if (q->entries[i].obj != locked_obj) + dev_unlock_obj(dev, q->entries[i].obj); + } +} + +static void try_wake_all_obj(struct ntsync_device *dev, struct ntsync_obj *obj) +{ + struct ntsync_q_entry *entry; + + lockdep_assert_held(&dev->wait_all_lock); + lockdep_assert(obj->dev_locked); + + list_for_each_entry(entry, &obj->all_waiters, node) + try_wake_all(dev, entry->q, obj); +} + +static void try_wake_any_sem(struct ntsync_obj *sem) +{ + struct ntsync_q_entry *entry; + + ntsync_assert_held(sem); + lockdep_assert(sem->type == NTSYNC_TYPE_SEM); + + list_for_each_entry(entry, &sem->any_waiters, node) { + struct ntsync_q *q = entry->q; + int signaled = -1; + + if (!sem->u.sem.count) + break; + + if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { + sem->u.sem.count--; + wake_up_process(q->task); + } + } +} + +static void try_wake_any_mutex(struct ntsync_obj *mutex) +{ + struct ntsync_q_entry *entry; + + ntsync_assert_held(mutex); + lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX); + + list_for_each_entry(entry, &mutex->any_waiters, node) { + struct ntsync_q *q = entry->q; + int signaled = -1; + + if (mutex->u.mutex.count == UINT_MAX) + break; + if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner) + continue; + + if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { + if (mutex->u.mutex.ownerdead) + q->ownerdead = true; + mutex->u.mutex.ownerdead = false; + mutex->u.mutex.count++; + mutex->u.mutex.owner = q->owner; + wake_up_process(q->task); + } + } +} + +static void try_wake_any_event(struct ntsync_obj *event) +{ + struct ntsync_q_entry *entry; + + ntsync_assert_held(event); + lockdep_assert(event->type == NTSYNC_TYPE_EVENT); + + list_for_each_entry(entry, &event->any_waiters, node) { + struct ntsync_q *q = entry->q; + int signaled = -1; + + if (!event->u.event.signaled) + break; + + if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { + if (!event->u.event.manual) + event->u.event.signaled = false; + wake_up_process(q->task); + } + } +} + /* * Actually change the semaphore state, returning -EOVERFLOW if it is made * invalid. @@ -61,7 +394,7 @@ static int post_sem_state(struct ntsync_obj *sem, __u32 count) { __u32 sum; - lockdep_assert_held(&sem->lock); + ntsync_assert_held(sem); if (check_add_overflow(sem->u.sem.count, count, &sum) || sum > sem->u.sem.max) @@ -73,9 +406,11 @@ static int post_sem_state(struct ntsync_obj *sem, __u32 count) static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) { + struct ntsync_device *dev = sem->dev; __u32 __user *user_args = argp; __u32 prev_count; __u32 args; + bool all; int ret; if (copy_from_user(&args, argp, sizeof(args))) @@ -84,12 +419,17 @@ static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) if (sem->type != NTSYNC_TYPE_SEM) return -EINVAL; - spin_lock(&sem->lock); + all = ntsync_lock_obj(dev, sem); prev_count = sem->u.sem.count; ret = post_sem_state(sem, args); + if (!ret) { + if (all) + try_wake_all_obj(dev, sem); + try_wake_any_sem(sem); + } - spin_unlock(&sem->lock); + ntsync_unlock_obj(dev, sem, all); if (!ret && put_user(prev_count, user_args)) ret = -EFAULT; @@ -97,6 +437,226 @@ static int ntsync_sem_post(struct ntsync_obj *sem, void __user *argp) return ret; } +/* + * Actually change the mutex state, returning -EPERM if not the owner. + */ +static int unlock_mutex_state(struct ntsync_obj *mutex, + const struct ntsync_mutex_args *args) +{ + ntsync_assert_held(mutex); + + if (mutex->u.mutex.owner != args->owner) + return -EPERM; + + if (!--mutex->u.mutex.count) + mutex->u.mutex.owner = 0; + return 0; +} + +static int ntsync_mutex_unlock(struct ntsync_obj *mutex, void __user *argp) +{ + struct ntsync_mutex_args __user *user_args = argp; + struct ntsync_device *dev = mutex->dev; + struct ntsync_mutex_args args; + __u32 prev_count; + bool all; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + if (!args.owner) + return -EINVAL; + + if (mutex->type != NTSYNC_TYPE_MUTEX) + return -EINVAL; + + all = ntsync_lock_obj(dev, mutex); + + prev_count = mutex->u.mutex.count; + ret = unlock_mutex_state(mutex, &args); + if (!ret) { + if (all) + try_wake_all_obj(dev, mutex); + try_wake_any_mutex(mutex); + } + + ntsync_unlock_obj(dev, mutex, all); + + if (!ret && put_user(prev_count, &user_args->count)) + ret = -EFAULT; + + return ret; +} + +/* + * Actually change the mutex state to mark its owner as dead, + * returning -EPERM if not the owner. + */ +static int kill_mutex_state(struct ntsync_obj *mutex, __u32 owner) +{ + ntsync_assert_held(mutex); + + if (mutex->u.mutex.owner != owner) + return -EPERM; + + mutex->u.mutex.ownerdead = true; + mutex->u.mutex.owner = 0; + mutex->u.mutex.count = 0; + return 0; +} + +static int ntsync_mutex_kill(struct ntsync_obj *mutex, void __user *argp) +{ + struct ntsync_device *dev = mutex->dev; + __u32 owner; + bool all; + int ret; + + if (get_user(owner, (__u32 __user *)argp)) + return -EFAULT; + if (!owner) + return -EINVAL; + + if (mutex->type != NTSYNC_TYPE_MUTEX) + return -EINVAL; + + all = ntsync_lock_obj(dev, mutex); + + ret = kill_mutex_state(mutex, owner); + if (!ret) { + if (all) + try_wake_all_obj(dev, mutex); + try_wake_any_mutex(mutex); + } + + ntsync_unlock_obj(dev, mutex, all); + + return ret; +} + +static int ntsync_event_set(struct ntsync_obj *event, void __user *argp, bool pulse) +{ + struct ntsync_device *dev = event->dev; + __u32 prev_state; + bool all; + + if (event->type != NTSYNC_TYPE_EVENT) + return -EINVAL; + + all = ntsync_lock_obj(dev, event); + + prev_state = event->u.event.signaled; + event->u.event.signaled = true; + if (all) + try_wake_all_obj(dev, event); + try_wake_any_event(event); + if (pulse) + event->u.event.signaled = false; + + ntsync_unlock_obj(dev, event, all); + + if (put_user(prev_state, (__u32 __user *)argp)) + return -EFAULT; + + return 0; +} + +static int ntsync_event_reset(struct ntsync_obj *event, void __user *argp) +{ + struct ntsync_device *dev = event->dev; + __u32 prev_state; + bool all; + + if (event->type != NTSYNC_TYPE_EVENT) + return -EINVAL; + + all = ntsync_lock_obj(dev, event); + + prev_state = event->u.event.signaled; + event->u.event.signaled = false; + + ntsync_unlock_obj(dev, event, all); + + if (put_user(prev_state, (__u32 __user *)argp)) + return -EFAULT; + + return 0; +} + +static int ntsync_sem_read(struct ntsync_obj *sem, void __user *argp) +{ + struct ntsync_sem_args __user *user_args = argp; + struct ntsync_device *dev = sem->dev; + struct ntsync_sem_args args; + bool all; + + if (sem->type != NTSYNC_TYPE_SEM) + return -EINVAL; + + args.sem = 0; + + all = ntsync_lock_obj(dev, sem); + + args.count = sem->u.sem.count; + args.max = sem->u.sem.max; + + ntsync_unlock_obj(dev, sem, all); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return 0; +} + +static int ntsync_mutex_read(struct ntsync_obj *mutex, void __user *argp) +{ + struct ntsync_mutex_args __user *user_args = argp; + struct ntsync_device *dev = mutex->dev; + struct ntsync_mutex_args args; + bool all; + int ret; + + if (mutex->type != NTSYNC_TYPE_MUTEX) + return -EINVAL; + + args.mutex = 0; + + all = ntsync_lock_obj(dev, mutex); + + args.count = mutex->u.mutex.count; + args.owner = mutex->u.mutex.owner; + ret = mutex->u.mutex.ownerdead ? -EOWNERDEAD : 0; + + ntsync_unlock_obj(dev, mutex, all); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return ret; +} + +static int ntsync_event_read(struct ntsync_obj *event, void __user *argp) +{ + struct ntsync_event_args __user *user_args = argp; + struct ntsync_device *dev = event->dev; + struct ntsync_event_args args; + bool all; + + if (event->type != NTSYNC_TYPE_EVENT) + return -EINVAL; + + args.event = 0; + + all = ntsync_lock_obj(dev, event); + + args.manual = event->u.event.manual; + args.signaled = event->u.event.signaled; + + ntsync_unlock_obj(dev, event, all); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return 0; +} + static int ntsync_obj_release(struct inode *inode, struct file *file) { struct ntsync_obj *obj = file->private_data; @@ -116,6 +676,22 @@ static long ntsync_obj_ioctl(struct file *file, unsigned int cmd, switch (cmd) { case NTSYNC_IOC_SEM_POST: return ntsync_sem_post(obj, argp); + case NTSYNC_IOC_SEM_READ: + return ntsync_sem_read(obj, argp); + case NTSYNC_IOC_MUTEX_UNLOCK: + return ntsync_mutex_unlock(obj, argp); + case NTSYNC_IOC_MUTEX_KILL: + return ntsync_mutex_kill(obj, argp); + case NTSYNC_IOC_MUTEX_READ: + return ntsync_mutex_read(obj, argp); + case NTSYNC_IOC_EVENT_SET: + return ntsync_event_set(obj, argp, false); + case NTSYNC_IOC_EVENT_RESET: + return ntsync_event_reset(obj, argp); + case NTSYNC_IOC_EVENT_PULSE: + return ntsync_event_set(obj, argp, true); + case NTSYNC_IOC_EVENT_READ: + return ntsync_event_read(obj, argp); default: return -ENOIOCTLCMD; } @@ -140,6 +716,9 @@ static struct ntsync_obj *ntsync_alloc_obj(struct ntsync_device *dev, obj->dev = dev; get_file(dev->file); spin_lock_init(&obj->lock); + INIT_LIST_HEAD(&obj->any_waiters); + INIT_LIST_HEAD(&obj->all_waiters); + atomic_set(&obj->all_hint, 0); return obj; } @@ -190,6 +769,400 @@ static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp) return put_user(fd, &user_args->sem); } +static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_mutex_args __user *user_args = argp; + struct ntsync_mutex_args args; + struct ntsync_obj *mutex; + int fd; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + if (!args.owner != !args.count) + return -EINVAL; + + mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX); + if (!mutex) + return -ENOMEM; + mutex->u.mutex.count = args.count; + mutex->u.mutex.owner = args.owner; + fd = ntsync_obj_get_fd(mutex); + if (fd < 0) { + kfree(mutex); + return fd; + } + + return put_user(fd, &user_args->mutex); +} + +static int ntsync_create_event(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_event_args __user *user_args = argp; + struct ntsync_event_args args; + struct ntsync_obj *event; + int fd; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + event = ntsync_alloc_obj(dev, NTSYNC_TYPE_EVENT); + if (!event) + return -ENOMEM; + event->u.event.manual = args.manual; + event->u.event.signaled = args.signaled; + fd = ntsync_obj_get_fd(event); + if (fd < 0) { + kfree(event); + return fd; + } + + return put_user(fd, &user_args->event); +} + +static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd) +{ + struct file *file = fget(fd); + struct ntsync_obj *obj; + + if (!file) + return NULL; + + if (file->f_op != &ntsync_obj_fops) { + fput(file); + return NULL; + } + + obj = file->private_data; + if (obj->dev != dev) { + fput(file); + return NULL; + } + + return obj; +} + +static void put_obj(struct ntsync_obj *obj) +{ + fput(obj->file); +} + +static int ntsync_schedule(const struct ntsync_q *q, const struct ntsync_wait_args *args) +{ + ktime_t timeout = ns_to_ktime(args->timeout); + clockid_t clock = CLOCK_MONOTONIC; + ktime_t *timeout_ptr; + int ret = 0; + + timeout_ptr = (args->timeout == U64_MAX ? NULL : &timeout); + + if (args->flags & NTSYNC_WAIT_REALTIME) + clock = CLOCK_REALTIME; + + do { + if (signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + + set_current_state(TASK_INTERRUPTIBLE); + if (atomic_read(&q->signaled) != -1) { + ret = 0; + break; + } + ret = schedule_hrtimeout_range_clock(timeout_ptr, 0, HRTIMER_MODE_ABS, clock); + } while (ret < 0); + __set_current_state(TASK_RUNNING); + + return ret; +} + +/* + * Allocate and initialize the ntsync_q structure, but do not queue us yet. + */ +static int setup_wait(struct ntsync_device *dev, + const struct ntsync_wait_args *args, bool all, + struct ntsync_q **ret_q) +{ + int fds[NTSYNC_MAX_WAIT_COUNT + 1]; + const __u32 count = args->count; + struct ntsync_q *q; + __u32 total_count; + __u32 i, j; + + if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME)) + return -EINVAL; + + if (args->count > NTSYNC_MAX_WAIT_COUNT) + return -EINVAL; + + total_count = count; + if (args->alert) + total_count++; + + if (copy_from_user(fds, u64_to_user_ptr(args->objs), + array_size(count, sizeof(*fds)))) + return -EFAULT; + if (args->alert) + fds[count] = args->alert; + + q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL); + if (!q) + return -ENOMEM; + q->task = current; + q->owner = args->owner; + atomic_set(&q->signaled, -1); + q->all = all; + q->ownerdead = false; + q->count = count; + + for (i = 0; i < total_count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = get_obj(dev, fds[i]); + + if (!obj) + goto err; + + if (all) { + /* Check that the objects are all distinct. */ + for (j = 0; j < i; j++) { + if (obj == q->entries[j].obj) { + put_obj(obj); + goto err; + } + } + } + + entry->obj = obj; + entry->q = q; + entry->index = i; + } + + *ret_q = q; + return 0; + +err: + for (j = 0; j < i; j++) + put_obj(q->entries[j].obj); + kfree(q); + return -EINVAL; +} + +static void try_wake_any_obj(struct ntsync_obj *obj) +{ + switch (obj->type) { + case NTSYNC_TYPE_SEM: + try_wake_any_sem(obj); + break; + case NTSYNC_TYPE_MUTEX: + try_wake_any_mutex(obj); + break; + case NTSYNC_TYPE_EVENT: + try_wake_any_event(obj); + break; + } +} + +static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_wait_args args; + __u32 i, total_count; + struct ntsync_q *q; + int signaled; + bool all; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + ret = setup_wait(dev, &args, false, &q); + if (ret < 0) + return ret; + + total_count = args.count; + if (args.alert) + total_count++; + + /* queue ourselves */ + + for (i = 0; i < total_count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + all = ntsync_lock_obj(dev, obj); + list_add_tail(&entry->node, &obj->any_waiters); + ntsync_unlock_obj(dev, obj, all); + } + + /* + * Check if we are already signaled. + * + * Note that the API requires that normal objects are checked before + * the alert event. Hence we queue the alert event last, and check + * objects in order. + */ + + for (i = 0; i < total_count; i++) { + struct ntsync_obj *obj = q->entries[i].obj; + + if (atomic_read(&q->signaled) != -1) + break; + + all = ntsync_lock_obj(dev, obj); + try_wake_any_obj(obj); + ntsync_unlock_obj(dev, obj, all); + } + + /* sleep */ + + ret = ntsync_schedule(q, &args); + + /* and finally, unqueue */ + + for (i = 0; i < total_count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + all = ntsync_lock_obj(dev, obj); + list_del(&entry->node); + ntsync_unlock_obj(dev, obj, all); + + put_obj(obj); + } + + signaled = atomic_read(&q->signaled); + if (signaled != -1) { + struct ntsync_wait_args __user *user_args = argp; + + /* even if we caught a signal, we need to communicate success */ + ret = q->ownerdead ? -EOWNERDEAD : 0; + + if (put_user(signaled, &user_args->index)) + ret = -EFAULT; + } else if (!ret) { + ret = -ETIMEDOUT; + } + + kfree(q); + return ret; +} + +static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_wait_args args; + struct ntsync_q *q; + int signaled; + __u32 i; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + ret = setup_wait(dev, &args, true, &q); + if (ret < 0) + return ret; + + /* queue ourselves */ + + mutex_lock(&dev->wait_all_lock); + + for (i = 0; i < args.count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + atomic_inc(&obj->all_hint); + + /* + * obj->all_waiters is protected by dev->wait_all_lock rather + * than obj->lock, so there is no need to acquire obj->lock + * here. + */ + list_add_tail(&entry->node, &obj->all_waiters); + } + if (args.alert) { + struct ntsync_q_entry *entry = &q->entries[args.count]; + struct ntsync_obj *obj = entry->obj; + + dev_lock_obj(dev, obj); + list_add_tail(&entry->node, &obj->any_waiters); + dev_unlock_obj(dev, obj); + } + + /* check if we are already signaled */ + + try_wake_all(dev, q, NULL); + + mutex_unlock(&dev->wait_all_lock); + + /* + * Check if the alert event is signaled, making sure to do so only + * after checking if the other objects are signaled. + */ + + if (args.alert) { + struct ntsync_obj *obj = q->entries[args.count].obj; + + if (atomic_read(&q->signaled) == -1) { + bool all = ntsync_lock_obj(dev, obj); + try_wake_any_obj(obj); + ntsync_unlock_obj(dev, obj, all); + } + } + + /* sleep */ + + ret = ntsync_schedule(q, &args); + + /* and finally, unqueue */ + + mutex_lock(&dev->wait_all_lock); + + for (i = 0; i < args.count; i++) { + struct ntsync_q_entry *entry = &q->entries[i]; + struct ntsync_obj *obj = entry->obj; + + /* + * obj->all_waiters is protected by dev->wait_all_lock rather + * than obj->lock, so there is no need to acquire it here. + */ + list_del(&entry->node); + + atomic_dec(&obj->all_hint); + + put_obj(obj); + } + + mutex_unlock(&dev->wait_all_lock); + + if (args.alert) { + struct ntsync_q_entry *entry = &q->entries[args.count]; + struct ntsync_obj *obj = entry->obj; + bool all; + + all = ntsync_lock_obj(dev, obj); + list_del(&entry->node); + ntsync_unlock_obj(dev, obj, all); + + put_obj(obj); + } + + signaled = atomic_read(&q->signaled); + if (signaled != -1) { + struct ntsync_wait_args __user *user_args = argp; + + /* even if we caught a signal, we need to communicate success */ + ret = q->ownerdead ? -EOWNERDEAD : 0; + + if (put_user(signaled, &user_args->index)) + ret = -EFAULT; + } else if (!ret) { + ret = -ETIMEDOUT; + } + + kfree(q); + return ret; +} + static int ntsync_char_open(struct inode *inode, struct file *file) { struct ntsync_device *dev; @@ -198,6 +1171,8 @@ static int ntsync_char_open(struct inode *inode, struct file *file) if (!dev) return -ENOMEM; + mutex_init(&dev->wait_all_lock); + file->private_data = dev; dev->file = file; return nonseekable_open(inode, file); @@ -219,8 +1194,16 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd, void __user *argp = (void __user *)parm; switch (cmd) { + case NTSYNC_IOC_CREATE_EVENT: + return ntsync_create_event(dev, argp); + case NTSYNC_IOC_CREATE_MUTEX: + return ntsync_create_mutex(dev, argp); case NTSYNC_IOC_CREATE_SEM: return ntsync_create_sem(dev, argp); + case NTSYNC_IOC_WAIT_ALL: + return ntsync_wait_all(dev, argp); + case NTSYNC_IOC_WAIT_ANY: + return ntsync_wait_any(dev, argp); default: return -ENOIOCTLCMD; } diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h index dcfa38fdc93c..4a8095a3fc34 100644 --- a/include/uapi/linux/ntsync.h +++ b/include/uapi/linux/ntsync.h @@ -16,8 +16,47 @@ struct ntsync_sem_args { __u32 max; }; +struct ntsync_mutex_args { + __u32 mutex; + __u32 owner; + __u32 count; +}; + +struct ntsync_event_args { + __u32 event; + __u32 manual; + __u32 signaled; +}; + +#define NTSYNC_WAIT_REALTIME 0x1 + +struct ntsync_wait_args { + __u64 timeout; + __u64 objs; + __u32 count; + __u32 index; + __u32 flags; + __u32 owner; + __u32 alert; + __u32 pad; +}; + +#define NTSYNC_MAX_WAIT_COUNT 64 + #define NTSYNC_IOC_CREATE_SEM _IOWR('N', 0x80, struct ntsync_sem_args) +#define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args) +#define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args) +#define NTSYNC_IOC_CREATE_MUTEX _IOWR('N', 0x84, struct ntsync_sem_args) +#define NTSYNC_IOC_CREATE_EVENT _IOWR('N', 0x87, struct ntsync_event_args) #define NTSYNC_IOC_SEM_POST _IOWR('N', 0x81, __u32) +#define NTSYNC_IOC_MUTEX_UNLOCK _IOWR('N', 0x85, struct ntsync_mutex_args) +#define NTSYNC_IOC_MUTEX_KILL _IOW ('N', 0x86, __u32) +#define NTSYNC_IOC_EVENT_SET _IOR ('N', 0x88, __u32) +#define NTSYNC_IOC_EVENT_RESET _IOR ('N', 0x89, __u32) +#define NTSYNC_IOC_EVENT_PULSE _IOR ('N', 0x8a, __u32) +#define NTSYNC_IOC_SEM_READ _IOR ('N', 0x8b, struct ntsync_sem_args) +#define NTSYNC_IOC_MUTEX_READ _IOR ('N', 0x8c, struct ntsync_mutex_args) +#define NTSYNC_IOC_EVENT_READ _IOR ('N', 0x8d, struct ntsync_event_args) #endif diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile index 363d031a16f7..ff18c0361e38 100644 --- a/tools/testing/selftests/Makefile +++ b/tools/testing/selftests/Makefile @@ -18,6 +18,7 @@ TARGETS += devices/error_logs TARGETS += devices/probe TARGETS += dmabuf-heaps TARGETS += drivers/dma-buf +TARGETS += drivers/ntsync TARGETS += drivers/s390x/uvdevice TARGETS += drivers/net TARGETS += drivers/net/bonding diff --git a/tools/testing/selftests/drivers/ntsync/.gitignore b/tools/testing/selftests/drivers/ntsync/.gitignore new file mode 100644 index 000000000000..848573a3d3ea --- /dev/null +++ b/tools/testing/selftests/drivers/ntsync/.gitignore @@ -0,0 +1 @@ +ntsync diff --git a/tools/testing/selftests/drivers/ntsync/Makefile b/tools/testing/selftests/drivers/ntsync/Makefile new file mode 100644 index 000000000000..dbf2b055c0b2 --- /dev/null +++ b/tools/testing/selftests/drivers/ntsync/Makefile @@ -0,0 +1,7 @@ +# SPDX-LICENSE-IDENTIFIER: GPL-2.0-only +TEST_GEN_PROGS := ntsync + +CFLAGS += $(KHDR_INCLUDES) +LDLIBS += -lpthread + +include ../../lib.mk diff --git a/tools/testing/selftests/drivers/ntsync/config b/tools/testing/selftests/drivers/ntsync/config new file mode 100644 index 000000000000..60539c826d06 --- /dev/null +++ b/tools/testing/selftests/drivers/ntsync/config @@ -0,0 +1 @@ +CONFIG_WINESYNC=y diff --git a/tools/testing/selftests/drivers/ntsync/ntsync.c b/tools/testing/selftests/drivers/ntsync/ntsync.c new file mode 100644 index 000000000000..5fa2c9a0768c --- /dev/null +++ b/tools/testing/selftests/drivers/ntsync/ntsync.c @@ -0,0 +1,1407 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Various unit tests for the "ntsync" synchronization primitive driver. + * + * Copyright (C) 2021-2022 Elizabeth Figura + */ + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include "../../kselftest_harness.h" + +static int read_sem_state(int sem, __u32 *count, __u32 *max) +{ + struct ntsync_sem_args args; + int ret; + + memset(&args, 0xcc, sizeof(args)); + ret = ioctl(sem, NTSYNC_IOC_SEM_READ, &args); + *count = args.count; + *max = args.max; + return ret; +} + +#define check_sem_state(sem, count, max) \ + ({ \ + __u32 __count, __max; \ + int ret = read_sem_state((sem), &__count, &__max); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((count), __count); \ + EXPECT_EQ((max), __max); \ + }) + +static int post_sem(int sem, __u32 *count) +{ + return ioctl(sem, NTSYNC_IOC_SEM_POST, count); +} + +static int read_mutex_state(int mutex, __u32 *count, __u32 *owner) +{ + struct ntsync_mutex_args args; + int ret; + + memset(&args, 0xcc, sizeof(args)); + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &args); + *count = args.count; + *owner = args.owner; + return ret; +} + +#define check_mutex_state(mutex, count, owner) \ + ({ \ + __u32 __count, __owner; \ + int ret = read_mutex_state((mutex), &__count, &__owner); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((count), __count); \ + EXPECT_EQ((owner), __owner); \ + }) + +static int unlock_mutex(int mutex, __u32 owner, __u32 *count) +{ + struct ntsync_mutex_args args; + int ret; + + args.owner = owner; + args.count = 0xdeadbeef; + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_UNLOCK, &args); + *count = args.count; + return ret; +} + +static int read_event_state(int event, __u32 *signaled, __u32 *manual) +{ + struct ntsync_event_args args; + int ret; + + memset(&args, 0xcc, sizeof(args)); + ret = ioctl(event, NTSYNC_IOC_EVENT_READ, &args); + *signaled = args.signaled; + *manual = args.manual; + return ret; +} + +#define check_event_state(event, signaled, manual) \ + ({ \ + __u32 __signaled, __manual; \ + int ret = read_event_state((event), &__signaled, &__manual); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((signaled), __signaled); \ + EXPECT_EQ((manual), __manual); \ + }) + +static int wait_objs(int fd, unsigned long request, __u32 count, + const int *objs, __u32 owner, int alert, __u32 *index) +{ + struct ntsync_wait_args args = {0}; + struct timespec timeout; + int ret; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + args.timeout = timeout.tv_sec * 1000000000 + timeout.tv_nsec; + args.count = count; + args.objs = (uintptr_t)objs; + args.owner = owner; + args.index = 0xdeadbeef; + args.alert = alert; + ret = ioctl(fd, request, &args); + *index = args.index; + return ret; +} + +static int wait_any(int fd, __u32 count, const int *objs, __u32 owner, __u32 *index) +{ + return wait_objs(fd, NTSYNC_IOC_WAIT_ANY, count, objs, owner, 0, index); +} + +static int wait_all(int fd, __u32 count, const int *objs, __u32 owner, __u32 *index) +{ + return wait_objs(fd, NTSYNC_IOC_WAIT_ALL, count, objs, owner, 0, index); +} + +static int wait_any_alert(int fd, __u32 count, const int *objs, + __u32 owner, int alert, __u32 *index) +{ + return wait_objs(fd, NTSYNC_IOC_WAIT_ANY, + count, objs, owner, alert, index); +} + +static int wait_all_alert(int fd, __u32 count, const int *objs, + __u32 owner, int alert, __u32 *index) +{ + return wait_objs(fd, NTSYNC_IOC_WAIT_ALL, + count, objs, owner, alert, index); +} + +TEST(semaphore_state) +{ + struct ntsync_sem_args sem_args; + struct timespec timeout; + __u32 count, index; + int fd, ret, sem; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 3; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + sem_args.count = 2; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + sem = sem_args.sem; + check_sem_state(sem, 2, 2); + + count = 0; + ret = post_sem(sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_sem_state(sem, 2, 2); + + count = 1; + ret = post_sem(sem, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOVERFLOW, errno); + check_sem_state(sem, 2, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem, 1, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem, 0, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + count = 3; + ret = post_sem(sem, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOVERFLOW, errno); + check_sem_state(sem, 0, 2); + + count = 2; + ret = post_sem(sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(sem, 2, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + + count = 1; + ret = post_sem(sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(sem, 1, 2); + + count = ~0u; + ret = post_sem(sem, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOVERFLOW, errno); + check_sem_state(sem, 1, 2); + + close(sem); + + close(fd); +} + +TEST(mutex_state) +{ + struct ntsync_mutex_args mutex_args; + __u32 owner, count, index; + struct timespec timeout; + int fd, ret, mutex; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + mutex_args.owner = 123; + mutex_args.count = 0; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + mutex_args.owner = 0; + mutex_args.count = 2; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + mutex_args.owner = 123; + mutex_args.count = 2; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + mutex = mutex_args.mutex; + check_mutex_state(mutex, 2, 123); + + ret = unlock_mutex(mutex, 0, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = unlock_mutex(mutex, 456, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EPERM, errno); + check_mutex_state(mutex, 2, 123); + + ret = unlock_mutex(mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_mutex_state(mutex, 1, 123); + + ret = unlock_mutex(mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, count); + check_mutex_state(mutex, 0, 0); + + ret = unlock_mutex(mutex, 123, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EPERM, errno); + + ret = wait_any(fd, 1, &mutex, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(mutex, 1, 456); + + ret = wait_any(fd, 1, &mutex, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(mutex, 2, 456); + + ret = unlock_mutex(mutex, 456, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_mutex_state(mutex, 1, 456); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + owner = 0; + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + owner = 123; + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EPERM, errno); + check_mutex_state(mutex, 1, 456); + + owner = 456; + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(0, ret); + + memset(&mutex_args, 0xcc, sizeof(mutex_args)); + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + memset(&mutex_args, 0xcc, sizeof(mutex_args)); + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, index); + check_mutex_state(mutex, 1, 123); + + owner = 123; + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(0, ret); + + memset(&mutex_args, 0xcc, sizeof(mutex_args)); + ret = ioctl(mutex, NTSYNC_IOC_MUTEX_READ, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, index); + check_mutex_state(mutex, 1, 123); + + close(mutex); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + mutex = mutex_args.mutex; + check_mutex_state(mutex, 0, 0); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(mutex, 1, 123); + + close(mutex); + + mutex_args.owner = 123; + mutex_args.count = ~0u; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + mutex = mutex_args.mutex; + check_mutex_state(mutex, ~0u, 123); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + close(mutex); + + close(fd); +} + +TEST(manual_event_state) +{ + struct ntsync_event_args event_args; + __u32 index, signaled; + int fd, event, ret; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + event_args.manual = 1; + event_args.signaled = 0; + event_args.event = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, event_args.event); + event = event_args.event; + check_event_state(event, 0, 1); + + signaled = 0xdeadbeef; + ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event, 1, 1); + + ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + check_event_state(event, 1, 1); + + ret = wait_any(fd, 1, &event, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_event_state(event, 1, 1); + + signaled = 0xdeadbeef; + ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + check_event_state(event, 0, 1); + + ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event, 0, 1); + + ret = wait_any(fd, 1, &event, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + + ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + check_event_state(event, 0, 1); + + ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event, 0, 1); + + close(event); + + close(fd); +} + +TEST(auto_event_state) +{ + struct ntsync_event_args event_args; + __u32 index, signaled; + int fd, event, ret; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + event_args.manual = 0; + event_args.signaled = 1; + event_args.event = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, event_args.event); + event = event_args.event; + + check_event_state(event, 1, 0); + + signaled = 0xdeadbeef; + ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + check_event_state(event, 1, 0); + + ret = wait_any(fd, 1, &event, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_event_state(event, 0, 0); + + signaled = 0xdeadbeef; + ret = ioctl(event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event, 0, 0); + + ret = wait_any(fd, 1, &event, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + + ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + check_event_state(event, 0, 0); + + ret = ioctl(event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event, 0, 0); + + close(event); + + close(fd); +} + +TEST(test_wait_any) +{ + int objs[NTSYNC_MAX_WAIT_COUNT + 1], fd, ret; + struct ntsync_mutex_args mutex_args = {0}; + struct ntsync_sem_args sem_args = {0}; + __u32 owner, index, count, i; + struct timespec timeout; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 1, 3); + check_mutex_state(mutex_args.mutex, 0, 0); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 0, 0); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 1, 123); + + count = 1; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 1, 123); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 2, 123); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + owner = 123; + ret = ioctl(mutex_args.mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(0, ret); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(1, index); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + + /* test waiting on the same object twice */ + count = 2; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + + objs[0] = objs[1] = sem_args.sem; + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 1, 3); + + ret = wait_any(fd, 0, NULL, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + for (i = 0; i < NTSYNC_MAX_WAIT_COUNT + 1; ++i) + objs[i] = sem_args.sem; + + ret = wait_any(fd, NTSYNC_MAX_WAIT_COUNT, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_any(fd, NTSYNC_MAX_WAIT_COUNT + 1, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = wait_any(fd, -1, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + close(sem_args.sem); + close(mutex_args.mutex); + + close(fd); +} + +TEST(test_wait_all) +{ + struct ntsync_event_args event_args = {0}; + struct ntsync_mutex_args mutex_args = {0}; + struct ntsync_sem_args sem_args = {0}; + __u32 owner, index, count; + int objs[2], fd, ret; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 1, 3); + check_mutex_state(mutex_args.mutex, 1, 123); + + ret = wait_all(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + check_sem_state(sem_args.sem, 1, 3); + check_mutex_state(mutex_args.mutex, 1, 123); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 2, 123); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + check_sem_state(sem_args.sem, 0, 3); + check_mutex_state(mutex_args.mutex, 2, 123); + + count = 3; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 2, 3); + check_mutex_state(mutex_args.mutex, 3, 123); + + owner = 123; + ret = ioctl(mutex_args.mutex, NTSYNC_IOC_MUTEX_KILL, &owner); + EXPECT_EQ(0, ret); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + check_sem_state(sem_args.sem, 1, 3); + check_mutex_state(mutex_args.mutex, 1, 123); + + objs[0] = sem_args.sem; + objs[1] = event_args.event; + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(sem_args.sem, 0, 3); + check_event_state(event_args.event, 1, 1); + + /* test waiting on the same object twice */ + objs[0] = objs[1] = sem_args.sem; + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + close(sem_args.sem); + close(mutex_args.mutex); + close(event_args.event); + + close(fd); +} + +struct wake_args { + int fd; + int obj; +}; + +struct wait_args { + int fd; + unsigned long request; + struct ntsync_wait_args *args; + int ret; + int err; +}; + +static void *wait_thread(void *arg) +{ + struct wait_args *args = arg; + + args->ret = ioctl(args->fd, args->request, args->args); + args->err = errno; + return NULL; +} + +static __u64 get_abs_timeout(unsigned int ms) +{ + struct timespec timeout; + clock_gettime(CLOCK_MONOTONIC, &timeout); + return (timeout.tv_sec * 1000000000) + timeout.tv_nsec + (ms * 1000000); +} + +static int wait_for_thread(pthread_t thread, unsigned int ms) +{ + struct timespec timeout; + + clock_gettime(CLOCK_REALTIME, &timeout); + timeout.tv_nsec += ms * 1000000; + timeout.tv_sec += (timeout.tv_nsec / 1000000000); + timeout.tv_nsec %= 1000000000; + return pthread_timedjoin_np(thread, NULL, &timeout); +} + +TEST(wake_any) +{ + struct ntsync_event_args event_args = {0}; + struct ntsync_mutex_args mutex_args = {0}; + struct ntsync_wait_args wait_args = {0}; + struct ntsync_sem_args sem_args = {0}; + struct wait_args thread_args; + __u32 count, index, signaled; + int objs[2], fd, ret; + pthread_t thread; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 123; + mutex_args.count = 1; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + /* test waking the semaphore */ + + wait_args.timeout = get_abs_timeout(1000); + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 456; + wait_args.index = 0xdeadbeef; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = NTSYNC_IOC_WAIT_ANY; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + count = 1; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(sem_args.sem, 0, 3); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(0, wait_args.index); + + /* test waking the mutex */ + + /* first grab it again for owner 123 */ + ret = wait_any(fd, 1, &mutex_args.mutex, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + wait_args.timeout = get_abs_timeout(1000); + wait_args.owner = 456; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = unlock_mutex(mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + ret = unlock_mutex(mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, mutex_args.count); + check_mutex_state(mutex_args.mutex, 1, 456); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + /* test waking events */ + + event_args.manual = false; + event_args.signaled = false; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[1] = event_args.event; + wait_args.timeout = get_abs_timeout(1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + wait_args.timeout = get_abs_timeout(1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + close(event_args.event); + + event_args.manual = true; + event_args.signaled = false; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[1] = event_args.event; + wait_args.timeout = get_abs_timeout(1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event_args.event, 1, 1); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + + wait_args.timeout = get_abs_timeout(1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_PULSE, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + check_event_state(event_args.event, 0, 1); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + close(event_args.event); + + /* delete an object while it's being waited on */ + + wait_args.timeout = get_abs_timeout(200); + wait_args.owner = 123; + objs[1] = mutex_args.mutex; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + close(sem_args.sem); + close(mutex_args.mutex); + + ret = wait_for_thread(thread, 200); + EXPECT_EQ(0, ret); + EXPECT_EQ(-1, thread_args.ret); + EXPECT_EQ(ETIMEDOUT, thread_args.err); + + close(fd); +} + +TEST(wake_all) +{ + struct ntsync_event_args manual_event_args = {0}; + struct ntsync_event_args auto_event_args = {0}; + struct ntsync_mutex_args mutex_args = {0}; + struct ntsync_wait_args wait_args = {0}; + struct ntsync_sem_args sem_args = {0}; + struct wait_args thread_args; + __u32 count, index, signaled; + int objs[4], fd, ret; + pthread_t thread; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 123; + mutex_args.count = 1; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + manual_event_args.manual = true; + manual_event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &manual_event_args); + EXPECT_EQ(0, ret); + + auto_event_args.manual = false; + auto_event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &auto_event_args); + EXPECT_EQ(0, ret); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + objs[2] = manual_event_args.event; + objs[3] = auto_event_args.event; + + wait_args.timeout = get_abs_timeout(1000); + wait_args.objs = (uintptr_t)objs; + wait_args.count = 4; + wait_args.owner = 456; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = NTSYNC_IOC_WAIT_ALL; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + count = 1; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + check_sem_state(sem_args.sem, 1, 3); + + ret = wait_any(fd, 1, &sem_args.sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = unlock_mutex(mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + check_mutex_state(mutex_args.mutex, 0, 0); + + ret = ioctl(manual_event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + + count = 2; + ret = post_sem(sem_args.sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(sem_args.sem, 2, 3); + + ret = ioctl(auto_event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, signaled); + + ret = ioctl(manual_event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + + ret = ioctl(auto_event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, signaled); + + check_sem_state(sem_args.sem, 1, 3); + check_mutex_state(mutex_args.mutex, 1, 456); + check_event_state(manual_event_args.event, 1, 1); + check_event_state(auto_event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + + /* delete an object while it's being waited on */ + + wait_args.timeout = get_abs_timeout(200); + wait_args.owner = 123; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + close(sem_args.sem); + close(mutex_args.mutex); + close(manual_event_args.event); + close(auto_event_args.event); + + ret = wait_for_thread(thread, 200); + EXPECT_EQ(0, ret); + EXPECT_EQ(-1, thread_args.ret); + EXPECT_EQ(ETIMEDOUT, thread_args.err); + + close(fd); +} + +TEST(alert_any) +{ + struct ntsync_event_args event_args = {0}; + struct ntsync_wait_args wait_args = {0}; + struct ntsync_sem_args sem_args = {0}; + __u32 index, count, signaled; + struct wait_args thread_args; + int objs[2], fd, ret; + pthread_t thread; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[0] = sem_args.sem; + + sem_args.count = 1; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[1] = sem_args.sem; + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + /* test wakeup via alert */ + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + + wait_args.timeout = get_abs_timeout(1000); + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 123; + wait_args.index = 0xdeadbeef; + wait_args.alert = event_args.event; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = NTSYNC_IOC_WAIT_ANY; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(2, wait_args.index); + + close(event_args.event); + + /* test with an auto-reset event */ + + event_args.manual = false; + event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + count = 1; + ret = post_sem(objs[0], &count); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + close(event_args.event); + + close(objs[0]); + close(objs[1]); + + close(fd); +} + +TEST(alert_all) +{ + struct ntsync_event_args event_args = {0}; + struct ntsync_wait_args wait_args = {0}; + struct ntsync_sem_args sem_args = {0}; + struct wait_args thread_args; + __u32 index, count, signaled; + int objs[2], fd, ret; + pthread_t thread; + + fd = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[0] = sem_args.sem; + + sem_args.count = 1; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, NTSYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[1] = sem_args.sem; + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + /* test wakeup via alert */ + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_RESET, &signaled); + EXPECT_EQ(0, ret); + + wait_args.timeout = get_abs_timeout(1000); + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 123; + wait_args.index = 0xdeadbeef; + wait_args.alert = event_args.event; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = NTSYNC_IOC_WAIT_ALL; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(event_args.event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(2, wait_args.index); + + close(event_args.event); + + /* test with an auto-reset event */ + + event_args.manual = false; + event_args.signaled = true; + ret = ioctl(fd, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + count = 2; + ret = post_sem(objs[1], &count); + EXPECT_EQ(0, ret); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + close(event_args.event); + + close(objs[0]); + close(objs[1]); + + close(fd); +} + +#define STRESS_LOOPS 10000 +#define STRESS_THREADS 4 + +static unsigned int stress_counter; +static int stress_device, stress_start_event, stress_mutex; + +static void *stress_thread(void *arg) +{ + struct ntsync_wait_args wait_args = {0}; + __u32 index, count, i; + int ret; + + wait_args.timeout = UINT64_MAX; + wait_args.count = 1; + wait_args.objs = (uintptr_t)&stress_start_event; + wait_args.owner = gettid(); + wait_args.index = 0xdeadbeef; + + ioctl(stress_device, NTSYNC_IOC_WAIT_ANY, &wait_args); + + wait_args.objs = (uintptr_t)&stress_mutex; + + for (i = 0; i < STRESS_LOOPS; ++i) { + ioctl(stress_device, NTSYNC_IOC_WAIT_ANY, &wait_args); + + ++stress_counter; + + unlock_mutex(stress_mutex, wait_args.owner, &count); + } + + return NULL; +} + +TEST(stress_wait) +{ + struct ntsync_event_args event_args; + struct ntsync_mutex_args mutex_args; + pthread_t threads[STRESS_THREADS]; + __u32 signaled, i; + int ret; + + stress_device = open("/dev/ntsync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, stress_device); + + mutex_args.owner = 0; + mutex_args.count = 0; + ret = ioctl(stress_device, NTSYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + stress_mutex = mutex_args.mutex; + + event_args.manual = 1; + event_args.signaled = 0; + ret = ioctl(stress_device, NTSYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + stress_start_event = event_args.event; + + for (i = 0; i < STRESS_THREADS; ++i) + pthread_create(&threads[i], NULL, stress_thread, NULL); + + ret = ioctl(stress_start_event, NTSYNC_IOC_EVENT_SET, &signaled); + EXPECT_EQ(0, ret); + + for (i = 0; i < STRESS_THREADS; ++i) { + ret = pthread_join(threads[i], NULL); + EXPECT_EQ(0, ret); + } + + EXPECT_EQ(STRESS_LOOPS * STRESS_THREADS, stress_counter); + + close(stress_start_event); + close(stress_mutex); + close(stress_device); +} + +TEST_HARNESS_MAIN -- 2.47.0 From b71dfd054ca2932d63b75136a3191c01e74a374c Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:20:57 +0100 Subject: [PATCH 10/13] openvpn-dco Signed-off-by: Peter Jung --- Documentation/netlink/specs/ovpn.yaml | 362 +++ MAINTAINERS | 11 + drivers/net/Kconfig | 14 + drivers/net/Makefile | 1 + drivers/net/ovpn/Makefile | 22 + drivers/net/ovpn/bind.c | 54 + drivers/net/ovpn/bind.h | 117 + drivers/net/ovpn/crypto.c | 214 ++ drivers/net/ovpn/crypto.h | 145 + drivers/net/ovpn/crypto_aead.c | 386 +++ drivers/net/ovpn/crypto_aead.h | 33 + drivers/net/ovpn/io.c | 462 ++++ drivers/net/ovpn/io.h | 25 + drivers/net/ovpn/main.c | 337 +++ drivers/net/ovpn/main.h | 24 + drivers/net/ovpn/netlink-gen.c | 212 ++ drivers/net/ovpn/netlink-gen.h | 41 + drivers/net/ovpn/netlink.c | 1135 ++++++++ drivers/net/ovpn/netlink.h | 18 + drivers/net/ovpn/ovpnstruct.h | 61 + drivers/net/ovpn/packet.h | 40 + drivers/net/ovpn/peer.c | 1201 +++++++++ drivers/net/ovpn/peer.h | 165 ++ drivers/net/ovpn/pktid.c | 130 + drivers/net/ovpn/pktid.h | 87 + drivers/net/ovpn/proto.h | 104 + drivers/net/ovpn/skb.h | 56 + drivers/net/ovpn/socket.c | 178 ++ drivers/net/ovpn/socket.h | 55 + drivers/net/ovpn/stats.c | 21 + drivers/net/ovpn/stats.h | 47 + drivers/net/ovpn/tcp.c | 506 ++++ drivers/net/ovpn/tcp.h | 44 + drivers/net/ovpn/udp.c | 406 +++ drivers/net/ovpn/udp.h | 26 + include/net/netlink.h | 1 + include/uapi/linux/if_link.h | 15 + include/uapi/linux/ovpn.h | 109 + include/uapi/linux/udp.h | 1 + tools/net/ynl/ynl-gen-c.py | 2 + tools/testing/selftests/Makefile | 1 + tools/testing/selftests/net/ovpn/.gitignore | 2 + tools/testing/selftests/net/ovpn/Makefile | 17 + tools/testing/selftests/net/ovpn/config | 10 + tools/testing/selftests/net/ovpn/data64.key | 5 + tools/testing/selftests/net/ovpn/ovpn-cli.c | 2370 +++++++++++++++++ .../testing/selftests/net/ovpn/tcp_peers.txt | 5 + .../selftests/net/ovpn/test-chachapoly.sh | 9 + .../testing/selftests/net/ovpn/test-float.sh | 9 + tools/testing/selftests/net/ovpn/test-tcp.sh | 9 + tools/testing/selftests/net/ovpn/test.sh | 183 ++ .../testing/selftests/net/ovpn/udp_peers.txt | 5 + 52 files changed, 9493 insertions(+) create mode 100644 Documentation/netlink/specs/ovpn.yaml create mode 100644 drivers/net/ovpn/Makefile create mode 100644 drivers/net/ovpn/bind.c create mode 100644 drivers/net/ovpn/bind.h create mode 100644 drivers/net/ovpn/crypto.c create mode 100644 drivers/net/ovpn/crypto.h create mode 100644 drivers/net/ovpn/crypto_aead.c create mode 100644 drivers/net/ovpn/crypto_aead.h create mode 100644 drivers/net/ovpn/io.c create mode 100644 drivers/net/ovpn/io.h create mode 100644 drivers/net/ovpn/main.c create mode 100644 drivers/net/ovpn/main.h create mode 100644 drivers/net/ovpn/netlink-gen.c create mode 100644 drivers/net/ovpn/netlink-gen.h create mode 100644 drivers/net/ovpn/netlink.c create mode 100644 drivers/net/ovpn/netlink.h create mode 100644 drivers/net/ovpn/ovpnstruct.h create mode 100644 drivers/net/ovpn/packet.h create mode 100644 drivers/net/ovpn/peer.c create mode 100644 drivers/net/ovpn/peer.h create mode 100644 drivers/net/ovpn/pktid.c create mode 100644 drivers/net/ovpn/pktid.h create mode 100644 drivers/net/ovpn/proto.h create mode 100644 drivers/net/ovpn/skb.h create mode 100644 drivers/net/ovpn/socket.c create mode 100644 drivers/net/ovpn/socket.h create mode 100644 drivers/net/ovpn/stats.c create mode 100644 drivers/net/ovpn/stats.h create mode 100644 drivers/net/ovpn/tcp.c create mode 100644 drivers/net/ovpn/tcp.h create mode 100644 drivers/net/ovpn/udp.c create mode 100644 drivers/net/ovpn/udp.h create mode 100644 include/uapi/linux/ovpn.h create mode 100644 tools/testing/selftests/net/ovpn/.gitignore create mode 100644 tools/testing/selftests/net/ovpn/Makefile create mode 100644 tools/testing/selftests/net/ovpn/config create mode 100644 tools/testing/selftests/net/ovpn/data64.key create mode 100644 tools/testing/selftests/net/ovpn/ovpn-cli.c create mode 100644 tools/testing/selftests/net/ovpn/tcp_peers.txt create mode 100755 tools/testing/selftests/net/ovpn/test-chachapoly.sh create mode 100755 tools/testing/selftests/net/ovpn/test-float.sh create mode 100755 tools/testing/selftests/net/ovpn/test-tcp.sh create mode 100755 tools/testing/selftests/net/ovpn/test.sh create mode 100644 tools/testing/selftests/net/ovpn/udp_peers.txt diff --git a/Documentation/netlink/specs/ovpn.yaml b/Documentation/netlink/specs/ovpn.yaml new file mode 100644 index 000000000000..79339c25d607 --- /dev/null +++ b/Documentation/netlink/specs/ovpn.yaml @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) +# +# Author: Antonio Quartulli +# +# Copyright (c) 2024, OpenVPN Inc. +# + +name: ovpn + +protocol: genetlink + +doc: Netlink protocol to control OpenVPN network devices + +definitions: + - + type: const + name: nonce-tail-size + value: 8 + - + type: enum + name: cipher-alg + entries: [ none, aes-gcm, chacha20-poly1305 ] + - + type: enum + name: del-peer-reason + entries: [ teardown, userspace, expired, transport-error, transport-disconnect ] + - + type: enum + name: key-slot + entries: [ primary, secondary ] + +attribute-sets: + - + name: peer + attributes: + - + name: id + type: u32 + doc: | + The unique ID of the peer. To be used to identify peers during + operations + checks: + max: 0xFFFFFF + - + name: remote-ipv4 + type: u32 + doc: The remote IPv4 address of the peer + byte-order: big-endian + display-hint: ipv4 + - + name: remote-ipv6 + type: binary + doc: The remote IPv6 address of the peer + display-hint: ipv6 + checks: + exact-len: 16 + - + name: remote-ipv6-scope-id + type: u32 + doc: The scope id of the remote IPv6 address of the peer (RFC2553) + - + name: remote-port + type: u16 + doc: The remote port of the peer + byte-order: big-endian + checks: + min: 1 + - + name: socket + type: u32 + doc: The socket to be used to communicate with the peer + - + name: vpn-ipv4 + type: u32 + doc: The IPv4 address assigned to the peer by the server + byte-order: big-endian + display-hint: ipv4 + - + name: vpn-ipv6 + type: binary + doc: The IPv6 address assigned to the peer by the server + display-hint: ipv6 + checks: + exact-len: 16 + - + name: local-ipv4 + type: u32 + doc: The local IPv4 to be used to send packets to the peer (UDP only) + byte-order: big-endian + display-hint: ipv4 + - + name: local-ipv6 + type: binary + doc: The local IPv6 to be used to send packets to the peer (UDP only) + display-hint: ipv6 + checks: + exact-len: 16 + - + name: local-port + type: u16 + doc: The local port to be used to send packets to the peer (UDP only) + byte-order: big-endian + checks: + min: 1 + - + name: keepalive-interval + type: u32 + doc: | + The number of seconds after which a keep alive message is sent to the + peer + - + name: keepalive-timeout + type: u32 + doc: | + The number of seconds from the last activity after which the peer is + assumed dead + - + name: del-reason + type: u32 + doc: The reason why a peer was deleted + enum: del-peer-reason + - + name: vpn-rx-bytes + type: uint + doc: Number of bytes received over the tunnel + - + name: vpn-tx-bytes + type: uint + doc: Number of bytes transmitted over the tunnel + - + name: vpn-rx-packets + type: uint + doc: Number of packets received over the tunnel + - + name: vpn-tx-packets + type: uint + doc: Number of packets transmitted over the tunnel + - + name: link-rx-bytes + type: uint + doc: Number of bytes received at the transport level + - + name: link-tx-bytes + type: uint + doc: Number of bytes transmitted at the transport level + - + name: link-rx-packets + type: u32 + doc: Number of packets received at the transport level + - + name: link-tx-packets + type: u32 + doc: Number of packets transmitted at the transport level + - + name: keyconf + attributes: + - + name: peer-id + type: u32 + doc: | + The unique ID of the peer. To be used to identify peers during + key operations + checks: + max: 0xFFFFFF + - + name: slot + type: u32 + doc: The slot where the key should be stored + enum: key-slot + - + name: key-id + doc: | + The unique ID of the key. Used to fetch the correct key upon + decryption + type: u32 + checks: + max: 7 + - + name: cipher-alg + type: u32 + doc: The cipher to be used when communicating with the peer + enum: cipher-alg + - + name: encrypt-dir + type: nest + doc: Key material for encrypt direction + nested-attributes: keydir + - + name: decrypt-dir + type: nest + doc: Key material for decrypt direction + nested-attributes: keydir + - + name: keydir + attributes: + - + name: cipher-key + type: binary + doc: The actual key to be used by the cipher + checks: + max-len: 256 + - + name: nonce-tail + type: binary + doc: | + Random nonce to be concatenated to the packet ID, in order to + obtain the actual cipher IV + checks: + exact-len: nonce-tail-size + - + name: ovpn + attributes: + - + name: ifindex + type: u32 + doc: Index of the ovpn interface to operate on + - + name: ifname + type: string + doc: Name of the ovpn interface + - + name: peer + type: nest + doc: | + The peer object containing the attributed of interest for the specific + operation + nested-attributes: peer + - + name: keyconf + type: nest + doc: Peer specific cipher configuration + nested-attributes: keyconf + +operations: + list: + - + name: peer-new + attribute-set: ovpn + flags: [ admin-perm ] + doc: Add a remote peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - peer + - + name: peer-set + attribute-set: ovpn + flags: [ admin-perm ] + doc: modify a remote peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - peer + - + name: peer-get + attribute-set: ovpn + flags: [ admin-perm ] + doc: Retrieve data about existing remote peers (or a specific one) + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - peer + reply: + attributes: + - peer + dump: + request: + attributes: + - ifindex + reply: + attributes: + - peer + - + name: peer-del + attribute-set: ovpn + flags: [ admin-perm ] + doc: Delete existing remote peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - peer + - + name: peer-del-ntf + doc: Notification about a peer being deleted + notify: peer-get + mcgrp: peers + + - + name: key-new + attribute-set: ovpn + flags: [ admin-perm ] + doc: Add a cipher key for a specific peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - keyconf + - + name: key-get + attribute-set: ovpn + flags: [ admin-perm ] + doc: Retrieve non-sensitive data about peer key and cipher + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - keyconf + reply: + attributes: + - keyconf + - + name: key-swap + attribute-set: ovpn + flags: [ admin-perm ] + doc: Swap primary and secondary session keys for a specific peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - keyconf + - + name: key-swap-ntf + notify: key-get + doc: | + Notification about key having exhausted its IV space and requiring + renegotiation + mcgrp: peers + - + name: key-del + attribute-set: ovpn + flags: [ admin-perm ] + doc: Delete cipher key for a specific peer + do: + pre: ovpn-nl-pre-doit + post: ovpn-nl-post-doit + request: + attributes: + - ifindex + - keyconf + +mcast-groups: + list: + - + name: peers diff --git a/MAINTAINERS b/MAINTAINERS index 3ca514d82269..f509050e63ed 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -17362,6 +17362,17 @@ F: arch/openrisc/ F: drivers/irqchip/irq-ompic.c F: drivers/irqchip/irq-or1k-* +OPENVPN DATA CHANNEL OFFLOAD +M: Antonio Quartulli +L: openvpn-devel@lists.sourceforge.net (moderated for non-subscribers) +L: netdev@vger.kernel.org +S: Supported +T: git https://github.com/OpenVPN/linux-kernel-ovpn.git +F: Documentation/netlink/specs/ovpn.yaml +F: drivers/net/ovpn/ +F: include/uapi/linux/ovpn.h +F: tools/testing/selftests/net/ovpn/ + OPENVSWITCH M: Pravin B Shelar L: netdev@vger.kernel.org diff --git a/drivers/net/Kconfig b/drivers/net/Kconfig index 9920b3a68ed1..ddc65bc1e218 100644 --- a/drivers/net/Kconfig +++ b/drivers/net/Kconfig @@ -115,6 +115,20 @@ config WIREGUARD_DEBUG Say N here unless you know what you're doing. +config OVPN + tristate "OpenVPN data channel offload" + depends on NET && INET + select STREAM_PARSER + select NET_UDP_TUNNEL + select DST_CACHE + select CRYPTO + select CRYPTO_AES + select CRYPTO_GCM + select CRYPTO_CHACHA20POLY1305 + help + This module enhances the performance of the OpenVPN userspace software + by offloading the data channel processing to kernelspace. + config EQUALIZER tristate "EQL (serial line load balancing) support" help diff --git a/drivers/net/Makefile b/drivers/net/Makefile index 13743d0e83b5..5152b3330e28 100644 --- a/drivers/net/Makefile +++ b/drivers/net/Makefile @@ -11,6 +11,7 @@ obj-$(CONFIG_IPVLAN) += ipvlan/ obj-$(CONFIG_IPVTAP) += ipvlan/ obj-$(CONFIG_DUMMY) += dummy.o obj-$(CONFIG_WIREGUARD) += wireguard/ +obj-$(CONFIG_OVPN) += ovpn/ obj-$(CONFIG_EQUALIZER) += eql.o obj-$(CONFIG_IFB) += ifb.o obj-$(CONFIG_MACSEC) += macsec.o diff --git a/drivers/net/ovpn/Makefile b/drivers/net/ovpn/Makefile new file mode 100644 index 000000000000..f4d4bd87c851 --- /dev/null +++ b/drivers/net/ovpn/Makefile @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# ovpn -- OpenVPN data channel offload in kernel space +# +# Copyright (C) 2020-2024 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +obj-$(CONFIG_OVPN) := ovpn.o +ovpn-y += bind.o +ovpn-y += crypto.o +ovpn-y += crypto_aead.o +ovpn-y += main.o +ovpn-y += io.o +ovpn-y += netlink.o +ovpn-y += netlink-gen.o +ovpn-y += peer.o +ovpn-y += pktid.o +ovpn-y += socket.o +ovpn-y += stats.o +ovpn-y += tcp.o +ovpn-y += udp.o diff --git a/drivers/net/ovpn/bind.c b/drivers/net/ovpn/bind.c new file mode 100644 index 000000000000..d17d078c5730 --- /dev/null +++ b/drivers/net/ovpn/bind.c @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2012-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include + +#include "ovpnstruct.h" +#include "bind.h" +#include "peer.h" + +/** + * ovpn_bind_from_sockaddr - retrieve binding matching sockaddr + * @ss: the sockaddr to match + * + * Return: the bind matching the passed sockaddr if found, NULL otherwise + */ +struct ovpn_bind *ovpn_bind_from_sockaddr(const struct sockaddr_storage *ss) +{ + struct ovpn_bind *bind; + size_t sa_len; + + if (ss->ss_family == AF_INET) + sa_len = sizeof(struct sockaddr_in); + else if (ss->ss_family == AF_INET6) + sa_len = sizeof(struct sockaddr_in6); + else + return ERR_PTR(-EAFNOSUPPORT); + + bind = kzalloc(sizeof(*bind), GFP_ATOMIC); + if (unlikely(!bind)) + return ERR_PTR(-ENOMEM); + + memcpy(&bind->remote, ss, sa_len); + + return bind; +} + +/** + * ovpn_bind_reset - assign new binding to peer + * @peer: the peer whose binding has to be replaced + * @new: the new bind to assign + */ +void ovpn_bind_reset(struct ovpn_peer *peer, struct ovpn_bind *new) + __must_hold(&peer->lock) +{ + kfree_rcu(rcu_replace_pointer(peer->bind, new, + lockdep_is_held(&peer->lock)), rcu); +} diff --git a/drivers/net/ovpn/bind.h b/drivers/net/ovpn/bind.h new file mode 100644 index 000000000000..859213d5040d --- /dev/null +++ b/drivers/net/ovpn/bind.h @@ -0,0 +1,117 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2012-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPNBIND_H_ +#define _NET_OVPN_OVPNBIND_H_ + +#include +#include +#include +#include +#include +#include + +struct ovpn_peer; + +/** + * union ovpn_sockaddr - basic transport layer address + * @in4: IPv4 address + * @in6: IPv6 address + */ +union ovpn_sockaddr { + struct sockaddr_in in4; + struct sockaddr_in6 in6; +}; + +/** + * struct ovpn_bind - remote peer binding + * @remote: the remote peer sockaddress + * @local: local endpoint used to talk to the peer + * @local.ipv4: local IPv4 used to talk to the peer + * @local.ipv6: local IPv6 used to talk to the peer + * @rcu: used to schedule RCU cleanup job + */ +struct ovpn_bind { + union ovpn_sockaddr remote; /* remote sockaddr */ + + union { + struct in_addr ipv4; + struct in6_addr ipv6; + } local; + + struct rcu_head rcu; +}; + +/** + * skb_protocol_to_family - translate skb->protocol to AF_INET or AF_INET6 + * @skb: the packet sk_buff to inspect + * + * Return: AF_INET, AF_INET6 or 0 in case of unknown protocol + */ +static inline unsigned short skb_protocol_to_family(const struct sk_buff *skb) +{ + switch (skb->protocol) { + case htons(ETH_P_IP): + return AF_INET; + case htons(ETH_P_IPV6): + return AF_INET6; + default: + return 0; + } +} + +/** + * ovpn_bind_skb_src_match - match packet source with binding + * @bind: the binding to match + * @skb: the packet to match + * + * Return: true if the packet source matches the remote peer sockaddr + * in the binding + */ +static inline bool ovpn_bind_skb_src_match(const struct ovpn_bind *bind, + const struct sk_buff *skb) +{ + const unsigned short family = skb_protocol_to_family(skb); + const union ovpn_sockaddr *remote; + + if (unlikely(!bind)) + return false; + + remote = &bind->remote; + + if (unlikely(remote->in4.sin_family != family)) + return false; + + switch (family) { + case AF_INET: + if (unlikely(remote->in4.sin_addr.s_addr != ip_hdr(skb)->saddr)) + return false; + + if (unlikely(remote->in4.sin_port != udp_hdr(skb)->source)) + return false; + break; + case AF_INET6: + if (unlikely(!ipv6_addr_equal(&remote->in6.sin6_addr, + &ipv6_hdr(skb)->saddr))) + return false; + + if (unlikely(remote->in6.sin6_port != udp_hdr(skb)->source)) + return false; + break; + default: + return false; + } + + return true; +} + +struct ovpn_bind *ovpn_bind_from_sockaddr(const struct sockaddr_storage *sa); +void ovpn_bind_reset(struct ovpn_peer *peer, struct ovpn_bind *bind); + +#endif /* _NET_OVPN_OVPNBIND_H_ */ diff --git a/drivers/net/ovpn/crypto.c b/drivers/net/ovpn/crypto.c new file mode 100644 index 000000000000..a2346bc630be --- /dev/null +++ b/drivers/net/ovpn/crypto.c @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "packet.h" +#include "pktid.h" +#include "crypto_aead.h" +#include "crypto.h" + +static void ovpn_ks_destroy_rcu(struct rcu_head *head) +{ + struct ovpn_crypto_key_slot *ks; + + ks = container_of(head, struct ovpn_crypto_key_slot, rcu); + ovpn_aead_crypto_key_slot_destroy(ks); +} + +void ovpn_crypto_key_slot_release(struct kref *kref) +{ + struct ovpn_crypto_key_slot *ks; + + ks = container_of(kref, struct ovpn_crypto_key_slot, refcount); + call_rcu(&ks->rcu, ovpn_ks_destroy_rcu); +} + +/* can only be invoked when all peer references have been dropped (i.e. RCU + * release routine) + */ +void ovpn_crypto_state_release(struct ovpn_crypto_state *cs) +{ + struct ovpn_crypto_key_slot *ks; + + ks = rcu_access_pointer(cs->slots[0]); + if (ks) { + RCU_INIT_POINTER(cs->slots[0], NULL); + ovpn_crypto_key_slot_put(ks); + } + + ks = rcu_access_pointer(cs->slots[1]); + if (ks) { + RCU_INIT_POINTER(cs->slots[1], NULL); + ovpn_crypto_key_slot_put(ks); + } +} + +/* removes the key matching the specified id from the crypto context */ +void ovpn_crypto_kill_key(struct ovpn_crypto_state *cs, u8 key_id) +{ + struct ovpn_crypto_key_slot *ks = NULL; + + spin_lock_bh(&cs->lock); + if (rcu_access_pointer(cs->slots[0])->key_id == key_id) { + ks = rcu_replace_pointer(cs->slots[0], NULL, + lockdep_is_held(&cs->lock)); + } else if (rcu_access_pointer(cs->slots[1])->key_id == key_id) { + ks = rcu_replace_pointer(cs->slots[1], NULL, + lockdep_is_held(&cs->lock)); + } + spin_unlock_bh(&cs->lock); + + if (ks) + ovpn_crypto_key_slot_put(ks); +} + +/* Reset the ovpn_crypto_state object in a way that is atomic + * to RCU readers. + */ +int ovpn_crypto_state_reset(struct ovpn_crypto_state *cs, + const struct ovpn_peer_key_reset *pkr) +{ + struct ovpn_crypto_key_slot *old = NULL, *new; + u8 idx; + + if (pkr->slot != OVPN_KEY_SLOT_PRIMARY && + pkr->slot != OVPN_KEY_SLOT_SECONDARY) + return -EINVAL; + + new = ovpn_aead_crypto_key_slot_new(&pkr->key); + if (IS_ERR(new)) + return PTR_ERR(new); + + spin_lock_bh(&cs->lock); + idx = cs->primary_idx; + switch (pkr->slot) { + case OVPN_KEY_SLOT_PRIMARY: + old = rcu_replace_pointer(cs->slots[idx], new, + lockdep_is_held(&cs->lock)); + break; + case OVPN_KEY_SLOT_SECONDARY: + old = rcu_replace_pointer(cs->slots[!idx], new, + lockdep_is_held(&cs->lock)); + break; + } + spin_unlock_bh(&cs->lock); + + if (old) + ovpn_crypto_key_slot_put(old); + + return 0; +} + +void ovpn_crypto_key_slot_delete(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot) +{ + struct ovpn_crypto_key_slot *ks = NULL; + u8 idx; + + if (slot != OVPN_KEY_SLOT_PRIMARY && + slot != OVPN_KEY_SLOT_SECONDARY) { + pr_warn("Invalid slot to release: %u\n", slot); + return; + } + + spin_lock_bh(&cs->lock); + idx = cs->primary_idx; + switch (slot) { + case OVPN_KEY_SLOT_PRIMARY: + ks = rcu_replace_pointer(cs->slots[idx], NULL, + lockdep_is_held(&cs->lock)); + break; + case OVPN_KEY_SLOT_SECONDARY: + ks = rcu_replace_pointer(cs->slots[!idx], NULL, + lockdep_is_held(&cs->lock)); + break; + } + spin_unlock_bh(&cs->lock); + + if (!ks) { + pr_debug("Key slot already released: %u\n", slot); + return; + } + + pr_debug("deleting key slot %u, key_id=%u\n", slot, ks->key_id); + ovpn_crypto_key_slot_put(ks); +} + +/* this swap is not atomic, but there will be a very short time frame where the + * old_secondary key won't be available. This should not be a big deal as most + * likely both peers are already using the new primary at this point. + */ +void ovpn_crypto_key_slots_swap(struct ovpn_crypto_state *cs) +{ + const struct ovpn_crypto_key_slot *old_primary, *old_secondary; + u8 idx; + + spin_lock_bh(&cs->lock); + idx = cs->primary_idx; + old_primary = rcu_dereference_protected(cs->slots[idx], + lockdep_is_held(&cs->lock)); + old_secondary = rcu_dereference_protected(cs->slots[!idx], + lockdep_is_held(&cs->lock)); + /* perform real swap by switching the index of the primary key */ + cs->primary_idx = !cs->primary_idx; + + pr_debug("key swapped: (old primary) %d <-> (new primary) %d\n", + old_primary ? old_primary->key_id : -1, + old_secondary ? old_secondary->key_id : -1); + + spin_unlock_bh(&cs->lock); +} + +/** + * ovpn_crypto_config_get - populate keyconf object with non-sensible key data + * @cs: the crypto state to extract the key data from + * @slot: the specific slot to inspect + * @keyconf: the output object to populate + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_crypto_config_get(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot, + struct ovpn_key_config *keyconf) +{ + struct ovpn_crypto_key_slot *ks; + int idx; + + switch (slot) { + case OVPN_KEY_SLOT_PRIMARY: + idx = cs->primary_idx; + break; + case OVPN_KEY_SLOT_SECONDARY: + idx = !cs->primary_idx; + break; + default: + return -EINVAL; + } + + rcu_read_lock(); + ks = rcu_dereference(cs->slots[idx]); + if (!ks || (ks && !ovpn_crypto_key_slot_hold(ks))) { + rcu_read_unlock(); + return -ENOENT; + } + rcu_read_unlock(); + + keyconf->cipher_alg = ovpn_aead_crypto_alg(ks); + keyconf->key_id = ks->key_id; + + ovpn_crypto_key_slot_put(ks); + + return 0; +} diff --git a/drivers/net/ovpn/crypto.h b/drivers/net/ovpn/crypto.h new file mode 100644 index 000000000000..b7a7be752d54 --- /dev/null +++ b/drivers/net/ovpn/crypto.h @@ -0,0 +1,145 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPNCRYPTO_H_ +#define _NET_OVPN_OVPNCRYPTO_H_ + +#include "packet.h" +#include "pktid.h" + +/* info needed for both encrypt and decrypt directions */ +struct ovpn_key_direction { + const u8 *cipher_key; + size_t cipher_key_size; + const u8 *nonce_tail; /* only needed for GCM modes */ + size_t nonce_tail_size; /* only needed for GCM modes */ +}; + +/* all info for a particular symmetric key (primary or secondary) */ +struct ovpn_key_config { + enum ovpn_cipher_alg cipher_alg; + u8 key_id; + struct ovpn_key_direction encrypt; + struct ovpn_key_direction decrypt; +}; + +/* used to pass settings from netlink to the crypto engine */ +struct ovpn_peer_key_reset { + enum ovpn_key_slot slot; + struct ovpn_key_config key; +}; + +struct ovpn_crypto_key_slot { + u8 key_id; + + struct crypto_aead *encrypt; + struct crypto_aead *decrypt; + struct ovpn_nonce_tail nonce_tail_xmit; + struct ovpn_nonce_tail nonce_tail_recv; + + struct ovpn_pktid_recv pid_recv ____cacheline_aligned_in_smp; + struct ovpn_pktid_xmit pid_xmit ____cacheline_aligned_in_smp; + struct kref refcount; + struct rcu_head rcu; +}; + +struct ovpn_crypto_state { + struct ovpn_crypto_key_slot __rcu *slots[2]; + u8 primary_idx; + + /* protects primary and secondary slots */ + spinlock_t lock; +}; + +static inline bool ovpn_crypto_key_slot_hold(struct ovpn_crypto_key_slot *ks) +{ + return kref_get_unless_zero(&ks->refcount); +} + +static inline void ovpn_crypto_state_init(struct ovpn_crypto_state *cs) +{ + RCU_INIT_POINTER(cs->slots[0], NULL); + RCU_INIT_POINTER(cs->slots[1], NULL); + cs->primary_idx = 0; + spin_lock_init(&cs->lock); +} + +static inline struct ovpn_crypto_key_slot * +ovpn_crypto_key_id_to_slot(const struct ovpn_crypto_state *cs, u8 key_id) +{ + struct ovpn_crypto_key_slot *ks; + u8 idx; + + if (unlikely(!cs)) + return NULL; + + rcu_read_lock(); + idx = cs->primary_idx; + ks = rcu_dereference(cs->slots[idx]); + if (ks && ks->key_id == key_id) { + if (unlikely(!ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + goto out; + } + + ks = rcu_dereference(cs->slots[idx ^ 1]); + if (ks && ks->key_id == key_id) { + if (unlikely(!ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + goto out; + } + + /* when both key slots are occupied but no matching key ID is found, ks + * has to be reset to NULL to avoid carrying a stale pointer + */ + ks = NULL; +out: + rcu_read_unlock(); + + return ks; +} + +static inline struct ovpn_crypto_key_slot * +ovpn_crypto_key_slot_primary(const struct ovpn_crypto_state *cs) +{ + struct ovpn_crypto_key_slot *ks; + + rcu_read_lock(); + ks = rcu_dereference(cs->slots[cs->primary_idx]); + if (unlikely(ks && !ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + rcu_read_unlock(); + + return ks; +} + +void ovpn_crypto_key_slot_release(struct kref *kref); + +static inline void ovpn_crypto_key_slot_put(struct ovpn_crypto_key_slot *ks) +{ + kref_put(&ks->refcount, ovpn_crypto_key_slot_release); +} + +int ovpn_crypto_state_reset(struct ovpn_crypto_state *cs, + const struct ovpn_peer_key_reset *pkr); + +void ovpn_crypto_key_slot_delete(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot); + +void ovpn_crypto_state_release(struct ovpn_crypto_state *cs); + +void ovpn_crypto_key_slots_swap(struct ovpn_crypto_state *cs); + +int ovpn_crypto_config_get(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot, + struct ovpn_key_config *keyconf); + +void ovpn_crypto_kill_key(struct ovpn_crypto_state *cs, u8 key_id); + +#endif /* _NET_OVPN_OVPNCRYPTO_H_ */ diff --git a/drivers/net/ovpn/crypto_aead.c b/drivers/net/ovpn/crypto_aead.c new file mode 100644 index 000000000000..25e4e4a453b2 --- /dev/null +++ b/drivers/net/ovpn/crypto_aead.c @@ -0,0 +1,386 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "io.h" +#include "packet.h" +#include "pktid.h" +#include "crypto_aead.h" +#include "crypto.h" +#include "peer.h" +#include "proto.h" +#include "skb.h" + +#define AUTH_TAG_SIZE 16 + +#define ALG_NAME_AES "gcm(aes)" +#define ALG_NAME_CHACHAPOLY "rfc7539(chacha20,poly1305)" + +static int ovpn_aead_encap_overhead(const struct ovpn_crypto_key_slot *ks) +{ + return OVPN_OP_SIZE_V2 + /* OP header size */ + 4 + /* Packet ID */ + crypto_aead_authsize(ks->encrypt); /* Auth Tag */ +} + +int ovpn_aead_encrypt(struct ovpn_peer *peer, struct ovpn_crypto_key_slot *ks, + struct sk_buff *skb) +{ + const unsigned int tag_size = crypto_aead_authsize(ks->encrypt); + const unsigned int head_size = ovpn_aead_encap_overhead(ks); + struct aead_request *req; + struct sk_buff *trailer; + struct scatterlist *sg; + u8 iv[NONCE_SIZE]; + int nfrags, ret; + u32 pktid, op; + + ovpn_skb_cb(skb)->orig_len = skb->len; + ovpn_skb_cb(skb)->peer = peer; + ovpn_skb_cb(skb)->ks = ks; + + /* Sample AEAD header format: + * 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a... + * [ OP32 ] [seq # ] [ auth tag ] [ payload ... ] + * [4-byte + * IV head] + */ + + /* check that there's enough headroom in the skb for packet + * encapsulation, after adding network header and encryption overhead + */ + if (unlikely(skb_cow_head(skb, OVPN_HEAD_ROOM + head_size))) + return -ENOBUFS; + + /* get number of skb frags and ensure that packet data is writable */ + nfrags = skb_cow_data(skb, 0, &trailer); + if (unlikely(nfrags < 0)) + return nfrags; + + if (unlikely(nfrags + 2 > (MAX_SKB_FRAGS + 2))) + return -ENOSPC; + + ovpn_skb_cb(skb)->sg = kmalloc(sizeof(*ovpn_skb_cb(skb)->sg) * + (nfrags + 2), GFP_ATOMIC); + if (unlikely(!ovpn_skb_cb(skb)->sg)) + return -ENOMEM; + + sg = ovpn_skb_cb(skb)->sg; + + /* sg table: + * 0: op, wire nonce (AD, len=OVPN_OP_SIZE_V2+NONCE_WIRE_SIZE), + * 1, 2, 3, ..., n: payload, + * n+1: auth_tag (len=tag_size) + */ + sg_init_table(sg, nfrags + 2); + + /* build scatterlist to encrypt packet payload */ + ret = skb_to_sgvec_nomark(skb, sg + 1, 0, skb->len); + if (unlikely(nfrags != ret)) { + ret = -EINVAL; + goto free_sg; + } + + /* append auth_tag onto scatterlist */ + __skb_push(skb, tag_size); + sg_set_buf(sg + nfrags + 1, skb->data, tag_size); + + /* obtain packet ID, which is used both as a first + * 4 bytes of nonce and last 4 bytes of associated data. + */ + ret = ovpn_pktid_xmit_next(&ks->pid_xmit, &pktid); + if (unlikely(ret < 0)) + goto free_sg; + + /* concat 4 bytes packet id and 8 bytes nonce tail into 12 bytes + * nonce + */ + ovpn_pktid_aead_write(pktid, &ks->nonce_tail_xmit, iv); + + /* make space for packet id and push it to the front */ + __skb_push(skb, NONCE_WIRE_SIZE); + memcpy(skb->data, iv, NONCE_WIRE_SIZE); + + /* add packet op as head of additional data */ + op = ovpn_opcode_compose(OVPN_DATA_V2, ks->key_id, peer->id); + __skb_push(skb, OVPN_OP_SIZE_V2); + BUILD_BUG_ON(sizeof(op) != OVPN_OP_SIZE_V2); + *((__force __be32 *)skb->data) = htonl(op); + + /* AEAD Additional data */ + sg_set_buf(sg, skb->data, OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE); + + req = aead_request_alloc(ks->encrypt, GFP_ATOMIC); + if (unlikely(!req)) { + ret = -ENOMEM; + goto free_sg; + } + + ovpn_skb_cb(skb)->req = req; + + /* setup async crypto operation */ + aead_request_set_tfm(req, ks->encrypt); + aead_request_set_callback(req, 0, ovpn_encrypt_post, skb); + aead_request_set_crypt(req, sg, sg, skb->len - head_size, iv); + aead_request_set_ad(req, OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE); + + /* encrypt it */ + return crypto_aead_encrypt(req); +free_sg: + kfree(ovpn_skb_cb(skb)->sg); + ovpn_skb_cb(skb)->sg = NULL; + return ret; +} + +int ovpn_aead_decrypt(struct ovpn_peer *peer, struct ovpn_crypto_key_slot *ks, + struct sk_buff *skb) +{ + const unsigned int tag_size = crypto_aead_authsize(ks->decrypt); + int ret, payload_len, nfrags; + unsigned int payload_offset; + struct aead_request *req; + struct sk_buff *trailer; + struct scatterlist *sg; + unsigned int sg_len; + u8 iv[NONCE_SIZE]; + + payload_offset = OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE + tag_size; + payload_len = skb->len - payload_offset; + + ovpn_skb_cb(skb)->orig_len = skb->len; + ovpn_skb_cb(skb)->payload_offset = payload_offset; + ovpn_skb_cb(skb)->peer = peer; + ovpn_skb_cb(skb)->ks = ks; + + /* sanity check on packet size, payload size must be >= 0 */ + if (unlikely(payload_len < 0)) + return -EINVAL; + + /* Prepare the skb data buffer to be accessed up until the auth tag. + * This is required because this area is directly mapped into the sg + * list. + */ + if (unlikely(!pskb_may_pull(skb, payload_offset))) + return -ENODATA; + + /* get number of skb frags and ensure that packet data is writable */ + nfrags = skb_cow_data(skb, 0, &trailer); + if (unlikely(nfrags < 0)) + return nfrags; + + if (unlikely(nfrags + 2 > (MAX_SKB_FRAGS + 2))) + return -ENOSPC; + + ovpn_skb_cb(skb)->sg = kmalloc(sizeof(*ovpn_skb_cb(skb)->sg) * + (nfrags + 2), GFP_ATOMIC); + if (unlikely(!ovpn_skb_cb(skb)->sg)) + return -ENOMEM; + + sg = ovpn_skb_cb(skb)->sg; + + /* sg table: + * 0: op, wire nonce (AD, len=OVPN_OP_SIZE_V2+NONCE_WIRE_SIZE), + * 1, 2, 3, ..., n: payload, + * n+1: auth_tag (len=tag_size) + */ + sg_init_table(sg, nfrags + 2); + + /* packet op is head of additional data */ + sg_len = OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE; + sg_set_buf(sg, skb->data, sg_len); + + /* build scatterlist to decrypt packet payload */ + ret = skb_to_sgvec_nomark(skb, sg + 1, payload_offset, payload_len); + if (unlikely(nfrags != ret)) { + ret = -EINVAL; + goto free_sg; + } + + /* append auth_tag onto scatterlist */ + sg_set_buf(sg + nfrags + 1, skb->data + sg_len, tag_size); + + /* copy nonce into IV buffer */ + memcpy(iv, skb->data + OVPN_OP_SIZE_V2, NONCE_WIRE_SIZE); + memcpy(iv + NONCE_WIRE_SIZE, ks->nonce_tail_recv.u8, + sizeof(struct ovpn_nonce_tail)); + + req = aead_request_alloc(ks->decrypt, GFP_ATOMIC); + if (unlikely(!req)) { + ret = -ENOMEM; + goto free_sg; + } + + ovpn_skb_cb(skb)->req = req; + + /* setup async crypto operation */ + aead_request_set_tfm(req, ks->decrypt); + aead_request_set_callback(req, 0, ovpn_decrypt_post, skb); + aead_request_set_crypt(req, sg, sg, payload_len + tag_size, iv); + + aead_request_set_ad(req, NONCE_WIRE_SIZE + OVPN_OP_SIZE_V2); + + /* decrypt it */ + return crypto_aead_decrypt(req); +free_sg: + kfree(ovpn_skb_cb(skb)->sg); + ovpn_skb_cb(skb)->sg = NULL; + return ret; +} + +/* Initialize a struct crypto_aead object */ +struct crypto_aead *ovpn_aead_init(const char *title, const char *alg_name, + const unsigned char *key, + unsigned int keylen) +{ + struct crypto_aead *aead; + int ret; + + aead = crypto_alloc_aead(alg_name, 0, 0); + if (IS_ERR(aead)) { + ret = PTR_ERR(aead); + pr_err("%s crypto_alloc_aead failed, err=%d\n", title, ret); + aead = NULL; + goto error; + } + + ret = crypto_aead_setkey(aead, key, keylen); + if (ret) { + pr_err("%s crypto_aead_setkey size=%u failed, err=%d\n", title, + keylen, ret); + goto error; + } + + ret = crypto_aead_setauthsize(aead, AUTH_TAG_SIZE); + if (ret) { + pr_err("%s crypto_aead_setauthsize failed, err=%d\n", title, + ret); + goto error; + } + + /* basic AEAD assumption */ + if (crypto_aead_ivsize(aead) != NONCE_SIZE) { + pr_err("%s IV size must be %d\n", title, NONCE_SIZE); + ret = -EINVAL; + goto error; + } + + pr_debug("********* Cipher %s (%s)\n", alg_name, title); + pr_debug("*** IV size=%u\n", crypto_aead_ivsize(aead)); + pr_debug("*** req size=%u\n", crypto_aead_reqsize(aead)); + pr_debug("*** block size=%u\n", crypto_aead_blocksize(aead)); + pr_debug("*** auth size=%u\n", crypto_aead_authsize(aead)); + pr_debug("*** alignmask=0x%x\n", crypto_aead_alignmask(aead)); + + return aead; + +error: + crypto_free_aead(aead); + return ERR_PTR(ret); +} + +void ovpn_aead_crypto_key_slot_destroy(struct ovpn_crypto_key_slot *ks) +{ + if (!ks) + return; + + crypto_free_aead(ks->encrypt); + crypto_free_aead(ks->decrypt); + kfree(ks); +} + +struct ovpn_crypto_key_slot * +ovpn_aead_crypto_key_slot_new(const struct ovpn_key_config *kc) +{ + struct ovpn_crypto_key_slot *ks = NULL; + const char *alg_name; + int ret; + + /* validate crypto alg */ + switch (kc->cipher_alg) { + case OVPN_CIPHER_ALG_AES_GCM: + alg_name = ALG_NAME_AES; + break; + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + alg_name = ALG_NAME_CHACHAPOLY; + break; + default: + return ERR_PTR(-EOPNOTSUPP); + } + + if (sizeof(struct ovpn_nonce_tail) != kc->encrypt.nonce_tail_size || + sizeof(struct ovpn_nonce_tail) != kc->decrypt.nonce_tail_size) + return ERR_PTR(-EINVAL); + + /* build the key slot */ + ks = kmalloc(sizeof(*ks), GFP_KERNEL); + if (!ks) + return ERR_PTR(-ENOMEM); + + ks->encrypt = NULL; + ks->decrypt = NULL; + kref_init(&ks->refcount); + ks->key_id = kc->key_id; + + ks->encrypt = ovpn_aead_init("encrypt", alg_name, + kc->encrypt.cipher_key, + kc->encrypt.cipher_key_size); + if (IS_ERR(ks->encrypt)) { + ret = PTR_ERR(ks->encrypt); + ks->encrypt = NULL; + goto destroy_ks; + } + + ks->decrypt = ovpn_aead_init("decrypt", alg_name, + kc->decrypt.cipher_key, + kc->decrypt.cipher_key_size); + if (IS_ERR(ks->decrypt)) { + ret = PTR_ERR(ks->decrypt); + ks->decrypt = NULL; + goto destroy_ks; + } + + memcpy(ks->nonce_tail_xmit.u8, kc->encrypt.nonce_tail, + sizeof(struct ovpn_nonce_tail)); + memcpy(ks->nonce_tail_recv.u8, kc->decrypt.nonce_tail, + sizeof(struct ovpn_nonce_tail)); + + /* init packet ID generation/validation */ + ovpn_pktid_xmit_init(&ks->pid_xmit); + ovpn_pktid_recv_init(&ks->pid_recv); + + return ks; + +destroy_ks: + ovpn_aead_crypto_key_slot_destroy(ks); + return ERR_PTR(ret); +} + +enum ovpn_cipher_alg ovpn_aead_crypto_alg(struct ovpn_crypto_key_slot *ks) +{ + const char *alg_name; + + if (!ks->encrypt) + return OVPN_CIPHER_ALG_NONE; + + alg_name = crypto_tfm_alg_name(crypto_aead_tfm(ks->encrypt)); + + if (!strcmp(alg_name, ALG_NAME_AES)) + return OVPN_CIPHER_ALG_AES_GCM; + else if (!strcmp(alg_name, ALG_NAME_CHACHAPOLY)) + return OVPN_CIPHER_ALG_CHACHA20_POLY1305; + else + return OVPN_CIPHER_ALG_NONE; +} diff --git a/drivers/net/ovpn/crypto_aead.h b/drivers/net/ovpn/crypto_aead.h new file mode 100644 index 000000000000..fb65be82436e --- /dev/null +++ b/drivers/net/ovpn/crypto_aead.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPNAEAD_H_ +#define _NET_OVPN_OVPNAEAD_H_ + +#include "crypto.h" + +#include +#include + +struct crypto_aead *ovpn_aead_init(const char *title, const char *alg_name, + const unsigned char *key, + unsigned int keylen); + +int ovpn_aead_encrypt(struct ovpn_peer *peer, struct ovpn_crypto_key_slot *ks, + struct sk_buff *skb); +int ovpn_aead_decrypt(struct ovpn_peer *peer, struct ovpn_crypto_key_slot *ks, + struct sk_buff *skb); + +struct ovpn_crypto_key_slot * +ovpn_aead_crypto_key_slot_new(const struct ovpn_key_config *kc); +void ovpn_aead_crypto_key_slot_destroy(struct ovpn_crypto_key_slot *ks); + +enum ovpn_cipher_alg ovpn_aead_crypto_alg(struct ovpn_crypto_key_slot *ks); + +#endif /* _NET_OVPN_OVPNAEAD_H_ */ diff --git a/drivers/net/ovpn/io.c b/drivers/net/ovpn/io.c new file mode 100644 index 000000000000..c04791a508e5 --- /dev/null +++ b/drivers/net/ovpn/io.c @@ -0,0 +1,462 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "peer.h" +#include "io.h" +#include "bind.h" +#include "crypto.h" +#include "crypto_aead.h" +#include "netlink.h" +#include "proto.h" +#include "tcp.h" +#include "udp.h" +#include "skb.h" +#include "socket.h" + +const unsigned char ovpn_keepalive_message[OVPN_KEEPALIVE_SIZE] = { + 0x2a, 0x18, 0x7b, 0xf3, 0x64, 0x1e, 0xb4, 0xcb, + 0x07, 0xed, 0x2d, 0x0a, 0x98, 0x1f, 0xc7, 0x48 +}; + +/** + * ovpn_is_keepalive - check if skb contains a keepalive message + * @skb: packet to check + * + * Assumes that the first byte of skb->data is defined. + * + * Return: true if skb contains a keepalive or false otherwise + */ +static bool ovpn_is_keepalive(struct sk_buff *skb) +{ + if (*skb->data != ovpn_keepalive_message[0]) + return false; + + if (skb->len != OVPN_KEEPALIVE_SIZE) + return false; + + if (!pskb_may_pull(skb, OVPN_KEEPALIVE_SIZE)) + return false; + + return !memcmp(skb->data, ovpn_keepalive_message, OVPN_KEEPALIVE_SIZE); +} + +/* Called after decrypt to write the IP packet to the device. + * This method is expected to manage/free the skb. + */ +static void ovpn_netdev_write(struct ovpn_peer *peer, struct sk_buff *skb) +{ + unsigned int pkt_len; + + /* we can't guarantee the packet wasn't corrupted before entering the + * VPN, therefore we give other layers a chance to check that + */ + skb->ip_summed = CHECKSUM_NONE; + + /* skb hash for transport packet no longer valid after decapsulation */ + skb_clear_hash(skb); + + /* post-decrypt scrub -- prepare to inject encapsulated packet onto the + * interface, based on __skb_tunnel_rx() in dst.h + */ + skb->dev = peer->ovpn->dev; + skb_set_queue_mapping(skb, 0); + skb_scrub_packet(skb, true); + + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_probe_transport_header(skb); + skb_reset_inner_headers(skb); + + memset(skb->cb, 0, sizeof(skb->cb)); + + /* cause packet to be "received" by the interface */ + pkt_len = skb->len; + if (likely(gro_cells_receive(&peer->ovpn->gro_cells, + skb) == NET_RX_SUCCESS)) + /* update RX stats with the size of decrypted packet */ + dev_sw_netstats_rx_add(peer->ovpn->dev, pkt_len); +} + +void ovpn_decrypt_post(void *data, int ret) +{ + struct ovpn_crypto_key_slot *ks; + unsigned int payload_offset = 0; + struct sk_buff *skb = data; + struct ovpn_peer *peer; + unsigned int orig_len; + __be16 proto; + __be32 *pid; + + /* crypto is happening asynchronously. this function will be called + * again later by the crypto callback with a proper return code + */ + if (unlikely(ret == -EINPROGRESS)) + return; + + payload_offset = ovpn_skb_cb(skb)->payload_offset; + ks = ovpn_skb_cb(skb)->ks; + peer = ovpn_skb_cb(skb)->peer; + orig_len = ovpn_skb_cb(skb)->orig_len; + + /* crypto is done, cleanup skb CB and its members */ + + if (likely(ovpn_skb_cb(skb)->sg)) + kfree(ovpn_skb_cb(skb)->sg); + + if (likely(ovpn_skb_cb(skb)->req)) + aead_request_free(ovpn_skb_cb(skb)->req); + + if (unlikely(ret < 0)) + goto drop; + + /* PID sits after the op */ + pid = (__force __be32 *)(skb->data + OVPN_OP_SIZE_V2); + ret = ovpn_pktid_recv(&ks->pid_recv, ntohl(*pid), 0); + if (unlikely(ret < 0)) { + net_err_ratelimited("%s: PKT ID RX error: %d\n", + peer->ovpn->dev->name, ret); + goto drop; + } + + /* keep track of last received authenticated packet for keepalive */ + peer->last_recv = ktime_get_real_seconds(); + + if (peer->sock->sock->sk->sk_protocol == IPPROTO_UDP) { + /* check if this peer changed it's IP address and update + * state + */ + ovpn_peer_float(peer, skb); + /* update source endpoint for this peer */ + ovpn_peer_update_local_endpoint(peer, skb); + } + + /* point to encapsulated IP packet */ + __skb_pull(skb, payload_offset); + + /* check if this is a valid datapacket that has to be delivered to the + * ovpn interface + */ + skb_reset_network_header(skb); + proto = ovpn_ip_check_protocol(skb); + if (unlikely(!proto)) { + /* check if null packet */ + if (unlikely(!pskb_may_pull(skb, 1))) { + net_info_ratelimited("%s: NULL packet received from peer %u\n", + peer->ovpn->dev->name, peer->id); + goto drop; + } + + if (ovpn_is_keepalive(skb)) { + net_dbg_ratelimited("%s: ping received from peer %u\n", + peer->ovpn->dev->name, peer->id); + goto drop; + } + + net_info_ratelimited("%s: unsupported protocol received from peer %u\n", + peer->ovpn->dev->name, peer->id); + goto drop; + } + skb->protocol = proto; + + /* perform Reverse Path Filtering (RPF) */ + if (unlikely(!ovpn_peer_check_by_src(peer->ovpn, skb, peer))) { + if (skb_protocol_to_family(skb) == AF_INET6) + net_dbg_ratelimited("%s: RPF dropped packet from peer %u, src: %pI6c\n", + peer->ovpn->dev->name, peer->id, + &ipv6_hdr(skb)->saddr); + else + net_dbg_ratelimited("%s: RPF dropped packet from peer %u, src: %pI4\n", + peer->ovpn->dev->name, peer->id, + &ip_hdr(skb)->saddr); + goto drop; + } + + /* increment RX stats */ + ovpn_peer_stats_increment_rx(&peer->vpn_stats, skb->len); + ovpn_peer_stats_increment_rx(&peer->link_stats, orig_len); + + ovpn_netdev_write(peer, skb); + /* skb is passed to upper layer - don't free it */ + skb = NULL; +drop: + if (unlikely(skb)) + dev_core_stats_rx_dropped_inc(peer->ovpn->dev); + if (likely(peer)) + ovpn_peer_put(peer); + if (likely(ks)) + ovpn_crypto_key_slot_put(ks); + kfree_skb(skb); +} + +/* pick next packet from RX queue, decrypt and forward it to the device */ +void ovpn_recv(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct ovpn_crypto_key_slot *ks; + u8 key_id; + + /* get the key slot matching the key ID in the received packet */ + key_id = ovpn_key_id_from_skb(skb); + ks = ovpn_crypto_key_id_to_slot(&peer->crypto, key_id); + if (unlikely(!ks)) { + net_info_ratelimited("%s: no available key for peer %u, key-id: %u\n", + peer->ovpn->dev->name, peer->id, key_id); + dev_core_stats_rx_dropped_inc(peer->ovpn->dev); + kfree_skb(skb); + return; + } + + memset(ovpn_skb_cb(skb), 0, sizeof(struct ovpn_cb)); + ovpn_decrypt_post(skb, ovpn_aead_decrypt(peer, ks, skb)); +} + +void ovpn_encrypt_post(void *data, int ret) +{ + struct ovpn_crypto_key_slot *ks; + struct sk_buff *skb = data; + struct ovpn_peer *peer; + unsigned int orig_len; + + /* encryption is happening asynchronously. This function will be + * called later by the crypto callback with a proper return value + */ + if (unlikely(ret == -EINPROGRESS)) + return; + + ks = ovpn_skb_cb(skb)->ks; + peer = ovpn_skb_cb(skb)->peer; + orig_len = ovpn_skb_cb(skb)->orig_len; + + /* crypto is done, cleanup skb CB and its members */ + + if (likely(ovpn_skb_cb(skb)->sg)) + kfree(ovpn_skb_cb(skb)->sg); + + if (likely(ovpn_skb_cb(skb)->req)) + aead_request_free(ovpn_skb_cb(skb)->req); + + if (unlikely(ret == -ERANGE)) { + /* we ran out of IVs and we must kill the key as it can't be + * use anymore + */ + netdev_warn(peer->ovpn->dev, + "killing key %u for peer %u\n", ks->key_id, + peer->id); + ovpn_crypto_kill_key(&peer->crypto, ks->key_id); + /* let userspace know so that a new key must be negotiated */ + ovpn_nl_key_swap_notify(peer, ks->key_id); + goto err; + } + + if (unlikely(ret < 0)) + goto err; + + skb_mark_not_on_list(skb); + ovpn_peer_stats_increment_tx(&peer->link_stats, skb->len); + ovpn_peer_stats_increment_tx(&peer->vpn_stats, orig_len); + + switch (peer->sock->sock->sk->sk_protocol) { + case IPPROTO_UDP: + ovpn_udp_send_skb(peer->ovpn, peer, skb); + break; + case IPPROTO_TCP: + ovpn_tcp_send_skb(peer, skb); + break; + default: + /* no transport configured yet */ + goto err; + } + + /* keep track of last sent packet for keepalive */ + peer->last_sent = ktime_get_real_seconds(); + + /* skb passed down the stack - don't free it */ + skb = NULL; +err: + if (unlikely(skb)) + dev_core_stats_tx_dropped_inc(peer->ovpn->dev); + if (likely(peer)) + ovpn_peer_put(peer); + if (likely(ks)) + ovpn_crypto_key_slot_put(ks); + kfree_skb(skb); +} + +static bool ovpn_encrypt_one(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct ovpn_crypto_key_slot *ks; + + if (unlikely(skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_help(skb))) { + net_warn_ratelimited("%s: cannot compute checksum for outgoing packet\n", + peer->ovpn->dev->name); + return false; + } + + /* get primary key to be used for encrypting data */ + ks = ovpn_crypto_key_slot_primary(&peer->crypto); + if (unlikely(!ks)) { + net_warn_ratelimited("%s: error while retrieving primary key slot for peer %u\n", + peer->ovpn->dev->name, peer->id); + return false; + } + + /* take a reference to the peer because the crypto code may run async. + * ovpn_encrypt_post() will release it upon completion + */ + if (unlikely(!ovpn_peer_hold(peer))) { + DEBUG_NET_WARN_ON_ONCE(1); + return false; + } + + memset(ovpn_skb_cb(skb), 0, sizeof(struct ovpn_cb)); + ovpn_encrypt_post(skb, ovpn_aead_encrypt(peer, ks, skb)); + return true; +} + +/* send skb to connected peer, if any */ +static void ovpn_send(struct ovpn_struct *ovpn, struct sk_buff *skb, + struct ovpn_peer *peer) +{ + struct sk_buff *curr, *next; + + if (likely(!peer)) + /* retrieve peer serving the destination IP of this packet */ + peer = ovpn_peer_get_by_dst(ovpn, skb); + if (unlikely(!peer)) { + net_dbg_ratelimited("%s: no peer to send data to\n", + ovpn->dev->name); + dev_core_stats_tx_dropped_inc(ovpn->dev); + goto drop; + } + + /* this might be a GSO-segmented skb list: process each skb + * independently + */ + skb_list_walk_safe(skb, curr, next) + if (unlikely(!ovpn_encrypt_one(peer, curr))) { + dev_core_stats_tx_dropped_inc(ovpn->dev); + kfree_skb(curr); + } + + /* skb passed over, no need to free */ + skb = NULL; +drop: + if (likely(peer)) + ovpn_peer_put(peer); + kfree_skb_list(skb); +} + +/* Send user data to the network + */ +netdev_tx_t ovpn_net_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + struct sk_buff *segments, *curr, *next; + struct sk_buff_head skb_list; + __be16 proto; + int ret; + + /* reset netfilter state */ + nf_reset_ct(skb); + + /* verify IP header size in network packet */ + proto = ovpn_ip_check_protocol(skb); + if (unlikely(!proto || skb->protocol != proto)) { + net_err_ratelimited("%s: dropping malformed payload packet\n", + dev->name); + dev_core_stats_tx_dropped_inc(ovpn->dev); + goto drop; + } + + if (skb_is_gso(skb)) { + segments = skb_gso_segment(skb, 0); + if (IS_ERR(segments)) { + ret = PTR_ERR(segments); + net_err_ratelimited("%s: cannot segment packet: %d\n", + dev->name, ret); + dev_core_stats_tx_dropped_inc(ovpn->dev); + goto drop; + } + + consume_skb(skb); + skb = segments; + } + + /* from this moment on, "skb" might be a list */ + + __skb_queue_head_init(&skb_list); + skb_list_walk_safe(skb, curr, next) { + skb_mark_not_on_list(curr); + + curr = skb_share_check(curr, GFP_ATOMIC); + if (unlikely(!curr)) { + net_err_ratelimited("%s: skb_share_check failed\n", + dev->name); + dev_core_stats_tx_dropped_inc(ovpn->dev); + continue; + } + + __skb_queue_tail(&skb_list, curr); + } + skb_list.prev->next = NULL; + + ovpn_send(ovpn, skb_list.next, NULL); + + return NETDEV_TX_OK; + +drop: + skb_tx_error(skb); + kfree_skb_list(skb); + return NET_XMIT_DROP; +} + +/** + * ovpn_xmit_special - encrypt and transmit an out-of-band message to peer + * @peer: peer to send the message to + * @data: message content + * @len: message length + * + * Assumes that caller holds a reference to peer + */ +void ovpn_xmit_special(struct ovpn_peer *peer, const void *data, + const unsigned int len) +{ + struct ovpn_struct *ovpn; + struct sk_buff *skb; + + ovpn = peer->ovpn; + if (unlikely(!ovpn)) + return; + + skb = alloc_skb(256 + len, GFP_ATOMIC); + if (unlikely(!skb)) + return; + + skb_reserve(skb, 128); + skb->priority = TC_PRIO_BESTEFFORT; + __skb_put_data(skb, data, len); + + /* increase reference counter when passing peer to sending queue */ + if (!ovpn_peer_hold(peer)) { + netdev_dbg(ovpn->dev, "%s: cannot hold peer reference for sending special packet\n", + __func__); + kfree_skb(skb); + return; + } + + ovpn_send(ovpn, skb, peer); +} diff --git a/drivers/net/ovpn/io.h b/drivers/net/ovpn/io.h new file mode 100644 index 000000000000..eb224114152c --- /dev/null +++ b/drivers/net/ovpn/io.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPN_H_ +#define _NET_OVPN_OVPN_H_ + +#define OVPN_KEEPALIVE_SIZE 16 +extern const unsigned char ovpn_keepalive_message[OVPN_KEEPALIVE_SIZE]; + +netdev_tx_t ovpn_net_xmit(struct sk_buff *skb, struct net_device *dev); + +void ovpn_recv(struct ovpn_peer *peer, struct sk_buff *skb); +void ovpn_xmit_special(struct ovpn_peer *peer, const void *data, + const unsigned int len); + +void ovpn_encrypt_post(void *data, int ret); +void ovpn_decrypt_post(void *data, int ret); + +#endif /* _NET_OVPN_OVPN_H_ */ diff --git a/drivers/net/ovpn/main.c b/drivers/net/ovpn/main.c new file mode 100644 index 000000000000..9dcf51ae1497 --- /dev/null +++ b/drivers/net/ovpn/main.c @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "netlink.h" +#include "io.h" +#include "packet.h" +#include "peer.h" +#include "tcp.h" + +/* Driver info */ +#define DRV_DESCRIPTION "OpenVPN data channel offload (ovpn)" +#define DRV_COPYRIGHT "(C) 2020-2024 OpenVPN, Inc." + +static void ovpn_struct_free(struct net_device *net) +{ + struct ovpn_struct *ovpn = netdev_priv(net); + + kfree(ovpn->peers); +} + +static int ovpn_net_init(struct net_device *dev) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + + return gro_cells_init(&ovpn->gro_cells, dev); +} + +static void ovpn_net_uninit(struct net_device *dev) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + + gro_cells_destroy(&ovpn->gro_cells); +} + +static int ovpn_net_open(struct net_device *dev) +{ + /* ovpn keeps the carrier always on to avoid losing IP or route + * configuration upon disconnection. This way it can prevent leaks + * of traffic outside of the VPN tunnel. + * The user may override this behaviour by tearing down the interface + * manually. + */ + netif_carrier_on(dev); + netif_tx_start_all_queues(dev); + return 0; +} + +static int ovpn_net_stop(struct net_device *dev) +{ + netif_tx_stop_all_queues(dev); + return 0; +} + +static const struct net_device_ops ovpn_netdev_ops = { + .ndo_init = ovpn_net_init, + .ndo_uninit = ovpn_net_uninit, + .ndo_open = ovpn_net_open, + .ndo_stop = ovpn_net_stop, + .ndo_start_xmit = ovpn_net_xmit, +}; + +static const struct device_type ovpn_type = { + .name = OVPN_FAMILY_NAME, +}; + +static const struct nla_policy ovpn_policy[IFLA_OVPN_MAX + 1] = { + [IFLA_OVPN_MODE] = NLA_POLICY_RANGE(NLA_U8, OVPN_MODE_P2P, + OVPN_MODE_MP), +}; + +/** + * ovpn_dev_is_valid - check if the netdevice is of type 'ovpn' + * @dev: the interface to check + * + * Return: whether the netdevice is of type 'ovpn' + */ +bool ovpn_dev_is_valid(const struct net_device *dev) +{ + return dev->netdev_ops->ndo_start_xmit == ovpn_net_xmit; +} + +static void ovpn_get_drvinfo(struct net_device *dev, + struct ethtool_drvinfo *info) +{ + strscpy(info->driver, OVPN_FAMILY_NAME, sizeof(info->driver)); + strscpy(info->bus_info, "ovpn", sizeof(info->bus_info)); +} + +static const struct ethtool_ops ovpn_ethtool_ops = { + .get_drvinfo = ovpn_get_drvinfo, + .get_link = ethtool_op_get_link, + .get_ts_info = ethtool_op_get_ts_info, +}; + +static void ovpn_setup(struct net_device *dev) +{ + /* compute the overhead considering AEAD encryption */ + const int overhead = sizeof(u32) + NONCE_WIRE_SIZE + 16 + + sizeof(struct udphdr) + + max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); + + netdev_features_t feat = NETIF_F_SG | NETIF_F_HW_CSUM | NETIF_F_RXCSUM | + NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | + NETIF_F_HIGHDMA; + + dev->needs_free_netdev = true; + + dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS; + + dev->ethtool_ops = &ovpn_ethtool_ops; + dev->netdev_ops = &ovpn_netdev_ops; + + dev->priv_destructor = ovpn_struct_free; + + dev->hard_header_len = 0; + dev->addr_len = 0; + dev->mtu = ETH_DATA_LEN - overhead; + dev->min_mtu = IPV4_MIN_MTU; + dev->max_mtu = IP_MAX_MTU - overhead; + + dev->type = ARPHRD_NONE; + dev->flags = IFF_POINTOPOINT | IFF_NOARP; + dev->priv_flags |= IFF_NO_QUEUE; + + dev->lltx = true; + dev->features |= feat; + dev->hw_features |= feat; + dev->hw_enc_features |= feat; + + dev->needed_headroom = OVPN_HEAD_ROOM; + dev->needed_tailroom = OVPN_MAX_PADDING; + + SET_NETDEV_DEVTYPE(dev, &ovpn_type); +} + +static int ovpn_mp_alloc(struct ovpn_struct *ovpn) +{ + struct in_device *dev_v4; + int i; + + if (ovpn->mode != OVPN_MODE_MP) + return 0; + + dev_v4 = __in_dev_get_rtnl(ovpn->dev); + if (dev_v4) { + /* disable redirects as Linux gets confused by ovpn + * handling same-LAN routing. + * This happens because a multipeer interface is used as + * relay point between hosts in the same subnet, while + * in a classic LAN this would not be needed because the + * two hosts would be able to talk directly. + */ + IN_DEV_CONF_SET(dev_v4, SEND_REDIRECTS, false); + IPV4_DEVCONF_ALL(dev_net(ovpn->dev), SEND_REDIRECTS) = false; + } + + /* the peer container is fairly large, therefore we allocate it only in + * MP mode + */ + ovpn->peers = kzalloc(sizeof(*ovpn->peers), GFP_KERNEL); + if (!ovpn->peers) + return -ENOMEM; + + spin_lock_init(&ovpn->peers->lock); + + for (i = 0; i < ARRAY_SIZE(ovpn->peers->by_id); i++) { + INIT_HLIST_HEAD(&ovpn->peers->by_id[i]); + INIT_HLIST_NULLS_HEAD(&ovpn->peers->by_vpn_addr[i], i); + INIT_HLIST_NULLS_HEAD(&ovpn->peers->by_transp_addr[i], i); + } + + return 0; +} + +static int ovpn_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + enum ovpn_mode mode = OVPN_MODE_P2P; + int err; + + if (data && data[IFLA_OVPN_MODE]) { + mode = nla_get_u8(data[IFLA_OVPN_MODE]); + netdev_dbg(dev, "setting device mode: %u\n", mode); + } + + ovpn->dev = dev; + ovpn->mode = mode; + spin_lock_init(&ovpn->lock); + INIT_DELAYED_WORK(&ovpn->keepalive_work, ovpn_peer_keepalive_work); + + err = ovpn_mp_alloc(ovpn); + if (err < 0) + return err; + + /* turn carrier explicitly off after registration, this way state is + * clearly defined + */ + netif_carrier_off(dev); + + return register_netdevice(dev); +} + +static struct rtnl_link_ops ovpn_link_ops = { + .kind = OVPN_FAMILY_NAME, + .netns_refund = false, + .priv_size = sizeof(struct ovpn_struct), + .setup = ovpn_setup, + .policy = ovpn_policy, + .maxtype = IFLA_OVPN_MAX, + .newlink = ovpn_newlink, + .dellink = unregister_netdevice_queue, +}; + +static int ovpn_netdev_notifier_call(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct ovpn_struct *ovpn; + + if (!ovpn_dev_is_valid(dev)) + return NOTIFY_DONE; + + ovpn = netdev_priv(dev); + + switch (state) { + case NETDEV_REGISTER: + ovpn->registered = true; + break; + case NETDEV_UNREGISTER: + /* twiddle thumbs on netns device moves */ + if (dev->reg_state != NETREG_UNREGISTERING) + break; + + /* can be delivered multiple times, so check registered flag, + * then destroy the interface + */ + if (!ovpn->registered) + return NOTIFY_DONE; + + netif_carrier_off(dev); + ovpn->registered = false; + + cancel_delayed_work_sync(&ovpn->keepalive_work); + + switch (ovpn->mode) { + case OVPN_MODE_P2P: + ovpn_peer_release_p2p(ovpn); + break; + case OVPN_MODE_MP: + ovpn_peers_free(ovpn); + break; + } + break; + case NETDEV_POST_INIT: + case NETDEV_GOING_DOWN: + case NETDEV_DOWN: + case NETDEV_UP: + case NETDEV_PRE_UP: + break; + default: + return NOTIFY_DONE; + } + + return NOTIFY_OK; +} + +static struct notifier_block ovpn_netdev_notifier = { + .notifier_call = ovpn_netdev_notifier_call, +}; + +static int __init ovpn_init(void) +{ + int err = register_netdevice_notifier(&ovpn_netdev_notifier); + + if (err) { + pr_err("ovpn: can't register netdevice notifier: %d\n", err); + return err; + } + + err = rtnl_link_register(&ovpn_link_ops); + if (err) { + pr_err("ovpn: can't register rtnl link ops: %d\n", err); + goto unreg_netdev; + } + + err = ovpn_nl_register(); + if (err) { + pr_err("ovpn: can't register netlink family: %d\n", err); + goto unreg_rtnl; + } + + ovpn_tcp_init(); + + return 0; + +unreg_rtnl: + rtnl_link_unregister(&ovpn_link_ops); +unreg_netdev: + unregister_netdevice_notifier(&ovpn_netdev_notifier); + return err; +} + +static __exit void ovpn_cleanup(void) +{ + ovpn_nl_unregister(); + rtnl_link_unregister(&ovpn_link_ops); + unregister_netdevice_notifier(&ovpn_netdev_notifier); + + rcu_barrier(); +} + +module_init(ovpn_init); +module_exit(ovpn_cleanup); + +MODULE_DESCRIPTION(DRV_DESCRIPTION); +MODULE_AUTHOR(DRV_COPYRIGHT); +MODULE_LICENSE("GPL"); diff --git a/drivers/net/ovpn/main.h b/drivers/net/ovpn/main.h new file mode 100644 index 000000000000..28e5c44816e1 --- /dev/null +++ b/drivers/net/ovpn/main.h @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_MAIN_H_ +#define _NET_OVPN_MAIN_H_ + +bool ovpn_dev_is_valid(const struct net_device *dev); + +#define SKB_HEADER_LEN \ + (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + \ + sizeof(struct udphdr) + NET_SKB_PAD) + +#define OVPN_HEAD_ROOM ALIGN(16 + SKB_HEADER_LEN, 4) +#define OVPN_MAX_PADDING 16 + +#define OVPN_QUEUE_LEN 1024 + +#endif /* _NET_OVPN_MAIN_H_ */ diff --git a/drivers/net/ovpn/netlink-gen.c b/drivers/net/ovpn/netlink-gen.c new file mode 100644 index 000000000000..6a43eab9a136 --- /dev/null +++ b/drivers/net/ovpn/netlink-gen.c @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/ovpn.yaml */ +/* YNL-GEN kernel source */ + +#include +#include + +#include "netlink-gen.h" + +#include + +/* Integer value ranges */ +static const struct netlink_range_validation ovpn_a_peer_id_range = { + .max = 16777215ULL, +}; + +static const struct netlink_range_validation ovpn_a_keyconf_peer_id_range = { + .max = 16777215ULL, +}; + +/* Common nested types */ +const struct nla_policy ovpn_keyconf_nl_policy[OVPN_A_KEYCONF_DECRYPT_DIR + 1] = { + [OVPN_A_KEYCONF_PEER_ID] = NLA_POLICY_FULL_RANGE(NLA_U32, &ovpn_a_keyconf_peer_id_range), + [OVPN_A_KEYCONF_SLOT] = NLA_POLICY_MAX(NLA_U32, 1), + [OVPN_A_KEYCONF_KEY_ID] = NLA_POLICY_MAX(NLA_U32, 7), + [OVPN_A_KEYCONF_CIPHER_ALG] = NLA_POLICY_MAX(NLA_U32, 2), + [OVPN_A_KEYCONF_ENCRYPT_DIR] = NLA_POLICY_NESTED(ovpn_keydir_nl_policy), + [OVPN_A_KEYCONF_DECRYPT_DIR] = NLA_POLICY_NESTED(ovpn_keydir_nl_policy), +}; + +const struct nla_policy ovpn_keydir_nl_policy[OVPN_A_KEYDIR_NONCE_TAIL + 1] = { + [OVPN_A_KEYDIR_CIPHER_KEY] = NLA_POLICY_MAX_LEN(256), + [OVPN_A_KEYDIR_NONCE_TAIL] = NLA_POLICY_EXACT_LEN(OVPN_NONCE_TAIL_SIZE), +}; + +const struct nla_policy ovpn_peer_nl_policy[OVPN_A_PEER_LINK_TX_PACKETS + 1] = { + [OVPN_A_PEER_ID] = NLA_POLICY_FULL_RANGE(NLA_U32, &ovpn_a_peer_id_range), + [OVPN_A_PEER_REMOTE_IPV4] = { .type = NLA_U32, }, + [OVPN_A_PEER_REMOTE_IPV6] = NLA_POLICY_EXACT_LEN(16), + [OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID] = { .type = NLA_U32, }, + [OVPN_A_PEER_REMOTE_PORT] = NLA_POLICY_MIN(NLA_U16, 1), + [OVPN_A_PEER_SOCKET] = { .type = NLA_U32, }, + [OVPN_A_PEER_VPN_IPV4] = { .type = NLA_U32, }, + [OVPN_A_PEER_VPN_IPV6] = NLA_POLICY_EXACT_LEN(16), + [OVPN_A_PEER_LOCAL_IPV4] = { .type = NLA_U32, }, + [OVPN_A_PEER_LOCAL_IPV6] = NLA_POLICY_EXACT_LEN(16), + [OVPN_A_PEER_LOCAL_PORT] = NLA_POLICY_MIN(NLA_U16, 1), + [OVPN_A_PEER_KEEPALIVE_INTERVAL] = { .type = NLA_U32, }, + [OVPN_A_PEER_KEEPALIVE_TIMEOUT] = { .type = NLA_U32, }, + [OVPN_A_PEER_DEL_REASON] = NLA_POLICY_MAX(NLA_U32, 4), + [OVPN_A_PEER_VPN_RX_BYTES] = { .type = NLA_UINT, }, + [OVPN_A_PEER_VPN_TX_BYTES] = { .type = NLA_UINT, }, + [OVPN_A_PEER_VPN_RX_PACKETS] = { .type = NLA_UINT, }, + [OVPN_A_PEER_VPN_TX_PACKETS] = { .type = NLA_UINT, }, + [OVPN_A_PEER_LINK_RX_BYTES] = { .type = NLA_UINT, }, + [OVPN_A_PEER_LINK_TX_BYTES] = { .type = NLA_UINT, }, + [OVPN_A_PEER_LINK_RX_PACKETS] = { .type = NLA_U32, }, + [OVPN_A_PEER_LINK_TX_PACKETS] = { .type = NLA_U32, }, +}; + +/* OVPN_CMD_PEER_NEW - do */ +static const struct nla_policy ovpn_peer_new_nl_policy[OVPN_A_PEER + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_PEER] = NLA_POLICY_NESTED(ovpn_peer_nl_policy), +}; + +/* OVPN_CMD_PEER_SET - do */ +static const struct nla_policy ovpn_peer_set_nl_policy[OVPN_A_PEER + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_PEER] = NLA_POLICY_NESTED(ovpn_peer_nl_policy), +}; + +/* OVPN_CMD_PEER_GET - do */ +static const struct nla_policy ovpn_peer_get_do_nl_policy[OVPN_A_PEER + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_PEER] = NLA_POLICY_NESTED(ovpn_peer_nl_policy), +}; + +/* OVPN_CMD_PEER_GET - dump */ +static const struct nla_policy ovpn_peer_get_dump_nl_policy[OVPN_A_IFINDEX + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, +}; + +/* OVPN_CMD_PEER_DEL - do */ +static const struct nla_policy ovpn_peer_del_nl_policy[OVPN_A_PEER + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_PEER] = NLA_POLICY_NESTED(ovpn_peer_nl_policy), +}; + +/* OVPN_CMD_KEY_NEW - do */ +static const struct nla_policy ovpn_key_new_nl_policy[OVPN_A_KEYCONF + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_KEYCONF] = NLA_POLICY_NESTED(ovpn_keyconf_nl_policy), +}; + +/* OVPN_CMD_KEY_GET - do */ +static const struct nla_policy ovpn_key_get_nl_policy[OVPN_A_KEYCONF + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_KEYCONF] = NLA_POLICY_NESTED(ovpn_keyconf_nl_policy), +}; + +/* OVPN_CMD_KEY_SWAP - do */ +static const struct nla_policy ovpn_key_swap_nl_policy[OVPN_A_KEYCONF + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_KEYCONF] = NLA_POLICY_NESTED(ovpn_keyconf_nl_policy), +}; + +/* OVPN_CMD_KEY_DEL - do */ +static const struct nla_policy ovpn_key_del_nl_policy[OVPN_A_KEYCONF + 1] = { + [OVPN_A_IFINDEX] = { .type = NLA_U32, }, + [OVPN_A_KEYCONF] = NLA_POLICY_NESTED(ovpn_keyconf_nl_policy), +}; + +/* Ops table for ovpn */ +static const struct genl_split_ops ovpn_nl_ops[] = { + { + .cmd = OVPN_CMD_PEER_NEW, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_peer_new_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_peer_new_nl_policy, + .maxattr = OVPN_A_PEER, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_PEER_SET, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_peer_set_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_peer_set_nl_policy, + .maxattr = OVPN_A_PEER, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_PEER_GET, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_peer_get_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_peer_get_do_nl_policy, + .maxattr = OVPN_A_PEER, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_PEER_GET, + .dumpit = ovpn_nl_peer_get_dumpit, + .policy = ovpn_peer_get_dump_nl_policy, + .maxattr = OVPN_A_IFINDEX, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DUMP, + }, + { + .cmd = OVPN_CMD_PEER_DEL, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_peer_del_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_peer_del_nl_policy, + .maxattr = OVPN_A_PEER, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_KEY_NEW, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_key_new_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_key_new_nl_policy, + .maxattr = OVPN_A_KEYCONF, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_KEY_GET, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_key_get_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_key_get_nl_policy, + .maxattr = OVPN_A_KEYCONF, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_KEY_SWAP, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_key_swap_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_key_swap_nl_policy, + .maxattr = OVPN_A_KEYCONF, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, + { + .cmd = OVPN_CMD_KEY_DEL, + .pre_doit = ovpn_nl_pre_doit, + .doit = ovpn_nl_key_del_doit, + .post_doit = ovpn_nl_post_doit, + .policy = ovpn_key_del_nl_policy, + .maxattr = OVPN_A_KEYCONF, + .flags = GENL_ADMIN_PERM | GENL_CMD_CAP_DO, + }, +}; + +static const struct genl_multicast_group ovpn_nl_mcgrps[] = { + [OVPN_NLGRP_PEERS] = { "peers", }, +}; + +struct genl_family ovpn_nl_family __ro_after_init = { + .name = OVPN_FAMILY_NAME, + .version = OVPN_FAMILY_VERSION, + .netnsok = true, + .parallel_ops = true, + .module = THIS_MODULE, + .split_ops = ovpn_nl_ops, + .n_split_ops = ARRAY_SIZE(ovpn_nl_ops), + .mcgrps = ovpn_nl_mcgrps, + .n_mcgrps = ARRAY_SIZE(ovpn_nl_mcgrps), +}; diff --git a/drivers/net/ovpn/netlink-gen.h b/drivers/net/ovpn/netlink-gen.h new file mode 100644 index 000000000000..66a4e4a0a055 --- /dev/null +++ b/drivers/net/ovpn/netlink-gen.h @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */ +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/ovpn.yaml */ +/* YNL-GEN kernel header */ + +#ifndef _LINUX_OVPN_GEN_H +#define _LINUX_OVPN_GEN_H + +#include +#include + +#include + +/* Common nested types */ +extern const struct nla_policy ovpn_keyconf_nl_policy[OVPN_A_KEYCONF_DECRYPT_DIR + 1]; +extern const struct nla_policy ovpn_keydir_nl_policy[OVPN_A_KEYDIR_NONCE_TAIL + 1]; +extern const struct nla_policy ovpn_peer_nl_policy[OVPN_A_PEER_LINK_TX_PACKETS + 1]; + +int ovpn_nl_pre_doit(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info); +void +ovpn_nl_post_doit(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info); + +int ovpn_nl_peer_new_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_peer_set_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_peer_get_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_peer_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb); +int ovpn_nl_peer_del_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_key_new_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_key_get_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_key_swap_doit(struct sk_buff *skb, struct genl_info *info); +int ovpn_nl_key_del_doit(struct sk_buff *skb, struct genl_info *info); + +enum { + OVPN_NLGRP_PEERS, +}; + +extern struct genl_family ovpn_nl_family; + +#endif /* _LINUX_OVPN_GEN_H */ diff --git a/drivers/net/ovpn/netlink.c b/drivers/net/ovpn/netlink.c new file mode 100644 index 000000000000..4d7d835cb47f --- /dev/null +++ b/drivers/net/ovpn/netlink.c @@ -0,0 +1,1135 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include +#include +#include + +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "io.h" +#include "netlink.h" +#include "netlink-gen.h" +#include "bind.h" +#include "crypto.h" +#include "packet.h" +#include "peer.h" +#include "socket.h" + +MODULE_ALIAS_GENL_FAMILY(OVPN_FAMILY_NAME); + +/** + * ovpn_get_dev_from_attrs - retrieve the ovpn private data from the netdevice + * a netlink message is targeting + * @net: network namespace where to look for the interface + * @info: generic netlink info from the user request + * + * Return: the ovpn private data, if found, or an error otherwise + */ +static struct ovpn_struct * +ovpn_get_dev_from_attrs(struct net *net, const struct genl_info *info) +{ + struct ovpn_struct *ovpn; + struct net_device *dev; + int ifindex; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_IFINDEX)) + return ERR_PTR(-EINVAL); + + ifindex = nla_get_u32(info->attrs[OVPN_A_IFINDEX]); + + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, ifindex); + if (!dev) { + rcu_read_unlock(); + NL_SET_ERR_MSG_MOD(info->extack, + "ifindex does not match any interface"); + return ERR_PTR(-ENODEV); + } + + if (!ovpn_dev_is_valid(dev)) { + rcu_read_unlock(); + NL_SET_ERR_MSG_MOD(info->extack, + "specified interface is not ovpn"); + NL_SET_BAD_ATTR(info->extack, info->attrs[OVPN_A_IFINDEX]); + return ERR_PTR(-EINVAL); + } + + ovpn = netdev_priv(dev); + netdev_hold(dev, &ovpn->dev_tracker, GFP_KERNEL); + rcu_read_unlock(); + + return ovpn; +} + +int ovpn_nl_pre_doit(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct ovpn_struct *ovpn = ovpn_get_dev_from_attrs(genl_info_net(info), + info); + + if (IS_ERR(ovpn)) + return PTR_ERR(ovpn); + + info->user_ptr[0] = ovpn; + + return 0; +} + +void ovpn_nl_post_doit(const struct genl_split_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct ovpn_struct *ovpn = info->user_ptr[0]; + + if (ovpn) + netdev_put(ovpn->dev, &ovpn->dev_tracker); +} + +static int ovpn_nl_attr_sockaddr_remote(struct nlattr **attrs, + struct sockaddr_storage *ss) +{ + struct sockaddr_in6 *sin6; + struct sockaddr_in *sin; + struct in6_addr *in6; + __be16 port = 0; + __be32 *in; + int af; + + ss->ss_family = AF_UNSPEC; + + if (attrs[OVPN_A_PEER_REMOTE_PORT]) + port = nla_get_be16(attrs[OVPN_A_PEER_REMOTE_PORT]); + + if (attrs[OVPN_A_PEER_REMOTE_IPV4]) { + af = AF_INET; + ss->ss_family = AF_INET; + in = nla_data(attrs[OVPN_A_PEER_REMOTE_IPV4]); + } else if (attrs[OVPN_A_PEER_REMOTE_IPV6]) { + af = AF_INET6; + ss->ss_family = AF_INET6; + in6 = nla_data(attrs[OVPN_A_PEER_REMOTE_IPV6]); + } else { + return AF_UNSPEC; + } + + switch (ss->ss_family) { + case AF_INET6: + /* If this is a regular IPv6 just break and move on, + * otherwise switch to AF_INET and extract the IPv4 accordingly + */ + if (!ipv6_addr_v4mapped(in6)) { + sin6 = (struct sockaddr_in6 *)ss; + sin6->sin6_port = port; + memcpy(&sin6->sin6_addr, in6, sizeof(*in6)); + break; + } + + /* v4-mapped-v6 address */ + ss->ss_family = AF_INET; + in = &in6->s6_addr32[3]; + fallthrough; + case AF_INET: + sin = (struct sockaddr_in *)ss; + sin->sin_port = port; + sin->sin_addr.s_addr = *in; + break; + } + + /* don't return ss->ss_family as it may have changed in case of + * v4-mapped-v6 address + */ + return af; +} + +static u8 *ovpn_nl_attr_local_ip(struct nlattr **attrs) +{ + u8 *addr6; + + if (!attrs[OVPN_A_PEER_LOCAL_IPV4] && !attrs[OVPN_A_PEER_LOCAL_IPV6]) + return NULL; + + if (attrs[OVPN_A_PEER_LOCAL_IPV4]) + return nla_data(attrs[OVPN_A_PEER_LOCAL_IPV4]); + + addr6 = nla_data(attrs[OVPN_A_PEER_LOCAL_IPV6]); + /* this is an IPv4-mapped IPv6 address, therefore extract the actual + * v4 address from the last 4 bytes + */ + if (ipv6_addr_v4mapped((struct in6_addr *)addr6)) + return addr6 + 12; + + return addr6; +} + +static int ovpn_nl_peer_precheck(struct ovpn_struct *ovpn, + struct genl_info *info, + struct nlattr **attrs) +{ + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_PEER], attrs, + OVPN_A_PEER_ID)) + return -EINVAL; + + if (attrs[OVPN_A_PEER_REMOTE_IPV4] && attrs[OVPN_A_PEER_REMOTE_IPV6]) { + NL_SET_ERR_MSG_MOD(info->extack, + "cannot specify both remote IPv4 or IPv6 address"); + return -EINVAL; + } + + if (!attrs[OVPN_A_PEER_REMOTE_IPV4] && + !attrs[OVPN_A_PEER_REMOTE_IPV6] && attrs[OVPN_A_PEER_REMOTE_PORT]) { + NL_SET_ERR_MSG_MOD(info->extack, + "cannot specify remote port without IP address"); + return -EINVAL; + } + + if (!attrs[OVPN_A_PEER_REMOTE_IPV4] && + attrs[OVPN_A_PEER_LOCAL_IPV4]) { + NL_SET_ERR_MSG_MOD(info->extack, + "cannot specify local IPv4 address without remote"); + return -EINVAL; + } + + if (!attrs[OVPN_A_PEER_REMOTE_IPV6] && + attrs[OVPN_A_PEER_LOCAL_IPV6]) { + NL_SET_ERR_MSG_MOD(info->extack, + "cannot specify local IPV6 address without remote"); + return -EINVAL; + } + + if (!attrs[OVPN_A_PEER_REMOTE_IPV6] && + attrs[OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID]) { + NL_SET_ERR_MSG_MOD(info->extack, + "cannot specify scope id without remote IPv6 address"); + return -EINVAL; + } + + /* VPN IPs are needed only in MP mode for selecting the right peer */ + if (ovpn->mode == OVPN_MODE_P2P && (attrs[OVPN_A_PEER_VPN_IPV4] || + attrs[OVPN_A_PEER_VPN_IPV6])) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "VPN IP unexpected in P2P mode"); + return -EINVAL; + } + + if ((attrs[OVPN_A_PEER_KEEPALIVE_INTERVAL] && + !attrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT]) || + (!attrs[OVPN_A_PEER_KEEPALIVE_INTERVAL] && + attrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT])) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "keepalive interval and timeout are required together"); + return -EINVAL; + } + + return 0; +} + +/** + * ovpn_nl_peer_modify - modify the peer attributes according to the incoming msg + * @peer: the peer to modify + * @info: generic netlink info from the user request + * @attrs: the attributes from the user request + * + * Return: a negative error code in case of failure, 0 on success or 1 on + * success and the VPN IPs have been modified (requires rehashing in MP + * mode) + */ +static int ovpn_nl_peer_modify(struct ovpn_peer *peer, struct genl_info *info, + struct nlattr **attrs) +{ + struct sockaddr_storage ss = {}; + u32 sockfd, interv, timeout; + struct socket *sock = NULL; + u8 *local_ip = NULL; + bool rehash = false; + int ret; + + if (attrs[OVPN_A_PEER_SOCKET]) { + /* lookup the fd in the kernel table and extract the socket + * object + */ + sockfd = nla_get_u32(attrs[OVPN_A_PEER_SOCKET]); + /* sockfd_lookup() increases sock's refcounter */ + sock = sockfd_lookup(sockfd, &ret); + if (!sock) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot lookup peer socket (fd=%u): %d", + sockfd, ret); + return -ENOTSOCK; + } + + /* Only when using UDP as transport protocol the remote endpoint + * can be configured so that ovpn knows where to send packets + * to. + * + * In case of TCP, the socket is connected to the peer and ovpn + * will just send bytes over it, without the need to specify a + * destination. + */ + if (sock->sk->sk_protocol != IPPROTO_UDP && + (attrs[OVPN_A_PEER_REMOTE_IPV4] || + attrs[OVPN_A_PEER_REMOTE_IPV6])) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "unexpected remote IP address for non UDP socket"); + sockfd_put(sock); + return -EINVAL; + } + + if (peer->sock) + ovpn_socket_put(peer->sock); + + peer->sock = ovpn_socket_new(sock, peer); + if (IS_ERR(peer->sock)) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot encapsulate socket: %ld", + PTR_ERR(peer->sock)); + sockfd_put(sock); + peer->sock = NULL; + return -ENOTSOCK; + } + } + + if (ovpn_nl_attr_sockaddr_remote(attrs, &ss) != AF_UNSPEC) { + /* we carry the local IP in a generic container. + * ovpn_peer_reset_sockaddr() will properly interpret it + * based on ss.ss_family + */ + local_ip = ovpn_nl_attr_local_ip(attrs); + + spin_lock_bh(&peer->lock); + /* set peer sockaddr */ + ret = ovpn_peer_reset_sockaddr(peer, &ss, local_ip); + if (ret < 0) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot set peer sockaddr: %d", + ret); + spin_unlock_bh(&peer->lock); + return ret; + } + spin_unlock_bh(&peer->lock); + } + + if (attrs[OVPN_A_PEER_VPN_IPV4]) { + rehash = true; + peer->vpn_addrs.ipv4.s_addr = + nla_get_in_addr(attrs[OVPN_A_PEER_VPN_IPV4]); + } + + if (attrs[OVPN_A_PEER_VPN_IPV6]) { + rehash = true; + peer->vpn_addrs.ipv6 = + nla_get_in6_addr(attrs[OVPN_A_PEER_VPN_IPV6]); + } + + /* when setting the keepalive, both parameters have to be configured */ + if (attrs[OVPN_A_PEER_KEEPALIVE_INTERVAL] && + attrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT]) { + interv = nla_get_u32(attrs[OVPN_A_PEER_KEEPALIVE_INTERVAL]); + timeout = nla_get_u32(attrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT]); + ovpn_peer_keepalive_set(peer, interv, timeout); + } + + netdev_dbg(peer->ovpn->dev, + "%s: peer id=%u endpoint=%pIScp/%s VPN-IPv4=%pI4 VPN-IPv6=%pI6c\n", + __func__, peer->id, &ss, + peer->sock->sock->sk->sk_prot_creator->name, + &peer->vpn_addrs.ipv4.s_addr, &peer->vpn_addrs.ipv6); + + return rehash ? 1 : 0; +} + +int ovpn_nl_peer_new_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_PEER_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_PEER)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_PEER_MAX, info->attrs[OVPN_A_PEER], + ovpn_peer_nl_policy, info->extack); + if (ret) + return ret; + + ret = ovpn_nl_peer_precheck(ovpn, info, attrs); + if (ret < 0) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_PEER], attrs, + OVPN_A_PEER_SOCKET)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_PEER_ID]); + peer = ovpn_peer_new(ovpn, peer_id); + if (IS_ERR(peer)) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot create new peer object for peer %u: %ld", + peer_id, PTR_ERR(peer)); + return PTR_ERR(peer); + } + + ret = ovpn_nl_peer_modify(peer, info, attrs); + if (ret < 0) + goto peer_release; + + ret = ovpn_peer_add(ovpn, peer); + if (ret < 0) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot add new peer (id=%u) to hashtable: %d\n", + peer->id, ret); + goto peer_release; + } + + return 0; + +peer_release: + /* release right away because peer is not used in any context */ + ovpn_peer_release(peer); + + return ret; +} + +int ovpn_nl_peer_set_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_PEER_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_PEER)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_PEER_MAX, info->attrs[OVPN_A_PEER], + ovpn_peer_nl_policy, info->extack); + if (ret) + return ret; + + ret = ovpn_nl_peer_precheck(ovpn, info, attrs); + if (ret < 0) + return ret; + + peer_id = nla_get_u32(attrs[OVPN_A_PEER_ID]); + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + ret = ovpn_nl_peer_modify(peer, info, attrs); + if (ret < 0) { + ovpn_peer_put(peer); + return ret; + } + + /* ret == 1 means that VPN IPv4/6 has been modified and rehashing + * is required + */ + if (ret > 0) { + spin_lock_bh(&ovpn->peers->lock); + ovpn_peer_hash_vpn_ip(peer); + spin_unlock_bh(&ovpn->peers->lock); + } + + ovpn_peer_put(peer); + + return 0; +} + +static int ovpn_nl_send_peer(struct sk_buff *skb, const struct genl_info *info, + const struct ovpn_peer *peer, u32 portid, u32 seq, + int flags) +{ + const struct ovpn_bind *bind; + struct nlattr *attr; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ovpn_nl_family, flags, + OVPN_CMD_PEER_GET); + if (!hdr) + return -ENOBUFS; + + attr = nla_nest_start(skb, OVPN_A_PEER); + if (!attr) + goto err; + + if (nla_put_u32(skb, OVPN_A_PEER_ID, peer->id)) + goto err; + + if (peer->vpn_addrs.ipv4.s_addr != htonl(INADDR_ANY)) + if (nla_put_in_addr(skb, OVPN_A_PEER_VPN_IPV4, + peer->vpn_addrs.ipv4.s_addr)) + goto err; + + if (!ipv6_addr_equal(&peer->vpn_addrs.ipv6, &in6addr_any)) + if (nla_put_in6_addr(skb, OVPN_A_PEER_VPN_IPV6, + &peer->vpn_addrs.ipv6)) + goto err; + + if (nla_put_u32(skb, OVPN_A_PEER_KEEPALIVE_INTERVAL, + peer->keepalive_interval) || + nla_put_u32(skb, OVPN_A_PEER_KEEPALIVE_TIMEOUT, + peer->keepalive_timeout)) + goto err; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (bind) { + if (bind->remote.in4.sin_family == AF_INET) { + if (nla_put_in_addr(skb, OVPN_A_PEER_REMOTE_IPV4, + bind->remote.in4.sin_addr.s_addr) || + nla_put_net16(skb, OVPN_A_PEER_REMOTE_PORT, + bind->remote.in4.sin_port) || + nla_put_in_addr(skb, OVPN_A_PEER_LOCAL_IPV4, + bind->local.ipv4.s_addr)) + goto err_unlock; + } else if (bind->remote.in4.sin_family == AF_INET6) { + if (nla_put_in6_addr(skb, OVPN_A_PEER_REMOTE_IPV6, + &bind->remote.in6.sin6_addr) || + nla_put_u32(skb, OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID, + bind->remote.in6.sin6_scope_id) || + nla_put_net16(skb, OVPN_A_PEER_REMOTE_PORT, + bind->remote.in6.sin6_port) || + nla_put_in6_addr(skb, OVPN_A_PEER_LOCAL_IPV6, + &bind->local.ipv6)) + goto err_unlock; + } + } + rcu_read_unlock(); + + if (nla_put_net16(skb, OVPN_A_PEER_LOCAL_PORT, + inet_sk(peer->sock->sock->sk)->inet_sport) || + /* VPN RX stats */ + nla_put_uint(skb, OVPN_A_PEER_VPN_RX_BYTES, + atomic64_read(&peer->vpn_stats.rx.bytes)) || + nla_put_uint(skb, OVPN_A_PEER_VPN_RX_PACKETS, + atomic64_read(&peer->vpn_stats.rx.packets)) || + /* VPN TX stats */ + nla_put_uint(skb, OVPN_A_PEER_VPN_TX_BYTES, + atomic64_read(&peer->vpn_stats.tx.bytes)) || + nla_put_uint(skb, OVPN_A_PEER_VPN_TX_PACKETS, + atomic64_read(&peer->vpn_stats.tx.packets)) || + /* link RX stats */ + nla_put_uint(skb, OVPN_A_PEER_LINK_RX_BYTES, + atomic64_read(&peer->link_stats.rx.bytes)) || + nla_put_uint(skb, OVPN_A_PEER_LINK_RX_PACKETS, + atomic64_read(&peer->link_stats.rx.packets)) || + /* link TX stats */ + nla_put_uint(skb, OVPN_A_PEER_LINK_TX_BYTES, + atomic64_read(&peer->link_stats.tx.bytes)) || + nla_put_uint(skb, OVPN_A_PEER_LINK_TX_PACKETS, + atomic64_read(&peer->link_stats.tx.packets))) + goto err; + + nla_nest_end(skb, attr); + genlmsg_end(skb, hdr); + + return 0; +err_unlock: + rcu_read_unlock(); +err: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +int ovpn_nl_peer_get_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_PEER_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + struct sk_buff *msg; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_PEER)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_PEER_MAX, info->attrs[OVPN_A_PEER], + ovpn_peer_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_PEER], attrs, + OVPN_A_PEER_ID)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_PEER_ID]); + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot find peer with id %u", peer_id); + return -ENOENT; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto err; + } + + ret = ovpn_nl_send_peer(msg, info, peer, info->snd_portid, + info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + goto err; + } + + ret = genlmsg_reply(msg, info); +err: + ovpn_peer_put(peer); + return ret; +} + +int ovpn_nl_peer_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct genl_info *info = genl_info_dump(cb); + int bkt, last_idx = cb->args[1], dumped = 0; + struct ovpn_struct *ovpn; + struct ovpn_peer *peer; + + ovpn = ovpn_get_dev_from_attrs(sock_net(cb->skb->sk), info); + if (IS_ERR(ovpn)) + return PTR_ERR(ovpn); + + if (ovpn->mode == OVPN_MODE_P2P) { + /* if we already dumped a peer it means we are done */ + if (last_idx) + goto out; + + rcu_read_lock(); + peer = rcu_dereference(ovpn->peer); + if (peer) { + if (ovpn_nl_send_peer(skb, info, peer, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) == 0) + dumped++; + } + rcu_read_unlock(); + } else { + rcu_read_lock(); + hash_for_each_rcu(ovpn->peers->by_id, bkt, peer, + hash_entry_id) { + /* skip already dumped peers that were dumped by + * previous invocations + */ + if (last_idx > 0) { + last_idx--; + continue; + } + + if (ovpn_nl_send_peer(skb, info, peer, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + break; + + /* count peers being dumped during this invocation */ + dumped++; + } + rcu_read_unlock(); + } + +out: + netdev_put(ovpn->dev, &ovpn->dev_tracker); + + /* sum up peers dumped in this message, so that at the next invocation + * we can continue from where we left + */ + cb->args[1] += dumped; + return skb->len; +} + +int ovpn_nl_peer_del_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_PEER_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_PEER)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_PEER_MAX, info->attrs[OVPN_A_PEER], + ovpn_peer_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_PEER], attrs, + OVPN_A_PEER_ID)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_PEER_ID]); + + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + netdev_dbg(ovpn->dev, "%s: peer id=%u\n", __func__, peer->id); + ret = ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_USERSPACE); + ovpn_peer_put(peer); + + return ret; +} + +static int ovpn_nl_get_key_dir(struct genl_info *info, struct nlattr *key, + enum ovpn_cipher_alg cipher, + struct ovpn_key_direction *dir) +{ + struct nlattr *attrs[OVPN_A_KEYDIR_MAX + 1]; + int ret; + + ret = nla_parse_nested(attrs, OVPN_A_KEYDIR_MAX, key, + ovpn_keydir_nl_policy, info->extack); + if (ret) + return ret; + + switch (cipher) { + case OVPN_CIPHER_ALG_AES_GCM: + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + if (NL_REQ_ATTR_CHECK(info->extack, key, attrs, + OVPN_A_KEYDIR_CIPHER_KEY) || + NL_REQ_ATTR_CHECK(info->extack, key, attrs, + OVPN_A_KEYDIR_NONCE_TAIL)) + return -EINVAL; + + dir->cipher_key = nla_data(attrs[OVPN_A_KEYDIR_CIPHER_KEY]); + dir->cipher_key_size = nla_len(attrs[OVPN_A_KEYDIR_CIPHER_KEY]); + + /* These algorithms require a 96bit nonce, + * Construct it by combining 4-bytes packet id and + * 8-bytes nonce-tail from userspace + */ + dir->nonce_tail = nla_data(attrs[OVPN_A_KEYDIR_NONCE_TAIL]); + dir->nonce_tail_size = nla_len(attrs[OVPN_A_KEYDIR_NONCE_TAIL]); + break; + default: + NL_SET_ERR_MSG_MOD(info->extack, "unsupported cipher"); + return -EINVAL; + } + + return 0; +} + +/** + * ovpn_nl_key_new_doit - configure a new key for the specified peer + * @skb: incoming netlink message + * @info: genetlink metadata + * + * This function allows the user to install a new key in the peer crypto + * state. + * Each peer has two 'slots', namely 'primary' and 'secondary', where + * keys can be installed. The key in the 'primary' slot is used for + * encryption, while both keys can be used for decryption by matching the + * key ID carried in the incoming packet. + * + * The user is responsible for rotating keys when necessary. The user + * may fetch peer traffic statistics via netlink in order to better + * identify the right time to rotate keys. + * The renegotiation follows these steps: + * 1. a new key is computed by the user and is installed in the 'secondary' + * slot + * 2. at user discretion (usually after a predetermined time) 'primary' and + * 'secondary' contents are swapped and the new key starts being used for + * encryption, while the old key is kept around for decryption of late + * packets. + * + * Return: 0 on success or a negative error code otherwise. + */ +int ovpn_nl_key_new_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_KEYCONF_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer_key_reset pkr; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_KEYCONF)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_KEYCONF_MAX, + info->attrs[OVPN_A_KEYCONF], + ovpn_keyconf_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_PEER_ID)) + return -EINVAL; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_SLOT) || + NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_KEY_ID) || + NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_CIPHER_ALG) || + NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_ENCRYPT_DIR) || + NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_DECRYPT_DIR)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_KEYCONF_PEER_ID]); + pkr.slot = nla_get_u8(attrs[OVPN_A_KEYCONF_SLOT]); + pkr.key.key_id = nla_get_u16(attrs[OVPN_A_KEYCONF_KEY_ID]); + pkr.key.cipher_alg = nla_get_u16(attrs[OVPN_A_KEYCONF_CIPHER_ALG]); + + ret = ovpn_nl_get_key_dir(info, attrs[OVPN_A_KEYCONF_ENCRYPT_DIR], + pkr.key.cipher_alg, &pkr.key.encrypt); + if (ret < 0) + return ret; + + ret = ovpn_nl_get_key_dir(info, attrs[OVPN_A_KEYCONF_DECRYPT_DIR], + pkr.key.cipher_alg, &pkr.key.decrypt); + if (ret < 0) + return ret; + + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "no peer with id %u to set key for", + peer_id); + return -ENOENT; + } + + ret = ovpn_crypto_state_reset(&peer->crypto, &pkr); + if (ret < 0) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot install new key for peer %u", + peer_id); + goto out; + } + + netdev_dbg(ovpn->dev, "%s: new key installed (id=%u) for peer %u\n", + __func__, pkr.key.key_id, peer_id); +out: + ovpn_peer_put(peer); + return ret; +} + +static int ovpn_nl_send_key(struct sk_buff *skb, const struct genl_info *info, + u32 peer_id, enum ovpn_key_slot slot, + const struct ovpn_key_config *keyconf, u32 portid, + u32 seq, int flags) +{ + struct nlattr *attr; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ovpn_nl_family, flags, + OVPN_CMD_KEY_GET); + if (!hdr) + return -ENOBUFS; + + attr = nla_nest_start(skb, OVPN_A_KEYCONF); + if (!attr) + goto err; + + if (nla_put_u32(skb, OVPN_A_KEYCONF_PEER_ID, peer_id)) + goto err; + + if (nla_put_u32(skb, OVPN_A_KEYCONF_SLOT, slot) || + nla_put_u32(skb, OVPN_A_KEYCONF_KEY_ID, keyconf->key_id) || + nla_put_u32(skb, OVPN_A_KEYCONF_CIPHER_ALG, keyconf->cipher_alg)) + goto err; + + nla_nest_end(skb, attr); + genlmsg_end(skb, hdr); + + return 0; +err: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +int ovpn_nl_key_get_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_KEYCONF_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_key_config keyconf = { 0 }; + enum ovpn_key_slot slot; + struct ovpn_peer *peer; + struct sk_buff *msg; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_KEYCONF)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_KEYCONF_MAX, + info->attrs[OVPN_A_KEYCONF], + ovpn_keyconf_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_PEER_ID)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_KEYCONF_PEER_ID]); + + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot find peer with id %u", 0); + return -ENOENT; + } + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_SLOT)) + return -EINVAL; + + slot = nla_get_u32(attrs[OVPN_A_KEYCONF_SLOT]); + + ret = ovpn_crypto_config_get(&peer->crypto, slot, &keyconf); + if (ret < 0) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "cannot extract key from slot %u for peer %u", + slot, peer_id); + goto err; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto err; + } + + ret = ovpn_nl_send_key(msg, info, peer->id, slot, &keyconf, + info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + goto err; + } + + ret = genlmsg_reply(msg, info); +err: + ovpn_peer_put(peer); + return ret; +} + +int ovpn_nl_key_swap_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct nlattr *attrs[OVPN_A_PEER_MAX + 1]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_KEYCONF)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_KEYCONF_MAX, + info->attrs[OVPN_A_KEYCONF], + ovpn_keyconf_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_PEER_ID)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_KEYCONF_PEER_ID]); + + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "no peer with id %u to swap keys for", + peer_id); + return -ENOENT; + } + + ovpn_crypto_key_slots_swap(&peer->crypto); + ovpn_peer_put(peer); + + return 0; +} + +int ovpn_nl_key_del_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_A_KEYCONF_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + enum ovpn_key_slot slot; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (GENL_REQ_ATTR_CHECK(info, OVPN_A_KEYCONF)) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_A_KEYCONF_MAX, + info->attrs[OVPN_A_KEYCONF], + ovpn_keyconf_nl_policy, info->extack); + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_PEER_ID)) + return -EINVAL; + + if (ret) + return ret; + + if (NL_REQ_ATTR_CHECK(info->extack, info->attrs[OVPN_A_KEYCONF], attrs, + OVPN_A_KEYCONF_SLOT)) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_A_KEYCONF_PEER_ID]); + slot = nla_get_u8(attrs[OVPN_A_KEYCONF_SLOT]); + + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + NL_SET_ERR_MSG_FMT_MOD(info->extack, + "no peer with id %u to delete key for", + peer_id); + return -ENOENT; + } + + ovpn_crypto_key_slot_delete(&peer->crypto, slot); + ovpn_peer_put(peer); + + return 0; +} + +/** + * ovpn_nl_peer_del_notify - notify userspace about peer being deleted + * @peer: the peer being deleted + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_nl_peer_del_notify(struct ovpn_peer *peer) +{ + struct sk_buff *msg; + struct nlattr *attr; + int ret = -EMSGSIZE; + void *hdr; + + netdev_info(peer->ovpn->dev, "deleting peer with id %u, reason %d\n", + peer->id, peer->delete_reason); + + msg = nlmsg_new(100, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &ovpn_nl_family, 0, OVPN_CMD_PEER_DEL_NTF); + if (!hdr) { + ret = -ENOBUFS; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_A_IFINDEX, peer->ovpn->dev->ifindex)) + goto err_cancel_msg; + + attr = nla_nest_start(msg, OVPN_A_PEER); + if (!attr) + goto err_cancel_msg; + + if (nla_put_u8(msg, OVPN_A_PEER_DEL_REASON, peer->delete_reason)) + goto err_cancel_msg; + + if (nla_put_u32(msg, OVPN_A_PEER_ID, peer->id)) + goto err_cancel_msg; + + nla_nest_end(msg, attr); + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&ovpn_nl_family, dev_net(peer->ovpn->dev), msg, + 0, OVPN_NLGRP_PEERS, GFP_ATOMIC); + + return 0; + +err_cancel_msg: + genlmsg_cancel(msg, hdr); +err_free_msg: + nlmsg_free(msg); + return ret; +} + +/** + * ovpn_nl_key_swap_notify - notify userspace peer's key must be renewed + * @peer: the peer whose key needs to be renewed + * @key_id: the ID of the key that needs to be renewed + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_nl_key_swap_notify(struct ovpn_peer *peer, u8 key_id) +{ + struct nlattr *k_attr; + struct sk_buff *msg; + int ret = -EMSGSIZE; + void *hdr; + + netdev_info(peer->ovpn->dev, "peer with id %u must rekey - primary key unusable.\n", + peer->id); + + msg = nlmsg_new(100, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &ovpn_nl_family, 0, OVPN_CMD_KEY_SWAP_NTF); + if (!hdr) { + ret = -ENOBUFS; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_A_IFINDEX, peer->ovpn->dev->ifindex)) + goto err_cancel_msg; + + k_attr = nla_nest_start(msg, OVPN_A_KEYCONF); + if (!k_attr) + goto err_cancel_msg; + + if (nla_put_u32(msg, OVPN_A_KEYCONF_PEER_ID, peer->id)) + goto err_cancel_msg; + + if (nla_put_u16(msg, OVPN_A_KEYCONF_KEY_ID, key_id)) + goto err_cancel_msg; + + nla_nest_end(msg, k_attr); + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&ovpn_nl_family, dev_net(peer->ovpn->dev), msg, + 0, OVPN_NLGRP_PEERS, GFP_ATOMIC); + + return 0; + +err_cancel_msg: + genlmsg_cancel(msg, hdr); +err_free_msg: + nlmsg_free(msg); + return ret; +} + +/** + * ovpn_nl_register - perform any needed registration in the NL subsustem + * + * Return: 0 on success, a negative error code otherwise + */ +int __init ovpn_nl_register(void) +{ + int ret = genl_register_family(&ovpn_nl_family); + + if (ret) { + pr_err("ovpn: genl_register_family failed: %d\n", ret); + return ret; + } + + return 0; +} + +/** + * ovpn_nl_unregister - undo any module wide netlink registration + */ +void ovpn_nl_unregister(void) +{ + genl_unregister_family(&ovpn_nl_family); +} diff --git a/drivers/net/ovpn/netlink.h b/drivers/net/ovpn/netlink.h new file mode 100644 index 000000000000..4ab3abcf23db --- /dev/null +++ b/drivers/net/ovpn/netlink.h @@ -0,0 +1,18 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_NETLINK_H_ +#define _NET_OVPN_NETLINK_H_ + +int ovpn_nl_register(void); +void ovpn_nl_unregister(void); + +int ovpn_nl_peer_del_notify(struct ovpn_peer *peer); +int ovpn_nl_key_swap_notify(struct ovpn_peer *peer, u8 key_id); + +#endif /* _NET_OVPN_NETLINK_H_ */ diff --git a/drivers/net/ovpn/ovpnstruct.h b/drivers/net/ovpn/ovpnstruct.h new file mode 100644 index 000000000000..4ac00d550ecb --- /dev/null +++ b/drivers/net/ovpn/ovpnstruct.h @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPNSTRUCT_H_ +#define _NET_OVPN_OVPNSTRUCT_H_ + +#include +#include +#include +#include + +/** + * struct ovpn_peer_collection - container of peers for MultiPeer mode + * @by_id: table of peers index by ID + * @by_vpn_addr: table of peers indexed by VPN IP address (items can be + * rehashed on the fly due to peer IP change) + * @by_transp_addr: table of peers indexed by transport address (items can be + * rehashed on the fly due to peer IP change) + * @lock: protects writes to peer tables + */ +struct ovpn_peer_collection { + DECLARE_HASHTABLE(by_id, 12); + struct hlist_nulls_head by_vpn_addr[1 << 12]; + struct hlist_nulls_head by_transp_addr[1 << 12]; + + spinlock_t lock; /* protects writes to peer tables */ +}; + +/** + * struct ovpn_struct - per ovpn interface state + * @dev: the actual netdev representing the tunnel + * @dev_tracker: reference tracker for associated dev + * @registered: whether dev is still registered with netdev or not + * @mode: device operation mode (i.e. p2p, mp, ..) + * @lock: protect this object + * @peers: data structures holding multi-peer references + * @peer: in P2P mode, this is the only remote peer + * @dev_list: entry for the module wide device list + * @gro_cells: pointer to the Generic Receive Offload cell + * @keepalive_work: struct used to schedule keepalive periodic job + */ +struct ovpn_struct { + struct net_device *dev; + netdevice_tracker dev_tracker; + bool registered; + enum ovpn_mode mode; + spinlock_t lock; /* protect writing to the ovpn_struct object */ + struct ovpn_peer_collection *peers; + struct ovpn_peer __rcu *peer; + struct list_head dev_list; + struct gro_cells gro_cells; + struct delayed_work keepalive_work; +}; + +#endif /* _NET_OVPN_OVPNSTRUCT_H_ */ diff --git a/drivers/net/ovpn/packet.h b/drivers/net/ovpn/packet.h new file mode 100644 index 000000000000..e14c9bf464f7 --- /dev/null +++ b/drivers/net/ovpn/packet.h @@ -0,0 +1,40 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_PACKET_H_ +#define _NET_OVPN_PACKET_H_ + +/* When the OpenVPN protocol is run in AEAD mode, use + * the OpenVPN packet ID as the AEAD nonce: + * + * 00000005 521c3b01 4308c041 + * [seq # ] [ nonce_tail ] + * [ 12-byte full IV ] -> NONCE_SIZE + * [4-bytes -> NONCE_WIRE_SIZE + * on wire] + */ + +/* OpenVPN nonce size */ +#define NONCE_SIZE 12 + +/* OpenVPN nonce size reduced by 8-byte nonce tail -- this is the + * size of the AEAD Associated Data (AD) sent over the wire + * and is normally the head of the IV + */ +#define NONCE_WIRE_SIZE (NONCE_SIZE - sizeof(struct ovpn_nonce_tail)) + +/* Last 8 bytes of AEAD nonce + * Provided by userspace and usually derived from + * key material generated during TLS handshake + */ +struct ovpn_nonce_tail { + u8 u8[OVPN_NONCE_TAIL_SIZE]; +}; + +#endif /* _NET_OVPN_PACKET_H_ */ diff --git a/drivers/net/ovpn/peer.c b/drivers/net/ovpn/peer.c new file mode 100644 index 000000000000..91c608f1ffa1 --- /dev/null +++ b/drivers/net/ovpn/peer.c @@ -0,0 +1,1201 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "bind.h" +#include "pktid.h" +#include "crypto.h" +#include "io.h" +#include "main.h" +#include "netlink.h" +#include "peer.h" +#include "socket.h" + +/** + * ovpn_peer_keepalive_set - configure keepalive values for peer + * @peer: the peer to configure + * @interval: outgoing keepalive interval + * @timeout: incoming keepalive timeout + */ +void ovpn_peer_keepalive_set(struct ovpn_peer *peer, u32 interval, u32 timeout) +{ + time64_t now = ktime_get_real_seconds(); + + netdev_dbg(peer->ovpn->dev, + "%s: scheduling keepalive for peer %u: interval=%u timeout=%u\n", + __func__, peer->id, interval, timeout); + + peer->keepalive_interval = interval; + peer->last_sent = now; + peer->keepalive_xmit_exp = now + interval; + + peer->keepalive_timeout = timeout; + peer->last_recv = now; + peer->keepalive_recv_exp = now + timeout; + + /* now that interval and timeout have been changed, kick + * off the worker so that the next delay can be recomputed + */ + mod_delayed_work(system_wq, &peer->ovpn->keepalive_work, 0); +} + +/** + * ovpn_peer_new - allocate and initialize a new peer object + * @ovpn: the openvpn instance inside which the peer should be created + * @id: the ID assigned to this peer + * + * Return: a pointer to the new peer on success or an error code otherwise + */ +struct ovpn_peer *ovpn_peer_new(struct ovpn_struct *ovpn, u32 id) +{ + struct ovpn_peer *peer; + int ret; + + /* alloc and init peer object */ + peer = kzalloc(sizeof(*peer), GFP_KERNEL); + if (!peer) + return ERR_PTR(-ENOMEM); + + peer->id = id; + peer->halt = false; + peer->ovpn = ovpn; + + peer->vpn_addrs.ipv4.s_addr = htonl(INADDR_ANY); + peer->vpn_addrs.ipv6 = in6addr_any; + + RCU_INIT_POINTER(peer->bind, NULL); + ovpn_crypto_state_init(&peer->crypto); + spin_lock_init(&peer->lock); + kref_init(&peer->refcount); + ovpn_peer_stats_init(&peer->vpn_stats); + ovpn_peer_stats_init(&peer->link_stats); + + ret = dst_cache_init(&peer->dst_cache, GFP_KERNEL); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot initialize dst cache\n", + __func__); + kfree(peer); + return ERR_PTR(ret); + } + + netdev_hold(ovpn->dev, &ovpn->dev_tracker, GFP_KERNEL); + + return peer; +} + +/** + * ovpn_peer_reset_sockaddr - recreate binding for peer + * @peer: peer to recreate the binding for + * @ss: sockaddr to use as remote endpoint for the binding + * @local_ip: local IP for the binding + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer, + const struct sockaddr_storage *ss, + const u8 *local_ip) + __must_hold(&peer->lock) +{ + struct ovpn_bind *bind; + size_t ip_len; + + /* create new ovpn_bind object */ + bind = ovpn_bind_from_sockaddr(ss); + if (IS_ERR(bind)) + return PTR_ERR(bind); + + if (local_ip) { + if (ss->ss_family == AF_INET) { + ip_len = sizeof(struct in_addr); + } else if (ss->ss_family == AF_INET6) { + ip_len = sizeof(struct in6_addr); + } else { + netdev_dbg(peer->ovpn->dev, "%s: invalid family for remote endpoint\n", + __func__); + kfree(bind); + return -EINVAL; + } + + memcpy(&bind->local, local_ip, ip_len); + } + + /* set binding */ + ovpn_bind_reset(peer, bind); + + return 0; +} + +#define ovpn_get_hash_head(_tbl, _key, _key_len) ({ \ + typeof(_tbl) *__tbl = &(_tbl); \ + (&(*__tbl)[jhash(_key, _key_len, 0) % HASH_SIZE(*__tbl)]); }) \ + +/** + * ovpn_peer_float - update remote endpoint for peer + * @peer: peer to update the remote endpoint for + * @skb: incoming packet to retrieve the source address (remote) from + */ +void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct hlist_nulls_head *nhead; + struct sockaddr_storage ss; + const u8 *local_ip = NULL; + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa; + struct ovpn_bind *bind; + sa_family_t family; + size_t salen; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) { + rcu_read_unlock(); + return; + } + + spin_lock_bh(&peer->lock); + if (likely(ovpn_bind_skb_src_match(bind, skb))) + goto unlock; + + family = skb_protocol_to_family(skb); + + if (bind->remote.in4.sin_family == family) + local_ip = (u8 *)&bind->local; + + switch (family) { + case AF_INET: + sa = (struct sockaddr_in *)&ss; + sa->sin_family = AF_INET; + sa->sin_addr.s_addr = ip_hdr(skb)->saddr; + sa->sin_port = udp_hdr(skb)->source; + salen = sizeof(*sa); + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)&ss; + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = ipv6_hdr(skb)->saddr; + sa6->sin6_port = udp_hdr(skb)->source; + sa6->sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, + skb->skb_iif); + salen = sizeof(*sa6); + break; + default: + goto unlock; + } + + netdev_dbg(peer->ovpn->dev, "%s: peer %d floated to %pIScp", __func__, + peer->id, &ss); + ovpn_peer_reset_sockaddr(peer, (struct sockaddr_storage *)&ss, + local_ip); + spin_unlock_bh(&peer->lock); + rcu_read_unlock(); + + /* rehashing is required only in MP mode as P2P has one peer + * only and thus there is no hashtable + */ + if (peer->ovpn->mode == OVPN_MODE_MP) { + spin_lock_bh(&peer->ovpn->peers->lock); + /* remove old hashing */ + hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr); + /* re-add with new transport address */ + nhead = ovpn_get_hash_head(peer->ovpn->peers->by_transp_addr, + &ss, salen); + hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead); + spin_unlock_bh(&peer->ovpn->peers->lock); + } + return; +unlock: + spin_unlock_bh(&peer->lock); + rcu_read_unlock(); +} + +void ovpn_peer_release(struct ovpn_peer *peer) +{ + if (peer->sock) + ovpn_socket_put(peer->sock); + + ovpn_crypto_state_release(&peer->crypto); + spin_lock_bh(&peer->lock); + ovpn_bind_reset(peer, NULL); + spin_unlock_bh(&peer->lock); + + dst_cache_destroy(&peer->dst_cache); + netdev_put(peer->ovpn->dev, &peer->ovpn->dev_tracker); + kfree_rcu(peer, rcu); +} + +/** + * ovpn_peer_release_kref - callback for kref_put + * @kref: the kref object belonging to the peer + */ +void ovpn_peer_release_kref(struct kref *kref) +{ + struct ovpn_peer *peer = container_of(kref, struct ovpn_peer, refcount); + + ovpn_nl_peer_del_notify(peer); + ovpn_peer_release(peer); +} + +/** + * ovpn_peer_skb_to_sockaddr - fill sockaddr with skb source address + * @skb: the packet to extract data from + * @ss: the sockaddr to fill + * + * Return: true on success or false otherwise + */ +static bool ovpn_peer_skb_to_sockaddr(struct sk_buff *skb, + struct sockaddr_storage *ss) +{ + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + + ss->ss_family = skb_protocol_to_family(skb); + switch (ss->ss_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)ss; + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = ip_hdr(skb)->saddr; + sa4->sin_port = udp_hdr(skb)->source; + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)ss; + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = ipv6_hdr(skb)->saddr; + sa6->sin6_port = udp_hdr(skb)->source; + break; + default: + return false; + } + + return true; +} + +/** + * ovpn_nexthop_from_skb4 - retrieve IPv4 nexthop for outgoing skb + * @skb: the outgoing packet + * + * Return: the IPv4 of the nexthop + */ +static __be32 ovpn_nexthop_from_skb4(struct sk_buff *skb) +{ + const struct rtable *rt = skb_rtable(skb); + + if (rt && rt->rt_uses_gateway) + return rt->rt_gw4; + + return ip_hdr(skb)->daddr; +} + +/** + * ovpn_nexthop_from_skb6 - retrieve IPv6 nexthop for outgoing skb + * @skb: the outgoing packet + * + * Return: the IPv6 of the nexthop + */ +static struct in6_addr ovpn_nexthop_from_skb6(struct sk_buff *skb) +{ + const struct rt6_info *rt = skb_rt6_info(skb); + + if (!rt || !(rt->rt6i_flags & RTF_GATEWAY)) + return ipv6_hdr(skb)->daddr; + + return rt->rt6i_gateway; +} + +/** + * ovpn_peer_get_by_vpn_addr4 - retrieve peer by its VPN IPv4 address + * @ovpn: the openvpn instance to search + * @addr: VPN IPv4 to use as search key + * + * Refcounter is not increased for the returned peer. + * + * Return: the peer if found or NULL otherwise + */ +static struct ovpn_peer *ovpn_peer_get_by_vpn_addr4(struct ovpn_struct *ovpn, + __be32 addr) +{ + struct hlist_nulls_head *nhead; + struct hlist_nulls_node *ntmp; + struct ovpn_peer *tmp; + + nhead = ovpn_get_hash_head(ovpn->peers->by_vpn_addr, &addr, + sizeof(addr)); + + hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead, hash_entry_addr4) + if (addr == tmp->vpn_addrs.ipv4.s_addr) + return tmp; + + return NULL; +} + +/** + * ovpn_peer_get_by_vpn_addr6 - retrieve peer by its VPN IPv6 address + * @ovpn: the openvpn instance to search + * @addr: VPN IPv6 to use as search key + * + * Refcounter is not increased for the returned peer. + * + * Return: the peer if found or NULL otherwise + */ +static struct ovpn_peer *ovpn_peer_get_by_vpn_addr6(struct ovpn_struct *ovpn, + struct in6_addr *addr) +{ + struct hlist_nulls_head *nhead; + struct hlist_nulls_node *ntmp; + struct ovpn_peer *tmp; + + nhead = ovpn_get_hash_head(ovpn->peers->by_vpn_addr, addr, + sizeof(*addr)); + + hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead, hash_entry_addr6) + if (ipv6_addr_equal(addr, &tmp->vpn_addrs.ipv6)) + return tmp; + + return NULL; +} + +/** + * ovpn_peer_transp_match - check if sockaddr and peer binding match + * @peer: the peer to get the binding from + * @ss: the sockaddr to match + * + * Return: true if sockaddr and binding match or false otherwise + */ +static bool ovpn_peer_transp_match(const struct ovpn_peer *peer, + const struct sockaddr_storage *ss) +{ + struct ovpn_bind *bind = rcu_dereference(peer->bind); + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + + if (unlikely(!bind)) + return false; + + if (ss->ss_family != bind->remote.in4.sin_family) + return false; + + switch (ss->ss_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)ss; + if (sa4->sin_addr.s_addr != bind->remote.in4.sin_addr.s_addr) + return false; + if (sa4->sin_port != bind->remote.in4.sin_port) + return false; + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)ss; + if (!ipv6_addr_equal(&sa6->sin6_addr, + &bind->remote.in6.sin6_addr)) + return false; + if (sa6->sin6_port != bind->remote.in6.sin6_port) + return false; + break; + default: + return false; + } + + return true; +} + +/** + * ovpn_peer_get_by_transp_addr_p2p - get peer by transport address in a P2P + * instance + * @ovpn: the openvpn instance to search + * @ss: the transport socket address + * + * Return: the peer if found or NULL otherwise + */ +static struct ovpn_peer * +ovpn_peer_get_by_transp_addr_p2p(struct ovpn_struct *ovpn, + struct sockaddr_storage *ss) +{ + struct ovpn_peer *tmp, *peer = NULL; + + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (likely(tmp && ovpn_peer_transp_match(tmp, ss) && + ovpn_peer_hold(tmp))) + peer = tmp; + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_peer_get_by_transp_addr - retrieve peer by transport address + * @ovpn: the openvpn instance to search + * @skb: the skb to retrieve the source transport address from + * + * Return: a pointer to the peer if found or NULL otherwise + */ +struct ovpn_peer *ovpn_peer_get_by_transp_addr(struct ovpn_struct *ovpn, + struct sk_buff *skb) +{ + struct ovpn_peer *tmp, *peer = NULL; + struct sockaddr_storage ss = { 0 }; + struct hlist_nulls_head *nhead; + struct hlist_nulls_node *ntmp; + size_t sa_len; + + if (unlikely(!ovpn_peer_skb_to_sockaddr(skb, &ss))) + return NULL; + + if (ovpn->mode == OVPN_MODE_P2P) + return ovpn_peer_get_by_transp_addr_p2p(ovpn, &ss); + + switch (ss.ss_family) { + case AF_INET: + sa_len = sizeof(struct sockaddr_in); + break; + case AF_INET6: + sa_len = sizeof(struct sockaddr_in6); + break; + default: + return NULL; + } + + nhead = ovpn_get_hash_head(ovpn->peers->by_transp_addr, &ss, sa_len); + + rcu_read_lock(); + hlist_nulls_for_each_entry_rcu(tmp, ntmp, nhead, + hash_entry_transp_addr) { + if (!ovpn_peer_transp_match(tmp, &ss)) + continue; + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_peer_get_by_id_p2p - get peer by ID in a P2P instance + * @ovpn: the openvpn instance to search + * @peer_id: the ID of the peer to find + * + * Return: the peer if found or NULL otherwise + */ +static struct ovpn_peer *ovpn_peer_get_by_id_p2p(struct ovpn_struct *ovpn, + u32 peer_id) +{ + struct ovpn_peer *tmp, *peer = NULL; + + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (likely(tmp && tmp->id == peer_id && ovpn_peer_hold(tmp))) + peer = tmp; + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_peer_get_by_id - retrieve peer by ID + * @ovpn: the openvpn instance to search + * @peer_id: the unique peer identifier to match + * + * Return: a pointer to the peer if found or NULL otherwise + */ +struct ovpn_peer *ovpn_peer_get_by_id(struct ovpn_struct *ovpn, u32 peer_id) +{ + struct ovpn_peer *tmp, *peer = NULL; + struct hlist_head *head; + + if (ovpn->mode == OVPN_MODE_P2P) + return ovpn_peer_get_by_id_p2p(ovpn, peer_id); + + head = ovpn_get_hash_head(ovpn->peers->by_id, &peer_id, + sizeof(peer_id)); + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, hash_entry_id) { + if (tmp->id != peer_id) + continue; + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_peer_update_local_endpoint - update local endpoint for peer + * @peer: peer to update the endpoint for + * @skb: incoming packet to retrieve the destination address (local) from + */ +void ovpn_peer_update_local_endpoint(struct ovpn_peer *peer, + struct sk_buff *skb) +{ + struct ovpn_bind *bind; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) + goto unlock; + + spin_lock_bh(&peer->lock); + switch (skb_protocol_to_family(skb)) { + case AF_INET: + if (unlikely(bind->local.ipv4.s_addr != ip_hdr(skb)->daddr)) { + netdev_dbg(peer->ovpn->dev, + "%s: learning local IPv4 for peer %d (%pI4 -> %pI4)\n", + __func__, peer->id, &bind->local.ipv4.s_addr, + &ip_hdr(skb)->daddr); + bind->local.ipv4.s_addr = ip_hdr(skb)->daddr; + } + break; + case AF_INET6: + if (unlikely(!ipv6_addr_equal(&bind->local.ipv6, + &ipv6_hdr(skb)->daddr))) { + netdev_dbg(peer->ovpn->dev, + "%s: learning local IPv6 for peer %d (%pI6c -> %pI6c\n", + __func__, peer->id, &bind->local.ipv6, + &ipv6_hdr(skb)->daddr); + bind->local.ipv6 = ipv6_hdr(skb)->daddr; + } + break; + default: + break; + } + spin_unlock_bh(&peer->lock); + +unlock: + rcu_read_unlock(); +} + +/** + * ovpn_peer_get_by_dst - Lookup peer to send skb to + * @ovpn: the private data representing the current VPN session + * @skb: the skb to extract the destination address from + * + * This function takes a tunnel packet and looks up the peer to send it to + * after encapsulation. The skb is expected to be the in-tunnel packet, without + * any OpenVPN related header. + * + * Assume that the IP header is accessible in the skb data. + * + * Return: the peer if found or NULL otherwise. + */ +struct ovpn_peer *ovpn_peer_get_by_dst(struct ovpn_struct *ovpn, + struct sk_buff *skb) +{ + struct ovpn_peer *peer = NULL; + struct in6_addr addr6; + __be32 addr4; + + /* in P2P mode, no matter the destination, packets are always sent to + * the single peer listening on the other side + */ + if (ovpn->mode == OVPN_MODE_P2P) { + rcu_read_lock(); + peer = rcu_dereference(ovpn->peer); + if (unlikely(peer && !ovpn_peer_hold(peer))) + peer = NULL; + rcu_read_unlock(); + return peer; + } + + rcu_read_lock(); + switch (skb_protocol_to_family(skb)) { + case AF_INET: + addr4 = ovpn_nexthop_from_skb4(skb); + peer = ovpn_peer_get_by_vpn_addr4(ovpn, addr4); + break; + case AF_INET6: + addr6 = ovpn_nexthop_from_skb6(skb); + peer = ovpn_peer_get_by_vpn_addr6(ovpn, &addr6); + break; + } + + if (unlikely(peer && !ovpn_peer_hold(peer))) + peer = NULL; + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_nexthop_from_rt4 - look up the IPv4 nexthop for the given destination + * @ovpn: the private data representing the current VPN session + * @dest: the destination to be looked up + * + * Looks up in the IPv4 system routing table the IP of the nexthop to be used + * to reach the destination passed as argument. If no nexthop can be found, the + * destination itself is returned as it probably has to be used as nexthop. + * + * Return: the IP of the next hop if found or dest itself otherwise + */ +static __be32 ovpn_nexthop_from_rt4(struct ovpn_struct *ovpn, __be32 dest) +{ + struct rtable *rt; + struct flowi4 fl = { + .daddr = dest + }; + + rt = ip_route_output_flow(dev_net(ovpn->dev), &fl, NULL); + if (IS_ERR(rt)) { + net_dbg_ratelimited("%s: no route to host %pI4\n", __func__, + &dest); + /* if we end up here this packet is probably going to be + * thrown away later + */ + return dest; + } + + if (!rt->rt_uses_gateway) + goto out; + + dest = rt->rt_gw4; +out: + ip_rt_put(rt); + return dest; +} + +/** + * ovpn_nexthop_from_rt6 - look up the IPv6 nexthop for the given destination + * @ovpn: the private data representing the current VPN session + * @dest: the destination to be looked up + * + * Looks up in the IPv6 system routing table the IP of the nexthop to be used + * to reach the destination passed as argument. If no nexthop can be found, the + * destination itself is returned as it probably has to be used as nexthop. + * + * Return: the IP of the next hop if found or dest itself otherwise + */ +static struct in6_addr ovpn_nexthop_from_rt6(struct ovpn_struct *ovpn, + struct in6_addr dest) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct dst_entry *entry; + struct rt6_info *rt; + struct flowi6 fl = { + .daddr = dest, + }; + + entry = ipv6_stub->ipv6_dst_lookup_flow(dev_net(ovpn->dev), NULL, &fl, + NULL); + if (IS_ERR(entry)) { + net_dbg_ratelimited("%s: no route to host %pI6c\n", __func__, + &dest); + /* if we end up here this packet is probably going to be + * thrown away later + */ + return dest; + } + + rt = dst_rt6_info(entry); + + if (!(rt->rt6i_flags & RTF_GATEWAY)) + goto out; + + dest = rt->rt6i_gateway; +out: + dst_release((struct dst_entry *)rt); +#endif + return dest; +} + +/** + * ovpn_peer_check_by_src - check that skb source is routed via peer + * @ovpn: the openvpn instance to search + * @skb: the packet to extract source address from + * @peer: the peer to check against the source address + * + * Return: true if the peer is matching or false otherwise + */ +bool ovpn_peer_check_by_src(struct ovpn_struct *ovpn, struct sk_buff *skb, + struct ovpn_peer *peer) +{ + bool match = false; + struct in6_addr addr6; + __be32 addr4; + + if (ovpn->mode == OVPN_MODE_P2P) { + /* in P2P mode, no matter the destination, packets are always + * sent to the single peer listening on the other side + */ + rcu_read_lock(); + match = (peer == rcu_dereference(ovpn->peer)); + rcu_read_unlock(); + return match; + } + + /* This function performs a reverse path check, therefore we now + * lookup the nexthop we would use if we wanted to route a packet + * to the source IP. If the nexthop matches the sender we know the + * latter is valid and we allow the packet to come in + */ + + switch (skb_protocol_to_family(skb)) { + case AF_INET: + addr4 = ovpn_nexthop_from_rt4(ovpn, ip_hdr(skb)->saddr); + rcu_read_lock(); + match = (peer == ovpn_peer_get_by_vpn_addr4(ovpn, addr4)); + rcu_read_unlock(); + break; + case AF_INET6: + addr6 = ovpn_nexthop_from_rt6(ovpn, ipv6_hdr(skb)->saddr); + rcu_read_lock(); + match = (peer == ovpn_peer_get_by_vpn_addr6(ovpn, &addr6)); + rcu_read_unlock(); + break; + } + + return match; +} + +void ovpn_peer_hash_vpn_ip(struct ovpn_peer *peer) + __must_hold(&peer->ovpn->peers->lock) +{ + struct hlist_nulls_head *nhead; + + if (peer->vpn_addrs.ipv4.s_addr != htonl(INADDR_ANY)) { + /* remove potential old hashing */ + hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr); + + nhead = ovpn_get_hash_head(peer->ovpn->peers->by_vpn_addr, + &peer->vpn_addrs.ipv4, + sizeof(peer->vpn_addrs.ipv4)); + hlist_nulls_add_head_rcu(&peer->hash_entry_addr4, nhead); + } + + if (!ipv6_addr_any(&peer->vpn_addrs.ipv6)) { + /* remove potential old hashing */ + hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr); + + nhead = ovpn_get_hash_head(peer->ovpn->peers->by_vpn_addr, + &peer->vpn_addrs.ipv6, + sizeof(peer->vpn_addrs.ipv6)); + hlist_nulls_add_head_rcu(&peer->hash_entry_addr6, nhead); + } +} + +/** + * ovpn_peer_add_mp - add peer to related tables in a MP instance + * @ovpn: the instance to add the peer to + * @peer: the peer to add + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_peer_add_mp(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + struct sockaddr_storage sa = { 0 }; + struct hlist_nulls_head *nhead; + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + struct ovpn_bind *bind; + struct ovpn_peer *tmp; + size_t salen; + int ret = 0; + + spin_lock_bh(&ovpn->peers->lock); + /* do not add duplicates */ + tmp = ovpn_peer_get_by_id(ovpn, peer->id); + if (tmp) { + ovpn_peer_put(tmp); + ret = -EEXIST; + goto out; + } + + bind = rcu_dereference_protected(peer->bind, true); + /* peers connected via TCP have bind == NULL */ + if (bind) { + switch (bind->remote.in4.sin_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)&sa; + + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = bind->remote.in4.sin_addr.s_addr; + sa4->sin_port = bind->remote.in4.sin_port; + salen = sizeof(*sa4); + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)&sa; + + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = bind->remote.in6.sin6_addr; + sa6->sin6_port = bind->remote.in6.sin6_port; + salen = sizeof(*sa6); + break; + default: + ret = -EPROTONOSUPPORT; + goto out; + } + + nhead = ovpn_get_hash_head(ovpn->peers->by_transp_addr, &sa, + salen); + hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead); + } + + hlist_add_head_rcu(&peer->hash_entry_id, + ovpn_get_hash_head(ovpn->peers->by_id, &peer->id, + sizeof(peer->id))); + + ovpn_peer_hash_vpn_ip(peer); +out: + spin_unlock_bh(&ovpn->peers->lock); + return ret; +} + +/** + * ovpn_peer_add_p2p - add peer to related tables in a P2P instance + * @ovpn: the instance to add the peer to + * @peer: the peer to add + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_peer_add_p2p(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + struct ovpn_peer *tmp; + + spin_lock_bh(&ovpn->lock); + /* in p2p mode it is possible to have a single peer only, therefore the + * old one is released and substituted by the new one + */ + tmp = rcu_dereference_protected(ovpn->peer, + lockdep_is_held(&ovpn->lock)); + if (tmp) { + tmp->delete_reason = OVPN_DEL_PEER_REASON_TEARDOWN; + ovpn_peer_put(tmp); + } + + rcu_assign_pointer(ovpn->peer, peer); + spin_unlock_bh(&ovpn->lock); + + return 0; +} + +/** + * ovpn_peer_add - add peer to the related tables + * @ovpn: the openvpn instance the peer belongs to + * @peer: the peer object to add + * + * Assume refcounter was increased by caller + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_peer_add(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + switch (ovpn->mode) { + case OVPN_MODE_MP: + return ovpn_peer_add_mp(ovpn, peer); + case OVPN_MODE_P2P: + return ovpn_peer_add_p2p(ovpn, peer); + default: + return -EOPNOTSUPP; + } +} + +/** + * ovpn_peer_unhash - remove peer reference from all hashtables + * @peer: the peer to remove + * @reason: the delete reason to attach to the peer + */ +static void ovpn_peer_unhash(struct ovpn_peer *peer, + enum ovpn_del_peer_reason reason) + __must_hold(&ovpn->peers->lock) +{ + hlist_del_init_rcu(&peer->hash_entry_id); + + hlist_nulls_del_init_rcu(&peer->hash_entry_addr4); + hlist_nulls_del_init_rcu(&peer->hash_entry_addr6); + hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr); + + ovpn_peer_put(peer); + peer->delete_reason = reason; +} + +/** + * ovpn_peer_del_mp - delete peer from related tables in a MP instance + * @peer: the peer to delete + * @reason: reason why the peer was deleted (sent to userspace) + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_peer_del_mp(struct ovpn_peer *peer, + enum ovpn_del_peer_reason reason) + __must_hold(&peer->ovpn->peers->lock) +{ + struct ovpn_peer *tmp; + int ret = -ENOENT; + + tmp = ovpn_peer_get_by_id(peer->ovpn, peer->id); + if (tmp == peer) { + ovpn_peer_unhash(peer, reason); + ret = 0; + } + + if (tmp) + ovpn_peer_put(tmp); + + return ret; +} + +/** + * ovpn_peer_del_p2p - delete peer from related tables in a P2P instance + * @peer: the peer to delete + * @reason: reason why the peer was deleted (sent to userspace) + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_peer_del_p2p(struct ovpn_peer *peer, + enum ovpn_del_peer_reason reason) + __must_hold(&peer->ovpn->lock) +{ + struct ovpn_peer *tmp; + + tmp = rcu_dereference_protected(peer->ovpn->peer, + lockdep_is_held(&peer->ovpn->lock)); + if (tmp != peer) { + DEBUG_NET_WARN_ON_ONCE(1); + if (tmp) + ovpn_peer_put(tmp); + + return -ENOENT; + } + + tmp->delete_reason = reason; + RCU_INIT_POINTER(peer->ovpn->peer, NULL); + ovpn_peer_put(tmp); + + return 0; +} + +/** + * ovpn_peer_release_p2p - release peer upon P2P device teardown + * @ovpn: the instance being torn down + */ +void ovpn_peer_release_p2p(struct ovpn_struct *ovpn) +{ + struct ovpn_peer *tmp; + + spin_lock_bh(&ovpn->lock); + tmp = rcu_dereference_protected(ovpn->peer, + lockdep_is_held(&ovpn->lock)); + if (tmp) + ovpn_peer_del_p2p(tmp, OVPN_DEL_PEER_REASON_TEARDOWN); + spin_unlock_bh(&ovpn->lock); +} + +/** + * ovpn_peer_del - delete peer from related tables + * @peer: the peer object to delete + * @reason: reason for deleting peer (will be sent to userspace) + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_peer_del(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason) +{ + int ret; + + switch (peer->ovpn->mode) { + case OVPN_MODE_MP: + spin_lock_bh(&peer->ovpn->peers->lock); + ret = ovpn_peer_del_mp(peer, reason); + spin_unlock_bh(&peer->ovpn->peers->lock); + return ret; + case OVPN_MODE_P2P: + spin_lock_bh(&peer->ovpn->lock); + ret = ovpn_peer_del_p2p(peer, reason); + spin_unlock_bh(&peer->ovpn->lock); + return ret; + default: + return -EOPNOTSUPP; + } +} + +static int ovpn_peer_del_nolock(struct ovpn_peer *peer, + enum ovpn_del_peer_reason reason) +{ + switch (peer->ovpn->mode) { + case OVPN_MODE_MP: + return ovpn_peer_del_mp(peer, reason); + case OVPN_MODE_P2P: + return ovpn_peer_del_p2p(peer, reason); + default: + return -EOPNOTSUPP; + } +} + +/** + * ovpn_peers_free - free all peers in the instance + * @ovpn: the instance whose peers should be released + */ +void ovpn_peers_free(struct ovpn_struct *ovpn) +{ + struct hlist_node *tmp; + struct ovpn_peer *peer; + int bkt; + + spin_lock_bh(&ovpn->peers->lock); + hash_for_each_safe(ovpn->peers->by_id, bkt, tmp, peer, hash_entry_id) + ovpn_peer_unhash(peer, OVPN_DEL_PEER_REASON_TEARDOWN); + spin_unlock_bh(&ovpn->peers->lock); +} + +static time64_t ovpn_peer_keepalive_work_single(struct ovpn_peer *peer, + time64_t now) +{ + time64_t next_run1, next_run2, delta; + unsigned long timeout, interval; + bool expired; + + spin_lock_bh(&peer->lock); + /* we expect both timers to be configured at the same time, + * therefore bail out if either is not set + */ + if (!peer->keepalive_timeout || !peer->keepalive_interval) { + spin_unlock_bh(&peer->lock); + return 0; + } + + /* check for peer timeout */ + expired = false; + timeout = peer->keepalive_timeout; + delta = now - peer->last_recv; + if (delta < timeout) { + peer->keepalive_recv_exp = now + timeout - delta; + next_run1 = peer->keepalive_recv_exp; + } else if (peer->keepalive_recv_exp > now) { + next_run1 = peer->keepalive_recv_exp; + } else { + expired = true; + } + + if (expired) { + /* peer is dead -> kill it and move on */ + spin_unlock_bh(&peer->lock); + netdev_dbg(peer->ovpn->dev, "peer %u expired\n", + peer->id); + ovpn_peer_del_nolock(peer, OVPN_DEL_PEER_REASON_EXPIRED); + return 0; + } + + /* check for peer keepalive */ + expired = false; + interval = peer->keepalive_interval; + delta = now - peer->last_sent; + if (delta < interval) { + peer->keepalive_xmit_exp = now + interval - delta; + next_run2 = peer->keepalive_xmit_exp; + } else if (peer->keepalive_xmit_exp > now) { + next_run2 = peer->keepalive_xmit_exp; + } else { + expired = true; + next_run2 = now + interval; + } + spin_unlock_bh(&peer->lock); + + if (expired) { + /* a keepalive packet is required */ + netdev_dbg(peer->ovpn->dev, + "sending keepalive to peer %u\n", + peer->id); + ovpn_xmit_special(peer, ovpn_keepalive_message, + sizeof(ovpn_keepalive_message)); + } + + if (next_run1 < next_run2) + return next_run1; + + return next_run2; +} + +static time64_t ovpn_peer_keepalive_work_mp(struct ovpn_struct *ovpn, + time64_t now) +{ + time64_t tmp_next_run, next_run = 0; + struct hlist_node *tmp; + struct ovpn_peer *peer; + int bkt; + + spin_lock_bh(&ovpn->peers->lock); + hash_for_each_safe(ovpn->peers->by_id, bkt, tmp, peer, hash_entry_id) { + tmp_next_run = ovpn_peer_keepalive_work_single(peer, now); + if (!tmp_next_run) + continue; + + /* the next worker run will be scheduled based on the shortest + * required interval across all peers + */ + if (!next_run || tmp_next_run < next_run) + next_run = tmp_next_run; + } + spin_unlock_bh(&ovpn->peers->lock); + + return next_run; +} + +static time64_t ovpn_peer_keepalive_work_p2p(struct ovpn_struct *ovpn, + time64_t now) +{ + struct ovpn_peer *peer; + time64_t next_run = 0; + + spin_lock_bh(&ovpn->lock); + peer = rcu_dereference_protected(ovpn->peer, + lockdep_is_held(&ovpn->lock)); + if (peer) + next_run = ovpn_peer_keepalive_work_single(peer, now); + spin_unlock_bh(&ovpn->lock); + + return next_run; +} + +/** + * ovpn_peer_keepalive_work - run keepalive logic on each known peer + * @work: pointer to the work member of the related ovpn object + * + * Each peer has two timers (if configured): + * 1. peer timeout: when no data is received for a certain interval, + * the peer is considered dead and it gets killed. + * 2. peer keepalive: when no data is sent to a certain peer for a + * certain interval, a special 'keepalive' packet is explicitly sent. + * + * This function iterates across the whole peer collection while + * checking the timers described above. + */ +void ovpn_peer_keepalive_work(struct work_struct *work) +{ + struct ovpn_struct *ovpn = container_of(work, struct ovpn_struct, + keepalive_work.work); + time64_t next_run = 0, now = ktime_get_real_seconds(); + + switch (ovpn->mode) { + case OVPN_MODE_MP: + next_run = ovpn_peer_keepalive_work_mp(ovpn, now); + break; + case OVPN_MODE_P2P: + next_run = ovpn_peer_keepalive_work_p2p(ovpn, now); + break; + } + + /* prevent rearming if the interface is being destroyed */ + if (next_run > 0 && ovpn->registered) { + netdev_dbg(ovpn->dev, + "scheduling keepalive work: now=%llu next_run=%llu delta=%llu\n", + next_run, now, next_run - now); + schedule_delayed_work(&ovpn->keepalive_work, + (next_run - now) * HZ); + } +} diff --git a/drivers/net/ovpn/peer.h b/drivers/net/ovpn/peer.h new file mode 100644 index 000000000000..1adecd0f79f8 --- /dev/null +++ b/drivers/net/ovpn/peer.h @@ -0,0 +1,165 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_OVPNPEER_H_ +#define _NET_OVPN_OVPNPEER_H_ + +#include +#include + +#include "crypto.h" +#include "stats.h" + +/** + * struct ovpn_peer - the main remote peer object + * @ovpn: main openvpn instance this peer belongs to + * @id: unique identifier + * @vpn_addrs: IP addresses assigned over the tunnel + * @vpn_addrs.ipv4: IPv4 assigned to peer on the tunnel + * @vpn_addrs.ipv6: IPv6 assigned to peer on the tunnel + * @hash_entry_id: entry in the peer ID hashtable + * @hash_entry_addr4: entry in the peer IPv4 hashtable + * @hash_entry_addr6: entry in the peer IPv6 hashtable + * @hash_entry_transp_addr: entry in the peer transport address hashtable + * @sock: the socket being used to talk to this peer + * @tcp: keeps track of TCP specific state + * @tcp.strp: stream parser context (TCP only) + * @tcp.tx_work: work for deferring outgoing packet processing (TCP only) + * @tcp.user_queue: received packets that have to go to userspace (TCP only) + * @tcp.tx_in_progress: true if TX is already ongoing (TCP only) + * @tcp.out_msg.skb: packet scheduled for sending (TCP only) + * @tcp.out_msg.offset: offset where next send should start (TCP only) + * @tcp.out_msg.len: remaining data to send within packet (TCP only) + * @tcp.sk_cb.sk_data_ready: pointer to original cb (TCP only) + * @tcp.sk_cb.sk_write_space: pointer to original cb (TCP only) + * @tcp.sk_cb.prot: pointer to original prot object (TCP only) + * @tcp.sk_cb.ops: pointer to the original prot_ops object (TCP only) + * @crypto: the crypto configuration (ciphers, keys, etc..) + * @dst_cache: cache for dst_entry used to send to peer + * @bind: remote peer binding + * @keepalive_interval: seconds after which a new keepalive should be sent + * @keepalive_xmit_exp: future timestamp when next keepalive should be sent + * @last_sent: timestamp of the last successfully sent packet + * @keepalive_timeout: seconds after which an inactive peer is considered dead + * @keepalive_recv_exp: future timestamp when the peer should expire + * @last_recv: timestamp of the last authenticated received packet + * @halt: true if ovpn_peer_mark_delete was called + * @vpn_stats: per-peer in-VPN TX/RX stays + * @link_stats: per-peer link/transport TX/RX stats + * @delete_reason: why peer was deleted (i.e. timeout, transport error, ..) + * @lock: protects binding to peer (bind) + * @refcount: reference counter + * @rcu: used to free peer in an RCU safe way + * @delete_work: deferred cleanup work, used to notify userspace + */ +struct ovpn_peer { + struct ovpn_struct *ovpn; + u32 id; + struct { + struct in_addr ipv4; + struct in6_addr ipv6; + } vpn_addrs; + struct hlist_node hash_entry_id; + struct hlist_nulls_node hash_entry_addr4; + struct hlist_nulls_node hash_entry_addr6; + struct hlist_nulls_node hash_entry_transp_addr; + struct ovpn_socket *sock; + + /* state of the TCP reading. Needed to keep track of how much of a + * single packet has already been read from the stream and how much is + * missing + */ + struct { + struct strparser strp; + struct work_struct tx_work; + struct sk_buff_head user_queue; + bool tx_in_progress; + + struct { + struct sk_buff *skb; + int offset; + int len; + } out_msg; + + struct { + void (*sk_data_ready)(struct sock *sk); + void (*sk_write_space)(struct sock *sk); + struct proto *prot; + const struct proto_ops *ops; + } sk_cb; + } tcp; + struct ovpn_crypto_state crypto; + struct dst_cache dst_cache; + struct ovpn_bind __rcu *bind; + unsigned long keepalive_interval; + unsigned long keepalive_xmit_exp; + time64_t last_sent; + unsigned long keepalive_timeout; + unsigned long keepalive_recv_exp; + time64_t last_recv; + bool halt; + struct ovpn_peer_stats vpn_stats; + struct ovpn_peer_stats link_stats; + enum ovpn_del_peer_reason delete_reason; + spinlock_t lock; /* protects bind */ + struct kref refcount; + struct rcu_head rcu; + struct work_struct delete_work; +}; + +/** + * ovpn_peer_hold - increase reference counter + * @peer: the peer whose counter should be increased + * + * Return: true if the counter was increased or false if it was zero already + */ +static inline bool ovpn_peer_hold(struct ovpn_peer *peer) +{ + return kref_get_unless_zero(&peer->refcount); +} + +void ovpn_peer_release(struct ovpn_peer *peer); +void ovpn_peer_release_kref(struct kref *kref); + +/** + * ovpn_peer_put - decrease reference counter + * @peer: the peer whose counter should be decreased + */ +static inline void ovpn_peer_put(struct ovpn_peer *peer) +{ + kref_put(&peer->refcount, ovpn_peer_release_kref); +} + +struct ovpn_peer *ovpn_peer_new(struct ovpn_struct *ovpn, u32 id); +int ovpn_peer_add(struct ovpn_struct *ovpn, struct ovpn_peer *peer); +int ovpn_peer_del(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason); +void ovpn_peer_release_p2p(struct ovpn_struct *ovpn); +void ovpn_peers_free(struct ovpn_struct *ovpn); + +struct ovpn_peer *ovpn_peer_get_by_transp_addr(struct ovpn_struct *ovpn, + struct sk_buff *skb); +struct ovpn_peer *ovpn_peer_get_by_id(struct ovpn_struct *ovpn, u32 peer_id); +struct ovpn_peer *ovpn_peer_get_by_dst(struct ovpn_struct *ovpn, + struct sk_buff *skb); +void ovpn_peer_hash_vpn_ip(struct ovpn_peer *peer); +bool ovpn_peer_check_by_src(struct ovpn_struct *ovpn, struct sk_buff *skb, + struct ovpn_peer *peer); + +void ovpn_peer_keepalive_set(struct ovpn_peer *peer, u32 interval, u32 timeout); +void ovpn_peer_keepalive_work(struct work_struct *work); + +void ovpn_peer_update_local_endpoint(struct ovpn_peer *peer, + struct sk_buff *skb); + +void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb); +int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer, + const struct sockaddr_storage *ss, + const u8 *local_ip); + +#endif /* _NET_OVPN_OVPNPEER_H_ */ diff --git a/drivers/net/ovpn/pktid.c b/drivers/net/ovpn/pktid.c new file mode 100644 index 000000000000..96dc87635670 --- /dev/null +++ b/drivers/net/ovpn/pktid.c @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "packet.h" +#include "pktid.h" + +void ovpn_pktid_xmit_init(struct ovpn_pktid_xmit *pid) +{ + atomic64_set(&pid->seq_num, 1); +} + +void ovpn_pktid_recv_init(struct ovpn_pktid_recv *pr) +{ + memset(pr, 0, sizeof(*pr)); + spin_lock_init(&pr->lock); +} + +/* Packet replay detection. + * Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1. + */ +int ovpn_pktid_recv(struct ovpn_pktid_recv *pr, u32 pkt_id, u32 pkt_time) +{ + const unsigned long now = jiffies; + int ret; + + /* ID must not be zero */ + if (unlikely(pkt_id == 0)) + return -EINVAL; + + spin_lock_bh(&pr->lock); + + /* expire backtracks at or below pr->id after PKTID_RECV_EXPIRE time */ + if (unlikely(time_after_eq(now, pr->expire))) + pr->id_floor = pr->id; + + /* time changed? */ + if (unlikely(pkt_time != pr->time)) { + if (pkt_time > pr->time) { + /* time moved forward, accept */ + pr->base = 0; + pr->extent = 0; + pr->id = 0; + pr->time = pkt_time; + pr->id_floor = 0; + } else { + /* time moved backward, reject */ + ret = -ETIME; + goto out; + } + } + + if (likely(pkt_id == pr->id + 1)) { + /* well-formed ID sequence (incremented by 1) */ + pr->base = REPLAY_INDEX(pr->base, -1); + pr->history[pr->base / 8] |= (1 << (pr->base % 8)); + if (pr->extent < REPLAY_WINDOW_SIZE) + ++pr->extent; + pr->id = pkt_id; + } else if (pkt_id > pr->id) { + /* ID jumped forward by more than one */ + const unsigned int delta = pkt_id - pr->id; + + if (delta < REPLAY_WINDOW_SIZE) { + unsigned int i; + + pr->base = REPLAY_INDEX(pr->base, -delta); + pr->history[pr->base / 8] |= (1 << (pr->base % 8)); + pr->extent += delta; + if (pr->extent > REPLAY_WINDOW_SIZE) + pr->extent = REPLAY_WINDOW_SIZE; + for (i = 1; i < delta; ++i) { + unsigned int newb = REPLAY_INDEX(pr->base, i); + + pr->history[newb / 8] &= ~BIT(newb % 8); + } + } else { + pr->base = 0; + pr->extent = REPLAY_WINDOW_SIZE; + memset(pr->history, 0, sizeof(pr->history)); + pr->history[0] = 1; + } + pr->id = pkt_id; + } else { + /* ID backtrack */ + const unsigned int delta = pr->id - pkt_id; + + if (delta > pr->max_backtrack) + pr->max_backtrack = delta; + if (delta < pr->extent) { + if (pkt_id > pr->id_floor) { + const unsigned int ri = REPLAY_INDEX(pr->base, + delta); + u8 *p = &pr->history[ri / 8]; + const u8 mask = (1 << (ri % 8)); + + if (*p & mask) { + ret = -EINVAL; + goto out; + } + *p |= mask; + } else { + ret = -EINVAL; + goto out; + } + } else { + ret = -EINVAL; + goto out; + } + } + + pr->expire = now + PKTID_RECV_EXPIRE; + ret = 0; +out: + spin_unlock_bh(&pr->lock); + return ret; +} diff --git a/drivers/net/ovpn/pktid.h b/drivers/net/ovpn/pktid.h new file mode 100644 index 000000000000..fe02f0667e1a --- /dev/null +++ b/drivers/net/ovpn/pktid.h @@ -0,0 +1,87 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_OVPNPKTID_H_ +#define _NET_OVPN_OVPNPKTID_H_ + +#include "packet.h" + +/* If no packets received for this length of time, set a backtrack floor + * at highest received packet ID thus far. + */ +#define PKTID_RECV_EXPIRE (30 * HZ) + +/* Packet-ID state for transmitter */ +struct ovpn_pktid_xmit { + atomic64_t seq_num; +}; + +/* replay window sizing in bytes = 2^REPLAY_WINDOW_ORDER */ +#define REPLAY_WINDOW_ORDER 8 + +#define REPLAY_WINDOW_BYTES BIT(REPLAY_WINDOW_ORDER) +#define REPLAY_WINDOW_SIZE (REPLAY_WINDOW_BYTES * 8) +#define REPLAY_INDEX(base, i) (((base) + (i)) & (REPLAY_WINDOW_SIZE - 1)) + +/* Packet-ID state for receiver. + * Other than lock member, can be zeroed to initialize. + */ +struct ovpn_pktid_recv { + /* "sliding window" bitmask of recent packet IDs received */ + u8 history[REPLAY_WINDOW_BYTES]; + /* bit position of deque base in history */ + unsigned int base; + /* extent (in bits) of deque in history */ + unsigned int extent; + /* expiration of history in jiffies */ + unsigned long expire; + /* highest sequence number received */ + u32 id; + /* highest time stamp received */ + u32 time; + /* we will only accept backtrack IDs > id_floor */ + u32 id_floor; + unsigned int max_backtrack; + /* protects entire pktd ID state */ + spinlock_t lock; +}; + +/* Get the next packet ID for xmit */ +static inline int ovpn_pktid_xmit_next(struct ovpn_pktid_xmit *pid, u32 *pktid) +{ + const s64 seq_num = atomic64_fetch_add_unless(&pid->seq_num, 1, + 0x100000000LL); + /* when the 32bit space is over, we return an error because the packet + * ID is used to create the cipher IV and we do not want to reuse the + * same value more than once + */ + if (unlikely(seq_num == 0x100000000LL)) + return -ERANGE; + + *pktid = (u32)seq_num; + + return 0; +} + +/* Write 12-byte AEAD IV to dest */ +static inline void ovpn_pktid_aead_write(const u32 pktid, + const struct ovpn_nonce_tail *nt, + unsigned char *dest) +{ + *(__force __be32 *)(dest) = htonl(pktid); + BUILD_BUG_ON(4 + sizeof(struct ovpn_nonce_tail) != NONCE_SIZE); + memcpy(dest + 4, nt->u8, sizeof(struct ovpn_nonce_tail)); +} + +void ovpn_pktid_xmit_init(struct ovpn_pktid_xmit *pid); +void ovpn_pktid_recv_init(struct ovpn_pktid_recv *pr); + +int ovpn_pktid_recv(struct ovpn_pktid_recv *pr, u32 pkt_id, u32 pkt_time); + +#endif /* _NET_OVPN_OVPNPKTID_H_ */ diff --git a/drivers/net/ovpn/proto.h b/drivers/net/ovpn/proto.h new file mode 100644 index 000000000000..0de8bafadc89 --- /dev/null +++ b/drivers/net/ovpn/proto.h @@ -0,0 +1,104 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_OVPNPROTO_H_ +#define _NET_OVPN_OVPNPROTO_H_ + +#include "main.h" + +#include + +/* Methods for operating on the initial command + * byte of the OpenVPN protocol. + */ + +/* packet opcode (high 5 bits) and key-id (low 3 bits) are combined in + * one byte + */ +#define OVPN_KEY_ID_MASK 0x07 +#define OVPN_OPCODE_SHIFT 3 +#define OVPN_OPCODE_MASK 0x1F +/* upper bounds on opcode and key ID */ +#define OVPN_KEY_ID_MAX (OVPN_KEY_ID_MASK + 1) +#define OVPN_OPCODE_MAX (OVPN_OPCODE_MASK + 1) +/* packet opcodes of interest to us */ +#define OVPN_DATA_V1 6 /* data channel V1 packet */ +#define OVPN_DATA_V2 9 /* data channel V2 packet */ +/* size of initial packet opcode */ +#define OVPN_OP_SIZE_V1 1 +#define OVPN_OP_SIZE_V2 4 +#define OVPN_PEER_ID_MASK 0x00FFFFFF +#define OVPN_PEER_ID_UNDEF 0x00FFFFFF +/* first byte of exit message */ +#define OVPN_EXPLICIT_EXIT_NOTIFY_FIRST_BYTE 0x28 + +/** + * ovpn_opcode_from_skb - extract OP code from skb at specified offset + * @skb: the packet to extract the OP code from + * @offset: the offset in the data buffer where the OP code is located + * + * Note: this function assumes that the skb head was pulled enough + * to access the first byte. + * + * Return: the OP code + */ +static inline u8 ovpn_opcode_from_skb(const struct sk_buff *skb, u16 offset) +{ + u8 byte = *(skb->data + offset); + + return byte >> OVPN_OPCODE_SHIFT; +} + +/** + * ovpn_peer_id_from_skb - extract peer ID from skb at specified offset + * @skb: the packet to extract the OP code from + * @offset: the offset in the data buffer where the OP code is located + * + * Note: this function assumes that the skb head was pulled enough + * to access the first 4 bytes. + * + * Return: the peer ID. + */ +static inline u32 ovpn_peer_id_from_skb(const struct sk_buff *skb, u16 offset) +{ + return ntohl(*(__be32 *)(skb->data + offset)) & OVPN_PEER_ID_MASK; +} + +/** + * ovpn_key_id_from_skb - extract key ID from the skb head + * @skb: the packet to extract the key ID code from + * + * Note: this function assumes that the skb head was pulled enough + * to access the first byte. + * + * Return: the key ID + */ +static inline u8 ovpn_key_id_from_skb(const struct sk_buff *skb) +{ + return *skb->data & OVPN_KEY_ID_MASK; +} + +/** + * ovpn_opcode_compose - combine OP code, key ID and peer ID to wire format + * @opcode: the OP code + * @key_id: the key ID + * @peer_id: the peer ID + * + * Return: a 4 bytes integer obtained combining all input values following the + * OpenVPN wire format. This integer can then be written to the packet header. + */ +static inline u32 ovpn_opcode_compose(u8 opcode, u8 key_id, u32 peer_id) +{ + const u8 op = (opcode << OVPN_OPCODE_SHIFT) | + (key_id & OVPN_KEY_ID_MASK); + + return (op << 24) | (peer_id & OVPN_PEER_ID_MASK); +} + +#endif /* _NET_OVPN_OVPNPROTO_H_ */ diff --git a/drivers/net/ovpn/skb.h b/drivers/net/ovpn/skb.h new file mode 100644 index 000000000000..96afa01466ab --- /dev/null +++ b/drivers/net/ovpn/skb.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_SKB_H_ +#define _NET_OVPN_SKB_H_ + +#include +#include +#include +#include +#include +#include + +struct ovpn_cb { + struct ovpn_peer *peer; + struct ovpn_crypto_key_slot *ks; + struct aead_request *req; + struct scatterlist *sg; + unsigned int orig_len; + unsigned int payload_offset; +}; + +static inline struct ovpn_cb *ovpn_skb_cb(struct sk_buff *skb) +{ + BUILD_BUG_ON(sizeof(struct ovpn_cb) > sizeof(skb->cb)); + return (struct ovpn_cb *)skb->cb; +} + +/* Return IP protocol version from skb header. + * Return 0 if protocol is not IPv4/IPv6 or cannot be read. + */ +static inline __be16 ovpn_ip_check_protocol(struct sk_buff *skb) +{ + __be16 proto = 0; + + /* skb could be non-linear, + * make sure IP header is in non-fragmented part + */ + if (!pskb_network_may_pull(skb, sizeof(struct iphdr))) + return 0; + + if (ip_hdr(skb)->version == 4) + proto = htons(ETH_P_IP); + else if (ip_hdr(skb)->version == 6) + proto = htons(ETH_P_IPV6); + + return proto; +} + +#endif /* _NET_OVPN_SKB_H_ */ diff --git a/drivers/net/ovpn/socket.c b/drivers/net/ovpn/socket.c new file mode 100644 index 000000000000..a0c2a02ff205 --- /dev/null +++ b/drivers/net/ovpn/socket.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "io.h" +#include "peer.h" +#include "socket.h" +#include "tcp.h" +#include "udp.h" + +static void ovpn_socket_detach(struct socket *sock) +{ + if (!sock) + return; + + if (sock->sk->sk_protocol == IPPROTO_UDP) + ovpn_udp_socket_detach(sock); + else if (sock->sk->sk_protocol == IPPROTO_TCP) + ovpn_tcp_socket_detach(sock); + + sockfd_put(sock); +} + +static void ovpn_socket_release_work(struct work_struct *work) +{ + struct ovpn_socket *sock = container_of(work, struct ovpn_socket, work); + + ovpn_socket_detach(sock->sock); + kfree_rcu(sock, rcu); +} + +static void ovpn_socket_schedule_release(struct ovpn_socket *sock) +{ + INIT_WORK(&sock->work, ovpn_socket_release_work); + schedule_work(&sock->work); +} + +/** + * ovpn_socket_release_kref - kref_put callback + * @kref: the kref object + */ +void ovpn_socket_release_kref(struct kref *kref) +{ + struct ovpn_socket *sock = container_of(kref, struct ovpn_socket, + refcount); + + ovpn_socket_schedule_release(sock); +} + +static bool ovpn_socket_hold(struct ovpn_socket *sock) +{ + return kref_get_unless_zero(&sock->refcount); +} + +static struct ovpn_socket *ovpn_socket_get(struct socket *sock) +{ + struct ovpn_socket *ovpn_sock; + + rcu_read_lock(); + ovpn_sock = rcu_dereference_sk_user_data(sock->sk); + if (!ovpn_socket_hold(ovpn_sock)) { + pr_warn("%s: found ovpn_socket with ref = 0\n", __func__); + ovpn_sock = NULL; + } + rcu_read_unlock(); + + return ovpn_sock; +} + +static int ovpn_socket_attach(struct socket *sock, struct ovpn_peer *peer) +{ + int ret = -EOPNOTSUPP; + + if (!sock || !peer) + return -EINVAL; + + if (sock->sk->sk_protocol == IPPROTO_UDP) + ret = ovpn_udp_socket_attach(sock, peer->ovpn); + else if (sock->sk->sk_protocol == IPPROTO_TCP) + ret = ovpn_tcp_socket_attach(sock, peer); + + return ret; +} + +/* Retrieve the corresponding ovpn object from a UDP socket + * rcu_read_lock must be held on entry + */ +struct ovpn_struct *ovpn_from_udp_sock(struct sock *sk) +{ + struct ovpn_socket *ovpn_sock; + + if (unlikely(READ_ONCE(udp_sk(sk)->encap_type) != UDP_ENCAP_OVPNINUDP)) + return NULL; + + ovpn_sock = rcu_dereference_sk_user_data(sk); + if (unlikely(!ovpn_sock)) + return NULL; + + /* make sure that sk matches our stored transport socket */ + if (unlikely(!ovpn_sock->sock || sk != ovpn_sock->sock->sk)) + return NULL; + + return ovpn_sock->ovpn; +} + +/** + * ovpn_socket_new - create a new socket and initialize it + * @sock: the kernel socket to embed + * @peer: the peer reachable via this socket + * + * Return: an openvpn socket on success or a negative error code otherwise + */ +struct ovpn_socket *ovpn_socket_new(struct socket *sock, struct ovpn_peer *peer) +{ + struct ovpn_socket *ovpn_sock; + int ret; + + ret = ovpn_socket_attach(sock, peer); + if (ret < 0 && ret != -EALREADY) + return ERR_PTR(ret); + + /* if this socket is already owned by this interface, just increase the + * refcounter and use it as expected. + * + * Since UDP sockets can be used to talk to multiple remote endpoints, + * openvpn normally instantiates only one socket and shares it among all + * its peers. For this reason, when we find out that a socket is already + * used for some other peer in *this* instance, we can happily increase + * its refcounter and use it normally. + */ + if (ret == -EALREADY) { + /* caller is expected to increase the sock refcounter before + * passing it to this function. For this reason we drop it if + * not needed, like when this socket is already owned. + */ + ovpn_sock = ovpn_socket_get(sock); + sockfd_put(sock); + return ovpn_sock; + } + + ovpn_sock = kzalloc(sizeof(*ovpn_sock), GFP_KERNEL); + if (!ovpn_sock) { + ret = -ENOMEM; + goto err; + } + + ovpn_sock->sock = sock; + kref_init(&ovpn_sock->refcount); + + /* TCP sockets are per-peer, therefore they are linked to their unique + * peer + */ + if (sock->sk->sk_protocol == IPPROTO_TCP) { + ovpn_sock->peer = peer; + } else { + /* in UDP we only link the ovpn instance since the socket is + * shared among multiple peers + */ + ovpn_sock->ovpn = peer->ovpn; + } + + rcu_assign_sk_user_data(sock->sk, ovpn_sock); + + return ovpn_sock; +err: + ovpn_socket_detach(sock); + return ERR_PTR(ret); +} diff --git a/drivers/net/ovpn/socket.h b/drivers/net/ovpn/socket.h new file mode 100644 index 000000000000..bc22fff453ad --- /dev/null +++ b/drivers/net/ovpn/socket.h @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_SOCK_H_ +#define _NET_OVPN_SOCK_H_ + +#include +#include +#include + +struct ovpn_struct; +struct ovpn_peer; + +/** + * struct ovpn_socket - a kernel socket referenced in the ovpn code + * @ovpn: ovpn instance owning this socket (UDP only) + * @peer: unique peer transmitting over this socket (TCP only) + * @sock: the low level sock object + * @refcount: amount of contexts currently referencing this object + * @work: member used to schedule release routine (it may block) + * @rcu: member used to schedule RCU destructor callback + */ +struct ovpn_socket { + union { + struct ovpn_struct *ovpn; + struct ovpn_peer *peer; + }; + + struct socket *sock; + struct kref refcount; + struct work_struct work; + struct rcu_head rcu; +}; + +void ovpn_socket_release_kref(struct kref *kref); + +/** + * ovpn_socket_put - decrease reference counter + * @sock: the socket whose reference counter should be decreased + */ +static inline void ovpn_socket_put(struct ovpn_socket *sock) +{ + kref_put(&sock->refcount, ovpn_socket_release_kref); +} + +struct ovpn_socket *ovpn_socket_new(struct socket *sock, + struct ovpn_peer *peer); + +#endif /* _NET_OVPN_SOCK_H_ */ diff --git a/drivers/net/ovpn/stats.c b/drivers/net/ovpn/stats.c new file mode 100644 index 000000000000..a383842c3449 --- /dev/null +++ b/drivers/net/ovpn/stats.c @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include + +#include "stats.h" + +void ovpn_peer_stats_init(struct ovpn_peer_stats *ps) +{ + atomic64_set(&ps->rx.bytes, 0); + atomic64_set(&ps->rx.packets, 0); + + atomic64_set(&ps->tx.bytes, 0); + atomic64_set(&ps->tx.packets, 0); +} diff --git a/drivers/net/ovpn/stats.h b/drivers/net/ovpn/stats.h new file mode 100644 index 000000000000..868f49d25eaa --- /dev/null +++ b/drivers/net/ovpn/stats.h @@ -0,0 +1,47 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + * Lev Stipakov + */ + +#ifndef _NET_OVPN_OVPNSTATS_H_ +#define _NET_OVPN_OVPNSTATS_H_ + +/* one stat */ +struct ovpn_peer_stat { + atomic64_t bytes; + atomic64_t packets; +}; + +/* rx and tx stats combined */ +struct ovpn_peer_stats { + struct ovpn_peer_stat rx; + struct ovpn_peer_stat tx; +}; + +void ovpn_peer_stats_init(struct ovpn_peer_stats *ps); + +static inline void ovpn_peer_stats_increment(struct ovpn_peer_stat *stat, + const unsigned int n) +{ + atomic64_add(n, &stat->bytes); + atomic64_inc(&stat->packets); +} + +static inline void ovpn_peer_stats_increment_rx(struct ovpn_peer_stats *stats, + const unsigned int n) +{ + ovpn_peer_stats_increment(&stats->rx, n); +} + +static inline void ovpn_peer_stats_increment_tx(struct ovpn_peer_stats *stats, + const unsigned int n) +{ + ovpn_peer_stats_increment(&stats->tx, n); +} + +#endif /* _NET_OVPN_OVPNSTATS_H_ */ diff --git a/drivers/net/ovpn/tcp.c b/drivers/net/ovpn/tcp.c new file mode 100644 index 000000000000..d6f377a116ef --- /dev/null +++ b/drivers/net/ovpn/tcp.c @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "io.h" +#include "packet.h" +#include "peer.h" +#include "proto.h" +#include "skb.h" +#include "tcp.h" + +static struct proto ovpn_tcp_prot __ro_after_init; +static struct proto_ops ovpn_tcp_ops __ro_after_init; +static struct proto ovpn_tcp6_prot; +static struct proto_ops ovpn_tcp6_ops; +static DEFINE_MUTEX(tcp6_prot_mutex); + +static int ovpn_tcp_parse(struct strparser *strp, struct sk_buff *skb) +{ + struct strp_msg *rxm = strp_msg(skb); + __be16 blen; + u16 len; + int err; + + /* when packets are written to the TCP stream, they are prepended with + * two bytes indicating the actual packet size. + * Here we read those two bytes and move the skb data pointer to the + * beginning of the packet + */ + + if (skb->len < rxm->offset + 2) + return 0; + + err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); + if (err < 0) + return err; + + len = be16_to_cpu(blen); + if (len < 2) + return -EINVAL; + + return len + 2; +} + +/* queue skb for sending to userspace via recvmsg on the socket */ +static void ovpn_tcp_to_userspace(struct ovpn_peer *peer, struct sock *sk, + struct sk_buff *skb) +{ + skb_set_owner_r(skb, sk); + memset(skb->cb, 0, sizeof(skb->cb)); + skb_queue_tail(&peer->tcp.user_queue, skb); + peer->tcp.sk_cb.sk_data_ready(sk); +} + +static void ovpn_tcp_rcv(struct strparser *strp, struct sk_buff *skb) +{ + struct ovpn_peer *peer = container_of(strp, struct ovpn_peer, tcp.strp); + struct strp_msg *msg = strp_msg(skb); + size_t pkt_len = msg->full_len - 2; + size_t off = msg->offset + 2; + + /* ensure skb->data points to the beginning of the openvpn packet */ + if (!pskb_pull(skb, off)) { + net_warn_ratelimited("%s: packet too small\n", + peer->ovpn->dev->name); + goto err; + } + + /* strparser does not trim the skb for us, therefore we do it now */ + if (pskb_trim(skb, pkt_len) != 0) { + net_warn_ratelimited("%s: trimming skb failed\n", + peer->ovpn->dev->name); + goto err; + } + + /* we need the first byte of data to be accessible + * to extract the opcode and the key ID later on + */ + if (!pskb_may_pull(skb, 1)) { + net_warn_ratelimited("%s: packet too small to fetch opcode\n", + peer->ovpn->dev->name); + goto err; + } + + /* DATA_V2 packets are handled in kernel, the rest goes to user space */ + if (likely(ovpn_opcode_from_skb(skb, 0) == OVPN_DATA_V2)) { + /* hold reference to peer as required by ovpn_recv(). + * + * NOTE: in this context we should already be holding a + * reference to this peer, therefore ovpn_peer_hold() is + * not expected to fail + */ + if (WARN_ON(!ovpn_peer_hold(peer))) + goto err; + + ovpn_recv(peer, skb); + } else { + /* The packet size header must be there when sending the packet + * to userspace, therefore we put it back + */ + skb_push(skb, 2); + ovpn_tcp_to_userspace(peer, strp->sk, skb); + } + + return; +err: + netdev_err(peer->ovpn->dev, + "cannot process incoming TCP data for peer %u\n", peer->id); + dev_core_stats_rx_dropped_inc(peer->ovpn->dev); + kfree_skb(skb); + ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_TRANSPORT_ERROR); +} + +static int ovpn_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + int err = 0, off, copied = 0, ret; + struct ovpn_socket *sock; + struct ovpn_peer *peer; + struct sk_buff *skb; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + if (!sock || !sock->peer) { + rcu_read_unlock(); + return -EBADF; + } + /* we take a reference to the peer linked to this TCP socket, because + * in turn the peer holds a reference to the socket itself. + * By doing so we also ensure that the peer stays alive along with + * the socket while executing this function + */ + ovpn_peer_hold(sock->peer); + peer = sock->peer; + rcu_read_unlock(); + + skb = __skb_recv_datagram(sk, &peer->tcp.user_queue, flags, &off, &err); + if (!skb) { + if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN) { + ret = 0; + goto out; + } + ret = err; + goto out; + } + + copied = len; + if (copied > skb->len) + copied = skb->len; + else if (copied < skb->len) + msg->msg_flags |= MSG_TRUNC; + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (unlikely(err)) { + kfree_skb(skb); + ret = err; + goto out; + } + + if (flags & MSG_TRUNC) + copied = skb->len; + kfree_skb(skb); + ret = copied; +out: + ovpn_peer_put(peer); + return ret; +} + +void ovpn_tcp_socket_detach(struct socket *sock) +{ + struct ovpn_socket *ovpn_sock; + struct ovpn_peer *peer; + + if (!sock) + return; + + rcu_read_lock(); + ovpn_sock = rcu_dereference_sk_user_data(sock->sk); + + if (!ovpn_sock->peer) { + rcu_read_unlock(); + return; + } + + peer = ovpn_sock->peer; + strp_stop(&peer->tcp.strp); + + skb_queue_purge(&peer->tcp.user_queue); + + /* restore CBs that were saved in ovpn_sock_set_tcp_cb() */ + sock->sk->sk_data_ready = peer->tcp.sk_cb.sk_data_ready; + sock->sk->sk_write_space = peer->tcp.sk_cb.sk_write_space; + sock->sk->sk_prot = peer->tcp.sk_cb.prot; + sock->sk->sk_socket->ops = peer->tcp.sk_cb.ops; + rcu_assign_sk_user_data(sock->sk, NULL); + + rcu_read_unlock(); + + /* cancel any ongoing work. Done after removing the CBs so that these + * workers cannot be re-armed + */ + cancel_work_sync(&peer->tcp.tx_work); + strp_done(&peer->tcp.strp); +} + +static void ovpn_tcp_send_sock(struct ovpn_peer *peer) +{ + struct sk_buff *skb = peer->tcp.out_msg.skb; + + if (!skb) + return; + + if (peer->tcp.tx_in_progress) + return; + + peer->tcp.tx_in_progress = true; + + do { + int ret = skb_send_sock_locked(peer->sock->sock->sk, skb, + peer->tcp.out_msg.offset, + peer->tcp.out_msg.len); + if (unlikely(ret < 0)) { + if (ret == -EAGAIN) + goto out; + + net_warn_ratelimited("%s: TCP error to peer %u: %d\n", + peer->ovpn->dev->name, peer->id, + ret); + + /* in case of TCP error we can't recover the VPN + * stream therefore we abort the connection + */ + ovpn_peer_del(peer, + OVPN_DEL_PEER_REASON_TRANSPORT_ERROR); + break; + } + + peer->tcp.out_msg.len -= ret; + peer->tcp.out_msg.offset += ret; + } while (peer->tcp.out_msg.len > 0); + + if (!peer->tcp.out_msg.len) + dev_sw_netstats_tx_add(peer->ovpn->dev, 1, skb->len); + + kfree_skb(peer->tcp.out_msg.skb); + peer->tcp.out_msg.skb = NULL; + peer->tcp.out_msg.len = 0; + peer->tcp.out_msg.offset = 0; + +out: + peer->tcp.tx_in_progress = false; +} + +static void ovpn_tcp_tx_work(struct work_struct *work) +{ + struct ovpn_peer *peer; + + peer = container_of(work, struct ovpn_peer, tcp.tx_work); + + lock_sock(peer->sock->sock->sk); + ovpn_tcp_send_sock(peer); + release_sock(peer->sock->sock->sk); +} + +void ovpn_tcp_send_sock_skb(struct ovpn_peer *peer, struct sk_buff *skb) +{ + if (peer->tcp.out_msg.skb) + return; + + peer->tcp.out_msg.skb = skb; + peer->tcp.out_msg.len = skb->len; + peer->tcp.out_msg.offset = 0; + + ovpn_tcp_send_sock(peer); +} + +static int ovpn_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct ovpn_socket *sock; + int ret, linear = PAGE_SIZE; + struct ovpn_peer *peer; + struct sk_buff *skb; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + peer = sock->peer; + if (unlikely(!ovpn_peer_hold(peer))) { + rcu_read_unlock(); + return -EIO; + } + rcu_read_unlock(); + + if (msg->msg_flags & ~MSG_DONTWAIT) { + ret = -EOPNOTSUPP; + goto peer_free; + } + + lock_sock(sk); + + if (peer->tcp.out_msg.skb) { + ret = -EAGAIN; + goto unlock; + } + + if (size < linear) + linear = size; + + skb = sock_alloc_send_pskb(sk, linear, size - linear, + msg->msg_flags & MSG_DONTWAIT, &ret, 0); + if (!skb) { + net_err_ratelimited("%s: skb alloc failed: %d\n", + sock->peer->ovpn->dev->name, ret); + goto unlock; + } + + skb_put(skb, linear); + skb->len = size; + skb->data_len = size - linear; + + ret = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, size); + if (ret) { + kfree_skb(skb); + net_err_ratelimited("%s: skb copy from iter failed: %d\n", + sock->peer->ovpn->dev->name, ret); + goto unlock; + } + + ovpn_tcp_send_sock_skb(sock->peer, skb); + ret = size; +unlock: + release_sock(sk); +peer_free: + ovpn_peer_put(peer); + return ret; +} + +static void ovpn_tcp_data_ready(struct sock *sk) +{ + struct ovpn_socket *sock; + + trace_sk_data_ready(sk); + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + strp_data_ready(&sock->peer->tcp.strp); + rcu_read_unlock(); +} + +static void ovpn_tcp_write_space(struct sock *sk) +{ + struct ovpn_socket *sock; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + schedule_work(&sock->peer->tcp.tx_work); + sock->peer->tcp.sk_cb.sk_write_space(sk); + rcu_read_unlock(); +} + +static void ovpn_tcp_build_protos(struct proto *new_prot, + struct proto_ops *new_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops); + +/* Set TCP encapsulation callbacks */ +int ovpn_tcp_socket_attach(struct socket *sock, struct ovpn_peer *peer) +{ + struct strp_callbacks cb = { + .rcv_msg = ovpn_tcp_rcv, + .parse_msg = ovpn_tcp_parse, + }; + int ret; + + /* make sure no pre-existing encapsulation handler exists */ + if (sock->sk->sk_user_data) + return -EBUSY; + + /* sanity check */ + if (sock->sk->sk_protocol != IPPROTO_TCP) { + netdev_err(peer->ovpn->dev, + "provided socket is not TCP as expected\n"); + return -EINVAL; + } + + /* only a fully connected socket are expected. Connection should be + * handled in userspace + */ + if (sock->sk->sk_state != TCP_ESTABLISHED) { + netdev_err(peer->ovpn->dev, + "provided TCP socket is not in ESTABLISHED state: %d\n", + sock->sk->sk_state); + return -EINVAL; + } + + lock_sock(sock->sk); + + ret = strp_init(&peer->tcp.strp, sock->sk, &cb); + if (ret < 0) { + DEBUG_NET_WARN_ON_ONCE(1); + release_sock(sock->sk); + return ret; + } + + INIT_WORK(&peer->tcp.tx_work, ovpn_tcp_tx_work); + __sk_dst_reset(sock->sk); + skb_queue_head_init(&peer->tcp.user_queue); + + /* save current CBs so that they can be restored upon socket release */ + peer->tcp.sk_cb.sk_data_ready = sock->sk->sk_data_ready; + peer->tcp.sk_cb.sk_write_space = sock->sk->sk_write_space; + peer->tcp.sk_cb.prot = sock->sk->sk_prot; + peer->tcp.sk_cb.ops = sock->sk->sk_socket->ops; + + /* assign our static CBs and prot/ops */ + sock->sk->sk_data_ready = ovpn_tcp_data_ready; + sock->sk->sk_write_space = ovpn_tcp_write_space; + + if (sock->sk->sk_family == AF_INET) { + sock->sk->sk_prot = &ovpn_tcp_prot; + sock->sk->sk_socket->ops = &ovpn_tcp_ops; + } else { + mutex_lock(&tcp6_prot_mutex); + if (!ovpn_tcp6_prot.recvmsg) + ovpn_tcp_build_protos(&ovpn_tcp6_prot, &ovpn_tcp6_ops, + sock->sk->sk_prot, + sock->sk->sk_socket->ops); + mutex_unlock(&tcp6_prot_mutex); + + sock->sk->sk_prot = &ovpn_tcp6_prot; + sock->sk->sk_socket->ops = &ovpn_tcp6_ops; + } + + /* avoid using task_frag */ + sock->sk->sk_allocation = GFP_ATOMIC; + sock->sk->sk_use_task_frag = false; + + /* enqueue the RX worker */ + strp_check_rcv(&peer->tcp.strp); + + release_sock(sock->sk); + return 0; +} + +static void ovpn_tcp_close(struct sock *sk, long timeout) +{ + struct ovpn_socket *sock; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + + strp_stop(&sock->peer->tcp.strp); + barrier(); + + tcp_close(sk, timeout); + + ovpn_peer_del(sock->peer, OVPN_DEL_PEER_REASON_TRANSPORT_ERROR); + rcu_read_unlock(); +} + +static __poll_t ovpn_tcp_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + __poll_t mask = datagram_poll(file, sock, wait); + struct ovpn_socket *ovpn_sock; + + rcu_read_lock(); + ovpn_sock = rcu_dereference_sk_user_data(sock->sk); + if (!skb_queue_empty(&ovpn_sock->peer->tcp.user_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + rcu_read_unlock(); + + return mask; +} + +static void ovpn_tcp_build_protos(struct proto *new_prot, + struct proto_ops *new_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops) +{ + memcpy(new_prot, orig_prot, sizeof(*new_prot)); + memcpy(new_ops, orig_ops, sizeof(*new_ops)); + new_prot->recvmsg = ovpn_tcp_recvmsg; + new_prot->sendmsg = ovpn_tcp_sendmsg; + new_prot->close = ovpn_tcp_close; + new_ops->poll = ovpn_tcp_poll; +} + +/* Initialize TCP static objects */ +void __init ovpn_tcp_init(void) +{ + ovpn_tcp_build_protos(&ovpn_tcp_prot, &ovpn_tcp_ops, &tcp_prot, + &inet_stream_ops); +} diff --git a/drivers/net/ovpn/tcp.h b/drivers/net/ovpn/tcp.h new file mode 100644 index 000000000000..fb2cd0b606b4 --- /dev/null +++ b/drivers/net/ovpn/tcp.h @@ -0,0 +1,44 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_TCP_H_ +#define _NET_OVPN_TCP_H_ + +#include +#include +#include + +#include "peer.h" +#include "skb.h" +#include "socket.h" + +void __init ovpn_tcp_init(void); + +int ovpn_tcp_socket_attach(struct socket *sock, struct ovpn_peer *peer); +void ovpn_tcp_socket_detach(struct socket *sock); +void ovpn_tcp_send_sock_skb(struct ovpn_peer *peer, struct sk_buff *skb); + +/* Prepare skb and enqueue it for sending to peer. + * + * Preparation consist in prepending the skb payload with its size. + * Required by the OpenVPN protocol in order to extract packets from + * the TCP stream on the receiver side. + */ +static inline void ovpn_tcp_send_skb(struct ovpn_peer *peer, + struct sk_buff *skb) +{ + u16 len = skb->len; + + *(__be16 *)__skb_push(skb, sizeof(u16)) = htons(len); + + bh_lock_sock(peer->sock->sock->sk); + ovpn_tcp_send_sock_skb(peer, skb); + bh_unlock_sock(peer->sock->sock->sk); +} + +#endif /* _NET_OVPN_TCP_H_ */ diff --git a/drivers/net/ovpn/udp.c b/drivers/net/ovpn/udp.c new file mode 100644 index 000000000000..d1e88ae83843 --- /dev/null +++ b/drivers/net/ovpn/udp.c @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ovpnstruct.h" +#include "main.h" +#include "bind.h" +#include "io.h" +#include "peer.h" +#include "proto.h" +#include "socket.h" +#include "udp.h" + +/** + * ovpn_udp_encap_recv - Start processing a received UDP packet. + * @sk: socket over which the packet was received + * @skb: the received packet + * + * If the first byte of the payload is DATA_V2, the packet is further processed, + * otherwise it is forwarded to the UDP stack for delivery to user space. + * + * Return: + * 0 if skb was consumed or dropped + * >0 if skb should be passed up to userspace as UDP (packet not consumed) + * <0 if skb should be resubmitted as proto -N (packet not consumed) + */ +static int ovpn_udp_encap_recv(struct sock *sk, struct sk_buff *skb) +{ + struct ovpn_peer *peer = NULL; + struct ovpn_struct *ovpn; + u32 peer_id; + u8 opcode; + + ovpn = ovpn_from_udp_sock(sk); + if (unlikely(!ovpn)) { + net_err_ratelimited("%s: cannot obtain ovpn object from UDP socket\n", + __func__); + goto drop_noovpn; + } + + /* Make sure the first 4 bytes of the skb data buffer after the UDP + * header are accessible. + * They are required to fetch the OP code, the key ID and the peer ID. + */ + if (unlikely(!pskb_may_pull(skb, sizeof(struct udphdr) + + OVPN_OP_SIZE_V2))) { + net_dbg_ratelimited("%s: packet too small\n", __func__); + goto drop; + } + + opcode = ovpn_opcode_from_skb(skb, sizeof(struct udphdr)); + if (unlikely(opcode != OVPN_DATA_V2)) { + /* DATA_V1 is not supported */ + if (opcode == OVPN_DATA_V1) + goto drop; + + /* unknown or control packet: let it bubble up to userspace */ + return 1; + } + + peer_id = ovpn_peer_id_from_skb(skb, sizeof(struct udphdr)); + /* some OpenVPN server implementations send data packets with the + * peer-id set to undef. In this case we skip the peer lookup by peer-id + * and we try with the transport address + */ + if (peer_id != OVPN_PEER_ID_UNDEF) { + peer = ovpn_peer_get_by_id(ovpn, peer_id); + if (!peer) { + net_err_ratelimited("%s: received data from unknown peer (id: %d)\n", + __func__, peer_id); + goto drop; + } + } + + if (!peer) { + /* data packet with undef peer-id */ + peer = ovpn_peer_get_by_transp_addr(ovpn, skb); + if (unlikely(!peer)) { + net_dbg_ratelimited("%s: received data with undef peer-id from unknown source\n", + __func__); + goto drop; + } + } + + /* pop off outer UDP header */ + __skb_pull(skb, sizeof(struct udphdr)); + ovpn_recv(peer, skb); + return 0; + +drop: + if (peer) + ovpn_peer_put(peer); + dev_core_stats_rx_dropped_inc(ovpn->dev); +drop_noovpn: + kfree_skb(skb); + return 0; +} + +/** + * ovpn_udp4_output - send IPv4 packet over udp socket + * @ovpn: the openvpn instance + * @bind: the binding related to the destination peer + * @cache: dst cache + * @sk: the socket to send the packet over + * @skb: the packet to send + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_udp4_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + struct rtable *rt; + struct flowi4 fl = { + .saddr = bind->local.ipv4.s_addr, + .daddr = bind->remote.in4.sin_addr.s_addr, + .fl4_sport = inet_sk(sk)->inet_sport, + .fl4_dport = bind->remote.in4.sin_port, + .flowi4_proto = sk->sk_protocol, + .flowi4_mark = sk->sk_mark, + }; + int ret; + + local_bh_disable(); + rt = dst_cache_get_ip4(cache, &fl.saddr); + if (rt) + goto transmit; + + if (unlikely(!inet_confirm_addr(sock_net(sk), NULL, 0, fl.saddr, + RT_SCOPE_HOST))) { + /* we may end up here when the cached address is not usable + * anymore. In this case we reset address/cache and perform a + * new look up + */ + fl.saddr = 0; + bind->local.ipv4.s_addr = 0; + dst_cache_reset(cache); + } + + rt = ip_route_output_flow(sock_net(sk), &fl, sk); + if (IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) { + fl.saddr = 0; + bind->local.ipv4.s_addr = 0; + dst_cache_reset(cache); + + rt = ip_route_output_flow(sock_net(sk), &fl, sk); + } + + if (IS_ERR(rt)) { + ret = PTR_ERR(rt); + net_dbg_ratelimited("%s: no route to host %pISpc: %d\n", + ovpn->dev->name, &bind->remote.in4, ret); + goto err; + } + dst_cache_set_ip4(cache, &rt->dst, fl.saddr); + +transmit: + udp_tunnel_xmit_skb(rt, sk, skb, fl.saddr, fl.daddr, 0, + ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, + fl.fl4_dport, false, sk->sk_no_check_tx); + ret = 0; +err: + local_bh_enable(); + return ret; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * ovpn_udp6_output - send IPv6 packet over udp socket + * @ovpn: the openvpn instance + * @bind: the binding related to the destination peer + * @cache: dst cache + * @sk: the socket to send the packet over + * @skb: the packet to send + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_udp6_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *dst; + int ret; + + struct flowi6 fl = { + .saddr = bind->local.ipv6, + .daddr = bind->remote.in6.sin6_addr, + .fl6_sport = inet_sk(sk)->inet_sport, + .fl6_dport = bind->remote.in6.sin6_port, + .flowi6_proto = sk->sk_protocol, + .flowi6_mark = sk->sk_mark, + .flowi6_oif = bind->remote.in6.sin6_scope_id, + }; + + local_bh_disable(); + dst = dst_cache_get_ip6(cache, &fl.saddr); + if (dst) + goto transmit; + + if (unlikely(!ipv6_chk_addr(sock_net(sk), &fl.saddr, NULL, 0))) { + /* we may end up here when the cached address is not usable + * anymore. In this case we reset address/cache and perform a + * new look up + */ + fl.saddr = in6addr_any; + bind->local.ipv6 = in6addr_any; + dst_cache_reset(cache); + } + + dst = ipv6_stub->ipv6_dst_lookup_flow(sock_net(sk), sk, &fl, NULL); + if (IS_ERR(dst)) { + ret = PTR_ERR(dst); + net_dbg_ratelimited("%s: no route to host %pISpc: %d\n", + ovpn->dev->name, &bind->remote.in6, ret); + goto err; + } + dst_cache_set_ip6(cache, dst, &fl.saddr); + +transmit: + udp_tunnel6_xmit_skb(dst, sk, skb, skb->dev, &fl.saddr, &fl.daddr, 0, + ip6_dst_hoplimit(dst), 0, fl.fl6_sport, + fl.fl6_dport, udp_get_no_check6_tx(sk)); + ret = 0; +err: + local_bh_enable(); + return ret; +} +#endif + +/** + * ovpn_udp_output - transmit skb using udp-tunnel + * @ovpn: the openvpn instance + * @bind: the binding related to the destination peer + * @cache: dst cache + * @sk: the socket to send the packet over + * @skb: the packet to send + * + * rcu_read_lock should be held on entry. + * On return, the skb is consumed. + * + * Return: 0 on success or a negative error code otherwise + */ +static int ovpn_udp_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + int ret; + + /* set sk to null if skb is already orphaned */ + if (!skb->destructor) + skb->sk = NULL; + + /* always permit openvpn-created packets to be (outside) fragmented */ + skb->ignore_df = 1; + + switch (bind->remote.in4.sin_family) { + case AF_INET: + ret = ovpn_udp4_output(ovpn, bind, cache, sk, skb); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ret = ovpn_udp6_output(ovpn, bind, cache, sk, skb); + break; +#endif + default: + ret = -EAFNOSUPPORT; + break; + } + + return ret; +} + +/** + * ovpn_udp_send_skb - prepare skb and send it over via UDP + * @ovpn: the openvpn instance + * @peer: the destination peer + * @skb: the packet to send + */ +void ovpn_udp_send_skb(struct ovpn_struct *ovpn, struct ovpn_peer *peer, + struct sk_buff *skb) +{ + struct ovpn_bind *bind; + unsigned int pkt_len; + struct socket *sock; + int ret = -1; + + skb->dev = ovpn->dev; + /* no checksum performed at this layer */ + skb->ip_summed = CHECKSUM_NONE; + + /* get socket info */ + sock = peer->sock->sock; + if (unlikely(!sock)) { + net_warn_ratelimited("%s: no sock for remote peer\n", __func__); + goto out; + } + + rcu_read_lock(); + /* get binding */ + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) { + net_warn_ratelimited("%s: no bind for remote peer\n", __func__); + goto out_unlock; + } + + /* crypto layer -> transport (UDP) */ + pkt_len = skb->len; + ret = ovpn_udp_output(ovpn, bind, &peer->dst_cache, sock->sk, skb); + +out_unlock: + rcu_read_unlock(); +out: + if (unlikely(ret < 0)) { + dev_core_stats_tx_dropped_inc(ovpn->dev); + kfree_skb(skb); + return; + } + + dev_sw_netstats_tx_add(ovpn->dev, 1, pkt_len); +} + +/** + * ovpn_udp_socket_attach - set udp-tunnel CBs on socket and link it to ovpn + * @sock: socket to configure + * @ovpn: the openvp instance to link + * + * After invoking this function, the sock will be controlled by ovpn so that + * any incoming packet may be processed by ovpn first. + * + * Return: 0 on success or a negative error code otherwise + */ +int ovpn_udp_socket_attach(struct socket *sock, struct ovpn_struct *ovpn) +{ + struct udp_tunnel_sock_cfg cfg = { + .encap_type = UDP_ENCAP_OVPNINUDP, + .encap_rcv = ovpn_udp_encap_recv, + }; + struct ovpn_socket *old_data; + int ret; + + /* sanity check */ + if (sock->sk->sk_protocol != IPPROTO_UDP) { + DEBUG_NET_WARN_ON_ONCE(1); + return -EINVAL; + } + + /* make sure no pre-existing encapsulation handler exists */ + rcu_read_lock(); + old_data = rcu_dereference_sk_user_data(sock->sk); + if (!old_data) { + /* socket is currently unused - we can take it */ + rcu_read_unlock(); + setup_udp_tunnel_sock(sock_net(sock->sk), sock, &cfg); + return 0; + } + + /* socket is in use. We need to understand if it's owned by this ovpn + * instance or by something else. + * In the former case, we can increase the refcounter and happily + * use it, because the same UDP socket is expected to be shared among + * different peers. + * + * Unlikely TCP, a single UDP socket can be used to talk to many remote + * hosts and therefore openvpn instantiates one only for all its peers + */ + if ((READ_ONCE(udp_sk(sock->sk)->encap_type) == UDP_ENCAP_OVPNINUDP) && + old_data->ovpn == ovpn) { + netdev_dbg(ovpn->dev, + "%s: provided socket already owned by this interface\n", + __func__); + ret = -EALREADY; + } else { + netdev_err(ovpn->dev, + "%s: provided socket already taken by other user\n", + __func__); + ret = -EBUSY; + } + rcu_read_unlock(); + + return ret; +} + +/** + * ovpn_udp_socket_detach - clean udp-tunnel status for this socket + * @sock: the socket to clean + */ +void ovpn_udp_socket_detach(struct socket *sock) +{ + struct udp_tunnel_sock_cfg cfg = { }; + + setup_udp_tunnel_sock(sock_net(sock->sk), sock, &cfg); +} diff --git a/drivers/net/ovpn/udp.h b/drivers/net/ovpn/udp.h new file mode 100644 index 000000000000..fecb68464896 --- /dev/null +++ b/drivers/net/ovpn/udp.h @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* OpenVPN data channel offload + * + * Copyright (C) 2019-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_UDP_H_ +#define _NET_OVPN_UDP_H_ + +#include +#include + +struct ovpn_peer; +struct ovpn_struct; +struct sk_buff; +struct socket; + +int ovpn_udp_socket_attach(struct socket *sock, struct ovpn_struct *ovpn); +void ovpn_udp_socket_detach(struct socket *sock); +void ovpn_udp_send_skb(struct ovpn_struct *ovpn, struct ovpn_peer *peer, + struct sk_buff *skb); +struct ovpn_struct *ovpn_from_udp_sock(struct sock *sk); + +#endif /* _NET_OVPN_UDP_H_ */ diff --git a/include/net/netlink.h b/include/net/netlink.h index db6af207287c..2dc671c977ff 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -469,6 +469,7 @@ struct nla_policy { .max = _len \ } #define NLA_POLICY_MIN_LEN(_len) NLA_POLICY_MIN(NLA_BINARY, _len) +#define NLA_POLICY_MAX_LEN(_len) NLA_POLICY_MAX(NLA_BINARY, _len) /** * struct nl_info - netlink source information diff --git a/include/uapi/linux/if_link.h b/include/uapi/linux/if_link.h index 6dc258993b17..9a5419d60100 100644 --- a/include/uapi/linux/if_link.h +++ b/include/uapi/linux/if_link.h @@ -1959,4 +1959,19 @@ enum { #define IFLA_DSA_MAX (__IFLA_DSA_MAX - 1) +/* OVPN section */ + +enum ovpn_mode { + OVPN_MODE_P2P, + OVPN_MODE_MP, +}; + +enum { + IFLA_OVPN_UNSPEC, + IFLA_OVPN_MODE, + __IFLA_OVPN_MAX, +}; + +#define IFLA_OVPN_MAX (__IFLA_OVPN_MAX - 1) + #endif /* _UAPI_LINUX_IF_LINK_H */ diff --git a/include/uapi/linux/ovpn.h b/include/uapi/linux/ovpn.h new file mode 100644 index 000000000000..7bac0803cd9f --- /dev/null +++ b/include/uapi/linux/ovpn.h @@ -0,0 +1,109 @@ +/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */ +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/ovpn.yaml */ +/* YNL-GEN uapi header */ + +#ifndef _UAPI_LINUX_OVPN_H +#define _UAPI_LINUX_OVPN_H + +#define OVPN_FAMILY_NAME "ovpn" +#define OVPN_FAMILY_VERSION 1 + +#define OVPN_NONCE_TAIL_SIZE 8 + +enum ovpn_cipher_alg { + OVPN_CIPHER_ALG_NONE, + OVPN_CIPHER_ALG_AES_GCM, + OVPN_CIPHER_ALG_CHACHA20_POLY1305, +}; + +enum ovpn_del_peer_reason { + OVPN_DEL_PEER_REASON_TEARDOWN, + OVPN_DEL_PEER_REASON_USERSPACE, + OVPN_DEL_PEER_REASON_EXPIRED, + OVPN_DEL_PEER_REASON_TRANSPORT_ERROR, + OVPN_DEL_PEER_REASON_TRANSPORT_DISCONNECT, +}; + +enum ovpn_key_slot { + OVPN_KEY_SLOT_PRIMARY, + OVPN_KEY_SLOT_SECONDARY, +}; + +enum { + OVPN_A_PEER_ID = 1, + OVPN_A_PEER_REMOTE_IPV4, + OVPN_A_PEER_REMOTE_IPV6, + OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID, + OVPN_A_PEER_REMOTE_PORT, + OVPN_A_PEER_SOCKET, + OVPN_A_PEER_VPN_IPV4, + OVPN_A_PEER_VPN_IPV6, + OVPN_A_PEER_LOCAL_IPV4, + OVPN_A_PEER_LOCAL_IPV6, + OVPN_A_PEER_LOCAL_PORT, + OVPN_A_PEER_KEEPALIVE_INTERVAL, + OVPN_A_PEER_KEEPALIVE_TIMEOUT, + OVPN_A_PEER_DEL_REASON, + OVPN_A_PEER_VPN_RX_BYTES, + OVPN_A_PEER_VPN_TX_BYTES, + OVPN_A_PEER_VPN_RX_PACKETS, + OVPN_A_PEER_VPN_TX_PACKETS, + OVPN_A_PEER_LINK_RX_BYTES, + OVPN_A_PEER_LINK_TX_BYTES, + OVPN_A_PEER_LINK_RX_PACKETS, + OVPN_A_PEER_LINK_TX_PACKETS, + + __OVPN_A_PEER_MAX, + OVPN_A_PEER_MAX = (__OVPN_A_PEER_MAX - 1) +}; + +enum { + OVPN_A_KEYCONF_PEER_ID = 1, + OVPN_A_KEYCONF_SLOT, + OVPN_A_KEYCONF_KEY_ID, + OVPN_A_KEYCONF_CIPHER_ALG, + OVPN_A_KEYCONF_ENCRYPT_DIR, + OVPN_A_KEYCONF_DECRYPT_DIR, + + __OVPN_A_KEYCONF_MAX, + OVPN_A_KEYCONF_MAX = (__OVPN_A_KEYCONF_MAX - 1) +}; + +enum { + OVPN_A_KEYDIR_CIPHER_KEY = 1, + OVPN_A_KEYDIR_NONCE_TAIL, + + __OVPN_A_KEYDIR_MAX, + OVPN_A_KEYDIR_MAX = (__OVPN_A_KEYDIR_MAX - 1) +}; + +enum { + OVPN_A_IFINDEX = 1, + OVPN_A_IFNAME, + OVPN_A_PEER, + OVPN_A_KEYCONF, + + __OVPN_A_MAX, + OVPN_A_MAX = (__OVPN_A_MAX - 1) +}; + +enum { + OVPN_CMD_PEER_NEW = 1, + OVPN_CMD_PEER_SET, + OVPN_CMD_PEER_GET, + OVPN_CMD_PEER_DEL, + OVPN_CMD_PEER_DEL_NTF, + OVPN_CMD_KEY_NEW, + OVPN_CMD_KEY_GET, + OVPN_CMD_KEY_SWAP, + OVPN_CMD_KEY_SWAP_NTF, + OVPN_CMD_KEY_DEL, + + __OVPN_CMD_MAX, + OVPN_CMD_MAX = (__OVPN_CMD_MAX - 1) +}; + +#define OVPN_MCGRP_PEERS "peers" + +#endif /* _UAPI_LINUX_OVPN_H */ diff --git a/include/uapi/linux/udp.h b/include/uapi/linux/udp.h index 1a0fe8b151fb..f9f8ffddfd0c 100644 --- a/include/uapi/linux/udp.h +++ b/include/uapi/linux/udp.h @@ -43,5 +43,6 @@ struct udphdr { #define UDP_ENCAP_GTP1U 5 /* 3GPP TS 29.060 */ #define UDP_ENCAP_RXRPC 6 #define TCP_ENCAP_ESPINTCP 7 /* Yikes, this is really xfrm encap types. */ +#define UDP_ENCAP_OVPNINUDP 8 /* OpenVPN traffic */ #endif /* _UAPI_LINUX_UDP_H */ diff --git a/tools/net/ynl/ynl-gen-c.py b/tools/net/ynl/ynl-gen-c.py index 717530bc9c52..3ccbb301be87 100755 --- a/tools/net/ynl/ynl-gen-c.py +++ b/tools/net/ynl/ynl-gen-c.py @@ -466,6 +466,8 @@ class TypeBinary(Type): def _attr_policy(self, policy): if 'exact-len' in self.checks: mem = 'NLA_POLICY_EXACT_LEN(' + str(self.get_limit('exact-len')) + ')' + elif 'max-len' in self.checks: + mem = 'NLA_POLICY_MAX_LEN(' + str(self.get_limit('max-len')) + ')' else: mem = '{ ' if len(self.checks) == 1 and 'min-len' in self.checks: diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile index ff18c0361e38..e4b4494b0765 100644 --- a/tools/testing/selftests/Makefile +++ b/tools/testing/selftests/Makefile @@ -69,6 +69,7 @@ TARGETS += net/hsr TARGETS += net/mptcp TARGETS += net/netfilter TARGETS += net/openvswitch +TARGETS += net/ovpn TARGETS += net/packetdrill TARGETS += net/rds TARGETS += net/tcp_ao diff --git a/tools/testing/selftests/net/ovpn/.gitignore b/tools/testing/selftests/net/ovpn/.gitignore new file mode 100644 index 000000000000..ee44c081ca7c --- /dev/null +++ b/tools/testing/selftests/net/ovpn/.gitignore @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0+ +ovpn-cli diff --git a/tools/testing/selftests/net/ovpn/Makefile b/tools/testing/selftests/net/ovpn/Makefile new file mode 100644 index 000000000000..c76d8fd953c5 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/Makefile @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) 2020-2024 OpenVPN, Inc. +# +CFLAGS = -pedantic -Wextra -Wall -Wl,--no-as-needed -g -O0 -ggdb $(KHDR_INCLUDES) +CFLAGS += $(shell pkg-config --cflags libnl-3.0 libnl-genl-3.0) + +LDFLAGS = -lmbedtls -lmbedcrypto +LDFLAGS += $(shell pkg-config --libs libnl-3.0 libnl-genl-3.0) + +TEST_PROGS = test.sh \ + test-chachapoly.sh \ + test-tcp.sh \ + test-float.sh + +TEST_GEN_FILES = ovpn-cli + +include ../../lib.mk diff --git a/tools/testing/selftests/net/ovpn/config b/tools/testing/selftests/net/ovpn/config new file mode 100644 index 000000000000..71946ba9fa17 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/config @@ -0,0 +1,10 @@ +CONFIG_NET=y +CONFIG_INET=y +CONFIG_STREAM_PARSER=y +CONFIG_NET_UDP_TUNNEL=y +CONFIG_DST_CACHE=y +CONFIG_CRYPTO=y +CONFIG_CRYPTO_AES=y +CONFIG_CRYPTO_GCM=y +CONFIG_CRYPTO_CHACHA20POLY1305=y +CONFIG_OVPN=m diff --git a/tools/testing/selftests/net/ovpn/data64.key b/tools/testing/selftests/net/ovpn/data64.key new file mode 100644 index 000000000000..a99e88c4e290 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/data64.key @@ -0,0 +1,5 @@ +jRqMACN7d7/aFQNT8S7jkrBD8uwrgHbG5OQZP2eu4R1Y7tfpS2bf5RHv06Vi163CGoaIiTX99R3B +ia9ycAH8Wz1+9PWv51dnBLur9jbShlgZ2QHLtUc4a/gfT7zZwULXuuxdLnvR21DDeMBaTbkgbai9 +uvAa7ne1liIgGFzbv+Bas4HDVrygxIxuAnP5Qgc3648IJkZ0QEXPF+O9f0n5+QIvGCxkAUVx+5K6 +KIs+SoeWXnAopELmoGSjUpFtJbagXK82HfdqpuUxT2Tnuef0/14SzVE/vNleBNu2ZbyrSAaah8tE +BofkPJUBFY+YQcfZNM5Dgrw3i+Bpmpq/gpdg5w== diff --git a/tools/testing/selftests/net/ovpn/ovpn-cli.c b/tools/testing/selftests/net/ovpn/ovpn-cli.c new file mode 100644 index 000000000000..046dd069aaaf --- /dev/null +++ b/tools/testing/selftests/net/ovpn/ovpn-cli.c @@ -0,0 +1,2370 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2024 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include + +/* defines to make checkpatch happy */ +#define strscpy strncpy +#define __always_unused __attribute__((__unused__)) + +/* libnl < 3.5.0 does not set the NLA_F_NESTED on its own, therefore we + * have to explicitly do it to prevent the kernel from failing upon + * parsing of the message + */ +#define nla_nest_start(_msg, _type) \ + nla_nest_start(_msg, (_type) | NLA_F_NESTED) + +uint64_t nla_get_uint(struct nlattr *attr) +{ + if (nla_len(attr) == sizeof(uint32_t)) + return nla_get_u32(attr); + else + return nla_get_u64(attr); +} + +typedef int (*ovpn_nl_cb)(struct nl_msg *msg, void *arg); + +enum ovpn_key_direction { + KEY_DIR_IN = 0, + KEY_DIR_OUT, +}; + +#define KEY_LEN (256 / 8) +#define NONCE_LEN 8 + +#define PEER_ID_UNDEF 0x00FFFFFF + +struct nl_ctx { + struct nl_sock *nl_sock; + struct nl_msg *nl_msg; + struct nl_cb *nl_cb; + + int ovpn_dco_id; +}; + +enum ovpn_cmd { + CMD_INVALID, + CMD_NEW_IFACE, + CMD_DEL_IFACE, + CMD_LISTEN, + CMD_CONNECT, + CMD_NEW_PEER, + CMD_NEW_MULTI_PEER, + CMD_SET_PEER, + CMD_DEL_PEER, + CMD_GET_PEER, + CMD_NEW_KEY, + CMD_DEL_KEY, + CMD_GET_KEY, + CMD_SWAP_KEYS, + CMD_LISTEN_MCAST, +}; + +struct ovpn_ctx { + enum ovpn_cmd cmd; + + __u8 key_enc[KEY_LEN]; + __u8 key_dec[KEY_LEN]; + __u8 nonce[NONCE_LEN]; + + enum ovpn_cipher_alg cipher; + + sa_family_t sa_family; + + unsigned long peer_id; + unsigned long lport; + + union { + struct sockaddr_in in4; + struct sockaddr_in6 in6; + } remote; + + union { + struct sockaddr_in in4; + struct sockaddr_in6 in6; + } peer_ip; + + bool peer_ip_set; + + unsigned int ifindex; + char ifname[IFNAMSIZ]; + enum ovpn_mode mode; + bool mode_set; + + int socket; + int cli_socket; + + __u32 keepalive_interval; + __u32 keepalive_timeout; + + enum ovpn_key_direction key_dir; + enum ovpn_key_slot key_slot; + int key_id; + + const char *peers_file; +}; + +static int ovpn_nl_recvmsgs(struct nl_ctx *ctx) +{ + int ret; + + ret = nl_recvmsgs(ctx->nl_sock, ctx->nl_cb); + + switch (ret) { + case -NLE_INTR: + fprintf(stderr, + "netlink received interrupt due to signal - ignoring\n"); + break; + case -NLE_NOMEM: + fprintf(stderr, "netlink out of memory error\n"); + break; + case -NLE_AGAIN: + fprintf(stderr, + "netlink reports blocking read - aborting wait\n"); + break; + default: + if (ret) + fprintf(stderr, "netlink reports error (%d): %s\n", + ret, nl_geterror(-ret)); + break; + } + + return ret; +} + +static struct nl_ctx *nl_ctx_alloc_flags(struct ovpn_ctx *ovpn, int cmd, + int flags) +{ + struct nl_ctx *ctx; + int err, ret; + + ctx = calloc(1, sizeof(*ctx)); + if (!ctx) + return NULL; + + ctx->nl_sock = nl_socket_alloc(); + if (!ctx->nl_sock) { + fprintf(stderr, "cannot allocate netlink socket\n"); + goto err_free; + } + + nl_socket_set_buffer_size(ctx->nl_sock, 8192, 8192); + + ret = genl_connect(ctx->nl_sock); + if (ret) { + fprintf(stderr, "cannot connect to generic netlink: %s\n", + nl_geterror(ret)); + goto err_sock; + } + + /* enable Extended ACK for detailed error reporting */ + err = 1; + setsockopt(nl_socket_get_fd(ctx->nl_sock), SOL_NETLINK, NETLINK_EXT_ACK, + &err, sizeof(err)); + + ctx->ovpn_dco_id = genl_ctrl_resolve(ctx->nl_sock, OVPN_FAMILY_NAME); + if (ctx->ovpn_dco_id < 0) { + fprintf(stderr, "cannot find ovpn_dco netlink component: %d\n", + ctx->ovpn_dco_id); + goto err_free; + } + + ctx->nl_msg = nlmsg_alloc(); + if (!ctx->nl_msg) { + fprintf(stderr, "cannot allocate netlink message\n"); + goto err_sock; + } + + ctx->nl_cb = nl_cb_alloc(NL_CB_DEFAULT); + if (!ctx->nl_cb) { + fprintf(stderr, "failed to allocate netlink callback\n"); + goto err_msg; + } + + nl_socket_set_cb(ctx->nl_sock, ctx->nl_cb); + + genlmsg_put(ctx->nl_msg, 0, 0, ctx->ovpn_dco_id, 0, flags, cmd, 0); + + if (ovpn->ifindex > 0) + NLA_PUT_U32(ctx->nl_msg, OVPN_A_IFINDEX, ovpn->ifindex); + + return ctx; +nla_put_failure: +err_msg: + nlmsg_free(ctx->nl_msg); +err_sock: + nl_socket_free(ctx->nl_sock); +err_free: + free(ctx); + return NULL; +} + +static struct nl_ctx *nl_ctx_alloc(struct ovpn_ctx *ovpn, int cmd) +{ + return nl_ctx_alloc_flags(ovpn, cmd, 0); +} + +static void nl_ctx_free(struct nl_ctx *ctx) +{ + if (!ctx) + return; + + nl_socket_free(ctx->nl_sock); + nlmsg_free(ctx->nl_msg); + nl_cb_put(ctx->nl_cb); + free(ctx); +} + +static int ovpn_nl_cb_error(struct sockaddr_nl (*nla)__always_unused, + struct nlmsgerr *err, void *arg) +{ + struct nlmsghdr *nlh = (struct nlmsghdr *)err - 1; + struct nlattr *tb_msg[NLMSGERR_ATTR_MAX + 1]; + int len = nlh->nlmsg_len; + struct nlattr *attrs; + int *ret = arg; + int ack_len = sizeof(*nlh) + sizeof(int) + sizeof(*nlh); + + *ret = err->error; + + if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) + return NL_STOP; + + if (!(nlh->nlmsg_flags & NLM_F_CAPPED)) + ack_len += err->msg.nlmsg_len - sizeof(*nlh); + + if (len <= ack_len) + return NL_STOP; + + attrs = (void *)((uint8_t *)nlh + ack_len); + len -= ack_len; + + nla_parse(tb_msg, NLMSGERR_ATTR_MAX, attrs, len, NULL); + if (tb_msg[NLMSGERR_ATTR_MSG]) { + len = strnlen((char *)nla_data(tb_msg[NLMSGERR_ATTR_MSG]), + nla_len(tb_msg[NLMSGERR_ATTR_MSG])); + fprintf(stderr, "kernel error: %*s\n", len, + (char *)nla_data(tb_msg[NLMSGERR_ATTR_MSG])); + } + + if (tb_msg[NLMSGERR_ATTR_MISS_NEST]) { + fprintf(stderr, "missing required nesting type %u\n", + nla_get_u32(tb_msg[NLMSGERR_ATTR_MISS_NEST])); + } + + if (tb_msg[NLMSGERR_ATTR_MISS_TYPE]) { + fprintf(stderr, "missing required attribute type %u\n", + nla_get_u32(tb_msg[NLMSGERR_ATTR_MISS_TYPE])); + } + + return NL_STOP; +} + +static int ovpn_nl_cb_finish(struct nl_msg (*msg)__always_unused, + void *arg) +{ + int *status = arg; + + *status = 0; + return NL_SKIP; +} + +static int ovpn_nl_cb_ack(struct nl_msg (*msg)__always_unused, + void *arg) +{ + int *status = arg; + + *status = 0; + return NL_STOP; +} + +static int ovpn_nl_msg_send(struct nl_ctx *ctx, ovpn_nl_cb cb) +{ + int status = 1; + + nl_cb_err(ctx->nl_cb, NL_CB_CUSTOM, ovpn_nl_cb_error, &status); + nl_cb_set(ctx->nl_cb, NL_CB_FINISH, NL_CB_CUSTOM, ovpn_nl_cb_finish, + &status); + nl_cb_set(ctx->nl_cb, NL_CB_ACK, NL_CB_CUSTOM, ovpn_nl_cb_ack, &status); + + if (cb) + nl_cb_set(ctx->nl_cb, NL_CB_VALID, NL_CB_CUSTOM, cb, ctx); + + nl_send_auto_complete(ctx->nl_sock, ctx->nl_msg); + + while (status == 1) + ovpn_nl_recvmsgs(ctx); + + if (status < 0) + fprintf(stderr, "failed to send netlink message: %s (%d)\n", + strerror(-status), status); + + return status; +} + +static int ovpn_parse_key(const char *file, struct ovpn_ctx *ctx) +{ + int idx_enc, idx_dec, ret = -1; + unsigned char *ckey = NULL; + __u8 *bkey = NULL; + size_t olen = 0; + long ckey_len; + FILE *fp; + + fp = fopen(file, "r"); + if (!fp) { + fprintf(stderr, "cannot open: %s\n", file); + return -1; + } + + /* get file size */ + fseek(fp, 0L, SEEK_END); + ckey_len = ftell(fp); + rewind(fp); + + /* if the file is longer, let's just read a portion */ + if (ckey_len > 256) + ckey_len = 256; + + ckey = malloc(ckey_len); + if (!ckey) + goto err; + + ret = fread(ckey, 1, ckey_len, fp); + if (ret != ckey_len) { + fprintf(stderr, + "couldn't read enough data from key file: %dbytes read\n", + ret); + goto err; + } + + olen = 0; + ret = mbedtls_base64_decode(NULL, 0, &olen, ckey, ckey_len); + if (ret != MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL) { + char buf[256]; + + mbedtls_strerror(ret, buf, sizeof(buf)); + fprintf(stderr, "unexpected base64 error1: %s (%d)\n", buf, + ret); + + goto err; + } + + bkey = malloc(olen); + if (!bkey) { + fprintf(stderr, "cannot allocate binary key buffer\n"); + goto err; + } + + ret = mbedtls_base64_decode(bkey, olen, &olen, ckey, ckey_len); + if (ret) { + char buf[256]; + + mbedtls_strerror(ret, buf, sizeof(buf)); + fprintf(stderr, "unexpected base64 error2: %s (%d)\n", buf, + ret); + + goto err; + } + + if (olen < 2 * KEY_LEN + NONCE_LEN) { + fprintf(stderr, + "not enough data in key file, found %zdB but needs %dB\n", + olen, 2 * KEY_LEN + NONCE_LEN); + goto err; + } + + switch (ctx->key_dir) { + case KEY_DIR_IN: + idx_enc = 0; + idx_dec = 1; + break; + case KEY_DIR_OUT: + idx_enc = 1; + idx_dec = 0; + break; + default: + goto err; + } + + memcpy(ctx->key_enc, bkey + KEY_LEN * idx_enc, KEY_LEN); + memcpy(ctx->key_dec, bkey + KEY_LEN * idx_dec, KEY_LEN); + memcpy(ctx->nonce, bkey + 2 * KEY_LEN, NONCE_LEN); + + ret = 0; + +err: + fclose(fp); + free(bkey); + free(ckey); + + return ret; +} + +static int ovpn_parse_cipher(const char *cipher, struct ovpn_ctx *ctx) +{ + if (strcmp(cipher, "aes") == 0) + ctx->cipher = OVPN_CIPHER_ALG_AES_GCM; + else if (strcmp(cipher, "chachapoly") == 0) + ctx->cipher = OVPN_CIPHER_ALG_CHACHA20_POLY1305; + else if (strcmp(cipher, "none") == 0) + ctx->cipher = OVPN_CIPHER_ALG_NONE; + else + return -ENOTSUP; + + return 0; +} + +static int ovpn_parse_key_direction(const char *dir, struct ovpn_ctx *ctx) +{ + int in_dir; + + in_dir = strtoll(dir, NULL, 10); + switch (in_dir) { + case KEY_DIR_IN: + case KEY_DIR_OUT: + ctx->key_dir = in_dir; + break; + default: + fprintf(stderr, + "invalid key direction provided. Can be 0 or 1 only\n"); + return -1; + } + + return 0; +} + +static int ovpn_socket(struct ovpn_ctx *ctx, sa_family_t family, int proto) +{ + struct sockaddr_storage local_sock = { 0 }; + struct sockaddr_in6 *in6; + struct sockaddr_in *in; + int ret, s, sock_type; + size_t sock_len; + + if (proto == IPPROTO_UDP) + sock_type = SOCK_DGRAM; + else if (proto == IPPROTO_TCP) + sock_type = SOCK_STREAM; + else + return -EINVAL; + + s = socket(family, sock_type, 0); + if (s < 0) { + perror("cannot create socket"); + return -1; + } + + switch (family) { + case AF_INET: + in = (struct sockaddr_in *)&local_sock; + in->sin_family = family; + in->sin_port = htons(ctx->lport); + in->sin_addr.s_addr = htonl(INADDR_ANY); + sock_len = sizeof(*in); + break; + case AF_INET6: + in6 = (struct sockaddr_in6 *)&local_sock; + in6->sin6_family = family; + in6->sin6_port = htons(ctx->lport); + in6->sin6_addr = in6addr_any; + sock_len = sizeof(*in6); + break; + default: + return -1; + } + + int opt = 1; + + ret = setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + if (ret < 0) { + perror("setsockopt for SO_REUSEADDR"); + return ret; + } + + ret = setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)); + if (ret < 0) { + perror("setsockopt for SO_REUSEPORT"); + return ret; + } + + if (family == AF_INET6) { + opt = 0; + if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &opt, + sizeof(opt))) { + perror("failed to set IPV6_V6ONLY"); + return -1; + } + } + + ret = bind(s, (struct sockaddr *)&local_sock, sock_len); + if (ret < 0) { + perror("cannot bind socket"); + goto err_socket; + } + + ctx->socket = s; + ctx->sa_family = family; + return 0; + +err_socket: + close(s); + return -1; +} + +static int ovpn_udp_socket(struct ovpn_ctx *ctx, sa_family_t family) +{ + return ovpn_socket(ctx, family, IPPROTO_UDP); +} + +static int ovpn_listen(struct ovpn_ctx *ctx, sa_family_t family) +{ + int ret; + + ret = ovpn_socket(ctx, family, IPPROTO_TCP); + if (ret < 0) + return ret; + + ret = listen(ctx->socket, 10); + if (ret < 0) { + perror("listen"); + close(ctx->socket); + return -1; + } + + return 0; +} + +static int ovpn_accept(struct ovpn_ctx *ctx) +{ + socklen_t socklen; + int ret; + + socklen = sizeof(ctx->remote); + ret = accept(ctx->socket, (struct sockaddr *)&ctx->remote, &socklen); + if (ret < 0) { + perror("accept"); + goto err; + } + + fprintf(stderr, "Connection received!\n"); + + switch (socklen) { + case sizeof(struct sockaddr_in): + case sizeof(struct sockaddr_in6): + break; + default: + fprintf(stderr, "error: expecting IPv4 or IPv6 connection\n"); + close(ret); + ret = -EINVAL; + goto err; + } + + return ret; +err: + close(ctx->socket); + return ret; +} + +static int ovpn_connect(struct ovpn_ctx *ovpn) +{ + socklen_t socklen; + int s, ret; + + s = socket(ovpn->remote.in4.sin_family, SOCK_STREAM, 0); + if (s < 0) { + perror("cannot create socket"); + return -1; + } + + switch (ovpn->remote.in4.sin_family) { + case AF_INET: + socklen = sizeof(struct sockaddr_in); + break; + case AF_INET6: + socklen = sizeof(struct sockaddr_in6); + break; + default: + return -EOPNOTSUPP; + } + + ret = connect(s, (struct sockaddr *)&ovpn->remote, socklen); + if (ret < 0) { + perror("connect"); + goto err; + } + + fprintf(stderr, "connected\n"); + + ovpn->socket = s; + + return 0; +err: + close(s); + return ret; +} + +static int ovpn_new_peer(struct ovpn_ctx *ovpn, bool is_tcp) +{ + struct nlattr *attr; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_PEER_NEW); + if (!ctx) + return -ENOMEM; + + attr = nla_nest_start(ctx->nl_msg, OVPN_A_PEER); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_ID, ovpn->peer_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_SOCKET, ovpn->socket); + + if (!is_tcp) { + switch (ovpn->remote.in4.sin_family) { + case AF_INET: + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_REMOTE_IPV4, + ovpn->remote.in4.sin_addr.s_addr); + NLA_PUT_U16(ctx->nl_msg, OVPN_A_PEER_REMOTE_PORT, + ovpn->remote.in4.sin_port); + break; + case AF_INET6: + NLA_PUT(ctx->nl_msg, OVPN_A_PEER_REMOTE_IPV6, + sizeof(ovpn->remote.in6.sin6_addr), + &ovpn->remote.in6.sin6_addr); + NLA_PUT_U32(ctx->nl_msg, + OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID, + ovpn->remote.in6.sin6_scope_id); + NLA_PUT_U16(ctx->nl_msg, OVPN_A_PEER_REMOTE_PORT, + ovpn->remote.in6.sin6_port); + break; + default: + fprintf(stderr, + "Invalid family for remote socket address\n"); + goto nla_put_failure; + } + } + + if (ovpn->peer_ip_set) { + switch (ovpn->peer_ip.in4.sin_family) { + case AF_INET: + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_VPN_IPV4, + ovpn->peer_ip.in4.sin_addr.s_addr); + break; + case AF_INET6: + NLA_PUT(ctx->nl_msg, OVPN_A_PEER_VPN_IPV6, + sizeof(struct in6_addr), + &ovpn->peer_ip.in6.sin6_addr); + break; + default: + fprintf(stderr, "Invalid family for peer address\n"); + goto nla_put_failure; + } + } + + nla_nest_end(ctx->nl_msg, attr); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_set_peer(struct ovpn_ctx *ovpn) +{ + struct nlattr *attr; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_PEER_SET); + if (!ctx) + return -ENOMEM; + + attr = nla_nest_start(ctx->nl_msg, OVPN_A_PEER); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_ID, ovpn->peer_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_KEEPALIVE_INTERVAL, + ovpn->keepalive_interval); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_KEEPALIVE_TIMEOUT, + ovpn->keepalive_timeout); + nla_nest_end(ctx->nl_msg, attr); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_del_peer(struct ovpn_ctx *ovpn) +{ + struct nlattr *attr; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_PEER_DEL); + if (!ctx) + return -ENOMEM; + + attr = nla_nest_start(ctx->nl_msg, OVPN_A_PEER); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_ID, ovpn->peer_id); + nla_nest_end(ctx->nl_msg, attr); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_handle_peer(struct nl_msg *msg, void (*arg)__always_unused) +{ + struct nlattr *pattrs[OVPN_A_PEER_MAX + 1]; + struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); + struct nlattr *attrs[OVPN_A_MAX + 1]; + __u16 rport = 0, lport = 0; + + nla_parse(attrs, OVPN_A_MAX, genlmsg_attrdata(gnlh, 0), + genlmsg_attrlen(gnlh, 0), NULL); + + if (!attrs[OVPN_A_PEER]) { + fprintf(stderr, "no packet content in netlink message\n"); + return NL_SKIP; + } + + nla_parse(pattrs, OVPN_A_PEER_MAX, nla_data(attrs[OVPN_A_PEER]), + nla_len(attrs[OVPN_A_PEER]), NULL); + + if (pattrs[OVPN_A_PEER_ID]) + fprintf(stderr, "* Peer %u\n", + nla_get_u32(pattrs[OVPN_A_PEER_ID])); + + if (pattrs[OVPN_A_PEER_VPN_IPV4]) { + char buf[INET_ADDRSTRLEN]; + + inet_ntop(AF_INET, nla_data(pattrs[OVPN_A_PEER_VPN_IPV4]), + buf, sizeof(buf)); + fprintf(stderr, "\tVPN IPv4: %s\n", buf); + } + + if (pattrs[OVPN_A_PEER_VPN_IPV6]) { + char buf[INET6_ADDRSTRLEN]; + + inet_ntop(AF_INET6, nla_data(pattrs[OVPN_A_PEER_VPN_IPV6]), + buf, sizeof(buf)); + fprintf(stderr, "\tVPN IPv6: %s\n", buf); + } + + if (pattrs[OVPN_A_PEER_LOCAL_PORT]) + lport = ntohs(nla_get_u16(pattrs[OVPN_A_PEER_LOCAL_PORT])); + + if (pattrs[OVPN_A_PEER_REMOTE_PORT]) + rport = ntohs(nla_get_u16(pattrs[OVPN_A_PEER_REMOTE_PORT])); + + if (pattrs[OVPN_A_PEER_REMOTE_IPV6]) { + void *ip = pattrs[OVPN_A_PEER_REMOTE_IPV6]; + char buf[INET6_ADDRSTRLEN]; + int scope_id = -1; + + if (pattrs[OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID]) { + void *p = pattrs[OVPN_A_PEER_REMOTE_IPV6_SCOPE_ID]; + + scope_id = nla_get_u32(p); + } + + inet_ntop(AF_INET6, nla_data(ip), buf, sizeof(buf)); + fprintf(stderr, "\tRemote: %s:%hu (scope-id: %u)\n", buf, rport, + scope_id); + + if (pattrs[OVPN_A_PEER_LOCAL_IPV6]) { + void *ip = pattrs[OVPN_A_PEER_LOCAL_IPV6]; + + inet_ntop(AF_INET6, nla_data(ip), buf, sizeof(buf)); + fprintf(stderr, "\tLocal: %s:%hu\n", buf, lport); + } + } + + if (pattrs[OVPN_A_PEER_REMOTE_IPV4]) { + void *ip = pattrs[OVPN_A_PEER_REMOTE_IPV4]; + char buf[INET_ADDRSTRLEN]; + + inet_ntop(AF_INET, nla_data(ip), buf, sizeof(buf)); + fprintf(stderr, "\tRemote: %s:%hu\n", buf, rport); + + if (pattrs[OVPN_A_PEER_LOCAL_IPV4]) { + void *p = pattrs[OVPN_A_PEER_LOCAL_IPV4]; + + inet_ntop(AF_INET, nla_data(p), buf, sizeof(buf)); + fprintf(stderr, "\tLocal: %s:%hu\n", buf, lport); + } + } + + if (pattrs[OVPN_A_PEER_KEEPALIVE_INTERVAL]) { + void *p = pattrs[OVPN_A_PEER_KEEPALIVE_INTERVAL]; + + fprintf(stderr, "\tKeepalive interval: %u sec\n", + nla_get_u32(p)); + } + + if (pattrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT]) + fprintf(stderr, "\tKeepalive timeout: %u sec\n", + nla_get_u32(pattrs[OVPN_A_PEER_KEEPALIVE_TIMEOUT])); + + if (pattrs[OVPN_A_PEER_VPN_RX_BYTES]) + fprintf(stderr, "\tVPN RX bytes: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_VPN_RX_BYTES])); + + if (pattrs[OVPN_A_PEER_VPN_TX_BYTES]) + fprintf(stderr, "\tVPN TX bytes: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_VPN_TX_BYTES])); + + if (pattrs[OVPN_A_PEER_VPN_RX_PACKETS]) + fprintf(stderr, "\tVPN RX packets: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_VPN_RX_PACKETS])); + + if (pattrs[OVPN_A_PEER_VPN_TX_PACKETS]) + fprintf(stderr, "\tVPN TX packets: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_VPN_TX_PACKETS])); + + if (pattrs[OVPN_A_PEER_LINK_RX_BYTES]) + fprintf(stderr, "\tLINK RX bytes: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_LINK_RX_BYTES])); + + if (pattrs[OVPN_A_PEER_LINK_TX_BYTES]) + fprintf(stderr, "\tLINK TX bytes: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_LINK_TX_BYTES])); + + if (pattrs[OVPN_A_PEER_LINK_RX_PACKETS]) + fprintf(stderr, "\tLINK RX packets: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_LINK_RX_PACKETS])); + + if (pattrs[OVPN_A_PEER_LINK_TX_PACKETS]) + fprintf(stderr, "\tLINK TX packets: %" PRIu64 "\n", + nla_get_uint(pattrs[OVPN_A_PEER_LINK_TX_PACKETS])); + + return NL_SKIP; +} + +static int ovpn_get_peer(struct ovpn_ctx *ovpn) +{ + int flags = 0, ret = -1; + struct nlattr *attr; + struct nl_ctx *ctx; + + if (ovpn->peer_id == PEER_ID_UNDEF) + flags = NLM_F_DUMP; + + ctx = nl_ctx_alloc_flags(ovpn, OVPN_CMD_PEER_GET, flags); + if (!ctx) + return -ENOMEM; + + if (ovpn->peer_id != PEER_ID_UNDEF) { + attr = nla_nest_start(ctx->nl_msg, OVPN_A_PEER); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_PEER_ID, ovpn->peer_id); + nla_nest_end(ctx->nl_msg, attr); + } + + ret = ovpn_nl_msg_send(ctx, ovpn_handle_peer); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_new_key(struct ovpn_ctx *ovpn) +{ + struct nlattr *keyconf, *key_dir; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_KEY_NEW); + if (!ctx) + return -ENOMEM; + + keyconf = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_PEER_ID, ovpn->peer_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_SLOT, ovpn->key_slot); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_KEY_ID, ovpn->key_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_CIPHER_ALG, ovpn->cipher); + + key_dir = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF_ENCRYPT_DIR); + NLA_PUT(ctx->nl_msg, OVPN_A_KEYDIR_CIPHER_KEY, KEY_LEN, ovpn->key_enc); + NLA_PUT(ctx->nl_msg, OVPN_A_KEYDIR_NONCE_TAIL, NONCE_LEN, ovpn->nonce); + nla_nest_end(ctx->nl_msg, key_dir); + + key_dir = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF_DECRYPT_DIR); + NLA_PUT(ctx->nl_msg, OVPN_A_KEYDIR_CIPHER_KEY, KEY_LEN, ovpn->key_dec); + NLA_PUT(ctx->nl_msg, OVPN_A_KEYDIR_NONCE_TAIL, NONCE_LEN, ovpn->nonce); + nla_nest_end(ctx->nl_msg, key_dir); + + nla_nest_end(ctx->nl_msg, keyconf); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_del_key(struct ovpn_ctx *ovpn) +{ + struct nlattr *keyconf; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_KEY_DEL); + if (!ctx) + return -ENOMEM; + + keyconf = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_PEER_ID, ovpn->peer_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_SLOT, ovpn->key_slot); + nla_nest_end(ctx->nl_msg, keyconf); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_handle_key(struct nl_msg *msg, void (*arg)__always_unused) +{ + struct nlattr *kattrs[OVPN_A_KEYCONF_MAX + 1]; + struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); + struct nlattr *attrs[OVPN_A_MAX + 1]; + + nla_parse(attrs, OVPN_A_MAX, genlmsg_attrdata(gnlh, 0), + genlmsg_attrlen(gnlh, 0), NULL); + + if (!attrs[OVPN_A_KEYCONF]) { + fprintf(stderr, "no packet content in netlink message\n"); + return NL_SKIP; + } + + nla_parse(kattrs, OVPN_A_KEYCONF_MAX, nla_data(attrs[OVPN_A_KEYCONF]), + nla_len(attrs[OVPN_A_KEYCONF]), NULL); + + if (kattrs[OVPN_A_KEYCONF_PEER_ID]) + fprintf(stderr, "* Peer %u\n", + nla_get_u32(kattrs[OVPN_A_KEYCONF_PEER_ID])); + if (kattrs[OVPN_A_KEYCONF_SLOT]) { + fprintf(stderr, "\t- Slot: "); + switch (nla_get_u32(kattrs[OVPN_A_KEYCONF_SLOT])) { + case OVPN_KEY_SLOT_PRIMARY: + fprintf(stderr, "primary\n"); + break; + case OVPN_KEY_SLOT_SECONDARY: + fprintf(stderr, "secondary\n"); + break; + default: + fprintf(stderr, "invalid (%u)\n", + nla_get_u32(kattrs[OVPN_A_KEYCONF_SLOT])); + break; + } + } + if (kattrs[OVPN_A_KEYCONF_KEY_ID]) + fprintf(stderr, "\t- Key ID: %u\n", + nla_get_u32(kattrs[OVPN_A_KEYCONF_KEY_ID])); + if (kattrs[OVPN_A_KEYCONF_CIPHER_ALG]) { + fprintf(stderr, "\t- Cipher: "); + switch (nla_get_u32(kattrs[OVPN_A_KEYCONF_CIPHER_ALG])) { + case OVPN_CIPHER_ALG_NONE: + fprintf(stderr, "none\n"); + break; + case OVPN_CIPHER_ALG_AES_GCM: + fprintf(stderr, "aes-gcm\n"); + break; + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + fprintf(stderr, "chacha20poly1305\n"); + break; + default: + fprintf(stderr, "invalid (%u)\n", + nla_get_u32(kattrs[OVPN_A_KEYCONF_CIPHER_ALG])); + break; + } + } + + return NL_SKIP; +} + +static int ovpn_get_key(struct ovpn_ctx *ovpn) +{ + struct nlattr *keyconf; + struct nl_ctx *ctx; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_KEY_GET); + if (!ctx) + return -ENOMEM; + + keyconf = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_PEER_ID, ovpn->peer_id); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_SLOT, ovpn->key_slot); + nla_nest_end(ctx->nl_msg, keyconf); + + ret = ovpn_nl_msg_send(ctx, ovpn_handle_key); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +static int ovpn_swap_keys(struct ovpn_ctx *ovpn) +{ + struct nl_ctx *ctx; + struct nlattr *kc; + int ret = -1; + + ctx = nl_ctx_alloc(ovpn, OVPN_CMD_KEY_SWAP); + if (!ctx) + return -ENOMEM; + + kc = nla_nest_start(ctx->nl_msg, OVPN_A_KEYCONF); + NLA_PUT_U32(ctx->nl_msg, OVPN_A_KEYCONF_PEER_ID, ovpn->peer_id); + nla_nest_end(ctx->nl_msg, kc); + + ret = ovpn_nl_msg_send(ctx, NULL); +nla_put_failure: + nl_ctx_free(ctx); + return ret; +} + +/** + * Helper function used to easily add attributes to a rtnl message + */ +static int ovpn_addattr(struct nlmsghdr *n, int maxlen, int type, + const void *data, int alen) +{ + int len = RTA_LENGTH(alen); + struct rtattr *rta; + + if ((int)(NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len)) > maxlen) { + fprintf(stderr, "%s: rtnl: message exceeded bound of %d\n", + __func__, maxlen); + return -EMSGSIZE; + } + + rta = nlmsg_tail(n); + rta->rta_type = type; + rta->rta_len = len; + + if (!data) + memset(RTA_DATA(rta), 0, alen); + else + memcpy(RTA_DATA(rta), data, alen); + + n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len); + + return 0; +} + +static struct rtattr *ovpn_nest_start(struct nlmsghdr *msg, size_t max_size, + int attr) +{ + struct rtattr *nest = nlmsg_tail(msg); + + if (ovpn_addattr(msg, max_size, attr, NULL, 0) < 0) + return NULL; + + return nest; +} + +static void ovpn_nest_end(struct nlmsghdr *msg, struct rtattr *nest) +{ + nest->rta_len = (uint8_t *)nlmsg_tail(msg) - (uint8_t *)nest; +} + +#define RT_SNDBUF_SIZE (1024 * 2) +#define RT_RCVBUF_SIZE (1024 * 4) + +/** + * Open RTNL socket + */ +static int ovpn_rt_socket(void) +{ + int sndbuf = RT_SNDBUF_SIZE, rcvbuf = RT_RCVBUF_SIZE, fd; + + fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if (fd < 0) { + fprintf(stderr, "%s: cannot open netlink socket\n", __func__); + return fd; + } + + if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &sndbuf, + sizeof(sndbuf)) < 0) { + fprintf(stderr, "%s: SO_SNDBUF\n", __func__); + close(fd); + return -1; + } + + if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &rcvbuf, + sizeof(rcvbuf)) < 0) { + fprintf(stderr, "%s: SO_RCVBUF\n", __func__); + close(fd); + return -1; + } + + return fd; +} + +/** + * Bind socket to Netlink subsystem + */ +static int ovpn_rt_bind(int fd, uint32_t groups) +{ + struct sockaddr_nl local = { 0 }; + socklen_t addr_len; + + local.nl_family = AF_NETLINK; + local.nl_groups = groups; + + if (bind(fd, (struct sockaddr *)&local, sizeof(local)) < 0) { + fprintf(stderr, "%s: cannot bind netlink socket: %d\n", + __func__, errno); + return -errno; + } + + addr_len = sizeof(local); + if (getsockname(fd, (struct sockaddr *)&local, &addr_len) < 0) { + fprintf(stderr, "%s: cannot getsockname: %d\n", __func__, + errno); + return -errno; + } + + if (addr_len != sizeof(local)) { + fprintf(stderr, "%s: wrong address length %d\n", __func__, + addr_len); + return -EINVAL; + } + + if (local.nl_family != AF_NETLINK) { + fprintf(stderr, "%s: wrong address family %d\n", __func__, + local.nl_family); + return -EINVAL; + } + + return 0; +} + +typedef int (*ovpn_parse_reply_cb)(struct nlmsghdr *msg, void *arg); + +/** + * Send Netlink message and run callback on reply (if specified) + */ +static int ovpn_rt_send(struct nlmsghdr *payload, pid_t peer, + unsigned int groups, ovpn_parse_reply_cb cb, + void *arg_cb) +{ + int len, rem_len, fd, ret, rcv_len; + struct sockaddr_nl nladdr = { 0 }; + struct nlmsgerr *err; + struct nlmsghdr *h; + char buf[1024 * 16]; + struct iovec iov = { + .iov_base = payload, + .iov_len = payload->nlmsg_len, + }; + struct msghdr nlmsg = { + .msg_name = &nladdr, + .msg_namelen = sizeof(nladdr), + .msg_iov = &iov, + .msg_iovlen = 1, + }; + + nladdr.nl_family = AF_NETLINK; + nladdr.nl_pid = peer; + nladdr.nl_groups = groups; + + payload->nlmsg_seq = time(NULL); + + /* no need to send reply */ + if (!cb) + payload->nlmsg_flags |= NLM_F_ACK; + + fd = ovpn_rt_socket(); + if (fd < 0) { + fprintf(stderr, "%s: can't open rtnl socket\n", __func__); + return -errno; + } + + ret = ovpn_rt_bind(fd, 0); + if (ret < 0) { + fprintf(stderr, "%s: can't bind rtnl socket\n", __func__); + ret = -errno; + goto out; + } + + ret = sendmsg(fd, &nlmsg, 0); + if (ret < 0) { + fprintf(stderr, "%s: rtnl: error on sendmsg()\n", __func__); + ret = -errno; + goto out; + } + + /* prepare buffer to store RTNL replies */ + memset(buf, 0, sizeof(buf)); + iov.iov_base = buf; + + while (1) { + /* + * iov_len is modified by recvmsg(), therefore has to be initialized before + * using it again + */ + iov.iov_len = sizeof(buf); + rcv_len = recvmsg(fd, &nlmsg, 0); + if (rcv_len < 0) { + if (errno == EINTR || errno == EAGAIN) { + fprintf(stderr, "%s: interrupted call\n", + __func__); + continue; + } + fprintf(stderr, "%s: rtnl: error on recvmsg()\n", + __func__); + ret = -errno; + goto out; + } + + if (rcv_len == 0) { + fprintf(stderr, + "%s: rtnl: socket reached unexpected EOF\n", + __func__); + ret = -EIO; + goto out; + } + + if (nlmsg.msg_namelen != sizeof(nladdr)) { + fprintf(stderr, + "%s: sender address length: %u (expected %zu)\n", + __func__, nlmsg.msg_namelen, sizeof(nladdr)); + ret = -EIO; + goto out; + } + + h = (struct nlmsghdr *)buf; + while (rcv_len >= (int)sizeof(*h)) { + len = h->nlmsg_len; + rem_len = len - sizeof(*h); + + if (rem_len < 0 || len > rcv_len) { + if (nlmsg.msg_flags & MSG_TRUNC) { + fprintf(stderr, "%s: truncated message\n", + __func__); + ret = -EIO; + goto out; + } + fprintf(stderr, "%s: malformed message: len=%d\n", + __func__, len); + ret = -EIO; + goto out; + } + + if (h->nlmsg_type == NLMSG_DONE) { + ret = 0; + goto out; + } + + if (h->nlmsg_type == NLMSG_ERROR) { + err = (struct nlmsgerr *)NLMSG_DATA(h); + if (rem_len < (int)sizeof(struct nlmsgerr)) { + fprintf(stderr, "%s: ERROR truncated\n", + __func__); + ret = -EIO; + goto out; + } + + if (err->error) { + fprintf(stderr, "%s: (%d) %s\n", + __func__, err->error, + strerror(-err->error)); + ret = err->error; + goto out; + } + + ret = 0; + if (cb) { + int r = cb(h, arg_cb); + + if (r <= 0) + ret = r; + } + goto out; + } + + if (cb) { + int r = cb(h, arg_cb); + + if (r <= 0) { + ret = r; + goto out; + } + } else { + fprintf(stderr, "%s: RTNL: unexpected reply\n", + __func__); + } + + rcv_len -= NLMSG_ALIGN(len); + h = (struct nlmsghdr *)((uint8_t *)h + + NLMSG_ALIGN(len)); + } + + if (nlmsg.msg_flags & MSG_TRUNC) { + fprintf(stderr, "%s: message truncated\n", __func__); + continue; + } + + if (rcv_len) { + fprintf(stderr, "%s: rtnl: %d not parsed bytes\n", + __func__, rcv_len); + ret = -1; + goto out; + } + } +out: + close(fd); + + return ret; +} + +struct ovpn_link_req { + struct nlmsghdr n; + struct ifinfomsg i; + char buf[256]; +}; + +static int ovpn_new_iface(struct ovpn_ctx *ovpn) +{ + struct rtattr *linkinfo, *data; + struct ovpn_link_req req = { 0 }; + int ret = -1; + + fprintf(stdout, "Creating interface %s with mode %u\n", ovpn->ifname, + ovpn->mode); + + req.n.nlmsg_len = NLMSG_LENGTH(sizeof(req.i)); + req.n.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL; + req.n.nlmsg_type = RTM_NEWLINK; + + if (ovpn_addattr(&req.n, sizeof(req), IFLA_IFNAME, ovpn->ifname, + strlen(ovpn->ifname) + 1) < 0) + goto err; + + linkinfo = ovpn_nest_start(&req.n, sizeof(req), IFLA_LINKINFO); + if (!linkinfo) + goto err; + + if (ovpn_addattr(&req.n, sizeof(req), IFLA_INFO_KIND, OVPN_FAMILY_NAME, + strlen(OVPN_FAMILY_NAME) + 1) < 0) + goto err; + + if (ovpn->mode_set) { + data = ovpn_nest_start(&req.n, sizeof(req), IFLA_INFO_DATA); + if (!data) + goto err; + + if (ovpn_addattr(&req.n, sizeof(req), IFLA_OVPN_MODE, + &ovpn->mode, sizeof(uint8_t)) < 0) + goto err; + + ovpn_nest_end(&req.n, data); + } + + ovpn_nest_end(&req.n, linkinfo); + + req.i.ifi_family = AF_PACKET; + + ret = ovpn_rt_send(&req.n, 0, 0, NULL, NULL); +err: + return ret; +} + +static int ovpn_del_iface(struct ovpn_ctx *ovpn) +{ + struct ovpn_link_req req = { 0 }; + + fprintf(stdout, "Deleting interface %s ifindex %u\n", ovpn->ifname, + ovpn->ifindex); + + req.n.nlmsg_len = NLMSG_LENGTH(sizeof(req.i)); + req.n.nlmsg_flags = NLM_F_REQUEST; + req.n.nlmsg_type = RTM_DELLINK; + + req.i.ifi_family = AF_PACKET; + req.i.ifi_index = ovpn->ifindex; + + return ovpn_rt_send(&req.n, 0, 0, NULL, NULL); +} + +static int nl_seq_check(struct nl_msg (*msg)__always_unused, + void (*arg)__always_unused) +{ + return NL_OK; +} + +struct mcast_handler_args { + const char *group; + int id; +}; + +static int mcast_family_handler(struct nl_msg *msg, void *arg) +{ + struct mcast_handler_args *grp = arg; + struct nlattr *tb[CTRL_ATTR_MAX + 1]; + struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); + struct nlattr *mcgrp; + int rem_mcgrp; + + nla_parse(tb, CTRL_ATTR_MAX, genlmsg_attrdata(gnlh, 0), + genlmsg_attrlen(gnlh, 0), NULL); + + if (!tb[CTRL_ATTR_MCAST_GROUPS]) + return NL_SKIP; + + nla_for_each_nested(mcgrp, tb[CTRL_ATTR_MCAST_GROUPS], rem_mcgrp) { + struct nlattr *tb_mcgrp[CTRL_ATTR_MCAST_GRP_MAX + 1]; + + nla_parse(tb_mcgrp, CTRL_ATTR_MCAST_GRP_MAX, + nla_data(mcgrp), nla_len(mcgrp), NULL); + + if (!tb_mcgrp[CTRL_ATTR_MCAST_GRP_NAME] || + !tb_mcgrp[CTRL_ATTR_MCAST_GRP_ID]) + continue; + if (strncmp(nla_data(tb_mcgrp[CTRL_ATTR_MCAST_GRP_NAME]), + grp->group, nla_len(tb_mcgrp[CTRL_ATTR_MCAST_GRP_NAME]))) + continue; + grp->id = nla_get_u32(tb_mcgrp[CTRL_ATTR_MCAST_GRP_ID]); + break; + } + + return NL_SKIP; +} + +static int mcast_error_handler(struct sockaddr_nl (*nla)__always_unused, + struct nlmsgerr *err, void *arg) +{ + int *ret = arg; + + *ret = err->error; + return NL_STOP; +} + +static int mcast_ack_handler(struct nl_msg (*msg)__always_unused, void *arg) +{ + int *ret = arg; + + *ret = 0; + return NL_STOP; +} + +static int ovpn_handle_msg(struct nl_msg *msg, void *arg) +{ + struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg)); + struct nlattr *attrs[OVPN_A_MAX + 1]; + struct nlmsghdr *nlh = nlmsg_hdr(msg); + //enum ovpn_del_peer_reason reason; + char ifname[IF_NAMESIZE]; + int *ret = arg; + __u32 ifindex; + + fprintf(stderr, "received message from ovpn-dco\n"); + + *ret = -1; + + if (!genlmsg_valid_hdr(nlh, 0)) { + fprintf(stderr, "invalid header\n"); + return NL_STOP; + } + + if (nla_parse(attrs, OVPN_A_MAX, genlmsg_attrdata(gnlh, 0), + genlmsg_attrlen(gnlh, 0), NULL)) { + fprintf(stderr, "received bogus data from ovpn-dco\n"); + return NL_STOP; + } + + if (!attrs[OVPN_A_IFINDEX]) { + fprintf(stderr, "no ifindex in this message\n"); + return NL_STOP; + } + + ifindex = nla_get_u32(attrs[OVPN_A_IFINDEX]); + if (!if_indextoname(ifindex, ifname)) { + fprintf(stderr, "cannot resolve ifname for ifindex: %u\n", + ifindex); + return NL_STOP; + } + + switch (gnlh->cmd) { + case OVPN_CMD_PEER_DEL_NTF: + /*if (!attrs[OVPN_A_DEL_PEER_REASON]) { + * fprintf(stderr, "no reason in DEL_PEER message\n"); + * return NL_STOP; + *} + * + *reason = nla_get_u8(attrs[OVPN_A_DEL_PEER_REASON]); + *fprintf(stderr, + * "received CMD_DEL_PEER, ifname: %s reason: %d\n", + * ifname, reason); + */ + fprintf(stdout, "received CMD_PEER_DEL_NTF\n"); + break; + case OVPN_CMD_KEY_SWAP_NTF: + fprintf(stdout, "received CMD_KEY_SWAP_NTF\n"); + break; + default: + fprintf(stderr, "received unknown command: %d\n", gnlh->cmd); + return NL_STOP; + } + + *ret = 0; + return NL_OK; +} + +static int ovpn_get_mcast_id(struct nl_sock *sock, const char *family, + const char *group) +{ + struct nl_msg *msg; + struct nl_cb *cb; + int ret, ctrlid; + struct mcast_handler_args grp = { + .group = group, + .id = -ENOENT, + }; + + msg = nlmsg_alloc(); + if (!msg) + return -ENOMEM; + + cb = nl_cb_alloc(NL_CB_DEFAULT); + if (!cb) { + ret = -ENOMEM; + goto out_fail_cb; + } + + ctrlid = genl_ctrl_resolve(sock, "nlctrl"); + + genlmsg_put(msg, 0, 0, ctrlid, 0, 0, CTRL_CMD_GETFAMILY, 0); + + ret = -ENOBUFS; + NLA_PUT_STRING(msg, CTRL_ATTR_FAMILY_NAME, family); + + ret = nl_send_auto_complete(sock, msg); + if (ret < 0) + goto nla_put_failure; + + ret = 1; + + nl_cb_err(cb, NL_CB_CUSTOM, mcast_error_handler, &ret); + nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, mcast_ack_handler, &ret); + nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, mcast_family_handler, &grp); + + while (ret > 0) + nl_recvmsgs(sock, cb); + + if (ret == 0) + ret = grp.id; + nla_put_failure: + nl_cb_put(cb); + out_fail_cb: + nlmsg_free(msg); + return ret; +} + +static int ovpn_listen_mcast(void) +{ + struct nl_sock *sock; + struct nl_cb *cb; + int mcid, ret; + + sock = nl_socket_alloc(); + if (!sock) { + fprintf(stderr, "cannot allocate netlink socket\n"); + goto err_free; + } + + nl_socket_set_buffer_size(sock, 8192, 8192); + + ret = genl_connect(sock); + if (ret < 0) { + fprintf(stderr, "cannot connect to generic netlink: %s\n", + nl_geterror(ret)); + goto err_free; + } + + mcid = ovpn_get_mcast_id(sock, OVPN_FAMILY_NAME, OVPN_MCGRP_PEERS); + if (mcid < 0) { + fprintf(stderr, "cannot get mcast group: %s\n", + nl_geterror(mcid)); + goto err_free; + } + + ret = nl_socket_add_membership(sock, mcid); + if (ret) { + fprintf(stderr, "failed to join mcast group: %d\n", ret); + goto err_free; + } + + ret = 1; + cb = nl_cb_alloc(NL_CB_DEFAULT); + nl_cb_set(cb, NL_CB_SEQ_CHECK, NL_CB_CUSTOM, nl_seq_check, NULL); + nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, ovpn_handle_msg, &ret); + nl_cb_err(cb, NL_CB_CUSTOM, ovpn_nl_cb_error, &ret); + + while (ret == 1) { + int err = nl_recvmsgs(sock, cb); + + if (err < 0) { + fprintf(stderr, + "cannot receive netlink message: (%d) %s\n", + err, nl_geterror(-err)); + ret = -1; + break; + } + } + + nl_cb_put(cb); +err_free: + nl_socket_free(sock); + return ret; +} + +static void usage(const char *cmd) +{ + fprintf(stderr, + "Usage %s [arguments..]\n", + cmd); + fprintf(stderr, "where can be one of the following\n\n"); + + fprintf(stderr, "* new_iface [mode]: create new ovpn interface\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tmode:\n"); + fprintf(stderr, "\t\t- P2P for peer-to-peer mode (i.e. client)\n"); + fprintf(stderr, "\t\t- MP for multi-peer mode (i.e. server)\n"); + + fprintf(stderr, "* del_iface : delete ovpn interface\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + + fprintf(stderr, + "* listen [ipv6]: listen for incoming peer TCP connections\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tlport: TCP port to listen to\n"); + fprintf(stderr, + "\tpeers_file: file containing one peer per line: Line format:\n"); + fprintf(stderr, "\t\t \n"); + fprintf(stderr, + "\tipv6: whether the socket should listen to the IPv6 wildcard address\n"); + + fprintf(stderr, + "* connect [key_file]: start connecting peer of TCP-based VPN session\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the connecting peer\n"); + fprintf(stderr, "\traddr: peer IP address to connect to\n"); + fprintf(stderr, "\trport: peer TCP port to connect to\n"); + fprintf(stderr, + "\tkey_file: file containing the symmetric key for encryption\n"); + + fprintf(stderr, + "* new_peer [vpnaddr]: add new peer\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tlport: local UDP port to bind to\n"); + fprintf(stderr, + "\tpeer_id: peer ID to be used in data packets to/from this peer\n"); + fprintf(stderr, "\traddr: peer IP address\n"); + fprintf(stderr, "\trport: peer UDP port\n"); + fprintf(stderr, "\tvpnaddr: peer VPN IP\n"); + + fprintf(stderr, + "* new_multi_peer : add multiple peers as listed in the file\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tlport: local UDP port to bind to\n"); + fprintf(stderr, + "\tpeers_file: text file containing one peer per line. Line format:\n"); + fprintf(stderr, "\t\t \n"); + + fprintf(stderr, + "* set_peer : set peer attributes\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the peer to modify\n"); + fprintf(stderr, + "\tkeepalive_interval: interval for sending ping messages\n"); + fprintf(stderr, + "\tkeepalive_timeout: time after which a peer is timed out\n"); + + fprintf(stderr, "* del_peer : delete peer\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the peer to delete\n"); + + fprintf(stderr, "* get_peer [peer_id]: retrieve peer(s) status\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, + "\tpeer_id: peer ID of the peer to query. All peers are returned if omitted\n"); + + fprintf(stderr, + "* new_key : set data channel key\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, + "\tpeer_id: peer ID of the peer to configure the key for\n"); + fprintf(stderr, "\tslot: either 1 (primary) or 2 (secondary)\n"); + fprintf(stderr, "\tkey_id: an ID from 0 to 7\n"); + fprintf(stderr, + "\tcipher: cipher to use, supported: aes (AES-GCM), chachapoly (CHACHA20POLY1305)\n"); + fprintf(stderr, + "\tkey_dir: key direction, must 0 on one host and 1 on the other\n"); + fprintf(stderr, "\tkey_file: file containing the pre-shared key\n"); + + fprintf(stderr, + "* del_key [slot]: erase existing data channel key\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the peer to modify\n"); + fprintf(stderr, "\tslot: slot to erase. PRIMARY if omitted\n"); + + fprintf(stderr, + "* get_key : retrieve non sensible key data\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the peer to query\n"); + fprintf(stderr, "\tslot: either 1 (primary) or 2 (secondary)\n"); + + fprintf(stderr, + "* swap_keys : swap content of primary and secondary key slots\n"); + fprintf(stderr, "\tiface: ovpn interface name\n"); + fprintf(stderr, "\tpeer_id: peer ID of the peer to modify\n"); + + fprintf(stderr, + "* listen_mcast: listen to ovpn netlink multicast messages\n"); +} + +static int ovpn_parse_remote(struct ovpn_ctx *ovpn, const char *host, + const char *service, const char *vpnip) +{ + int ret; + struct addrinfo *result; + struct addrinfo hints = { + .ai_family = ovpn->sa_family, + .ai_socktype = SOCK_DGRAM, + .ai_protocol = IPPROTO_UDP + }; + + if (host) { + ret = getaddrinfo(host, service, &hints, &result); + if (ret == EAI_NONAME || ret == EAI_FAIL) + return -1; + + if (!(result->ai_family == AF_INET && + result->ai_addrlen == sizeof(struct sockaddr_in)) && + !(result->ai_family == AF_INET6 && + result->ai_addrlen == sizeof(struct sockaddr_in6))) { + ret = -EINVAL; + goto out; + } + + memcpy(&ovpn->remote, result->ai_addr, result->ai_addrlen); + } + + if (vpnip) { + ret = getaddrinfo(vpnip, NULL, &hints, &result); + if (ret == EAI_NONAME || ret == EAI_FAIL) + return -1; + + if (!(result->ai_family == AF_INET && + result->ai_addrlen == sizeof(struct sockaddr_in)) && + !(result->ai_family == AF_INET6 && + result->ai_addrlen == sizeof(struct sockaddr_in6))) { + ret = -EINVAL; + goto out; + } + + memcpy(&ovpn->peer_ip, result->ai_addr, result->ai_addrlen); + ovpn->sa_family = result->ai_family; + + ovpn->peer_ip_set = true; + } + + ret = 0; +out: + freeaddrinfo(result); + return ret; +} + +static int ovpn_parse_new_peer(struct ovpn_ctx *ovpn, const char *peer_id, + const char *raddr, const char *rport, + const char *vpnip) +{ + ovpn->peer_id = strtoul(peer_id, NULL, 10); + if (errno == ERANGE || ovpn->peer_id > PEER_ID_UNDEF) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + + return ovpn_parse_remote(ovpn, raddr, rport, vpnip); +} + +static int ovpn_parse_key_slot(const char *arg, struct ovpn_ctx *ovpn) +{ + int slot = strtoul(arg, NULL, 10); + + if (errno == ERANGE || slot < 1 || slot > 2) { + fprintf(stderr, "key slot out of range\n"); + return -1; + } + + switch (slot) { + case 1: + ovpn->key_slot = OVPN_KEY_SLOT_PRIMARY; + break; + case 2: + ovpn->key_slot = OVPN_KEY_SLOT_SECONDARY; + break; + } + + return 0; +} + +static int ovpn_send_tcp_data(int socket) +{ + uint16_t len = htons(1000); + uint8_t buf[1002]; + int ret; + + memcpy(buf, &len, sizeof(len)); + memset(buf + sizeof(len), 0x86, sizeof(buf) - sizeof(len)); + + ret = send(socket, buf, sizeof(buf), 0); + + fprintf(stdout, "Sent %u bytes over TCP socket\n", ret); + + return ret > 0 ? 0 : ret; +} + +static int ovpn_recv_tcp_data(int socket) +{ + uint8_t buf[1002]; + uint16_t len; + int ret; + + ret = recv(socket, buf, sizeof(buf), 0); + + if (ret < 2) { + fprintf(stderr, ">>>> Error while reading TCP data: %d\n", ret); + return ret; + } + + memcpy(&len, buf, sizeof(len)); + len = ntohs(len); + + fprintf(stdout, ">>>> Received %u bytes over TCP socket, header: %u\n", + ret, len); + +/* int i; + * for (i = 2; i < ret; i++) { + * fprintf(stdout, "0x%.2x ", buf[i]); + * if (i && !((i - 2) % 16)) + * fprintf(stdout, "\n"); + * } + * fprintf(stdout, "\n"); + */ + return 0; +} + +static enum ovpn_cmd ovpn_parse_cmd(const char *cmd) +{ + if (!strcmp(cmd, "new_iface")) + return CMD_NEW_IFACE; + + if (!strcmp(cmd, "del_iface")) + return CMD_DEL_IFACE; + + if (!strcmp(cmd, "listen")) + return CMD_LISTEN; + + if (!strcmp(cmd, "connect")) + return CMD_CONNECT; + + if (!strcmp(cmd, "new_peer")) + return CMD_NEW_PEER; + + if (!strcmp(cmd, "new_multi_peer")) + return CMD_NEW_MULTI_PEER; + + if (!strcmp(cmd, "set_peer")) + return CMD_SET_PEER; + + if (!strcmp(cmd, "del_peer")) + return CMD_DEL_PEER; + + if (!strcmp(cmd, "get_peer")) + return CMD_GET_PEER; + + if (!strcmp(cmd, "new_key")) + return CMD_NEW_KEY; + + if (!strcmp(cmd, "del_key")) + return CMD_DEL_KEY; + + if (!strcmp(cmd, "get_key")) + return CMD_GET_KEY; + + if (!strcmp(cmd, "swap_keys")) + return CMD_SWAP_KEYS; + + if (!strcmp(cmd, "listen_mcast")) + return CMD_LISTEN_MCAST; + + return CMD_INVALID; +} + +static int ovpn_run_cmd(struct ovpn_ctx *ovpn) +{ + char peer_id[10], vpnip[INET6_ADDRSTRLEN], raddr[128], rport[10]; + int n, ret; + FILE *fp; + + switch (ovpn->cmd) { + case CMD_NEW_IFACE: + ret = ovpn_new_iface(ovpn); + break; + case CMD_DEL_IFACE: + ret = ovpn_del_iface(ovpn); + break; + case CMD_LISTEN: + ret = ovpn_listen(ovpn, ovpn->sa_family); + if (ret < 0) { + fprintf(stderr, "cannot listen on TCP socket\n"); + return ret; + } + + fp = fopen(ovpn->peers_file, "r"); + if (!fp) { + fprintf(stderr, "cannot open file: %s\n", + ovpn->peers_file); + return -1; + } + + while ((n = fscanf(fp, "%s %s\n", peer_id, vpnip)) == 2) { + struct ovpn_ctx peer_ctx = { 0 }; + + peer_ctx.ifindex = ovpn->ifindex; + peer_ctx.sa_family = ovpn->sa_family; + + peer_ctx.socket = ovpn_accept(ovpn); + if (peer_ctx.socket < 0) { + fprintf(stderr, "cannot accept connection!\n"); + return -1; + } + + /* store the socket of the first peer to test TCP I/O */ + if (ovpn->cli_socket < 0) + ovpn->cli_socket = peer_ctx.socket; + + ret = ovpn_parse_new_peer(&peer_ctx, peer_id, NULL, + NULL, vpnip); + if (ret < 0) { + fprintf(stderr, "error while parsing line\n"); + return -1; + } + + ret = ovpn_new_peer(&peer_ctx, true); + if (ret < 0) { + fprintf(stderr, + "cannot add peer to VPN: %s %s\n", + peer_id, vpnip); + return ret; + } + } + + if (ovpn->cli_socket >= 0) + ret = ovpn_recv_tcp_data(ovpn->cli_socket); + + break; + case CMD_CONNECT: + ret = ovpn_connect(ovpn); + if (ret < 0) { + fprintf(stderr, "cannot connect TCP socket\n"); + return ret; + } + + ret = ovpn_new_peer(ovpn, true); + if (ret < 0) { + fprintf(stderr, "cannot add peer to VPN\n"); + close(ovpn->socket); + return ret; + } + + if (ovpn->cipher != OVPN_CIPHER_ALG_NONE) { + ret = ovpn_new_key(ovpn); + if (ret < 0) { + fprintf(stderr, "cannot set key\n"); + return ret; + } + } + + ret = ovpn_send_tcp_data(ovpn->socket); + break; + case CMD_NEW_PEER: + ret = ovpn_udp_socket(ovpn, AF_INET6); //ovpn->sa_family ? + if (ret < 0) + return ret; + + ret = ovpn_new_peer(ovpn, false); + break; + case CMD_NEW_MULTI_PEER: + ret = ovpn_udp_socket(ovpn, AF_INET6); + if (ret < 0) + return ret; + + fp = fopen(ovpn->peers_file, "r"); + if (!fp) { + fprintf(stderr, "cannot open file: %s\n", + ovpn->peers_file); + return -1; + } + + while ((n = fscanf(fp, "%s %s %s %s\n", peer_id, raddr, rport, + vpnip)) == 4) { + struct ovpn_ctx peer_ctx = { 0 }; + + peer_ctx.ifindex = ovpn->ifindex; + peer_ctx.socket = ovpn->socket; + peer_ctx.sa_family = AF_UNSPEC; + + ret = ovpn_parse_new_peer(&peer_ctx, peer_id, raddr, + rport, vpnip); + if (ret < 0) { + fprintf(stderr, "error while parsing line\n"); + return -1; + } + + ret = ovpn_new_peer(&peer_ctx, false); + if (ret < 0) { + fprintf(stderr, + "cannot add peer to VPN: %s %s %s %s\n", + peer_id, raddr, rport, vpnip); + return ret; + } + } + break; + case CMD_SET_PEER: + ret = ovpn_set_peer(ovpn); + break; + case CMD_DEL_PEER: + ret = ovpn_del_peer(ovpn); + break; + case CMD_GET_PEER: + if (ovpn->peer_id == PEER_ID_UNDEF) + fprintf(stderr, "List of peers connected to: %s\n", + ovpn->ifname); + + ret = ovpn_get_peer(ovpn); + break; + case CMD_NEW_KEY: + ret = ovpn_new_key(ovpn); + break; + case CMD_DEL_KEY: + ret = ovpn_del_key(ovpn); + break; + case CMD_GET_KEY: + ret = ovpn_get_key(ovpn); + break; + case CMD_SWAP_KEYS: + ret = ovpn_swap_keys(ovpn); + break; + case CMD_LISTEN_MCAST: + ret = ovpn_listen_mcast(); + break; + case CMD_INVALID: + break; + } + + return ret; +} + +static int ovpn_parse_cmd_args(struct ovpn_ctx *ovpn, int argc, char *argv[]) +{ + int ret; + + /* no args required for LISTEN_MCAST */ + if (ovpn->cmd == CMD_LISTEN_MCAST) + return 0; + + /* all commands need an ifname */ + if (argc < 3) + return -EINVAL; + + strscpy(ovpn->ifname, argv[2], IFNAMSIZ - 1); + ovpn->ifname[IFNAMSIZ - 1] = '\0'; + + /* all commands, except NEW_IFNAME, needs an ifindex */ + if (ovpn->cmd != CMD_NEW_IFACE) { + ovpn->ifindex = if_nametoindex(ovpn->ifname); + if (!ovpn->ifindex) { + fprintf(stderr, "cannot find interface: %s\n", + strerror(errno)); + return -1; + } + } + + switch (ovpn->cmd) { + case CMD_NEW_IFACE: + if (argc < 4) + break; + + if (!strcmp(argv[3], "P2P")) { + ovpn->mode = OVPN_MODE_P2P; + } else if (!strcmp(argv[3], "MP")) { + ovpn->mode = OVPN_MODE_MP; + } else { + fprintf(stderr, "Cannot parse iface mode: %s\n", + argv[3]); + return -1; + } + ovpn->mode_set = true; + break; + case CMD_DEL_IFACE: + break; + case CMD_LISTEN: + if (argc < 5) + return -EINVAL; + + ovpn->lport = strtoul(argv[3], NULL, 10); + if (errno == ERANGE || ovpn->lport > 65535) { + fprintf(stderr, "lport value out of range\n"); + return -1; + } + + ovpn->peers_file = argv[4]; + + if (argc > 5 && !strcmp(argv[5], "ipv6")) + ovpn->sa_family = AF_INET6; + break; + case CMD_CONNECT: + if (argc < 6) + return -EINVAL; + + ovpn->sa_family = AF_INET; + + ret = ovpn_parse_new_peer(ovpn, argv[3], argv[4], argv[5], + NULL); + if (ret < 0) { + fprintf(stderr, "Cannot parse remote peer data\n"); + return -1; + } + + if (argc > 6) { + ovpn->key_slot = OVPN_KEY_SLOT_PRIMARY; + ovpn->key_id = 0; + ovpn->cipher = OVPN_CIPHER_ALG_AES_GCM; + ovpn->key_dir = KEY_DIR_OUT; + + ret = ovpn_parse_key(argv[6], ovpn); + if (ret) + return -1; + } + break; + case CMD_NEW_PEER: + if (argc < 7) + return -EINVAL; + + ovpn->lport = strtoul(argv[4], NULL, 10); + if (errno == ERANGE || ovpn->lport > 65535) { + fprintf(stderr, "lport value out of range\n"); + return -1; + } + + const char *vpnip = (argc > 7) ? argv[7] : NULL; + + ret = ovpn_parse_new_peer(ovpn, argv[3], argv[5], argv[6], + vpnip); + if (ret < 0) + return -1; + break; + case CMD_NEW_MULTI_PEER: + if (argc < 5) + return -EINVAL; + + ovpn->lport = strtoul(argv[3], NULL, 10); + if (errno == ERANGE || ovpn->lport > 65535) { + fprintf(stderr, "lport value out of range\n"); + return -1; + } + + ovpn->peers_file = argv[4]; + break; + case CMD_SET_PEER: + if (argc < 6) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE || ovpn->peer_id > PEER_ID_UNDEF) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + + ovpn->keepalive_interval = strtoul(argv[4], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, + "keepalive interval value out of range\n"); + return -1; + } + + ovpn->keepalive_timeout = strtoul(argv[5], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, + "keepalive interval value out of range\n"); + return -1; + } + break; + case CMD_DEL_PEER: + if (argc < 4) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE || ovpn->peer_id > PEER_ID_UNDEF) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + break; + case CMD_GET_PEER: + ovpn->peer_id = PEER_ID_UNDEF; + if (argc > 3) { + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE || ovpn->peer_id > PEER_ID_UNDEF) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + } + break; + case CMD_NEW_KEY: + if (argc < 9) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + + ret = ovpn_parse_key_slot(argv[4], ovpn); + if (ret) + return -1; + + ovpn->key_id = strtoul(argv[5], NULL, 10); + if (errno == ERANGE || ovpn->key_id > 2) { + fprintf(stderr, "key ID out of range\n"); + return -1; + } + + ret = ovpn_parse_cipher(argv[6], ovpn); + if (ret < 0) + return -1; + + ret = ovpn_parse_key_direction(argv[7], ovpn); + if (ret < 0) + return -1; + + ret = ovpn_parse_key(argv[8], ovpn); + if (ret) + return -1; + break; + case CMD_DEL_KEY: + if (argc < 4) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + + ret = ovpn_parse_key_slot(argv[4], ovpn); + if (ret) + return ret; + break; + case CMD_GET_KEY: + if (argc < 5) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + + ret = ovpn_parse_key_slot(argv[4], ovpn); + if (ret) + return ret; + break; + case CMD_SWAP_KEYS: + if (argc < 4) + return -EINVAL; + + ovpn->peer_id = strtoul(argv[3], NULL, 10); + if (errno == ERANGE) { + fprintf(stderr, "peer ID value out of range\n"); + return -1; + } + break; + case CMD_LISTEN_MCAST: + break; + case CMD_INVALID: + break; + } + + return 0; +} + +int main(int argc, char *argv[]) +{ + struct ovpn_ctx ovpn; + int ret; + + if (argc < 2) { + usage(argv[0]); + return -1; + } + + memset(&ovpn, 0, sizeof(ovpn)); + ovpn.sa_family = AF_INET; + ovpn.cipher = OVPN_CIPHER_ALG_NONE; + ovpn.cli_socket = -1; + + ovpn.cmd = ovpn_parse_cmd(argv[1]); + if (ovpn.cmd == CMD_INVALID) { + fprintf(stderr, "Error: unknown command.\n\n"); + usage(argv[0]); + return -1; + } + + ret = ovpn_parse_cmd_args(&ovpn, argc, argv); + if (ret < 0) { + fprintf(stderr, "Error: invalid arguments.\n\n"); + if (ret == -EINVAL) + usage(argv[0]); + return ret; + } + + ret = ovpn_run_cmd(&ovpn); + if (ret) + fprintf(stderr, "Cannot execute command: %s (%d)\n", + strerror(-ret), ret); + + return ret; +} diff --git a/tools/testing/selftests/net/ovpn/tcp_peers.txt b/tools/testing/selftests/net/ovpn/tcp_peers.txt new file mode 100644 index 000000000000..d753eebe8716 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/tcp_peers.txt @@ -0,0 +1,5 @@ +1 5.5.5.2 +2 5.5.5.3 +3 5.5.5.4 +4 5.5.5.5 +5 5.5.5.6 diff --git a/tools/testing/selftests/net/ovpn/test-chachapoly.sh b/tools/testing/selftests/net/ovpn/test-chachapoly.sh new file mode 100755 index 000000000000..79788f10d33b --- /dev/null +++ b/tools/testing/selftests/net/ovpn/test-chachapoly.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) 2024 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +ALG="chachapoly" + +source test.sh diff --git a/tools/testing/selftests/net/ovpn/test-float.sh b/tools/testing/selftests/net/ovpn/test-float.sh new file mode 100755 index 000000000000..93e1b729861d --- /dev/null +++ b/tools/testing/selftests/net/ovpn/test-float.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) 2024 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +FLOAT="1" + +source test.sh diff --git a/tools/testing/selftests/net/ovpn/test-tcp.sh b/tools/testing/selftests/net/ovpn/test-tcp.sh new file mode 100755 index 000000000000..7542f595cc56 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/test-tcp.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) 2024 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +PROTO="TCP" + +source test.sh diff --git a/tools/testing/selftests/net/ovpn/test.sh b/tools/testing/selftests/net/ovpn/test.sh new file mode 100755 index 000000000000..07f3a82df8f3 --- /dev/null +++ b/tools/testing/selftests/net/ovpn/test.sh @@ -0,0 +1,183 @@ +#!/bin/bash +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) 2020-2024 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +#set -x +set -e + +UDP_PEERS_FILE=${UDP_PEERS_FILE:-udp_peers.txt} +TCP_PEERS_FILE=${TCP_PEERS_FILE:-tcp_peers.txt} +OVPN_CLI=${OVPN_CLI:-./ovpn-cli} +ALG=${ALG:-aes} +PROTO=${PROTO:-UDP} +FLOAT=${FLOAT:-0} + +create_ns() { + ip netns add peer${1} +} + +setup_ns() { + MODE="P2P" + + if [ ${1} -eq 0 ]; then + MODE="MP" + for p in $(seq 1 ${NUM_PEERS}); do + ip link add veth${p} netns peer0 type veth peer name veth${p} netns peer${p} + + ip -n peer0 addr add 10.10.${p}.1/24 dev veth${p} + ip -n peer0 link set veth${p} up + + ip -n peer${p} addr add 10.10.${p}.2/24 dev veth${p} + ip -n peer${p} link set veth${p} up + done + fi + + ip netns exec peer${1} ${OVPN_CLI} new_iface tun${1} $MODE + ip -n peer${1} addr add ${2} dev tun${1} + ip -n peer${1} link set tun${1} up +} + +add_peer() { + if [ "${PROTO}" == "UDP" ]; then + if [ ${1} -eq 0 ]; then + ip netns exec peer0 ${OVPN_CLI} new_multi_peer tun0 1 ${UDP_PEERS_FILE} + + for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer0 ${OVPN_CLI} new_key tun0 ${p} 1 0 ${ALG} 0 \ + data64.key + done + else + ip netns exec peer${1} ${OVPN_CLI} new_peer tun${1} ${1} 1 10.10.${1}.1 1 + ip netns exec peer${1} ${OVPN_CLI} new_key tun${1} ${1} 1 0 ${ALG} 1 \ + data64.key + fi + else + if [ ${1} -eq 0 ]; then + (ip netns exec peer0 ${OVPN_CLI} listen tun0 1 ${TCP_PEERS_FILE} && { + for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer0 ${OVPN_CLI} new_key tun0 ${p} 1 0 \ + ${ALG} 0 data64.key + done + }) & + sleep 5 + else + ip netns exec peer${1} ${OVPN_CLI} connect tun${1} ${1} 10.10.${1}.1 1 \ + data64.key + fi + fi +} + +cleanup() { + for p in $(seq 1 10); do + ip -n peer0 link del veth${p} 2>/dev/null || true + done + for p in $(seq 0 10); do + ip netns exec peer${p} ${OVPN_CLI} del_iface tun${p} 2>/dev/null || true + ip netns del peer${p} 2>/dev/null || true + done +} + +if [ "${PROTO}" == "UDP" ]; then + NUM_PEERS=${NUM_PEERS:-$(wc -l ${UDP_PEERS_FILE} | awk '{print $1}')} +else + NUM_PEERS=${NUM_PEERS:-$(wc -l ${TCP_PEERS_FILE} | awk '{print $1}')} +fi + +cleanup + +modprobe -q ovpn || true + +for p in $(seq 0 ${NUM_PEERS}); do + create_ns ${p} +done + +for p in $(seq 0 ${NUM_PEERS}); do + setup_ns ${p} 5.5.5.$((${p} + 1))/24 +done + +for p in $(seq 0 ${NUM_PEERS}); do + add_peer ${p} +done + +for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer0 ${OVPN_CLI} set_peer tun0 ${p} 60 120 + ip netns exec peer${p} ${OVPN_CLI} set_peer tun${p} ${p} 60 120 +done + +for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer0 ping -qfc 1000 -w 5 5.5.5.$((${p} + 1)) +done + +if [ "$FLOAT" == "1" ]; then + # make clients float.. + for p in $(seq 1 ${NUM_PEERS}); do + ip -n peer${p} addr del 10.10.${p}.2/24 dev veth${p} + ip -n peer${p} addr add 10.10.${p}.3/24 dev veth${p} + done + for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer${p} ping -qfc 1000 -w 5 5.5.5.1 + done +fi + +ip netns exec peer0 iperf3 -1 -s & +sleep 1 +ip netns exec peer1 iperf3 -Z -t 3 -c 5.5.5.1 + +echo "Adding secondary key and then swap:" +for p in $(seq 1 ${NUM_PEERS}); do + ip netns exec peer0 ${OVPN_CLI} new_key tun0 ${p} 2 1 ${ALG} 0 data64.key + ip netns exec peer${p} ${OVPN_CLI} new_key tun${p} ${p} 2 1 ${ALG} 1 data64.key + ip netns exec peer${p} ${OVPN_CLI} swap_keys tun${p} ${p} +done + +sleep 1 +echo "Querying all peers:" +ip netns exec peer0 ${OVPN_CLI} get_peer tun0 +ip netns exec peer1 ${OVPN_CLI} get_peer tun1 + +echo "Querying peer 1:" +ip netns exec peer0 ${OVPN_CLI} get_peer tun0 1 + +echo "Querying non-existent peer 10:" +ip netns exec peer0 ${OVPN_CLI} get_peer tun0 10 || true + +echo "Deleting peer 1:" +ip netns exec peer0 ${OVPN_CLI} del_peer tun0 1 +ip netns exec peer1 ${OVPN_CLI} del_peer tun1 1 + +echo "Querying keys:" +for p in $(seq 2 ${NUM_PEERS}); do + ip netns exec peer${p} ${OVPN_CLI} get_key tun${p} ${p} 1 + ip netns exec peer${p} ${OVPN_CLI} get_key tun${p} ${p} 2 +done + +echo "Deleting keys:" +for p in $(seq 2 ${NUM_PEERS}); do + ip netns exec peer${p} ${OVPN_CLI} del_key tun${p} ${p} 1 + ip netns exec peer${p} ${OVPN_CLI} del_key tun${p} ${p} 2 +done + +echo "Setting timeout to 10s MP:" +# bring ifaces down to prevent traffic being sent +for p in $(seq 0 ${NUM_PEERS}); do + ip -n peer${p} link set tun${p} down +done +# set short timeout +for p in $(seq 2 ${NUM_PEERS}); do + ip netns exec peer0 ${OVPN_CLI} set_peer tun0 ${p} 10 10 || true + ip netns exec peer${p} ${OVPN_CLI} set_peer tun${p} ${p} 0 0 +done +# wait for peers to timeout +sleep 15 + +echo "Setting timeout to 10s P2P:" +for p in $(seq 2 ${NUM_PEERS}); do + ip netns exec peer${p} ${OVPN_CLI} set_peer tun${p} ${p} 10 10 +done +sleep 15 + +cleanup + +modprobe -r ovpn || true diff --git a/tools/testing/selftests/net/ovpn/udp_peers.txt b/tools/testing/selftests/net/ovpn/udp_peers.txt new file mode 100644 index 000000000000..32f14bd9347a --- /dev/null +++ b/tools/testing/selftests/net/ovpn/udp_peers.txt @@ -0,0 +1,5 @@ +1 10.10.1.2 1 5.5.5.2 +2 10.10.2.2 1 5.5.5.3 +3 10.10.3.2 1 5.5.5.4 +4 10.10.4.2 1 5.5.5.5 +5 10.10.5.2 1 5.5.5.6 -- 2.47.0 From df5fd664d8ea01bad5af3ff4c6575bda0917383c Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:22:05 +0100 Subject: [PATCH 11/13] perf-per-core Signed-off-by: Peter Jung --- Documentation/arch/x86/topology.rst | 4 + arch/x86/events/rapl.c | 408 ++++++++++++++++++-------- arch/x86/include/asm/processor.h | 1 + arch/x86/include/asm/topology.h | 1 + arch/x86/kernel/cpu/debugfs.c | 1 + arch/x86/kernel/cpu/topology_common.c | 1 + 6 files changed, 288 insertions(+), 128 deletions(-) diff --git a/Documentation/arch/x86/topology.rst b/Documentation/arch/x86/topology.rst index 7352ab89a55a..c12837e61bda 100644 --- a/Documentation/arch/x86/topology.rst +++ b/Documentation/arch/x86/topology.rst @@ -135,6 +135,10 @@ Thread-related topology information in the kernel: The ID of the core to which a thread belongs. It is also printed in /proc/cpuinfo "core_id." + - topology_logical_core_id(); + + The logical core ID to which a thread belongs. + System topology examples diff --git a/arch/x86/events/rapl.c b/arch/x86/events/rapl.c index a481a939862e..6b405bf46781 100644 --- a/arch/x86/events/rapl.c +++ b/arch/x86/events/rapl.c @@ -39,6 +39,10 @@ * event: rapl_energy_psys * perf code: 0x5 * + * per_core counter: consumption of a single physical core + * event: rapl_energy_per_core (power_per_core PMU) + * perf code: 0x1 + * * We manage those counters as free running (read-only). They may be * use simultaneously by other tools, such as turbostat. * @@ -70,18 +74,25 @@ MODULE_LICENSE("GPL"); /* * RAPL energy status counters */ -enum perf_rapl_events { +enum perf_rapl_pkg_events { PERF_RAPL_PP0 = 0, /* all cores */ PERF_RAPL_PKG, /* entire package */ PERF_RAPL_RAM, /* DRAM */ PERF_RAPL_PP1, /* gpu */ PERF_RAPL_PSYS, /* psys */ - PERF_RAPL_MAX, - NR_RAPL_DOMAINS = PERF_RAPL_MAX, + PERF_RAPL_PKG_EVENTS_MAX, + NR_RAPL_PKG_DOMAINS = PERF_RAPL_PKG_EVENTS_MAX, +}; + +enum perf_rapl_core_events { + PERF_RAPL_PER_CORE = 0, /* per-core */ + + PERF_RAPL_CORE_EVENTS_MAX, + NR_RAPL_CORE_DOMAINS = PERF_RAPL_CORE_EVENTS_MAX, }; -static const char *const rapl_domain_names[NR_RAPL_DOMAINS] __initconst = { +static const char *const rapl_pkg_domain_names[NR_RAPL_PKG_DOMAINS] __initconst = { "pp0-core", "package", "dram", @@ -89,6 +100,10 @@ static const char *const rapl_domain_names[NR_RAPL_DOMAINS] __initconst = { "psys", }; +static const char *const rapl_core_domain_names[NR_RAPL_CORE_DOMAINS] __initconst = { + "per-core", +}; + /* * event code: LSB 8 bits, passed in attr->config * any other bit is reserved @@ -128,8 +143,9 @@ struct rapl_pmu { struct rapl_pmus { struct pmu pmu; + cpumask_t cpumask; unsigned int nr_rapl_pmu; - struct rapl_pmu *pmus[] __counted_by(nr_rapl_pmu); + struct rapl_pmu *rapl_pmu[] __counted_by(nr_rapl_pmu); }; enum rapl_unit_quirk { @@ -139,19 +155,22 @@ enum rapl_unit_quirk { }; struct rapl_model { - struct perf_msr *rapl_msrs; - unsigned long events; + struct perf_msr *rapl_pkg_msrs; + struct perf_msr *rapl_core_msrs; + unsigned long pkg_events; + unsigned long core_events; unsigned int msr_power_unit; enum rapl_unit_quirk unit_quirk; }; /* 1/2^hw_unit Joule */ -static int rapl_hw_unit[NR_RAPL_DOMAINS] __read_mostly; -static struct rapl_pmus *rapl_pmus; -static cpumask_t rapl_cpu_mask; -static unsigned int rapl_cntr_mask; +static int rapl_hw_unit[NR_RAPL_PKG_DOMAINS] __read_mostly; +static struct rapl_pmus *rapl_pmus_pkg; +static struct rapl_pmus *rapl_pmus_core; +static unsigned int rapl_pkg_cntr_mask; +static unsigned int rapl_core_cntr_mask; static u64 rapl_timer_ms; -static struct perf_msr *rapl_msrs; +static struct rapl_model *rapl_model; /* * Helper functions to get the correct topology macros according to the @@ -177,7 +196,8 @@ static inline struct rapl_pmu *cpu_to_rapl_pmu(unsigned int cpu) * The unsigned check also catches the '-1' return value for non * existent mappings in the topology map. */ - return rapl_pmu_idx < rapl_pmus->nr_rapl_pmu ? rapl_pmus->pmus[rapl_pmu_idx] : NULL; + return rapl_pmu_idx < rapl_pmus_pkg->nr_rapl_pmu ? + rapl_pmus_pkg->rapl_pmu[rapl_pmu_idx] : NULL; } static inline u64 rapl_read_counter(struct perf_event *event) @@ -189,7 +209,7 @@ static inline u64 rapl_read_counter(struct perf_event *event) static inline u64 rapl_scale(u64 v, int cfg) { - if (cfg > NR_RAPL_DOMAINS) { + if (cfg > NR_RAPL_PKG_DOMAINS) { pr_warn("Invalid domain %d, failed to scale data\n", cfg); return v; } @@ -241,34 +261,34 @@ static void rapl_start_hrtimer(struct rapl_pmu *pmu) static enum hrtimer_restart rapl_hrtimer_handle(struct hrtimer *hrtimer) { - struct rapl_pmu *pmu = container_of(hrtimer, struct rapl_pmu, hrtimer); + struct rapl_pmu *rapl_pmu = container_of(hrtimer, struct rapl_pmu, hrtimer); struct perf_event *event; unsigned long flags; - if (!pmu->n_active) + if (!rapl_pmu->n_active) return HRTIMER_NORESTART; - raw_spin_lock_irqsave(&pmu->lock, flags); + raw_spin_lock_irqsave(&rapl_pmu->lock, flags); - list_for_each_entry(event, &pmu->active_list, active_entry) + list_for_each_entry(event, &rapl_pmu->active_list, active_entry) rapl_event_update(event); - raw_spin_unlock_irqrestore(&pmu->lock, flags); + raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); - hrtimer_forward_now(hrtimer, pmu->timer_interval); + hrtimer_forward_now(hrtimer, rapl_pmu->timer_interval); return HRTIMER_RESTART; } -static void rapl_hrtimer_init(struct rapl_pmu *pmu) +static void rapl_hrtimer_init(struct rapl_pmu *rapl_pmu) { - struct hrtimer *hr = &pmu->hrtimer; + struct hrtimer *hr = &rapl_pmu->hrtimer; hrtimer_init(hr, CLOCK_MONOTONIC, HRTIMER_MODE_REL); hr->function = rapl_hrtimer_handle; } -static void __rapl_pmu_event_start(struct rapl_pmu *pmu, +static void __rapl_pmu_event_start(struct rapl_pmu *rapl_pmu, struct perf_event *event) { if (WARN_ON_ONCE(!(event->hw.state & PERF_HES_STOPPED))) @@ -276,39 +296,39 @@ static void __rapl_pmu_event_start(struct rapl_pmu *pmu, event->hw.state = 0; - list_add_tail(&event->active_entry, &pmu->active_list); + list_add_tail(&event->active_entry, &rapl_pmu->active_list); local64_set(&event->hw.prev_count, rapl_read_counter(event)); - pmu->n_active++; - if (pmu->n_active == 1) - rapl_start_hrtimer(pmu); + rapl_pmu->n_active++; + if (rapl_pmu->n_active == 1) + rapl_start_hrtimer(rapl_pmu); } static void rapl_pmu_event_start(struct perf_event *event, int mode) { - struct rapl_pmu *pmu = event->pmu_private; + struct rapl_pmu *rapl_pmu = event->pmu_private; unsigned long flags; - raw_spin_lock_irqsave(&pmu->lock, flags); - __rapl_pmu_event_start(pmu, event); - raw_spin_unlock_irqrestore(&pmu->lock, flags); + raw_spin_lock_irqsave(&rapl_pmu->lock, flags); + __rapl_pmu_event_start(rapl_pmu, event); + raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); } static void rapl_pmu_event_stop(struct perf_event *event, int mode) { - struct rapl_pmu *pmu = event->pmu_private; + struct rapl_pmu *rapl_pmu = event->pmu_private; struct hw_perf_event *hwc = &event->hw; unsigned long flags; - raw_spin_lock_irqsave(&pmu->lock, flags); + raw_spin_lock_irqsave(&rapl_pmu->lock, flags); /* mark event as deactivated and stopped */ if (!(hwc->state & PERF_HES_STOPPED)) { - WARN_ON_ONCE(pmu->n_active <= 0); - pmu->n_active--; - if (pmu->n_active == 0) - hrtimer_cancel(&pmu->hrtimer); + WARN_ON_ONCE(rapl_pmu->n_active <= 0); + rapl_pmu->n_active--; + if (rapl_pmu->n_active == 0) + hrtimer_cancel(&rapl_pmu->hrtimer); list_del(&event->active_entry); @@ -326,23 +346,23 @@ static void rapl_pmu_event_stop(struct perf_event *event, int mode) hwc->state |= PERF_HES_UPTODATE; } - raw_spin_unlock_irqrestore(&pmu->lock, flags); + raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); } static int rapl_pmu_event_add(struct perf_event *event, int mode) { - struct rapl_pmu *pmu = event->pmu_private; + struct rapl_pmu *rapl_pmu = event->pmu_private; struct hw_perf_event *hwc = &event->hw; unsigned long flags; - raw_spin_lock_irqsave(&pmu->lock, flags); + raw_spin_lock_irqsave(&rapl_pmu->lock, flags); hwc->state = PERF_HES_UPTODATE | PERF_HES_STOPPED; if (mode & PERF_EF_START) - __rapl_pmu_event_start(pmu, event); + __rapl_pmu_event_start(rapl_pmu, event); - raw_spin_unlock_irqrestore(&pmu->lock, flags); + raw_spin_unlock_irqrestore(&rapl_pmu->lock, flags); return 0; } @@ -356,10 +376,14 @@ static int rapl_pmu_event_init(struct perf_event *event) { u64 cfg = event->attr.config & RAPL_EVENT_MASK; int bit, ret = 0; - struct rapl_pmu *pmu; + struct rapl_pmu *rapl_pmu; + struct rapl_pmus *curr_rapl_pmus; /* only look at RAPL events */ - if (event->attr.type != rapl_pmus->pmu.type) + if (event->attr.type == rapl_pmus_pkg->pmu.type || + (rapl_pmus_core && event->attr.type == rapl_pmus_core->pmu.type)) + curr_rapl_pmus = container_of(event->pmu, struct rapl_pmus, pmu); + else return -ENOENT; /* check only supported bits are set */ @@ -369,16 +393,18 @@ static int rapl_pmu_event_init(struct perf_event *event) if (event->cpu < 0) return -EINVAL; - event->event_caps |= PERF_EV_CAP_READ_ACTIVE_PKG; + if (curr_rapl_pmus == rapl_pmus_pkg) + event->event_caps |= PERF_EV_CAP_READ_ACTIVE_PKG; - if (!cfg || cfg >= NR_RAPL_DOMAINS + 1) + if (!cfg || cfg >= NR_RAPL_PKG_DOMAINS + 1) return -EINVAL; - cfg = array_index_nospec((long)cfg, NR_RAPL_DOMAINS + 1); + cfg = array_index_nospec((long)cfg, NR_RAPL_PKG_DOMAINS + 1); bit = cfg - 1; /* check event supported */ - if (!(rapl_cntr_mask & (1 << bit))) + if (!(rapl_pkg_cntr_mask & (1 << bit)) && + !(rapl_core_cntr_mask & (1 << bit))) return -EINVAL; /* unsupported modes and filters */ @@ -386,12 +412,18 @@ static int rapl_pmu_event_init(struct perf_event *event) return -EINVAL; /* must be done before validate_group */ - pmu = cpu_to_rapl_pmu(event->cpu); - if (!pmu) + if (curr_rapl_pmus == rapl_pmus_core) { + rapl_pmu = curr_rapl_pmus->rapl_pmu[topology_logical_core_id(event->cpu)]; + event->hw.event_base = rapl_model->rapl_core_msrs[bit].msr; + } else { + rapl_pmu = curr_rapl_pmus->rapl_pmu[get_rapl_pmu_idx(event->cpu)]; + event->hw.event_base = rapl_model->rapl_pkg_msrs[bit].msr; + } + + if (!rapl_pmu) return -EINVAL; - event->cpu = pmu->cpu; - event->pmu_private = pmu; - event->hw.event_base = rapl_msrs[bit].msr; + event->cpu = rapl_pmu->cpu; + event->pmu_private = rapl_pmu; event->hw.config = cfg; event->hw.idx = bit; @@ -406,7 +438,7 @@ static void rapl_pmu_event_read(struct perf_event *event) static ssize_t rapl_get_attr_cpumask(struct device *dev, struct device_attribute *attr, char *buf) { - return cpumap_print_to_pagebuf(true, buf, &rapl_cpu_mask); + return cpumap_print_to_pagebuf(true, buf, &rapl_pmus_pkg->cpumask); } static DEVICE_ATTR(cpumask, S_IRUGO, rapl_get_attr_cpumask, NULL); @@ -420,17 +452,38 @@ static struct attribute_group rapl_pmu_attr_group = { .attrs = rapl_pmu_attrs, }; +static ssize_t rapl_get_attr_per_core_cpumask(struct device *dev, + struct device_attribute *attr, char *buf) +{ + return cpumap_print_to_pagebuf(true, buf, &rapl_pmus_core->cpumask); +} + +static struct device_attribute dev_attr_per_core_cpumask = __ATTR(cpumask, 0444, + rapl_get_attr_per_core_cpumask, + NULL); + +static struct attribute *rapl_pmu_per_core_attrs[] = { + &dev_attr_per_core_cpumask.attr, + NULL, +}; + +static struct attribute_group rapl_pmu_per_core_attr_group = { + .attrs = rapl_pmu_per_core_attrs, +}; + RAPL_EVENT_ATTR_STR(energy-cores, rapl_cores, "event=0x01"); RAPL_EVENT_ATTR_STR(energy-pkg , rapl_pkg, "event=0x02"); RAPL_EVENT_ATTR_STR(energy-ram , rapl_ram, "event=0x03"); RAPL_EVENT_ATTR_STR(energy-gpu , rapl_gpu, "event=0x04"); RAPL_EVENT_ATTR_STR(energy-psys, rapl_psys, "event=0x05"); +RAPL_EVENT_ATTR_STR(energy-per-core, rapl_per_core, "event=0x01"); RAPL_EVENT_ATTR_STR(energy-cores.unit, rapl_cores_unit, "Joules"); RAPL_EVENT_ATTR_STR(energy-pkg.unit , rapl_pkg_unit, "Joules"); RAPL_EVENT_ATTR_STR(energy-ram.unit , rapl_ram_unit, "Joules"); RAPL_EVENT_ATTR_STR(energy-gpu.unit , rapl_gpu_unit, "Joules"); RAPL_EVENT_ATTR_STR(energy-psys.unit, rapl_psys_unit, "Joules"); +RAPL_EVENT_ATTR_STR(energy-per-core.unit, rapl_per_core_unit, "Joules"); /* * we compute in 0.23 nJ increments regardless of MSR @@ -440,6 +493,7 @@ RAPL_EVENT_ATTR_STR(energy-pkg.scale, rapl_pkg_scale, "2.3283064365386962890 RAPL_EVENT_ATTR_STR(energy-ram.scale, rapl_ram_scale, "2.3283064365386962890625e-10"); RAPL_EVENT_ATTR_STR(energy-gpu.scale, rapl_gpu_scale, "2.3283064365386962890625e-10"); RAPL_EVENT_ATTR_STR(energy-psys.scale, rapl_psys_scale, "2.3283064365386962890625e-10"); +RAPL_EVENT_ATTR_STR(energy-per-core.scale, rapl_per_core_scale, "2.3283064365386962890625e-10"); /* * There are no default events, but we need to create @@ -473,6 +527,13 @@ static const struct attribute_group *rapl_attr_groups[] = { NULL, }; +static const struct attribute_group *rapl_per_core_attr_groups[] = { + &rapl_pmu_per_core_attr_group, + &rapl_pmu_format_group, + &rapl_pmu_events_group, + NULL, +}; + static struct attribute *rapl_events_cores[] = { EVENT_PTR(rapl_cores), EVENT_PTR(rapl_cores_unit), @@ -533,6 +594,18 @@ static struct attribute_group rapl_events_psys_group = { .attrs = rapl_events_psys, }; +static struct attribute *rapl_events_per_core[] = { + EVENT_PTR(rapl_per_core), + EVENT_PTR(rapl_per_core_unit), + EVENT_PTR(rapl_per_core_scale), + NULL, +}; + +static struct attribute_group rapl_events_per_core_group = { + .name = "events", + .attrs = rapl_events_per_core, +}; + static bool test_msr(int idx, void *data) { return test_bit(idx, (unsigned long *) data); @@ -558,11 +631,11 @@ static struct perf_msr intel_rapl_spr_msrs[] = { }; /* - * Force to PERF_RAPL_MAX size due to: - * - perf_msr_probe(PERF_RAPL_MAX) + * Force to PERF_RAPL_PKG_EVENTS_MAX size due to: + * - perf_msr_probe(PERF_RAPL_PKG_EVENTS_MAX) * - want to use same event codes across both architectures */ -static struct perf_msr amd_rapl_msrs[] = { +static struct perf_msr amd_rapl_pkg_msrs[] = { [PERF_RAPL_PP0] = { 0, &rapl_events_cores_group, NULL, false, 0 }, [PERF_RAPL_PKG] = { MSR_AMD_PKG_ENERGY_STATUS, &rapl_events_pkg_group, test_msr, false, RAPL_MSR_MASK }, [PERF_RAPL_RAM] = { 0, &rapl_events_ram_group, NULL, false, 0 }, @@ -570,77 +643,104 @@ static struct perf_msr amd_rapl_msrs[] = { [PERF_RAPL_PSYS] = { 0, &rapl_events_psys_group, NULL, false, 0 }, }; -static int rapl_cpu_offline(unsigned int cpu) +static struct perf_msr amd_rapl_core_msrs[] = { + [PERF_RAPL_PER_CORE] = { MSR_AMD_CORE_ENERGY_STATUS, &rapl_events_per_core_group, + test_msr, false, RAPL_MSR_MASK }, +}; + +static int __rapl_cpu_offline(struct rapl_pmus *rapl_pmus, unsigned int rapl_pmu_idx, + const struct cpumask *event_cpumask, unsigned int cpu) { - struct rapl_pmu *pmu = cpu_to_rapl_pmu(cpu); + struct rapl_pmu *rapl_pmu = rapl_pmus->rapl_pmu[rapl_pmu_idx]; int target; /* Check if exiting cpu is used for collecting rapl events */ - if (!cpumask_test_and_clear_cpu(cpu, &rapl_cpu_mask)) + if (!cpumask_test_and_clear_cpu(cpu, &rapl_pmus->cpumask)) return 0; - pmu->cpu = -1; + rapl_pmu->cpu = -1; /* Find a new cpu to collect rapl events */ - target = cpumask_any_but(get_rapl_pmu_cpumask(cpu), cpu); + target = cpumask_any_but(event_cpumask, cpu); /* Migrate rapl events to the new target */ if (target < nr_cpu_ids) { - cpumask_set_cpu(target, &rapl_cpu_mask); - pmu->cpu = target; - perf_pmu_migrate_context(pmu->pmu, cpu, target); + cpumask_set_cpu(target, &rapl_pmus->cpumask); + rapl_pmu->cpu = target; + perf_pmu_migrate_context(rapl_pmu->pmu, cpu, target); } return 0; } -static int rapl_cpu_online(unsigned int cpu) +static int rapl_cpu_offline(unsigned int cpu) { - s32 rapl_pmu_idx = get_rapl_pmu_idx(cpu); - if (rapl_pmu_idx < 0) { - pr_err("topology_logical_(package/die)_id() returned a negative value"); - return -EINVAL; - } - struct rapl_pmu *pmu = cpu_to_rapl_pmu(cpu); + int ret = __rapl_cpu_offline(rapl_pmus_pkg, get_rapl_pmu_idx(cpu), + get_rapl_pmu_cpumask(cpu), cpu); + + if (ret == 0 && rapl_model->core_events) + ret = __rapl_cpu_offline(rapl_pmus_core, topology_logical_core_id(cpu), + topology_sibling_cpumask(cpu), cpu); + + return ret; +} + +static int __rapl_cpu_online(struct rapl_pmus *rapl_pmus, unsigned int rapl_pmu_idx, + const struct cpumask *event_cpumask, unsigned int cpu) +{ + struct rapl_pmu *rapl_pmu = rapl_pmus->rapl_pmu[rapl_pmu_idx]; int target; - if (!pmu) { - pmu = kzalloc_node(sizeof(*pmu), GFP_KERNEL, cpu_to_node(cpu)); - if (!pmu) + if (!rapl_pmu) { + rapl_pmu = kzalloc_node(sizeof(*rapl_pmu), GFP_KERNEL, cpu_to_node(cpu)); + if (!rapl_pmu) return -ENOMEM; - raw_spin_lock_init(&pmu->lock); - INIT_LIST_HEAD(&pmu->active_list); - pmu->pmu = &rapl_pmus->pmu; - pmu->timer_interval = ms_to_ktime(rapl_timer_ms); - rapl_hrtimer_init(pmu); + raw_spin_lock_init(&rapl_pmu->lock); + INIT_LIST_HEAD(&rapl_pmu->active_list); + rapl_pmu->pmu = &rapl_pmus->pmu; + rapl_pmu->timer_interval = ms_to_ktime(rapl_timer_ms); + rapl_hrtimer_init(rapl_pmu); - rapl_pmus->pmus[rapl_pmu_idx] = pmu; + rapl_pmus->rapl_pmu[rapl_pmu_idx] = rapl_pmu; } /* * Check if there is an online cpu in the package which collects rapl * events already. */ - target = cpumask_any_and(&rapl_cpu_mask, get_rapl_pmu_cpumask(cpu)); + target = cpumask_any_and(&rapl_pmus->cpumask, event_cpumask); if (target < nr_cpu_ids) return 0; - cpumask_set_cpu(cpu, &rapl_cpu_mask); - pmu->cpu = cpu; + cpumask_set_cpu(cpu, &rapl_pmus->cpumask); + rapl_pmu->cpu = cpu; return 0; } -static int rapl_check_hw_unit(struct rapl_model *rm) +static int rapl_cpu_online(unsigned int cpu) +{ + int ret = __rapl_cpu_online(rapl_pmus_pkg, get_rapl_pmu_idx(cpu), + get_rapl_pmu_cpumask(cpu), cpu); + + if (ret == 0 && rapl_model->core_events) + ret = __rapl_cpu_online(rapl_pmus_core, topology_logical_core_id(cpu), + topology_sibling_cpumask(cpu), cpu); + + return ret; +} + + +static int rapl_check_hw_unit(void) { u64 msr_rapl_power_unit_bits; int i; /* protect rdmsrl() to handle virtualization */ - if (rdmsrl_safe(rm->msr_power_unit, &msr_rapl_power_unit_bits)) + if (rdmsrl_safe(rapl_model->msr_power_unit, &msr_rapl_power_unit_bits)) return -1; - for (i = 0; i < NR_RAPL_DOMAINS; i++) + for (i = 0; i < NR_RAPL_PKG_DOMAINS; i++) rapl_hw_unit[i] = (msr_rapl_power_unit_bits >> 8) & 0x1FULL; - switch (rm->unit_quirk) { + switch (rapl_model->unit_quirk) { /* * DRAM domain on HSW server and KNL has fixed energy unit which can be * different than the unit from power unit MSR. See @@ -679,22 +779,29 @@ static void __init rapl_advertise(void) int i; pr_info("API unit is 2^-32 Joules, %d fixed counters, %llu ms ovfl timer\n", - hweight32(rapl_cntr_mask), rapl_timer_ms); + hweight32(rapl_pkg_cntr_mask) + hweight32(rapl_core_cntr_mask), rapl_timer_ms); + + for (i = 0; i < NR_RAPL_PKG_DOMAINS; i++) { + if (rapl_pkg_cntr_mask & (1 << i)) { + pr_info("hw unit of domain %s 2^-%d Joules\n", + rapl_pkg_domain_names[i], rapl_hw_unit[i]); + } + } - for (i = 0; i < NR_RAPL_DOMAINS; i++) { - if (rapl_cntr_mask & (1 << i)) { + for (i = 0; i < NR_RAPL_CORE_DOMAINS; i++) { + if (rapl_core_cntr_mask & (1 << i)) { pr_info("hw unit of domain %s 2^-%d Joules\n", - rapl_domain_names[i], rapl_hw_unit[i]); + rapl_core_domain_names[i], rapl_hw_unit[i]); } } } -static void cleanup_rapl_pmus(void) +static void cleanup_rapl_pmus(struct rapl_pmus *rapl_pmus) { int i; for (i = 0; i < rapl_pmus->nr_rapl_pmu; i++) - kfree(rapl_pmus->pmus[i]); + kfree(rapl_pmus->rapl_pmu[i]); kfree(rapl_pmus); } @@ -707,14 +814,17 @@ static const struct attribute_group *rapl_attr_update[] = { NULL, }; -static int __init init_rapl_pmus(void) -{ - int nr_rapl_pmu = topology_max_packages(); +static const struct attribute_group *rapl_per_core_attr_update[] = { + &rapl_events_per_core_group, +}; - if (!rapl_pmu_is_pkg_scope()) - nr_rapl_pmu *= topology_max_dies_per_package(); +static int __init init_rapl_pmus(struct rapl_pmus **rapl_pmus_ptr, int nr_rapl_pmu, + const struct attribute_group **rapl_attr_groups, + const struct attribute_group **rapl_attr_update) +{ + struct rapl_pmus *rapl_pmus; - rapl_pmus = kzalloc(struct_size(rapl_pmus, pmus, nr_rapl_pmu), GFP_KERNEL); + rapl_pmus = kzalloc(struct_size(rapl_pmus, rapl_pmu, nr_rapl_pmu), GFP_KERNEL); if (!rapl_pmus) return -ENOMEM; @@ -730,75 +840,80 @@ static int __init init_rapl_pmus(void) rapl_pmus->pmu.read = rapl_pmu_event_read; rapl_pmus->pmu.module = THIS_MODULE; rapl_pmus->pmu.capabilities = PERF_PMU_CAP_NO_EXCLUDE; + + *rapl_pmus_ptr = rapl_pmus; + return 0; } static struct rapl_model model_snb = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_PP1), .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_snbep = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM), .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_hsw = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM) | BIT(PERF_RAPL_PP1), .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_hsx = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM), .unit_quirk = RAPL_UNIT_QUIRK_INTEL_HSW, .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_knl = { - .events = BIT(PERF_RAPL_PKG) | + .pkg_events = BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM), .unit_quirk = RAPL_UNIT_QUIRK_INTEL_HSW, .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_skl = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM) | BIT(PERF_RAPL_PP1) | BIT(PERF_RAPL_PSYS), .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_msrs, + .rapl_pkg_msrs = intel_rapl_msrs, }; static struct rapl_model model_spr = { - .events = BIT(PERF_RAPL_PP0) | + .pkg_events = BIT(PERF_RAPL_PP0) | BIT(PERF_RAPL_PKG) | BIT(PERF_RAPL_RAM) | BIT(PERF_RAPL_PSYS), .unit_quirk = RAPL_UNIT_QUIRK_INTEL_SPR, .msr_power_unit = MSR_RAPL_POWER_UNIT, - .rapl_msrs = intel_rapl_spr_msrs, + .rapl_pkg_msrs = intel_rapl_spr_msrs, }; static struct rapl_model model_amd_hygon = { - .events = BIT(PERF_RAPL_PKG), + .pkg_events = BIT(PERF_RAPL_PKG), + .core_events = BIT(PERF_RAPL_PER_CORE), .msr_power_unit = MSR_AMD_RAPL_POWER_UNIT, - .rapl_msrs = amd_rapl_msrs, + .rapl_pkg_msrs = amd_rapl_pkg_msrs, + .rapl_core_msrs = amd_rapl_core_msrs, }; static const struct x86_cpu_id rapl_model_match[] __initconst = { @@ -854,28 +969,47 @@ MODULE_DEVICE_TABLE(x86cpu, rapl_model_match); static int __init rapl_pmu_init(void) { const struct x86_cpu_id *id; - struct rapl_model *rm; int ret; + int nr_rapl_pmu = topology_max_packages() * topology_max_dies_per_package(); + int nr_cores = topology_max_packages() * topology_num_cores_per_package(); + + if (rapl_pmu_is_pkg_scope()) + nr_rapl_pmu = topology_max_packages(); id = x86_match_cpu(rapl_model_match); if (!id) return -ENODEV; - rm = (struct rapl_model *) id->driver_data; + rapl_model = (struct rapl_model *) id->driver_data; - rapl_msrs = rm->rapl_msrs; + rapl_pkg_cntr_mask = perf_msr_probe(rapl_model->rapl_pkg_msrs, PERF_RAPL_PKG_EVENTS_MAX, + false, (void *) &rapl_model->pkg_events); - rapl_cntr_mask = perf_msr_probe(rapl_msrs, PERF_RAPL_MAX, - false, (void *) &rm->events); - - ret = rapl_check_hw_unit(rm); + ret = rapl_check_hw_unit(); if (ret) return ret; - ret = init_rapl_pmus(); + ret = init_rapl_pmus(&rapl_pmus_pkg, nr_rapl_pmu, rapl_attr_groups, rapl_attr_update); if (ret) return ret; + if (rapl_model->core_events) { + rapl_core_cntr_mask = perf_msr_probe(rapl_model->rapl_core_msrs, + PERF_RAPL_CORE_EVENTS_MAX, false, + (void *) &rapl_model->core_events); + + ret = init_rapl_pmus(&rapl_pmus_core, nr_cores, + rapl_per_core_attr_groups, rapl_per_core_attr_update); + if (ret) { + /* + * If initialization of per_core PMU fails, reset per_core + * flag, and continue with power PMU initialization. + */ + pr_warn("Per-core PMU initialization failed (%d)\n", ret); + rapl_model->core_events = 0UL; + } + } + /* * Install callbacks. Core will call them for each online cpu. */ @@ -885,10 +1019,24 @@ static int __init rapl_pmu_init(void) if (ret) goto out; - ret = perf_pmu_register(&rapl_pmus->pmu, "power", -1); + ret = perf_pmu_register(&rapl_pmus_pkg->pmu, "power", -1); if (ret) goto out1; + if (rapl_model->core_events) { + ret = perf_pmu_register(&rapl_pmus_core->pmu, "power_per_core", -1); + if (ret) { + /* + * If registration of per_core PMU fails, cleanup per_core PMU + * variables, reset the per_core flag and keep the + * power PMU untouched. + */ + pr_warn("Per-core PMU registration failed (%d)\n", ret); + cleanup_rapl_pmus(rapl_pmus_core); + rapl_model->core_events = 0UL; + } + } + rapl_advertise(); return 0; @@ -896,7 +1044,7 @@ static int __init rapl_pmu_init(void) cpuhp_remove_state(CPUHP_AP_PERF_X86_RAPL_ONLINE); out: pr_warn("Initialization failed (%d), disabled\n", ret); - cleanup_rapl_pmus(); + cleanup_rapl_pmus(rapl_pmus_pkg); return ret; } module_init(rapl_pmu_init); @@ -904,7 +1052,11 @@ module_init(rapl_pmu_init); static void __exit intel_rapl_exit(void) { cpuhp_remove_state_nocalls(CPUHP_AP_PERF_X86_RAPL_ONLINE); - perf_pmu_unregister(&rapl_pmus->pmu); - cleanup_rapl_pmus(); + perf_pmu_unregister(&rapl_pmus_pkg->pmu); + cleanup_rapl_pmus(rapl_pmus_pkg); + if (rapl_model->core_events) { + perf_pmu_unregister(&rapl_pmus_core->pmu); + cleanup_rapl_pmus(rapl_pmus_core); + } } module_exit(intel_rapl_exit); diff --git a/arch/x86/include/asm/processor.h b/arch/x86/include/asm/processor.h index c0975815980c..cfd8a5591421 100644 --- a/arch/x86/include/asm/processor.h +++ b/arch/x86/include/asm/processor.h @@ -98,6 +98,7 @@ struct cpuinfo_topology { // Logical ID mappings u32 logical_pkg_id; u32 logical_die_id; + u32 logical_core_id; // AMD Node ID and Nodes per Package info u32 amd_node_id; diff --git a/arch/x86/include/asm/topology.h b/arch/x86/include/asm/topology.h index fd41103ad342..3973cb9bb2e6 100644 --- a/arch/x86/include/asm/topology.h +++ b/arch/x86/include/asm/topology.h @@ -143,6 +143,7 @@ extern const struct cpumask *cpu_clustergroup_mask(int cpu); #define topology_logical_package_id(cpu) (cpu_data(cpu).topo.logical_pkg_id) #define topology_physical_package_id(cpu) (cpu_data(cpu).topo.pkg_id) #define topology_logical_die_id(cpu) (cpu_data(cpu).topo.logical_die_id) +#define topology_logical_core_id(cpu) (cpu_data(cpu).topo.logical_core_id) #define topology_die_id(cpu) (cpu_data(cpu).topo.die_id) #define topology_core_id(cpu) (cpu_data(cpu).topo.core_id) #define topology_ppin(cpu) (cpu_data(cpu).ppin) diff --git a/arch/x86/kernel/cpu/debugfs.c b/arch/x86/kernel/cpu/debugfs.c index 10719aba6276..cacfd3f6abef 100644 --- a/arch/x86/kernel/cpu/debugfs.c +++ b/arch/x86/kernel/cpu/debugfs.c @@ -25,6 +25,7 @@ static int cpu_debug_show(struct seq_file *m, void *p) seq_printf(m, "cpu_type: %s\n", get_topology_cpu_type_name(c)); seq_printf(m, "logical_pkg_id: %u\n", c->topo.logical_pkg_id); seq_printf(m, "logical_die_id: %u\n", c->topo.logical_die_id); + seq_printf(m, "logical_core_id: %u\n", c->topo.logical_core_id); seq_printf(m, "llc_id: %u\n", c->topo.llc_id); seq_printf(m, "l2c_id: %u\n", c->topo.l2c_id); seq_printf(m, "amd_node_id: %u\n", c->topo.amd_node_id); diff --git a/arch/x86/kernel/cpu/topology_common.c b/arch/x86/kernel/cpu/topology_common.c index 8277c64f88db..b5a5e1411469 100644 --- a/arch/x86/kernel/cpu/topology_common.c +++ b/arch/x86/kernel/cpu/topology_common.c @@ -185,6 +185,7 @@ static void topo_set_ids(struct topo_scan *tscan, bool early) if (!early) { c->topo.logical_pkg_id = topology_get_logical_id(apicid, TOPO_PKG_DOMAIN); c->topo.logical_die_id = topology_get_logical_id(apicid, TOPO_DIE_DOMAIN); + c->topo.logical_core_id = topology_get_logical_id(apicid, TOPO_CORE_DOMAIN); } /* Package relative core ID */ -- 2.47.0 From 063003cdcfa118a0e75173a1c02094c2978bc532 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:22:19 +0100 Subject: [PATCH 12/13] t2 Signed-off-by: Peter Jung --- .../ABI/testing/sysfs-driver-hid-appletb-kbd | 13 + Documentation/core-api/printk-formats.rst | 32 + MAINTAINERS | 6 + drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c | 3 + drivers/gpu/drm/drm_format_helper.c | 54 + drivers/gpu/drm/i915/display/intel_ddi.c | 4 + drivers/gpu/drm/i915/display/intel_fbdev.c | 6 +- drivers/gpu/drm/i915/display/intel_quirks.c | 15 + drivers/gpu/drm/i915/display/intel_quirks.h | 1 + .../gpu/drm/tests/drm_format_helper_test.c | 81 ++ drivers/gpu/drm/tiny/Kconfig | 12 + drivers/gpu/drm/tiny/Makefile | 1 + drivers/gpu/drm/tiny/appletbdrm.c | 624 +++++++++ drivers/gpu/vga/vga_switcheroo.c | 7 +- drivers/hid/Kconfig | 22 + drivers/hid/Makefile | 2 + drivers/hid/hid-appletb-bl.c | 207 +++ drivers/hid/hid-appletb-kbd.c | 432 +++++++ drivers/hid/hid-multitouch.c | 60 +- drivers/hid/hid-quirks.c | 8 +- drivers/hwmon/applesmc.c | 1138 ++++++++++++----- drivers/input/mouse/bcm5974.c | 138 ++ drivers/pci/vgaarb.c | 1 + drivers/platform/x86/apple-gmux.c | 18 + drivers/staging/Kconfig | 2 + drivers/staging/Makefile | 1 + drivers/staging/apple-bce/Kconfig | 18 + drivers/staging/apple-bce/Makefile | 28 + drivers/staging/apple-bce/apple_bce.c | 445 +++++++ drivers/staging/apple-bce/apple_bce.h | 38 + drivers/staging/apple-bce/audio/audio.c | 711 ++++++++++ drivers/staging/apple-bce/audio/audio.h | 125 ++ drivers/staging/apple-bce/audio/description.h | 42 + drivers/staging/apple-bce/audio/pcm.c | 308 +++++ drivers/staging/apple-bce/audio/pcm.h | 16 + drivers/staging/apple-bce/audio/protocol.c | 347 +++++ drivers/staging/apple-bce/audio/protocol.h | 147 +++ .../staging/apple-bce/audio/protocol_bce.c | 226 ++++ .../staging/apple-bce/audio/protocol_bce.h | 72 ++ drivers/staging/apple-bce/mailbox.c | 151 +++ drivers/staging/apple-bce/mailbox.h | 53 + drivers/staging/apple-bce/queue.c | 390 ++++++ drivers/staging/apple-bce/queue.h | 177 +++ drivers/staging/apple-bce/queue_dma.c | 220 ++++ drivers/staging/apple-bce/queue_dma.h | 50 + drivers/staging/apple-bce/vhci/command.h | 204 +++ drivers/staging/apple-bce/vhci/queue.c | 268 ++++ drivers/staging/apple-bce/vhci/queue.h | 76 ++ drivers/staging/apple-bce/vhci/transfer.c | 661 ++++++++++ drivers/staging/apple-bce/vhci/transfer.h | 73 ++ drivers/staging/apple-bce/vhci/vhci.c | 759 +++++++++++ drivers/staging/apple-bce/vhci/vhci.h | 52 + include/drm/drm_format_helper.h | 3 + lib/test_printf.c | 20 +- lib/vsprintf.c | 36 +- scripts/checkpatch.pl | 2 +- 56 files changed, 8270 insertions(+), 336 deletions(-) create mode 100644 Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd create mode 100644 drivers/gpu/drm/tiny/appletbdrm.c create mode 100644 drivers/hid/hid-appletb-bl.c create mode 100644 drivers/hid/hid-appletb-kbd.c create mode 100644 drivers/staging/apple-bce/Kconfig create mode 100644 drivers/staging/apple-bce/Makefile create mode 100644 drivers/staging/apple-bce/apple_bce.c create mode 100644 drivers/staging/apple-bce/apple_bce.h create mode 100644 drivers/staging/apple-bce/audio/audio.c create mode 100644 drivers/staging/apple-bce/audio/audio.h create mode 100644 drivers/staging/apple-bce/audio/description.h create mode 100644 drivers/staging/apple-bce/audio/pcm.c create mode 100644 drivers/staging/apple-bce/audio/pcm.h create mode 100644 drivers/staging/apple-bce/audio/protocol.c create mode 100644 drivers/staging/apple-bce/audio/protocol.h create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.c create mode 100644 drivers/staging/apple-bce/audio/protocol_bce.h create mode 100644 drivers/staging/apple-bce/mailbox.c create mode 100644 drivers/staging/apple-bce/mailbox.h create mode 100644 drivers/staging/apple-bce/queue.c create mode 100644 drivers/staging/apple-bce/queue.h create mode 100644 drivers/staging/apple-bce/queue_dma.c create mode 100644 drivers/staging/apple-bce/queue_dma.h create mode 100644 drivers/staging/apple-bce/vhci/command.h create mode 100644 drivers/staging/apple-bce/vhci/queue.c create mode 100644 drivers/staging/apple-bce/vhci/queue.h create mode 100644 drivers/staging/apple-bce/vhci/transfer.c create mode 100644 drivers/staging/apple-bce/vhci/transfer.h create mode 100644 drivers/staging/apple-bce/vhci/vhci.c create mode 100644 drivers/staging/apple-bce/vhci/vhci.h diff --git a/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd new file mode 100644 index 000000000000..2a19584d091e --- /dev/null +++ b/Documentation/ABI/testing/sysfs-driver-hid-appletb-kbd @@ -0,0 +1,13 @@ +What: /sys/bus/hid/drivers/hid-appletb-kbd//mode +Date: September, 2023 +KernelVersion: 6.5 +Contact: linux-input@vger.kernel.org +Description: + The set of keys displayed on the Touch Bar. + Valid values are: + == ================= + 0 Escape key only + 1 Function keys + 2 Media/brightness keys + 3 None + == ================= diff --git a/Documentation/core-api/printk-formats.rst b/Documentation/core-api/printk-formats.rst index 14e093da3ccd..ccd7bd29a6d6 100644 --- a/Documentation/core-api/printk-formats.rst +++ b/Documentation/core-api/printk-formats.rst @@ -630,6 +630,38 @@ Examples:: %p4cc Y10 little-endian (0x20303159) %p4cc NV12 big-endian (0xb231564e) +Generic FourCC code +------------------- + +:: + %p4c[hnbl] gP00 (0x67503030) + +Print a generic FourCC code, as both ASCII characters and its numerical +value as hexadecimal. + +The additional ``h``, ``r``, ``b``, and ``l`` specifiers are used to specify +host, reversed, big or little endian order data respectively. Host endian +order means the data is interpreted as a 32-bit integer and the most +significant byte is printed first; that is, the character code as printed +matches the byte order stored in memory on big-endian systems, and is reversed +on little-endian systems. + +Passed by reference. + +Examples for a little-endian machine, given &(u32)0x67503030:: + + %p4ch gP00 (0x67503030) + %p4cl gP00 (0x67503030) + %p4cb 00Pg (0x30305067) + %p4cr 00Pg (0x30305067) + +Examples for a big-endian machine, given &(u32)0x67503030:: + + %p4ch gP00 (0x67503030) + %p4cl 00Pg (0x30305067) + %p4cb gP00 (0x67503030) + %p4cr 00Pg (0x30305067) + Rust ---- diff --git a/MAINTAINERS b/MAINTAINERS index f509050e63ed..a3bbf3d5fb9e 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -7013,6 +7013,12 @@ S: Supported T: git https://gitlab.freedesktop.org/drm/misc/kernel.git F: drivers/gpu/drm/sun4i/sun8i* +DRM DRIVER FOR APPLE TOUCH BARS +M: Kerem Karabay +L: dri-devel@lists.freedesktop.org +S: Maintained +F: drivers/gpu/drm/tiny/appletbdrm.c + DRM DRIVER FOR ARM PL111 CLCD S: Orphan T: git https://gitlab.freedesktop.org/drm/misc/kernel.git diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c index f6a6fc6a4f5c..e71b6dfad958 100644 --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c @@ -2260,6 +2260,9 @@ static int amdgpu_pci_probe(struct pci_dev *pdev, int ret, retry = 0, i; bool supports_atomic = false; + if (vga_switcheroo_client_probe_defer(pdev)) + return -EPROBE_DEFER; + /* skip devices which are owned by radeon */ for (i = 0; i < ARRAY_SIZE(amdgpu_unsupported_pciidlist); i++) { if (amdgpu_unsupported_pciidlist[i] == pdev->device) diff --git a/drivers/gpu/drm/drm_format_helper.c b/drivers/gpu/drm/drm_format_helper.c index b1be458ed4dd..28c0e76a1e88 100644 --- a/drivers/gpu/drm/drm_format_helper.c +++ b/drivers/gpu/drm/drm_format_helper.c @@ -702,6 +702,57 @@ void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pi } EXPORT_SYMBOL(drm_fb_xrgb8888_to_rgb888); +static void drm_fb_xrgb8888_to_bgr888_line(void *dbuf, const void *sbuf, unsigned int pixels) +{ + u8 *dbuf8 = dbuf; + const __le32 *sbuf32 = sbuf; + unsigned int x; + u32 pix; + + for (x = 0; x < pixels; x++) { + pix = le32_to_cpu(sbuf32[x]); + /* write red-green-blue to output in little endianness */ + *dbuf8++ = (pix & 0x00FF0000) >> 16; + *dbuf8++ = (pix & 0x0000FF00) >> 8; + *dbuf8++ = (pix & 0x000000FF) >> 0; + } +} + +/** + * drm_fb_xrgb8888_to_bgr888 - Convert XRGB8888 to BGR888 clip buffer + * @dst: Array of BGR888 destination buffers + * @dst_pitch: Array of numbers of bytes between the start of two consecutive scanlines + * within @dst; can be NULL if scanlines are stored next to each other. + * @src: Array of XRGB8888 source buffers + * @fb: DRM framebuffer + * @clip: Clip rectangle area to copy + * @state: Transform and conversion state + * + * This function copies parts of a framebuffer to display memory and converts the + * color format during the process. Destination and framebuffer formats must match. The + * parameters @dst, @dst_pitch and @src refer to arrays. Each array must have at + * least as many entries as there are planes in @fb's format. Each entry stores the + * value for the format's respective color plane at the same index. + * + * This function does not apply clipping on @dst (i.e. the destination is at the + * top-left corner). + * + * Drivers can use this function for BGR888 devices that don't natively + * support XRGB8888. + */ +void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state) +{ + static const u8 dst_pixsize[DRM_FORMAT_MAX_PLANES] = { + 3, + }; + + drm_fb_xfrm(dst, dst_pitch, dst_pixsize, src, fb, clip, false, state, + drm_fb_xrgb8888_to_bgr888_line); +} +EXPORT_SYMBOL(drm_fb_xrgb8888_to_bgr888); + static void drm_fb_xrgb8888_to_argb8888_line(void *dbuf, const void *sbuf, unsigned int pixels) { __le32 *dbuf32 = dbuf; @@ -1035,6 +1086,9 @@ int drm_fb_blit(struct iosys_map *dst, const unsigned int *dst_pitch, uint32_t d } else if (dst_format == DRM_FORMAT_RGB888) { drm_fb_xrgb8888_to_rgb888(dst, dst_pitch, src, fb, clip, state); return 0; + } else if (dst_format == DRM_FORMAT_BGR888) { + drm_fb_xrgb8888_to_bgr888(dst, dst_pitch, src, fb, clip, state); + return 0; } else if (dst_format == DRM_FORMAT_ARGB8888) { drm_fb_xrgb8888_to_argb8888(dst, dst_pitch, src, fb, clip, state); return 0; diff --git a/drivers/gpu/drm/i915/display/intel_ddi.c b/drivers/gpu/drm/i915/display/intel_ddi.c index b1c294236cc8..21e23ba5391e 100644 --- a/drivers/gpu/drm/i915/display/intel_ddi.c +++ b/drivers/gpu/drm/i915/display/intel_ddi.c @@ -4641,6 +4641,7 @@ intel_ddi_init_hdmi_connector(struct intel_digital_port *dig_port) static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) { + struct intel_display *display = to_intel_display(dig_port); struct drm_i915_private *dev_priv = to_i915(dig_port->base.base.dev); if (dig_port->base.port != PORT_A) @@ -4649,6 +4650,9 @@ static bool intel_ddi_a_force_4_lanes(struct intel_digital_port *dig_port) if (dig_port->saved_port_bits & DDI_A_4_LANES) return false; + if (intel_has_quirk(display, QUIRK_DDI_A_FORCE_4_LANES)) + return true; + /* Broxton/Geminilake: Bspec says that DDI_A_4_LANES is the only * supported configuration */ diff --git a/drivers/gpu/drm/i915/display/intel_fbdev.c b/drivers/gpu/drm/i915/display/intel_fbdev.c index 49a1ac4f5491..c8c10a6104c4 100644 --- a/drivers/gpu/drm/i915/display/intel_fbdev.c +++ b/drivers/gpu/drm/i915/display/intel_fbdev.c @@ -199,10 +199,10 @@ static int intelfb_create(struct drm_fb_helper *helper, ifbdev->fb = NULL; if (fb && - (sizes->fb_width > fb->base.width || - sizes->fb_height > fb->base.height)) { + (sizes->fb_width != fb->base.width || + sizes->fb_height != fb->base.height)) { drm_dbg_kms(&dev_priv->drm, - "BIOS fb too small (%dx%d), we require (%dx%d)," + "BIOS fb not valid (%dx%d), we require (%dx%d)," " releasing it\n", fb->base.width, fb->base.height, sizes->fb_width, sizes->fb_height); diff --git a/drivers/gpu/drm/i915/display/intel_quirks.c b/drivers/gpu/drm/i915/display/intel_quirks.c index 29b56d53a340..7226ec8fdd9c 100644 --- a/drivers/gpu/drm/i915/display/intel_quirks.c +++ b/drivers/gpu/drm/i915/display/intel_quirks.c @@ -64,6 +64,18 @@ static void quirk_increase_ddi_disabled_time(struct intel_display *display) drm_info(display->drm, "Applying Increase DDI Disabled quirk\n"); } +/* + * In some cases, the firmware might not set the lane count to 4 (for example, + * when booting in some dual GPU Macs with the dGPU as the default GPU), this + * quirk is used to force it as otherwise it might not be possible to compute a + * valid link configuration. + */ +static void quirk_ddi_a_force_4_lanes(struct intel_display *display) +{ + intel_set_quirk(display, QUIRK_DDI_A_FORCE_4_LANES); + drm_info(display->drm, "Applying DDI A Forced 4 Lanes quirk\n"); +} + static void quirk_no_pps_backlight_power_hook(struct intel_display *display) { intel_set_quirk(display, QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK); @@ -229,6 +241,9 @@ static struct intel_quirk intel_quirks[] = { { 0x3184, 0x1019, 0xa94d, quirk_increase_ddi_disabled_time }, /* HP Notebook - 14-r206nv */ { 0x0f31, 0x103c, 0x220f, quirk_invert_brightness }, + + /* Apple MacBookPro15,1 */ + { 0x3e9b, 0x106b, 0x0176, quirk_ddi_a_force_4_lanes }, }; static struct intel_dpcd_quirk intel_dpcd_quirks[] = { diff --git a/drivers/gpu/drm/i915/display/intel_quirks.h b/drivers/gpu/drm/i915/display/intel_quirks.h index cafdebda7535..a5296f82776e 100644 --- a/drivers/gpu/drm/i915/display/intel_quirks.h +++ b/drivers/gpu/drm/i915/display/intel_quirks.h @@ -20,6 +20,7 @@ enum intel_quirk_id { QUIRK_LVDS_SSC_DISABLE, QUIRK_NO_PPS_BACKLIGHT_POWER_HOOK, QUIRK_FW_SYNC_LEN, + QUIRK_DDI_A_FORCE_4_LANES, }; void intel_init_quirks(struct intel_display *display); diff --git a/drivers/gpu/drm/tests/drm_format_helper_test.c b/drivers/gpu/drm/tests/drm_format_helper_test.c index 08992636ec05..35cd3405d045 100644 --- a/drivers/gpu/drm/tests/drm_format_helper_test.c +++ b/drivers/gpu/drm/tests/drm_format_helper_test.c @@ -60,6 +60,11 @@ struct convert_to_rgb888_result { const u8 expected[TEST_BUF_SIZE]; }; +struct convert_to_bgr888_result { + unsigned int dst_pitch; + const u8 expected[TEST_BUF_SIZE]; +}; + struct convert_to_argb8888_result { unsigned int dst_pitch; const u32 expected[TEST_BUF_SIZE]; @@ -107,6 +112,7 @@ struct convert_xrgb8888_case { struct convert_to_argb1555_result argb1555_result; struct convert_to_rgba5551_result rgba5551_result; struct convert_to_rgb888_result rgb888_result; + struct convert_to_bgr888_result bgr888_result; struct convert_to_argb8888_result argb8888_result; struct convert_to_xrgb2101010_result xrgb2101010_result; struct convert_to_argb2101010_result argb2101010_result; @@ -151,6 +157,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0x00, 0x00, 0xFF }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFF, 0x00, 0x00 }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0xFFFF0000 }, @@ -217,6 +227,10 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0x00, 0x00, 0xFF }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { 0xFF, 0x00, 0x00 }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { 0xFFFF0000 }, @@ -330,6 +344,15 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, }, }, + .bgr888_result = { + .dst_pitch = TEST_USE_DEFAULT_PITCH, + .expected = { + 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, + 0xFF, 0x00, 0x00, 0x00, 0xFF, 0x00, + 0x00, 0x00, 0xFF, 0xFF, 0x00, 0xFF, + 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, + }, + }, .argb8888_result = { .dst_pitch = TEST_USE_DEFAULT_PITCH, .expected = { @@ -468,6 +491,17 @@ static struct convert_xrgb8888_case convert_xrgb8888_cases[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, + .bgr888_result = { + .dst_pitch = 15, + .expected = { + 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, 0xA8, 0xF3, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, 0x11, 0x4D, 0x05, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xA8, 0x03, 0x03, 0x6C, 0xF0, 0x73, 0x0E, 0x44, 0x9C, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, .argb8888_result = { .dst_pitch = 20, .expected = { @@ -914,6 +948,52 @@ static void drm_test_fb_xrgb8888_to_rgb888(struct kunit *test) KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); } +static void drm_test_fb_xrgb8888_to_bgr888(struct kunit *test) +{ + const struct convert_xrgb8888_case *params = test->param_value; + const struct convert_to_bgr888_result *result = ¶ms->bgr888_result; + size_t dst_size; + u8 *buf = NULL; + __le32 *xrgb8888 = NULL; + struct iosys_map dst, src; + + struct drm_framebuffer fb = { + .format = drm_format_info(DRM_FORMAT_XRGB8888), + .pitches = { params->pitch, 0, 0 }, + }; + + dst_size = conversion_buf_size(DRM_FORMAT_BGR888, result->dst_pitch, + ¶ms->clip, 0); + KUNIT_ASSERT_GT(test, dst_size, 0); + + buf = kunit_kzalloc(test, dst_size, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, buf); + iosys_map_set_vaddr(&dst, buf); + + xrgb8888 = cpubuf_to_le32(test, params->xrgb8888, TEST_BUF_SIZE); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, xrgb8888); + iosys_map_set_vaddr(&src, xrgb8888); + + /* + * BGR888 expected results are already in little-endian + * order, so there's no need to convert the test output. + */ + drm_fb_xrgb8888_to_bgr888(&dst, &result->dst_pitch, &src, &fb, ¶ms->clip, + &fmtcnv_state); + KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); + + buf = dst.vaddr; /* restore original value of buf */ + memset(buf, 0, dst_size); + + int blit_result = 0; + + blit_result = drm_fb_blit(&dst, &result->dst_pitch, DRM_FORMAT_BGR888, &src, &fb, ¶ms->clip, + &fmtcnv_state); + + KUNIT_EXPECT_FALSE(test, blit_result); + KUNIT_EXPECT_MEMEQ(test, buf, result->expected, dst_size); +} + static void drm_test_fb_xrgb8888_to_argb8888(struct kunit *test) { const struct convert_xrgb8888_case *params = test->param_value; @@ -1851,6 +1931,7 @@ static struct kunit_case drm_format_helper_test_cases[] = { KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb1555, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgba5551, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_rgb888, convert_xrgb8888_gen_params), + KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_bgr888, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb8888, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_xrgb2101010, convert_xrgb8888_gen_params), KUNIT_CASE_PARAM(drm_test_fb_xrgb8888_to_argb2101010, convert_xrgb8888_gen_params), diff --git a/drivers/gpu/drm/tiny/Kconfig b/drivers/gpu/drm/tiny/Kconfig index f6889f649bc1..559a97bce12c 100644 --- a/drivers/gpu/drm/tiny/Kconfig +++ b/drivers/gpu/drm/tiny/Kconfig @@ -1,5 +1,17 @@ # SPDX-License-Identifier: GPL-2.0-only +config DRM_APPLETBDRM + tristate "DRM support for Apple Touch Bars" + depends on DRM && USB && MMU + select DRM_KMS_HELPER + select DRM_GEM_SHMEM_HELPER + help + Say Y here if you want support for the display of Touch Bars on x86 + MacBook Pros. + + To compile this driver as a module, choose M here: the + module will be called appletbdrm. + config DRM_ARCPGU tristate "ARC PGU" depends on DRM && OF diff --git a/drivers/gpu/drm/tiny/Makefile b/drivers/gpu/drm/tiny/Makefile index 76dde89a044b..9a1b412e764a 100644 --- a/drivers/gpu/drm/tiny/Makefile +++ b/drivers/gpu/drm/tiny/Makefile @@ -1,5 +1,6 @@ # SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_DRM_APPLETBDRM) += appletbdrm.o obj-$(CONFIG_DRM_ARCPGU) += arcpgu.o obj-$(CONFIG_DRM_BOCHS) += bochs.o obj-$(CONFIG_DRM_CIRRUS_QEMU) += cirrus.o diff --git a/drivers/gpu/drm/tiny/appletbdrm.c b/drivers/gpu/drm/tiny/appletbdrm.c new file mode 100644 index 000000000000..7a74c8ad37cd --- /dev/null +++ b/drivers/gpu/drm/tiny/appletbdrm.c @@ -0,0 +1,624 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar DRM Driver + * + * Copyright (c) 2023 Kerem Karabay + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define _APPLETBDRM_FOURCC(s) (((s)[0] << 24) | ((s)[1] << 16) | ((s)[2] << 8) | (s)[3]) +#define APPLETBDRM_FOURCC(s) _APPLETBDRM_FOURCC(#s) + +#define APPLETBDRM_PIXEL_FORMAT APPLETBDRM_FOURCC(RGBA) /* The actual format is BGR888 */ +#define APPLETBDRM_BITS_PER_PIXEL 24 + +#define APPLETBDRM_MSG_CLEAR_DISPLAY APPLETBDRM_FOURCC(CLRD) +#define APPLETBDRM_MSG_GET_INFORMATION APPLETBDRM_FOURCC(GINF) +#define APPLETBDRM_MSG_UPDATE_COMPLETE APPLETBDRM_FOURCC(UDCL) +#define APPLETBDRM_MSG_SIGNAL_READINESS APPLETBDRM_FOURCC(REDY) + +#define APPLETBDRM_BULK_MSG_TIMEOUT 1000 + +#define drm_to_adev(_drm) container_of(_drm, struct appletbdrm_device, drm) +#define adev_to_udev(adev) interface_to_usbdev(to_usb_interface(adev->dev)) + +struct appletbdrm_device { + struct device *dev; + + u8 in_ep; + u8 out_ep; + + u32 width; + u32 height; + + struct drm_device drm; + struct drm_display_mode mode; + struct drm_connector connector; + struct drm_simple_display_pipe pipe; + + bool readiness_signal_received; +}; + +struct appletbdrm_request_header { + __le16 unk_00; + __le16 unk_02; + __le32 unk_04; + __le32 unk_08; + __le32 size; +} __packed; + +struct appletbdrm_response_header { + u8 unk_00[16]; + u32 msg; +} __packed; + +struct appletbdrm_simple_request { + struct appletbdrm_request_header header; + u32 msg; + u8 unk_14[8]; + __le32 size; +} __packed; + +struct appletbdrm_information { + struct appletbdrm_response_header header; + u8 unk_14[12]; + __le32 width; + __le32 height; + u8 bits_per_pixel; + __le32 bytes_per_row; + __le32 orientation; + __le32 bitmap_info; + u32 pixel_format; + __le32 width_inches; /* floating point */ + __le32 height_inches; /* floating point */ +} __packed; + +struct appletbdrm_frame { + __le16 begin_x; + __le16 begin_y; + __le16 width; + __le16 height; + __le32 buf_size; + u8 buf[]; +} __packed; + +struct appletbdrm_fb_request_footer { + u8 unk_00[12]; + __le32 unk_0c; + u8 unk_10[12]; + __le32 unk_1c; + __le64 timestamp; + u8 unk_28[12]; + __le32 unk_34; + u8 unk_38[20]; + __le32 unk_4c; +} __packed; + +struct appletbdrm_fb_request { + struct appletbdrm_request_header header; + __le16 unk_10; + u8 msg_id; + u8 unk_13[29]; + /* + * Contents of `data`: + * - struct appletbdrm_frame frames[]; + * - struct appletbdrm_fb_request_footer footer; + * - padding to make the total size a multiple of 16 + */ + u8 data[]; +} __packed; + +struct appletbdrm_fb_request_response { + struct appletbdrm_response_header header; + u8 unk_14[12]; + __le64 timestamp; +} __packed; + +static int appletbdrm_send_request(struct appletbdrm_device *adev, + struct appletbdrm_request_header *request, size_t size) +{ + struct usb_device *udev = adev_to_udev(adev); + struct drm_device *drm = &adev->drm; + int ret, actual_size; + + ret = usb_bulk_msg(udev, usb_sndbulkpipe(udev, adev->out_ep), + request, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); + if (ret) { + drm_err(drm, "Failed to send message (%pe)\n", ERR_PTR(ret)); + return ret; + } + + if (actual_size != size) { + drm_err(drm, "Actual size (%d) doesn't match expected size (%lu)\n", + actual_size, size); + return -EIO; + } + + return ret; +} + +static int appletbdrm_read_response(struct appletbdrm_device *adev, + struct appletbdrm_response_header *response, + size_t size, u32 expected_response) +{ + struct usb_device *udev = adev_to_udev(adev); + struct drm_device *drm = &adev->drm; + int ret, actual_size; + +retry: + ret = usb_bulk_msg(udev, usb_rcvbulkpipe(udev, adev->in_ep), + response, size, &actual_size, APPLETBDRM_BULK_MSG_TIMEOUT); + if (ret) { + drm_err(drm, "Failed to read response (%pe)\n", ERR_PTR(ret)); + return ret; + } + + /* + * The device responds to the first request sent in a particular + * timeframe after the USB device configuration is set with a readiness + * signal, in which case the response should be read again + */ + if (response->msg == APPLETBDRM_MSG_SIGNAL_READINESS) { + if (!adev->readiness_signal_received) { + adev->readiness_signal_received = true; + goto retry; + } + + drm_err(drm, "Encountered unexpected readiness signal\n"); + return -EIO; + } + + if (actual_size != size) { + drm_err(drm, "Actual size (%d) doesn't match expected size (%lu)\n", + actual_size, size); + return -EIO; + } + + if (response->msg != expected_response) { + drm_err(drm, "Unexpected response from device (expected %p4ch found %p4ch)\n", + &expected_response, &response->msg); + return -EIO; + } + + return 0; +} + +static int appletbdrm_send_msg(struct appletbdrm_device *adev, u32 msg) +{ + struct appletbdrm_simple_request *request; + int ret; + + request = kzalloc(sizeof(*request), GFP_KERNEL); + if (!request) + return -ENOMEM; + + request->header.unk_00 = cpu_to_le16(2); + request->header.unk_02 = cpu_to_le16(0x1512); + request->header.size = cpu_to_le32(sizeof(*request) - sizeof(request->header)); + request->msg = msg; + request->size = request->header.size; + + ret = appletbdrm_send_request(adev, &request->header, sizeof(*request)); + + kfree(request); + + return ret; +} + +static int appletbdrm_clear_display(struct appletbdrm_device *adev) +{ + return appletbdrm_send_msg(adev, APPLETBDRM_MSG_CLEAR_DISPLAY); +} + +static int appletbdrm_signal_readiness(struct appletbdrm_device *adev) +{ + return appletbdrm_send_msg(adev, APPLETBDRM_MSG_SIGNAL_READINESS); +} + +static int appletbdrm_get_information(struct appletbdrm_device *adev) +{ + struct appletbdrm_information *info; + struct drm_device *drm = &adev->drm; + u8 bits_per_pixel; + u32 pixel_format; + int ret; + + info = kzalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return -ENOMEM; + + ret = appletbdrm_send_msg(adev, APPLETBDRM_MSG_GET_INFORMATION); + if (ret) + return ret; + + ret = appletbdrm_read_response(adev, &info->header, sizeof(*info), + APPLETBDRM_MSG_GET_INFORMATION); + if (ret) + goto free_info; + + bits_per_pixel = info->bits_per_pixel; + pixel_format = get_unaligned(&info->pixel_format); + + adev->width = get_unaligned_le32(&info->width); + adev->height = get_unaligned_le32(&info->height); + + if (bits_per_pixel != APPLETBDRM_BITS_PER_PIXEL) { + drm_err(drm, "Encountered unexpected bits per pixel value (%d)\n", bits_per_pixel); + ret = -EINVAL; + goto free_info; + } + + if (pixel_format != APPLETBDRM_PIXEL_FORMAT) { + drm_err(drm, "Encountered unknown pixel format (%p4ch)\n", &pixel_format); + ret = -EINVAL; + goto free_info; + } + +free_info: + kfree(info); + + return ret; +} + +static u32 rect_size(struct drm_rect *rect) +{ + return drm_rect_width(rect) * drm_rect_height(rect) * (APPLETBDRM_BITS_PER_PIXEL / 8); +} + +static int appletbdrm_flush_damage(struct appletbdrm_device *adev, + struct drm_plane_state *old_state, + struct drm_plane_state *state) +{ + struct drm_shadow_plane_state *shadow_plane_state = to_drm_shadow_plane_state(state); + struct appletbdrm_fb_request_response *response; + struct appletbdrm_fb_request_footer *footer; + struct drm_atomic_helper_damage_iter iter; + struct drm_framebuffer *fb = state->fb; + struct appletbdrm_fb_request *request; + struct drm_device *drm = &adev->drm; + struct appletbdrm_frame *frame; + u64 timestamp = ktime_get_ns(); + struct drm_rect damage; + size_t frames_size = 0; + size_t request_size; + int ret; + + drm_atomic_helper_damage_iter_init(&iter, old_state, state); + drm_atomic_for_each_plane_damage(&iter, &damage) { + frames_size += struct_size(frame, buf, rect_size(&damage)); + } + + if (!frames_size) + return 0; + + request_size = ALIGN(sizeof(*request) + frames_size + sizeof(*footer), 16); + + request = kzalloc(request_size, GFP_KERNEL); + if (!request) + return -ENOMEM; + + response = kzalloc(sizeof(*response), GFP_KERNEL); + if (!response) { + ret = -ENOMEM; + goto free_request; + } + + ret = drm_gem_fb_begin_cpu_access(fb, DMA_FROM_DEVICE); + if (ret) { + drm_err(drm, "Failed to start CPU framebuffer access (%pe)\n", ERR_PTR(ret)); + goto free_response; + } + + request->header.unk_00 = cpu_to_le16(2); + request->header.unk_02 = cpu_to_le16(0x12); + request->header.unk_04 = cpu_to_le32(9); + request->header.size = cpu_to_le32(request_size - sizeof(request->header)); + request->unk_10 = cpu_to_le16(1); + request->msg_id = timestamp & 0xff; + + frame = (struct appletbdrm_frame *)request->data; + + drm_atomic_helper_damage_iter_init(&iter, old_state, state); + drm_atomic_for_each_plane_damage(&iter, &damage) { + struct iosys_map dst = IOSYS_MAP_INIT_VADDR(frame->buf); + u32 buf_size = rect_size(&damage); + + /* + * The coordinates need to be translated to the coordinate + * system the device expects, see the comment in + * appletbdrm_setup_mode_config + */ + frame->begin_x = cpu_to_le16(damage.y1); + frame->begin_y = cpu_to_le16(adev->height - damage.x2); + frame->width = cpu_to_le16(drm_rect_height(&damage)); + frame->height = cpu_to_le16(drm_rect_width(&damage)); + frame->buf_size = cpu_to_le32(buf_size); + + ret = drm_fb_blit(&dst, NULL, DRM_FORMAT_BGR888, + &shadow_plane_state->data[0], fb, &damage, &shadow_plane_state->fmtcnv_state); + if (ret) { + drm_err(drm, "Failed to copy damage clip (%pe)\n", ERR_PTR(ret)); + goto end_fb_cpu_access; + } + + frame = (void *)frame + struct_size(frame, buf, buf_size); + } + + footer = (struct appletbdrm_fb_request_footer *)&request->data[frames_size]; + + footer->unk_0c = cpu_to_le32(0xfffe); + footer->unk_1c = cpu_to_le32(0x80001); + footer->unk_34 = cpu_to_le32(0x80002); + footer->unk_4c = cpu_to_le32(0xffff); + footer->timestamp = cpu_to_le64(timestamp); + + ret = appletbdrm_send_request(adev, &request->header, request_size); + if (ret) + goto end_fb_cpu_access; + + ret = appletbdrm_read_response(adev, &response->header, sizeof(*response), + APPLETBDRM_MSG_UPDATE_COMPLETE); + if (ret) + goto end_fb_cpu_access; + + if (response->timestamp != footer->timestamp) { + drm_err(drm, "Response timestamp (%llu) doesn't match request timestamp (%llu)\n", + le64_to_cpu(response->timestamp), timestamp); + goto end_fb_cpu_access; + } + +end_fb_cpu_access: + drm_gem_fb_end_cpu_access(fb, DMA_FROM_DEVICE); +free_response: + kfree(response); +free_request: + kfree(request); + + return ret; +} + +static int appletbdrm_connector_helper_get_modes(struct drm_connector *connector) +{ + struct appletbdrm_device *adev = drm_to_adev(connector->dev); + + return drm_connector_helper_get_modes_fixed(connector, &adev->mode); +} + +static enum drm_mode_status appletbdrm_pipe_mode_valid(struct drm_simple_display_pipe *pipe, + const struct drm_display_mode *mode) +{ + struct drm_crtc *crtc = &pipe->crtc; + struct appletbdrm_device *adev = drm_to_adev(crtc->dev); + + return drm_crtc_helper_mode_valid_fixed(crtc, mode, &adev->mode); +} + +static void appletbdrm_pipe_disable(struct drm_simple_display_pipe *pipe) +{ + struct appletbdrm_device *adev = drm_to_adev(pipe->crtc.dev); + int idx; + + if (!drm_dev_enter(&adev->drm, &idx)) + return; + + appletbdrm_clear_display(adev); + + drm_dev_exit(idx); +} + +static void appletbdrm_pipe_update(struct drm_simple_display_pipe *pipe, + struct drm_plane_state *old_state) +{ + struct drm_crtc *crtc = &pipe->crtc; + struct appletbdrm_device *adev = drm_to_adev(crtc->dev); + int idx; + + if (!crtc->state->active || !drm_dev_enter(&adev->drm, &idx)) + return; + + appletbdrm_flush_damage(adev, old_state, pipe->plane.state); + + drm_dev_exit(idx); +} + +static const u32 appletbdrm_formats[] = { + DRM_FORMAT_BGR888, + DRM_FORMAT_XRGB8888, /* emulated */ +}; + +static const struct drm_mode_config_funcs appletbdrm_mode_config_funcs = { + .fb_create = drm_gem_fb_create_with_dirty, + .atomic_check = drm_atomic_helper_check, + .atomic_commit = drm_atomic_helper_commit, +}; + +static const struct drm_connector_funcs appletbdrm_connector_funcs = { + .reset = drm_atomic_helper_connector_reset, + .destroy = drm_connector_cleanup, + .fill_modes = drm_helper_probe_single_connector_modes, + .atomic_destroy_state = drm_atomic_helper_connector_destroy_state, + .atomic_duplicate_state = drm_atomic_helper_connector_duplicate_state, +}; + +static const struct drm_connector_helper_funcs appletbdrm_connector_helper_funcs = { + .get_modes = appletbdrm_connector_helper_get_modes, +}; + +static const struct drm_simple_display_pipe_funcs appletbdrm_pipe_funcs = { + DRM_GEM_SIMPLE_DISPLAY_PIPE_SHADOW_PLANE_FUNCS, + .update = appletbdrm_pipe_update, + .disable = appletbdrm_pipe_disable, + .mode_valid = appletbdrm_pipe_mode_valid, +}; + +DEFINE_DRM_GEM_FOPS(appletbdrm_drm_fops); + +static const struct drm_driver appletbdrm_drm_driver = { + DRM_GEM_SHMEM_DRIVER_OPS, + .name = "appletbdrm", + .desc = "Apple Touch Bar DRM Driver", + .date = "20230910", + .major = 1, + .minor = 0, + .driver_features = DRIVER_MODESET | DRIVER_GEM | DRIVER_ATOMIC, + .fops = &appletbdrm_drm_fops, +}; + +static int appletbdrm_setup_mode_config(struct appletbdrm_device *adev) +{ + struct drm_connector *connector = &adev->connector; + struct drm_device *drm = &adev->drm; + struct device *dev = adev->dev; + int ret; + + ret = drmm_mode_config_init(drm); + if (ret) + return dev_err_probe(dev, ret, "Failed to initialize mode configuration\n"); + + /* + * The coordinate system used by the device is different from the + * coordinate system of the framebuffer in that the x and y axes are + * swapped, and that the y axis is inverted; so what the device reports + * as the height is actually the width of the framebuffer and vice + * versa + */ + drm->mode_config.min_width = 0; + drm->mode_config.min_height = 0; + drm->mode_config.max_width = max(adev->height, DRM_SHADOW_PLANE_MAX_WIDTH); + drm->mode_config.max_height = max(adev->width, DRM_SHADOW_PLANE_MAX_HEIGHT); + drm->mode_config.preferred_depth = APPLETBDRM_BITS_PER_PIXEL; + drm->mode_config.funcs = &appletbdrm_mode_config_funcs; + + adev->mode = (struct drm_display_mode) { + DRM_MODE_INIT(60, adev->height, adev->width, + DRM_MODE_RES_MM(adev->height, 218), + DRM_MODE_RES_MM(adev->width, 218)) + }; + + ret = drm_connector_init(drm, connector, + &appletbdrm_connector_funcs, DRM_MODE_CONNECTOR_USB); + if (ret) + return dev_err_probe(dev, ret, "Failed to initialize connector\n"); + + drm_connector_helper_add(connector, &appletbdrm_connector_helper_funcs); + + ret = drm_connector_set_panel_orientation(connector, + DRM_MODE_PANEL_ORIENTATION_RIGHT_UP); + if (ret) + return dev_err_probe(dev, ret, "Failed to set panel orientation\n"); + + connector->display_info.non_desktop = true; + ret = drm_object_property_set_value(&connector->base, + drm->mode_config.non_desktop_property, true); + if (ret) + return dev_err_probe(dev, ret, "Failed to set non-desktop property\n"); + + ret = drm_simple_display_pipe_init(drm, &adev->pipe, &appletbdrm_pipe_funcs, + appletbdrm_formats, ARRAY_SIZE(appletbdrm_formats), + NULL, &adev->connector); + if (ret) + return dev_err_probe(dev, ret, "Failed to initialize simple display pipe\n"); + + drm_plane_enable_fb_damage_clips(&adev->pipe.plane); + + drm_mode_config_reset(drm); + + ret = drm_dev_register(drm, 0); + if (ret) + return dev_err_probe(dev, ret, "Failed to register DRM device\n"); + + return 0; +} + +static int appletbdrm_probe(struct usb_interface *intf, + const struct usb_device_id *id) +{ + struct usb_endpoint_descriptor *bulk_in, *bulk_out; + struct device *dev = &intf->dev; + struct appletbdrm_device *adev; + int ret; + + ret = usb_find_common_endpoints(intf->cur_altsetting, &bulk_in, &bulk_out, NULL, NULL); + if (ret) + return dev_err_probe(dev, ret, "Failed to find bulk endpoints\n"); + + adev = devm_drm_dev_alloc(dev, &appletbdrm_drm_driver, struct appletbdrm_device, drm); + if (IS_ERR(adev)) + return PTR_ERR(adev); + + adev->dev = dev; + adev->in_ep = bulk_in->bEndpointAddress; + adev->out_ep = bulk_out->bEndpointAddress; + + usb_set_intfdata(intf, adev); + + ret = appletbdrm_get_information(adev); + if (ret) + return dev_err_probe(dev, ret, "Failed to get display information\n"); + + ret = appletbdrm_signal_readiness(adev); + if (ret) + return dev_err_probe(dev, ret, "Failed to signal readiness\n"); + + ret = appletbdrm_clear_display(adev); + if (ret) + return dev_err_probe(dev, ret, "Failed to clear display\n"); + + return appletbdrm_setup_mode_config(adev); +} + +static void appletbdrm_disconnect(struct usb_interface *intf) +{ + struct appletbdrm_device *adev = usb_get_intfdata(intf); + struct drm_device *drm = &adev->drm; + + drm_dev_unplug(drm); + drm_atomic_helper_shutdown(drm); +} + +static void appletbdrm_shutdown(struct usb_interface *intf) +{ + struct appletbdrm_device *adev = usb_get_intfdata(intf); + + /* + * The framebuffer needs to be cleared on shutdown since its content + * persists across boots + */ + drm_atomic_helper_shutdown(&adev->drm); +} + +static const struct usb_device_id appletbdrm_usb_id_table[] = { + { USB_DEVICE_INTERFACE_CLASS(0x05ac, 0x8302, USB_CLASS_AUDIO_VIDEO) }, + {} +}; +MODULE_DEVICE_TABLE(usb, appletbdrm_usb_id_table); + +static struct usb_driver appletbdrm_usb_driver = { + .name = "appletbdrm", + .probe = appletbdrm_probe, + .disconnect = appletbdrm_disconnect, + .shutdown = appletbdrm_shutdown, + .id_table = appletbdrm_usb_id_table, +}; +module_usb_driver(appletbdrm_usb_driver); + +MODULE_AUTHOR("Kerem Karabay "); +MODULE_DESCRIPTION("Apple Touch Bar DRM Driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/gpu/vga/vga_switcheroo.c b/drivers/gpu/vga/vga_switcheroo.c index 18f2c92beff8..3de1bca45ed2 100644 --- a/drivers/gpu/vga/vga_switcheroo.c +++ b/drivers/gpu/vga/vga_switcheroo.c @@ -438,12 +438,7 @@ find_active_client(struct list_head *head) bool vga_switcheroo_client_probe_defer(struct pci_dev *pdev) { if ((pdev->class >> 16) == PCI_BASE_CLASS_DISPLAY) { - /* - * apple-gmux is needed on pre-retina MacBook Pro - * to probe the panel if pdev is the inactive GPU. - */ - if (apple_gmux_present() && pdev != vga_default_device() && - !vgasr_priv.handler_flags) + if (apple_gmux_present() && !vgasr_priv.handler_flags) return true; } diff --git a/drivers/hid/Kconfig b/drivers/hid/Kconfig index f8a56d631242..6c8e9e004907 100644 --- a/drivers/hid/Kconfig +++ b/drivers/hid/Kconfig @@ -148,6 +148,27 @@ config HID_APPLEIR Say Y here if you want support for Apple infrared remote control. +config HID_APPLETB_BL + tristate "Apple Touch Bar Backlight" + depends on BACKLIGHT_CLASS_DEVICE + help + Say Y here if you want support for the backlight of Touch Bars on x86 + MacBook Pros. + + To compile this driver as a module, choose M here: the + module will be called hid-appletb-bl. + +config HID_APPLETB_KBD + tristate "Apple Touch Bar Keyboard Mode" + depends on USB_HID + help + Say Y here if you want support for the keyboard mode (escape, + function, media and brightness keys) of Touch Bars on x86 MacBook + Pros. + + To compile this driver as a module, choose M here: the + module will be called hid-appletb-kbd. + config HID_ASUS tristate "Asus" depends on USB_HID @@ -729,6 +750,7 @@ config HID_MULTITOUCH Say Y here if you have one of the following devices: - 3M PCT touch screens - ActionStar dual touch panels + - Touch Bars on x86 MacBook Pros - Atmel panels - Cando dual touch panels - Chunghwa panels diff --git a/drivers/hid/Makefile b/drivers/hid/Makefile index 496dab54c73a..13d32f55e5d4 100644 --- a/drivers/hid/Makefile +++ b/drivers/hid/Makefile @@ -29,6 +29,8 @@ obj-$(CONFIG_HID_ALPS) += hid-alps.o obj-$(CONFIG_HID_ACRUX) += hid-axff.o obj-$(CONFIG_HID_APPLE) += hid-apple.o obj-$(CONFIG_HID_APPLEIR) += hid-appleir.o +obj-$(CONFIG_HID_APPLETB_BL) += hid-appletb-bl.o +obj-$(CONFIG_HID_APPLETB_KBD) += hid-appletb-kbd.o obj-$(CONFIG_HID_CREATIVE_SB0540) += hid-creative-sb0540.o obj-$(CONFIG_HID_ASUS) += hid-asus.o obj-$(CONFIG_HID_AUREAL) += hid-aureal.o diff --git a/drivers/hid/hid-appletb-bl.c b/drivers/hid/hid-appletb-bl.c new file mode 100644 index 000000000000..819157686e59 --- /dev/null +++ b/drivers/hid/hid-appletb-bl.c @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar Backlight Driver + * + * Copyright (c) 2017-2018 Ronald Tschalär + * Copyright (c) 2022-2023 Kerem Karabay + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +#include "hid-ids.h" + +#define APPLETB_BL_ON 1 +#define APPLETB_BL_DIM 3 +#define APPLETB_BL_OFF 4 + +#define HID_UP_APPLEVENDOR_TB_BL 0xff120000 + +#define HID_VD_APPLE_TB_BRIGHTNESS 0xff120001 +#define HID_USAGE_AUX1 0xff120020 +#define HID_USAGE_BRIGHTNESS 0xff120021 + +static int appletb_bl_def_brightness = 2; +module_param_named(brightness, appletb_bl_def_brightness, int, 0444); +MODULE_PARM_DESC(brightness, "Default brightness:\n" + " 0 - Touchbar is off\n" + " 1 - Dim brightness\n" + " [2] - Full brightness"); + +struct appletb_bl { + struct hid_field *aux1_field, *brightness_field; + struct backlight_device *bdev; + + bool full_on; +}; + +static const u8 appletb_bl_brightness_map[] = { + APPLETB_BL_OFF, + APPLETB_BL_DIM, + APPLETB_BL_ON, +}; + +static int appletb_bl_set_brightness(struct appletb_bl *bl, u8 brightness) +{ + struct hid_report *report = bl->brightness_field->report; + struct hid_device *hdev = report->device; + int ret; + + ret = hid_set_field(bl->aux1_field, 0, 1); + if (ret) { + hid_err(hdev, "Failed to set auxiliary field (%pe)\n", ERR_PTR(ret)); + return ret; + } + + ret = hid_set_field(bl->brightness_field, 0, brightness); + if (ret) { + hid_err(hdev, "Failed to set brightness field (%pe)\n", ERR_PTR(ret)); + return ret; + } + + if (!bl->full_on) { + ret = hid_hw_power(hdev, PM_HINT_FULLON); + if (ret < 0) { + hid_err(hdev, "Device didn't power on (%pe)\n", ERR_PTR(ret)); + return ret; + } + + bl->full_on = true; + } + + hid_hw_request(hdev, report, HID_REQ_SET_REPORT); + + if (brightness == APPLETB_BL_OFF) { + hid_hw_power(hdev, PM_HINT_NORMAL); + bl->full_on = false; + } + + return 0; +} + +static int appletb_bl_update_status(struct backlight_device *bdev) +{ + struct appletb_bl *bl = bl_get_data(bdev); + u8 brightness; + + if (backlight_is_blank(bdev)) + brightness = APPLETB_BL_OFF; + else + brightness = appletb_bl_brightness_map[backlight_get_brightness(bdev)]; + + return appletb_bl_set_brightness(bl, brightness); +} + +static const struct backlight_ops appletb_bl_backlight_ops = { + .options = BL_CORE_SUSPENDRESUME, + .update_status = appletb_bl_update_status, +}; + +static int appletb_bl_probe(struct hid_device *hdev, const struct hid_device_id *id) +{ + struct hid_field *aux1_field, *brightness_field; + struct backlight_properties bl_props = { 0 }; + struct device *dev = &hdev->dev; + struct appletb_bl *bl; + int ret; + + ret = hid_parse(hdev); + if (ret) + return dev_err_probe(dev, ret, "HID parse failed\n"); + + aux1_field = hid_find_field(hdev, HID_FEATURE_REPORT, + HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_AUX1); + + brightness_field = hid_find_field(hdev, HID_FEATURE_REPORT, + HID_VD_APPLE_TB_BRIGHTNESS, HID_USAGE_BRIGHTNESS); + + if (!aux1_field || !brightness_field) + return -ENODEV; + + if (aux1_field->report != brightness_field->report) + return dev_err_probe(dev, -ENODEV, "Encountered unexpected report structure\n"); + + bl = devm_kzalloc(dev, sizeof(*bl), GFP_KERNEL); + if (!bl) + return -ENOMEM; + + ret = hid_hw_start(hdev, HID_CONNECT_DRIVER); + if (ret) + return dev_err_probe(dev, ret, "HID hardware start failed\n"); + + ret = hid_hw_open(hdev); + if (ret) { + dev_err_probe(dev, ret, "HID hardware open failed\n"); + goto stop_hw; + } + + bl->aux1_field = aux1_field; + bl->brightness_field = brightness_field; + + if (appletb_bl_def_brightness == 0) + ret = appletb_bl_set_brightness(bl, APPLETB_BL_OFF); + else if (appletb_bl_def_brightness == 1) + ret = appletb_bl_set_brightness(bl, APPLETB_BL_DIM); + else + ret = appletb_bl_set_brightness(bl, APPLETB_BL_ON); + + if (ret) { + dev_err_probe(dev, ret, "Failed to set touch bar brightness to off\n"); + goto close_hw; + } + + bl_props.type = BACKLIGHT_RAW; + bl_props.max_brightness = ARRAY_SIZE(appletb_bl_brightness_map) - 1; + + bl->bdev = devm_backlight_device_register(dev, "appletb_backlight", dev, bl, + &appletb_bl_backlight_ops, &bl_props); + if (IS_ERR(bl->bdev)) { + ret = PTR_ERR(bl->bdev); + dev_err_probe(dev, ret, "Failed to register backlight device\n"); + goto close_hw; + } + + hid_set_drvdata(hdev, bl); + + return 0; + +close_hw: + hid_hw_close(hdev); +stop_hw: + hid_hw_stop(hdev); + + return ret; +} + +static void appletb_bl_remove(struct hid_device *hdev) +{ + struct appletb_bl *bl = hid_get_drvdata(hdev); + + appletb_bl_set_brightness(bl, APPLETB_BL_OFF); + + hid_hw_close(hdev); + hid_hw_stop(hdev); +} + +static const struct hid_device_id appletb_bl_hid_ids[] = { + /* MacBook Pro's 2018, 2019, with T2 chip: iBridge DFR Brightness */ + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, + { } +}; +MODULE_DEVICE_TABLE(hid, appletb_bl_hid_ids); + +static struct hid_driver appletb_bl_hid_driver = { + .name = "hid-appletb-bl", + .id_table = appletb_bl_hid_ids, + .probe = appletb_bl_probe, + .remove = appletb_bl_remove, +}; +module_hid_driver(appletb_bl_hid_driver); + +MODULE_AUTHOR("Ronald Tschalär"); +MODULE_AUTHOR("Kerem Karabay "); +MODULE_DESCRIPTION("MacBookPro Touch Bar Backlight Driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-appletb-kbd.c b/drivers/hid/hid-appletb-kbd.c new file mode 100644 index 000000000000..442c4d8848df --- /dev/null +++ b/drivers/hid/hid-appletb-kbd.c @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Apple Touch Bar Keyboard Mode Driver + * + * Copyright (c) 2017-2018 Ronald Tschalär + * Copyright (c) 2022-2023 Kerem Karabay + * Copyright (c) 2024 Aditya Garg + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hid-ids.h" + +#define APPLETB_KBD_MODE_ESC 0 +#define APPLETB_KBD_MODE_FN 1 +#define APPLETB_KBD_MODE_SPCL 2 +#define APPLETB_KBD_MODE_OFF 3 +#define APPLETB_KBD_MODE_MAX APPLETB_KBD_MODE_OFF + +#define APPLETB_DEVID_KEYBOARD 1 + +#define HID_USAGE_MODE 0x00ff0004 + +static int appletb_tb_def_mode = APPLETB_KBD_MODE_SPCL; +module_param_named(mode, appletb_tb_def_mode, int, 0444); +MODULE_PARM_DESC(mode, "Default touchbar mode:\n" + " 0 - escape key only\n" + " 1 - function-keys\n" + " [2] - special keys"); + +static bool appletb_tb_fn_toggle = true; +module_param_named(fntoggle, appletb_tb_fn_toggle, bool, 0644); +MODULE_PARM_DESC(fntoggle, "Switch between Fn and media controls on pressing Fn key"); + +struct appletb_kbd { + struct hid_field *mode_field; + + u8 saved_mode; + u8 current_mode; + struct input_handler inp_handler; + struct input_handle kbd_handle; + +}; + +static const struct key_entry appletb_kbd_keymap[] = { + { KE_KEY, KEY_ESC, { KEY_ESC } }, + { KE_KEY, KEY_F1, { KEY_BRIGHTNESSDOWN } }, + { KE_KEY, KEY_F2, { KEY_BRIGHTNESSUP } }, + { KE_KEY, KEY_F3, { KEY_RESERVED } }, + { KE_KEY, KEY_F4, { KEY_RESERVED } }, + { KE_KEY, KEY_F5, { KEY_KBDILLUMDOWN } }, + { KE_KEY, KEY_F6, { KEY_KBDILLUMUP } }, + { KE_KEY, KEY_F7, { KEY_PREVIOUSSONG } }, + { KE_KEY, KEY_F8, { KEY_PLAYPAUSE } }, + { KE_KEY, KEY_F9, { KEY_NEXTSONG } }, + { KE_KEY, KEY_F10, { KEY_MUTE } }, + { KE_KEY, KEY_F11, { KEY_VOLUMEDOWN } }, + { KE_KEY, KEY_F12, { KEY_VOLUMEUP } }, + { KE_END, 0 } +}; + +static int appletb_kbd_set_mode(struct appletb_kbd *kbd, u8 mode) +{ + struct hid_report *report = kbd->mode_field->report; + struct hid_device *hdev = report->device; + int ret; + + ret = hid_hw_power(hdev, PM_HINT_FULLON); + if (ret) { + hid_err(hdev, "Device didn't resume (%pe)\n", ERR_PTR(ret)); + return ret; + } + + ret = hid_set_field(kbd->mode_field, 0, mode); + if (ret) { + hid_err(hdev, "Failed to set mode field to %u (%pe)\n", mode, ERR_PTR(ret)); + goto power_normal; + } + + hid_hw_request(hdev, report, HID_REQ_SET_REPORT); + + kbd->current_mode = mode; + +power_normal: + hid_hw_power(hdev, PM_HINT_NORMAL); + + return ret; +} + +static ssize_t mode_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct appletb_kbd *kbd = dev_get_drvdata(dev); + + return sysfs_emit(buf, "%d\n", kbd->current_mode); +} + +static ssize_t mode_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t size) +{ + struct appletb_kbd *kbd = dev_get_drvdata(dev); + u8 mode; + int ret; + + ret = kstrtou8(buf, 0, &mode); + if (ret) + return ret; + + if (mode > APPLETB_KBD_MODE_MAX) + return -EINVAL; + + ret = appletb_kbd_set_mode(kbd, mode); + + return ret < 0 ? ret : size; +} +static DEVICE_ATTR_RW(mode); + +struct attribute *appletb_kbd_attrs[] = { + &dev_attr_mode.attr, + NULL +}; +ATTRIBUTE_GROUPS(appletb_kbd); + +static int appletb_tb_key_to_slot(unsigned int code) +{ + switch (code) { + case KEY_ESC: + return 0; + case KEY_F1 ... KEY_F10: + return code - KEY_F1 + 1; + case KEY_F11 ... KEY_F12: + return code - KEY_F11 + 11; + + default: + return -EINVAL; + } +} + +static int appletb_kbd_hid_event(struct hid_device *hdev, struct hid_field *field, + struct hid_usage *usage, __s32 value) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + struct key_entry *translation; + struct input_dev *input; + int slot; + + if ((usage->hid & HID_USAGE_PAGE) != HID_UP_KEYBOARD || usage->type != EV_KEY) + return 0; + + input = field->hidinput->input; + + /* + * Skip non-touch-bar keys. + * + * Either the touch bar itself or usbhid generate a slew of key-down + * events for all the meta keys. None of which we're at all interested + * in. + */ + slot = appletb_tb_key_to_slot(usage->code); + if (slot < 0) + return 0; + + translation = sparse_keymap_entry_from_scancode(input, usage->code); + + if (translation && kbd->current_mode == APPLETB_KBD_MODE_SPCL) { + input_event(input, usage->type, translation->keycode, value); + + return 1; + } + + return kbd->current_mode == APPLETB_KBD_MODE_OFF; +} + +static void appletb_kbd_inp_event(struct input_handle *handle, unsigned int type, + unsigned int code, int value) +{ + struct appletb_kbd *kbd = handle->private; + + if (type == EV_KEY && code == KEY_FN && appletb_tb_fn_toggle) { + if (value == 1) { + kbd->saved_mode = kbd->current_mode; + if (kbd->current_mode == APPLETB_KBD_MODE_SPCL) + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_FN); + else if (kbd->current_mode == APPLETB_KBD_MODE_FN) + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_SPCL); + } else if (value == 0) { + if (kbd->saved_mode != kbd->current_mode) + appletb_kbd_set_mode(kbd, kbd->saved_mode); + } + } +} + +static int appletb_kbd_inp_connect(struct input_handler *handler, + struct input_dev *dev, + const struct input_device_id *id) +{ + struct appletb_kbd *kbd = handler->private; + struct input_handle *handle; + int rc; + + if (id->driver_info == APPLETB_DEVID_KEYBOARD) { + handle = &kbd->kbd_handle; + handle->name = "tbkbd"; + } else { + return -ENOENT; + } + + if (handle->dev) + return -EEXIST; + + handle->open = 0; + handle->dev = input_get_device(dev); + handle->handler = handler; + handle->private = kbd; + + rc = input_register_handle(handle); + if (rc) + goto err_free_dev; + + rc = input_open_device(handle); + if (rc) + goto err_unregister_handle; + + return 0; + + err_unregister_handle: + input_unregister_handle(handle); + err_free_dev: + input_put_device(handle->dev); + handle->dev = NULL; + return rc; +} + +static void appletb_kbd_inp_disconnect(struct input_handle *handle) +{ + input_close_device(handle); + input_unregister_handle(handle); + + input_put_device(handle->dev); + handle->dev = NULL; +} + +static int appletb_kbd_input_configured(struct hid_device *hdev, struct hid_input *hidinput) +{ + int idx; + struct input_dev *input = hidinput->input; + + /* + * Clear various input capabilities that are blindly set by the hid + * driver (usbkbd.c) + */ + memset(input->evbit, 0, sizeof(input->evbit)); + memset(input->keybit, 0, sizeof(input->keybit)); + memset(input->ledbit, 0, sizeof(input->ledbit)); + + __set_bit(EV_REP, input->evbit); + + sparse_keymap_setup(input, appletb_kbd_keymap, NULL); + + for (idx = 0; appletb_kbd_keymap[idx].type != KE_END; idx++) { + input_set_capability(input, EV_KEY, appletb_kbd_keymap[idx].code); + } + + return 0; +} + +static const struct input_device_id appletb_kbd_input_devices[] = { + { + .flags = INPUT_DEVICE_ID_MATCH_BUS | + INPUT_DEVICE_ID_MATCH_VENDOR | + INPUT_DEVICE_ID_MATCH_KEYBIT, + .bustype = BUS_USB, + .vendor = USB_VENDOR_ID_APPLE, + .keybit = { [BIT_WORD(KEY_FN)] = BIT_MASK(KEY_FN) }, + .driver_info = APPLETB_DEVID_KEYBOARD, + }, + { } +}; + +static bool appletb_kbd_match_internal_device(struct input_handler *handler, + struct input_dev *inp_dev) +{ + struct device *dev = &inp_dev->dev; + + /* in kernel: dev && !is_usb_device(dev) */ + while (dev && !(dev->type && dev->type->name && + !strcmp(dev->type->name, "usb_device"))) + dev = dev->parent; + + /* + * Apple labels all their internal keyboards and trackpads as such, + * instead of maintaining an ever expanding list of product-id's we + * just look at the device's product name. + */ + if (dev) + return !!strstr(to_usb_device(dev)->product, "Internal Keyboard"); + + return false; +} + +static int appletb_kbd_probe(struct hid_device *hdev, const struct hid_device_id *id) +{ + struct appletb_kbd *kbd; + struct device *dev = &hdev->dev; + struct hid_field *mode_field; + int ret; + + ret = hid_parse(hdev); + if (ret) + return dev_err_probe(dev, ret, "HID parse failed\n"); + + mode_field = hid_find_field(hdev, HID_OUTPUT_REPORT, + HID_GD_KEYBOARD, HID_USAGE_MODE); + if (!mode_field) + return -ENODEV; + + kbd = devm_kzalloc(dev, sizeof(*kbd), GFP_KERNEL); + if (!kbd) + return -ENOMEM; + + kbd->mode_field = mode_field; + + ret = hid_hw_start(hdev, HID_CONNECT_HIDINPUT); + if (ret) + return dev_err_probe(dev, ret, "HID hw start failed\n"); + + ret = hid_hw_open(hdev); + if (ret) { + dev_err_probe(dev, ret, "HID hw open failed\n"); + goto stop_hw; + } + + kbd->inp_handler.event = appletb_kbd_inp_event; + kbd->inp_handler.connect = appletb_kbd_inp_connect; + kbd->inp_handler.disconnect = appletb_kbd_inp_disconnect; + kbd->inp_handler.name = "appletb"; + kbd->inp_handler.id_table = appletb_kbd_input_devices; + kbd->inp_handler.match = appletb_kbd_match_internal_device; + kbd->inp_handler.private = kbd; + + ret = input_register_handler(&kbd->inp_handler); + if (ret) { + dev_err_probe(dev, ret, "Unable to register keyboard handler\n"); + goto close_hw; + } + + ret = appletb_kbd_set_mode(kbd, appletb_tb_def_mode); + if (ret) { + dev_err_probe(dev, ret, "Failed to set touchbar mode\n"); + goto close_hw; + } + + hid_set_drvdata(hdev, kbd); + + return 0; + +close_hw: + hid_hw_close(hdev); +stop_hw: + hid_hw_stop(hdev); + return ret; +} + +static void appletb_kbd_remove(struct hid_device *hdev) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); + + input_unregister_handler(&kbd->inp_handler); + + hid_hw_close(hdev); + hid_hw_stop(hdev); +} + +#ifdef CONFIG_PM +static int appletb_kbd_suspend(struct hid_device *hdev, pm_message_t msg) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + kbd->saved_mode = kbd->current_mode; + appletb_kbd_set_mode(kbd, APPLETB_KBD_MODE_OFF); + + return 0; +} + +static int appletb_kbd_reset_resume(struct hid_device *hdev) +{ + struct appletb_kbd *kbd = hid_get_drvdata(hdev); + + appletb_kbd_set_mode(kbd, kbd->saved_mode); + + return 0; +} +#endif + +static const struct hid_device_id appletb_kbd_hid_ids[] = { + /* MacBook Pro's 2018, 2019, with T2 chip: iBridge Display */ + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, + { } +}; +MODULE_DEVICE_TABLE(hid, appletb_kbd_hid_ids); + +static struct hid_driver appletb_kbd_hid_driver = { + .name = "hid-appletb-kbd", + .id_table = appletb_kbd_hid_ids, + .probe = appletb_kbd_probe, + .remove = appletb_kbd_remove, + .event = appletb_kbd_hid_event, + .input_configured = appletb_kbd_input_configured, +#ifdef CONFIG_PM + .suspend = appletb_kbd_suspend, + .reset_resume = appletb_kbd_reset_resume, +#endif + .driver.dev_groups = appletb_kbd_groups, +}; +module_hid_driver(appletb_kbd_hid_driver); + +MODULE_AUTHOR("Ronald Tschalär"); +MODULE_AUTHOR("Kerem Karabay "); +MODULE_DESCRIPTION("MacBookPro Touch Bar Keyboard Mode Driver"); +MODULE_LICENSE("GPL"); diff --git a/drivers/hid/hid-multitouch.c b/drivers/hid/hid-multitouch.c index e936019d21fe..0d5382e965de 100644 --- a/drivers/hid/hid-multitouch.c +++ b/drivers/hid/hid-multitouch.c @@ -72,6 +72,7 @@ MODULE_LICENSE("GPL"); #define MT_QUIRK_FORCE_MULTI_INPUT BIT(20) #define MT_QUIRK_DISABLE_WAKEUP BIT(21) #define MT_QUIRK_ORIENTATION_INVERT BIT(22) +#define MT_QUIRK_TOUCH_IS_TIPSTATE BIT(23) #define MT_INPUTMODE_TOUCHSCREEN 0x02 #define MT_INPUTMODE_TOUCHPAD 0x03 @@ -145,6 +146,7 @@ struct mt_class { __s32 sn_height; /* Signal/noise ratio for height events */ __s32 sn_pressure; /* Signal/noise ratio for pressure events */ __u8 maxcontacts; + bool is_direct; /* true for touchscreens */ bool is_indirect; /* true for touchpads */ bool export_all_inputs; /* do not ignore mouse, keyboards, etc... */ }; @@ -212,6 +214,7 @@ static void mt_post_parse(struct mt_device *td, struct mt_application *app); #define MT_CLS_GOOGLE 0x0111 #define MT_CLS_RAZER_BLADE_STEALTH 0x0112 #define MT_CLS_SMART_TECH 0x0113 +#define MT_CLS_APPLE_TOUCHBAR 0x0114 #define MT_CLS_SIS 0x0457 #define MT_DEFAULT_MAXCONTACT 10 @@ -397,6 +400,13 @@ static const struct mt_class mt_classes[] = { MT_QUIRK_CONTACT_CNT_ACCURATE | MT_QUIRK_SEPARATE_APP_REPORT, }, + { .name = MT_CLS_APPLE_TOUCHBAR, + .quirks = MT_QUIRK_HOVERING | + MT_QUIRK_TOUCH_IS_TIPSTATE | + MT_QUIRK_SLOT_IS_CONTACTID_MINUS_ONE, + .is_direct = true, + .maxcontacts = 11, + }, { .name = MT_CLS_SIS, .quirks = MT_QUIRK_NOT_SEEN_MEANS_UP | MT_QUIRK_ALWAYS_VALID | @@ -495,9 +505,6 @@ static void mt_feature_mapping(struct hid_device *hdev, if (!td->maxcontacts && field->logical_maximum <= MT_MAX_MAXCONTACT) td->maxcontacts = field->logical_maximum; - if (td->mtclass.maxcontacts) - /* check if the maxcontacts is given by the class */ - td->maxcontacts = td->mtclass.maxcontacts; break; case HID_DG_BUTTONTYPE: @@ -571,13 +578,13 @@ static struct mt_application *mt_allocate_application(struct mt_device *td, mt_application->application = application; INIT_LIST_HEAD(&mt_application->mt_usages); - if (application == HID_DG_TOUCHSCREEN) + if (application == HID_DG_TOUCHSCREEN && !td->mtclass.is_indirect) mt_application->mt_flags |= INPUT_MT_DIRECT; /* * Model touchscreens providing buttons as touchpads. */ - if (application == HID_DG_TOUCHPAD) { + if (application == HID_DG_TOUCHPAD && !td->mtclass.is_direct) { mt_application->mt_flags |= INPUT_MT_POINTER; td->inputmode_value = MT_INPUTMODE_TOUCHPAD; } @@ -641,7 +648,9 @@ static struct mt_report_data *mt_allocate_report_data(struct mt_device *td, if (field->logical == HID_DG_FINGER || td->hdev->group != HID_GROUP_MULTITOUCH_WIN_8) { for (n = 0; n < field->report_count; n++) { - if (field->usage[n].hid == HID_DG_CONTACTID) { + unsigned int hid = field->usage[n].hid; + + if (hid == HID_DG_CONTACTID || hid == HID_DG_TRANSDUCER_INDEX) { rdata->is_mt_collection = true; break; } @@ -813,6 +822,15 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, MT_STORE_FIELD(confidence_state); return 1; + case HID_DG_TOUCH: + /* + * Legacy devices use TIPSWITCH and not TOUCH. + * Let's just ignore this field unless the quirk is set. + */ + if (!(cls->quirks & MT_QUIRK_TOUCH_IS_TIPSTATE)) + return -1; + + fallthrough; case HID_DG_TIPSWITCH: if (field->application != HID_GD_SYSTEM_MULTIAXIS) input_set_capability(hi->input, @@ -820,6 +838,7 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, MT_STORE_FIELD(tip_state); return 1; case HID_DG_CONTACTID: + case HID_DG_TRANSDUCER_INDEX: MT_STORE_FIELD(contactid); app->touches_by_report++; return 1; @@ -875,10 +894,6 @@ static int mt_touch_input_mapping(struct hid_device *hdev, struct hid_input *hi, case HID_DG_CONTACTMAX: /* contact max are global to the report */ return -1; - case HID_DG_TOUCH: - /* Legacy devices use TIPSWITCH and not TOUCH. - * Let's just ignore this field. */ - return -1; } /* let hid-input decide for the others */ return 0; @@ -1306,6 +1321,10 @@ static int mt_touch_input_configured(struct hid_device *hdev, struct input_dev *input = hi->input; int ret; + /* check if the maxcontacts is given by the class */ + if (cls->maxcontacts) + td->maxcontacts = cls->maxcontacts; + if (!td->maxcontacts) td->maxcontacts = MT_DEFAULT_MAXCONTACT; @@ -1313,6 +1332,9 @@ static int mt_touch_input_configured(struct hid_device *hdev, if (td->serial_maybe) mt_post_parse_default_settings(td, app); + if (cls->is_direct) + app->mt_flags |= INPUT_MT_DIRECT; + if (cls->is_indirect) app->mt_flags |= INPUT_MT_POINTER; @@ -1764,6 +1786,15 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) } } + ret = hid_parse(hdev); + if (ret != 0) + return ret; + + if (mtclass->name == MT_CLS_APPLE_TOUCHBAR && + !hid_find_field(hdev, HID_INPUT_REPORT, + HID_DG_TOUCHPAD, HID_DG_TRANSDUCER_INDEX)) + return -ENODEV; + td = devm_kzalloc(&hdev->dev, sizeof(struct mt_device), GFP_KERNEL); if (!td) { dev_err(&hdev->dev, "cannot allocate multitouch data\n"); @@ -1811,10 +1842,6 @@ static int mt_probe(struct hid_device *hdev, const struct hid_device_id *id) timer_setup(&td->release_timer, mt_expired_timeout, 0); - ret = hid_parse(hdev); - if (ret != 0) - return ret; - if (mtclass->quirks & MT_QUIRK_FIX_CONST_CONTACT_ID) mt_fix_const_fields(hdev, HID_DG_CONTACTID); @@ -2299,6 +2326,11 @@ static const struct hid_device_id mt_devices[] = { MT_USB_DEVICE(USB_VENDOR_ID_XIROKU, USB_DEVICE_ID_XIROKU_CSR2) }, + /* Apple Touch Bars */ + { .driver_data = MT_CLS_APPLE_TOUCHBAR, + HID_USB_DEVICE(USB_VENDOR_ID_APPLE, + USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, + /* Google MT devices */ { .driver_data = MT_CLS_GOOGLE, HID_DEVICE(HID_BUS_ANY, HID_GROUP_ANY, USB_VENDOR_ID_GOOGLE, diff --git a/drivers/hid/hid-quirks.c b/drivers/hid/hid-quirks.c index e0bbf0c6345d..7c576d6540fe 100644 --- a/drivers/hid/hid-quirks.c +++ b/drivers/hid/hid-quirks.c @@ -328,8 +328,6 @@ static const struct hid_device_id hid_have_special_driver[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_GEYSER1_TP_ONLY) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_2021) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_MAGIC_KEYBOARD_FINGERPRINT_2021) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, - { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, #endif #if IS_ENABLED(CONFIG_HID_APPLEIR) { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL) }, @@ -338,6 +336,12 @@ static const struct hid_device_id hid_have_special_driver[] = { { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL4) }, { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_IRCONTROL5) }, #endif +#if IS_ENABLED(CONFIG_HID_APPLETB_BL) + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_BACKLIGHT) }, +#endif +#if IS_ENABLED(CONFIG_HID_APPLETB_KBD) + { HID_USB_DEVICE(USB_VENDOR_ID_APPLE, USB_DEVICE_ID_APPLE_TOUCHBAR_DISPLAY) }, +#endif #if IS_ENABLED(CONFIG_HID_ASUS) { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_KEYBOARD) }, { HID_I2C_DEVICE(USB_VENDOR_ID_ASUSTEK, USB_DEVICE_ID_ASUSTEK_I2C_TOUCHPAD) }, diff --git a/drivers/hwmon/applesmc.c b/drivers/hwmon/applesmc.c index fc6d6a9053ce..698f44794453 100644 --- a/drivers/hwmon/applesmc.c +++ b/drivers/hwmon/applesmc.c @@ -6,6 +6,7 @@ * * Copyright (C) 2007 Nicolas Boichat * Copyright (C) 2010 Henrik Rydberg + * Copyright (C) 2019 Paul Pawlowski * * Based on hdaps.c driver: * Copyright (C) 2005 Robert Love @@ -18,7 +19,7 @@ #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt #include -#include +#include #include #include #include @@ -35,12 +36,24 @@ #include /* data port used by Apple SMC */ -#define APPLESMC_DATA_PORT 0x300 +#define APPLESMC_DATA_PORT 0 /* command/status port used by Apple SMC */ -#define APPLESMC_CMD_PORT 0x304 +#define APPLESMC_CMD_PORT 4 #define APPLESMC_NR_PORTS 32 /* 0x300-0x31f */ +#define APPLESMC_IOMEM_KEY_DATA 0 +#define APPLESMC_IOMEM_KEY_STATUS 0x4005 +#define APPLESMC_IOMEM_KEY_NAME 0x78 +#define APPLESMC_IOMEM_KEY_DATA_LEN 0x7D +#define APPLESMC_IOMEM_KEY_SMC_ID 0x7E +#define APPLESMC_IOMEM_KEY_CMD 0x7F +#define APPLESMC_IOMEM_MIN_SIZE 0x4006 + +#define APPLESMC_IOMEM_KEY_TYPE_CODE 0 +#define APPLESMC_IOMEM_KEY_TYPE_DATA_LEN 5 +#define APPLESMC_IOMEM_KEY_TYPE_FLAGS 6 + #define APPLESMC_MAX_DATA_LENGTH 32 /* Apple SMC status bits */ @@ -74,6 +87,7 @@ #define FAN_ID_FMT "F%dID" /* r-o char[16] */ #define TEMP_SENSOR_TYPE "sp78" +#define FLOAT_TYPE "flt " /* List of keys used to read/write fan speeds */ static const char *const fan_speed_fmt[] = { @@ -83,6 +97,7 @@ static const char *const fan_speed_fmt[] = { "F%dSf", /* safe speed - not all models */ "F%dTg", /* target speed (manual: rw) */ }; +#define FAN_MANUAL_FMT "F%dMd" #define INIT_TIMEOUT_MSECS 5000 /* wait up to 5s for device init ... */ #define INIT_WAIT_MSECS 50 /* ... in 50ms increments */ @@ -119,7 +134,7 @@ struct applesmc_entry { }; /* Register lookup and registers common to all SMCs */ -static struct applesmc_registers { +struct applesmc_registers { struct mutex mutex; /* register read/write mutex */ unsigned int key_count; /* number of SMC registers */ unsigned int fan_count; /* number of fans */ @@ -133,26 +148,38 @@ static struct applesmc_registers { bool init_complete; /* true when fully initialized */ struct applesmc_entry *cache; /* cached key entries */ const char **index; /* temperature key index */ -} smcreg = { - .mutex = __MUTEX_INITIALIZER(smcreg.mutex), }; -static const int debug; -static struct platform_device *pdev; -static s16 rest_x; -static s16 rest_y; -static u8 backlight_state[2]; +struct applesmc_device { + struct acpi_device *dev; + struct device *ldev; + struct applesmc_registers reg; -static struct device *hwmon_dev; -static struct input_dev *applesmc_idev; + bool port_base_set, iomem_base_set; + u16 port_base; + u8 *__iomem iomem_base; + u32 iomem_base_addr, iomem_base_size; -/* - * Last index written to key_at_index sysfs file, and value to use for all other - * key_at_index_* sysfs files. - */ -static unsigned int key_at_index; + s16 rest_x; + s16 rest_y; + + u8 backlight_state[2]; + + struct device *hwmon_dev; + struct input_dev *idev; + + /* + * Last index written to key_at_index sysfs file, and value to use for all other + * key_at_index_* sysfs files. + */ + unsigned int key_at_index; -static struct workqueue_struct *applesmc_led_wq; + struct workqueue_struct *backlight_wq; + struct work_struct backlight_work; + struct led_classdev backlight_dev; +}; + +static const int debug; /* * Wait for specific status bits with a mask on the SMC. @@ -162,7 +189,7 @@ static struct workqueue_struct *applesmc_led_wq; * run out past 500ms. */ -static int wait_status(u8 val, u8 mask) +static int port_wait_status(struct applesmc_device *smc, u8 val, u8 mask) { u8 status; int us; @@ -170,7 +197,7 @@ static int wait_status(u8 val, u8 mask) us = APPLESMC_MIN_WAIT; for (i = 0; i < 24 ; i++) { - status = inb(APPLESMC_CMD_PORT); + status = inb(smc->port_base + APPLESMC_CMD_PORT); if ((status & mask) == val) return 0; usleep_range(us, us * 2); @@ -180,13 +207,13 @@ static int wait_status(u8 val, u8 mask) return -EIO; } -/* send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ +/* port_send_byte - Write to SMC data port. Callers must hold applesmc_lock. */ -static int send_byte(u8 cmd, u16 port) +static int port_send_byte(struct applesmc_device *smc, u8 cmd, u16 port) { int status; - status = wait_status(0, SMC_STATUS_IB_CLOSED); + status = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); if (status) return status; /* @@ -195,24 +222,25 @@ static int send_byte(u8 cmd, u16 port) * this extra read may not happen if status returns both * simultaneously and this would appear to be required. */ - status = wait_status(SMC_STATUS_BUSY, SMC_STATUS_BUSY); + status = port_wait_status(smc, SMC_STATUS_BUSY, SMC_STATUS_BUSY); if (status) return status; - outb(cmd, port); + outb(cmd, smc->port_base + port); return 0; } -/* send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ +/* port_send_command - Write a command to the SMC. Callers must hold applesmc_lock. */ -static int send_command(u8 cmd) +static int port_send_command(struct applesmc_device *smc, u8 cmd) { int ret; - ret = wait_status(0, SMC_STATUS_IB_CLOSED); + ret = port_wait_status(smc, 0, SMC_STATUS_IB_CLOSED); if (ret) return ret; - outb(cmd, APPLESMC_CMD_PORT); + + outb(cmd, smc->port_base + APPLESMC_CMD_PORT); return 0; } @@ -222,110 +250,304 @@ static int send_command(u8 cmd) * If busy is stuck high after the command then the SMC is jammed. */ -static int smc_sane(void) +static int port_smc_sane(struct applesmc_device *smc) { int ret; - ret = wait_status(0, SMC_STATUS_BUSY); + ret = port_wait_status(smc, 0, SMC_STATUS_BUSY); if (!ret) return ret; - ret = send_command(APPLESMC_READ_CMD); + ret = port_send_command(smc, APPLESMC_READ_CMD); if (ret) return ret; - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int send_argument(const char *key) +static int port_send_argument(struct applesmc_device *smc, const char *key) { int i; for (i = 0; i < 4; i++) - if (send_byte(key[i], APPLESMC_DATA_PORT)) + if (port_send_byte(smc, key[i], APPLESMC_DATA_PORT)) return -EIO; return 0; } -static int read_smc(u8 cmd, const char *key, u8 *buffer, u8 len) +static int port_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, + u8 *buffer, u8 len) { u8 status, data = 0; int i; int ret; - ret = smc_sane(); + ret = port_smc_sane(smc); if (ret) return ret; - if (send_command(cmd) || send_argument(key)) { + if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { pr_warn("%.4s: read arg fail\n", key); return -EIO; } /* This has no effect on newer (2012) SMCs */ - if (send_byte(len, APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { pr_warn("%.4s: read len fail\n", key); return -EIO; } for (i = 0; i < len; i++) { - if (wait_status(SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, + if (port_wait_status(smc, + SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY, SMC_STATUS_AWAITING_DATA | SMC_STATUS_BUSY)) { pr_warn("%.4s: read data[%d] fail\n", key, i); return -EIO; } - buffer[i] = inb(APPLESMC_DATA_PORT); + buffer[i] = inb(smc->port_base + APPLESMC_DATA_PORT); } /* Read the data port until bit0 is cleared */ for (i = 0; i < 16; i++) { udelay(APPLESMC_MIN_WAIT); - status = inb(APPLESMC_CMD_PORT); + status = inb(smc->port_base + APPLESMC_CMD_PORT); if (!(status & SMC_STATUS_AWAITING_DATA)) break; - data = inb(APPLESMC_DATA_PORT); + data = inb(smc->port_base + APPLESMC_DATA_PORT); } if (i) pr_warn("flushed %d bytes, last value is: %d\n", i, data); - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int write_smc(u8 cmd, const char *key, const u8 *buffer, u8 len) +static int port_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, + const u8 *buffer, u8 len) { int i; int ret; - ret = smc_sane(); + ret = port_smc_sane(smc); if (ret) return ret; - if (send_command(cmd) || send_argument(key)) { + if (port_send_command(smc, cmd) || port_send_argument(smc, key)) { pr_warn("%s: write arg fail\n", key); return -EIO; } - if (send_byte(len, APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, len, APPLESMC_DATA_PORT)) { pr_warn("%.4s: write len fail\n", key); return -EIO; } for (i = 0; i < len; i++) { - if (send_byte(buffer[i], APPLESMC_DATA_PORT)) { + if (port_send_byte(smc, buffer[i], APPLESMC_DATA_PORT)) { pr_warn("%s: write data fail\n", key); return -EIO; } } - return wait_status(0, SMC_STATUS_BUSY); + return port_wait_status(smc, 0, SMC_STATUS_BUSY); } -static int read_register_count(unsigned int *count) +static int port_get_smc_key_info(struct applesmc_device *smc, + const char *key, struct applesmc_entry *info) { - __be32 be; int ret; + u8 raw[6]; - ret = read_smc(APPLESMC_READ_CMD, KEY_COUNT_KEY, (u8 *)&be, 4); + ret = port_read_smc(smc, APPLESMC_GET_KEY_TYPE_CMD, key, raw, 6); if (ret) return ret; + info->len = raw[0]; + memcpy(info->type, &raw[1], 4); + info->flags = raw[5]; + return 0; +} + + +/* + * MMIO based communication. + * TODO: Use updated mechanism for cmd timeout/retry + */ + +static void iomem_clear_status(struct applesmc_device *smc) +{ + if (ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS)) + iowrite8(0, smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); +} + +static int iomem_wait_read(struct applesmc_device *smc) +{ + u8 status; + int us; + int i; + + us = APPLESMC_MIN_WAIT; + for (i = 0; i < 24 ; i++) { + status = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); + if (status & 0x20) + return 0; + usleep_range(us, us * 2); + if (i > 9) + us <<= 1; + } + + dev_warn(smc->ldev, "%s... timeout\n", __func__); + return -EIO; +} + +static int iomem_read_smc(struct applesmc_device *smc, u8 cmd, const char *key, + u8 *buffer, u8 len) +{ + u8 err, remote_len; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "read_smc_mmio(%x %8x/%.4s) failed: %u\n", + cmd, key_int, key, err); + return -EIO; + } + + if (cmd == APPLESMC_READ_CMD) { + remote_len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); + if (remote_len != len) { + dev_warn(smc->ldev, + "read_smc_mmio(%x %8x/%.4s) failed: buffer length mismatch (remote = %u, requested = %u)\n", + cmd, key_int, key, remote_len, len); + return -EINVAL; + } + } else { + remote_len = len; + } + + memcpy_fromio(buffer, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, + remote_len); + + dev_dbg(smc->ldev, "read_smc_mmio(%x %8x/%.4s): buflen=%u reslen=%u\n", + cmd, key_int, key, len, remote_len); + print_hex_dump_bytes("read_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, remote_len); + return 0; +} + +static int iomem_get_smc_key_type(struct applesmc_device *smc, const char *key, + struct applesmc_entry *e) +{ + u8 err; + u8 cmd = APPLESMC_GET_KEY_TYPE_CMD; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "get_smc_key_type_mmio(%.4s) failed: %u\n", key, err); + return -EIO; + } + + e->len = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_DATA_LEN); + *((uint32_t *) e->type) = ioread32( + smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_CODE); + e->flags = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_TYPE_FLAGS); + + dev_dbg(smc->ldev, "get_smc_key_type_mmio(%.4s): len=%u type=%.4s flags=%x\n", + key, e->len, e->type, e->flags); + return 0; +} + +static int iomem_write_smc(struct applesmc_device *smc, u8 cmd, const char *key, + const u8 *buffer, u8 len) +{ + u8 err; + u32 key_int = *((u32 *) key); + + iomem_clear_status(smc); + iowrite32(key_int, smc->iomem_base + APPLESMC_IOMEM_KEY_NAME); + memcpy_toio(smc->iomem_base + APPLESMC_IOMEM_KEY_DATA, buffer, len); + iowrite32(len, smc->iomem_base + APPLESMC_IOMEM_KEY_DATA_LEN); + iowrite32(0, smc->iomem_base + APPLESMC_IOMEM_KEY_SMC_ID); + iowrite32(cmd, smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + + if (iomem_wait_read(smc)) + return -EIO; + + err = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_CMD); + if (err != 0) { + dev_warn(smc->ldev, "write_smc_mmio(%x %.4s) failed: %u\n", cmd, key, err); + print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); + return -EIO; + } + + dev_dbg(smc->ldev, "write_smc_mmio(%x %.4s): buflen=%u\n", cmd, key, len); + print_hex_dump_bytes("write_smc_mmio(): ", DUMP_PREFIX_NONE, buffer, len); + return 0; +} + + +static int read_smc(struct applesmc_device *smc, const char *key, + u8 *buffer, u8 len) +{ + if (smc->iomem_base_set) + return iomem_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); + else + return port_read_smc(smc, APPLESMC_READ_CMD, key, buffer, len); +} + +static int write_smc(struct applesmc_device *smc, const char *key, + const u8 *buffer, u8 len) +{ + if (smc->iomem_base_set) + return iomem_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); + else + return port_write_smc(smc, APPLESMC_WRITE_CMD, key, buffer, len); +} + +static int get_smc_key_by_index(struct applesmc_device *smc, + unsigned int index, char *key) +{ + __be32 be; + + be = cpu_to_be32(index); + if (smc->iomem_base_set) + return iomem_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, + (const char *) &be, (u8 *) key, 4); + else + return port_read_smc(smc, APPLESMC_GET_KEY_BY_INDEX_CMD, + (const char *) &be, (u8 *) key, 4); +} + +static int get_smc_key_info(struct applesmc_device *smc, const char *key, + struct applesmc_entry *info) +{ + if (smc->iomem_base_set) + return iomem_get_smc_key_type(smc, key, info); + else + return port_get_smc_key_info(smc, key, info); +} + +static int read_register_count(struct applesmc_device *smc, + unsigned int *count) +{ + __be32 be; + int ret; + + ret = read_smc(smc, KEY_COUNT_KEY, (u8 *)&be, 4); + if (ret < 0) + return ret; *count = be32_to_cpu(be); return 0; @@ -338,76 +560,73 @@ static int read_register_count(unsigned int *count) * All functions below are concurrency safe - callers should NOT hold lock. */ -static int applesmc_read_entry(const struct applesmc_entry *entry, - u8 *buf, u8 len) +static int applesmc_read_entry(struct applesmc_device *smc, + const struct applesmc_entry *entry, u8 *buf, u8 len) { int ret; if (entry->len != len) return -EINVAL; - mutex_lock(&smcreg.mutex); - ret = read_smc(APPLESMC_READ_CMD, entry->key, buf, len); - mutex_unlock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); + ret = read_smc(smc, entry->key, buf, len); + mutex_unlock(&smc->reg.mutex); return ret; } -static int applesmc_write_entry(const struct applesmc_entry *entry, - const u8 *buf, u8 len) +static int applesmc_write_entry(struct applesmc_device *smc, + const struct applesmc_entry *entry, const u8 *buf, u8 len) { int ret; if (entry->len != len) return -EINVAL; - mutex_lock(&smcreg.mutex); - ret = write_smc(APPLESMC_WRITE_CMD, entry->key, buf, len); - mutex_unlock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); + ret = write_smc(smc, entry->key, buf, len); + mutex_unlock(&smc->reg.mutex); return ret; } -static const struct applesmc_entry *applesmc_get_entry_by_index(int index) +static const struct applesmc_entry *applesmc_get_entry_by_index( + struct applesmc_device *smc, int index) { - struct applesmc_entry *cache = &smcreg.cache[index]; - u8 key[4], info[6]; - __be32 be; + struct applesmc_entry *cache = &smc->reg.cache[index]; + char key[4]; int ret = 0; if (cache->valid) return cache; - mutex_lock(&smcreg.mutex); + mutex_lock(&smc->reg.mutex); if (cache->valid) goto out; - be = cpu_to_be32(index); - ret = read_smc(APPLESMC_GET_KEY_BY_INDEX_CMD, (u8 *)&be, key, 4); + ret = get_smc_key_by_index(smc, index, key); if (ret) goto out; - ret = read_smc(APPLESMC_GET_KEY_TYPE_CMD, key, info, 6); + memcpy(cache->key, key, 4); + + ret = get_smc_key_info(smc, key, cache); if (ret) goto out; - - memcpy(cache->key, key, 4); - cache->len = info[0]; - memcpy(cache->type, &info[1], 4); - cache->flags = info[5]; cache->valid = true; out: - mutex_unlock(&smcreg.mutex); + mutex_unlock(&smc->reg.mutex); if (ret) return ERR_PTR(ret); return cache; } -static int applesmc_get_lower_bound(unsigned int *lo, const char *key) +static int applesmc_get_lower_bound(struct applesmc_device *smc, + unsigned int *lo, const char *key) { - int begin = 0, end = smcreg.key_count; + int begin = 0, end = smc->reg.key_count; const struct applesmc_entry *entry; while (begin != end) { int middle = begin + (end - begin) / 2; - entry = applesmc_get_entry_by_index(middle); + entry = applesmc_get_entry_by_index(smc, middle); if (IS_ERR(entry)) { *lo = 0; return PTR_ERR(entry); @@ -422,16 +641,17 @@ static int applesmc_get_lower_bound(unsigned int *lo, const char *key) return 0; } -static int applesmc_get_upper_bound(unsigned int *hi, const char *key) +static int applesmc_get_upper_bound(struct applesmc_device *smc, + unsigned int *hi, const char *key) { - int begin = 0, end = smcreg.key_count; + int begin = 0, end = smc->reg.key_count; const struct applesmc_entry *entry; while (begin != end) { int middle = begin + (end - begin) / 2; - entry = applesmc_get_entry_by_index(middle); + entry = applesmc_get_entry_by_index(smc, middle); if (IS_ERR(entry)) { - *hi = smcreg.key_count; + *hi = smc->reg.key_count; return PTR_ERR(entry); } if (strcmp(key, entry->key) < 0) @@ -444,50 +664,54 @@ static int applesmc_get_upper_bound(unsigned int *hi, const char *key) return 0; } -static const struct applesmc_entry *applesmc_get_entry_by_key(const char *key) +static const struct applesmc_entry *applesmc_get_entry_by_key( + struct applesmc_device *smc, const char *key) { int begin, end; int ret; - ret = applesmc_get_lower_bound(&begin, key); + ret = applesmc_get_lower_bound(smc, &begin, key); if (ret) return ERR_PTR(ret); - ret = applesmc_get_upper_bound(&end, key); + ret = applesmc_get_upper_bound(smc, &end, key); if (ret) return ERR_PTR(ret); if (end - begin != 1) return ERR_PTR(-EINVAL); - return applesmc_get_entry_by_index(begin); + return applesmc_get_entry_by_index(smc, begin); } -static int applesmc_read_key(const char *key, u8 *buffer, u8 len) +static int applesmc_read_key(struct applesmc_device *smc, + const char *key, u8 *buffer, u8 len) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry)) return PTR_ERR(entry); - return applesmc_read_entry(entry, buffer, len); + return applesmc_read_entry(smc, entry, buffer, len); } -static int applesmc_write_key(const char *key, const u8 *buffer, u8 len) +static int applesmc_write_key(struct applesmc_device *smc, + const char *key, const u8 *buffer, u8 len) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry)) return PTR_ERR(entry); - return applesmc_write_entry(entry, buffer, len); + return applesmc_write_entry(smc, entry, buffer, len); } -static int applesmc_has_key(const char *key, bool *value) +static int applesmc_has_key(struct applesmc_device *smc, + const char *key, bool *value) { const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_key(key); + entry = applesmc_get_entry_by_key(smc, key); if (IS_ERR(entry) && PTR_ERR(entry) != -EINVAL) return PTR_ERR(entry); @@ -498,12 +722,13 @@ static int applesmc_has_key(const char *key, bool *value) /* * applesmc_read_s16 - Read 16-bit signed big endian register */ -static int applesmc_read_s16(const char *key, s16 *value) +static int applesmc_read_s16(struct applesmc_device *smc, + const char *key, s16 *value) { u8 buffer[2]; int ret; - ret = applesmc_read_key(key, buffer, 2); + ret = applesmc_read_key(smc, key, buffer, 2); if (ret) return ret; @@ -511,31 +736,68 @@ static int applesmc_read_s16(const char *key, s16 *value) return 0; } +/** + * applesmc_float_to_u32 - Retrieve the integral part of a float. + * This is needed because Apple made fans use float values in the T2. + * The fractional point is not significantly useful though, and the integral + * part can be easily extracted. + */ +static inline u32 applesmc_float_to_u32(u32 d) +{ + u8 sign = (u8) ((d >> 31) & 1); + s32 exp = (s32) ((d >> 23) & 0xff) - 0x7f; + u32 fr = d & ((1u << 23) - 1); + + if (sign || exp < 0) + return 0; + + return (u32) ((1u << exp) + (fr >> (23 - exp))); +} + +/** + * applesmc_u32_to_float - Convert an u32 into a float. + * See applesmc_float_to_u32 for a rationale. + */ +static inline u32 applesmc_u32_to_float(u32 d) +{ + u32 dc = d, bc = 0, exp; + + if (!d) + return 0; + + while (dc >>= 1) + ++bc; + exp = 0x7f + bc; + + return (u32) ((exp << 23) | + ((d << (23 - (exp - 0x7f))) & ((1u << 23) - 1))); +} /* * applesmc_device_init - initialize the accelerometer. Can sleep. */ -static void applesmc_device_init(void) +static void applesmc_device_init(struct applesmc_device *smc) { int total; u8 buffer[2]; - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return; for (total = INIT_TIMEOUT_MSECS; total > 0; total -= INIT_WAIT_MSECS) { - if (!applesmc_read_key(MOTION_SENSOR_KEY, buffer, 2) && + if (!applesmc_read_key(smc, MOTION_SENSOR_KEY, buffer, 2) && (buffer[0] != 0x00 || buffer[1] != 0x00)) return; buffer[0] = 0xe0; buffer[1] = 0x00; - applesmc_write_key(MOTION_SENSOR_KEY, buffer, 2); + applesmc_write_key(smc, MOTION_SENSOR_KEY, buffer, 2); msleep(INIT_WAIT_MSECS); } pr_warn("failed to init the device\n"); } -static int applesmc_init_index(struct applesmc_registers *s) +static int applesmc_init_index(struct applesmc_device *smc, + struct applesmc_registers *s) { const struct applesmc_entry *entry; unsigned int i; @@ -548,7 +810,7 @@ static int applesmc_init_index(struct applesmc_registers *s) return -ENOMEM; for (i = s->temp_begin; i < s->temp_end; i++) { - entry = applesmc_get_entry_by_index(i); + entry = applesmc_get_entry_by_index(smc, i); if (IS_ERR(entry)) continue; if (strcmp(entry->type, TEMP_SENSOR_TYPE)) @@ -562,9 +824,9 @@ static int applesmc_init_index(struct applesmc_registers *s) /* * applesmc_init_smcreg_try - Try to initialize register cache. Idempotent. */ -static int applesmc_init_smcreg_try(void) +static int applesmc_init_smcreg_try(struct applesmc_device *smc) { - struct applesmc_registers *s = &smcreg; + struct applesmc_registers *s = &smc->reg; bool left_light_sensor = false, right_light_sensor = false; unsigned int count; u8 tmp[1]; @@ -573,7 +835,7 @@ static int applesmc_init_smcreg_try(void) if (s->init_complete) return 0; - ret = read_register_count(&count); + ret = read_register_count(smc, &count); if (ret) return ret; @@ -590,35 +852,35 @@ static int applesmc_init_smcreg_try(void) if (!s->cache) return -ENOMEM; - ret = applesmc_read_key(FANS_COUNT, tmp, 1); + ret = applesmc_read_key(smc, FANS_COUNT, tmp, 1); if (ret) return ret; s->fan_count = tmp[0]; if (s->fan_count > 10) s->fan_count = 10; - ret = applesmc_get_lower_bound(&s->temp_begin, "T"); + ret = applesmc_get_lower_bound(smc, &s->temp_begin, "T"); if (ret) return ret; - ret = applesmc_get_lower_bound(&s->temp_end, "U"); + ret = applesmc_get_lower_bound(smc, &s->temp_end, "U"); if (ret) return ret; s->temp_count = s->temp_end - s->temp_begin; - ret = applesmc_init_index(s); + ret = applesmc_init_index(smc, s); if (ret) return ret; - ret = applesmc_has_key(LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); + ret = applesmc_has_key(smc, LIGHT_SENSOR_LEFT_KEY, &left_light_sensor); if (ret) return ret; - ret = applesmc_has_key(LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); + ret = applesmc_has_key(smc, LIGHT_SENSOR_RIGHT_KEY, &right_light_sensor); if (ret) return ret; - ret = applesmc_has_key(MOTION_SENSOR_KEY, &s->has_accelerometer); + ret = applesmc_has_key(smc, MOTION_SENSOR_KEY, &s->has_accelerometer); if (ret) return ret; - ret = applesmc_has_key(BACKLIGHT_KEY, &s->has_key_backlight); + ret = applesmc_has_key(smc, BACKLIGHT_KEY, &s->has_key_backlight); if (ret) return ret; @@ -634,13 +896,13 @@ static int applesmc_init_smcreg_try(void) return 0; } -static void applesmc_destroy_smcreg(void) +static void applesmc_destroy_smcreg(struct applesmc_device *smc) { - kfree(smcreg.index); - smcreg.index = NULL; - kfree(smcreg.cache); - smcreg.cache = NULL; - smcreg.init_complete = false; + kfree(smc->reg.index); + smc->reg.index = NULL; + kfree(smc->reg.cache); + smc->reg.cache = NULL; + smc->reg.init_complete = false; } /* @@ -649,12 +911,12 @@ static void applesmc_destroy_smcreg(void) * Retries until initialization is successful, or the operation times out. * */ -static int applesmc_init_smcreg(void) +static int applesmc_init_smcreg(struct applesmc_device *smc) { int ms, ret; for (ms = 0; ms < INIT_TIMEOUT_MSECS; ms += INIT_WAIT_MSECS) { - ret = applesmc_init_smcreg_try(); + ret = applesmc_init_smcreg_try(smc); if (!ret) { if (ms) pr_info("init_smcreg() took %d ms\n", ms); @@ -663,50 +925,223 @@ static int applesmc_init_smcreg(void) msleep(INIT_WAIT_MSECS); } - applesmc_destroy_smcreg(); + applesmc_destroy_smcreg(smc); return ret; } /* Device model stuff */ -static int applesmc_probe(struct platform_device *dev) + +static int applesmc_init_resources(struct applesmc_device *smc); +static void applesmc_free_resources(struct applesmc_device *smc); +static int applesmc_create_modules(struct applesmc_device *smc); +static void applesmc_destroy_modules(struct applesmc_device *smc); + +static int applesmc_add(struct acpi_device *dev) { + struct applesmc_device *smc; int ret; - ret = applesmc_init_smcreg(); + smc = kzalloc(sizeof(struct applesmc_device), GFP_KERNEL); + if (!smc) + return -ENOMEM; + smc->dev = dev; + smc->ldev = &dev->dev; + mutex_init(&smc->reg.mutex); + + dev_set_drvdata(&dev->dev, smc); + + ret = applesmc_init_resources(smc); if (ret) - return ret; + goto out_mem; + + ret = applesmc_init_smcreg(smc); + if (ret) + goto out_res; + + applesmc_device_init(smc); + + ret = applesmc_create_modules(smc); + if (ret) + goto out_reg; + + return 0; + +out_reg: + applesmc_destroy_smcreg(smc); +out_res: + applesmc_free_resources(smc); +out_mem: + dev_set_drvdata(&dev->dev, NULL); + mutex_destroy(&smc->reg.mutex); + kfree(smc); + + return ret; +} + +static void applesmc_remove(struct acpi_device *dev) +{ + struct applesmc_device *smc = dev_get_drvdata(&dev->dev); + + applesmc_destroy_modules(smc); + applesmc_destroy_smcreg(smc); + applesmc_free_resources(smc); - applesmc_device_init(); + mutex_destroy(&smc->reg.mutex); + kfree(smc); + + return; +} + +static acpi_status applesmc_walk_resources(struct acpi_resource *res, + void *data) +{ + struct applesmc_device *smc = data; + + switch (res->type) { + case ACPI_RESOURCE_TYPE_IO: + if (!smc->port_base_set) { + if (res->data.io.address_length < APPLESMC_NR_PORTS) + return AE_ERROR; + smc->port_base = res->data.io.minimum; + smc->port_base_set = true; + } + return AE_OK; + + case ACPI_RESOURCE_TYPE_FIXED_MEMORY32: + if (!smc->iomem_base_set) { + if (res->data.fixed_memory32.address_length < + APPLESMC_IOMEM_MIN_SIZE) { + dev_warn(smc->ldev, "found iomem but it's too small: %u\n", + res->data.fixed_memory32.address_length); + return AE_OK; + } + smc->iomem_base_addr = res->data.fixed_memory32.address; + smc->iomem_base_size = res->data.fixed_memory32.address_length; + smc->iomem_base_set = true; + } + return AE_OK; + + case ACPI_RESOURCE_TYPE_END_TAG: + if (smc->port_base_set) + return AE_OK; + else + return AE_NOT_FOUND; + + default: + return AE_OK; + } +} + +static int applesmc_try_enable_iomem(struct applesmc_device *smc); + +static int applesmc_init_resources(struct applesmc_device *smc) +{ + int ret; + + ret = acpi_walk_resources(smc->dev->handle, METHOD_NAME__CRS, + applesmc_walk_resources, smc); + if (ACPI_FAILURE(ret)) + return -ENXIO; + + if (!request_region(smc->port_base, APPLESMC_NR_PORTS, "applesmc")) + return -ENXIO; + + if (smc->iomem_base_set) { + if (applesmc_try_enable_iomem(smc)) + smc->iomem_base_set = false; + } + + return 0; +} + +static int applesmc_try_enable_iomem(struct applesmc_device *smc) +{ + u8 test_val, ldkn_version; + + dev_dbg(smc->ldev, "Trying to enable iomem based communication\n"); + smc->iomem_base = ioremap(smc->iomem_base_addr, smc->iomem_base_size); + if (!smc->iomem_base) + goto out; + + /* Apple's driver does this check for some reason */ + test_val = ioread8(smc->iomem_base + APPLESMC_IOMEM_KEY_STATUS); + if (test_val == 0xff) { + dev_warn(smc->ldev, + "iomem enable failed: initial status is 0xff (is %x)\n", + test_val); + goto out_iomem; + } + + if (read_smc(smc, "LDKN", &ldkn_version, 1)) { + dev_warn(smc->ldev, "iomem enable failed: ldkn read failed\n"); + goto out_iomem; + } + + if (ldkn_version < 2) { + dev_warn(smc->ldev, + "iomem enable failed: ldkn version %u is less than minimum (2)\n", + ldkn_version); + goto out_iomem; + } return 0; + +out_iomem: + iounmap(smc->iomem_base); + +out: + return -ENXIO; +} + +static void applesmc_free_resources(struct applesmc_device *smc) +{ + if (smc->iomem_base_set) + iounmap(smc->iomem_base); + release_region(smc->port_base, APPLESMC_NR_PORTS); } /* Synchronize device with memorized backlight state */ static int applesmc_pm_resume(struct device *dev) { - if (smcreg.has_key_backlight) - applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); + struct applesmc_device *smc = dev_get_drvdata(dev); + + if (smc->reg.has_key_backlight) + applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); + return 0; } /* Reinitialize device on resume from hibernation */ static int applesmc_pm_restore(struct device *dev) { - applesmc_device_init(); + struct applesmc_device *smc = dev_get_drvdata(dev); + + applesmc_device_init(smc); + return applesmc_pm_resume(dev); } +static const struct acpi_device_id applesmc_ids[] = { + {"APP0001", 0}, + {"", 0}, +}; + static const struct dev_pm_ops applesmc_pm_ops = { .resume = applesmc_pm_resume, .restore = applesmc_pm_restore, }; -static struct platform_driver applesmc_driver = { - .probe = applesmc_probe, - .driver = { - .name = "applesmc", - .pm = &applesmc_pm_ops, +static struct acpi_driver applesmc_driver = { + .name = "applesmc", + .class = "applesmc", + .ids = applesmc_ids, + .ops = { + .add = applesmc_add, + .remove = applesmc_remove + }, + .drv = { + .pm = &applesmc_pm_ops }, }; @@ -714,25 +1149,26 @@ static struct platform_driver applesmc_driver = { * applesmc_calibrate - Set our "resting" values. Callers must * hold applesmc_lock. */ -static void applesmc_calibrate(void) +static void applesmc_calibrate(struct applesmc_device *smc) { - applesmc_read_s16(MOTION_SENSOR_X_KEY, &rest_x); - applesmc_read_s16(MOTION_SENSOR_Y_KEY, &rest_y); - rest_x = -rest_x; + applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &smc->rest_x); + applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &smc->rest_y); + smc->rest_x = -smc->rest_x; } static void applesmc_idev_poll(struct input_dev *idev) { + struct applesmc_device *smc = dev_get_drvdata(&idev->dev); s16 x, y; - if (applesmc_read_s16(MOTION_SENSOR_X_KEY, &x)) + if (applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x)) return; - if (applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y)) + if (applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y)) return; x = -x; - input_report_abs(idev, ABS_X, x - rest_x); - input_report_abs(idev, ABS_Y, y - rest_y); + input_report_abs(idev, ABS_X, x - smc->rest_x); + input_report_abs(idev, ABS_Y, y - smc->rest_y); input_sync(idev); } @@ -747,16 +1183,17 @@ static ssize_t applesmc_name_show(struct device *dev, static ssize_t applesmc_position_show(struct device *dev, struct device_attribute *attr, char *buf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; s16 x, y, z; - ret = applesmc_read_s16(MOTION_SENSOR_X_KEY, &x); + ret = applesmc_read_s16(smc, MOTION_SENSOR_X_KEY, &x); if (ret) goto out; - ret = applesmc_read_s16(MOTION_SENSOR_Y_KEY, &y); + ret = applesmc_read_s16(smc, MOTION_SENSOR_Y_KEY, &y); if (ret) goto out; - ret = applesmc_read_s16(MOTION_SENSOR_Z_KEY, &z); + ret = applesmc_read_s16(smc, MOTION_SENSOR_Z_KEY, &z); if (ret) goto out; @@ -770,6 +1207,7 @@ static ssize_t applesmc_position_show(struct device *dev, static ssize_t applesmc_light_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; static int data_length; int ret; @@ -777,7 +1215,7 @@ static ssize_t applesmc_light_show(struct device *dev, u8 buffer[10]; if (!data_length) { - entry = applesmc_get_entry_by_key(LIGHT_SENSOR_LEFT_KEY); + entry = applesmc_get_entry_by_key(smc, LIGHT_SENSOR_LEFT_KEY); if (IS_ERR(entry)) return PTR_ERR(entry); if (entry->len > 10) @@ -786,7 +1224,7 @@ static ssize_t applesmc_light_show(struct device *dev, pr_info("light sensor data length set to %d\n", data_length); } - ret = applesmc_read_key(LIGHT_SENSOR_LEFT_KEY, buffer, data_length); + ret = applesmc_read_key(smc, LIGHT_SENSOR_LEFT_KEY, buffer, data_length); if (ret) goto out; /* newer macbooks report a single 10-bit bigendian value */ @@ -796,7 +1234,7 @@ static ssize_t applesmc_light_show(struct device *dev, } left = buffer[2]; - ret = applesmc_read_key(LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); + ret = applesmc_read_key(smc, LIGHT_SENSOR_RIGHT_KEY, buffer, data_length); if (ret) goto out; right = buffer[2]; @@ -812,7 +1250,8 @@ static ssize_t applesmc_light_show(struct device *dev, static ssize_t applesmc_show_sensor_label(struct device *dev, struct device_attribute *devattr, char *sysfsbuf) { - const char *key = smcreg.index[to_index(devattr)]; + struct applesmc_device *smc = dev_get_drvdata(dev); + const char *key = smc->reg.index[to_index(devattr)]; return sysfs_emit(sysfsbuf, "%s\n", key); } @@ -821,12 +1260,13 @@ static ssize_t applesmc_show_sensor_label(struct device *dev, static ssize_t applesmc_show_temperature(struct device *dev, struct device_attribute *devattr, char *sysfsbuf) { - const char *key = smcreg.index[to_index(devattr)]; + struct applesmc_device *smc = dev_get_drvdata(dev); + const char *key = smc->reg.index[to_index(devattr)]; int ret; s16 value; int temp; - ret = applesmc_read_s16(key, &value); + ret = applesmc_read_s16(smc, key, &value); if (ret) return ret; @@ -838,6 +1278,8 @@ static ssize_t applesmc_show_temperature(struct device *dev, static ssize_t applesmc_show_fan_speed(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; int ret; unsigned int speed = 0; char newkey[5]; @@ -846,11 +1288,21 @@ static ssize_t applesmc_show_fan_speed(struct device *dev, scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], to_index(attr)); - ret = applesmc_read_key(newkey, buffer, 2); + entry = applesmc_get_entry_by_key(smc, newkey); + if (IS_ERR(entry)) + return PTR_ERR(entry); + + if (!strcmp(entry->type, FLOAT_TYPE)) { + ret = applesmc_read_entry(smc, entry, (u8 *) &speed, 4); + speed = applesmc_float_to_u32(speed); + } else { + ret = applesmc_read_entry(smc, entry, buffer, 2); + speed = ((buffer[0] << 8 | buffer[1]) >> 2); + } + if (ret) return ret; - speed = ((buffer[0] << 8 | buffer[1]) >> 2); return sysfs_emit(sysfsbuf, "%u\n", speed); } @@ -858,6 +1310,8 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); + const struct applesmc_entry *entry; int ret; unsigned long speed; char newkey[5]; @@ -869,9 +1323,18 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, scnprintf(newkey, sizeof(newkey), fan_speed_fmt[to_option(attr)], to_index(attr)); - buffer[0] = (speed >> 6) & 0xff; - buffer[1] = (speed << 2) & 0xff; - ret = applesmc_write_key(newkey, buffer, 2); + entry = applesmc_get_entry_by_key(smc, newkey); + if (IS_ERR(entry)) + return PTR_ERR(entry); + + if (!strcmp(entry->type, FLOAT_TYPE)) { + speed = applesmc_u32_to_float(speed); + ret = applesmc_write_entry(smc, entry, (u8 *) &speed, 4); + } else { + buffer[0] = (speed >> 6) & 0xff; + buffer[1] = (speed << 2) & 0xff; + ret = applesmc_write_key(smc, newkey, buffer, 2); + } if (ret) return ret; @@ -882,15 +1345,30 @@ static ssize_t applesmc_store_fan_speed(struct device *dev, static ssize_t applesmc_show_fan_manual(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u16 manual = 0; u8 buffer[2]; + char newkey[5]; + bool has_newkey = false; + + scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); + + ret = applesmc_has_key(smc, newkey, &has_newkey); + if (ret) + return ret; + + if (has_newkey) { + ret = applesmc_read_key(smc, newkey, buffer, 1); + manual = buffer[0]; + } else { + ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); + manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; + } - ret = applesmc_read_key(FANS_MANUAL, buffer, 2); if (ret) return ret; - manual = ((buffer[0] << 8 | buffer[1]) >> to_index(attr)) & 0x01; return sysfs_emit(sysfsbuf, "%d\n", manual); } @@ -898,29 +1376,42 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u8 buffer[2]; + char newkey[5]; + bool has_newkey = false; unsigned long input; u16 val; if (kstrtoul(sysfsbuf, 10, &input) < 0) return -EINVAL; - ret = applesmc_read_key(FANS_MANUAL, buffer, 2); + scnprintf(newkey, sizeof(newkey), FAN_MANUAL_FMT, to_index(attr)); + + ret = applesmc_has_key(smc, newkey, &has_newkey); if (ret) - goto out; + return ret; - val = (buffer[0] << 8 | buffer[1]); + if (has_newkey) { + buffer[0] = input & 1; + ret = applesmc_write_key(smc, newkey, buffer, 1); + } else { + ret = applesmc_read_key(smc, FANS_MANUAL, buffer, 2); + val = (buffer[0] << 8 | buffer[1]); + if (ret) + goto out; - if (input) - val = val | (0x01 << to_index(attr)); - else - val = val & ~(0x01 << to_index(attr)); + if (input) + val = val | (0x01 << to_index(attr)); + else + val = val & ~(0x01 << to_index(attr)); - buffer[0] = (val >> 8) & 0xFF; - buffer[1] = val & 0xFF; + buffer[0] = (val >> 8) & 0xFF; + buffer[1] = val & 0xFF; - ret = applesmc_write_key(FANS_MANUAL, buffer, 2); + ret = applesmc_write_key(smc, FANS_MANUAL, buffer, 2); + } out: if (ret) @@ -932,13 +1423,14 @@ static ssize_t applesmc_store_fan_manual(struct device *dev, static ssize_t applesmc_show_fan_position(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; char newkey[5]; u8 buffer[17]; scnprintf(newkey, sizeof(newkey), FAN_ID_FMT, to_index(attr)); - ret = applesmc_read_key(newkey, buffer, 16); + ret = applesmc_read_key(smc, newkey, buffer, 16); buffer[16] = 0; if (ret) @@ -950,43 +1442,79 @@ static ssize_t applesmc_show_fan_position(struct device *dev, static ssize_t applesmc_calibrate_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { - return sysfs_emit(sysfsbuf, "(%d,%d)\n", rest_x, rest_y); + struct applesmc_device *smc = dev_get_drvdata(dev); + + return sysfs_emit(sysfsbuf, "(%d,%d)\n", smc->rest_x, smc->rest_y); } static ssize_t applesmc_calibrate_store(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { - applesmc_calibrate(); + struct applesmc_device *smc = dev_get_drvdata(dev); + + applesmc_calibrate(smc); return count; } static void applesmc_backlight_set(struct work_struct *work) { - applesmc_write_key(BACKLIGHT_KEY, backlight_state, 2); + struct applesmc_device *smc = container_of(work, struct applesmc_device, backlight_work); + + applesmc_write_key(smc, BACKLIGHT_KEY, smc->backlight_state, 2); } -static DECLARE_WORK(backlight_work, &applesmc_backlight_set); static void applesmc_brightness_set(struct led_classdev *led_cdev, enum led_brightness value) { + struct applesmc_device *smc = dev_get_drvdata(led_cdev->dev); int ret; - backlight_state[0] = value; - ret = queue_work(applesmc_led_wq, &backlight_work); + smc->backlight_state[0] = value; + ret = queue_work(smc->backlight_wq, &smc->backlight_work); if (debug && (!ret)) dev_dbg(led_cdev->dev, "work was already on the queue.\n"); } +static ssize_t applesmc_BCLM_store(struct device *dev, + struct device_attribute *attr, char *sysfsbuf, size_t count) +{ + struct applesmc_device *smc = dev_get_drvdata(dev); + u8 val; + + if (kstrtou8(sysfsbuf, 10, &val) < 0) + return -EINVAL; + + if (val < 0 || val > 100) + return -EINVAL; + + if (applesmc_write_key(smc, "BCLM", &val, 1)) + return -ENODEV; + return count; +} + +static ssize_t applesmc_BCLM_show(struct device *dev, + struct device_attribute *attr, char *sysfsbuf) +{ + struct applesmc_device *smc = dev_get_drvdata(dev); + u8 val; + + if (applesmc_read_key(smc, "BCLM", &val, 1)) + return -ENODEV; + + return sysfs_emit(sysfsbuf, "%d\n", val); +} + static ssize_t applesmc_key_count_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); int ret; u8 buffer[4]; u32 count; - ret = applesmc_read_key(KEY_COUNT_KEY, buffer, 4); + ret = applesmc_read_key(smc, KEY_COUNT_KEY, buffer, 4); if (ret) return ret; @@ -998,13 +1526,14 @@ static ssize_t applesmc_key_count_show(struct device *dev, static ssize_t applesmc_key_at_index_read_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; int ret; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); - ret = applesmc_read_entry(entry, sysfsbuf, entry->len); + ret = applesmc_read_entry(smc, entry, sysfsbuf, entry->len); if (ret) return ret; @@ -1014,9 +1543,10 @@ static ssize_t applesmc_key_at_index_read_show(struct device *dev, static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1026,9 +1556,10 @@ static ssize_t applesmc_key_at_index_data_length_show(struct device *dev, static ssize_t applesmc_key_at_index_type_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1038,9 +1569,10 @@ static ssize_t applesmc_key_at_index_type_show(struct device *dev, static ssize_t applesmc_key_at_index_name_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { + struct applesmc_device *smc = dev_get_drvdata(dev); const struct applesmc_entry *entry; - entry = applesmc_get_entry_by_index(key_at_index); + entry = applesmc_get_entry_by_index(smc, smc->key_at_index); if (IS_ERR(entry)) return PTR_ERR(entry); @@ -1050,28 +1582,25 @@ static ssize_t applesmc_key_at_index_name_show(struct device *dev, static ssize_t applesmc_key_at_index_show(struct device *dev, struct device_attribute *attr, char *sysfsbuf) { - return sysfs_emit(sysfsbuf, "%d\n", key_at_index); + struct applesmc_device *smc = dev_get_drvdata(dev); + + return sysfs_emit(sysfsbuf, "%d\n", smc->key_at_index); } static ssize_t applesmc_key_at_index_store(struct device *dev, struct device_attribute *attr, const char *sysfsbuf, size_t count) { + struct applesmc_device *smc = dev_get_drvdata(dev); unsigned long newkey; if (kstrtoul(sysfsbuf, 10, &newkey) < 0 - || newkey >= smcreg.key_count) + || newkey >= smc->reg.key_count) return -EINVAL; - key_at_index = newkey; + smc->key_at_index = newkey; return count; } -static struct led_classdev applesmc_backlight = { - .name = "smc::kbd_backlight", - .default_trigger = "nand-disk", - .brightness_set = applesmc_brightness_set, -}; - static struct applesmc_node_group info_group[] = { { "name", applesmc_name_show }, { "key_count", applesmc_key_count_show }, @@ -1111,19 +1640,25 @@ static struct applesmc_node_group temp_group[] = { { } }; +static struct applesmc_node_group BCLM_group[] = { + { "battery_charge_limit", applesmc_BCLM_show, applesmc_BCLM_store }, + { } +}; + /* Module stuff */ /* * applesmc_destroy_nodes - remove files and free associated memory */ -static void applesmc_destroy_nodes(struct applesmc_node_group *groups) +static void applesmc_destroy_nodes(struct applesmc_device *smc, + struct applesmc_node_group *groups) { struct applesmc_node_group *grp; struct applesmc_dev_attr *node; for (grp = groups; grp->nodes; grp++) { for (node = grp->nodes; node->sda.dev_attr.attr.name; node++) - sysfs_remove_file(&pdev->dev.kobj, + sysfs_remove_file(&smc->dev->dev.kobj, &node->sda.dev_attr.attr); kfree(grp->nodes); grp->nodes = NULL; @@ -1133,7 +1668,8 @@ static void applesmc_destroy_nodes(struct applesmc_node_group *groups) /* * applesmc_create_nodes - create a two-dimensional group of sysfs files */ -static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) +static int applesmc_create_nodes(struct applesmc_device *smc, + struct applesmc_node_group *groups, int num) { struct applesmc_node_group *grp; struct applesmc_dev_attr *node; @@ -1157,7 +1693,7 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) sysfs_attr_init(attr); attr->name = node->name; attr->mode = 0444 | (grp->store ? 0200 : 0); - ret = sysfs_create_file(&pdev->dev.kobj, attr); + ret = sysfs_create_file(&smc->dev->dev.kobj, attr); if (ret) { attr->name = NULL; goto out; @@ -1167,57 +1703,56 @@ static int applesmc_create_nodes(struct applesmc_node_group *groups, int num) return 0; out: - applesmc_destroy_nodes(groups); + applesmc_destroy_nodes(smc, groups); return ret; } /* Create accelerometer resources */ -static int applesmc_create_accelerometer(void) +static int applesmc_create_accelerometer(struct applesmc_device *smc) { int ret; - - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return 0; - ret = applesmc_create_nodes(accelerometer_group, 1); + ret = applesmc_create_nodes(smc, accelerometer_group, 1); if (ret) goto out; - applesmc_idev = input_allocate_device(); - if (!applesmc_idev) { + smc->idev = input_allocate_device(); + if (!smc->idev) { ret = -ENOMEM; goto out_sysfs; } /* initial calibrate for the input device */ - applesmc_calibrate(); + applesmc_calibrate(smc); /* initialize the input device */ - applesmc_idev->name = "applesmc"; - applesmc_idev->id.bustype = BUS_HOST; - applesmc_idev->dev.parent = &pdev->dev; - input_set_abs_params(applesmc_idev, ABS_X, + smc->idev->name = "applesmc"; + smc->idev->id.bustype = BUS_HOST; + smc->idev->dev.parent = &smc->dev->dev; + input_set_abs_params(smc->idev, ABS_X, -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); - input_set_abs_params(applesmc_idev, ABS_Y, + input_set_abs_params(smc->idev, ABS_Y, -256, 256, APPLESMC_INPUT_FUZZ, APPLESMC_INPUT_FLAT); - ret = input_setup_polling(applesmc_idev, applesmc_idev_poll); + ret = input_setup_polling(smc->idev, applesmc_idev_poll); if (ret) goto out_idev; - input_set_poll_interval(applesmc_idev, APPLESMC_POLL_INTERVAL); + input_set_poll_interval(smc->idev, APPLESMC_POLL_INTERVAL); - ret = input_register_device(applesmc_idev); + ret = input_register_device(smc->idev); if (ret) goto out_idev; return 0; out_idev: - input_free_device(applesmc_idev); + input_free_device(smc->idev); out_sysfs: - applesmc_destroy_nodes(accelerometer_group); + applesmc_destroy_nodes(smc, accelerometer_group); out: pr_warn("driver init failed (ret=%d)!\n", ret); @@ -1225,44 +1760,55 @@ static int applesmc_create_accelerometer(void) } /* Release all resources used by the accelerometer */ -static void applesmc_release_accelerometer(void) +static void applesmc_release_accelerometer(struct applesmc_device *smc) { - if (!smcreg.has_accelerometer) + if (!smc->reg.has_accelerometer) return; - input_unregister_device(applesmc_idev); - applesmc_destroy_nodes(accelerometer_group); + input_unregister_device(smc->idev); + applesmc_destroy_nodes(smc, accelerometer_group); } -static int applesmc_create_light_sensor(void) +static int applesmc_create_light_sensor(struct applesmc_device *smc) { - if (!smcreg.num_light_sensors) + if (!smc->reg.num_light_sensors) return 0; - return applesmc_create_nodes(light_sensor_group, 1); + return applesmc_create_nodes(smc, light_sensor_group, 1); } -static void applesmc_release_light_sensor(void) +static void applesmc_release_light_sensor(struct applesmc_device *smc) { - if (!smcreg.num_light_sensors) + if (!smc->reg.num_light_sensors) return; - applesmc_destroy_nodes(light_sensor_group); + applesmc_destroy_nodes(smc, light_sensor_group); } -static int applesmc_create_key_backlight(void) +static int applesmc_create_key_backlight(struct applesmc_device *smc) { - if (!smcreg.has_key_backlight) + int ret; + + if (!smc->reg.has_key_backlight) return 0; - applesmc_led_wq = create_singlethread_workqueue("applesmc-led"); - if (!applesmc_led_wq) + smc->backlight_wq = create_singlethread_workqueue("applesmc-led"); + if (!smc->backlight_wq) return -ENOMEM; - return led_classdev_register(&pdev->dev, &applesmc_backlight); + + INIT_WORK(&smc->backlight_work, applesmc_backlight_set); + smc->backlight_dev.name = "smc::kbd_backlight"; + smc->backlight_dev.default_trigger = "nand-disk"; + smc->backlight_dev.brightness_set = applesmc_brightness_set; + ret = led_classdev_register(&smc->dev->dev, &smc->backlight_dev); + if (ret) + destroy_workqueue(smc->backlight_wq); + + return ret; } -static void applesmc_release_key_backlight(void) +static void applesmc_release_key_backlight(struct applesmc_device *smc) { - if (!smcreg.has_key_backlight) + if (!smc->reg.has_key_backlight) return; - led_classdev_unregister(&applesmc_backlight); - destroy_workqueue(applesmc_led_wq); + led_classdev_unregister(&smc->backlight_dev); + destroy_workqueue(smc->backlight_wq); } static int applesmc_dmi_match(const struct dmi_system_id *id) @@ -1291,6 +1837,10 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), DMI_MATCH(DMI_PRODUCT_NAME, "Macmini") }, }, + { applesmc_dmi_match, "Apple iMacPro", { + DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), + DMI_MATCH(DMI_PRODUCT_NAME, "iMacPro") }, + }, { applesmc_dmi_match, "Apple MacPro", { DMI_MATCH(DMI_BOARD_VENDOR, "Apple"), DMI_MATCH(DMI_PRODUCT_NAME, "MacPro") }, @@ -1306,90 +1856,91 @@ static const struct dmi_system_id applesmc_whitelist[] __initconst = { { .ident = NULL } }; -static int __init applesmc_init(void) +static int applesmc_create_modules(struct applesmc_device *smc) { int ret; - if (!dmi_check_system(applesmc_whitelist)) { - pr_warn("supported laptop not found!\n"); - ret = -ENODEV; - goto out; - } - - if (!request_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS, - "applesmc")) { - ret = -ENXIO; - goto out; - } - - ret = platform_driver_register(&applesmc_driver); - if (ret) - goto out_region; - - pdev = platform_device_register_simple("applesmc", APPLESMC_DATA_PORT, - NULL, 0); - if (IS_ERR(pdev)) { - ret = PTR_ERR(pdev); - goto out_driver; - } - - /* create register cache */ - ret = applesmc_init_smcreg(); + ret = applesmc_create_nodes(smc, info_group, 1); if (ret) - goto out_device; - - ret = applesmc_create_nodes(info_group, 1); + goto out; + ret = applesmc_create_nodes(smc, BCLM_group, 1); if (ret) - goto out_smcreg; + goto out_info; - ret = applesmc_create_nodes(fan_group, smcreg.fan_count); + ret = applesmc_create_nodes(smc, fan_group, smc->reg.fan_count); if (ret) - goto out_info; + goto out_bclm; - ret = applesmc_create_nodes(temp_group, smcreg.index_count); + ret = applesmc_create_nodes(smc, temp_group, smc->reg.index_count); if (ret) goto out_fans; - ret = applesmc_create_accelerometer(); + ret = applesmc_create_accelerometer(smc); if (ret) goto out_temperature; - ret = applesmc_create_light_sensor(); + ret = applesmc_create_light_sensor(smc); if (ret) goto out_accelerometer; - ret = applesmc_create_key_backlight(); + ret = applesmc_create_key_backlight(smc); if (ret) goto out_light_sysfs; - hwmon_dev = hwmon_device_register(&pdev->dev); - if (IS_ERR(hwmon_dev)) { - ret = PTR_ERR(hwmon_dev); + smc->hwmon_dev = hwmon_device_register(&smc->dev->dev); + if (IS_ERR(smc->hwmon_dev)) { + ret = PTR_ERR(smc->hwmon_dev); goto out_light_ledclass; } return 0; out_light_ledclass: - applesmc_release_key_backlight(); + applesmc_release_key_backlight(smc); out_light_sysfs: - applesmc_release_light_sensor(); + applesmc_release_light_sensor(smc); out_accelerometer: - applesmc_release_accelerometer(); + applesmc_release_accelerometer(smc); out_temperature: - applesmc_destroy_nodes(temp_group); + applesmc_destroy_nodes(smc, temp_group); out_fans: - applesmc_destroy_nodes(fan_group); + applesmc_destroy_nodes(smc, fan_group); +out_bclm: + applesmc_destroy_nodes(smc, BCLM_group); out_info: - applesmc_destroy_nodes(info_group); -out_smcreg: - applesmc_destroy_smcreg(); -out_device: - platform_device_unregister(pdev); -out_driver: - platform_driver_unregister(&applesmc_driver); -out_region: - release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); + applesmc_destroy_nodes(smc, info_group); +out: + return ret; +} + +static void applesmc_destroy_modules(struct applesmc_device *smc) +{ + hwmon_device_unregister(smc->hwmon_dev); + applesmc_release_key_backlight(smc); + applesmc_release_light_sensor(smc); + applesmc_release_accelerometer(smc); + applesmc_destroy_nodes(smc, temp_group); + applesmc_destroy_nodes(smc, fan_group); + applesmc_destroy_nodes(smc, BCLM_group); + applesmc_destroy_nodes(smc, info_group); +} + +static int __init applesmc_init(void) +{ + int ret; + + if (!dmi_check_system(applesmc_whitelist)) { + pr_warn("supported laptop not found!\n"); + ret = -ENODEV; + goto out; + } + + ret = acpi_bus_register_driver(&applesmc_driver); + if (ret) + goto out; + + return 0; + out: pr_warn("driver init failed (ret=%d)!\n", ret); return ret; @@ -1397,23 +1948,14 @@ static int __init applesmc_init(void) static void __exit applesmc_exit(void) { - hwmon_device_unregister(hwmon_dev); - applesmc_release_key_backlight(); - applesmc_release_light_sensor(); - applesmc_release_accelerometer(); - applesmc_destroy_nodes(temp_group); - applesmc_destroy_nodes(fan_group); - applesmc_destroy_nodes(info_group); - applesmc_destroy_smcreg(); - platform_device_unregister(pdev); - platform_driver_unregister(&applesmc_driver); - release_region(APPLESMC_DATA_PORT, APPLESMC_NR_PORTS); + acpi_bus_unregister_driver(&applesmc_driver); } module_init(applesmc_init); module_exit(applesmc_exit); MODULE_AUTHOR("Nicolas Boichat"); +MODULE_AUTHOR("Paul Pawlowski"); MODULE_DESCRIPTION("Apple SMC"); MODULE_LICENSE("GPL v2"); MODULE_DEVICE_TABLE(dmi, applesmc_whitelist); diff --git a/drivers/input/mouse/bcm5974.c b/drivers/input/mouse/bcm5974.c index dfdfb59cc8b5..e0da70576167 100644 --- a/drivers/input/mouse/bcm5974.c +++ b/drivers/input/mouse/bcm5974.c @@ -83,6 +83,24 @@ #define USB_DEVICE_ID_APPLE_WELLSPRING9_ISO 0x0273 #define USB_DEVICE_ID_APPLE_WELLSPRING9_JIS 0x0274 +/* T2-Attached Devices */ +/* MacbookAir8,1 (2018) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K 0x027a +/* MacbookPro15,2 (2018) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132 0x027b +/* MacbookPro15,1 (2018) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680 0x027c +/* MacbookPro15,4 (2019) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213 0x027d +/* MacbookPro16,2 (2020) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K 0x027e +/* MacbookPro16,3 (2020) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223 0x027f +/* MacbookAir9,1 (2020) */ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K 0x0280 +/* MacbookPro16,1 (2019)*/ +#define USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F 0x0340 + #define BCM5974_DEVICE(prod) { \ .match_flags = (USB_DEVICE_ID_MATCH_DEVICE | \ USB_DEVICE_ID_MATCH_INT_CLASS | \ @@ -147,6 +165,22 @@ static const struct usb_device_id bcm5974_table[] = { BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_ANSI), BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_ISO), BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRING9_JIS), + /* MacbookAir8,1 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K), + /* MacbookPro15,2 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132), + /* MacbookPro15,1 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680), + /* MacbookPro15,4 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213), + /* MacbookPro16,2 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K), + /* MacbookPro16,3 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223), + /* MacbookAir9,1 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K), + /* MacbookPro16,1 */ + BCM5974_DEVICE(USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F), /* Terminating entry */ {} }; @@ -483,6 +517,110 @@ static const struct bcm5974_config bcm5974_config_table[] = { { SN_COORD, -203, 6803 }, { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J140K, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -6243, 6749 }, + { SN_COORD, -170, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J132, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -6243, 6749 }, + { SN_COORD, -170, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J680, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -7456, 7976 }, + { SN_COORD, -1768, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J213, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -6243, 6749 }, + { SN_COORD, -170, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J214K, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -7823, 8329 }, + { SN_COORD, -370, 7925 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J223, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -6243, 6749 }, + { SN_COORD, -170, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J230K, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -6243, 6749 }, + { SN_COORD, -170, 7685 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, + { + USB_DEVICE_ID_APPLE_WELLSPRINGT2_J152F, + 0, + 0, + HAS_INTEGRATED_BUTTON, + 0, sizeof(struct bt_data), + 0x83, DATAFORMAT(TYPE4), + { SN_PRESSURE, 0, 300 }, + { SN_WIDTH, 0, 2048 }, + { SN_COORD, -8916, 9918 }, + { SN_COORD, -1934, 9835 }, + { SN_ORIENT, -MAX_FINGER_ORIENTATION, MAX_FINGER_ORIENTATION } + }, {} }; diff --git a/drivers/pci/vgaarb.c b/drivers/pci/vgaarb.c index 78748e8d2dba..2b2b558cebe6 100644 --- a/drivers/pci/vgaarb.c +++ b/drivers/pci/vgaarb.c @@ -143,6 +143,7 @@ void vga_set_default_device(struct pci_dev *pdev) pci_dev_put(vga_default); vga_default = pci_dev_get(pdev); } +EXPORT_SYMBOL_GPL(vga_set_default_device); /** * vga_remove_vgacon - deactivate VGA console diff --git a/drivers/platform/x86/apple-gmux.c b/drivers/platform/x86/apple-gmux.c index 1417e230edbd..e69785af8e1d 100644 --- a/drivers/platform/x86/apple-gmux.c +++ b/drivers/platform/x86/apple-gmux.c @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -107,6 +108,10 @@ struct apple_gmux_config { # define MMIO_GMUX_MAX_BRIGHTNESS 0xffff +static bool force_igd; +module_param(force_igd, bool, 0); +MODULE_PARM_DESC(force_idg, "Switch gpu to igd on module load. Make sure that you have apple-set-os set up and the iGPU is in `lspci -s 00:02.0`. (default: false) (bool)"); + static u8 gmux_pio_read8(struct apple_gmux_data *gmux_data, int port) { return inb(gmux_data->iostart + port); @@ -945,6 +950,19 @@ static int gmux_probe(struct pnp_dev *pnp, const struct pnp_device_id *id) gmux_enable_interrupts(gmux_data); gmux_read_switch_state(gmux_data); + if (force_igd) { + struct pci_dev *pdev; + + pdev = pci_get_domain_bus_and_slot(0, 0, PCI_DEVFN(2, 0)); + if (pdev) { + pr_info("Switching to IGD"); + gmux_switchto(VGA_SWITCHEROO_IGD); + vga_set_default_device(pdev); + } else { + pr_err("force_idg is true, but couldn't find iGPU at 00:02.0! Is apple-set-os working?"); + } + } + /* * Retina MacBook Pros cannot switch the panel's AUX separately * and need eDP pre-calibration. They are distinguishable from diff --git a/drivers/staging/Kconfig b/drivers/staging/Kconfig index 3fb68d60dfc1..7337f658fe96 100644 --- a/drivers/staging/Kconfig +++ b/drivers/staging/Kconfig @@ -64,4 +64,6 @@ source "drivers/staging/fieldbus/Kconfig" source "drivers/staging/vme_user/Kconfig" +source "drivers/staging/apple-bce/Kconfig" + endif # STAGING diff --git a/drivers/staging/Makefile b/drivers/staging/Makefile index c977aa13fad4..241ea7562045 100644 --- a/drivers/staging/Makefile +++ b/drivers/staging/Makefile @@ -21,3 +21,4 @@ obj-$(CONFIG_GREYBUS) += greybus/ obj-$(CONFIG_BCM2835_VCHIQ) += vc04_services/ obj-$(CONFIG_XIL_AXIS_FIFO) += axis-fifo/ obj-$(CONFIG_FIELDBUS_DEV) += fieldbus/ +obj-$(CONFIG_APPLE_BCE) += apple-bce/ diff --git a/drivers/staging/apple-bce/Kconfig b/drivers/staging/apple-bce/Kconfig new file mode 100644 index 000000000000..fe92bc441e89 --- /dev/null +++ b/drivers/staging/apple-bce/Kconfig @@ -0,0 +1,18 @@ +config APPLE_BCE + tristate "Apple BCE driver (VHCI and Audio support)" + default m + depends on X86 + select SOUND + select SND + select SND_PCM + select SND_JACK + help + VHCI and audio support on Apple MacBooks with the T2 Chip. + This driver is divided in three components: + - BCE (Buffer Copy Engine): which establishes a basic communication + channel with the T2 chip. This component is required by the other two: + - VHCI (Virtual Host Controller Interface): Access to keyboard, mouse + and other system devices depend on this virtual USB host controller + - Audio: a driver for the T2 audio interface. + + If "M" is selected, the module will be called apple-bce.' diff --git a/drivers/staging/apple-bce/Makefile b/drivers/staging/apple-bce/Makefile new file mode 100644 index 000000000000..8cfbd3f64af6 --- /dev/null +++ b/drivers/staging/apple-bce/Makefile @@ -0,0 +1,28 @@ +modname := apple-bce +obj-$(CONFIG_APPLE_BCE) += $(modname).o + +apple-bce-objs := apple_bce.o mailbox.o queue.o queue_dma.o vhci/vhci.o vhci/queue.o vhci/transfer.o audio/audio.o audio/protocol.o audio/protocol_bce.o audio/pcm.o + +MY_CFLAGS += -DWITHOUT_NVME_PATCH +#MY_CFLAGS += -g -DDEBUG +ccflags-y += ${MY_CFLAGS} +CC += ${MY_CFLAGS} + +KVERSION := $(KERNELRELEASE) +ifeq ($(origin KERNELRELEASE), undefined) +KVERSION := $(shell uname -r) +endif + +KDIR := /lib/modules/$(KVERSION)/build +PWD := $(shell pwd) + +.PHONY: all + +all: + $(MAKE) -C $(KDIR) M=$(PWD) modules + +clean: + $(MAKE) -C $(KDIR) M=$(PWD) clean + +install: + $(MAKE) -C $(KDIR) M=$(PWD) modules_install diff --git a/drivers/staging/apple-bce/apple_bce.c b/drivers/staging/apple-bce/apple_bce.c new file mode 100644 index 000000000000..4fd2415d7028 --- /dev/null +++ b/drivers/staging/apple-bce/apple_bce.c @@ -0,0 +1,445 @@ +#include "apple_bce.h" +#include +#include +#include "audio/audio.h" +#include + +static dev_t bce_chrdev; +static struct class *bce_class; + +struct apple_bce_device *global_bce; + +static int bce_create_command_queues(struct apple_bce_device *bce); +static void bce_free_command_queues(struct apple_bce_device *bce); +static irqreturn_t bce_handle_mb_irq(int irq, void *dev); +static irqreturn_t bce_handle_dma_irq(int irq, void *dev); +static int bce_fw_version_handshake(struct apple_bce_device *bce); +static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq); + +static int apple_bce_probe(struct pci_dev *dev, const struct pci_device_id *id) +{ + struct apple_bce_device *bce = NULL; + int status = 0; + int nvec; + + pr_info("apple-bce: capturing our device\n"); + + if (pci_enable_device(dev)) + return -ENODEV; + if (pci_request_regions(dev, "apple-bce")) { + status = -ENODEV; + goto fail; + } + pci_set_master(dev); + nvec = pci_alloc_irq_vectors(dev, 1, 8, PCI_IRQ_MSI); + if (nvec < 5) { + status = -EINVAL; + goto fail; + } + + bce = kzalloc(sizeof(struct apple_bce_device), GFP_KERNEL); + if (!bce) { + status = -ENOMEM; + goto fail; + } + + bce->pci = dev; + pci_set_drvdata(dev, bce); + + bce->devt = bce_chrdev; + bce->dev = device_create(bce_class, &dev->dev, bce->devt, NULL, "apple-bce"); + if (IS_ERR_OR_NULL(bce->dev)) { + status = PTR_ERR(bce_class); + goto fail; + } + + bce->reg_mem_mb = pci_iomap(dev, 4, 0); + bce->reg_mem_dma = pci_iomap(dev, 2, 0); + + if (IS_ERR_OR_NULL(bce->reg_mem_mb) || IS_ERR_OR_NULL(bce->reg_mem_dma)) { + dev_warn(&dev->dev, "apple-bce: Failed to pci_iomap required regions\n"); + goto fail; + } + + bce_mailbox_init(&bce->mbox, bce->reg_mem_mb); + bce_timestamp_init(&bce->timestamp, bce->reg_mem_mb); + + spin_lock_init(&bce->queues_lock); + ida_init(&bce->queue_ida); + + if ((status = pci_request_irq(dev, 0, bce_handle_mb_irq, NULL, dev, "bce_mbox"))) + goto fail; + if ((status = pci_request_irq(dev, 4, NULL, bce_handle_dma_irq, dev, "bce_dma"))) + goto fail_interrupt_0; + + if ((status = dma_set_mask_and_coherent(&dev->dev, DMA_BIT_MASK(37)))) { + dev_warn(&dev->dev, "dma: Setting mask failed\n"); + goto fail_interrupt; + } + + /* Gets the function 0's interface. This is needed because Apple only accepts DMA on our function if function 0 + is a bus master, so we need to work around this. */ + bce->pci0 = pci_get_slot(dev->bus, PCI_DEVFN(PCI_SLOT(dev->devfn), 0)); +#ifndef WITHOUT_NVME_PATCH + if ((status = pci_enable_device_mem(bce->pci0))) { + dev_warn(&dev->dev, "apple-bce: failed to enable function 0\n"); + goto fail_dev0; + } +#endif + pci_set_master(bce->pci0); + + bce_timestamp_start(&bce->timestamp, true); + + if ((status = bce_fw_version_handshake(bce))) + goto fail_ts; + pr_info("apple-bce: handshake done\n"); + + if ((status = bce_create_command_queues(bce))) { + pr_info("apple-bce: Creating command queues failed\n"); + goto fail_ts; + } + + global_bce = bce; + + bce_vhci_create(bce, &bce->vhci); + + return 0; + +fail_ts: + bce_timestamp_stop(&bce->timestamp); +#ifndef WITHOUT_NVME_PATCH + pci_disable_device(bce->pci0); +fail_dev0: +#endif + pci_dev_put(bce->pci0); +fail_interrupt: + pci_free_irq(dev, 4, dev); +fail_interrupt_0: + pci_free_irq(dev, 0, dev); +fail: + if (bce && bce->dev) { + device_destroy(bce_class, bce->devt); + + if (!IS_ERR_OR_NULL(bce->reg_mem_mb)) + pci_iounmap(dev, bce->reg_mem_mb); + if (!IS_ERR_OR_NULL(bce->reg_mem_dma)) + pci_iounmap(dev, bce->reg_mem_dma); + + kfree(bce); + } + + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + + if (!status) + status = -EINVAL; + return status; +} + +static int bce_create_command_queues(struct apple_bce_device *bce) +{ + int status; + struct bce_queue_memcfg *cfg; + + bce->cmd_cq = bce_alloc_cq(bce, 0, 0x20); + bce->cmd_cmdq = bce_alloc_cmdq(bce, 1, 0x20); + if (bce->cmd_cq == NULL || bce->cmd_cmdq == NULL) { + status = -ENOMEM; + goto err; + } + bce->queues[0] = (struct bce_queue *) bce->cmd_cq; + bce->queues[1] = (struct bce_queue *) bce->cmd_cmdq->sq; + + cfg = kzalloc(sizeof(struct bce_queue_memcfg), GFP_KERNEL); + if (!cfg) { + status = -ENOMEM; + goto err; + } + bce_get_cq_memcfg(bce->cmd_cq, cfg); + if ((status = bce_register_command_queue(bce, cfg, false))) + goto err; + bce_get_sq_memcfg(bce->cmd_cmdq->sq, bce->cmd_cq, cfg); + if ((status = bce_register_command_queue(bce, cfg, true))) + goto err; + kfree(cfg); + + return 0; + +err: + if (bce->cmd_cq) + bce_free_cq(bce, bce->cmd_cq); + if (bce->cmd_cmdq) + bce_free_cmdq(bce, bce->cmd_cmdq); + return status; +} + +static void bce_free_command_queues(struct apple_bce_device *bce) +{ + bce_free_cq(bce, bce->cmd_cq); + bce_free_cmdq(bce, bce->cmd_cmdq); + bce->cmd_cq = NULL; + bce->queues[0] = NULL; +} + +static irqreturn_t bce_handle_mb_irq(int irq, void *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(dev); + bce_mailbox_handle_interrupt(&bce->mbox); + return IRQ_HANDLED; +} + +static irqreturn_t bce_handle_dma_irq(int irq, void *dev) +{ + int i; + struct apple_bce_device *bce = pci_get_drvdata(dev); + spin_lock(&bce->queues_lock); + for (i = 0; i < BCE_MAX_QUEUE_COUNT; i++) + if (bce->queues[i] && bce->queues[i]->type == BCE_QUEUE_CQ) + bce_handle_cq_completions(bce, (struct bce_queue_cq *) bce->queues[i]); + spin_unlock(&bce->queues_lock); + return IRQ_HANDLED; +} + +static int bce_fw_version_handshake(struct apple_bce_device *bce) +{ + u64 result; + int status; + + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SET_FW_PROTOCOL_VERSION, BC_PROTOCOL_VERSION), + &result))) + return status; + if (BCE_MB_TYPE(result) != BCE_MB_SET_FW_PROTOCOL_VERSION || + BCE_MB_VALUE(result) != BC_PROTOCOL_VERSION) { + pr_err("apple-bce: FW version handshake failed %x:%llx\n", BCE_MB_TYPE(result), BCE_MB_VALUE(result)); + return -EINVAL; + } + return 0; +} + +static int bce_register_command_queue(struct apple_bce_device *bce, struct bce_queue_memcfg *cfg, int is_sq) +{ + int status; + int cmd_type; + u64 result; + // OS X uses an bidirectional direction, but that's not really needed + dma_addr_t a = dma_map_single(&bce->pci->dev, cfg, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); + if (dma_mapping_error(&bce->pci->dev, a)) + return -ENOMEM; + cmd_type = is_sq ? BCE_MB_REGISTER_COMMAND_SQ : BCE_MB_REGISTER_COMMAND_CQ; + status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(cmd_type, a), &result); + dma_unmap_single(&bce->pci->dev, a, sizeof(struct bce_queue_memcfg), DMA_TO_DEVICE); + if (status) + return status; + if (BCE_MB_TYPE(result) != BCE_MB_REGISTER_COMMAND_QUEUE_REPLY) + return -EINVAL; + return 0; +} + +static void apple_bce_remove(struct pci_dev *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(dev); + bce->is_being_removed = true; + + bce_vhci_destroy(&bce->vhci); + + bce_timestamp_stop(&bce->timestamp); +#ifndef WITHOUT_NVME_PATCH + pci_disable_device(bce->pci0); +#endif + pci_dev_put(bce->pci0); + pci_free_irq(dev, 0, dev); + pci_free_irq(dev, 4, dev); + bce_free_command_queues(bce); + pci_iounmap(dev, bce->reg_mem_mb); + pci_iounmap(dev, bce->reg_mem_dma); + device_destroy(bce_class, bce->devt); + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + kfree(bce); +} + +static int bce_save_state_and_sleep(struct apple_bce_device *bce) +{ + int attempt, status = 0; + u64 resp; + dma_addr_t dma_addr; + void *dma_ptr = NULL; + size_t size = max(PAGE_SIZE, 4096UL); + + for (attempt = 0; attempt < 5; ++attempt) { + pr_debug("apple-bce: suspend: attempt %i, buffer size %li\n", attempt, size); + dma_ptr = dma_alloc_coherent(&bce->pci->dev, size, &dma_addr, GFP_KERNEL); + if (!dma_ptr) { + pr_err("apple-bce: suspend failed (data alloc failed)\n"); + break; + } + BUG_ON((dma_addr % 4096) != 0); + status = bce_mailbox_send(&bce->mbox, + BCE_MB_MSG(BCE_MB_SAVE_STATE_AND_SLEEP, (dma_addr & ~(4096LLU - 1)) | (size / 4096)), &resp); + if (status) { + pr_err("apple-bce: suspend failed (mailbox send)\n"); + break; + } + if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { + bce->saved_data_dma_addr = dma_addr; + bce->saved_data_dma_ptr = dma_ptr; + bce->saved_data_dma_size = size; + return 0; + } else if (BCE_MB_TYPE(resp) == BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE) { + dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); + /* The 0x10ff magic value was extracted from Apple's driver */ + size = (BCE_MB_VALUE(resp) + 0x10ff) & ~(4096LLU - 1); + pr_debug("apple-bce: suspend: device requested a larger buffer (%li)\n", size); + continue; + } else { + pr_err("apple-bce: suspend failed (invalid device response)\n"); + status = -EINVAL; + break; + } + } + if (dma_ptr) + dma_free_coherent(&bce->pci->dev, size, dma_ptr, dma_addr); + if (!status) + return bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_SLEEP_NO_STATE, 0), &resp); + return status; +} + +static int bce_restore_state_and_wake(struct apple_bce_device *bce) +{ + int status; + u64 resp; + if (!bce->saved_data_dma_ptr) { + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_NO_STATE, 0), &resp))) { + pr_err("apple-bce: resume with no state failed (mailbox send)\n"); + return status; + } + if (BCE_MB_TYPE(resp) != BCE_MB_RESTORE_NO_STATE) { + pr_err("apple-bce: resume with no state failed (invalid device response)\n"); + return -EINVAL; + } + return 0; + } + + if ((status = bce_mailbox_send(&bce->mbox, BCE_MB_MSG(BCE_MB_RESTORE_STATE_AND_WAKE, + (bce->saved_data_dma_addr & ~(4096LLU - 1)) | (bce->saved_data_dma_size / 4096)), &resp))) { + pr_err("apple-bce: resume with state failed (mailbox send)\n"); + goto finish_with_state; + } + if (BCE_MB_TYPE(resp) != BCE_MB_SAVE_RESTORE_STATE_COMPLETE) { + pr_err("apple-bce: resume with state failed (invalid device response)\n"); + status = -EINVAL; + goto finish_with_state; + } + +finish_with_state: + dma_free_coherent(&bce->pci->dev, bce->saved_data_dma_size, bce->saved_data_dma_ptr, bce->saved_data_dma_addr); + bce->saved_data_dma_ptr = NULL; + return status; +} + +static int apple_bce_suspend(struct device *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); + int status; + + bce_timestamp_stop(&bce->timestamp); + + if ((status = bce_save_state_and_sleep(bce))) + return status; + + return 0; +} + +static int apple_bce_resume(struct device *dev) +{ + struct apple_bce_device *bce = pci_get_drvdata(to_pci_dev(dev)); + int status; + + pci_set_master(bce->pci); + pci_set_master(bce->pci0); + + if ((status = bce_restore_state_and_wake(bce))) + return status; + + bce_timestamp_start(&bce->timestamp, false); + + return 0; +} + +static struct pci_device_id apple_bce_ids[ ] = { + { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1801) }, + { 0, }, +}; + +MODULE_DEVICE_TABLE(pci, apple_bce_ids); + +struct dev_pm_ops apple_bce_pci_driver_pm = { + .suspend = apple_bce_suspend, + .resume = apple_bce_resume +}; +struct pci_driver apple_bce_pci_driver = { + .name = "apple-bce", + .id_table = apple_bce_ids, + .probe = apple_bce_probe, + .remove = apple_bce_remove, + .driver = { + .pm = &apple_bce_pci_driver_pm + } +}; + + +static int __init apple_bce_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&bce_chrdev, 0, 1, "apple-bce"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + bce_class = class_create(THIS_MODULE, "apple-bce"); +#else + bce_class = class_create("apple-bce"); +#endif + if (IS_ERR(bce_class)) { + result = PTR_ERR(bce_class); + goto fail_class; + } + if ((result = bce_vhci_module_init())) { + pr_err("apple-bce: bce-vhci init failed"); + goto fail_class; + } + + result = pci_register_driver(&apple_bce_pci_driver); + if (result) + goto fail_drv; + + aaudio_module_init(); + + return 0; + +fail_drv: + pci_unregister_driver(&apple_bce_pci_driver); +fail_class: + class_destroy(bce_class); +fail_chrdev: + unregister_chrdev_region(bce_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} +static void __exit apple_bce_module_exit(void) +{ + pci_unregister_driver(&apple_bce_pci_driver); + + aaudio_module_exit(); + bce_vhci_module_exit(); + class_destroy(bce_class); + unregister_chrdev_region(bce_chrdev, 1); +} + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("MrARM"); +MODULE_DESCRIPTION("Apple BCE Driver"); +MODULE_VERSION("0.01"); +module_init(apple_bce_module_init); +module_exit(apple_bce_module_exit); diff --git a/drivers/staging/apple-bce/apple_bce.h b/drivers/staging/apple-bce/apple_bce.h new file mode 100644 index 000000000000..f13ab8d5742e --- /dev/null +++ b/drivers/staging/apple-bce/apple_bce.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include "mailbox.h" +#include "queue.h" +#include "vhci/vhci.h" + +#define BC_PROTOCOL_VERSION 0x20001 +#define BCE_MAX_QUEUE_COUNT 0x100 + +#define BCE_QUEUE_USER_MIN 2 +#define BCE_QUEUE_USER_MAX (BCE_MAX_QUEUE_COUNT - 1) + +struct apple_bce_device { + struct pci_dev *pci, *pci0; + dev_t devt; + struct device *dev; + void __iomem *reg_mem_mb; + void __iomem *reg_mem_dma; + struct bce_mailbox mbox; + struct bce_timestamp timestamp; + struct bce_queue *queues[BCE_MAX_QUEUE_COUNT]; + struct spinlock queues_lock; + struct ida queue_ida; + struct bce_queue_cq *cmd_cq; + struct bce_queue_cmdq *cmd_cmdq; + struct bce_queue_sq *int_sq_list[BCE_MAX_QUEUE_COUNT]; + bool is_being_removed; + + dma_addr_t saved_data_dma_addr; + void *saved_data_dma_ptr; + size_t saved_data_dma_size; + + struct bce_vhci vhci; +}; + +extern struct apple_bce_device *global_bce; \ No newline at end of file diff --git a/drivers/staging/apple-bce/audio/audio.c b/drivers/staging/apple-bce/audio/audio.c new file mode 100644 index 000000000000..bd16ddd16c1d --- /dev/null +++ b/drivers/staging/apple-bce/audio/audio.c @@ -0,0 +1,711 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "audio.h" +#include "pcm.h" +#include + +static int aaudio_alsa_index = SNDRV_DEFAULT_IDX1; +static char *aaudio_alsa_id = SNDRV_DEFAULT_STR1; + +static dev_t aaudio_chrdev; +static struct class *aaudio_class; + +static int aaudio_init_cmd(struct aaudio_device *a); +static int aaudio_init_bs(struct aaudio_device *a); +static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id); +static void aaudio_free_dev(struct aaudio_subdevice *sdev); + +static int aaudio_probe(struct pci_dev *dev, const struct pci_device_id *id) +{ + struct aaudio_device *aaudio = NULL; + struct aaudio_subdevice *sdev = NULL; + int status = 0; + u32 cfg; + + pr_info("aaudio: capturing our device\n"); + + if (pci_enable_device(dev)) + return -ENODEV; + if (pci_request_regions(dev, "aaudio")) { + status = -ENODEV; + goto fail; + } + pci_set_master(dev); + + aaudio = kzalloc(sizeof(struct aaudio_device), GFP_KERNEL); + if (!aaudio) { + status = -ENOMEM; + goto fail; + } + + aaudio->bce = global_bce; + if (!aaudio->bce) { + dev_warn(&dev->dev, "aaudio: No BCE available\n"); + status = -EINVAL; + goto fail; + } + + aaudio->pci = dev; + pci_set_drvdata(dev, aaudio); + + aaudio->devt = aaudio_chrdev; + aaudio->dev = device_create(aaudio_class, &dev->dev, aaudio->devt, NULL, "aaudio"); + if (IS_ERR_OR_NULL(aaudio->dev)) { + status = PTR_ERR(aaudio_class); + goto fail; + } + device_link_add(aaudio->dev, aaudio->bce->dev, DL_FLAG_PM_RUNTIME | DL_FLAG_AUTOREMOVE_CONSUMER); + + init_completion(&aaudio->remote_alive); + INIT_LIST_HEAD(&aaudio->subdevice_list); + + /* Init: set an unknown flag in the bitset */ + if (pci_read_config_dword(dev, 4, &cfg)) + dev_warn(&dev->dev, "aaudio: pci_read_config_dword fail\n"); + if (pci_write_config_dword(dev, 4, cfg | 6u)) + dev_warn(&dev->dev, "aaudio: pci_write_config_dword fail\n"); + + dev_info(aaudio->dev, "aaudio: bs len = %llx\n", pci_resource_len(dev, 0)); + aaudio->reg_mem_bs_dma = pci_resource_start(dev, 0); + aaudio->reg_mem_bs = pci_iomap(dev, 0, 0); + aaudio->reg_mem_cfg = pci_iomap(dev, 4, 0); + + aaudio->reg_mem_gpr = (u32 __iomem *) ((u8 __iomem *) aaudio->reg_mem_cfg + 0xC000); + + if (IS_ERR_OR_NULL(aaudio->reg_mem_bs) || IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) { + dev_warn(&dev->dev, "aaudio: Failed to pci_iomap required regions\n"); + goto fail; + } + + if (aaudio_bce_init(aaudio)) { + dev_warn(&dev->dev, "aaudio: Failed to init BCE command transport\n"); + goto fail; + } + + if (snd_card_new(aaudio->dev, aaudio_alsa_index, aaudio_alsa_id, THIS_MODULE, 0, &aaudio->card)) { + dev_err(&dev->dev, "aaudio: Failed to create ALSA card\n"); + goto fail; + } + + strcpy(aaudio->card->shortname, "Apple T2 Audio"); + strcpy(aaudio->card->longname, "Apple T2 Audio"); + strcpy(aaudio->card->mixername, "Apple T2 Audio"); + /* Dynamic alsa ids start at 100 */ + aaudio->next_alsa_id = 100; + + if (aaudio_init_cmd(aaudio)) { + dev_err(&dev->dev, "aaudio: Failed to initialize over BCE\n"); + goto fail_snd; + } + + if (aaudio_init_bs(aaudio)) { + dev_err(&dev->dev, "aaudio: Failed to initialize BufferStruct\n"); + goto fail_snd; + } + + if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { + dev_err(&dev->dev, "Failed to set remote access\n"); + return status; + } + + if (snd_card_register(aaudio->card)) { + dev_err(&dev->dev, "aaudio: Failed to register ALSA sound device\n"); + goto fail_snd; + } + + list_for_each_entry(sdev, &aaudio->subdevice_list, list) { + struct aaudio_buffer_struct_device *dev = &aaudio->bs->devices[sdev->buf_id]; + + if (sdev->out_stream_cnt == 1 && !strcmp(dev->name, "Speaker")) { + struct snd_pcm_hardware *hw = sdev->out_streams[0].alsa_hw_desc; + + snprintf(aaudio->card->driver, sizeof(aaudio->card->driver) / sizeof(char), "AppleT2x%d", hw->channels_min); + } + } + + return 0; + +fail_snd: + snd_card_free(aaudio->card); +fail: + if (aaudio && aaudio->dev) + device_destroy(aaudio_class, aaudio->devt); + kfree(aaudio); + + if (!IS_ERR_OR_NULL(aaudio->reg_mem_bs)) + pci_iounmap(dev, aaudio->reg_mem_bs); + if (!IS_ERR_OR_NULL(aaudio->reg_mem_cfg)) + pci_iounmap(dev, aaudio->reg_mem_cfg); + + pci_release_regions(dev); + pci_disable_device(dev); + + if (!status) + status = -EINVAL; + return status; +} + + + +static void aaudio_remove(struct pci_dev *dev) +{ + struct aaudio_subdevice *sdev; + struct aaudio_device *aaudio = pci_get_drvdata(dev); + + snd_card_free(aaudio->card); + while (!list_empty(&aaudio->subdevice_list)) { + sdev = list_first_entry(&aaudio->subdevice_list, struct aaudio_subdevice, list); + list_del(&sdev->list); + aaudio_free_dev(sdev); + } + pci_iounmap(dev, aaudio->reg_mem_bs); + pci_iounmap(dev, aaudio->reg_mem_cfg); + device_destroy(aaudio_class, aaudio->devt); + pci_free_irq_vectors(dev); + pci_release_regions(dev); + pci_disable_device(dev); + kfree(aaudio); +} + +static int aaudio_suspend(struct device *dev) +{ + struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); + + if (aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_OFF)) + dev_warn(aaudio->dev, "Failed to reset remote access\n"); + + pci_disable_device(aaudio->pci); + return 0; +} + +static int aaudio_resume(struct device *dev) +{ + int status; + struct aaudio_device *aaudio = pci_get_drvdata(to_pci_dev(dev)); + + if ((status = pci_enable_device(aaudio->pci))) + return status; + pci_set_master(aaudio->pci); + + if ((status = aaudio_cmd_set_remote_access(aaudio, AAUDIO_REMOTE_ACCESS_ON))) { + dev_err(aaudio->dev, "Failed to set remote access\n"); + return status; + } + + return 0; +} + +static int aaudio_init_cmd(struct aaudio_device *a) +{ + int status; + struct aaudio_send_ctx sctx; + struct aaudio_msg buf; + u64 dev_cnt, dev_i; + aaudio_device_id_t *dev_l; + + if ((status = aaudio_send(a, &sctx, 500, + aaudio_msg_write_alive_notification, 1, 3))) { + dev_err(a->dev, "Sending alive notification failed\n"); + return status; + } + + if (wait_for_completion_timeout(&a->remote_alive, msecs_to_jiffies(500)) == 0) { + dev_err(a->dev, "Timed out waiting for remote\n"); + return -ETIMEDOUT; + } + dev_info(a->dev, "Continuing init\n"); + + buf = aaudio_reply_alloc(); + if ((status = aaudio_cmd_get_device_list(a, &buf, &dev_l, &dev_cnt))) { + dev_err(a->dev, "Failed to get device list\n"); + aaudio_reply_free(&buf); + return status; + } + for (dev_i = 0; dev_i < dev_cnt; ++dev_i) + aaudio_init_dev(a, dev_l[dev_i]); + aaudio_reply_free(&buf); + + return 0; +} + +static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm); +static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev); + +static void aaudio_init_dev(struct aaudio_device *a, aaudio_device_id_t dev_id) +{ + struct aaudio_subdevice *sdev; + struct aaudio_msg buf = aaudio_reply_alloc(); + u64 uid_len, stream_cnt, i; + aaudio_object_id_t *stream_list; + char *uid; + + sdev = kzalloc(sizeof(struct aaudio_subdevice), GFP_KERNEL); + + if (aaudio_cmd_get_property(a, &buf, dev_id, dev_id, AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_UID, 0), + NULL, 0, (void **) &uid, &uid_len) || uid_len > AAUDIO_DEVICE_MAX_UID_LEN) { + dev_err(a->dev, "Failed to get device uid for device %llx\n", dev_id); + goto fail; + } + dev_info(a->dev, "Remote device %llx %.*s\n", dev_id, (int) uid_len, uid); + + sdev->a = a; + INIT_LIST_HEAD(&sdev->list); + sdev->dev_id = dev_id; + sdev->buf_id = AAUDIO_BUFFER_ID_NONE; + strncpy(sdev->uid, uid, uid_len); + sdev->uid[uid_len + 1] = '\0'; + + if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_INPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->in_latency, sizeof(u32))) + dev_warn(a->dev, "Failed to query device input latency\n"); + if (aaudio_cmd_get_primitive_property(a, dev_id, dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_LATENCY, 0), NULL, 0, &sdev->out_latency, sizeof(u32))) + dev_warn(a->dev, "Failed to query device output latency\n"); + + if (aaudio_cmd_get_input_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { + dev_err(a->dev, "Failed to get input stream list for device %llx\n", dev_id); + goto fail; + } + if (stream_cnt > AAUDIO_DEIVCE_MAX_INPUT_STREAMS) { + dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", + sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_INPUT_STREAMS); + stream_cnt = AAUDIO_DEIVCE_MAX_INPUT_STREAMS; + } + sdev->in_stream_cnt = stream_cnt; + for (i = 0; i < stream_cnt; i++) { + sdev->in_streams[i].id = stream_list[i]; + sdev->in_streams[i].buffer_cnt = 0; + aaudio_init_stream_info(sdev, &sdev->in_streams[i]); + sdev->in_streams[i].latency += sdev->in_latency; + } + + if (aaudio_cmd_get_output_stream_list(a, &buf, dev_id, &stream_list, &stream_cnt)) { + dev_err(a->dev, "Failed to get output stream list for device %llx\n", dev_id); + goto fail; + } + if (stream_cnt > AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS) { + dev_warn(a->dev, "Device %s input stream count %llu is larger than the supported count of %u\n", + sdev->uid, stream_cnt, AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS); + stream_cnt = AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS; + } + sdev->out_stream_cnt = stream_cnt; + for (i = 0; i < stream_cnt; i++) { + sdev->out_streams[i].id = stream_list[i]; + sdev->out_streams[i].buffer_cnt = 0; + aaudio_init_stream_info(sdev, &sdev->out_streams[i]); + sdev->out_streams[i].latency += sdev->in_latency; + } + + if (sdev->is_pcm) + aaudio_create_pcm(sdev); + /* Headphone Jack status */ + if (!strcmp(sdev->uid, "Codec Output")) { + if (snd_jack_new(a->card, sdev->uid, SND_JACK_HEADPHONE, &sdev->jack, true, false)) + dev_warn(a->dev, "Failed to create an attached jack for %s\n", sdev->uid); + aaudio_cmd_property_listener(a, sdev->dev_id, sdev->dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0)); + aaudio_handle_jack_connection_change(sdev); + } + + aaudio_reply_free(&buf); + + list_add_tail(&sdev->list, &a->subdevice_list); + return; + +fail: + aaudio_reply_free(&buf); + kfree(sdev); +} + +static void aaudio_init_stream_info(struct aaudio_subdevice *sdev, struct aaudio_stream *strm) +{ + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_PHYS_FORMAT, 0), NULL, 0, + &strm->desc, sizeof(strm->desc))) + dev_warn(sdev->a->dev, "Failed to query stream descriptor\n"); + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, strm->id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_GLOBAL, AAUDIO_PROP_LATENCY, 0), NULL, 0, &strm->latency, sizeof(u32))) + dev_warn(sdev->a->dev, "Failed to query stream latency\n"); + if (strm->desc.format_id == AAUDIO_FORMAT_LPCM) + sdev->is_pcm = true; +} + +static void aaudio_free_dev(struct aaudio_subdevice *sdev) +{ + size_t i; + for (i = 0; i < sdev->in_stream_cnt; i++) { + if (sdev->in_streams[i].alsa_hw_desc) + kfree(sdev->in_streams[i].alsa_hw_desc); + if (sdev->in_streams[i].buffers) + kfree(sdev->in_streams[i].buffers); + } + for (i = 0; i < sdev->out_stream_cnt; i++) { + if (sdev->out_streams[i].alsa_hw_desc) + kfree(sdev->out_streams[i].alsa_hw_desc); + if (sdev->out_streams[i].buffers) + kfree(sdev->out_streams[i].buffers); + } + kfree(sdev); +} + +static struct aaudio_subdevice *aaudio_find_dev_by_dev_id(struct aaudio_device *a, aaudio_device_id_t dev_id) +{ + struct aaudio_subdevice *sdev; + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (dev_id == sdev->dev_id) + return sdev; + } + return NULL; +} + +static struct aaudio_subdevice *aaudio_find_dev_by_uid(struct aaudio_device *a, const char *uid) +{ + struct aaudio_subdevice *sdev; + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (!strcmp(uid, sdev->uid)) + return sdev; + } + return NULL; +} + +static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm); +static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm); + +static int aaudio_init_bs(struct aaudio_device *a) +{ + int i, j; + struct aaudio_buffer_struct_device *dev; + struct aaudio_subdevice *sdev; + u32 ver, sig, bs_base; + + ver = ioread32(&a->reg_mem_gpr[0]); + if (ver < 3) { + dev_err(a->dev, "aaudio: Bad GPR version (%u)", ver); + return -EINVAL; + } + sig = ioread32(&a->reg_mem_gpr[1]); + if (sig != AAUDIO_SIG) { + dev_err(a->dev, "aaudio: Bad GPR sig (%x)", sig); + return -EINVAL; + } + bs_base = ioread32(&a->reg_mem_gpr[2]); + a->bs = (struct aaudio_buffer_struct *) ((u8 *) a->reg_mem_bs + bs_base); + if (a->bs->signature != AAUDIO_SIG) { + dev_err(a->dev, "aaudio: Bad BufferStruct sig (%x)", a->bs->signature); + return -EINVAL; + } + dev_info(a->dev, "aaudio: BufferStruct ver = %i\n", a->bs->version); + dev_info(a->dev, "aaudio: Num devices = %i\n", a->bs->num_devices); + for (i = 0; i < a->bs->num_devices; i++) { + dev = &a->bs->devices[i]; + dev_info(a->dev, "aaudio: Device %i %s\n", i, dev->name); + + sdev = aaudio_find_dev_by_uid(a, dev->name); + if (!sdev) { + dev_err(a->dev, "aaudio: Subdevice not found for BufferStruct device %s\n", dev->name); + continue; + } + sdev->buf_id = (u8) i; + dev->num_input_streams = 0; + for (j = 0; j < dev->num_output_streams; j++) { + dev_info(a->dev, "aaudio: Device %i Stream %i: Output; Buffer Count = %i\n", i, j, + dev->output_streams[j].num_buffers); + if (j < sdev->out_stream_cnt) + aaudio_init_bs_stream(a, &sdev->out_streams[j], &dev->output_streams[j]); + } + } + + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (sdev->buf_id != AAUDIO_BUFFER_ID_NONE) + continue; + sdev->buf_id = i; + dev_info(a->dev, "aaudio: Created device %i %s\n", i, sdev->uid); + strcpy(a->bs->devices[i].name, sdev->uid); + a->bs->devices[i].num_input_streams = 0; + a->bs->devices[i].num_output_streams = 0; + a->bs->num_devices = ++i; + } + list_for_each_entry(sdev, &a->subdevice_list, list) { + if (sdev->in_stream_cnt == 1) { + dev_info(a->dev, "aaudio: Device %i Host Stream; Input\n", sdev->buf_id); + aaudio_init_bs_stream_host(a, &sdev->in_streams[0], &a->bs->devices[sdev->buf_id].input_streams[0]); + a->bs->devices[sdev->buf_id].num_input_streams = 1; + wmb(); + + if (aaudio_cmd_set_input_stream_address_ranges(a, sdev->dev_id)) { + dev_err(a->dev, "aaudio: Failed to set input stream address ranges\n"); + } + } + } + + return 0; +} + +static void aaudio_init_bs_stream(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm) +{ + size_t i; + strm->buffer_cnt = bs_strm->num_buffers; + if (bs_strm->num_buffers > AAUDIO_DEIVCE_MAX_BUFFER_COUNT) { + dev_warn(a->dev, "BufferStruct buffer count %u exceeds driver limit of %u\n", bs_strm->num_buffers, + AAUDIO_DEIVCE_MAX_BUFFER_COUNT); + strm->buffer_cnt = AAUDIO_DEIVCE_MAX_BUFFER_COUNT; + } + if (!strm->buffer_cnt) + return; + strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); + if (!strm->buffers) { + dev_err(a->dev, "Buffer list allocation failed\n"); + return; + } + for (i = 0; i < strm->buffer_cnt; i++) { + strm->buffers[i].dma_addr = a->reg_mem_bs_dma + (dma_addr_t) bs_strm->buffers[i].address; + strm->buffers[i].ptr = a->reg_mem_bs + bs_strm->buffers[i].address; + strm->buffers[i].size = bs_strm->buffers[i].size; + } + + if (strm->buffer_cnt == 1) { + strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); + if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { + kfree(strm->alsa_hw_desc); + strm->alsa_hw_desc = NULL; + } + } +} + +static void aaudio_init_bs_stream_host(struct aaudio_device *a, struct aaudio_stream *strm, + struct aaudio_buffer_struct_stream *bs_strm) +{ + size_t size; + dma_addr_t dma_addr; + void *dma_ptr; + size = strm->desc.bytes_per_packet * 16640; + dma_ptr = dma_alloc_coherent(&a->pci->dev, size, &dma_addr, GFP_KERNEL); + if (!dma_ptr) { + dev_err(a->dev, "dma_alloc_coherent failed\n"); + return; + } + bs_strm->buffers[0].address = dma_addr; + bs_strm->buffers[0].size = size; + bs_strm->num_buffers = 1; + + memset(dma_ptr, 0, size); + + strm->buffer_cnt = 1; + strm->buffers = kmalloc_array(strm->buffer_cnt, sizeof(struct aaudio_dma_buf), GFP_KERNEL); + if (!strm->buffers) { + dev_err(a->dev, "Buffer list allocation failed\n"); + return; + } + strm->buffers[0].dma_addr = dma_addr; + strm->buffers[0].ptr = dma_ptr; + strm->buffers[0].size = size; + + strm->alsa_hw_desc = kmalloc(sizeof(struct snd_pcm_hardware), GFP_KERNEL); + if (aaudio_create_hw_info(&strm->desc, strm->alsa_hw_desc, strm->buffers[0].size)) { + kfree(strm->alsa_hw_desc); + strm->alsa_hw_desc = NULL; + } +} + +static void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg); + +void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_send_ctx sctx; + struct aaudio_msg_base base; + if (aaudio_msg_read_base(msg, &base)) + return; + switch (base.msg) { + case AAUDIO_MSG_NOTIFICATION_BOOT: + dev_info(a->dev, "Received boot notification from remote\n"); + + /* Resend the alive notify */ + if (aaudio_send(a, &sctx, 500, + aaudio_msg_write_alive_notification, 1, 3)) { + pr_err("Sending alive notification failed\n"); + } + break; + case AAUDIO_MSG_NOTIFICATION_ALIVE: + dev_info(a->dev, "Received alive notification from remote\n"); + complete_all(&a->remote_alive); + break; + case AAUDIO_MSG_PROPERTY_CHANGED: + aaudio_handle_prop_change(a, msg); + break; + default: + dev_info(a->dev, "Unhandled notification %i", base.msg); + break; + } +} + +struct aaudio_prop_change_work_struct { + struct work_struct ws; + struct aaudio_device *a; + aaudio_device_id_t dev; + aaudio_object_id_t obj; + struct aaudio_prop_addr prop; +}; + +static void aaudio_handle_jack_connection_change(struct aaudio_subdevice *sdev) +{ + u32 plugged; + if (!sdev->jack) + return; + /* NOTE: Apple made the plug status scoped to the input and output streams. This makes no sense for us, so I just + * always pick the OUTPUT status. */ + if (aaudio_cmd_get_primitive_property(sdev->a, sdev->dev_id, sdev->dev_id, + AAUDIO_PROP(AAUDIO_PROP_SCOPE_OUTPUT, AAUDIO_PROP_JACK_PLUGGED, 0), NULL, 0, &plugged, sizeof(plugged))) { + dev_err(sdev->a->dev, "Failed to get jack enable status\n"); + return; + } + dev_dbg(sdev->a->dev, "Jack is now %s\n", plugged ? "plugged" : "unplugged"); + snd_jack_report(sdev->jack, plugged ? sdev->jack->type : 0); +} + +void aaudio_handle_prop_change_work(struct work_struct *ws) +{ + struct aaudio_prop_change_work_struct *work = container_of(ws, struct aaudio_prop_change_work_struct, ws); + struct aaudio_subdevice *sdev; + + sdev = aaudio_find_dev_by_dev_id(work->a, work->dev); + if (!sdev) { + dev_err(work->a->dev, "Property notification change: device not found\n"); + goto done; + } + dev_dbg(work->a->dev, "Property changed for device: %s\n", sdev->uid); + + if (work->prop.scope == AAUDIO_PROP_SCOPE_OUTPUT && work->prop.selector == AAUDIO_PROP_JACK_PLUGGED) { + aaudio_handle_jack_connection_change(sdev); + } + +done: + kfree(work); +} + +void aaudio_handle_prop_change(struct aaudio_device *a, struct aaudio_msg *msg) +{ + /* NOTE: This is a scheduled work because this callback will generally need to query device information and this + * is not possible when we are in the reply parsing code's context. */ + struct aaudio_prop_change_work_struct *work; + work = kmalloc(sizeof(struct aaudio_prop_change_work_struct), GFP_KERNEL); + work->a = a; + INIT_WORK(&work->ws, aaudio_handle_prop_change_work); + aaudio_msg_read_property_changed(msg, &work->dev, &work->obj, &work->prop); + schedule_work(&work->ws); +} + +#define aaudio_send_cmd_response(a, sctx, msg, fn, ...) \ + if (aaudio_send_with_tag(a, sctx, ((struct aaudio_msg_header *) msg->data)->tag, 500, fn, ##__VA_ARGS__)) \ + pr_err("aaudio: Failed to reply to a command\n"); + +void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg) +{ + ktime_t time_os = ktime_get_boottime(); + struct aaudio_send_ctx sctx; + struct aaudio_subdevice *sdev; + u64 devid, timestamp, update_seed; + aaudio_msg_read_update_timestamp(msg, &devid, ×tamp, &update_seed); + dev_dbg(a->dev, "Received timestamp update for dev=%llx ts=%llx seed=%llx\n", devid, timestamp, update_seed); + + sdev = aaudio_find_dev_by_dev_id(a, devid); + aaudio_handle_timestamp(sdev, time_os, timestamp); + + aaudio_send_cmd_response(a, &sctx, msg, + aaudio_msg_write_update_timestamp_response); +} + +void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_msg_base base; + if (aaudio_msg_read_base(msg, &base)) + return; + switch (base.msg) { + case AAUDIO_MSG_UPDATE_TIMESTAMP: + aaudio_handle_cmd_timestamp(a, msg); + break; + default: + dev_info(a->dev, "Unhandled device command %i", base.msg); + break; + } +} + +static struct pci_device_id aaudio_ids[ ] = { + { PCI_DEVICE(PCI_VENDOR_ID_APPLE, 0x1803) }, + { 0, }, +}; + +struct dev_pm_ops aaudio_pci_driver_pm = { + .suspend = aaudio_suspend, + .resume = aaudio_resume +}; +struct pci_driver aaudio_pci_driver = { + .name = "aaudio", + .id_table = aaudio_ids, + .probe = aaudio_probe, + .remove = aaudio_remove, + .driver = { + .pm = &aaudio_pci_driver_pm + } +}; + + +int aaudio_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&aaudio_chrdev, 0, 1, "aaudio"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + aaudio_class = class_create(THIS_MODULE, "aaudio"); +#else + aaudio_class = class_create("aaudio"); +#endif + if (IS_ERR(aaudio_class)) { + result = PTR_ERR(aaudio_class); + goto fail_class; + } + + result = pci_register_driver(&aaudio_pci_driver); + if (result) + goto fail_drv; + return 0; + +fail_drv: + pci_unregister_driver(&aaudio_pci_driver); +fail_class: + class_destroy(aaudio_class); +fail_chrdev: + unregister_chrdev_region(aaudio_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} + +void aaudio_module_exit(void) +{ + pci_unregister_driver(&aaudio_pci_driver); + class_destroy(aaudio_class); + unregister_chrdev_region(aaudio_chrdev, 1); +} + +struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[] = { + {"Speaker", 0}, + {"Digital Mic", 1}, + {"Codec Output", 2}, + {"Codec Input", 3}, + {"Bridge Loopback", 4}, + {} +}; + +module_param_named(index, aaudio_alsa_index, int, 0444); +MODULE_PARM_DESC(index, "Index value for Apple Internal Audio soundcard."); +module_param_named(id, aaudio_alsa_id, charp, 0444); +MODULE_PARM_DESC(id, "ID string for Apple Internal Audio soundcard."); diff --git a/drivers/staging/apple-bce/audio/audio.h b/drivers/staging/apple-bce/audio/audio.h new file mode 100644 index 000000000000..004bc1e22ea4 --- /dev/null +++ b/drivers/staging/apple-bce/audio/audio.h @@ -0,0 +1,125 @@ +#ifndef AAUDIO_H +#define AAUDIO_H + +#include +#include +#include "../apple_bce.h" +#include "protocol_bce.h" +#include "description.h" + +#define AAUDIO_SIG 0x19870423 + +#define AAUDIO_DEVICE_MAX_UID_LEN 128 +#define AAUDIO_DEIVCE_MAX_INPUT_STREAMS 1 +#define AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS 1 +#define AAUDIO_DEIVCE_MAX_BUFFER_COUNT 1 + +#define AAUDIO_BUFFER_ID_NONE 0xffu + +struct snd_card; +struct snd_pcm; +struct snd_pcm_hardware; +struct snd_jack; + +struct __attribute__((packed)) __attribute__((aligned(4))) aaudio_buffer_struct_buffer { + size_t address; + size_t size; + size_t pad[4]; +}; +struct aaudio_buffer_struct_stream { + u8 num_buffers; + struct aaudio_buffer_struct_buffer buffers[100]; + char filler[32]; +}; +struct aaudio_buffer_struct_device { + char name[128]; + u8 num_input_streams; + u8 num_output_streams; + struct aaudio_buffer_struct_stream input_streams[5]; + struct aaudio_buffer_struct_stream output_streams[5]; + char filler[128]; +}; +struct aaudio_buffer_struct { + u32 version; + u32 signature; + u32 flags; + u8 num_devices; + struct aaudio_buffer_struct_device devices[20]; +}; + +struct aaudio_device; +struct aaudio_dma_buf { + dma_addr_t dma_addr; + void *ptr; + size_t size; +}; +struct aaudio_stream { + aaudio_object_id_t id; + size_t buffer_cnt; + struct aaudio_dma_buf *buffers; + + struct aaudio_apple_description desc; + struct snd_pcm_hardware *alsa_hw_desc; + u32 latency; + + bool waiting_for_first_ts; + + ktime_t remote_timestamp; + snd_pcm_sframes_t frame_min; + int started; +}; +struct aaudio_subdevice { + struct aaudio_device *a; + struct list_head list; + aaudio_device_id_t dev_id; + u32 in_latency, out_latency; + u8 buf_id; + int alsa_id; + char uid[AAUDIO_DEVICE_MAX_UID_LEN + 1]; + size_t in_stream_cnt; + struct aaudio_stream in_streams[AAUDIO_DEIVCE_MAX_INPUT_STREAMS]; + size_t out_stream_cnt; + struct aaudio_stream out_streams[AAUDIO_DEIVCE_MAX_OUTPUT_STREAMS]; + bool is_pcm; + struct snd_pcm *pcm; + struct snd_jack *jack; +}; +struct aaudio_alsa_pcm_id_mapping { + const char *name; + int alsa_id; +}; + +struct aaudio_device { + struct pci_dev *pci; + dev_t devt; + struct device *dev; + void __iomem *reg_mem_bs; + dma_addr_t reg_mem_bs_dma; + void __iomem *reg_mem_cfg; + + u32 __iomem *reg_mem_gpr; + + struct aaudio_buffer_struct *bs; + + struct apple_bce_device *bce; + struct aaudio_bce bcem; + + struct snd_card *card; + + struct list_head subdevice_list; + int next_alsa_id; + + struct completion remote_alive; +}; + +void aaudio_handle_notification(struct aaudio_device *a, struct aaudio_msg *msg); +void aaudio_handle_prop_change_work(struct work_struct *ws); +void aaudio_handle_cmd_timestamp(struct aaudio_device *a, struct aaudio_msg *msg); +void aaudio_handle_command(struct aaudio_device *a, struct aaudio_msg *msg); + +int aaudio_module_init(void); +void aaudio_module_exit(void); + +extern struct aaudio_alsa_pcm_id_mapping aaudio_alsa_id_mappings[]; + +#endif //AAUDIO_H diff --git a/drivers/staging/apple-bce/audio/description.h b/drivers/staging/apple-bce/audio/description.h new file mode 100644 index 000000000000..dfef3ab68f27 --- /dev/null +++ b/drivers/staging/apple-bce/audio/description.h @@ -0,0 +1,42 @@ +#ifndef AAUDIO_DESCRIPTION_H +#define AAUDIO_DESCRIPTION_H + +#include + +struct aaudio_apple_description { + u64 sample_rate_double; + u32 format_id; + u32 format_flags; + u32 bytes_per_packet; + u32 frames_per_packet; + u32 bytes_per_frame; + u32 channels_per_frame; + u32 bits_per_channel; + u32 reserved; +}; + +enum { + AAUDIO_FORMAT_LPCM = 0x6c70636d // 'lpcm' +}; + +enum { + AAUDIO_FORMAT_FLAG_FLOAT = 1, + AAUDIO_FORMAT_FLAG_BIG_ENDIAN = 2, + AAUDIO_FORMAT_FLAG_SIGNED = 4, + AAUDIO_FORMAT_FLAG_PACKED = 8, + AAUDIO_FORMAT_FLAG_ALIGNED_HIGH = 16, + AAUDIO_FORMAT_FLAG_NON_INTERLEAVED = 32, + AAUDIO_FORMAT_FLAG_NON_MIXABLE = 64 +}; + +static inline u64 aaudio_double_to_u64(u64 d) +{ + u8 sign = (u8) ((d >> 63) & 1); + s32 exp = (s32) ((d >> 52) & 0x7ff) - 1023; + u64 fr = d & ((1LL << 52) - 1); + if (sign || exp < 0) + return 0; + return (u64) ((1LL << exp) + (fr >> (52 - exp))); +} + +#endif //AAUDIO_DESCRIPTION_H diff --git a/drivers/staging/apple-bce/audio/pcm.c b/drivers/staging/apple-bce/audio/pcm.c new file mode 100644 index 000000000000..1026e10a9ac5 --- /dev/null +++ b/drivers/staging/apple-bce/audio/pcm.c @@ -0,0 +1,308 @@ +#include "pcm.h" +#include "audio.h" + +static u64 aaudio_get_alsa_fmtbit(struct aaudio_apple_description *desc) +{ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_FLOAT) { + if (desc->bits_per_channel == 32) { + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) + return SNDRV_PCM_FMTBIT_FLOAT_BE; + else + return SNDRV_PCM_FMTBIT_FLOAT_LE; + } else if (desc->bits_per_channel == 64) { + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) + return SNDRV_PCM_FMTBIT_FLOAT64_BE; + else + return SNDRV_PCM_FMTBIT_FLOAT64_LE; + } else { + pr_err("aaudio: unsupported bits per channel for float format: %u\n", desc->bits_per_channel); + return 0; + } + } +#define DEFINE_BPC_OPTION(val, b) \ + case val: \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_BIG_ENDIAN) { \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ + return SNDRV_PCM_FMTBIT_S ## b ## BE; \ + else \ + return SNDRV_PCM_FMTBIT_U ## b ## BE; \ + } else { \ + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) \ + return SNDRV_PCM_FMTBIT_S ## b ## LE; \ + else \ + return SNDRV_PCM_FMTBIT_U ## b ## LE; \ + } + if (desc->format_flags & AAUDIO_FORMAT_FLAG_PACKED) { + switch (desc->bits_per_channel) { + case 8: + case 16: + case 32: + break; + DEFINE_BPC_OPTION(24, 24_3) + default: + pr_err("aaudio: unsupported bits per channel for packed format: %u\n", desc->bits_per_channel); + return 0; + } + } + if (desc->format_flags & AAUDIO_FORMAT_FLAG_ALIGNED_HIGH) { + switch (desc->bits_per_channel) { + DEFINE_BPC_OPTION(24, 32_) + default: + pr_err("aaudio: unsupported bits per channel for high-aligned format: %u\n", desc->bits_per_channel); + return 0; + } + } + switch (desc->bits_per_channel) { + case 8: + if (desc->format_flags & AAUDIO_FORMAT_FLAG_SIGNED) + return SNDRV_PCM_FMTBIT_S8; + else + return SNDRV_PCM_FMTBIT_U8; + DEFINE_BPC_OPTION(16, 16_) + DEFINE_BPC_OPTION(24, 24_) + DEFINE_BPC_OPTION(32, 32_) + default: + pr_err("aaudio: unsupported bits per channel: %u\n", desc->bits_per_channel); + return 0; + } +} +int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, + size_t buf_size) +{ + uint rate; + alsa_hw->info = (SNDRV_PCM_INFO_MMAP | + SNDRV_PCM_INFO_BLOCK_TRANSFER | + SNDRV_PCM_INFO_MMAP_VALID | + SNDRV_PCM_INFO_DOUBLE); + if (desc->format_flags & AAUDIO_FORMAT_FLAG_NON_MIXABLE) + pr_warn("aaudio: unsupported hw flag: NON_MIXABLE\n"); + if (!(desc->format_flags & AAUDIO_FORMAT_FLAG_NON_INTERLEAVED)) + alsa_hw->info |= SNDRV_PCM_INFO_INTERLEAVED; + alsa_hw->formats = aaudio_get_alsa_fmtbit(desc); + if (!alsa_hw->formats) + return -EINVAL; + rate = (uint) aaudio_double_to_u64(desc->sample_rate_double); + alsa_hw->rates = snd_pcm_rate_to_rate_bit(rate); + alsa_hw->rate_min = rate; + alsa_hw->rate_max = rate; + alsa_hw->channels_min = desc->channels_per_frame; + alsa_hw->channels_max = desc->channels_per_frame; + alsa_hw->buffer_bytes_max = buf_size; + alsa_hw->period_bytes_min = desc->bytes_per_packet; + alsa_hw->period_bytes_max = desc->bytes_per_packet; + alsa_hw->periods_min = (uint) (buf_size / desc->bytes_per_packet); + alsa_hw->periods_max = (uint) (buf_size / desc->bytes_per_packet); + pr_debug("aaudio_create_hw_info: format = %llu, rate = %u/%u. channels = %u, periods = %u, period size = %lu\n", + alsa_hw->formats, alsa_hw->rate_min, alsa_hw->rates, alsa_hw->channels_min, alsa_hw->periods_min, + alsa_hw->period_bytes_min); + return 0; +} + +static struct aaudio_stream *aaudio_pcm_stream(struct snd_pcm_substream *substream) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + if (substream->stream == SNDRV_PCM_STREAM_PLAYBACK) + return &sdev->out_streams[substream->number]; + else + return &sdev->in_streams[substream->number]; +} + +static int aaudio_pcm_open(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_open\n"); + substream->runtime->hw = *aaudio_pcm_stream(substream)->alsa_hw_desc; + + return 0; +} + +static int aaudio_pcm_close(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_close\n"); + return 0; +} + +static int aaudio_pcm_prepare(struct snd_pcm_substream *substream) +{ + return 0; +} + +static int aaudio_pcm_hw_params(struct snd_pcm_substream *substream, struct snd_pcm_hw_params *hw_params) +{ + struct aaudio_stream *astream = aaudio_pcm_stream(substream); + pr_debug("aaudio_pcm_hw_params\n"); + + if (!astream->buffer_cnt || !astream->buffers) + return -EINVAL; + + substream->runtime->dma_area = astream->buffers[0].ptr; + substream->runtime->dma_addr = astream->buffers[0].dma_addr; + substream->runtime->dma_bytes = astream->buffers[0].size; + return 0; +} + +static int aaudio_pcm_hw_free(struct snd_pcm_substream *substream) +{ + pr_debug("aaudio_pcm_hw_free\n"); + return 0; +} + +static void aaudio_pcm_start(struct snd_pcm_substream *substream) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + void *buf; + size_t s; + ktime_t time_start, time_end; + bool back_buffer; + time_start = ktime_get(); + + back_buffer = (substream->stream == SNDRV_PCM_STREAM_PLAYBACK); + + if (back_buffer) { + s = frames_to_bytes(substream->runtime, substream->runtime->control->appl_ptr); + buf = kmalloc(s, GFP_KERNEL); + memcpy_fromio(buf, substream->runtime->dma_area, s); + time_end = ktime_get(); + pr_debug("aaudio: Backed up the buffer in %lluns [%li]\n", ktime_to_ns(time_end - time_start), + substream->runtime->control->appl_ptr); + } + + stream->waiting_for_first_ts = true; + stream->frame_min = stream->latency; + + aaudio_cmd_start_io(sdev->a, sdev->dev_id); + if (back_buffer) + memcpy_toio(substream->runtime->dma_area, buf, s); + + time_end = ktime_get(); + pr_debug("aaudio: Started the audio device in %lluns\n", ktime_to_ns(time_end - time_start)); +} + +static int aaudio_pcm_trigger(struct snd_pcm_substream *substream, int cmd) +{ + struct aaudio_subdevice *sdev = snd_pcm_substream_chip(substream); + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + pr_debug("aaudio_pcm_trigger %x\n", cmd); + + /* We only supports triggers on the #0 buffer */ + if (substream->number != 0) + return 0; + switch (cmd) { + case SNDRV_PCM_TRIGGER_START: + aaudio_pcm_start(substream); + stream->started = 1; + break; + case SNDRV_PCM_TRIGGER_STOP: + aaudio_cmd_stop_io(sdev->a, sdev->dev_id); + stream->started = 0; + break; + default: + return -EINVAL; + } + return 0; +} + +static snd_pcm_uframes_t aaudio_pcm_pointer(struct snd_pcm_substream *substream) +{ + struct aaudio_stream *stream = aaudio_pcm_stream(substream); + ktime_t time_from_start; + snd_pcm_sframes_t frames; + snd_pcm_sframes_t buffer_time_length; + + if (!stream->started || stream->waiting_for_first_ts) { + pr_warn("aaudio_pcm_pointer while not started\n"); + return 0; + } + + /* Approximate the pointer based on the last received timestamp */ + time_from_start = ktime_get_boottime() - stream->remote_timestamp; + buffer_time_length = NSEC_PER_SEC * substream->runtime->buffer_size / substream->runtime->rate; + frames = (ktime_to_ns(time_from_start) % buffer_time_length) * substream->runtime->buffer_size / buffer_time_length; + if (ktime_to_ns(time_from_start) < buffer_time_length) { + if (frames < stream->frame_min) + frames = stream->frame_min; + else + stream->frame_min = 0; + } else { + if (ktime_to_ns(time_from_start) < 2 * buffer_time_length) + stream->frame_min = frames; + else + stream->frame_min = 0; /* Heavy desync */ + } + frames -= stream->latency; + if (frames < 0) + frames += ((-frames - 1) / substream->runtime->buffer_size + 1) * substream->runtime->buffer_size; + return (snd_pcm_uframes_t) frames; +} + +static struct snd_pcm_ops aaudio_pcm_ops = { + .open = aaudio_pcm_open, + .close = aaudio_pcm_close, + .ioctl = snd_pcm_lib_ioctl, + .hw_params = aaudio_pcm_hw_params, + .hw_free = aaudio_pcm_hw_free, + .prepare = aaudio_pcm_prepare, + .trigger = aaudio_pcm_trigger, + .pointer = aaudio_pcm_pointer, + .mmap = snd_pcm_lib_mmap_iomem +}; + +int aaudio_create_pcm(struct aaudio_subdevice *sdev) +{ + struct snd_pcm *pcm; + struct aaudio_alsa_pcm_id_mapping *id_mapping; + int err; + + if (!sdev->is_pcm || (sdev->in_stream_cnt == 0 && sdev->out_stream_cnt == 0)) { + return -EINVAL; + } + + for (id_mapping = aaudio_alsa_id_mappings; id_mapping->name; id_mapping++) { + if (!strcmp(sdev->uid, id_mapping->name)) { + sdev->alsa_id = id_mapping->alsa_id; + break; + } + } + if (!id_mapping->name) + sdev->alsa_id = sdev->a->next_alsa_id++; + err = snd_pcm_new(sdev->a->card, sdev->uid, sdev->alsa_id, + (int) sdev->out_stream_cnt, (int) sdev->in_stream_cnt, &pcm); + if (err < 0) + return err; + pcm->private_data = sdev; + pcm->nonatomic = 1; + sdev->pcm = pcm; + strcpy(pcm->name, sdev->uid); + snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_PLAYBACK, &aaudio_pcm_ops); + snd_pcm_set_ops(pcm, SNDRV_PCM_STREAM_CAPTURE, &aaudio_pcm_ops); + return 0; +} + +static void aaudio_handle_stream_timestamp(struct snd_pcm_substream *substream, ktime_t timestamp) +{ + unsigned long flags; + struct aaudio_stream *stream; + + stream = aaudio_pcm_stream(substream); + snd_pcm_stream_lock_irqsave(substream, flags); + stream->remote_timestamp = timestamp; + if (stream->waiting_for_first_ts) { + stream->waiting_for_first_ts = false; + snd_pcm_stream_unlock_irqrestore(substream, flags); + return; + } + snd_pcm_stream_unlock_irqrestore(substream, flags); + snd_pcm_period_elapsed(substream); +} + +void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp) +{ + struct snd_pcm_substream *substream; + + substream = sdev->pcm->streams[SNDRV_PCM_STREAM_PLAYBACK].substream; + if (substream) + aaudio_handle_stream_timestamp(substream, dev_timestamp); + substream = sdev->pcm->streams[SNDRV_PCM_STREAM_CAPTURE].substream; + if (substream) + aaudio_handle_stream_timestamp(substream, os_timestamp); +} diff --git a/drivers/staging/apple-bce/audio/pcm.h b/drivers/staging/apple-bce/audio/pcm.h new file mode 100644 index 000000000000..ea5f35fbe408 --- /dev/null +++ b/drivers/staging/apple-bce/audio/pcm.h @@ -0,0 +1,16 @@ +#ifndef AAUDIO_PCM_H +#define AAUDIO_PCM_H + +#include +#include + +struct aaudio_subdevice; +struct aaudio_apple_description; +struct snd_pcm_hardware; + +int aaudio_create_hw_info(struct aaudio_apple_description *desc, struct snd_pcm_hardware *alsa_hw, size_t buf_size); +int aaudio_create_pcm(struct aaudio_subdevice *sdev); + +void aaudio_handle_timestamp(struct aaudio_subdevice *sdev, ktime_t os_timestamp, u64 dev_timestamp); + +#endif //AAUDIO_PCM_H diff --git a/drivers/staging/apple-bce/audio/protocol.c b/drivers/staging/apple-bce/audio/protocol.c new file mode 100644 index 000000000000..2314813aeead --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol.c @@ -0,0 +1,347 @@ +#include "protocol.h" +#include "protocol_bce.h" +#include "audio.h" + +int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base) +{ + if (msg->size < sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base) * 2) + return -EINVAL; + *base = *((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1)); + return 0; +} + +#define READ_START(type) \ + size_t offset = sizeof(struct aaudio_msg_header) + sizeof(struct aaudio_msg_base); (void)offset; \ + if (((struct aaudio_msg_base *) ((struct aaudio_msg_header *) msg->data + 1))->msg != type) \ + return -EINVAL; +#define READ_DEVID_VAR(devid) *devid = ((struct aaudio_msg_header *) msg->data)->device_id +#define READ_VAL(type) ({ offset += sizeof(type); *((type *) ((u8 *) msg->data + offset - sizeof(type))); }) +#define READ_VAR(type, var) *var = READ_VAL(type) + +int aaudio_msg_read_start_io_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_START_IO_RESPONSE); + return 0; +} + +int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_STOP_IO_RESPONSE); + return 0; +} + +int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, + u64 *timestamp, u64 *update_seed) +{ + READ_START(AAUDIO_MSG_UPDATE_TIMESTAMP); + READ_DEVID_VAR(devid); + READ_VAR(u64, timestamp); + READ_VAR(u64, update_seed); + return 0; +} + +int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop, void **data, u64 *data_size) +{ + READ_START(AAUDIO_MSG_GET_PROPERTY_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + READ_VAR(u64, data_size); + *data = ((u8 *) msg->data + offset); + /* offset += data_size; */ + return 0; +} + +int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj) +{ + READ_START(AAUDIO_MSG_SET_PROPERTY_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + return 0; +} + +int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop) +{ + READ_START(AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + return 0; +} + +int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop) +{ + READ_START(AAUDIO_MSG_PROPERTY_CHANGED); + READ_DEVID_VAR(devid); + READ_VAR(aaudio_object_id_t, obj); + READ_VAR(u32, &prop->element); + READ_VAR(u32, &prop->scope); + READ_VAR(u32, &prop->selector); + return 0; +} + +int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE); + return 0; +} + +int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) +{ + READ_START(AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE); + READ_VAR(u64, str_cnt); + *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += str_cnt * sizeof(aaudio_object_id_t); */ + return 0; +} + +int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt) +{ + READ_START(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE); + READ_VAR(u64, str_cnt); + *str_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += str_cnt * sizeof(aaudio_object_id_t); */ + return 0; +} + +int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg) +{ + READ_START(AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE); + return 0; +} + +int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt) +{ + READ_START(AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE); + READ_VAR(u64, dev_cnt); + *dev_l = (aaudio_device_id_t *) ((u8 *) msg->data + offset); + /* offset += dev_cnt * sizeof(aaudio_device_id_t); */ + return 0; +} + +#define WRITE_START_OF_TYPE(typev, devid) \ + size_t offset = sizeof(struct aaudio_msg_header); (void) offset; \ + ((struct aaudio_msg_header *) msg->data)->type = (typev); \ + ((struct aaudio_msg_header *) msg->data)->device_id = (devid); +#define WRITE_START_COMMAND(devid) WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_COMMAND, devid) +#define WRITE_START_RESPONSE() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_RESPONSE, 0) +#define WRITE_START_NOTIFICATION() WRITE_START_OF_TYPE(AAUDIO_MSG_TYPE_NOTIFICATION, 0) +#define WRITE_VAL(type, value) { *((type *) ((u8 *) msg->data + offset)) = value; offset += sizeof(value); } +#define WRITE_BIN(value, size) { memcpy((u8 *) msg->data + offset, value, size); offset += size; } +#define WRITE_BASE(type) WRITE_VAL(u32, type) WRITE_VAL(u32, 0) +#define WRITE_END() { msg->size = offset; } + +void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_START_IO); + WRITE_END(); +} + +void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_STOP_IO); + WRITE_END(); +} + +void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_GET_PROPERTY); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_VAL(u64, qualifier_size); + WRITE_BIN(qualifier, qualifier_size); + WRITE_END(); +} + +void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_SET_PROPERTY); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_VAL(u64, data_size); + WRITE_BIN(data, data_size); + WRITE_VAL(u64, qualifier_size); + WRITE_BIN(qualifier, qualifier_size); + WRITE_END(); +} + +void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop) +{ + WRITE_START_COMMAND(dev); + WRITE_BASE(AAUDIO_MSG_PROPERTY_LISTENER); + WRITE_VAL(aaudio_object_id_t, obj); + WRITE_VAL(u32, prop.element); + WRITE_VAL(u32, prop.scope); + WRITE_VAL(u32, prop.selector); + WRITE_END(); +} + +void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES); + WRITE_END(); +} + +void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_GET_INPUT_STREAM_LIST); + WRITE_END(); +} + +void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid) +{ + WRITE_START_COMMAND(devid); + WRITE_BASE(AAUDIO_MSG_GET_OUTPUT_STREAM_LIST); + WRITE_END(); +} + +void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode) +{ + WRITE_START_COMMAND(0); + WRITE_BASE(AAUDIO_MSG_SET_REMOTE_ACCESS); + WRITE_VAL(u64, mode); + WRITE_END(); +} + +void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver) +{ + WRITE_START_NOTIFICATION(); + WRITE_BASE(AAUDIO_MSG_NOTIFICATION_ALIVE); + WRITE_VAL(u32, proto_ver); + WRITE_VAL(u32, msg_ver); + WRITE_END(); +} + +void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg) +{ + WRITE_START_RESPONSE(); + WRITE_BASE(AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE); + WRITE_END(); +} + +void aaudio_msg_write_get_device_list(struct aaudio_msg *msg) +{ + WRITE_START_COMMAND(0); + WRITE_BASE(AAUDIO_MSG_GET_DEVICE_LIST); + WRITE_END(); +} + +#define CMD_SHARED_VARS_NO_REPLY \ + int status = 0; \ + struct aaudio_send_ctx sctx; +#define CMD_SHARED_VARS \ + CMD_SHARED_VARS_NO_REPLY \ + struct aaudio_msg reply = aaudio_reply_alloc(); \ + struct aaudio_msg *buf = &reply; +#define CMD_SEND_REQUEST(fn, ...) \ + if ((status = aaudio_send_cmd_sync(a, &sctx, buf, 500, fn, ##__VA_ARGS__))) \ + return status; +#define CMD_DEF_SHARED_AND_SEND(fn, ...) \ + CMD_SHARED_VARS \ + CMD_SEND_REQUEST(fn, ##__VA_ARGS__); +#define CMD_DEF_SHARED_NO_REPLY_AND_SEND(fn, ...) \ + CMD_SHARED_VARS_NO_REPLY \ + CMD_SEND_REQUEST(fn, ##__VA_ARGS__); +#define CMD_HNDL_REPLY_NO_FREE(fn, ...) \ + status = fn(buf, ##__VA_ARGS__); \ + return status; +#define CMD_HNDL_REPLY_AND_FREE(fn, ...) \ + status = fn(buf, ##__VA_ARGS__); \ + aaudio_reply_free(&reply); \ + return status; + +int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_start_io, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_start_io_response); +} +int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_stop_io, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_stop_io_response); +} +int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_property, devid, obj, prop, qualifier, qualifier_size); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_property_response, &obj, &prop, data, data_size); +} +int aaudio_cmd_get_primitive_property(struct aaudio_device *a, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) +{ + int status; + struct aaudio_msg reply = aaudio_reply_alloc(); + void *r_data; + u64 r_data_size; + if ((status = aaudio_cmd_get_property(a, &reply, devid, obj, prop, qualifier, qualifier_size, + &r_data, &r_data_size))) + goto finish; + if (r_data_size != data_size) { + status = -EINVAL; + goto finish; + } + memcpy(data, r_data, data_size); +finish: + aaudio_reply_free(&reply); + return status; +} +int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_property, devid, obj, prop, data, data_size, + qualifier, qualifier_size); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_property_response, &obj); +} +int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_property_listener, devid, obj, prop); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_property_listener_response, &obj, &prop); +} +int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_input_stream_address_ranges, devid); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_input_stream_address_ranges_response); +} +int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_input_stream_list, devid); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_input_stream_list_response, str_l, str_cnt); +} +int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_output_stream_list, devid); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_output_stream_list_response, str_l, str_cnt); +} +int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode) +{ + CMD_DEF_SHARED_AND_SEND(aaudio_msg_write_set_remote_access, mode); + CMD_HNDL_REPLY_AND_FREE(aaudio_msg_read_set_remote_access_response); +} +int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t **dev_l, u64 *dev_cnt) +{ + CMD_DEF_SHARED_NO_REPLY_AND_SEND(aaudio_msg_write_get_device_list); + CMD_HNDL_REPLY_NO_FREE(aaudio_msg_read_get_device_list_response, dev_l, dev_cnt); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/audio/protocol.h b/drivers/staging/apple-bce/audio/protocol.h new file mode 100644 index 000000000000..3427486f3f57 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol.h @@ -0,0 +1,147 @@ +#ifndef AAUDIO_PROTOCOL_H +#define AAUDIO_PROTOCOL_H + +#include + +struct aaudio_device; + +typedef u64 aaudio_device_id_t; +typedef u64 aaudio_object_id_t; + +struct aaudio_msg { + void *data; + size_t size; +}; + +struct __attribute__((packed)) aaudio_msg_header { + char tag[4]; + u8 type; + aaudio_device_id_t device_id; // Idk, use zero for commands? +}; +struct __attribute__((packed)) aaudio_msg_base { + u32 msg; + u32 status; +}; + +struct aaudio_prop_addr { + u32 scope; + u32 selector; + u32 element; +}; +#define AAUDIO_PROP(scope, sel, el) (struct aaudio_prop_addr) { scope, sel, el } + +enum { + AAUDIO_MSG_TYPE_COMMAND = 1, + AAUDIO_MSG_TYPE_RESPONSE = 2, + AAUDIO_MSG_TYPE_NOTIFICATION = 3 +}; + +enum { + AAUDIO_MSG_START_IO = 0, + AAUDIO_MSG_START_IO_RESPONSE = 1, + AAUDIO_MSG_STOP_IO = 2, + AAUDIO_MSG_STOP_IO_RESPONSE = 3, + AAUDIO_MSG_UPDATE_TIMESTAMP = 4, + AAUDIO_MSG_GET_PROPERTY = 7, + AAUDIO_MSG_GET_PROPERTY_RESPONSE = 8, + AAUDIO_MSG_SET_PROPERTY = 9, + AAUDIO_MSG_SET_PROPERTY_RESPONSE = 10, + AAUDIO_MSG_PROPERTY_LISTENER = 11, + AAUDIO_MSG_PROPERTY_LISTENER_RESPONSE = 12, + AAUDIO_MSG_PROPERTY_CHANGED = 13, + AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES = 18, + AAUDIO_MSG_SET_INPUT_STREAM_ADDRESS_RANGES_RESPONSE = 19, + AAUDIO_MSG_GET_INPUT_STREAM_LIST = 24, + AAUDIO_MSG_GET_INPUT_STREAM_LIST_RESPONSE = 25, + AAUDIO_MSG_GET_OUTPUT_STREAM_LIST = 26, + AAUDIO_MSG_GET_OUTPUT_STREAM_LIST_RESPONSE = 27, + AAUDIO_MSG_SET_REMOTE_ACCESS = 32, + AAUDIO_MSG_SET_REMOTE_ACCESS_RESPONSE = 33, + AAUDIO_MSG_UPDATE_TIMESTAMP_RESPONSE = 34, + + AAUDIO_MSG_NOTIFICATION_ALIVE = 100, + AAUDIO_MSG_GET_DEVICE_LIST = 101, + AAUDIO_MSG_GET_DEVICE_LIST_RESPONSE = 102, + AAUDIO_MSG_NOTIFICATION_BOOT = 104 +}; + +enum { + AAUDIO_REMOTE_ACCESS_OFF = 0, + AAUDIO_REMOTE_ACCESS_ON = 2 +}; + +enum { + AAUDIO_PROP_SCOPE_GLOBAL = 0x676c6f62, // 'glob' + AAUDIO_PROP_SCOPE_INPUT = 0x696e7074, // 'inpt' + AAUDIO_PROP_SCOPE_OUTPUT = 0x6f757470 // 'outp' +}; + +enum { + AAUDIO_PROP_UID = 0x75696420, // 'uid ' + AAUDIO_PROP_BOOL_VALUE = 0x6263766c, // 'bcvl' + AAUDIO_PROP_JACK_PLUGGED = 0x6a61636b, // 'jack' + AAUDIO_PROP_SEL_VOLUME = 0x64656176, // 'deav' + AAUDIO_PROP_LATENCY = 0x6c746e63, // 'ltnc' + AAUDIO_PROP_PHYS_FORMAT = 0x70667420 // 'pft ' +}; + +int aaudio_msg_read_base(struct aaudio_msg *msg, struct aaudio_msg_base *base); + +int aaudio_msg_read_start_io_response(struct aaudio_msg *msg); +int aaudio_msg_read_stop_io_response(struct aaudio_msg *msg); +int aaudio_msg_read_update_timestamp(struct aaudio_msg *msg, aaudio_device_id_t *devid, + u64 *timestamp, u64 *update_seed); +int aaudio_msg_read_get_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop, void **data, u64 *data_size); +int aaudio_msg_read_set_property_response(struct aaudio_msg *msg, aaudio_object_id_t *obj); +int aaudio_msg_read_property_listener_response(struct aaudio_msg *msg,aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop); +int aaudio_msg_read_property_changed(struct aaudio_msg *msg, aaudio_device_id_t *devid, aaudio_object_id_t *obj, + struct aaudio_prop_addr *prop); +int aaudio_msg_read_set_input_stream_address_ranges_response(struct aaudio_msg *msg); +int aaudio_msg_read_get_input_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_msg_read_get_output_stream_list_response(struct aaudio_msg *msg, aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_msg_read_set_remote_access_response(struct aaudio_msg *msg); +int aaudio_msg_read_get_device_list_response(struct aaudio_msg *msg, aaudio_device_id_t **dev_l, u64 *dev_cnt); + +void aaudio_msg_write_start_io(struct aaudio_msg *msg, aaudio_device_id_t dev); +void aaudio_msg_write_stop_io(struct aaudio_msg *msg, aaudio_device_id_t dev); +void aaudio_msg_write_get_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size); +void aaudio_msg_write_set_property(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *data, u64 data_size, void *qualifier, u64 qualifier_size); +void aaudio_msg_write_property_listener(struct aaudio_msg *msg, aaudio_device_id_t dev, aaudio_object_id_t obj, + struct aaudio_prop_addr prop); +void aaudio_msg_write_set_input_stream_address_ranges(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_get_input_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_get_output_stream_list(struct aaudio_msg *msg, aaudio_device_id_t devid); +void aaudio_msg_write_set_remote_access(struct aaudio_msg *msg, u64 mode); +void aaudio_msg_write_alive_notification(struct aaudio_msg *msg, u32 proto_ver, u32 msg_ver); +void aaudio_msg_write_update_timestamp_response(struct aaudio_msg *msg); +void aaudio_msg_write_get_device_list(struct aaudio_msg *msg); + + +int aaudio_cmd_start_io(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_stop_io(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_get_property(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void **data, u64 *data_size); +int aaudio_cmd_get_primitive_property(struct aaudio_device *a, + aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); +int aaudio_cmd_set_property(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop, void *qualifier, u64 qualifier_size, void *data, u64 data_size); +int aaudio_cmd_property_listener(struct aaudio_device *a, aaudio_device_id_t devid, aaudio_object_id_t obj, + struct aaudio_prop_addr prop); +int aaudio_cmd_set_input_stream_address_ranges(struct aaudio_device *a, aaudio_device_id_t devid); +int aaudio_cmd_get_input_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_cmd_get_output_stream_list(struct aaudio_device *a, struct aaudio_msg *buf, aaudio_device_id_t devid, + aaudio_object_id_t **str_l, u64 *str_cnt); +int aaudio_cmd_set_remote_access(struct aaudio_device *a, u64 mode); +int aaudio_cmd_get_device_list(struct aaudio_device *a, struct aaudio_msg *buf, + aaudio_device_id_t **dev_l, u64 *dev_cnt); + + + +#endif //AAUDIO_PROTOCOL_H diff --git a/drivers/staging/apple-bce/audio/protocol_bce.c b/drivers/staging/apple-bce/audio/protocol_bce.c new file mode 100644 index 000000000000..28f2dfd44d67 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol_bce.c @@ -0,0 +1,226 @@ +#include "protocol_bce.h" + +#include "audio.h" + +static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq); +static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq); +static int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, + bce_sq_completion cfn); +void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count); + +int aaudio_bce_init(struct aaudio_device *dev) +{ + int status; + struct aaudio_bce *bce = &dev->bcem; + bce->cq = bce_create_cq(dev->bce, 0x80); + spin_lock_init(&bce->spinlock); + if (!bce->cq) + return -EINVAL; + if ((status = aaudio_bce_queue_init(dev, &bce->qout, "com.apple.BridgeAudio.IntelToARM", DMA_TO_DEVICE, + aaudio_bce_out_queue_completion))) { + return status; + } + if ((status = aaudio_bce_queue_init(dev, &bce->qin, "com.apple.BridgeAudio.ARMToIntel", DMA_FROM_DEVICE, + aaudio_bce_in_queue_completion))) { + return status; + } + aaudio_bce_in_queue_submit_pending(&bce->qin, bce->qin.el_count); + return 0; +} + +int aaudio_bce_queue_init(struct aaudio_device *dev, struct aaudio_bce_queue *q, const char *name, int direction, + bce_sq_completion cfn) +{ + q->cq = dev->bcem.cq; + q->el_size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; + q->el_count = AAUDIO_BCE_QUEUE_ELEMENT_COUNT; + /* NOTE: The Apple impl uses 0x80 as the queue size, however we use 21 (in fact 20) to simplify the impl */ + q->sq = bce_create_sq(dev->bce, q->cq, name, (u32) (q->el_count + 1), direction, cfn, dev); + if (!q->sq) + return -EINVAL; + + q->data = dma_alloc_coherent(&dev->bce->pci->dev, q->el_size * q->el_count, &q->dma_addr, GFP_KERNEL); + if (!q->data) { + bce_destroy_sq(dev->bce, q->sq); + return -EINVAL; + } + return 0; +} + +static void aaudio_send_create_tag(struct aaudio_bce *b, int *tagn, char tag[4]) +{ + char tag_zero[5]; + b->tag_num = (b->tag_num + 1) % AAUDIO_BCE_QUEUE_TAG_COUNT; + *tagn = b->tag_num; + snprintf(tag_zero, 5, "S%03d", b->tag_num); + *((u32 *) tag) = *((u32 *) tag_zero); +} + +int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag) +{ + int status; + size_t index; + void *dptr; + struct aaudio_msg_header *header; + if ((status = bce_reserve_submission(b->qout.sq, &ctx->timeout))) + return status; + spin_lock_irqsave(&b->spinlock, ctx->irq_flags); + index = b->qout.data_tail; + dptr = (u8 *) b->qout.data + index * b->qout.el_size; + ctx->msg.data = dptr; + header = dptr; + if (tag) + *((u32 *) header->tag) = *((u32 *) tag); + else + aaudio_send_create_tag(b, &ctx->tag_n, header->tag); + return 0; +} + +void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx) +{ + struct bce_qe_submission *s = bce_next_submission(b->qout.sq); +#ifdef DEBUG + pr_debug("aaudio: Sending command data\n"); + print_hex_dump(KERN_DEBUG, "aaudio:OUT ", DUMP_PREFIX_NONE, 32, 1, ctx->msg.data, ctx->msg.size, true); +#endif + bce_set_submission_single(s, b->qout.dma_addr + (dma_addr_t) (ctx->msg.data - b->qout.data), ctx->msg.size); + bce_submit_to_device(b->qout.sq); + b->qout.data_tail = (b->qout.data_tail + 1) % b->qout.el_count; + spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); +} + +int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply) +{ + struct aaudio_bce_queue_entry ent; + DECLARE_COMPLETION_ONSTACK(cmpl); + ent.msg = reply; + ent.cmpl = &cmpl; + b->pending_entries[ctx->tag_n] = &ent; + __aaudio_send(b, ctx); /* unlocks the spinlock */ + ctx->timeout = wait_for_completion_timeout(&cmpl, ctx->timeout); + if (ctx->timeout == 0) { + /* Remove the pending queue entry; this will be normally handled by the completion route but + * during a timeout it won't */ + spin_lock_irqsave(&b->spinlock, ctx->irq_flags); + if (b->pending_entries[ctx->tag_n] == &ent) + b->pending_entries[ctx->tag_n] = NULL; + spin_unlock_irqrestore(&b->spinlock, ctx->irq_flags); + return -ETIMEDOUT; + } + return 0; +} + +static void aaudio_handle_reply(struct aaudio_bce *b, struct aaudio_msg *reply) +{ + const char *tag; + int tagn; + unsigned long irq_flags; + char tag_zero[5]; + struct aaudio_bce_queue_entry *entry; + + tag = ((struct aaudio_msg_header *) reply->data)->tag; + if (tag[0] != 'S') { + pr_err("aaudio_handle_reply: Unexpected tag: %.4s\n", tag); + return; + } + *((u32 *) tag_zero) = *((u32 *) tag); + tag_zero[4] = 0; + if (kstrtoint(&tag_zero[1], 10, &tagn)) { + pr_err("aaudio_handle_reply: Tag parse failed: %.4s\n", tag); + return; + } + + spin_lock_irqsave(&b->spinlock, irq_flags); + entry = b->pending_entries[tagn]; + if (entry) { + if (reply->size < entry->msg->size) + entry->msg->size = reply->size; + memcpy(entry->msg->data, reply->data, entry->msg->size); + complete(entry->cmpl); + + b->pending_entries[tagn] = NULL; + } else { + pr_err("aaudio_handle_reply: No queued item found for tag: %.4s\n", tag); + } + spin_unlock_irqrestore(&b->spinlock, irq_flags); +} + +static void aaudio_bce_out_queue_completion(struct bce_queue_sq *sq) +{ + while (bce_next_completion(sq)) { + //pr_info("aaudio: Send confirmed\n"); + bce_notify_submission_complete(sq); + } +} + +static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg); + +static void aaudio_bce_in_queue_completion(struct bce_queue_sq *sq) +{ + struct aaudio_msg msg; + struct aaudio_device *dev = sq->userdata; + struct aaudio_bce_queue *q = &dev->bcem.qin; + struct bce_sq_completion_data *c; + size_t cnt = 0; + + mb(); + while ((c = bce_next_completion(sq))) { + msg.data = (u8 *) q->data + q->data_head * q->el_size; + msg.size = c->data_size; +#ifdef DEBUG + pr_debug("aaudio: Received command data %llx\n", c->data_size); + print_hex_dump(KERN_DEBUG, "aaudio:IN ", DUMP_PREFIX_NONE, 32, 1, msg.data, min(msg.size, 128UL), true); +#endif + aaudio_bce_in_queue_handle_msg(dev, &msg); + + q->data_head = (q->data_head + 1) % q->el_count; + + bce_notify_submission_complete(sq); + ++cnt; + } + aaudio_bce_in_queue_submit_pending(q, cnt); +} + +static void aaudio_bce_in_queue_handle_msg(struct aaudio_device *a, struct aaudio_msg *msg) +{ + struct aaudio_msg_header *header = (struct aaudio_msg_header *) msg->data; + if (msg->size < sizeof(struct aaudio_msg_header)) { + pr_err("aaudio: Msg size smaller than header (%lx)", msg->size); + return; + } + if (header->type == AAUDIO_MSG_TYPE_RESPONSE) { + aaudio_handle_reply(&a->bcem, msg); + } else if (header->type == AAUDIO_MSG_TYPE_COMMAND) { + aaudio_handle_command(a, msg); + } else if (header->type == AAUDIO_MSG_TYPE_NOTIFICATION) { + aaudio_handle_notification(a, msg); + } +} + +void aaudio_bce_in_queue_submit_pending(struct aaudio_bce_queue *q, size_t count) +{ + struct bce_qe_submission *s; + while (count--) { + if (bce_reserve_submission(q->sq, NULL)) { + pr_err("aaudio: Failed to reserve an event queue submission\n"); + break; + } + s = bce_next_submission(q->sq); + bce_set_submission_single(s, q->dma_addr + (dma_addr_t) (q->data_tail * q->el_size), q->el_size); + q->data_tail = (q->data_tail + 1) % q->el_count; + } + bce_submit_to_device(q->sq); +} + +struct aaudio_msg aaudio_reply_alloc(void) +{ + struct aaudio_msg ret; + ret.size = AAUDIO_BCE_QUEUE_ELEMENT_SIZE; + ret.data = kmalloc(ret.size, GFP_KERNEL); + return ret; +} + +void aaudio_reply_free(struct aaudio_msg *reply) +{ + kfree(reply->data); +} diff --git a/drivers/staging/apple-bce/audio/protocol_bce.h b/drivers/staging/apple-bce/audio/protocol_bce.h new file mode 100644 index 000000000000..14d26c05ddf9 --- /dev/null +++ b/drivers/staging/apple-bce/audio/protocol_bce.h @@ -0,0 +1,72 @@ +#ifndef AAUDIO_PROTOCOL_BCE_H +#define AAUDIO_PROTOCOL_BCE_H + +#include "protocol.h" +#include "../queue.h" + +#define AAUDIO_BCE_QUEUE_ELEMENT_SIZE 0x1000 +#define AAUDIO_BCE_QUEUE_ELEMENT_COUNT 20 + +#define AAUDIO_BCE_QUEUE_TAG_COUNT 1000 + +struct aaudio_device; + +struct aaudio_bce_queue_entry { + struct aaudio_msg *msg; + struct completion *cmpl; +}; +struct aaudio_bce_queue { + struct bce_queue_cq *cq; + struct bce_queue_sq *sq; + void *data; + dma_addr_t dma_addr; + size_t data_head, data_tail; + size_t el_size, el_count; +}; +struct aaudio_bce { + struct bce_queue_cq *cq; + struct aaudio_bce_queue qin; + struct aaudio_bce_queue qout; + int tag_num; + struct aaudio_bce_queue_entry *pending_entries[AAUDIO_BCE_QUEUE_TAG_COUNT]; + struct spinlock spinlock; +}; + +struct aaudio_send_ctx { + int status; + int tag_n; + unsigned long irq_flags; + struct aaudio_msg msg; + unsigned long timeout; +}; + +int aaudio_bce_init(struct aaudio_device *dev); +int __aaudio_send_prepare(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, char *tag); +void __aaudio_send(struct aaudio_bce *b, struct aaudio_send_ctx *ctx); +int __aaudio_send_cmd_sync(struct aaudio_bce *b, struct aaudio_send_ctx *ctx, struct aaudio_msg *reply); + +#define aaudio_send_with_tag(a, ctx, tag, tout, fn, ...) ({ \ + (ctx)->timeout = msecs_to_jiffies(tout); \ + (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), (tag)); \ + if (!(ctx)->status) { \ + fn(&(ctx)->msg, ##__VA_ARGS__); \ + __aaudio_send(&(a)->bcem, (ctx)); \ + } \ + (ctx)->status; \ +}) +#define aaudio_send(a, ctx, tout, fn, ...) aaudio_send_with_tag(a, ctx, NULL, tout, fn, ##__VA_ARGS__) + +#define aaudio_send_cmd_sync(a, ctx, reply, tout, fn, ...) ({ \ + (ctx)->timeout = msecs_to_jiffies(tout); \ + (ctx)->status = __aaudio_send_prepare(&(a)->bcem, (ctx), NULL); \ + if (!(ctx)->status) { \ + fn(&(ctx)->msg, ##__VA_ARGS__); \ + (ctx)->status = __aaudio_send_cmd_sync(&(a)->bcem, (ctx), (reply)); \ + } \ + (ctx)->status; \ +}) + +struct aaudio_msg aaudio_reply_alloc(void); +void aaudio_reply_free(struct aaudio_msg *reply); + +#endif //AAUDIO_PROTOCOL_BCE_H diff --git a/drivers/staging/apple-bce/mailbox.c b/drivers/staging/apple-bce/mailbox.c new file mode 100644 index 000000000000..e24bd35215c0 --- /dev/null +++ b/drivers/staging/apple-bce/mailbox.c @@ -0,0 +1,151 @@ +#include "mailbox.h" +#include +#include "apple_bce.h" + +#define REG_MBOX_OUT_BASE 0x820 +#define REG_MBOX_REPLY_COUNTER 0x108 +#define REG_MBOX_REPLY_BASE 0x810 +#define REG_TIMESTAMP_BASE 0xC000 + +#define BCE_MBOX_TIMEOUT_MS 200 + +void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb) +{ + mb->reg_mb = reg_mb; + init_completion(&mb->mb_completion); +} + +int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv) +{ + u32 __iomem *regb; + + if (atomic_cmpxchg(&mb->mb_status, 0, 1) != 0) { + return -EEXIST; // We don't support two messages at once + } + reinit_completion(&mb->mb_completion); + + pr_debug("bce_mailbox_send: %llx\n", msg); + regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_OUT_BASE); + iowrite32((u32) msg, regb); + iowrite32((u32) (msg >> 32), regb + 1); + iowrite32(0, regb + 2); + iowrite32(0, regb + 3); + + wait_for_completion_timeout(&mb->mb_completion, msecs_to_jiffies(BCE_MBOX_TIMEOUT_MS)); + if (atomic_read(&mb->mb_status) != 2) { // Didn't get the reply + atomic_set(&mb->mb_status, 0); + return -ETIMEDOUT; + } + + *recv = mb->mb_result; + pr_debug("bce_mailbox_send: reply %llx\n", *recv); + + atomic_set(&mb->mb_status, 0); + return 0; +} + +static int bce_mailbox_retrive_response(struct bce_mailbox *mb) +{ + u32 __iomem *regb; + u32 lo, hi; + int count, counter; + u32 res = ioread32((u8*) mb->reg_mb + REG_MBOX_REPLY_COUNTER); + count = (res >> 20) & 0xf; + counter = count; + pr_debug("bce_mailbox_retrive_response count=%i\n", count); + while (counter--) { + regb = (u32*) ((u8*) mb->reg_mb + REG_MBOX_REPLY_BASE); + lo = ioread32(regb); + hi = ioread32(regb + 1); + ioread32(regb + 2); + ioread32(regb + 3); + pr_debug("bce_mailbox_retrive_response %llx\n", ((u64) hi << 32) | lo); + mb->mb_result = ((u64) hi << 32) | lo; + } + return count > 0 ? 0 : -ENODATA; +} + +int bce_mailbox_handle_interrupt(struct bce_mailbox *mb) +{ + int status = bce_mailbox_retrive_response(mb); + if (!status) { + atomic_set(&mb->mb_status, 2); + complete(&mb->mb_completion); + } + return status; +} + +static void bc_send_timestamp(struct timer_list *tl); + +void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg) +{ + u32 __iomem *regb; + + spin_lock_init(&ts->stop_sl); + ts->stopped = false; + + ts->reg = reg; + + regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + ioread32(regb); + mb(); + + timer_setup(&ts->timer, bc_send_timestamp, 0); +} + +void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial) +{ + unsigned long flags; + u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + if (is_initial) { + iowrite32((u32) -4, regb + 2); + iowrite32((u32) -1, regb); + } else { + iowrite32((u32) -3, regb + 2); + iowrite32((u32) -1, regb); + } + + spin_lock_irqsave(&ts->stop_sl, flags); + ts->stopped = false; + spin_unlock_irqrestore(&ts->stop_sl, flags); + mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); +} + +void bce_timestamp_stop(struct bce_timestamp *ts) +{ + unsigned long flags; + u32 __iomem *regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + + spin_lock_irqsave(&ts->stop_sl, flags); + ts->stopped = true; + spin_unlock_irqrestore(&ts->stop_sl, flags); + del_timer_sync(&ts->timer); + + iowrite32((u32) -2, regb + 2); + iowrite32((u32) -1, regb); +} + +static void bc_send_timestamp(struct timer_list *tl) +{ + struct bce_timestamp *ts; + unsigned long flags; + u32 __iomem *regb; + ktime_t bt; + + ts = container_of(tl, struct bce_timestamp, timer); + regb = (u32*) ((u8*) ts->reg + REG_TIMESTAMP_BASE); + local_irq_save(flags); + ioread32(regb + 2); + mb(); + bt = ktime_get_boottime(); + iowrite32((u32) bt, regb + 2); + iowrite32((u32) (bt >> 32), regb); + + spin_lock(&ts->stop_sl); + if (!ts->stopped) + mod_timer(&ts->timer, jiffies + msecs_to_jiffies(150)); + spin_unlock(&ts->stop_sl); + local_irq_restore(flags); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/mailbox.h b/drivers/staging/apple-bce/mailbox.h new file mode 100644 index 000000000000..f3323f95ba51 --- /dev/null +++ b/drivers/staging/apple-bce/mailbox.h @@ -0,0 +1,53 @@ +#ifndef BCE_MAILBOX_H +#define BCE_MAILBOX_H + +#include +#include +#include + +struct bce_mailbox { + void __iomem *reg_mb; + + atomic_t mb_status; // possible statuses: 0 (no msg), 1 (has active msg), 2 (got reply) + struct completion mb_completion; + uint64_t mb_result; +}; + +enum bce_message_type { + BCE_MB_REGISTER_COMMAND_SQ = 0x7, // to-device + BCE_MB_REGISTER_COMMAND_CQ = 0x8, // to-device + BCE_MB_REGISTER_COMMAND_QUEUE_REPLY = 0xB, // to-host + BCE_MB_SET_FW_PROTOCOL_VERSION = 0xC, // both + BCE_MB_SLEEP_NO_STATE = 0x14, // to-device + BCE_MB_RESTORE_NO_STATE = 0x15, // to-device + BCE_MB_SAVE_STATE_AND_SLEEP = 0x17, // to-device + BCE_MB_RESTORE_STATE_AND_WAKE = 0x18, // to-device + BCE_MB_SAVE_STATE_AND_SLEEP_FAILURE = 0x19, // from-device + BCE_MB_SAVE_RESTORE_STATE_COMPLETE = 0x1A, // from-device +}; + +#define BCE_MB_MSG(type, value) (((u64) (type) << 58) | ((value) & 0x3FFFFFFFFFFFFFFLL)) +#define BCE_MB_TYPE(v) ((u32) (v >> 58)) +#define BCE_MB_VALUE(v) (v & 0x3FFFFFFFFFFFFFFLL) + +void bce_mailbox_init(struct bce_mailbox *mb, void __iomem *reg_mb); + +int bce_mailbox_send(struct bce_mailbox *mb, u64 msg, u64* recv); + +int bce_mailbox_handle_interrupt(struct bce_mailbox *mb); + + +struct bce_timestamp { + void __iomem *reg; + struct timer_list timer; + struct spinlock stop_sl; + bool stopped; +}; + +void bce_timestamp_init(struct bce_timestamp *ts, void __iomem *reg); + +void bce_timestamp_start(struct bce_timestamp *ts, bool is_initial); + +void bce_timestamp_stop(struct bce_timestamp *ts); + +#endif //BCEDRIVER_MAILBOX_H diff --git a/drivers/staging/apple-bce/queue.c b/drivers/staging/apple-bce/queue.c new file mode 100644 index 000000000000..bc9cd3bc6f0c --- /dev/null +++ b/drivers/staging/apple-bce/queue.c @@ -0,0 +1,390 @@ +#include "queue.h" +#include "apple_bce.h" + +#define REG_DOORBELL_BASE 0x44000 + +struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count) +{ + struct bce_queue_cq *q; + q = kzalloc(sizeof(struct bce_queue_cq), GFP_KERNEL); + q->qid = qid; + q->type = BCE_QUEUE_CQ; + q->el_count = el_count; + q->data = dma_alloc_coherent(&dev->pci->dev, el_count * sizeof(struct bce_qe_completion), + &q->dma_handle, GFP_KERNEL); + if (!q->data) { + pr_err("DMA queue memory alloc failed\n"); + kfree(q); + return NULL; + } + return q; +} + +void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) +{ + cfg->qid = (u16) cq->qid; + cfg->el_count = (u16) cq->el_count; + cfg->vector_or_cq = 0; + cfg->_pad = 0; + cfg->addr = cq->dma_handle; + cfg->length = cq->el_count * sizeof(struct bce_qe_completion); +} + +void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + dma_free_coherent(&dev->pci->dev, cq->el_count * sizeof(struct bce_qe_completion), cq->data, cq->dma_handle); + kfree(cq); +} + +static void bce_handle_cq_completion(struct apple_bce_device *dev, struct bce_qe_completion *e, size_t *ce) +{ + struct bce_queue *target; + struct bce_queue_sq *target_sq; + struct bce_sq_completion_data *cmpl; + if (e->qid >= BCE_MAX_QUEUE_COUNT) { + pr_err("Device sent a response for qid (%u) >= BCE_MAX_QUEUE_COUNT\n", e->qid); + return; + } + target = dev->queues[e->qid]; + if (!target || target->type != BCE_QUEUE_SQ) { + pr_err("Device sent a response for qid (%u), which does not exist\n", e->qid); + return; + } + target_sq = (struct bce_queue_sq *) target; + if (target_sq->completion_tail != e->completion_index) { + pr_err("Completion index mismatch; this is likely going to make this driver unusable\n"); + return; + } + if (!target_sq->has_pending_completions) { + target_sq->has_pending_completions = true; + dev->int_sq_list[(*ce)++] = target_sq; + } + cmpl = &target_sq->completion_data[e->completion_index]; + cmpl->status = e->status; + cmpl->data_size = e->data_size; + cmpl->result = e->result; + wmb(); + target_sq->completion_tail = (target_sq->completion_tail + 1) % target_sq->el_count; +} + +void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + size_t ce = 0; + struct bce_qe_completion *e; + struct bce_queue_sq *sq; + e = bce_cq_element(cq, cq->index); + if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) + return; + mb(); + while (true) { + e = bce_cq_element(cq, cq->index); + if (!(e->flags & BCE_COMPLETION_FLAG_PENDING)) + break; + // pr_info("apple-bce: compl: %i: %i %llx %llx", e->qid, e->status, e->data_size, e->result); + bce_handle_cq_completion(dev, e, &ce); + e->flags = 0; + cq->index = (cq->index + 1) % cq->el_count; + } + mb(); + iowrite32(cq->index, (u32 *) ((u8 *) dev->reg_mem_dma + REG_DOORBELL_BASE) + cq->qid); + while (ce) { + --ce; + sq = dev->int_sq_list[ce]; + sq->completion(sq); + sq->has_pending_completions = false; + } +} + + +struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, + bce_sq_completion compl, void *userdata) +{ + struct bce_queue_sq *q; + q = kzalloc(sizeof(struct bce_queue_sq), GFP_KERNEL); + q->qid = qid; + q->type = BCE_QUEUE_SQ; + q->el_size = el_size; + q->el_count = el_count; + q->data = dma_alloc_coherent(&dev->pci->dev, el_count * el_size, + &q->dma_handle, GFP_KERNEL); + q->completion = compl; + q->userdata = userdata; + q->completion_data = kzalloc(sizeof(struct bce_sq_completion_data) * el_count, GFP_KERNEL); + q->reg_mem_dma = dev->reg_mem_dma; + atomic_set(&q->available_commands, el_count - 1); + init_completion(&q->available_command_completion); + atomic_set(&q->available_command_completion_waiting_count, 0); + if (!q->data) { + pr_err("DMA queue memory alloc failed\n"); + kfree(q); + return NULL; + } + return q; +} + +void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg) +{ + cfg->qid = (u16) sq->qid; + cfg->el_count = (u16) sq->el_count; + cfg->vector_or_cq = (u16) cq->qid; + cfg->_pad = 0; + cfg->addr = sq->dma_handle; + cfg->length = sq->el_count * sq->el_size; +} + +void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) +{ + dma_free_coherent(&dev->pci->dev, sq->el_count * sq->el_size, sq->data, sq->dma_handle); + kfree(sq); +} + +int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout) +{ + while (atomic_dec_if_positive(&sq->available_commands) < 0) { + if (!timeout || !*timeout) + return -EAGAIN; + atomic_inc(&sq->available_command_completion_waiting_count); + *timeout = wait_for_completion_timeout(&sq->available_command_completion, *timeout); + if (!*timeout) { + if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) < 0) + try_wait_for_completion(&sq->available_command_completion); /* consume the pending completion */ + } + } + return 0; +} + +void bce_cancel_submission_reservation(struct bce_queue_sq *sq) +{ + atomic_inc(&sq->available_commands); +} + +void *bce_next_submission(struct bce_queue_sq *sq) +{ + void *ret = bce_sq_element(sq, sq->tail); + sq->tail = (sq->tail + 1) % sq->el_count; + return ret; +} + +void bce_submit_to_device(struct bce_queue_sq *sq) +{ + mb(); + iowrite32(sq->tail, (u32 *) ((u8 *) sq->reg_mem_dma + REG_DOORBELL_BASE) + sq->qid); +} + +void bce_notify_submission_complete(struct bce_queue_sq *sq) +{ + sq->head = (sq->head + 1) % sq->el_count; + atomic_inc(&sq->available_commands); + if (atomic_dec_if_positive(&sq->available_command_completion_waiting_count) >= 0) { + complete(&sq->available_command_completion); + } +} + +void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size) +{ + element->addr = addr; + element->length = size; + element->segl_addr = element->segl_length = 0; +} + +static void bce_cmdq_completion(struct bce_queue_sq *q); + +struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count) +{ + struct bce_queue_cmdq *q; + q = kzalloc(sizeof(struct bce_queue_cmdq), GFP_KERNEL); + q->sq = bce_alloc_sq(dev, qid, BCE_CMD_SIZE, el_count, bce_cmdq_completion, q); + if (!q->sq) { + kfree(q); + return NULL; + } + spin_lock_init(&q->lck); + q->tres = kzalloc(sizeof(struct bce_queue_cmdq_result_el*) * el_count, GFP_KERNEL); + if (!q->tres) { + kfree(q); + return NULL; + } + return q; +} + +void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq) +{ + bce_free_sq(dev, cmdq->sq); + kfree(cmdq->tres); + kfree(cmdq); +} + +void bce_cmdq_completion(struct bce_queue_sq *q) +{ + struct bce_queue_cmdq_result_el *el; + struct bce_queue_cmdq *cmdq = q->userdata; + struct bce_sq_completion_data *result; + + spin_lock(&cmdq->lck); + while ((result = bce_next_completion(q))) { + el = cmdq->tres[cmdq->sq->head]; + if (el) { + el->result = result->result; + el->status = result->status; + mb(); + complete(&el->cmpl); + } else { + pr_err("apple-bce: Unexpected command queue completion\n"); + } + cmdq->tres[cmdq->sq->head] = NULL; + bce_notify_submission_complete(q); + } + spin_unlock(&cmdq->lck); +} + +static __always_inline void *bce_cmd_start(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) +{ + void *ret; + unsigned long timeout; + init_completion(&res->cmpl); + mb(); + + timeout = msecs_to_jiffies(1000L * 60 * 5); /* wait for up to ~5 minutes */ + if (bce_reserve_submission(cmdq->sq, &timeout)) + return NULL; + + spin_lock(&cmdq->lck); + cmdq->tres[cmdq->sq->tail] = res; + ret = bce_next_submission(cmdq->sq); + return ret; +} + +static __always_inline void bce_cmd_finish(struct bce_queue_cmdq *cmdq, struct bce_queue_cmdq_result_el *res) +{ + bce_submit_to_device(cmdq->sq); + spin_unlock(&cmdq->lck); + + wait_for_completion(&res->cmpl); + mb(); +} + +u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_register_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_REGISTER_MEMORY_QUEUE; + cmd->flags = (u16) ((name ? 2 : 0) | (isdirout ? 1 : 0)); + cmd->qid = cfg->qid; + cmd->el_count = cfg->el_count; + cmd->vector_or_cq = cfg->vector_or_cq; + memset(cmd->name, 0, sizeof(cmd->name)); + if (name) { + cmd->name_len = (u16) min(strlen(name), (size_t) sizeof(cmd->name)); + memcpy(cmd->name, name, cmd->name_len); + } else { + cmd->name_len = 0; + } + cmd->addr = cfg->addr; + cmd->length = cfg->length; + + bce_cmd_finish(cmdq, &res); + return res.status; +} + +u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_UNREGISTER_MEMORY_QUEUE; + cmd->flags = 0; + cmd->qid = qid; + bce_cmd_finish(cmdq, &res); + return res.status; +} + +u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid) +{ + struct bce_queue_cmdq_result_el res; + struct bce_cmdq_simple_memory_queue_cmd *cmd = bce_cmd_start(cmdq, &res); + if (!cmd) + return (u32) -1; + cmd->cmd = BCE_CMD_FLUSH_MEMORY_QUEUE; + cmd->flags = 0; + cmd->qid = qid; + bce_cmd_finish(cmdq, &res); + return res.status; +} + + +struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count) +{ + struct bce_queue_cq *cq; + struct bce_queue_memcfg cfg; + int qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); + if (qid < 0) + return NULL; + cq = bce_alloc_cq(dev, qid, el_count); + if (!cq) + return NULL; + bce_get_cq_memcfg(cq, &cfg); + if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, NULL, false) != 0) { + pr_err("apple-bce: CQ registration failed (%i)", qid); + bce_free_cq(dev, cq); + ida_simple_remove(&dev->queue_ida, (uint) qid); + return NULL; + } + dev->queues[qid] = (struct bce_queue *) cq; + return cq; +} + +struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, + int direction, bce_sq_completion compl, void *userdata) +{ + struct bce_queue_sq *sq; + struct bce_queue_memcfg cfg; + int qid; + if (cq == NULL) + return NULL; /* cq can not be null */ + if (name == NULL) + return NULL; /* name can not be null */ + if (direction != DMA_TO_DEVICE && direction != DMA_FROM_DEVICE) + return NULL; /* unsupported direction */ + qid = ida_simple_get(&dev->queue_ida, BCE_QUEUE_USER_MIN, BCE_QUEUE_USER_MAX, GFP_KERNEL); + if (qid < 0) + return NULL; + sq = bce_alloc_sq(dev, qid, sizeof(struct bce_qe_submission), el_count, compl, userdata); + if (!sq) + return NULL; + bce_get_sq_memcfg(sq, cq, &cfg); + if (bce_cmd_register_queue(dev->cmd_cmdq, &cfg, name, direction != DMA_FROM_DEVICE) != 0) { + pr_err("apple-bce: SQ registration failed (%i)", qid); + bce_free_sq(dev, sq); + ida_simple_remove(&dev->queue_ida, (uint) qid); + return NULL; + } + spin_lock(&dev->queues_lock); + dev->queues[qid] = (struct bce_queue *) sq; + spin_unlock(&dev->queues_lock); + return sq; +} + +void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq) +{ + if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) cq->qid)) + pr_err("apple-bce: CQ unregister failed"); + spin_lock(&dev->queues_lock); + dev->queues[cq->qid] = NULL; + spin_unlock(&dev->queues_lock); + ida_simple_remove(&dev->queue_ida, (uint) cq->qid); + bce_free_cq(dev, cq); +} + +void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq) +{ + if (!dev->is_being_removed && bce_cmd_unregister_memory_queue(dev->cmd_cmdq, (u16) sq->qid)) + pr_err("apple-bce: CQ unregister failed"); + spin_lock(&dev->queues_lock); + dev->queues[sq->qid] = NULL; + spin_unlock(&dev->queues_lock); + ida_simple_remove(&dev->queue_ida, (uint) sq->qid); + bce_free_sq(dev, sq); +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/queue.h b/drivers/staging/apple-bce/queue.h new file mode 100644 index 000000000000..8368ac5dfca8 --- /dev/null +++ b/drivers/staging/apple-bce/queue.h @@ -0,0 +1,177 @@ +#ifndef BCE_QUEUE_H +#define BCE_QUEUE_H + +#include +#include + +#define BCE_CMD_SIZE 0x40 + +struct apple_bce_device; + +enum bce_queue_type { + BCE_QUEUE_CQ, BCE_QUEUE_SQ +}; +struct bce_queue { + int qid; + int type; +}; +struct bce_queue_cq { + int qid; + int type; + u32 el_count; + dma_addr_t dma_handle; + void *data; + + u32 index; +}; +struct bce_queue_sq; +typedef void (*bce_sq_completion)(struct bce_queue_sq *q); +struct bce_sq_completion_data { + u32 status; + u64 data_size; + u64 result; +}; +struct bce_queue_sq { + int qid; + int type; + u32 el_size; + u32 el_count; + dma_addr_t dma_handle; + void *data; + void *userdata; + void __iomem *reg_mem_dma; + + atomic_t available_commands; + struct completion available_command_completion; + atomic_t available_command_completion_waiting_count; + u32 head, tail; + + u32 completion_cidx, completion_tail; + struct bce_sq_completion_data *completion_data; + bool has_pending_completions; + bce_sq_completion completion; +}; + +struct bce_queue_cmdq_result_el { + struct completion cmpl; + u32 status; + u64 result; +}; +struct bce_queue_cmdq { + struct bce_queue_sq *sq; + struct spinlock lck; + struct bce_queue_cmdq_result_el **tres; +}; + +struct bce_queue_memcfg { + u16 qid; + u16 el_count; + u16 vector_or_cq; + u16 _pad; + u64 addr; + u64 length; +}; + +enum bce_qe_completion_status { + BCE_COMPLETION_SUCCESS = 0, + BCE_COMPLETION_ERROR = 1, + BCE_COMPLETION_ABORTED = 2, + BCE_COMPLETION_NO_SPACE = 3, + BCE_COMPLETION_OVERRUN = 4 +}; +enum bce_qe_completion_flags { + BCE_COMPLETION_FLAG_PENDING = 0x8000 +}; +struct bce_qe_completion { + u64 result; + u64 data_size; + u16 qid; + u16 completion_index; + u16 status; // bce_qe_completion_status + u16 flags; // bce_qe_completion_flags +}; + +struct bce_qe_submission { + u64 length; + u64 addr; + + u64 segl_addr; + u64 segl_length; +}; + +enum bce_cmdq_command { + BCE_CMD_REGISTER_MEMORY_QUEUE = 0x20, + BCE_CMD_UNREGISTER_MEMORY_QUEUE = 0x30, + BCE_CMD_FLUSH_MEMORY_QUEUE = 0x40, + BCE_CMD_SET_MEMORY_QUEUE_PROPERTY = 0x50 +}; +struct bce_cmdq_simple_memory_queue_cmd { + u16 cmd; // bce_cmdq_command + u16 flags; + u16 qid; +}; +struct bce_cmdq_register_memory_queue_cmd { + u16 cmd; // bce_cmdq_command + u16 flags; + u16 qid; + u16 _pad; + u16 el_count; + u16 vector_or_cq; + u16 _pad2; + u16 name_len; + char name[0x20]; + u64 addr; + u64 length; +}; + +static __always_inline void *bce_sq_element(struct bce_queue_sq *q, int i) { + return (void *) ((u8 *) q->data + q->el_size * i); +} +static __always_inline void *bce_cq_element(struct bce_queue_cq *q, int i) { + return (void *) ((struct bce_qe_completion *) q->data + i); +} + +static __always_inline struct bce_sq_completion_data *bce_next_completion(struct bce_queue_sq *sq) { + struct bce_sq_completion_data *res; + rmb(); + if (sq->completion_cidx == sq->completion_tail) + return NULL; + res = &sq->completion_data[sq->completion_cidx]; + sq->completion_cidx = (sq->completion_cidx + 1) % sq->el_count; + return res; +} + +struct bce_queue_cq *bce_alloc_cq(struct apple_bce_device *dev, int qid, u32 el_count); +void bce_get_cq_memcfg(struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); +void bce_free_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); +void bce_handle_cq_completions(struct apple_bce_device *dev, struct bce_queue_cq *cq); + +struct bce_queue_sq *bce_alloc_sq(struct apple_bce_device *dev, int qid, u32 el_size, u32 el_count, + bce_sq_completion compl, void *userdata); +void bce_get_sq_memcfg(struct bce_queue_sq *sq, struct bce_queue_cq *cq, struct bce_queue_memcfg *cfg); +void bce_free_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); +int bce_reserve_submission(struct bce_queue_sq *sq, unsigned long *timeout); +void bce_cancel_submission_reservation(struct bce_queue_sq *sq); +void *bce_next_submission(struct bce_queue_sq *sq); +void bce_submit_to_device(struct bce_queue_sq *sq); +void bce_notify_submission_complete(struct bce_queue_sq *sq); + +void bce_set_submission_single(struct bce_qe_submission *element, dma_addr_t addr, size_t size); + +struct bce_queue_cmdq *bce_alloc_cmdq(struct apple_bce_device *dev, int qid, u32 el_count); +void bce_free_cmdq(struct apple_bce_device *dev, struct bce_queue_cmdq *cmdq); + +u32 bce_cmd_register_queue(struct bce_queue_cmdq *cmdq, struct bce_queue_memcfg *cfg, const char *name, bool isdirout); +u32 bce_cmd_unregister_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); +u32 bce_cmd_flush_memory_queue(struct bce_queue_cmdq *cmdq, u16 qid); + + +/* User API - Creates and registers the queue */ + +struct bce_queue_cq *bce_create_cq(struct apple_bce_device *dev, u32 el_count); +struct bce_queue_sq *bce_create_sq(struct apple_bce_device *dev, struct bce_queue_cq *cq, const char *name, u32 el_count, + int direction, bce_sq_completion compl, void *userdata); +void bce_destroy_cq(struct apple_bce_device *dev, struct bce_queue_cq *cq); +void bce_destroy_sq(struct apple_bce_device *dev, struct bce_queue_sq *sq); + +#endif //BCEDRIVER_MAILBOX_H diff --git a/drivers/staging/apple-bce/queue_dma.c b/drivers/staging/apple-bce/queue_dma.c new file mode 100644 index 000000000000..b236613285c0 --- /dev/null +++ b/drivers/staging/apple-bce/queue_dma.c @@ -0,0 +1,220 @@ +#include "queue_dma.h" +#include +#include +#include "queue.h" + +static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len); +static struct bce_segment_list_element_hostinfo *bce_map_segment_list( + struct device *dev, struct scatterlist *pages, int pagen); +static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list); + +int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, + enum dma_data_direction dir) +{ + int cnt; + + buf->direction = dir; + buf->scatterlist = scatterlist; + buf->seglist_hostinfo = NULL; + + cnt = dma_map_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + if (cnt != buf->scatterlist.nents) { + pr_err("apple-bce: DMA scatter list mapping returned an unexpected count: %i\n", cnt); + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + return -EIO; + } + if (cnt == 1) + return 0; + + buf->seglist_hostinfo = bce_map_segment_list(dev, buf->scatterlist.sgl, buf->scatterlist.nents); + if (!buf->seglist_hostinfo) { + pr_err("apple-bce: Creating segment list failed\n"); + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, dir); + return -EIO; + } + return 0; +} + +int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir) +{ + int status; + struct sg_table scatterlist; + if ((status = bce_alloc_scatterlist_from_vm(&scatterlist, data, len))) + return status; + if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { + sg_free_table(&scatterlist); + return status; + } + return 0; +} + +int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir) +{ + /* Kernel memory is continuous which is great for us. */ + int status; + struct sg_table scatterlist; + if ((status = sg_alloc_table(&scatterlist, 1, GFP_KERNEL))) { + sg_free_table(&scatterlist); + return status; + } + sg_set_buf(scatterlist.sgl, data, (uint) len); + if ((status = bce_map_dma_buffer(dev, buf, scatterlist, dir))) { + sg_free_table(&scatterlist); + return status; + } + return 0; +} + +void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf) +{ + dma_unmap_sg(dev, buf->scatterlist.sgl, buf->scatterlist.nents, buf->direction); + bce_unmap_segement_list(dev, buf->seglist_hostinfo); +} + + +static int bce_alloc_scatterlist_from_vm(struct sg_table *tbl, void *data, size_t len) +{ + int status, i; + struct page **pages; + size_t off, start_page, end_page, page_count; + off = (size_t) data % PAGE_SIZE; + start_page = (size_t) data / PAGE_SIZE; + end_page = ((size_t) data + len - 1) / PAGE_SIZE; + page_count = end_page - start_page + 1; + + if (page_count > PAGE_SIZE / sizeof(struct page *)) + pages = vmalloc(page_count * sizeof(struct page *)); + else + pages = kmalloc(page_count * sizeof(struct page *), GFP_KERNEL); + + for (i = 0; i < page_count; i++) + pages[i] = vmalloc_to_page((void *) ((start_page + i) * PAGE_SIZE)); + + if ((status = sg_alloc_table_from_pages(tbl, pages, page_count, (unsigned int) off, len, GFP_KERNEL))) { + sg_free_table(tbl); + } + + if (page_count > PAGE_SIZE / sizeof(struct page *)) + vfree(pages); + else + kfree(pages); + return status; +} + +#define BCE_ELEMENTS_PER_PAGE ((PAGE_SIZE - sizeof(struct bce_segment_list_header)) \ + / sizeof(struct bce_segment_list_element)) +#define BCE_ELEMENTS_PER_ADDITIONAL_PAGE (PAGE_SIZE / sizeof(struct bce_segment_list_element)) + +static struct bce_segment_list_element_hostinfo *bce_map_segment_list( + struct device *dev, struct scatterlist *pages, int pagen) +{ + size_t ptr, pptr = 0; + struct bce_segment_list_header theader; /* a temp header, to store the initial seg */ + struct bce_segment_list_header *header; + struct bce_segment_list_element *el, *el_end; + struct bce_segment_list_element_hostinfo *out, *pout, *out_root; + struct scatterlist *sg; + int i; + header = &theader; + out = out_root = NULL; + el = el_end = NULL; + for_each_sg(pages, sg, pagen, i) { + if (el >= el_end) { + /* allocate a new page, this will be also done for the first element */ + ptr = __get_free_page(GFP_KERNEL); + if (pptr && ptr == pptr + PAGE_SIZE) { + out->page_count++; + header->element_count += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; + el_end += BCE_ELEMENTS_PER_ADDITIONAL_PAGE; + } else { + header = (void *) ptr; + header->element_count = BCE_ELEMENTS_PER_PAGE; + header->data_size = 0; + header->next_segl_addr = 0; + header->next_segl_length = 0; + el = (void *) (header + 1); + el_end = el + BCE_ELEMENTS_PER_PAGE; + + if (out) { + out->next = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); + out = out->next; + } else { + out_root = out = kmalloc(sizeof(struct bce_segment_list_element_hostinfo), GFP_KERNEL); + } + out->page_start = (void *) ptr; + out->page_count = 1; + out->dma_start = DMA_MAPPING_ERROR; + out->next = NULL; + } + pptr = ptr; + } + el->addr = sg->dma_address; + el->length = sg->length; + header->data_size += el->length; + } + + /* DMA map */ + out = out_root; + pout = NULL; + while (out) { + out->dma_start = dma_map_single(dev, out->page_start, out->page_count * PAGE_SIZE, DMA_TO_DEVICE); + if (dma_mapping_error(dev, out->dma_start)) + goto error; + if (pout) { + header = pout->page_start; + header->next_segl_addr = out->dma_start; + header->next_segl_length = out->page_count * PAGE_SIZE; + } + pout = out; + out = out->next; + } + return out_root; + + error: + bce_unmap_segement_list(dev, out_root); + return NULL; +} + +static void bce_unmap_segement_list(struct device *dev, struct bce_segment_list_element_hostinfo *list) +{ + struct bce_segment_list_element_hostinfo *next; + while (list) { + if (list->dma_start != DMA_MAPPING_ERROR) + dma_unmap_single(dev, list->dma_start, list->page_count * PAGE_SIZE, DMA_TO_DEVICE); + next = list->next; + kfree(list); + list = next; + } +} + +int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length) +{ + struct bce_segment_list_element_hostinfo *seg; + struct bce_segment_list_header *seg_header; + + seg = buf->seglist_hostinfo; + if (!seg) { + element->addr = buf->scatterlist.sgl->dma_address + offset; + element->length = length; + element->segl_addr = 0; + element->segl_length = 0; + return 0; + } + + while (seg) { + seg_header = seg->page_start; + if (offset <= seg_header->data_size) + break; + offset -= seg_header->data_size; + seg = seg->next; + } + if (!seg) + return -EINVAL; + element->addr = offset; + element->length = buf->scatterlist.sgl->dma_length; + element->segl_addr = seg->dma_start; + element->segl_length = seg->page_count * PAGE_SIZE; + return 0; +} \ No newline at end of file diff --git a/drivers/staging/apple-bce/queue_dma.h b/drivers/staging/apple-bce/queue_dma.h new file mode 100644 index 000000000000..f8a57e50e7a3 --- /dev/null +++ b/drivers/staging/apple-bce/queue_dma.h @@ -0,0 +1,50 @@ +#ifndef BCE_QUEUE_DMA_H +#define BCE_QUEUE_DMA_H + +#include + +struct bce_qe_submission; + +struct bce_segment_list_header { + u64 element_count; + u64 data_size; + + u64 next_segl_addr; + u64 next_segl_length; +}; +struct bce_segment_list_element { + u64 addr; + u64 length; +}; + +struct bce_segment_list_element_hostinfo { + struct bce_segment_list_element_hostinfo *next; + void *page_start; + size_t page_count; + dma_addr_t dma_start; +}; + + +struct bce_dma_buffer { + enum dma_data_direction direction; + struct sg_table scatterlist; + struct bce_segment_list_element_hostinfo *seglist_hostinfo; +}; + +/* NOTE: Takes ownership of the sg_table if it succeeds. Ownership is not transferred on failure. */ +int bce_map_dma_buffer(struct device *dev, struct bce_dma_buffer *buf, struct sg_table scatterlist, + enum dma_data_direction dir); + +/* Creates a buffer from virtual memory (vmalloc) */ +int bce_map_dma_buffer_vm(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir); + +/* Creates a buffer from kernel memory (kmalloc) */ +int bce_map_dma_buffer_km(struct device *dev, struct bce_dma_buffer *buf, void *data, size_t len, + enum dma_data_direction dir); + +void bce_unmap_dma_buffer(struct device *dev, struct bce_dma_buffer *buf); + +int bce_set_submission_buf(struct bce_qe_submission *element, struct bce_dma_buffer *buf, size_t offset, size_t length); + +#endif //BCE_QUEUE_DMA_H diff --git a/drivers/staging/apple-bce/vhci/command.h b/drivers/staging/apple-bce/vhci/command.h new file mode 100644 index 000000000000..26619e0bccfa --- /dev/null +++ b/drivers/staging/apple-bce/vhci/command.h @@ -0,0 +1,204 @@ +#ifndef BCE_VHCI_COMMAND_H +#define BCE_VHCI_COMMAND_H + +#include "queue.h" +#include +#include + +#define BCE_VHCI_CMD_TIMEOUT_SHORT msecs_to_jiffies(2000) +#define BCE_VHCI_CMD_TIMEOUT_LONG msecs_to_jiffies(30000) + +#define BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2 2 +#define BCE_VHCI_BULK_MAX_ACTIVE_URBS (1 << BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2) + +typedef u8 bce_vhci_port_t; +typedef u8 bce_vhci_device_t; + +enum bce_vhci_command { + BCE_VHCI_CMD_CONTROLLER_ENABLE = 1, + BCE_VHCI_CMD_CONTROLLER_DISABLE = 2, + BCE_VHCI_CMD_CONTROLLER_START = 3, + BCE_VHCI_CMD_CONTROLLER_PAUSE = 4, + + BCE_VHCI_CMD_PORT_POWER_ON = 0x10, + BCE_VHCI_CMD_PORT_POWER_OFF = 0x11, + BCE_VHCI_CMD_PORT_RESUME = 0x12, + BCE_VHCI_CMD_PORT_SUSPEND = 0x13, + BCE_VHCI_CMD_PORT_RESET = 0x14, + BCE_VHCI_CMD_PORT_DISABLE = 0x15, + BCE_VHCI_CMD_PORT_STATUS = 0x16, + + BCE_VHCI_CMD_DEVICE_CREATE = 0x30, + BCE_VHCI_CMD_DEVICE_DESTROY = 0x31, + + BCE_VHCI_CMD_ENDPOINT_CREATE = 0x40, + BCE_VHCI_CMD_ENDPOINT_DESTROY = 0x41, + BCE_VHCI_CMD_ENDPOINT_SET_STATE = 0x42, + BCE_VHCI_CMD_ENDPOINT_RESET = 0x44, + + /* Device to host only */ + BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE = 0x43, + BCE_VHCI_CMD_TRANSFER_REQUEST = 0x1000, + BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS = 0x1005 +}; + +enum bce_vhci_endpoint_state { + BCE_VHCI_ENDPOINT_ACTIVE = 0, + BCE_VHCI_ENDPOINT_PAUSED = 1, + BCE_VHCI_ENDPOINT_STALLED = 2 +}; + +static inline int bce_vhci_cmd_controller_enable(struct bce_vhci_command_queue *q, u8 busNum, u16 *portMask) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_ENABLE; + cmd.param1 = 0x7100u | busNum; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); + if (!status) + *portMask = (u16) res.param2; + return status; +} +static inline int bce_vhci_cmd_controller_disable(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_DISABLE; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_controller_start(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_START; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_controller_pause(struct bce_vhci_command_queue *q) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_CONTROLLER_PAUSE; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} + +static inline int bce_vhci_cmd_port_power_on(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_POWER_ON; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_power_off(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_POWER_OFF; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_resume(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_RESUME; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_port_suspend(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_SUSPEND; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} +static inline int bce_vhci_cmd_port_reset(struct bce_vhci_command_queue *q, bce_vhci_port_t port, u32 timeout) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_RESET; + cmd.param1 = port; + cmd.param2 = timeout; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_disable(struct bce_vhci_command_queue *q, bce_vhci_port_t port) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_DISABLE; + cmd.param1 = port; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_port_status(struct bce_vhci_command_queue *q, bce_vhci_port_t port, + u32 clearFlags, u32 *resStatus) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_PORT_STATUS; + cmd.param1 = port; + cmd.param2 = clearFlags & 0x560000; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (status >= 0) + *resStatus = (u32) res.param2; + return status; +} + +static inline int bce_vhci_cmd_device_create(struct bce_vhci_command_queue *q, bce_vhci_port_t port, + bce_vhci_device_t *dev) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_DEVICE_CREATE; + cmd.param1 = port; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (!status) + *dev = (bce_vhci_device_t) res.param2; + return status; +} +static inline int bce_vhci_cmd_device_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_DEVICE_DESTROY; + cmd.param1 = dev; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_LONG); +} + +static inline int bce_vhci_cmd_endpoint_create(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, + struct usb_endpoint_descriptor *desc) +{ + struct bce_vhci_message cmd, res; + int endpoint_type = usb_endpoint_type(desc); + int maxp = usb_endpoint_maxp(desc); + int maxp_burst = usb_endpoint_maxp_mult(desc) * maxp; + u8 max_active_requests_pow2 = 0; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_CREATE; + cmd.param1 = dev | ((desc->bEndpointAddress & 0x8Fu) << 8); + if (endpoint_type == USB_ENDPOINT_XFER_BULK) + max_active_requests_pow2 = BCE_VHCI_BULK_MAX_ACTIVE_URBS_POW2; + cmd.param2 = endpoint_type | ((max_active_requests_pow2 & 0xf) << 4) | (maxp << 16) | ((u64) maxp_burst << 32); + if (endpoint_type == USB_ENDPOINT_XFER_INT) + cmd.param2 |= (desc->bInterval - 1) << 8; + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_endpoint_destroy(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_DESTROY; + cmd.param1 = dev | (endpoint << 8); + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} +static inline int bce_vhci_cmd_endpoint_set_state(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint, + enum bce_vhci_endpoint_state newState, enum bce_vhci_endpoint_state *retState) +{ + int status; + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_SET_STATE; + cmd.param1 = dev | (endpoint << 8); + cmd.param2 = (u64) newState; + status = bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); + if (status != BCE_VHCI_INTERNAL_ERROR && status != BCE_VHCI_NO_POWER) + *retState = (enum bce_vhci_endpoint_state) res.param2; + return status; +} +static inline int bce_vhci_cmd_endpoint_reset(struct bce_vhci_command_queue *q, bce_vhci_device_t dev, u8 endpoint) +{ + struct bce_vhci_message cmd, res; + cmd.cmd = BCE_VHCI_CMD_ENDPOINT_RESET; + cmd.param1 = dev | (endpoint << 8); + return bce_vhci_command_queue_execute(q, &cmd, &res, BCE_VHCI_CMD_TIMEOUT_SHORT); +} + + +#endif //BCE_VHCI_COMMAND_H diff --git a/drivers/staging/apple-bce/vhci/queue.c b/drivers/staging/apple-bce/vhci/queue.c new file mode 100644 index 000000000000..7b0b5027157b --- /dev/null +++ b/drivers/staging/apple-bce/vhci/queue.c @@ -0,0 +1,268 @@ +#include "queue.h" +#include "vhci.h" +#include "../apple_bce.h" + + +static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq); + +int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name) +{ + int status; + ret->cq = bce_create_cq(vhci->dev, VHCI_EVENT_QUEUE_EL_COUNT); + if (!ret->cq) + return -EINVAL; + ret->sq = bce_create_sq(vhci->dev, ret->cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_TO_DEVICE, + bce_vhci_message_queue_completion, ret); + if (!ret->sq) { + status = -EINVAL; + goto fail_cq; + } + ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + &ret->dma_addr, GFP_KERNEL); + if (!ret->data) { + status = -EINVAL; + goto fail_sq; + } + return 0; + +fail_sq: + bce_destroy_sq(vhci->dev, ret->sq); + ret->sq = NULL; +fail_cq: + bce_destroy_cq(vhci->dev, ret->cq); + ret->cq = NULL; + return status; +} + +void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q) +{ + if (!q->cq) + return; + dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + q->data, q->dma_addr); + bce_destroy_sq(vhci->dev, q->sq); + bce_destroy_cq(vhci->dev, q->cq); +} + +void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req) +{ + int sidx; + struct bce_qe_submission *s; + sidx = q->sq->tail; + s = bce_next_submission(q->sq); + pr_debug("bce-vhci: Send message: %x s=%x p1=%x p2=%llx\n", req->cmd, req->status, req->param1, req->param2); + q->data[sidx] = *req; + bce_set_submission_single(s, q->dma_addr + sizeof(struct bce_vhci_message) * sidx, + sizeof(struct bce_vhci_message)); + bce_submit_to_device(q->sq); +} + +static void bce_vhci_message_queue_completion(struct bce_queue_sq *sq) +{ + while (bce_next_completion(sq)) + bce_notify_submission_complete(sq); +} + + + +static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq); + +int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_sq_completion compl) +{ + ret->vhci = vhci; + + ret->sq = bce_create_sq(vhci->dev, vhci->ev_cq, name, VHCI_EVENT_QUEUE_EL_COUNT, DMA_FROM_DEVICE, compl, ret); + if (!ret->sq) + return -EINVAL; + ret->data = dma_alloc_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + &ret->dma_addr, GFP_KERNEL); + if (!ret->data) { + bce_destroy_sq(vhci->dev, ret->sq); + ret->sq = NULL; + return -EINVAL; + } + + init_completion(&ret->queue_empty_completion); + bce_vhci_event_queue_submit_pending(ret, VHCI_EVENT_PENDING_COUNT); + return 0; +} + +int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_vhci_event_queue_callback cb) +{ + ret->cb = cb; + return __bce_vhci_event_queue_create(vhci, ret, name, bce_vhci_event_queue_completion); +} + +void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q) +{ + if (!q->sq) + return; + dma_free_coherent(&vhci->dev->pci->dev, sizeof(struct bce_vhci_message) * VHCI_EVENT_QUEUE_EL_COUNT, + q->data, q->dma_addr); + bce_destroy_sq(vhci->dev, q->sq); +} + +static void bce_vhci_event_queue_completion(struct bce_queue_sq *sq) +{ + struct bce_sq_completion_data *cd; + struct bce_vhci_event_queue *ev = sq->userdata; + struct bce_vhci_message *msg; + size_t cnt = 0; + + while ((cd = bce_next_completion(sq))) { + if (cd->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ + bce_notify_submission_complete(sq); + continue; + } + msg = &ev->data[sq->head]; + pr_debug("bce-vhci: Got event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); + ev->cb(ev, msg); + + bce_notify_submission_complete(sq); + ++cnt; + } + bce_vhci_event_queue_submit_pending(ev, cnt); + if (atomic_read(&sq->available_commands) == sq->el_count - 1) + complete(&ev->queue_empty_completion); +} + +void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count) +{ + int idx; + struct bce_qe_submission *s; + while (count--) { + if (bce_reserve_submission(q->sq, NULL)) { + pr_err("bce-vhci: Failed to reserve an event queue submission\n"); + break; + } + idx = q->sq->tail; + s = bce_next_submission(q->sq); + bce_set_submission_single(s, + q->dma_addr + idx * sizeof(struct bce_vhci_message), sizeof(struct bce_vhci_message)); + } + bce_submit_to_device(q->sq); +} + +void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q) +{ + unsigned long timeout; + reinit_completion(&q->queue_empty_completion); + if (bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, q->sq->qid)) + pr_warn("bce-vhci: failed to flush event queue\n"); + timeout = msecs_to_jiffies(5000); + while (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { + timeout = wait_for_completion_timeout(&q->queue_empty_completion, timeout); + if (timeout == 0) { + pr_err("bce-vhci: waiting for queue to be flushed timed out\n"); + break; + } + } +} + +void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q) +{ + if (atomic_read(&q->sq->available_commands) != q->sq->el_count - 1) { + pr_err("bce-vhci: resume of a queue with pending submissions\n"); + return; + } + bce_vhci_event_queue_submit_pending(q, VHCI_EVENT_PENDING_COUNT); +} + +void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq) +{ + ret->mq = mq; + ret->completion.result = NULL; + init_completion(&ret->completion.completion); + spin_lock_init(&ret->completion_lock); + mutex_init(&ret->mutex); +} + +void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq) +{ + spin_lock(&cq->completion_lock); + if (cq->completion.result) { + memset(cq->completion.result, 0, sizeof(struct bce_vhci_message)); + cq->completion.result->status = BCE_VHCI_ABORT; + complete(&cq->completion.completion); + cq->completion.result = NULL; + } + spin_unlock(&cq->completion_lock); + mutex_lock(&cq->mutex); + mutex_unlock(&cq->mutex); + mutex_destroy(&cq->mutex); +} + +void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg) +{ + struct bce_vhci_command_queue_completion *c = &cq->completion; + + spin_lock(&cq->completion_lock); + if (c->result) { + *c->result = *msg; + complete(&c->completion); + c->result = NULL; + } + spin_unlock(&cq->completion_lock); +} + +static int __bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout) +{ + int status; + struct bce_vhci_command_queue_completion *c; + struct bce_vhci_message creq; + c = &cq->completion; + + if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) + return status; + + spin_lock(&cq->completion_lock); + c->result = res; + reinit_completion(&c->completion); + spin_unlock(&cq->completion_lock); + + bce_vhci_message_queue_write(cq->mq, req); + + if (!wait_for_completion_timeout(&c->completion, timeout)) { + /* we ran out of time, send cancellation */ + pr_debug("bce-vhci: command timed out req=%x\n", req->cmd); + if ((status = bce_reserve_submission(cq->mq->sq, &timeout))) + return status; + + creq = *req; + creq.cmd |= 0x4000; + bce_vhci_message_queue_write(cq->mq, &creq); + + if (!wait_for_completion_timeout(&c->completion, 1000)) { + pr_err("bce-vhci: Possible desync, cmd cancel timed out\n"); + + spin_lock(&cq->completion_lock); + c->result = NULL; + spin_unlock(&cq->completion_lock); + return -ETIMEDOUT; + } + if ((res->cmd & ~0x8000) == creq.cmd) + return -ETIMEDOUT; + /* reply for the previous command most likely arrived */ + } + + if ((res->cmd & ~0x8000) != req->cmd) { + pr_err("bce-vhci: Possible desync, cmd reply mismatch req=%x, res=%x\n", req->cmd, res->cmd); + return -EIO; + } + if (res->status == BCE_VHCI_SUCCESS) + return 0; + return res->status; +} + +int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout) +{ + int status; + mutex_lock(&cq->mutex); + status = __bce_vhci_command_queue_execute(cq, req, res, timeout); + mutex_unlock(&cq->mutex); + return status; +} diff --git a/drivers/staging/apple-bce/vhci/queue.h b/drivers/staging/apple-bce/vhci/queue.h new file mode 100644 index 000000000000..adb705b6ba1d --- /dev/null +++ b/drivers/staging/apple-bce/vhci/queue.h @@ -0,0 +1,76 @@ +#ifndef BCE_VHCI_QUEUE_H +#define BCE_VHCI_QUEUE_H + +#include +#include "../queue.h" + +#define VHCI_EVENT_QUEUE_EL_COUNT 256 +#define VHCI_EVENT_PENDING_COUNT 32 + +struct bce_vhci; +struct bce_vhci_event_queue; + +enum bce_vhci_message_status { + BCE_VHCI_SUCCESS = 1, + BCE_VHCI_ERROR = 2, + BCE_VHCI_USB_PIPE_STALL = 3, + BCE_VHCI_ABORT = 4, + BCE_VHCI_BAD_ARGUMENT = 5, + BCE_VHCI_OVERRUN = 6, + BCE_VHCI_INTERNAL_ERROR = 7, + BCE_VHCI_NO_POWER = 8, + BCE_VHCI_UNSUPPORTED = 9 +}; +struct bce_vhci_message { + u16 cmd; + u16 status; // bce_vhci_message_status + u32 param1; + u64 param2; +}; + +struct bce_vhci_message_queue { + struct bce_queue_cq *cq; + struct bce_queue_sq *sq; + struct bce_vhci_message *data; + dma_addr_t dma_addr; +}; +typedef void (*bce_vhci_event_queue_callback)(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); +struct bce_vhci_event_queue { + struct bce_vhci *vhci; + struct bce_queue_sq *sq; + struct bce_vhci_message *data; + dma_addr_t dma_addr; + bce_vhci_event_queue_callback cb; + struct completion queue_empty_completion; +}; +struct bce_vhci_command_queue_completion { + struct bce_vhci_message *result; + struct completion completion; +}; +struct bce_vhci_command_queue { + struct bce_vhci_message_queue *mq; + struct bce_vhci_command_queue_completion completion; + struct spinlock completion_lock; + struct mutex mutex; +}; + +int bce_vhci_message_queue_create(struct bce_vhci *vhci, struct bce_vhci_message_queue *ret, const char *name); +void bce_vhci_message_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_message_queue *q); +void bce_vhci_message_queue_write(struct bce_vhci_message_queue *q, struct bce_vhci_message *req); + +int __bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_sq_completion compl); +int bce_vhci_event_queue_create(struct bce_vhci *vhci, struct bce_vhci_event_queue *ret, const char *name, + bce_vhci_event_queue_callback cb); +void bce_vhci_event_queue_destroy(struct bce_vhci *vhci, struct bce_vhci_event_queue *q); +void bce_vhci_event_queue_submit_pending(struct bce_vhci_event_queue *q, size_t count); +void bce_vhci_event_queue_pause(struct bce_vhci_event_queue *q); +void bce_vhci_event_queue_resume(struct bce_vhci_event_queue *q); + +void bce_vhci_command_queue_create(struct bce_vhci_command_queue *ret, struct bce_vhci_message_queue *mq); +void bce_vhci_command_queue_destroy(struct bce_vhci_command_queue *cq); +int bce_vhci_command_queue_execute(struct bce_vhci_command_queue *cq, struct bce_vhci_message *req, + struct bce_vhci_message *res, unsigned long timeout); +void bce_vhci_command_queue_deliver_completion(struct bce_vhci_command_queue *cq, struct bce_vhci_message *msg); + +#endif //BCE_VHCI_QUEUE_H diff --git a/drivers/staging/apple-bce/vhci/transfer.c b/drivers/staging/apple-bce/vhci/transfer.c new file mode 100644 index 000000000000..8226363d69c8 --- /dev/null +++ b/drivers/staging/apple-bce/vhci/transfer.c @@ -0,0 +1,661 @@ +#include "transfer.h" +#include "../queue.h" +#include "vhci.h" +#include "../apple_bce.h" +#include + +static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq); +static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q); +static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q); + +static int bce_vhci_urb_init(struct bce_vhci_urb *vurb); +static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg); +static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c); + +static void bce_vhci_transfer_queue_reset_w(struct work_struct *work); + +void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, + struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir) +{ + char name[0x21]; + INIT_LIST_HEAD(&q->evq); + INIT_LIST_HEAD(&q->giveback_urb_list); + spin_lock_init(&q->urb_lock); + mutex_init(&q->pause_lock); + q->vhci = vhci; + q->endp = endp; + q->dev_addr = dev_addr; + q->endp_addr = (u8) (endp->desc.bEndpointAddress & 0x8F); + q->state = BCE_VHCI_ENDPOINT_ACTIVE; + q->active = true; + q->stalled = false; + q->max_active_requests = 1; + if (usb_endpoint_type(&endp->desc) == USB_ENDPOINT_XFER_BULK) + q->max_active_requests = BCE_VHCI_BULK_MAX_ACTIVE_URBS; + q->remaining_active_requests = q->max_active_requests; + q->cq = bce_create_cq(vhci->dev, 0x100); + INIT_WORK(&q->w_reset, bce_vhci_transfer_queue_reset_w); + q->sq_in = NULL; + if (dir == DMA_FROM_DEVICE || dir == DMA_BIDIRECTIONAL) { + snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, 0x80 | usb_endpoint_num(&endp->desc)); + q->sq_in = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_FROM_DEVICE, + bce_vhci_transfer_queue_completion, q); + } + q->sq_out = NULL; + if (dir == DMA_TO_DEVICE || dir == DMA_BIDIRECTIONAL) { + snprintf(name, sizeof(name), "VHC1-%i-%02x", dev_addr, usb_endpoint_num(&endp->desc)); + q->sq_out = bce_create_sq(vhci->dev, q->cq, name, 0x100, DMA_TO_DEVICE, + bce_vhci_transfer_queue_completion, q); + } +} + +void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q) +{ + bce_vhci_transfer_queue_giveback(q); + bce_vhci_transfer_queue_remove_pending(q); + if (q->sq_in) + bce_destroy_sq(vhci->dev, q->sq_in); + if (q->sq_out) + bce_destroy_sq(vhci->dev, q->sq_out); + bce_destroy_cq(vhci->dev, q->cq); +} + +static inline bool bce_vhci_transfer_queue_can_init_urb(struct bce_vhci_transfer_queue *q) +{ + return q->remaining_active_requests > 0; +} + +static void bce_vhci_transfer_queue_defer_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) +{ + struct bce_vhci_list_message *lm; + lm = kmalloc(sizeof(struct bce_vhci_list_message), GFP_KERNEL); + INIT_LIST_HEAD(&lm->list); + lm->msg = *msg; + list_add_tail(&lm->list, &q->evq); +} + +static void bce_vhci_transfer_queue_giveback(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + struct urb *urb; + spin_lock_irqsave(&q->urb_lock, flags); + while (!list_empty(&q->giveback_urb_list)) { + urb = list_first_entry(&q->giveback_urb_list, struct urb, urb_list); + list_del(&urb->urb_list); + + spin_unlock_irqrestore(&q->urb_lock, flags); + usb_hcd_giveback_urb(q->vhci->hcd, urb, urb->status); + spin_lock_irqsave(&q->urb_lock, flags); + } + spin_unlock_irqrestore(&q->urb_lock, flags); +} + +static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q); + +static void bce_vhci_transfer_queue_deliver_pending(struct bce_vhci_transfer_queue *q) +{ + struct urb *urb; + struct bce_vhci_list_message *lm; + + while (!list_empty(&q->endp->urb_list) && !list_empty(&q->evq)) { + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + + lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); + if (bce_vhci_urb_update(urb->hcpriv, &lm->msg) == -EAGAIN) + break; + list_del(&lm->list); + kfree(lm); + } + + /* some of the URBs could have been completed, so initialize more URBs if possible */ + bce_vhci_transfer_queue_init_pending_urbs(q); +} + +static void bce_vhci_transfer_queue_remove_pending(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + struct bce_vhci_list_message *lm; + spin_lock_irqsave(&q->urb_lock, flags); + while (!list_empty(&q->evq)) { + lm = list_first_entry(&q->evq, struct bce_vhci_list_message, list); + list_del(&lm->list); + kfree(lm); + } + spin_unlock_irqrestore(&q->urb_lock, flags); +} + +void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg) +{ + unsigned long flags; + struct bce_vhci_urb *turb; + struct urb *urb; + spin_lock_irqsave(&q->urb_lock, flags); + bce_vhci_transfer_queue_deliver_pending(q); + + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && + (!list_empty(&q->evq) || list_empty(&q->endp->urb_list))) { + bce_vhci_transfer_queue_defer_event(q, msg); + goto complete; + } + if (list_empty(&q->endp->urb_list)) { + pr_err("bce-vhci: [%02x] Unexpected transfer queue event\n", q->endp_addr); + goto complete; + } + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + turb = urb->hcpriv; + if (bce_vhci_urb_update(turb, msg) == -EAGAIN) { + bce_vhci_transfer_queue_defer_event(q, msg); + } else { + bce_vhci_transfer_queue_init_pending_urbs(q); + } + +complete: + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_giveback(q); +} + +static void bce_vhci_transfer_queue_completion(struct bce_queue_sq *sq) +{ + unsigned long flags; + struct bce_sq_completion_data *c; + struct urb *urb; + struct bce_vhci_transfer_queue *q = sq->userdata; + spin_lock_irqsave(&q->urb_lock, flags); + while ((c = bce_next_completion(sq))) { + if (c->status == BCE_COMPLETION_ABORTED) { /* We flushed the queue */ + pr_debug("bce-vhci: [%02x] Got an abort completion\n", q->endp_addr); + bce_notify_submission_complete(sq); + continue; + } + if (list_empty(&q->endp->urb_list)) { + pr_err("bce-vhci: [%02x] Got a completion while no requests are pending\n", q->endp_addr); + continue; + } + pr_debug("bce-vhci: [%02x] Got a transfer queue completion\n", q->endp_addr); + urb = list_first_entry(&q->endp->urb_list, struct urb, urb_list); + bce_vhci_urb_transfer_completion(urb->hcpriv, c); + bce_notify_submission_complete(sq); + } + bce_vhci_transfer_queue_deliver_pending(q); + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_giveback(q); +} + +int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + int status; + u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); + spin_lock_irqsave(&q->urb_lock, flags); + q->active = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + if (q->sq_out) { + pr_err("bce-vhci: Not implemented: wait for pending output requests\n"); + } + bce_vhci_transfer_queue_remove_pending(q); + if ((status = bce_vhci_cmd_endpoint_set_state( + &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_PAUSED, &q->state))) + return status; + if (q->state != BCE_VHCI_ENDPOINT_PAUSED) + return -EINVAL; + if (q->sq_in) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); + if (q->sq_out) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); + return 0; +} + +static void bce_vhci_urb_resume(struct bce_vhci_urb *urb); + +int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q) +{ + unsigned long flags; + int status; + struct urb *urb, *urbt; + struct bce_vhci_urb *vurb; + u8 endp_addr = (u8) (q->endp->desc.bEndpointAddress & 0x8F); + if ((status = bce_vhci_cmd_endpoint_set_state( + &q->vhci->cq, q->dev_addr, endp_addr, BCE_VHCI_ENDPOINT_ACTIVE, &q->state))) + return status; + if (q->state != BCE_VHCI_ENDPOINT_ACTIVE) + return -EINVAL; + spin_lock_irqsave(&q->urb_lock, flags); + q->active = true; + list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { + vurb = urb->hcpriv; + if (vurb->state == BCE_VHCI_URB_INIT_PENDING) { + if (!bce_vhci_transfer_queue_can_init_urb(q)) + break; + bce_vhci_urb_init(vurb); + } else { + bce_vhci_urb_resume(vurb); + } + } + bce_vhci_transfer_queue_deliver_pending(q); + spin_unlock_irqrestore(&q->urb_lock, flags); + return 0; +} + +int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) +{ + int ret = 0; + mutex_lock(&q->pause_lock); + if ((q->paused_by & src) != src) { + if (!q->paused_by) + ret = bce_vhci_transfer_queue_do_pause(q); + if (!ret) + q->paused_by |= src; + } + mutex_unlock(&q->pause_lock); + return ret; +} + +int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src) +{ + int ret = 0; + mutex_lock(&q->pause_lock); + if (q->paused_by & src) { + if (!(q->paused_by & ~src)) + ret = bce_vhci_transfer_queue_do_resume(q); + if (!ret) + q->paused_by &= ~src; + } + mutex_unlock(&q->pause_lock); + return ret; +} + +static void bce_vhci_transfer_queue_reset_w(struct work_struct *work) +{ + unsigned long flags; + struct bce_vhci_transfer_queue *q = container_of(work, struct bce_vhci_transfer_queue, w_reset); + + mutex_lock(&q->pause_lock); + spin_lock_irqsave(&q->urb_lock, flags); + if (!q->stalled) { + spin_unlock_irqrestore(&q->urb_lock, flags); + mutex_unlock(&q->pause_lock); + return; + } + q->active = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + q->paused_by |= BCE_VHCI_PAUSE_INTERNAL_WQ; + bce_vhci_transfer_queue_remove_pending(q); + if (q->sq_in) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_in->qid); + if (q->sq_out) + bce_cmd_flush_memory_queue(q->vhci->dev->cmd_cmdq, (u16) q->sq_out->qid); + bce_vhci_cmd_endpoint_reset(&q->vhci->cq, q->dev_addr, (u8) (q->endp->desc.bEndpointAddress & 0x8F)); + spin_lock_irqsave(&q->urb_lock, flags); + q->stalled = false; + spin_unlock_irqrestore(&q->urb_lock, flags); + mutex_unlock(&q->pause_lock); + bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); +} + +void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q) +{ + queue_work(q->vhci->tq_state_wq, &q->w_reset); +} + +static void bce_vhci_transfer_queue_init_pending_urbs(struct bce_vhci_transfer_queue *q) +{ + struct urb *urb, *urbt; + struct bce_vhci_urb *vurb; + list_for_each_entry_safe(urb, urbt, &q->endp->urb_list, urb_list) { + vurb = urb->hcpriv; + if (!bce_vhci_transfer_queue_can_init_urb(q)) + break; + if (vurb->state == BCE_VHCI_URB_INIT_PENDING) + bce_vhci_urb_init(vurb); + } +} + + + +static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout); + +int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb) +{ + unsigned long flags; + int status = 0; + struct bce_vhci_urb *vurb; + vurb = kzalloc(sizeof(struct bce_vhci_urb), GFP_KERNEL); + urb->hcpriv = vurb; + + vurb->q = q; + vurb->urb = urb; + vurb->dir = usb_urb_dir_in(urb) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; + vurb->is_control = (usb_endpoint_num(&urb->ep->desc) == 0); + + spin_lock_irqsave(&q->urb_lock, flags); + status = usb_hcd_link_urb_to_ep(q->vhci->hcd, urb); + if (status) { + spin_unlock_irqrestore(&q->urb_lock, flags); + urb->hcpriv = NULL; + kfree(vurb); + return status; + } + + if (q->active) { + if (bce_vhci_transfer_queue_can_init_urb(vurb->q)) + status = bce_vhci_urb_init(vurb); + else + vurb->state = BCE_VHCI_URB_INIT_PENDING; + } else { + if (q->stalled) + bce_vhci_transfer_queue_request_reset(q); + vurb->state = BCE_VHCI_URB_INIT_PENDING; + } + if (status) { + usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); + urb->hcpriv = NULL; + kfree(vurb); + } else { + bce_vhci_transfer_queue_deliver_pending(q); + } + spin_unlock_irqrestore(&q->urb_lock, flags); + pr_debug("bce-vhci: [%02x] URB enqueued (dir = %s, size = %i)\n", q->endp_addr, + usb_urb_dir_in(urb) ? "IN" : "OUT", urb->transfer_buffer_length); + return status; +} + +static int bce_vhci_urb_init(struct bce_vhci_urb *vurb) +{ + int status = 0; + + if (vurb->q->remaining_active_requests == 0) { + pr_err("bce-vhci: cannot init request (remaining_active_requests = 0)\n"); + return -EINVAL; + } + + if (vurb->is_control) { + vurb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST; + } else { + status = bce_vhci_urb_data_start(vurb, NULL); + } + + if (!status) { + --vurb->q->remaining_active_requests; + } + return status; +} + +static void bce_vhci_urb_complete(struct bce_vhci_urb *urb, int status) +{ + struct bce_vhci_transfer_queue *q = urb->q; + struct bce_vhci *vhci = q->vhci; + struct urb *real_urb = urb->urb; + pr_debug("bce-vhci: [%02x] URB complete %i\n", q->endp_addr, status); + usb_hcd_unlink_urb_from_ep(vhci->hcd, real_urb); + real_urb->hcpriv = NULL; + real_urb->status = status; + if (urb->state != BCE_VHCI_URB_INIT_PENDING) + ++urb->q->remaining_active_requests; + kfree(urb); + list_add_tail(&real_urb->urb_list, &q->giveback_urb_list); +} + +int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status) +{ + struct bce_vhci_urb *vurb; + unsigned long flags; + int ret; + + spin_lock_irqsave(&q->urb_lock, flags); + if ((ret = usb_hcd_check_unlink_urb(q->vhci->hcd, urb, status))) { + spin_unlock_irqrestore(&q->urb_lock, flags); + return ret; + } + + vurb = urb->hcpriv; + /* If the URB wasn't posted to the device yet, we can still remove it on the host without pausing the queue. */ + if (vurb->state != BCE_VHCI_URB_INIT_PENDING) { + pr_debug("bce-vhci: [%02x] Cancelling URB\n", q->endp_addr); + + spin_unlock_irqrestore(&q->urb_lock, flags); + bce_vhci_transfer_queue_pause(q, BCE_VHCI_PAUSE_INTERNAL_WQ); + spin_lock_irqsave(&q->urb_lock, flags); + + ++q->remaining_active_requests; + } + + usb_hcd_unlink_urb_from_ep(q->vhci->hcd, urb); + + spin_unlock_irqrestore(&q->urb_lock, flags); + + usb_hcd_giveback_urb(q->vhci->hcd, urb, status); + + if (vurb->state != BCE_VHCI_URB_INIT_PENDING) + bce_vhci_transfer_queue_resume(q, BCE_VHCI_PAUSE_INTERNAL_WQ); + + kfree(vurb); + + return 0; +} + +static int bce_vhci_urb_data_transfer_in(struct bce_vhci_urb *urb, unsigned long *timeout) +{ + struct bce_vhci_message msg; + struct bce_qe_submission *s; + u32 tr_len; + int reservation1, reservation2 = -EFAULT; + + pr_debug("bce-vhci: [%02x] DMA from device %llx %x\n", urb->q->endp_addr, + (u64) urb->urb->transfer_dma, urb->urb->transfer_buffer_length); + + /* Reserve both a message and a submission, so we don't run into issues later. */ + reservation1 = bce_reserve_submission(urb->q->vhci->msg_asynchronous.sq, timeout); + if (!reservation1) + reservation2 = bce_reserve_submission(urb->q->sq_in, timeout); + if (reservation1 || reservation2) { + pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); + if (!reservation1) + bce_cancel_submission_reservation(urb->q->vhci->msg_asynchronous.sq); + return -ENOMEM; + } + + urb->send_offset = urb->receive_offset; + + tr_len = urb->urb->transfer_buffer_length - urb->send_offset; + + spin_lock(&urb->q->vhci->msg_asynchronous_lock); + msg.cmd = BCE_VHCI_CMD_TRANSFER_REQUEST; + msg.status = 0; + msg.param1 = ((urb->urb->ep->desc.bEndpointAddress & 0x8Fu) << 8) | urb->q->dev_addr; + msg.param2 = tr_len; + bce_vhci_message_queue_write(&urb->q->vhci->msg_asynchronous, &msg); + spin_unlock(&urb->q->vhci->msg_asynchronous_lock); + + s = bce_next_submission(urb->q->sq_in); + bce_set_submission_single(s, urb->urb->transfer_dma + urb->send_offset, tr_len); + bce_submit_to_device(urb->q->sq_in); + + urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; + return 0; +} + +static int bce_vhci_urb_data_start(struct bce_vhci_urb *urb, unsigned long *timeout) +{ + if (urb->dir == DMA_TO_DEVICE) { + if (urb->urb->transfer_buffer_length > 0) + urb->state = BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST; + else + urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; + return 0; + } else { + return bce_vhci_urb_data_transfer_in(urb, timeout); + } +} + +static int bce_vhci_urb_send_out_data(struct bce_vhci_urb *urb, dma_addr_t addr, size_t size) +{ + struct bce_qe_submission *s; + unsigned long timeout = 0; + if (bce_reserve_submission(urb->q->sq_out, &timeout)) { + pr_err("bce-vhci: Failed to reserve a submission for URB data transfer\n"); + return -EPIPE; + } + + pr_debug("bce-vhci: [%02x] DMA to device %llx %lx\n", urb->q->endp_addr, (u64) addr, size); + + s = bce_next_submission(urb->q->sq_out); + bce_set_submission_single(s, addr, size); + bce_submit_to_device(urb->q->sq_out); + return 0; +} + +static int bce_vhci_urb_data_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + u32 tr_len; + int status; + if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST) { + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { + tr_len = min(urb->urb->transfer_buffer_length - urb->send_offset, (u32) msg->param2); + if ((status = bce_vhci_urb_send_out_data(urb, urb->urb->transfer_dma + urb->send_offset, tr_len))) + return status; + urb->send_offset += tr_len; + urb->state = BCE_VHCI_URB_WAITING_FOR_COMPLETION; + return 0; + } + } + + /* 0x1000 in out queues aren't really unexpected */ + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) + return -EAGAIN; + pr_err("bce-vhci: [%02x] %s URB unexpected message (state = %x, msg: %x %x %x %llx)\n", + urb->q->endp_addr, (urb->is_control ? "Control (data update)" : "Data"), urb->state, + msg->cmd, msg->status, msg->param1, msg->param2); + return -EAGAIN; +} + +static int bce_vhci_urb_data_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + urb->receive_offset += c->data_size; + if (urb->dir == DMA_FROM_DEVICE || urb->receive_offset >= urb->urb->transfer_buffer_length) { + urb->urb->actual_length = (u32) urb->receive_offset; + urb->state = BCE_VHCI_URB_DATA_TRANSFER_COMPLETE; + if (!urb->is_control) { + bce_vhci_urb_complete(urb, 0); + return -ENOENT; + } + } + } else { + pr_err("bce-vhci: [%02x] Data URB unexpected completion\n", urb->q->endp_addr); + } + return 0; +} + + +static int bce_vhci_urb_control_check_status(struct bce_vhci_urb *urb) +{ + struct bce_vhci_transfer_queue *q = urb->q; + if (urb->received_status == 0) + return 0; + if (urb->state == BCE_VHCI_URB_DATA_TRANSFER_COMPLETE || + (urb->received_status != BCE_VHCI_SUCCESS && urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST && + urb->state != BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION)) { + urb->state = BCE_VHCI_URB_CONTROL_COMPLETE; + if (urb->received_status != BCE_VHCI_SUCCESS) { + pr_err("bce-vhci: [%02x] URB failed: %x\n", urb->q->endp_addr, urb->received_status); + urb->q->active = false; + urb->q->stalled = true; + bce_vhci_urb_complete(urb, -EPIPE); + if (!list_empty(&q->endp->urb_list)) + bce_vhci_transfer_queue_request_reset(q); + return -ENOENT; + } + bce_vhci_urb_complete(urb, 0); + return -ENOENT; + } + return 0; +} + +static int bce_vhci_urb_control_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + int status; + if (msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { + urb->received_status = msg->status; + return bce_vhci_urb_control_check_status(urb); + } + + if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST) { + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST) { + if (bce_vhci_urb_send_out_data(urb, urb->urb->setup_dma, sizeof(struct usb_ctrlrequest))) { + pr_err("bce-vhci: [%02x] Failed to start URB setup transfer\n", urb->q->endp_addr); + return 0; /* TODO: fail the URB? */ + } + urb->state = BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION; + pr_debug("bce-vhci: [%02x] Sent setup %llx\n", urb->q->endp_addr, urb->urb->setup_dma); + return 0; + } + } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || + urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + if ((status = bce_vhci_urb_data_update(urb, msg))) + return status; + return bce_vhci_urb_control_check_status(urb); + } + + /* 0x1000 in out queues aren't really unexpected */ + if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST && urb->q->sq_out != NULL) + return -EAGAIN; + pr_err("bce-vhci: [%02x] Control URB unexpected message (state = %x, msg: %x %x %x %llx)\n", urb->q->endp_addr, + urb->state, msg->cmd, msg->status, msg->param1, msg->param2); + return -EAGAIN; +} + +static int bce_vhci_urb_control_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + int status; + unsigned long timeout; + + if (urb->state == BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION) { + if (c->data_size != sizeof(struct usb_ctrlrequest)) + pr_err("bce-vhci: [%02x] transfer complete data size mistmatch for usb_ctrlrequest (%llx instead of %lx)\n", + urb->q->endp_addr, c->data_size, sizeof(struct usb_ctrlrequest)); + + timeout = 1000; + status = bce_vhci_urb_data_start(urb, &timeout); + if (status) { + bce_vhci_urb_complete(urb, status); + return -ENOENT; + } + return 0; + } else if (urb->state == BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST || + urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + if ((status = bce_vhci_urb_data_transfer_completion(urb, c))) + return status; + return bce_vhci_urb_control_check_status(urb); + } else { + pr_err("bce-vhci: [%02x] Control URB unexpected completion (state = %x)\n", urb->q->endp_addr, urb->state); + } + return 0; +} + +static int bce_vhci_urb_update(struct bce_vhci_urb *urb, struct bce_vhci_message *msg) +{ + if (urb->state == BCE_VHCI_URB_INIT_PENDING) + return -EAGAIN; + if (urb->is_control) + return bce_vhci_urb_control_update(urb, msg); + else + return bce_vhci_urb_data_update(urb, msg); +} + +static int bce_vhci_urb_transfer_completion(struct bce_vhci_urb *urb, struct bce_sq_completion_data *c) +{ + if (urb->is_control) + return bce_vhci_urb_control_transfer_completion(urb, c); + else + return bce_vhci_urb_data_transfer_completion(urb, c); +} + +static void bce_vhci_urb_resume(struct bce_vhci_urb *urb) +{ + int status = 0; + if (urb->state == BCE_VHCI_URB_WAITING_FOR_COMPLETION) { + status = bce_vhci_urb_data_transfer_in(urb, NULL); + } + if (status) + bce_vhci_urb_complete(urb, status); +} diff --git a/drivers/staging/apple-bce/vhci/transfer.h b/drivers/staging/apple-bce/vhci/transfer.h new file mode 100644 index 000000000000..89ecad6bcf8f --- /dev/null +++ b/drivers/staging/apple-bce/vhci/transfer.h @@ -0,0 +1,73 @@ +#ifndef BCEDRIVER_TRANSFER_H +#define BCEDRIVER_TRANSFER_H + +#include +#include "queue.h" +#include "command.h" +#include "../queue.h" + +struct bce_vhci_list_message { + struct list_head list; + struct bce_vhci_message msg; +}; +enum bce_vhci_pause_source { + BCE_VHCI_PAUSE_INTERNAL_WQ = 1, + BCE_VHCI_PAUSE_FIRMWARE = 2, + BCE_VHCI_PAUSE_SUSPEND = 4, + BCE_VHCI_PAUSE_SHUTDOWN = 8 +}; +struct bce_vhci_transfer_queue { + struct bce_vhci *vhci; + struct usb_host_endpoint *endp; + enum bce_vhci_endpoint_state state; + u32 max_active_requests, remaining_active_requests; + bool active, stalled; + u32 paused_by; + bce_vhci_device_t dev_addr; + u8 endp_addr; + struct bce_queue_cq *cq; + struct bce_queue_sq *sq_in; + struct bce_queue_sq *sq_out; + struct list_head evq; + struct spinlock urb_lock; + struct mutex pause_lock; + struct list_head giveback_urb_list; + + struct work_struct w_reset; +}; +enum bce_vhci_urb_state { + BCE_VHCI_URB_INIT_PENDING, + + BCE_VHCI_URB_WAITING_FOR_TRANSFER_REQUEST, + BCE_VHCI_URB_WAITING_FOR_COMPLETION, + BCE_VHCI_URB_DATA_TRANSFER_COMPLETE, + + BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_REQUEST, + BCE_VHCI_URB_CONTROL_WAITING_FOR_SETUP_COMPLETION, + BCE_VHCI_URB_CONTROL_COMPLETE +}; +struct bce_vhci_urb { + struct urb *urb; + struct bce_vhci_transfer_queue *q; + enum dma_data_direction dir; + bool is_control; + enum bce_vhci_urb_state state; + int received_status; + u32 send_offset; + u32 receive_offset; +}; + +void bce_vhci_create_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q, + struct usb_host_endpoint *endp, bce_vhci_device_t dev_addr, enum dma_data_direction dir); +void bce_vhci_destroy_transfer_queue(struct bce_vhci *vhci, struct bce_vhci_transfer_queue *q); +void bce_vhci_transfer_queue_event(struct bce_vhci_transfer_queue *q, struct bce_vhci_message *msg); +int bce_vhci_transfer_queue_do_pause(struct bce_vhci_transfer_queue *q); +int bce_vhci_transfer_queue_do_resume(struct bce_vhci_transfer_queue *q); +int bce_vhci_transfer_queue_pause(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); +int bce_vhci_transfer_queue_resume(struct bce_vhci_transfer_queue *q, enum bce_vhci_pause_source src); +void bce_vhci_transfer_queue_request_reset(struct bce_vhci_transfer_queue *q); + +int bce_vhci_urb_create(struct bce_vhci_transfer_queue *q, struct urb *urb); +int bce_vhci_urb_request_cancel(struct bce_vhci_transfer_queue *q, struct urb *urb, int status); + +#endif //BCEDRIVER_TRANSFER_H diff --git a/drivers/staging/apple-bce/vhci/vhci.c b/drivers/staging/apple-bce/vhci/vhci.c new file mode 100644 index 000000000000..eb26f55000d8 --- /dev/null +++ b/drivers/staging/apple-bce/vhci/vhci.c @@ -0,0 +1,759 @@ +#include "vhci.h" +#include "../apple_bce.h" +#include "command.h" +#include +#include +#include +#include + +static dev_t bce_vhci_chrdev; +static struct class *bce_vhci_class; +static const struct hc_driver bce_vhci_driver; +static u16 bce_vhci_port_mask = U16_MAX; + +static int bce_vhci_create_event_queues(struct bce_vhci *vhci); +static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci); +static int bce_vhci_create_message_queues(struct bce_vhci *vhci); +static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci); +static void bce_vhci_handle_firmware_events_w(struct work_struct *ws); +static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq); + +int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci) +{ + int status; + + spin_lock_init(&vhci->hcd_spinlock); + + vhci->dev = dev; + + vhci->vdevt = bce_vhci_chrdev; + vhci->vdev = device_create(bce_vhci_class, dev->dev, vhci->vdevt, NULL, "bce-vhci"); + if (IS_ERR_OR_NULL(vhci->vdev)) { + status = PTR_ERR(vhci->vdev); + goto fail_dev; + } + + if ((status = bce_vhci_create_message_queues(vhci))) + goto fail_mq; + if ((status = bce_vhci_create_event_queues(vhci))) + goto fail_eq; + + vhci->tq_state_wq = alloc_ordered_workqueue("bce-vhci-tq-state", 0); + INIT_WORK(&vhci->w_fw_events, bce_vhci_handle_firmware_events_w); + + vhci->hcd = usb_create_hcd(&bce_vhci_driver, vhci->vdev, "bce-vhci"); + if (!vhci->hcd) { + status = -ENOMEM; + goto fail_hcd; + } + vhci->hcd->self.sysdev = &dev->pci->dev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) + vhci->hcd->self.uses_dma = 1; +#endif + *((struct bce_vhci **) vhci->hcd->hcd_priv) = vhci; + vhci->hcd->speed = HCD_USB2; + + if ((status = usb_add_hcd(vhci->hcd, 0, 0))) + goto fail_hcd; + + return 0; + +fail_hcd: + bce_vhci_destroy_event_queues(vhci); +fail_eq: + bce_vhci_destroy_message_queues(vhci); +fail_mq: + device_destroy(bce_vhci_class, vhci->vdevt); +fail_dev: + if (!status) + status = -EINVAL; + return status; +} + +void bce_vhci_destroy(struct bce_vhci *vhci) +{ + usb_remove_hcd(vhci->hcd); + bce_vhci_destroy_event_queues(vhci); + bce_vhci_destroy_message_queues(vhci); + device_destroy(bce_vhci_class, vhci->vdevt); +} + +struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd) +{ + return *((struct bce_vhci **) hcd->hcd_priv); +} + +int bce_vhci_start(struct usb_hcd *hcd) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int status; + u16 port_mask = 0; + bce_vhci_port_t port_no = 0; + if ((status = bce_vhci_cmd_controller_enable(&vhci->cq, 1, &port_mask))) + return status; + vhci->port_mask = port_mask; + vhci->port_power_mask = 0; + if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) + return status; + port_mask = vhci->port_mask; + while (port_mask) { + port_no += 1; + port_mask >>= 1; + } + vhci->port_count = port_no; + return 0; +} + +void bce_vhci_stop(struct usb_hcd *hcd) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_cmd_controller_disable(&vhci->cq); +} + +static int bce_vhci_hub_status_data(struct usb_hcd *hcd, char *buf) +{ + return 0; +} + +static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout); + +static int bce_vhci_hub_control(struct usb_hcd *hcd, u16 typeReq, u16 wValue, u16 wIndex, char *buf, u16 wLength) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int status; + struct usb_hub_descriptor *hd; + struct usb_hub_status *hs; + struct usb_port_status *ps; + u32 port_status; + // pr_info("bce-vhci: bce_vhci_hub_control %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); + if (typeReq == GetHubDescriptor && wLength >= sizeof(struct usb_hub_descriptor)) { + hd = (struct usb_hub_descriptor *) buf; + memset(hd, 0, sizeof(*hd)); + hd->bDescLength = sizeof(struct usb_hub_descriptor); + hd->bDescriptorType = USB_DT_HUB; + hd->bNbrPorts = (u8) vhci->port_count; + hd->wHubCharacteristics = HUB_CHAR_INDV_PORT_LPSM | HUB_CHAR_INDV_PORT_OCPM; + hd->bPwrOn2PwrGood = 0; + hd->bHubContrCurrent = 0; + return 0; + } else if (typeReq == GetHubStatus && wLength >= sizeof(struct usb_hub_status)) { + hs = (struct usb_hub_status *) buf; + memset(hs, 0, sizeof(*hs)); + hs->wHubStatus = 0; + hs->wHubChange = 0; + return 0; + } else if (typeReq == GetPortStatus && wLength >= 4 /* usb 2.0 */) { + ps = (struct usb_port_status *) buf; + ps->wPortStatus = 0; + ps->wPortChange = 0; + + if (vhci->port_power_mask & BIT(wIndex)) + ps->wPortStatus |= USB_PORT_STAT_POWER; + + if (!(bce_vhci_port_mask & BIT(wIndex))) + return 0; + + if ((status = bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0, &port_status))) + return status; + + if (port_status & 16) + ps->wPortStatus |= USB_PORT_STAT_ENABLE | USB_PORT_STAT_HIGH_SPEED; + if (port_status & 4) + ps->wPortStatus |= USB_PORT_STAT_CONNECTION; + if (port_status & 2) + ps->wPortStatus |= USB_PORT_STAT_OVERCURRENT; + if (port_status & 8) + ps->wPortStatus |= USB_PORT_STAT_RESET; + if (port_status & 0x60) + ps->wPortStatus |= USB_PORT_STAT_SUSPEND; + + if (port_status & 0x40000) + ps->wPortChange |= USB_PORT_STAT_C_CONNECTION; + + pr_debug("bce-vhci: Translated status %x to %x:%x\n", port_status, ps->wPortStatus, ps->wPortChange); + return 0; + } else if (typeReq == SetPortFeature) { + if (wValue == USB_PORT_FEAT_POWER) { + status = bce_vhci_cmd_port_power_on(&vhci->cq, (u8) wIndex); + /* As far as I am aware, power status is not part of the port status so store it separately */ + if (!status) + vhci->port_power_mask |= BIT(wIndex); + return status; + } + if (wValue == USB_PORT_FEAT_RESET) { + return bce_vhci_reset_device(vhci, wIndex, wValue); + } + if (wValue == USB_PORT_FEAT_SUSPEND) { + /* TODO: Am I supposed to also suspend the endpoints? */ + pr_debug("bce-vhci: Suspending port %i\n", wIndex); + return bce_vhci_cmd_port_suspend(&vhci->cq, (u8) wIndex); + } + } else if (typeReq == ClearPortFeature) { + if (wValue == USB_PORT_FEAT_ENABLE) + return bce_vhci_cmd_port_disable(&vhci->cq, (u8) wIndex); + if (wValue == USB_PORT_FEAT_POWER) { + status = bce_vhci_cmd_port_power_off(&vhci->cq, (u8) wIndex); + if (!status) + vhci->port_power_mask &= ~BIT(wIndex); + return status; + } + if (wValue == USB_PORT_FEAT_C_CONNECTION) + return bce_vhci_cmd_port_status(&vhci->cq, (u8) wIndex, 0x40000, &port_status); + if (wValue == USB_PORT_FEAT_C_RESET) { /* I don't think I can transfer it in any way */ + return 0; + } + if (wValue == USB_PORT_FEAT_SUSPEND) { + pr_debug("bce-vhci: Resuming port %i\n", wIndex); + return bce_vhci_cmd_port_resume(&vhci->cq, (u8) wIndex); + } + } + pr_err("bce-vhci: bce_vhci_hub_control unhandled request: %x %i %i [bufl=%i]\n", typeReq, wValue, wIndex, wLength); + dump_stack(); + return -EIO; +} + +static int bce_vhci_enable_device(struct usb_hcd *hcd, struct usb_device *udev) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + struct bce_vhci_device *vdev; + bce_vhci_device_t devid; + pr_info("bce_vhci_enable_device\n"); + + if (vhci->port_to_device[udev->portnum]) + return 0; + + /* We need to early address the device */ + if (bce_vhci_cmd_device_create(&vhci->cq, udev->portnum, &devid)) + return -EIO; + + pr_info("bce_vhci_cmd_device_create %i -> %i\n", udev->portnum, devid); + + vdev = kzalloc(sizeof(struct bce_vhci_device), GFP_KERNEL); + vhci->port_to_device[udev->portnum] = devid; + vhci->devices[devid] = vdev; + + bce_vhci_create_transfer_queue(vhci, &vdev->tq[0], &udev->ep0, devid, DMA_BIDIRECTIONAL); + udev->ep0.hcpriv = &vdev->tq[0]; + vdev->tq_mask |= BIT(0); + + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &udev->ep0.desc); + return 0; +} + +static int bce_vhci_address_device(struct usb_hcd *hcd, struct usb_device *udev, unsigned int timeout_ms) //TODO: follow timeout +{ + /* This is the same as enable_device, but instead in the old scheme */ + return bce_vhci_enable_device(hcd, udev); +} + +static void bce_vhci_free_device(struct usb_hcd *hcd, struct usb_device *udev) +{ + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + int i; + bce_vhci_device_t devid; + struct bce_vhci_device *dev; + pr_info("bce_vhci_free_device %i\n", udev->portnum); + if (!vhci->port_to_device[udev->portnum]) + return; + devid = vhci->port_to_device[udev->portnum]; + dev = vhci->devices[devid]; + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); + bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); + } + } + vhci->devices[devid] = NULL; + vhci->port_to_device[udev->portnum] = 0; + bce_vhci_cmd_device_destroy(&vhci->cq, devid); + kfree(dev); +} + +static int bce_vhci_reset_device(struct bce_vhci *vhci, int index, u16 timeout) +{ + struct bce_vhci_device *dev = NULL; + bce_vhci_device_t devid; + int i; + int status; + enum dma_data_direction dir; + pr_info("bce_vhci_reset_device %i\n", index); + + devid = vhci->port_to_device[index]; + if (devid) { + dev = vhci->devices[devid]; + + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + bce_vhci_transfer_queue_pause(&dev->tq[i], BCE_VHCI_PAUSE_SHUTDOWN); + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) i); + bce_vhci_destroy_transfer_queue(vhci, &dev->tq[i]); + } + } + vhci->devices[devid] = NULL; + vhci->port_to_device[index] = 0; + bce_vhci_cmd_device_destroy(&vhci->cq, devid); + } + status = bce_vhci_cmd_port_reset(&vhci->cq, (u8) index, timeout); + + if (dev) { + if ((status = bce_vhci_cmd_device_create(&vhci->cq, index, &devid))) + return status; + vhci->devices[devid] = dev; + vhci->port_to_device[index] = devid; + + for (i = 0; i < 32; i++) { + if (dev->tq_mask & BIT(i)) { + dir = usb_endpoint_dir_in(&dev->tq[i].endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE; + if (i == 0) + dir = DMA_BIDIRECTIONAL; + bce_vhci_create_transfer_queue(vhci, &dev->tq[i], dev->tq[i].endp, devid, dir); + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &dev->tq[i].endp->desc); + } + } + } + + return status; +} + +static int bce_vhci_check_bandwidth(struct usb_hcd *hcd, struct usb_device *udev) +{ + return 0; +} + +static int bce_vhci_get_frame_number(struct usb_hcd *hcd) +{ + return 0; +} + +static int bce_vhci_bus_suspend(struct usb_hcd *hcd) +{ + int i, j; + int status; + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + pr_info("bce_vhci: suspend started\n"); + + pr_info("bce_vhci: suspend endpoints\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + for (j = 0; j < 32; j++) { + if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) + continue; + bce_vhci_transfer_queue_pause(&vhci->devices[vhci->port_to_device[i]]->tq[j], + BCE_VHCI_PAUSE_SUSPEND); + } + } + + pr_info("bce_vhci: suspend ports\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + bce_vhci_cmd_port_suspend(&vhci->cq, i); + } + pr_info("bce_vhci: suspend controller\n"); + if ((status = bce_vhci_cmd_controller_pause(&vhci->cq))) + return status; + + bce_vhci_event_queue_pause(&vhci->ev_commands); + bce_vhci_event_queue_pause(&vhci->ev_system); + bce_vhci_event_queue_pause(&vhci->ev_isochronous); + bce_vhci_event_queue_pause(&vhci->ev_interrupt); + bce_vhci_event_queue_pause(&vhci->ev_asynchronous); + pr_info("bce_vhci: suspend done\n"); + return 0; +} + +static int bce_vhci_bus_resume(struct usb_hcd *hcd) +{ + int i, j; + int status; + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + pr_info("bce_vhci: resume started\n"); + + bce_vhci_event_queue_resume(&vhci->ev_system); + bce_vhci_event_queue_resume(&vhci->ev_isochronous); + bce_vhci_event_queue_resume(&vhci->ev_interrupt); + bce_vhci_event_queue_resume(&vhci->ev_asynchronous); + bce_vhci_event_queue_resume(&vhci->ev_commands); + + pr_info("bce_vhci: resume controller\n"); + if ((status = bce_vhci_cmd_controller_start(&vhci->cq))) + return status; + + pr_info("bce_vhci: resume ports\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + bce_vhci_cmd_port_resume(&vhci->cq, i); + } + pr_info("bce_vhci: resume endpoints\n"); + for (i = 0; i < 16; i++) { + if (!vhci->port_to_device[i]) + continue; + for (j = 0; j < 32; j++) { + if (!(vhci->devices[vhci->port_to_device[i]]->tq_mask & BIT(j))) + continue; + bce_vhci_transfer_queue_resume(&vhci->devices[vhci->port_to_device[i]]->tq[j], + BCE_VHCI_PAUSE_SUSPEND); + } + } + + pr_info("bce_vhci: resume done\n"); + return 0; +} + +static int bce_vhci_urb_enqueue(struct usb_hcd *hcd, struct urb *urb, gfp_t mem_flags) +{ + struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; + pr_debug("bce_vhci_urb_enqueue %i:%x\n", q->dev_addr, urb->ep->desc.bEndpointAddress); + if (!q) + return -ENOENT; + return bce_vhci_urb_create(q, urb); +} + +static int bce_vhci_urb_dequeue(struct usb_hcd *hcd, struct urb *urb, int status) +{ + struct bce_vhci_transfer_queue *q = urb->ep->hcpriv; + pr_debug("bce_vhci_urb_dequeue %x\n", urb->ep->desc.bEndpointAddress); + return bce_vhci_urb_request_cancel(q, urb, status); +} + +static void bce_vhci_endpoint_reset(struct usb_hcd *hcd, struct usb_host_endpoint *ep) +{ + struct bce_vhci_transfer_queue *q = ep->hcpriv; + pr_debug("bce_vhci_endpoint_reset\n"); + if (q) + bce_vhci_transfer_queue_request_reset(q); +} + +static u8 bce_vhci_endpoint_index(u8 addr) +{ + if (addr & 0x80) + return (u8) (0x10 + (addr & 0xf)); + return (u8) (addr & 0xf); +} + +static int bce_vhci_add_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) +{ + u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; + struct bce_vhci_device *vdev = vhci->devices[devid]; + pr_debug("bce_vhci_add_endpoint %x/%x:%x\n", udev->portnum, devid, endp_index); + + if (udev->bus->root_hub == udev) /* The USB hub */ + return 0; + if (vdev == NULL) + return -ENODEV; + if (vdev->tq_mask & BIT(endp_index)) { + endp->hcpriv = &vdev->tq[endp_index]; + return 0; + } + + bce_vhci_create_transfer_queue(vhci, &vdev->tq[endp_index], endp, devid, + usb_endpoint_dir_in(&endp->desc) ? DMA_FROM_DEVICE : DMA_TO_DEVICE); + endp->hcpriv = &vdev->tq[endp_index]; + vdev->tq_mask |= BIT(endp_index); + + bce_vhci_cmd_endpoint_create(&vhci->cq, devid, &endp->desc); + return 0; +} + +static int bce_vhci_drop_endpoint(struct usb_hcd *hcd, struct usb_device *udev, struct usb_host_endpoint *endp) +{ + u8 endp_index = bce_vhci_endpoint_index(endp->desc.bEndpointAddress); + struct bce_vhci *vhci = bce_vhci_from_hcd(hcd); + bce_vhci_device_t devid = vhci->port_to_device[udev->portnum]; + struct bce_vhci_transfer_queue *q = endp->hcpriv; + struct bce_vhci_device *vdev = vhci->devices[devid]; + pr_info("bce_vhci_drop_endpoint %x:%x\n", udev->portnum, endp_index); + if (!q) { + if (vdev && vdev->tq_mask & BIT(endp_index)) { + pr_err("something deleted the hcpriv?\n"); + q = &vdev->tq[endp_index]; + } else { + return 0; + } + } + + bce_vhci_cmd_endpoint_destroy(&vhci->cq, devid, (u8) (endp->desc.bEndpointAddress & 0x8Fu)); + vhci->devices[devid]->tq_mask &= ~BIT(endp_index); + bce_vhci_destroy_transfer_queue(vhci, q); + return 0; +} + +static int bce_vhci_create_message_queues(struct bce_vhci *vhci) +{ + if (bce_vhci_message_queue_create(vhci, &vhci->msg_commands, "VHC1HostCommands") || + bce_vhci_message_queue_create(vhci, &vhci->msg_system, "VHC1HostSystemEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_isochronous, "VHC1HostIsochronousEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_interrupt, "VHC1HostInterruptEvents") || + bce_vhci_message_queue_create(vhci, &vhci->msg_asynchronous, "VHC1HostAsynchronousEvents")) { + bce_vhci_destroy_message_queues(vhci); + return -EINVAL; + } + spin_lock_init(&vhci->msg_asynchronous_lock); + bce_vhci_command_queue_create(&vhci->cq, &vhci->msg_commands); + return 0; +} + +static void bce_vhci_destroy_message_queues(struct bce_vhci *vhci) +{ + bce_vhci_command_queue_destroy(&vhci->cq); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_commands); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_system); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_isochronous); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_interrupt); + bce_vhci_message_queue_destroy(vhci, &vhci->msg_asynchronous); +} + +static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); +static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg); + +static int bce_vhci_create_event_queues(struct bce_vhci *vhci) +{ + vhci->ev_cq = bce_create_cq(vhci->dev, 0x100); + if (!vhci->ev_cq) + return -EINVAL; +#define CREATE_EVENT_QUEUE(field, name, cb) bce_vhci_event_queue_create(vhci, &vhci->field, name, cb) + if (__bce_vhci_event_queue_create(vhci, &vhci->ev_commands, "VHC1FirmwareCommands", + bce_vhci_firmware_event_completion) || + CREATE_EVENT_QUEUE(ev_system, "VHC1FirmwareSystemEvents", bce_vhci_handle_system_event) || + CREATE_EVENT_QUEUE(ev_isochronous, "VHC1FirmwareIsochronousEvents", bce_vhci_handle_usb_event) || + CREATE_EVENT_QUEUE(ev_interrupt, "VHC1FirmwareInterruptEvents", bce_vhci_handle_usb_event) || + CREATE_EVENT_QUEUE(ev_asynchronous, "VHC1FirmwareAsynchronousEvents", bce_vhci_handle_usb_event)) { + bce_vhci_destroy_event_queues(vhci); + return -EINVAL; + } +#undef CREATE_EVENT_QUEUE + return 0; +} + +static void bce_vhci_destroy_event_queues(struct bce_vhci *vhci) +{ + bce_vhci_event_queue_destroy(vhci, &vhci->ev_commands); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_system); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_isochronous); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_interrupt); + bce_vhci_event_queue_destroy(vhci, &vhci->ev_asynchronous); + if (vhci->ev_cq) + bce_destroy_cq(vhci->dev, vhci->ev_cq); +} + +static void bce_vhci_send_fw_event_response(struct bce_vhci *vhci, struct bce_vhci_message *req, u16 status) +{ + unsigned long timeout = 1000; + struct bce_vhci_message r = *req; + r.cmd = (u16) (req->cmd | 0x8000u); + r.status = status; + r.param1 = req->param1; + r.param2 = 0; + + if (bce_reserve_submission(vhci->msg_system.sq, &timeout)) { + pr_err("bce-vhci: Cannot reserve submision for FW event reply\n"); + return; + } + bce_vhci_message_queue_write(&vhci->msg_system, &r); +} + +static int bce_vhci_handle_firmware_event(struct bce_vhci *vhci, struct bce_vhci_message *msg) +{ + unsigned long flags; + bce_vhci_device_t devid; + u8 endp; + struct bce_vhci_device *dev; + struct bce_vhci_transfer_queue *tq; + if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE || msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { + devid = (bce_vhci_device_t) (msg->param1 & 0xff); + endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); + dev = vhci->devices[devid]; + if (!dev || !(dev->tq_mask & BIT(endp))) + return BCE_VHCI_BAD_ARGUMENT; + tq = &dev->tq[endp]; + } + + if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_REQUEST_STATE) { + if (msg->param2 == BCE_VHCI_ENDPOINT_ACTIVE) { + bce_vhci_transfer_queue_resume(tq, BCE_VHCI_PAUSE_FIRMWARE); + return BCE_VHCI_SUCCESS; + } else if (msg->param2 == BCE_VHCI_ENDPOINT_PAUSED) { + bce_vhci_transfer_queue_pause(tq, BCE_VHCI_PAUSE_FIRMWARE); + return BCE_VHCI_SUCCESS; + } + return BCE_VHCI_BAD_ARGUMENT; + } else if (msg->cmd == BCE_VHCI_CMD_ENDPOINT_SET_STATE) { + if (msg->param2 == BCE_VHCI_ENDPOINT_STALLED) { + tq->state = msg->param2; + spin_lock_irqsave(&tq->urb_lock, flags); + tq->stalled = true; + spin_unlock_irqrestore(&tq->urb_lock, flags); + return BCE_VHCI_SUCCESS; + } + return BCE_VHCI_BAD_ARGUMENT; + } + pr_warn("bce-vhci: Unhandled firmware event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + return BCE_VHCI_BAD_ARGUMENT; +} + +static void bce_vhci_handle_firmware_events_w(struct work_struct *ws) +{ + size_t cnt = 0; + int result; + struct bce_vhci *vhci = container_of(ws, struct bce_vhci, w_fw_events); + struct bce_queue_sq *sq = vhci->ev_commands.sq; + struct bce_sq_completion_data *cq; + struct bce_vhci_message *msg, *msg2 = NULL; + + while (true) { + if (msg2) { + msg = msg2; + msg2 = NULL; + } else if ((cq = bce_next_completion(sq))) { + if (cq->status == BCE_COMPLETION_ABORTED) { + bce_notify_submission_complete(sq); + continue; + } + msg = &vhci->ev_commands.data[sq->head]; + } else { + break; + } + + pr_debug("bce-vhci: Got fw event: %x s=%x p1=%x p2=%llx\n", msg->cmd, msg->status, msg->param1, msg->param2); + if ((cq = bce_next_completion(sq))) { + msg2 = &vhci->ev_commands.data[(sq->head + 1) % sq->el_count]; + pr_debug("bce-vhci: Got second fw event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + if (cq->status != BCE_COMPLETION_ABORTED && + msg2->cmd == (msg->cmd | 0x4000) && msg2->param1 == msg->param1) { + /* Take two elements */ + pr_debug("bce-vhci: Cancelled\n"); + bce_vhci_send_fw_event_response(vhci, msg, BCE_VHCI_ABORT); + + bce_notify_submission_complete(sq); + bce_notify_submission_complete(sq); + msg2 = NULL; + cnt += 2; + continue; + } + + pr_warn("bce-vhci: Handle fw event - unexpected cancellation\n"); + } + + result = bce_vhci_handle_firmware_event(vhci, msg); + bce_vhci_send_fw_event_response(vhci, msg, (u16) result); + + + bce_notify_submission_complete(sq); + ++cnt; + } + bce_vhci_event_queue_submit_pending(&vhci->ev_commands, cnt); + if (atomic_read(&sq->available_commands) == sq->el_count - 1) { + pr_debug("bce-vhci: complete\n"); + complete(&vhci->ev_commands.queue_empty_completion); + } +} + +static void bce_vhci_firmware_event_completion(struct bce_queue_sq *sq) +{ + struct bce_vhci_event_queue *q = sq->userdata; + queue_work(q->vhci->tq_state_wq, &q->vhci->w_fw_events); +} + +static void bce_vhci_handle_system_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) +{ + if (msg->cmd & 0x8000) { + bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); + } else { + pr_warn("bce-vhci: Unhandled system event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + } +} + +static void bce_vhci_handle_usb_event(struct bce_vhci_event_queue *q, struct bce_vhci_message *msg) +{ + bce_vhci_device_t devid; + u8 endp; + struct bce_vhci_device *dev; + if (msg->cmd & 0x8000) { + bce_vhci_command_queue_deliver_completion(&q->vhci->cq, msg); + } else if (msg->cmd == BCE_VHCI_CMD_TRANSFER_REQUEST || msg->cmd == BCE_VHCI_CMD_CONTROL_TRANSFER_STATUS) { + devid = (bce_vhci_device_t) (msg->param1 & 0xff); + endp = bce_vhci_endpoint_index((u8) ((msg->param1 >> 8) & 0xff)); + dev = q->vhci->devices[devid]; + if (!dev || (dev->tq_mask & BIT(endp)) == 0) { + pr_err("bce-vhci: Didn't find destination for transfer queue event\n"); + return; + } + bce_vhci_transfer_queue_event(&dev->tq[endp], msg); + } else { + pr_warn("bce-vhci: Unhandled USB event: %x s=%x p1=%x p2=%llx\n", + msg->cmd, msg->status, msg->param1, msg->param2); + } +} + + + +static const struct hc_driver bce_vhci_driver = { + .description = "bce-vhci", + .product_desc = "BCE VHCI Host Controller", + .hcd_priv_size = sizeof(struct bce_vhci *), + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5,4,0) + .flags = HCD_USB2, +#else + .flags = HCD_USB2 | HCD_DMA, +#endif + + .start = bce_vhci_start, + .stop = bce_vhci_stop, + .hub_status_data = bce_vhci_hub_status_data, + .hub_control = bce_vhci_hub_control, + .urb_enqueue = bce_vhci_urb_enqueue, + .urb_dequeue = bce_vhci_urb_dequeue, + .enable_device = bce_vhci_enable_device, + .free_dev = bce_vhci_free_device, + .address_device = bce_vhci_address_device, + .add_endpoint = bce_vhci_add_endpoint, + .drop_endpoint = bce_vhci_drop_endpoint, + .endpoint_reset = bce_vhci_endpoint_reset, + .check_bandwidth = bce_vhci_check_bandwidth, + .get_frame_number = bce_vhci_get_frame_number, + .bus_suspend = bce_vhci_bus_suspend, + .bus_resume = bce_vhci_bus_resume +}; + + +int __init bce_vhci_module_init(void) +{ + int result; + if ((result = alloc_chrdev_region(&bce_vhci_chrdev, 0, 1, "bce-vhci"))) + goto fail_chrdev; +#if LINUX_VERSION_CODE < KERNEL_VERSION(6,4,0) + bce_vhci_class = class_create(THIS_MODULE, "bce-vhci"); +#else + bce_vhci_class = class_create("bce-vhci"); +#endif + if (IS_ERR(bce_vhci_class)) { + result = PTR_ERR(bce_vhci_class); + goto fail_class; + } + return 0; + +fail_class: + class_destroy(bce_vhci_class); +fail_chrdev: + unregister_chrdev_region(bce_vhci_chrdev, 1); + if (!result) + result = -EINVAL; + return result; +} +void __exit bce_vhci_module_exit(void) +{ + class_destroy(bce_vhci_class); + unregister_chrdev_region(bce_vhci_chrdev, 1); +} + +module_param_named(vhci_port_mask, bce_vhci_port_mask, ushort, 0444); +MODULE_PARM_DESC(vhci_port_mask, "Specifies which VHCI ports are enabled"); diff --git a/drivers/staging/apple-bce/vhci/vhci.h b/drivers/staging/apple-bce/vhci/vhci.h new file mode 100644 index 000000000000..6c2e22622f4c --- /dev/null +++ b/drivers/staging/apple-bce/vhci/vhci.h @@ -0,0 +1,52 @@ +#ifndef BCE_VHCI_H +#define BCE_VHCI_H + +#include "queue.h" +#include "transfer.h" + +struct usb_hcd; +struct bce_queue_cq; + +struct bce_vhci_device { + struct bce_vhci_transfer_queue tq[32]; + u32 tq_mask; +}; +struct bce_vhci { + struct apple_bce_device *dev; + dev_t vdevt; + struct device *vdev; + struct usb_hcd *hcd; + struct spinlock hcd_spinlock; + struct bce_vhci_message_queue msg_commands; + struct bce_vhci_message_queue msg_system; + struct bce_vhci_message_queue msg_isochronous; + struct bce_vhci_message_queue msg_interrupt; + struct bce_vhci_message_queue msg_asynchronous; + struct spinlock msg_asynchronous_lock; + struct bce_vhci_command_queue cq; + struct bce_queue_cq *ev_cq; + struct bce_vhci_event_queue ev_commands; + struct bce_vhci_event_queue ev_system; + struct bce_vhci_event_queue ev_isochronous; + struct bce_vhci_event_queue ev_interrupt; + struct bce_vhci_event_queue ev_asynchronous; + u16 port_mask; + u8 port_count; + u16 port_power_mask; + bce_vhci_device_t port_to_device[16]; + struct bce_vhci_device *devices[16]; + struct workqueue_struct *tq_state_wq; + struct work_struct w_fw_events; +}; + +int __init bce_vhci_module_init(void); +void __exit bce_vhci_module_exit(void); + +int bce_vhci_create(struct apple_bce_device *dev, struct bce_vhci *vhci); +void bce_vhci_destroy(struct bce_vhci *vhci); +int bce_vhci_start(struct usb_hcd *hcd); +void bce_vhci_stop(struct usb_hcd *hcd); + +struct bce_vhci *bce_vhci_from_hcd(struct usb_hcd *hcd); + +#endif //BCE_VHCI_H diff --git a/include/drm/drm_format_helper.h b/include/drm/drm_format_helper.h index 428d81afe215..aa1604d92c1a 100644 --- a/include/drm/drm_format_helper.h +++ b/include/drm/drm_format_helper.h @@ -96,6 +96,9 @@ void drm_fb_xrgb8888_to_rgba5551(struct iosys_map *dst, const unsigned int *dst_ void drm_fb_xrgb8888_to_rgb888(struct iosys_map *dst, const unsigned int *dst_pitch, const struct iosys_map *src, const struct drm_framebuffer *fb, const struct drm_rect *clip, struct drm_format_conv_state *state); +void drm_fb_xrgb8888_to_bgr888(struct iosys_map *dst, const unsigned int *dst_pitch, + const struct iosys_map *src, const struct drm_framebuffer *fb, + const struct drm_rect *clip, struct drm_format_conv_state *state); void drm_fb_xrgb8888_to_argb8888(struct iosys_map *dst, const unsigned int *dst_pitch, const struct iosys_map *src, const struct drm_framebuffer *fb, const struct drm_rect *clip, struct drm_format_conv_state *state); diff --git a/lib/test_printf.c b/lib/test_printf.c index 8448b6d02bd9..f63591b3ee69 100644 --- a/lib/test_printf.c +++ b/lib/test_printf.c @@ -719,18 +719,26 @@ static void __init fwnode_pointer(void) static void __init fourcc_pointer(void) { struct { + char type; u32 code; char *str; } const try[] = { - { 0x3231564e, "NV12 little-endian (0x3231564e)", }, - { 0xb231564e, "NV12 big-endian (0xb231564e)", }, - { 0x10111213, ".... little-endian (0x10111213)", }, - { 0x20303159, "Y10 little-endian (0x20303159)", }, + { 'c', 0x3231564e, "NV12 little-endian (0x3231564e)", }, + { 'c', 0xb231564e, "NV12 big-endian (0xb231564e)", }, + { 'c', 0x10111213, ".... little-endian (0x10111213)", }, + { 'c', 0x20303159, "Y10 little-endian (0x20303159)", }, + { 'h', 0x67503030, "gP00 (0x67503030)", }, + { 'r', 0x30305067, "gP00 (0x67503030)", }, + { 'l', cpu_to_le32(0x67503030), "gP00 (0x67503030)", }, + { 'b', cpu_to_be32(0x67503030), "gP00 (0x67503030)", }, }; unsigned int i; - for (i = 0; i < ARRAY_SIZE(try); i++) - test(try[i].str, "%p4cc", &try[i].code); + for (i = 0; i < ARRAY_SIZE(try); i++) { + char fmt[] = { '%', 'p', '4', 'c', try[i].type, '\0' }; + + test(try[i].str, fmt, &try[i].code); + } } static void __init diff --git a/lib/vsprintf.c b/lib/vsprintf.c index c5e2ec9303c5..874e3af8104c 100644 --- a/lib/vsprintf.c +++ b/lib/vsprintf.c @@ -1760,27 +1760,50 @@ char *fourcc_string(char *buf, char *end, const u32 *fourcc, char output[sizeof("0123 little-endian (0x01234567)")]; char *p = output; unsigned int i; + bool pix_fmt = false; u32 orig, val; - if (fmt[1] != 'c' || fmt[2] != 'c') + if (fmt[1] != 'c') return error_string(buf, end, "(%p4?)", spec); if (check_pointer(&buf, end, fourcc, spec)) return buf; orig = get_unaligned(fourcc); - val = orig & ~BIT(31); + switch (fmt[2]) { + case 'h': + val = orig; + break; + case 'r': + val = orig = swab32(orig); + break; + case 'l': + val = orig = le32_to_cpu(orig); + break; + case 'b': + val = orig = be32_to_cpu(orig); + break; + case 'c': + /* Pixel formats are printed LSB-first */ + val = swab32(orig & ~BIT(31)); + pix_fmt = true; + break; + default: + return error_string(buf, end, "(%p4?)", spec); + } for (i = 0; i < sizeof(u32); i++) { - unsigned char c = val >> (i * 8); + unsigned char c = val >> ((3 - i) * 8); /* Print non-control ASCII characters as-is, dot otherwise */ *p++ = isascii(c) && isprint(c) ? c : '.'; } - *p++ = ' '; - strcpy(p, orig & BIT(31) ? "big-endian" : "little-endian"); - p += strlen(p); + if (pix_fmt) { + *p++ = ' '; + strcpy(p, orig & BIT(31) ? "big-endian" : "little-endian"); + p += strlen(p); + } *p++ = ' '; *p++ = '('; @@ -2334,6 +2357,7 @@ char *rust_fmt_argument(char *buf, char *end, void *ptr); * read the documentation (path below) first. * - 'NF' For a netdev_features_t * - '4cc' V4L2 or DRM FourCC code, with endianness and raw numerical value. + * - '4c[hlbr]' Generic FourCC code. * - 'h[CDN]' For a variable-length buffer, it prints it as a hex string with * a certain separator (' ' by default): * C colon diff --git a/scripts/checkpatch.pl b/scripts/checkpatch.pl index 4427572b2477..b60c99d61882 100755 --- a/scripts/checkpatch.pl +++ b/scripts/checkpatch.pl @@ -6917,7 +6917,7 @@ sub process { ($extension eq "f" && defined $qualifier && $qualifier !~ /^w/) || ($extension eq "4" && - defined $qualifier && $qualifier !~ /^cc/)) { + defined $qualifier && $qualifier !~ /^c[chlbr]/)) { $bad_specifier = $specifier; last; } -- 2.47.0 From 126ef40989e28bba3ff5a4bb41333942de1c9dbf Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 11 Nov 2024 09:22:31 +0100 Subject: [PATCH 13/13] zstd Signed-off-by: Peter Jung --- include/linux/zstd.h | 2 +- include/linux/zstd_errors.h | 23 +- include/linux/zstd_lib.h | 850 +++++-- lib/zstd/Makefile | 2 +- lib/zstd/common/allocations.h | 56 + lib/zstd/common/bits.h | 149 ++ lib/zstd/common/bitstream.h | 127 +- lib/zstd/common/compiler.h | 134 +- lib/zstd/common/cpu.h | 3 +- lib/zstd/common/debug.c | 9 +- lib/zstd/common/debug.h | 34 +- lib/zstd/common/entropy_common.c | 42 +- lib/zstd/common/error_private.c | 12 +- lib/zstd/common/error_private.h | 84 +- lib/zstd/common/fse.h | 94 +- lib/zstd/common/fse_decompress.c | 130 +- lib/zstd/common/huf.h | 237 +- lib/zstd/common/mem.h | 3 +- lib/zstd/common/portability_macros.h | 28 +- lib/zstd/common/zstd_common.c | 38 +- lib/zstd/common/zstd_deps.h | 16 +- lib/zstd/common/zstd_internal.h | 109 +- lib/zstd/compress/clevels.h | 3 +- lib/zstd/compress/fse_compress.c | 74 +- lib/zstd/compress/hist.c | 3 +- lib/zstd/compress/hist.h | 3 +- lib/zstd/compress/huf_compress.c | 441 ++-- lib/zstd/compress/zstd_compress.c | 2111 ++++++++++++----- lib/zstd/compress/zstd_compress_internal.h | 359 ++- lib/zstd/compress/zstd_compress_literals.c | 155 +- lib/zstd/compress/zstd_compress_literals.h | 25 +- lib/zstd/compress/zstd_compress_sequences.c | 7 +- lib/zstd/compress/zstd_compress_sequences.h | 3 +- lib/zstd/compress/zstd_compress_superblock.c | 376 ++- lib/zstd/compress/zstd_compress_superblock.h | 3 +- lib/zstd/compress/zstd_cwksp.h | 169 +- lib/zstd/compress/zstd_double_fast.c | 143 +- lib/zstd/compress/zstd_double_fast.h | 17 +- lib/zstd/compress/zstd_fast.c | 596 +++-- lib/zstd/compress/zstd_fast.h | 6 +- lib/zstd/compress/zstd_lazy.c | 732 +++--- lib/zstd/compress/zstd_lazy.h | 138 +- lib/zstd/compress/zstd_ldm.c | 21 +- lib/zstd/compress/zstd_ldm.h | 3 +- lib/zstd/compress/zstd_ldm_geartab.h | 3 +- lib/zstd/compress/zstd_opt.c | 497 ++-- lib/zstd/compress/zstd_opt.h | 41 +- lib/zstd/decompress/huf_decompress.c | 887 ++++--- lib/zstd/decompress/zstd_ddict.c | 9 +- lib/zstd/decompress/zstd_ddict.h | 3 +- lib/zstd/decompress/zstd_decompress.c | 358 ++- lib/zstd/decompress/zstd_decompress_block.c | 708 +++--- lib/zstd/decompress/zstd_decompress_block.h | 10 +- .../decompress/zstd_decompress_internal.h | 9 +- lib/zstd/decompress_sources.h | 2 +- lib/zstd/zstd_common_module.c | 5 +- lib/zstd/zstd_compress_module.c | 2 +- lib/zstd/zstd_decompress_module.c | 4 +- 58 files changed, 6577 insertions(+), 3531 deletions(-) create mode 100644 lib/zstd/common/allocations.h create mode 100644 lib/zstd/common/bits.h diff --git a/include/linux/zstd.h b/include/linux/zstd.h index b2c7cf310c8f..ac59ae9a18d7 100644 --- a/include/linux/zstd.h +++ b/include/linux/zstd.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/include/linux/zstd_errors.h b/include/linux/zstd_errors.h index 58b6dd45a969..6d5cf55f0bf3 100644 --- a/include/linux/zstd_errors.h +++ b/include/linux/zstd_errors.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -17,8 +18,17 @@ /* ===== ZSTDERRORLIB_API : control library symbols visibility ===== */ -#define ZSTDERRORLIB_VISIBILITY -#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBILITY +#define ZSTDERRORLIB_VISIBLE + +#ifndef ZSTDERRORLIB_HIDDEN +# if (__GNUC__ >= 4) && !defined(__MINGW32__) +# define ZSTDERRORLIB_HIDDEN __attribute__ ((visibility ("hidden"))) +# else +# define ZSTDERRORLIB_HIDDEN +# endif +#endif + +#define ZSTDERRORLIB_API ZSTDERRORLIB_VISIBLE /*-********************************************* * Error codes list @@ -43,14 +53,17 @@ typedef enum { ZSTD_error_frameParameter_windowTooLarge = 16, ZSTD_error_corruption_detected = 20, ZSTD_error_checksum_wrong = 22, + ZSTD_error_literals_headerWrong = 24, ZSTD_error_dictionary_corrupted = 30, ZSTD_error_dictionary_wrong = 32, ZSTD_error_dictionaryCreation_failed = 34, ZSTD_error_parameter_unsupported = 40, + ZSTD_error_parameter_combination_unsupported = 41, ZSTD_error_parameter_outOfBound = 42, ZSTD_error_tableLog_tooLarge = 44, ZSTD_error_maxSymbolValue_tooLarge = 46, ZSTD_error_maxSymbolValue_tooSmall = 48, + ZSTD_error_stabilityCondition_notRespected = 50, ZSTD_error_stage_wrong = 60, ZSTD_error_init_missing = 62, ZSTD_error_memory_allocation = 64, @@ -58,11 +71,15 @@ typedef enum { ZSTD_error_dstSize_tooSmall = 70, ZSTD_error_srcSize_wrong = 72, ZSTD_error_dstBuffer_null = 74, + ZSTD_error_noForwardProgress_destFull = 80, + ZSTD_error_noForwardProgress_inputEmpty = 82, /* following error codes are __NOT STABLE__, they can be removed or changed in future versions */ ZSTD_error_frameIndex_tooLarge = 100, ZSTD_error_seekableIO = 102, ZSTD_error_dstBuffer_wrong = 104, ZSTD_error_srcBuffer_wrong = 105, + ZSTD_error_sequenceProducer_failed = 106, + ZSTD_error_externalSequences_invalid = 107, ZSTD_error_maxCode = 120 /* never EVER use this value directly, it can change in future versions! Use ZSTD_isError() instead */ } ZSTD_ErrorCode; diff --git a/include/linux/zstd_lib.h b/include/linux/zstd_lib.h index 79d55465d5c1..6320fedcf8a4 100644 --- a/include/linux/zstd_lib.h +++ b/include/linux/zstd_lib.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,23 +12,42 @@ #ifndef ZSTD_H_235446 #define ZSTD_H_235446 -/* ====== Dependency ======*/ +/* ====== Dependencies ======*/ #include /* INT_MAX */ #include /* size_t */ /* ===== ZSTDLIB_API : control library symbols visibility ===== */ -#ifndef ZSTDLIB_VISIBLE +#define ZSTDLIB_VISIBLE + +#ifndef ZSTDLIB_HIDDEN # if (__GNUC__ >= 4) && !defined(__MINGW32__) -# define ZSTDLIB_VISIBLE __attribute__ ((visibility ("default"))) # define ZSTDLIB_HIDDEN __attribute__ ((visibility ("hidden"))) # else -# define ZSTDLIB_VISIBLE # define ZSTDLIB_HIDDEN # endif #endif + #define ZSTDLIB_API ZSTDLIB_VISIBLE +/* Deprecation warnings : + * Should these warnings be a problem, it is generally possible to disable them, + * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. + * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. + */ +#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS +# define ZSTD_DEPRECATED(message) /* disable deprecation warnings */ +#else +# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) +# define ZSTD_DEPRECATED(message) __attribute__((deprecated(message))) +# elif (__GNUC__ >= 3) +# define ZSTD_DEPRECATED(message) __attribute__((deprecated)) +# else +# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") +# define ZSTD_DEPRECATED(message) +# endif +#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ + /* ***************************************************************************** Introduction @@ -65,7 +85,7 @@ /*------ Version ------*/ #define ZSTD_VERSION_MAJOR 1 #define ZSTD_VERSION_MINOR 5 -#define ZSTD_VERSION_RELEASE 2 +#define ZSTD_VERSION_RELEASE 6 #define ZSTD_VERSION_NUMBER (ZSTD_VERSION_MAJOR *100*100 + ZSTD_VERSION_MINOR *100 + ZSTD_VERSION_RELEASE) /*! ZSTD_versionNumber() : @@ -107,7 +127,8 @@ ZSTDLIB_API const char* ZSTD_versionString(void); ***************************************/ /*! ZSTD_compress() : * Compresses `src` content as a single zstd compressed frame into already allocated `dst`. - * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. + * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have + * enough space to successfully compress the data. * @return : compressed size written into `dst` (<= `dstCapacity), * or an error code if it fails (which can be tested using ZSTD_isError()). */ ZSTDLIB_API size_t ZSTD_compress( void* dst, size_t dstCapacity, @@ -156,7 +177,9 @@ ZSTDLIB_API unsigned long long ZSTD_getFrameContentSize(const void *src, size_t * "empty", "unknown" and "error" results to the same return value (0), * while ZSTD_getFrameContentSize() gives them separate return values. * @return : decompressed size of `src` frame content _if known and not empty_, 0 otherwise. */ -ZSTDLIB_API unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize); +ZSTD_DEPRECATED("Replaced by ZSTD_getFrameContentSize") +ZSTDLIB_API +unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize); /*! ZSTD_findFrameCompressedSize() : Requires v1.4.0+ * `src` should point to the start of a ZSTD frame or skippable frame. @@ -168,8 +191,30 @@ ZSTDLIB_API size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) /*====== Helper functions ======*/ -#define ZSTD_COMPRESSBOUND(srcSize) ((srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ -ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ +/* ZSTD_compressBound() : + * maximum compressed size in worst case single-pass scenario. + * When invoking `ZSTD_compress()` or any other one-pass compression function, + * it's recommended to provide @dstCapacity >= ZSTD_compressBound(srcSize) + * as it eliminates one potential failure scenario, + * aka not enough room in dst buffer to write the compressed frame. + * Note : ZSTD_compressBound() itself can fail, if @srcSize > ZSTD_MAX_INPUT_SIZE . + * In which case, ZSTD_compressBound() will return an error code + * which can be tested using ZSTD_isError(). + * + * ZSTD_COMPRESSBOUND() : + * same as ZSTD_compressBound(), but as a macro. + * It can be used to produce constants, which can be useful for static allocation, + * for example to size a static array on stack. + * Will produce constant value 0 if srcSize too large. + */ +#define ZSTD_MAX_INPUT_SIZE ((sizeof(size_t)==8) ? 0xFF00FF00FF00FF00ULL : 0xFF00FF00U) +#define ZSTD_COMPRESSBOUND(srcSize) (((size_t)(srcSize) >= ZSTD_MAX_INPUT_SIZE) ? 0 : (srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0)) /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */ +ZSTDLIB_API size_t ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */ +/* ZSTD_isError() : + * Most ZSTD_* functions returning a size_t value can be tested for error, + * using ZSTD_isError(). + * @return 1 if error, 0 otherwise + */ ZSTDLIB_API unsigned ZSTD_isError(size_t code); /*!< tells if a `size_t` function result is an error code */ ZSTDLIB_API const char* ZSTD_getErrorName(size_t code); /*!< provides readable string from an error code */ ZSTDLIB_API int ZSTD_minCLevel(void); /*!< minimum negative compression level allowed, requires v1.4.0+ */ @@ -183,7 +228,7 @@ ZSTDLIB_API int ZSTD_defaultCLevel(void); /*!< default compres /*= Compression context * When compressing many times, * it is recommended to allocate a context just once, - * and re-use it for each successive compression operation. + * and reuse it for each successive compression operation. * This will make workload friendlier for system's memory. * Note : re-using context is just a speed / resource optimization. * It doesn't change the compression ratio, which remains identical. @@ -196,9 +241,9 @@ ZSTDLIB_API size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx); /* accept NULL pointer * /*! ZSTD_compressCCtx() : * Same as ZSTD_compress(), using an explicit ZSTD_CCtx. - * Important : in order to behave similarly to `ZSTD_compress()`, - * this function compresses at requested compression level, - * __ignoring any other parameter__ . + * Important : in order to mirror `ZSTD_compress()` behavior, + * this function compresses at the requested compression level, + * __ignoring any other advanced parameter__ . * If any advanced parameter was set using the advanced API, * they will all be reset. Only `compressionLevel` remains. */ @@ -210,7 +255,7 @@ ZSTDLIB_API size_t ZSTD_compressCCtx(ZSTD_CCtx* cctx, /*= Decompression context * When decompressing many times, * it is recommended to allocate a context only once, - * and re-use it for each successive compression operation. + * and reuse it for each successive compression operation. * This will make workload friendlier for system's memory. * Use one context per thread for parallel execution. */ typedef struct ZSTD_DCtx_s ZSTD_DCtx; @@ -220,7 +265,7 @@ ZSTDLIB_API size_t ZSTD_freeDCtx(ZSTD_DCtx* dctx); /* accept NULL pointer * /*! ZSTD_decompressDCtx() : * Same as ZSTD_decompress(), * requires an allocated ZSTD_DCtx. - * Compatible with sticky parameters. + * Compatible with sticky parameters (see below). */ ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, @@ -236,12 +281,12 @@ ZSTDLIB_API size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, * using ZSTD_CCtx_set*() functions. * Pushed parameters are sticky : they are valid for next compressed frame, and any subsequent frame. * "sticky" parameters are applicable to `ZSTD_compress2()` and `ZSTD_compressStream*()` ! - * __They do not apply to "simple" one-shot variants such as ZSTD_compressCCtx()__ . + * __They do not apply to one-shot variants such as ZSTD_compressCCtx()__ . * * It's possible to reset all parameters to "default" using ZSTD_CCtx_reset(). * * This API supersedes all other "advanced" API entry points in the experimental section. - * In the future, we expect to remove from experimental API entry points which are redundant with this API. + * In the future, we expect to remove API entry points from experimental which are redundant with this API. */ @@ -324,6 +369,19 @@ typedef enum { * The higher the value of selected strategy, the more complex it is, * resulting in stronger and slower compression. * Special: value 0 means "use default strategy". */ + + ZSTD_c_targetCBlockSize=130, /* v1.5.6+ + * Attempts to fit compressed block size into approximatively targetCBlockSize. + * Bound by ZSTD_TARGETCBLOCKSIZE_MIN and ZSTD_TARGETCBLOCKSIZE_MAX. + * Note that it's not a guarantee, just a convergence target (default:0). + * No target when targetCBlockSize == 0. + * This is helpful in low bandwidth streaming environments to improve end-to-end latency, + * when a client can make use of partial documents (a prominent example being Chrome). + * Note: this parameter is stable since v1.5.6. + * It was present as an experimental parameter in earlier versions, + * but it's not recommended using it with earlier library versions + * due to massive performance regressions. + */ /* LDM mode parameters */ ZSTD_c_enableLongDistanceMatching=160, /* Enable long distance matching. * This parameter is designed to improve compression ratio @@ -403,7 +461,6 @@ typedef enum { * ZSTD_c_forceMaxWindow * ZSTD_c_forceAttachDict * ZSTD_c_literalCompressionMode - * ZSTD_c_targetCBlockSize * ZSTD_c_srcSizeHint * ZSTD_c_enableDedicatedDictSearch * ZSTD_c_stableInBuffer @@ -412,6 +469,9 @@ typedef enum { * ZSTD_c_validateSequences * ZSTD_c_useBlockSplitter * ZSTD_c_useRowMatchFinder + * ZSTD_c_prefetchCDictTables + * ZSTD_c_enableSeqProducerFallback + * ZSTD_c_maxBlockSize * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. * note : never ever use experimentalParam? names directly; * also, the enums values themselves are unstable and can still change. @@ -421,7 +481,7 @@ typedef enum { ZSTD_c_experimentalParam3=1000, ZSTD_c_experimentalParam4=1001, ZSTD_c_experimentalParam5=1002, - ZSTD_c_experimentalParam6=1003, + /* was ZSTD_c_experimentalParam6=1003; is now ZSTD_c_targetCBlockSize */ ZSTD_c_experimentalParam7=1004, ZSTD_c_experimentalParam8=1005, ZSTD_c_experimentalParam9=1006, @@ -430,7 +490,11 @@ typedef enum { ZSTD_c_experimentalParam12=1009, ZSTD_c_experimentalParam13=1010, ZSTD_c_experimentalParam14=1011, - ZSTD_c_experimentalParam15=1012 + ZSTD_c_experimentalParam15=1012, + ZSTD_c_experimentalParam16=1013, + ZSTD_c_experimentalParam17=1014, + ZSTD_c_experimentalParam18=1015, + ZSTD_c_experimentalParam19=1016 } ZSTD_cParameter; typedef struct { @@ -493,7 +557,7 @@ typedef enum { * They will be used to compress next frame. * Resetting session never fails. * - The parameters : changes all parameters back to "default". - * This removes any reference to any dictionary too. + * This also removes any reference to any dictionary or external sequence producer. * Parameters can only be changed between 2 sessions (i.e. no compression is currently ongoing) * otherwise the reset fails, and function returns an error value (which can be tested using ZSTD_isError()) * - Both : similar to resetting the session, followed by resetting parameters. @@ -502,11 +566,13 @@ ZSTDLIB_API size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset); /*! ZSTD_compress2() : * Behave the same as ZSTD_compressCCtx(), but compression parameters are set using the advanced API. + * (note that this entry point doesn't even expose a compression level parameter). * ZSTD_compress2() always starts a new frame. * Should cctx hold data from a previously unfinished frame, everything about it is forgotten. * - Compression parameters are pushed into CCtx before starting compression, using ZSTD_CCtx_set*() * - The function is always blocking, returns when compression is completed. - * Hint : compression runs faster if `dstCapacity` >= `ZSTD_compressBound(srcSize)`. + * NOTE: Providing `dstCapacity >= ZSTD_compressBound(srcSize)` guarantees that zstd will have + * enough space to successfully compress the data, though it is possible it fails for other reasons. * @return : compressed size written into `dst` (<= `dstCapacity), * or an error code if it fails (which can be tested using ZSTD_isError()). */ @@ -543,13 +609,17 @@ typedef enum { * ZSTD_d_stableOutBuffer * ZSTD_d_forceIgnoreChecksum * ZSTD_d_refMultipleDDicts + * ZSTD_d_disableHuffmanAssembly + * ZSTD_d_maxBlockSize * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them. * note : never ever use experimentalParam? names directly */ ZSTD_d_experimentalParam1=1000, ZSTD_d_experimentalParam2=1001, ZSTD_d_experimentalParam3=1002, - ZSTD_d_experimentalParam4=1003 + ZSTD_d_experimentalParam4=1003, + ZSTD_d_experimentalParam5=1004, + ZSTD_d_experimentalParam6=1005 } ZSTD_dParameter; @@ -604,14 +674,14 @@ typedef struct ZSTD_outBuffer_s { * A ZSTD_CStream object is required to track streaming operation. * Use ZSTD_createCStream() and ZSTD_freeCStream() to create/release resources. * ZSTD_CStream objects can be reused multiple times on consecutive compression operations. -* It is recommended to re-use ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. +* It is recommended to reuse ZSTD_CStream since it will play nicer with system's memory, by re-using already allocated memory. * * For parallel execution, use one separate ZSTD_CStream per thread. * * note : since v1.3.0, ZSTD_CStream and ZSTD_CCtx are the same thing. * * Parameters are sticky : when starting a new compression on the same context, -* it will re-use the same sticky parameters as previous compression session. +* it will reuse the same sticky parameters as previous compression session. * When in doubt, it's recommended to fully initialize the context before usage. * Use ZSTD_CCtx_reset() to reset the context and ZSTD_CCtx_setParameter(), * ZSTD_CCtx_setPledgedSrcSize(), or ZSTD_CCtx_loadDictionary() and friends to @@ -700,6 +770,11 @@ typedef enum { * only ZSTD_e_end or ZSTD_e_flush operations are allowed. * Before starting a new compression job, or changing compression parameters, * it is required to fully flush internal buffers. + * - note: if an operation ends with an error, it may leave @cctx in an undefined state. + * Therefore, it's UB to invoke ZSTD_compressStream2() of ZSTD_compressStream() on such a state. + * In order to be re-employed after an error, a state must be reset, + * which can be done explicitly (ZSTD_CCtx_reset()), + * or is sometimes implied by methods starting a new compression job (ZSTD_initCStream(), ZSTD_compressCCtx()) */ ZSTDLIB_API size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, ZSTD_outBuffer* output, @@ -728,8 +803,6 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output * This following is a legacy streaming API, available since v1.0+ . * It can be replaced by ZSTD_CCtx_reset() and ZSTD_compressStream2(). * It is redundant, but remains fully supported. - * Streaming in combination with advanced parameters and dictionary compression - * can only be used through the new API. ******************************************************************************/ /*! @@ -738,6 +811,9 @@ ZSTDLIB_API size_t ZSTD_CStreamOutSize(void); /*< recommended size for output * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); * ZSTD_CCtx_refCDict(zcs, NULL); // clear the dictionary (if any) * ZSTD_CCtx_setParameter(zcs, ZSTD_c_compressionLevel, compressionLevel); + * + * Note that ZSTD_initCStream() clears any previously set dictionary. Use the new API + * to compress with a dictionary. */ ZSTDLIB_API size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel); /*! @@ -758,7 +834,7 @@ ZSTDLIB_API size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output); * * A ZSTD_DStream object is required to track streaming operations. * Use ZSTD_createDStream() and ZSTD_freeDStream() to create/release resources. -* ZSTD_DStream objects can be re-used multiple times. +* ZSTD_DStream objects can be reused multiple times. * * Use ZSTD_initDStream() to start a new decompression operation. * @return : recommended first input size @@ -788,13 +864,37 @@ ZSTDLIB_API size_t ZSTD_freeDStream(ZSTD_DStream* zds); /* accept NULL pointer /*===== Streaming decompression functions =====*/ -/* This function is redundant with the advanced API and equivalent to: +/*! ZSTD_initDStream() : + * Initialize/reset DStream state for new decompression operation. + * Call before new decompression operation using same DStream. * + * Note : This function is redundant with the advanced API and equivalent to: * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); * ZSTD_DCtx_refDDict(zds, NULL); */ ZSTDLIB_API size_t ZSTD_initDStream(ZSTD_DStream* zds); +/*! ZSTD_decompressStream() : + * Streaming decompression function. + * Call repetitively to consume full input updating it as necessary. + * Function will update both input and output `pos` fields exposing current state via these fields: + * - `input.pos < input.size`, some input remaining and caller should provide remaining input + * on the next call. + * - `output.pos < output.size`, decoder finished and flushed all remaining buffers. + * - `output.pos == output.size`, potentially uncflushed data present in the internal buffers, + * call ZSTD_decompressStream() again to flush remaining data to output. + * Note : with no additional input, amount of data flushed <= ZSTD_BLOCKSIZE_MAX. + * + * @return : 0 when a frame is completely decoded and fully flushed, + * or an error code, which can be tested using ZSTD_isError(), + * or any other value > 0, which means there is some decoding or flushing to do to complete current frame. + * + * Note: when an operation returns with an error code, the @zds state may be left in undefined state. + * It's UB to invoke `ZSTD_decompressStream()` on such a state. + * In order to re-use such a state, it must be first reset, + * which can be done explicitly (`ZSTD_DCtx_reset()`), + * or is implied for operations starting some new decompression job (`ZSTD_initDStream`, `ZSTD_decompressDCtx()`, `ZSTD_decompress_usingDict()`) + */ ZSTDLIB_API size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inBuffer* input); ZSTDLIB_API size_t ZSTD_DStreamInSize(void); /*!< recommended size for input buffer */ @@ -913,7 +1013,7 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict); * If @return == 0, the dictID could not be decoded. * This could for one of the following reasons : * - The frame does not require a dictionary to be decoded (most common case). - * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden information. + * - The frame was built with dictID intentionally removed. Whatever dictionary is necessary is a hidden piece of information. * Note : this use case also happens when using a non-conformant dictionary. * - `srcSize` is too small, and as a result, the frame header could not be decoded (only possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`). * - This is not a Zstandard frame. @@ -925,9 +1025,11 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * Advanced dictionary and prefix API (Requires v1.4.0+) * * This API allows dictionaries to be used with ZSTD_compress2(), - * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). Dictionaries are sticky, and - * only reset with the context is reset with ZSTD_reset_parameters or - * ZSTD_reset_session_and_parameters. Prefixes are single-use. + * ZSTD_compressStream2(), and ZSTD_decompressDCtx(). + * Dictionaries are sticky, they remain valid when same context is reused, + * they only reset when the context is reset + * with ZSTD_reset_parameters or ZSTD_reset_session_and_parameters. + * In contrast, Prefixes are single-use. ******************************************************************************/ @@ -937,8 +1039,9 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special: Loading a NULL (or 0-size) dictionary invalidates previous dictionary, * meaning "return to no-dictionary mode". - * Note 1 : Dictionary is sticky, it will be used for all future compressed frames. - * To return to "no-dictionary" situation, load a NULL dictionary (or reset parameters). + * Note 1 : Dictionary is sticky, it will be used for all future compressed frames, + * until parameters are reset, a new dictionary is loaded, or the dictionary + * is explicitly invalidated by loading a NULL dictionary. * Note 2 : Loading a dictionary involves building tables. * It's also a CPU consuming operation, with non-negligible impact on latency. * Tables are dependent on compression parameters, and for this reason, @@ -947,11 +1050,15 @@ ZSTDLIB_API unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize); * Use experimental ZSTD_CCtx_loadDictionary_byReference() to reference content instead. * In such a case, dictionary buffer must outlive its users. * Note 4 : Use ZSTD_CCtx_loadDictionary_advanced() - * to precisely select how dictionary content must be interpreted. */ + * to precisely select how dictionary content must be interpreted. + * Note 5 : This method does not benefit from LDM (long distance mode). + * If you want to employ LDM on some large dictionary content, + * prefer employing ZSTD_CCtx_refPrefix() described below. + */ ZSTDLIB_API size_t ZSTD_CCtx_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize); /*! ZSTD_CCtx_refCDict() : Requires v1.4.0+ - * Reference a prepared dictionary, to be used for all next compressed frames. + * Reference a prepared dictionary, to be used for all future compressed frames. * Note that compression parameters are enforced from within CDict, * and supersede any compression parameter previously set within CCtx. * The parameters ignored are labelled as "superseded-by-cdict" in the ZSTD_cParameter enum docs. @@ -970,6 +1077,7 @@ ZSTDLIB_API size_t ZSTD_CCtx_refCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); * Decompression will need same prefix to properly regenerate data. * Compressing with a prefix is similar in outcome as performing a diff and compressing it, * but performs much faster, especially during decompression (compression speed is tunable with compression level). + * This method is compatible with LDM (long distance mode). * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special: Adding any prefix (including NULL) invalidates any previous prefix or dictionary * Note 1 : Prefix buffer is referenced. It **must** outlive compression. @@ -986,9 +1094,9 @@ ZSTDLIB_API size_t ZSTD_CCtx_refPrefix(ZSTD_CCtx* cctx, const void* prefix, size_t prefixSize); /*! ZSTD_DCtx_loadDictionary() : Requires v1.4.0+ - * Create an internal DDict from dict buffer, - * to be used to decompress next frames. - * The dictionary remains valid for all future frames, until explicitly invalidated. + * Create an internal DDict from dict buffer, to be used to decompress all future frames. + * The dictionary remains valid for all future frames, until explicitly invalidated, or + * a new dictionary is loaded. * @result : 0, or an error code (which can be tested with ZSTD_isError()). * Special : Adding a NULL (or 0-size) dictionary invalidates any previous dictionary, * meaning "return to no-dictionary mode". @@ -1012,9 +1120,10 @@ ZSTDLIB_API size_t ZSTD_DCtx_loadDictionary(ZSTD_DCtx* dctx, const void* dict, s * The memory for the table is allocated on the first call to refDDict, and can be * freed with ZSTD_freeDCtx(). * + * If called with ZSTD_d_refMultipleDDicts disabled (the default), only one dictionary + * will be managed, and referencing a dictionary effectively "discards" any previous one. + * * @result : 0, or an error code (which can be tested with ZSTD_isError()). - * Note 1 : Currently, only one dictionary can be managed. - * Referencing a new dictionary effectively "discards" any previous one. * Special: referencing a NULL DDict means "return to no-dictionary mode". * Note 2 : DDict is just referenced, its lifetime must outlive its usage from DCtx. */ @@ -1071,24 +1180,6 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #define ZSTDLIB_STATIC_API ZSTDLIB_VISIBLE #endif -/* Deprecation warnings : - * Should these warnings be a problem, it is generally possible to disable them, - * typically with -Wno-deprecated-declarations for gcc or _CRT_SECURE_NO_WARNINGS in Visual. - * Otherwise, it's also possible to define ZSTD_DISABLE_DEPRECATE_WARNINGS. - */ -#ifdef ZSTD_DISABLE_DEPRECATE_WARNINGS -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API /* disable deprecation warnings */ -#else -# if (defined(GNUC) && (GNUC > 4 || (GNUC == 4 && GNUC_MINOR >= 5))) || defined(__clang__) -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated(message))) -# elif (__GNUC__ >= 3) -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API __attribute__((deprecated)) -# else -# pragma message("WARNING: You need to implement ZSTD_DEPRECATED for this compiler") -# define ZSTD_DEPRECATED(message) ZSTDLIB_STATIC_API -# endif -#endif /* ZSTD_DISABLE_DEPRECATE_WARNINGS */ - /* ************************************************************************************** * experimental API (static linking only) **************************************************************************************** @@ -1123,6 +1214,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #define ZSTD_TARGETLENGTH_MIN 0 /* note : comparing this constant to an unsigned results in a tautological test */ #define ZSTD_STRATEGY_MIN ZSTD_fast #define ZSTD_STRATEGY_MAX ZSTD_btultra2 +#define ZSTD_BLOCKSIZE_MAX_MIN (1 << 10) /* The minimum valid max blocksize. Maximum blocksizes smaller than this make compressBound() inaccurate. */ #define ZSTD_OVERLAPLOG_MIN 0 @@ -1146,7 +1238,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict); #define ZSTD_LDM_HASHRATELOG_MAX (ZSTD_WINDOWLOG_MAX - ZSTD_HASHLOG_MIN) /* Advanced parameter bounds */ -#define ZSTD_TARGETCBLOCKSIZE_MIN 64 +#define ZSTD_TARGETCBLOCKSIZE_MIN 1340 /* suitable to fit into an ethernet / wifi / 4G transport frame */ #define ZSTD_TARGETCBLOCKSIZE_MAX ZSTD_BLOCKSIZE_MAX #define ZSTD_SRCSIZEHINT_MIN 0 #define ZSTD_SRCSIZEHINT_MAX INT_MAX @@ -1303,7 +1395,7 @@ typedef enum { } ZSTD_paramSwitch_e; /* ************************************* -* Frame size functions +* Frame header and size functions ***************************************/ /*! ZSTD_findDecompressedSize() : @@ -1350,29 +1442,122 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size * or an error code (if srcSize is too small) */ ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize); +typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_frameType_e; +typedef struct { + unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ + unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ + unsigned blockSizeMax; + ZSTD_frameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ + unsigned headerSize; + unsigned dictID; + unsigned checksumFlag; + unsigned _reserved1; + unsigned _reserved2; +} ZSTD_frameHeader; + +/*! ZSTD_getFrameHeader() : + * decode Frame Header, or requires larger `srcSize`. + * @return : 0, `zfhPtr` is correctly filled, + * >0, `srcSize` is too small, value is wanted `srcSize` amount, + * or an error code, which can be tested using ZSTD_isError() */ +ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize); /*< doesn't consume input */ +/*! ZSTD_getFrameHeader_advanced() : + * same as ZSTD_getFrameHeader(), + * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ +ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); + +/*! ZSTD_decompressionMargin() : + * Zstd supports in-place decompression, where the input and output buffers overlap. + * In this case, the output buffer must be at least (Margin + Output_Size) bytes large, + * and the input buffer must be at the end of the output buffer. + * + * _______________________ Output Buffer ________________________ + * | | + * | ____ Input Buffer ____| + * | | | + * v v v + * |---------------------------------------|-----------|----------| + * ^ ^ ^ + * |___________________ Output_Size ___________________|_ Margin _| + * + * NOTE: See also ZSTD_DECOMPRESSION_MARGIN(). + * NOTE: This applies only to single-pass decompression through ZSTD_decompress() or + * ZSTD_decompressDCtx(). + * NOTE: This function supports multi-frame input. + * + * @param src The compressed frame(s) + * @param srcSize The size of the compressed frame(s) + * @returns The decompression margin or an error that can be checked with ZSTD_isError(). + */ +ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize); + +/*! ZSTD_DECOMPRESS_MARGIN() : + * Similar to ZSTD_decompressionMargin(), but instead of computing the margin from + * the compressed frame, compute it from the original size and the blockSizeLog. + * See ZSTD_decompressionMargin() for details. + * + * WARNING: This macro does not support multi-frame input, the input must be a single + * zstd frame. If you need that support use the function, or implement it yourself. + * + * @param originalSize The original uncompressed size of the data. + * @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX). + * Unless you explicitly set the windowLog smaller than + * ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX. + */ +#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \ + ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \ + 4 /* checksum */ + \ + ((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \ + (blockSize) /* One block of margin */ \ + )) + typedef enum { ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */ ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */ } ZSTD_sequenceFormat_e; +/*! ZSTD_sequenceBound() : + * `srcSize` : size of the input buffer + * @return : upper-bound for the number of sequences that can be generated + * from a buffer of srcSize bytes + * + * note : returns number of sequences - to get bytes, multiply by sizeof(ZSTD_Sequence). + */ +ZSTDLIB_STATIC_API size_t ZSTD_sequenceBound(size_t srcSize); + /*! ZSTD_generateSequences() : - * Generate sequences using ZSTD_compress2, given a source buffer. + * WARNING: This function is meant for debugging and informational purposes ONLY! + * Its implementation is flawed, and it will be deleted in a future version. + * It is not guaranteed to succeed, as there are several cases where it will give + * up and fail. You should NOT use this function in production code. + * + * This function is deprecated, and will be removed in a future version. + * + * Generate sequences using ZSTD_compress2(), given a source buffer. + * + * @param zc The compression context to be used for ZSTD_compress2(). Set any + * compression parameters you need on this context. + * @param outSeqs The output sequences buffer of size @p outSeqsSize + * @param outSeqsSize The size of the output sequences buffer. + * ZSTD_sequenceBound(srcSize) is an upper bound on the number + * of sequences that can be generated. + * @param src The source buffer to generate sequences from of size @p srcSize. + * @param srcSize The size of the source buffer. * * Each block will end with a dummy sequence * with offset == 0, matchLength == 0, and litLength == length of last literals. * litLength may be == 0, and if so, then the sequence of (of: 0 ml: 0 ll: 0) * simply acts as a block delimiter. * - * zc can be used to insert custom compression params. - * This function invokes ZSTD_compress2 - * - * The output of this function can be fed into ZSTD_compressSequences() with CCtx - * setting of ZSTD_c_blockDelimiters as ZSTD_sf_explicitBlockDelimiters - * @return : number of sequences generated + * @returns The number of sequences generated, necessarily less than + * ZSTD_sequenceBound(srcSize), or an error code that can be checked + * with ZSTD_isError(). */ - -ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, - size_t outSeqsSize, const void* src, size_t srcSize); +ZSTD_DEPRECATED("For debugging only, will be replaced by ZSTD_extractSequences()") +ZSTDLIB_STATIC_API size_t +ZSTD_generateSequences(ZSTD_CCtx* zc, + ZSTD_Sequence* outSeqs, size_t outSeqsSize, + const void* src, size_t srcSize); /*! ZSTD_mergeBlockDelimiters() : * Given an array of ZSTD_Sequence, remove all sequences that represent block delimiters/last literals @@ -1388,7 +1573,9 @@ ZSTDLIB_STATIC_API size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* o ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, size_t seqsSize); /*! ZSTD_compressSequences() : - * Compress an array of ZSTD_Sequence, generated from the original source buffer, into dst. + * Compress an array of ZSTD_Sequence, associated with @src buffer, into dst. + * @src contains the entire input (not just the literals). + * If @srcSize > sum(sequence.length), the remaining bytes are considered all literals * If a dictionary is included, then the cctx should reference the dict. (see: ZSTD_CCtx_refCDict(), ZSTD_CCtx_loadDictionary(), etc.) * The entire source is compressed into a single frame. * @@ -1413,11 +1600,12 @@ ZSTDLIB_STATIC_API size_t ZSTD_mergeBlockDelimiters(ZSTD_Sequence* sequences, si * Note: Repcodes are, as of now, always re-calculated within this function, so ZSTD_Sequence::rep is unused. * Note 2: Once we integrate ability to ingest repcodes, the explicit block delims mode must respect those repcodes exactly, * and cannot emit an RLE block that disagrees with the repcode history - * @return : final compressed size or a ZSTD error. + * @return : final compressed size, or a ZSTD error code. */ -ZSTDLIB_STATIC_API size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstSize, - const ZSTD_Sequence* inSeqs, size_t inSeqsSize, - const void* src, size_t srcSize); +ZSTDLIB_STATIC_API size_t +ZSTD_compressSequences( ZSTD_CCtx* cctx, void* dst, size_t dstSize, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, + const void* src, size_t srcSize); /*! ZSTD_writeSkippableFrame() : @@ -1464,48 +1652,59 @@ ZSTDLIB_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size); /*! ZSTD_estimate*() : * These functions make it possible to estimate memory usage * of a future {D,C}Ctx, before its creation. + * This is useful in combination with ZSTD_initStatic(), + * which makes it possible to employ a static buffer for ZSTD_CCtx* state. * * ZSTD_estimateCCtxSize() will provide a memory budget large enough - * for any compression level up to selected one. - * Note : Unlike ZSTD_estimateCStreamSize*(), this estimate - * does not include space for a window buffer. - * Therefore, the estimation is only guaranteed for single-shot compressions, not streaming. + * to compress data of any size using one-shot compression ZSTD_compressCCtx() or ZSTD_compress2() + * associated with any compression level up to max specified one. * The estimate will assume the input may be arbitrarily large, * which is the worst case. * + * Note that the size estimation is specific for one-shot compression, + * it is not valid for streaming (see ZSTD_estimateCStreamSize*()) + * nor other potential ways of using a ZSTD_CCtx* state. + * * When srcSize can be bound by a known and rather "small" value, - * this fact can be used to provide a tighter estimation - * because the CCtx compression context will need less memory. - * This tighter estimation can be provided by more advanced functions + * this knowledge can be used to provide a tighter budget estimation + * because the ZSTD_CCtx* state will need less memory for small inputs. + * This tighter estimation can be provided by employing more advanced functions * ZSTD_estimateCCtxSize_usingCParams(), which can be used in tandem with ZSTD_getCParams(), * and ZSTD_estimateCCtxSize_usingCCtxParams(), which can be used in tandem with ZSTD_CCtxParams_setParameter(). * Both can be used to estimate memory using custom compression parameters and arbitrary srcSize limits. * - * Note 2 : only single-threaded compression is supported. + * Note : only single-threaded compression is supported. * ZSTD_estimateCCtxSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. */ -ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int compressionLevel); +ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int maxCompressionLevel); ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams); ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params); ZSTDLIB_STATIC_API size_t ZSTD_estimateDCtxSize(void); /*! ZSTD_estimateCStreamSize() : - * ZSTD_estimateCStreamSize() will provide a budget large enough for any compression level up to selected one. - * It will also consider src size to be arbitrarily "large", which is worst case. + * ZSTD_estimateCStreamSize() will provide a memory budget large enough for streaming compression + * using any compression level up to the max specified one. + * It will also consider src size to be arbitrarily "large", which is a worst case scenario. * If srcSize is known to always be small, ZSTD_estimateCStreamSize_usingCParams() can provide a tighter estimation. * ZSTD_estimateCStreamSize_usingCParams() can be used in tandem with ZSTD_getCParams() to create cParams from compressionLevel. * ZSTD_estimateCStreamSize_usingCCtxParams() can be used in tandem with ZSTD_CCtxParams_setParameter(). Only single-threaded compression is supported. This function will return an error code if ZSTD_c_nbWorkers is >= 1. * Note : CStream size estimation is only correct for single-threaded compression. - * ZSTD_DStream memory budget depends on window Size. + * ZSTD_estimateCStreamSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1. + * Note 2 : ZSTD_estimateCStreamSize* functions are not compatible with the Block-Level Sequence Producer API at this time. + * Size estimates assume that no external sequence producer is registered. + * + * ZSTD_DStream memory budget depends on frame's window Size. * This information can be passed manually, using ZSTD_estimateDStreamSize, * or deducted from a valid frame Header, using ZSTD_estimateDStreamSize_fromFrame(); + * Any frame requesting a window size larger than max specified one will be rejected. * Note : if streaming is init with function ZSTD_init?Stream_usingDict(), * an internal ?Dict will be created, which additional size is not estimated here. - * In this case, get total size by adding ZSTD_estimate?DictSize */ -ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int compressionLevel); + * In this case, get total size by adding ZSTD_estimate?DictSize + */ +ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize(int maxCompressionLevel); ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCParams(ZSTD_compressionParameters cParams); ZSTDLIB_STATIC_API size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params); -ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t windowSize); +ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize(size_t maxWindowSize); ZSTDLIB_STATIC_API size_t ZSTD_estimateDStreamSize_fromFrame(const void* src, size_t srcSize); /*! ZSTD_estimate?DictSize() : @@ -1649,22 +1848,45 @@ ZSTDLIB_STATIC_API size_t ZSTD_checkCParams(ZSTD_compressionParameters params); * This function never fails (wide contract) */ ZSTDLIB_STATIC_API ZSTD_compressionParameters ZSTD_adjustCParams(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize); +/*! ZSTD_CCtx_setCParams() : + * Set all parameters provided within @p cparams into the working @p cctx. + * Note : if modifying parameters during compression (MT mode only), + * note that changes to the .windowLog parameter will be ignored. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + * On failure, no parameters are updated. + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams); + +/*! ZSTD_CCtx_setFParams() : + * Set all parameters provided within @p fparams into the working @p cctx. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams); + +/*! ZSTD_CCtx_setParams() : + * Set all parameters provided within @p params into the working @p cctx. + * @return 0 on success, or an error code (can be checked with ZSTD_isError()). + */ +ZSTDLIB_STATIC_API size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params); + /*! ZSTD_compress_advanced() : * Note : this function is now DEPRECATED. * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_setParameter() and other parameter setters. * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_compress2") +ZSTDLIB_STATIC_API size_t ZSTD_compress_advanced(ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - const void* dict,size_t dictSize, - ZSTD_parameters params); + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + const void* dict,size_t dictSize, + ZSTD_parameters params); /*! ZSTD_compress_usingCDict_advanced() : * Note : this function is now DEPRECATED. * It can be replaced by ZSTD_compress2(), in combination with ZSTD_CCtx_loadDictionary() and other parameter setters. * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_compress2 with ZSTD_CCtx_loadDictionary") +ZSTDLIB_STATIC_API size_t ZSTD_compress_usingCDict_advanced(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, @@ -1737,11 +1959,6 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo */ #define ZSTD_c_literalCompressionMode ZSTD_c_experimentalParam5 -/* Tries to fit compressed block size to be around targetCBlockSize. - * No target when targetCBlockSize == 0. - * There is no guarantee on compressed block size (default:0) */ -#define ZSTD_c_targetCBlockSize ZSTD_c_experimentalParam6 - /* User's best guess of source size. * Hint is not valid when srcSizeHint == 0. * There is no guarantee that hint is close to actual source size, @@ -1808,13 +2025,16 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * Experimental parameter. * Default is 0 == disabled. Set to 1 to enable. * - * Tells the compressor that the ZSTD_inBuffer will ALWAYS be the same - * between calls, except for the modifications that zstd makes to pos (the - * caller must not modify pos). This is checked by the compressor, and - * compression will fail if it ever changes. This means the only flush - * mode that makes sense is ZSTD_e_end, so zstd will error if ZSTD_e_end - * is not used. The data in the ZSTD_inBuffer in the range [src, src + pos) - * MUST not be modified during compression or you will get data corruption. + * Tells the compressor that input data presented with ZSTD_inBuffer + * will ALWAYS be the same between calls. + * Technically, the @src pointer must never be changed, + * and the @pos field can only be updated by zstd. + * However, it's possible to increase the @size field, + * allowing scenarios where more data can be appended after compressions starts. + * These conditions are checked by the compressor, + * and compression will fail if they are not respected. + * Also, data in the ZSTD_inBuffer within the range [src, src + pos) + * MUST not be modified during compression or it will result in data corruption. * * When this flag is enabled zstd won't allocate an input window buffer, * because the user guarantees it can reference the ZSTD_inBuffer until @@ -1822,18 +2042,15 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * large enough to fit a block (see ZSTD_c_stableOutBuffer). This will also * avoid the memcpy() from the input buffer to the input window buffer. * - * NOTE: ZSTD_compressStream2() will error if ZSTD_e_end is not used. - * That means this flag cannot be used with ZSTD_compressStream(). - * * NOTE: So long as the ZSTD_inBuffer always points to valid memory, using * this flag is ALWAYS memory safe, and will never access out-of-bounds - * memory. However, compression WILL fail if you violate the preconditions. + * memory. However, compression WILL fail if conditions are not respected. * - * WARNING: The data in the ZSTD_inBuffer in the range [dst, dst + pos) MUST - * not be modified during compression or you will get data corruption. This - * is because zstd needs to reference data in the ZSTD_inBuffer to find + * WARNING: The data in the ZSTD_inBuffer in the range [src, src + pos) MUST + * not be modified during compression or it will result in data corruption. + * This is because zstd needs to reference data in the ZSTD_inBuffer to find * matches. Normally zstd maintains its own window buffer for this purpose, - * but passing this flag tells zstd to use the user provided buffer. + * but passing this flag tells zstd to rely on user provided buffer instead. */ #define ZSTD_c_stableInBuffer ZSTD_c_experimentalParam9 @@ -1878,7 +2095,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo * Without validation, providing a sequence that does not conform to the zstd spec will cause * undefined behavior, and may produce a corrupted block. * - * With validation enabled, a if sequence is invalid (see doc/zstd_compression_format.md for + * With validation enabled, if sequence is invalid (see doc/zstd_compression_format.md for * specifics regarding offset/matchlength requirements) then the function will bail out and * return an error. * @@ -1928,6 +2145,79 @@ ZSTDLIB_STATIC_API size_t ZSTD_CCtx_refPrefix_advanced(ZSTD_CCtx* cctx, const vo */ #define ZSTD_c_deterministicRefPrefix ZSTD_c_experimentalParam15 +/* ZSTD_c_prefetchCDictTables + * Controlled with ZSTD_paramSwitch_e enum. Default is ZSTD_ps_auto. + * + * In some situations, zstd uses CDict tables in-place rather than copying them + * into the working context. (See docs on ZSTD_dictAttachPref_e above for details). + * In such situations, compression speed is seriously impacted when CDict tables are + * "cold" (outside CPU cache). This parameter instructs zstd to prefetch CDict tables + * when they are used in-place. + * + * For sufficiently small inputs, the cost of the prefetch will outweigh the benefit. + * For sufficiently large inputs, zstd will by default memcpy() CDict tables + * into the working context, so there is no need to prefetch. This parameter is + * targeted at a middle range of input sizes, where a prefetch is cheap enough to be + * useful but memcpy() is too expensive. The exact range of input sizes where this + * makes sense is best determined by careful experimentation. + * + * Note: for this parameter, ZSTD_ps_auto is currently equivalent to ZSTD_ps_disable, + * but in the future zstd may conditionally enable this feature via an auto-detection + * heuristic for cold CDicts. + * Use ZSTD_ps_disable to opt out of prefetching under any circumstances. + */ +#define ZSTD_c_prefetchCDictTables ZSTD_c_experimentalParam16 + +/* ZSTD_c_enableSeqProducerFallback + * Allowed values are 0 (disable) and 1 (enable). The default setting is 0. + * + * Controls whether zstd will fall back to an internal sequence producer if an + * external sequence producer is registered and returns an error code. This fallback + * is block-by-block: the internal sequence producer will only be called for blocks + * where the external sequence producer returns an error code. Fallback parsing will + * follow any other cParam settings, such as compression level, the same as in a + * normal (fully-internal) compression operation. + * + * The user is strongly encouraged to read the full Block-Level Sequence Producer API + * documentation (below) before setting this parameter. */ +#define ZSTD_c_enableSeqProducerFallback ZSTD_c_experimentalParam17 + +/* ZSTD_c_maxBlockSize + * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). + * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. + * + * This parameter can be used to set an upper bound on the blocksize + * that overrides the default ZSTD_BLOCKSIZE_MAX. It cannot be used to set upper + * bounds greater than ZSTD_BLOCKSIZE_MAX or bounds lower than 1KB (will make + * compressBound() inaccurate). Only currently meant to be used for testing. + * + */ +#define ZSTD_c_maxBlockSize ZSTD_c_experimentalParam18 + +/* ZSTD_c_searchForExternalRepcodes + * This parameter affects how zstd parses external sequences, such as sequences + * provided through the compressSequences() API or from an external block-level + * sequence producer. + * + * If set to ZSTD_ps_enable, the library will check for repeated offsets in + * external sequences, even if those repcodes are not explicitly indicated in + * the "rep" field. Note that this is the only way to exploit repcode matches + * while using compressSequences() or an external sequence producer, since zstd + * currently ignores the "rep" field of external sequences. + * + * If set to ZSTD_ps_disable, the library will not exploit repeated offsets in + * external sequences, regardless of whether the "rep" field has been set. This + * reduces sequence compression overhead by about 25% while sacrificing some + * compression ratio. + * + * The default value is ZSTD_ps_auto, for which the library will enable/disable + * based on compression level. + * + * Note: for now, this param only has an effect if ZSTD_c_blockDelimiters is + * set to ZSTD_sf_explicitBlockDelimiters. That may change in the future. + */ +#define ZSTD_c_searchForExternalRepcodes ZSTD_c_experimentalParam19 + /*! ZSTD_CCtx_getParameter() : * Get the requested compression parameter value, selected by enum ZSTD_cParameter, * and store it into int* value. @@ -2084,7 +2374,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete * in the range [dst, dst + pos) MUST not be modified during decompression * or you will get data corruption. * - * When this flags is enabled zstd won't allocate an output buffer, because + * When this flag is enabled zstd won't allocate an output buffer, because * it can write directly to the ZSTD_outBuffer, but it will still allocate * an input buffer large enough to fit any compressed block. This will also * avoid the memcpy() from the internal output buffer to the ZSTD_outBuffer. @@ -2137,6 +2427,33 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete */ #define ZSTD_d_refMultipleDDicts ZSTD_d_experimentalParam4 +/* ZSTD_d_disableHuffmanAssembly + * Set to 1 to disable the Huffman assembly implementation. + * The default value is 0, which allows zstd to use the Huffman assembly + * implementation if available. + * + * This parameter can be used to disable Huffman assembly at runtime. + * If you want to disable it at compile time you can define the macro + * ZSTD_DISABLE_ASM. + */ +#define ZSTD_d_disableHuffmanAssembly ZSTD_d_experimentalParam5 + +/* ZSTD_d_maxBlockSize + * Allowed values are between 1KB and ZSTD_BLOCKSIZE_MAX (128KB). + * The default is ZSTD_BLOCKSIZE_MAX, and setting to 0 will set to the default. + * + * Forces the decompressor to reject blocks whose content size is + * larger than the configured maxBlockSize. When maxBlockSize is + * larger than the windowSize, the windowSize is used instead. + * This saves memory on the decoder when you know all blocks are small. + * + * This option is typically used in conjunction with ZSTD_c_maxBlockSize. + * + * WARNING: This causes the decoder to reject otherwise valid frames + * that have block sizes larger than the configured maxBlockSize. + */ +#define ZSTD_d_maxBlockSize ZSTD_d_experimentalParam6 + /*! ZSTD_DCtx_setFormat() : * This function is REDUNDANT. Prefer ZSTD_DCtx_setParameter(). @@ -2145,6 +2462,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParamete * such ZSTD_f_zstd1_magicless for example. * @return : 0, or an error code (which can be tested using ZSTD_isError()). */ ZSTD_DEPRECATED("use ZSTD_DCtx_setParameter() instead") +ZSTDLIB_STATIC_API size_t ZSTD_DCtx_setFormat(ZSTD_DCtx* dctx, ZSTD_format_e format); /*! ZSTD_decompressStream_simpleArgs() : @@ -2181,6 +2499,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_decompressStream_simpleArgs ( * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, int compressionLevel, unsigned long long pledgedSrcSize); @@ -2198,17 +2517,15 @@ size_t ZSTD_initCStream_srcSize(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, const void* dict, size_t dictSize, int compressionLevel); /*! ZSTD_initCStream_advanced() : - * This function is DEPRECATED, and is approximately equivalent to: + * This function is DEPRECATED, and is equivalent to: * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); - * // Pseudocode: Set each zstd parameter and leave the rest as-is. - * for ((param, value) : params) { - * ZSTD_CCtx_setParameter(zcs, param, value); - * } + * ZSTD_CCtx_setParams(zcs, params); * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); * ZSTD_CCtx_loadDictionary(zcs, dict, dictSize); * @@ -2218,6 +2535,7 @@ size_t ZSTD_initCStream_usingDict(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, const void* dict, size_t dictSize, ZSTD_parameters params, @@ -2232,15 +2550,13 @@ size_t ZSTD_initCStream_advanced(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); /*! ZSTD_initCStream_usingCDict_advanced() : - * This function is DEPRECATED, and is approximately equivalent to: + * This function is DEPRECATED, and is equivalent to: * ZSTD_CCtx_reset(zcs, ZSTD_reset_session_only); - * // Pseudocode: Set each zstd frame parameter and leave the rest as-is. - * for ((fParam, value) : fParams) { - * ZSTD_CCtx_setParameter(zcs, fParam, value); - * } + * ZSTD_CCtx_setFParams(zcs, fParams); * ZSTD_CCtx_setPledgedSrcSize(zcs, pledgedSrcSize); * ZSTD_CCtx_refCDict(zcs, cdict); * @@ -2250,6 +2566,7 @@ size_t ZSTD_initCStream_usingCDict(ZSTD_CStream* zcs, const ZSTD_CDict* cdict); * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset and ZSTD_CCtx_refCDict, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, const ZSTD_CDict* cdict, ZSTD_frameParameters fParams, @@ -2264,7 +2581,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, * explicitly specified. * * start a new frame, using same parameters from previous frame. - * This is typically useful to skip dictionary loading stage, since it will re-use it in-place. + * This is typically useful to skip dictionary loading stage, since it will reuse it in-place. * Note that zcs must be init at least once before using ZSTD_resetCStream(). * If pledgedSrcSize is not known at reset time, use macro ZSTD_CONTENTSIZE_UNKNOWN. * If pledgedSrcSize > 0, its value must be correct, as it will be written in header, and controlled at the end. @@ -2274,6 +2591,7 @@ size_t ZSTD_initCStream_usingCDict_advanced(ZSTD_CStream* zcs, * This prototype will generate compilation warnings. */ ZSTD_DEPRECATED("use ZSTD_CCtx_reset, see zstd.h for detailed instructions") +ZSTDLIB_STATIC_API size_t ZSTD_resetCStream(ZSTD_CStream* zcs, unsigned long long pledgedSrcSize); @@ -2319,8 +2637,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_toFlushNow(ZSTD_CCtx* cctx); * ZSTD_DCtx_loadDictionary(zds, dict, dictSize); * * note: no dictionary will be used if dict == NULL or dictSize < 8 - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_loadDictionary, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t dictSize); /*! @@ -2330,8 +2648,8 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const vo * ZSTD_DCtx_refDDict(zds, ddict); * * note : ddict is referenced, it must outlive decompression session - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset + ZSTD_DCtx_refDDict, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const ZSTD_DDict* ddict); /*! @@ -2339,18 +2657,202 @@ ZSTDLIB_STATIC_API size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* zds, const Z * * ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); * - * re-use decompression parameters from previous init; saves dictionary loading - * Note : this prototype will be marked as deprecated and generate compilation warnings on reaching v1.5.x + * reuse decompression parameters from previous init; saves dictionary loading */ +ZSTD_DEPRECATED("use ZSTD_DCtx_reset, see zstd.h for detailed instructions") ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); +/* ********************* BLOCK-LEVEL SEQUENCE PRODUCER API ********************* + * + * *** OVERVIEW *** + * The Block-Level Sequence Producer API allows users to provide their own custom + * sequence producer which libzstd invokes to process each block. The produced list + * of sequences (literals and matches) is then post-processed by libzstd to produce + * valid compressed blocks. + * + * This block-level offload API is a more granular complement of the existing + * frame-level offload API compressSequences() (introduced in v1.5.1). It offers + * an easier migration story for applications already integrated with libzstd: the + * user application continues to invoke the same compression functions + * ZSTD_compress2() or ZSTD_compressStream2() as usual, and transparently benefits + * from the specific advantages of the external sequence producer. For example, + * the sequence producer could be tuned to take advantage of known characteristics + * of the input, to offer better speed / ratio, or could leverage hardware + * acceleration not available within libzstd itself. + * + * See contrib/externalSequenceProducer for an example program employing the + * Block-Level Sequence Producer API. + * + * *** USAGE *** + * The user is responsible for implementing a function of type + * ZSTD_sequenceProducer_F. For each block, zstd will pass the following + * arguments to the user-provided function: + * + * - sequenceProducerState: a pointer to a user-managed state for the sequence + * producer. + * + * - outSeqs, outSeqsCapacity: an output buffer for the sequence producer. + * outSeqsCapacity is guaranteed >= ZSTD_sequenceBound(srcSize). The memory + * backing outSeqs is managed by the CCtx. + * + * - src, srcSize: an input buffer for the sequence producer to parse. + * srcSize is guaranteed to be <= ZSTD_BLOCKSIZE_MAX. + * + * - dict, dictSize: a history buffer, which may be empty, which the sequence + * producer may reference as it parses the src buffer. Currently, zstd will + * always pass dictSize == 0 into external sequence producers, but this will + * change in the future. + * + * - compressionLevel: a signed integer representing the zstd compression level + * set by the user for the current operation. The sequence producer may choose + * to use this information to change its compression strategy and speed/ratio + * tradeoff. Note: the compression level does not reflect zstd parameters set + * through the advanced API. + * + * - windowSize: a size_t representing the maximum allowed offset for external + * sequences. Note that sequence offsets are sometimes allowed to exceed the + * windowSize if a dictionary is present, see doc/zstd_compression_format.md + * for details. + * + * The user-provided function shall return a size_t representing the number of + * sequences written to outSeqs. This return value will be treated as an error + * code if it is greater than outSeqsCapacity. The return value must be non-zero + * if srcSize is non-zero. The ZSTD_SEQUENCE_PRODUCER_ERROR macro is provided + * for convenience, but any value greater than outSeqsCapacity will be treated as + * an error code. + * + * If the user-provided function does not return an error code, the sequences + * written to outSeqs must be a valid parse of the src buffer. Data corruption may + * occur if the parse is not valid. A parse is defined to be valid if the + * following conditions hold: + * - The sum of matchLengths and literalLengths must equal srcSize. + * - All sequences in the parse, except for the final sequence, must have + * matchLength >= ZSTD_MINMATCH_MIN. The final sequence must have + * matchLength >= ZSTD_MINMATCH_MIN or matchLength == 0. + * - All offsets must respect the windowSize parameter as specified in + * doc/zstd_compression_format.md. + * - If the final sequence has matchLength == 0, it must also have offset == 0. + * + * zstd will only validate these conditions (and fail compression if they do not + * hold) if the ZSTD_c_validateSequences cParam is enabled. Note that sequence + * validation has a performance cost. + * + * If the user-provided function returns an error, zstd will either fall back + * to an internal sequence producer or fail the compression operation. The user can + * choose between the two behaviors by setting the ZSTD_c_enableSeqProducerFallback + * cParam. Fallback compression will follow any other cParam settings, such as + * compression level, the same as in a normal compression operation. + * + * The user shall instruct zstd to use a particular ZSTD_sequenceProducer_F + * function by calling + * ZSTD_registerSequenceProducer(cctx, + * sequenceProducerState, + * sequenceProducer) + * This setting will persist until the next parameter reset of the CCtx. + * + * The sequenceProducerState must be initialized by the user before calling + * ZSTD_registerSequenceProducer(). The user is responsible for destroying the + * sequenceProducerState. + * + * *** LIMITATIONS *** + * This API is compatible with all zstd compression APIs which respect advanced parameters. + * However, there are three limitations: + * + * First, the ZSTD_c_enableLongDistanceMatching cParam is not currently supported. + * COMPRESSION WILL FAIL if it is enabled and the user tries to compress with a block-level + * external sequence producer. + * - Note that ZSTD_c_enableLongDistanceMatching is auto-enabled by default in some + * cases (see its documentation for details). Users must explicitly set + * ZSTD_c_enableLongDistanceMatching to ZSTD_ps_disable in such cases if an external + * sequence producer is registered. + * - As of this writing, ZSTD_c_enableLongDistanceMatching is disabled by default + * whenever ZSTD_c_windowLog < 128MB, but that's subject to change. Users should + * check the docs on ZSTD_c_enableLongDistanceMatching whenever the Block-Level Sequence + * Producer API is used in conjunction with advanced settings (like ZSTD_c_windowLog). + * + * Second, history buffers are not currently supported. Concretely, zstd will always pass + * dictSize == 0 to the external sequence producer (for now). This has two implications: + * - Dictionaries are not currently supported. Compression will *not* fail if the user + * references a dictionary, but the dictionary won't have any effect. + * - Stream history is not currently supported. All advanced compression APIs, including + * streaming APIs, work with external sequence producers, but each block is treated as + * an independent chunk without history from previous blocks. + * + * Third, multi-threading within a single compression is not currently supported. In other words, + * COMPRESSION WILL FAIL if ZSTD_c_nbWorkers > 0 and an external sequence producer is registered. + * Multi-threading across compressions is fine: simply create one CCtx per thread. + * + * Long-term, we plan to overcome all three limitations. There is no technical blocker to + * overcoming them. It is purely a question of engineering effort. + */ + +#define ZSTD_SEQUENCE_PRODUCER_ERROR ((size_t)(-1)) + +typedef size_t (*ZSTD_sequenceProducer_F) ( + void* sequenceProducerState, + ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, + const void* src, size_t srcSize, + const void* dict, size_t dictSize, + int compressionLevel, + size_t windowSize +); + +/*! ZSTD_registerSequenceProducer() : + * Instruct zstd to use a block-level external sequence producer function. + * + * The sequenceProducerState must be initialized by the caller, and the caller is + * responsible for managing its lifetime. This parameter is sticky across + * compressions. It will remain set until the user explicitly resets compression + * parameters. + * + * Sequence producer registration is considered to be an "advanced parameter", + * part of the "advanced API". This means it will only have an effect on compression + * APIs which respect advanced parameters, such as compress2() and compressStream2(). + * Older compression APIs such as compressCCtx(), which predate the introduction of + * "advanced parameters", will ignore any external sequence producer setting. + * + * The sequence producer can be "cleared" by registering a NULL function pointer. This + * removes all limitations described above in the "LIMITATIONS" section of the API docs. + * + * The user is strongly encouraged to read the full API documentation (above) before + * calling this function. */ +ZSTDLIB_STATIC_API void +ZSTD_registerSequenceProducer( + ZSTD_CCtx* cctx, + void* sequenceProducerState, + ZSTD_sequenceProducer_F sequenceProducer +); + +/*! ZSTD_CCtxParams_registerSequenceProducer() : + * Same as ZSTD_registerSequenceProducer(), but operates on ZSTD_CCtx_params. + * This is used for accurate size estimation with ZSTD_estimateCCtxSize_usingCCtxParams(), + * which is needed when creating a ZSTD_CCtx with ZSTD_initStaticCCtx(). + * + * If you are using the external sequence producer API in a scenario where ZSTD_initStaticCCtx() + * is required, then this function is for you. Otherwise, you probably don't need it. + * + * See tests/zstreamtest.c for example usage. */ +ZSTDLIB_STATIC_API void +ZSTD_CCtxParams_registerSequenceProducer( + ZSTD_CCtx_params* params, + void* sequenceProducerState, + ZSTD_sequenceProducer_F sequenceProducer +); + + /* ******************************************************************* -* Buffer-less and synchronous inner streaming functions +* Buffer-less and synchronous inner streaming functions (DEPRECATED) +* +* This API is deprecated, and will be removed in a future version. +* It allows streaming (de)compression with user allocated buffers. +* However, it is hard to use, and not as well tested as the rest of +* our API. * -* This is an advanced API, giving full control over buffer management, for users which need direct control over memory. -* But it's also a complex one, with several restrictions, documented below. -* Prefer normal streaming API for an easier experience. +* Please use the normal streaming API instead: ZSTD_compressStream2, +* and ZSTD_decompressStream. +* If there is functionality that you need, but it doesn't provide, +* please open an issue on our GitHub. ********************************************************************* */ /* @@ -2358,11 +2860,10 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); A ZSTD_CCtx object is required to track streaming operations. Use ZSTD_createCCtx() / ZSTD_freeCCtx() to manage resource. - ZSTD_CCtx object can be re-used multiple times within successive compression operations. + ZSTD_CCtx object can be reused multiple times within successive compression operations. Start by initializing a context. Use ZSTD_compressBegin(), or ZSTD_compressBegin_usingDict() for dictionary compression. - It's also possible to duplicate a reference context which has already been initialized, using ZSTD_copyCCtx() Then, consume your input using ZSTD_compressContinue(). There are some important considerations to keep in mind when using this advanced function : @@ -2380,36 +2881,46 @@ ZSTDLIB_STATIC_API size_t ZSTD_resetDStream(ZSTD_DStream* zds); It's possible to use srcSize==0, in which case, it will write a final empty block to end the frame. Without last block mark, frames are considered unfinished (hence corrupted) by compliant decoders. - `ZSTD_CCtx` object can be re-used (ZSTD_compressBegin()) to compress again. + `ZSTD_CCtx` object can be reused (ZSTD_compressBegin()) to compress again. */ /*===== Buffer-less streaming compression functions =====*/ +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); /*< note: fails if cdict==NULL */ -ZSTDLIB_STATIC_API size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ +ZSTD_DEPRECATED("This function will likely be removed in a future release. It is misleading and has very limited utility.") +ZSTDLIB_STATIC_API +size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); /*< note: if pledgedSrcSize is not known, use ZSTD_CONTENTSIZE_UNKNOWN */ + +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The buffer-less API is deprecated in favor of the normal streaming API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); /* The ZSTD_compressBegin_advanced() and ZSTD_compressBegin_usingCDict_advanced() are now DEPRECATED and will generate a compiler warning */ ZSTD_DEPRECATED("use advanced API to access custom parameters") +ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize); /*< pledgedSrcSize : If srcSize is not known at init time, use ZSTD_CONTENTSIZE_UNKNOWN */ ZSTD_DEPRECATED("use advanced API to access custom parameters") +ZSTDLIB_STATIC_API size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_CDict* const cdict, ZSTD_frameParameters const fParams, unsigned long long const pledgedSrcSize); /* compression parameters are already set within cdict. pledgedSrcSize must be correct. If srcSize is not known, use macro ZSTD_CONTENTSIZE_UNKNOWN */ /* Buffer-less streaming decompression (synchronous mode) A ZSTD_DCtx object is required to track streaming operations. Use ZSTD_createDCtx() / ZSTD_freeDCtx() to manage it. - A ZSTD_DCtx object can be re-used multiple times. + A ZSTD_DCtx object can be reused multiple times. First typical operation is to retrieve frame parameters, using ZSTD_getFrameHeader(). Frame header is extracted from the beginning of compressed frame, so providing only the frame's beginning is enough. Data fragment must be large enough to ensure successful decoding. `ZSTD_frameHeaderSize_max` bytes is guaranteed to always be large enough. - @result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. - >0 : `srcSize` is too small, please provide at least @result bytes on next attempt. + result : 0 : successful decoding, the `ZSTD_frameHeader` structure is correctly filled. + >0 : `srcSize` is too small, please provide at least result bytes on next attempt. errorCode, which can be tested using ZSTD_isError(). It fills a ZSTD_frameHeader structure with important information to correctly decode the frame, @@ -2428,7 +2939,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ The most memory efficient way is to use a round buffer of sufficient size. Sufficient size is determined by invoking ZSTD_decodingBufferSize_min(), - which can @return an error code if required value is too large for current system (in 32-bits mode). + which can return an error code if required value is too large for current system (in 32-bits mode). In a round buffer methodology, ZSTD_decompressContinue() decompresses each block next to previous one, up to the moment there is not enough room left in the buffer to guarantee decoding another full block, which maximum size is provided in `ZSTD_frameHeader` structure, field `blockSizeMax`. @@ -2448,7 +2959,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ ZSTD_nextSrcSizeToDecompress() tells how many bytes to provide as 'srcSize' to ZSTD_decompressContinue(). ZSTD_decompressContinue() requires this _exact_ amount of bytes, or it will fail. - @result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). + result of ZSTD_decompressContinue() is the number of bytes regenerated within 'dst' (necessarily <= dstCapacity). It can be zero : it just means ZSTD_decompressContinue() has decoded some metadata item. It can also be an error code, which can be tested with ZSTD_isError(). @@ -2471,27 +2982,7 @@ size_t ZSTD_compressBegin_usingCDict_advanced(ZSTD_CCtx* const cctx, const ZSTD_ */ /*===== Buffer-less streaming decompression functions =====*/ -typedef enum { ZSTD_frame, ZSTD_skippableFrame } ZSTD_frameType_e; -typedef struct { - unsigned long long frameContentSize; /* if == ZSTD_CONTENTSIZE_UNKNOWN, it means this field is not available. 0 means "empty" */ - unsigned long long windowSize; /* can be very large, up to <= frameContentSize */ - unsigned blockSizeMax; - ZSTD_frameType_e frameType; /* if == ZSTD_skippableFrame, frameContentSize is the size of skippable content */ - unsigned headerSize; - unsigned dictID; - unsigned checksumFlag; -} ZSTD_frameHeader; -/*! ZSTD_getFrameHeader() : - * decode Frame Header, or requires larger `srcSize`. - * @return : 0, `zfhPtr` is correctly filled, - * >0, `srcSize` is too small, value is wanted `srcSize` amount, - * or an error code, which can be tested using ZSTD_isError() */ -ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize); /*< doesn't consume input */ -/*! ZSTD_getFrameHeader_advanced() : - * same as ZSTD_getFrameHeader(), - * with added capability to select a format (like ZSTD_f_zstd1_magicless) */ -ZSTDLIB_STATIC_API size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format); ZSTDLIB_STATIC_API size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize); /*< when frame content size is not known, pass in frameContentSize == ZSTD_CONTENTSIZE_UNKNOWN */ ZSTDLIB_STATIC_API size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx); @@ -2502,6 +2993,7 @@ ZSTDLIB_STATIC_API size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx); ZSTDLIB_STATIC_API size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); /* misc */ +ZSTD_DEPRECATED("This function will likely be removed in the next minor release. It is misleading and has very limited utility.") ZSTDLIB_STATIC_API void ZSTD_copyDCtx(ZSTD_DCtx* dctx, const ZSTD_DCtx* preparedDCtx); typedef enum { ZSTDnit_frameHeader, ZSTDnit_blockHeader, ZSTDnit_block, ZSTDnit_lastBlock, ZSTDnit_checksum, ZSTDnit_skippableFrame } ZSTD_nextInputType_e; ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); @@ -2509,11 +3001,23 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); -/* ============================ */ -/* Block level API */ -/* ============================ */ +/* ========================================= */ +/* Block level API (DEPRECATED) */ +/* ========================================= */ /*! + + This API is deprecated in favor of the regular compression API. + You can get the frame header down to 2 bytes by setting: + - ZSTD_c_format = ZSTD_f_zstd1_magicless + - ZSTD_c_contentSizeFlag = 0 + - ZSTD_c_checksumFlag = 0 + - ZSTD_c_dictIDFlag = 0 + + This API is not as well tested as our normal API, so we recommend not using it. + We will be removing it in a future version. If the normal API doesn't provide + the functionality you need, please open a GitHub issue. + Block functions produce and decode raw zstd blocks, without frame metadata. Frame metadata cost is typically ~12 bytes, which can be non-negligible for very small blocks (< 100 bytes). But users will have to take in charge needed metadata to regenerate data, such as compressed and content sizes. @@ -2524,7 +3028,6 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); - It is necessary to init context before starting + compression : any ZSTD_compressBegin*() variant, including with dictionary + decompression : any ZSTD_decompressBegin*() variant, including with dictionary - + copyCCtx() and copyDCtx() can be used too - Block size is limited, it must be <= ZSTD_getBlockSize() <= ZSTD_BLOCKSIZE_MAX == 128 KB + If input is larger than a block size, it's necessary to split input data into multiple blocks + For inputs larger than a single block, consider using regular ZSTD_compress() instead. @@ -2541,11 +3044,14 @@ ZSTDLIB_STATIC_API ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx); */ /*===== Raw zstd block functions =====*/ +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_getBlockSize (const ZSTD_CCtx* cctx); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_compressBlock (ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); +ZSTD_DEPRECATED("The block API is deprecated in favor of the normal compression API. See docs.") ZSTDLIB_STATIC_API size_t ZSTD_insertBlock (ZSTD_DCtx* dctx, const void* blockStart, size_t blockSize); /*< insert uncompressed block into `dctx` history. Useful for multi-blocks decompression. */ - #endif /* ZSTD_H_ZSTD_STATIC_LINKING_ONLY */ diff --git a/lib/zstd/Makefile b/lib/zstd/Makefile index 20f08c644b71..464c410b2768 100644 --- a/lib/zstd/Makefile +++ b/lib/zstd/Makefile @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause # ################################################################ -# Copyright (c) Facebook, Inc. +# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/common/allocations.h b/lib/zstd/common/allocations.h new file mode 100644 index 000000000000..16c3d08e8d1a --- /dev/null +++ b/lib/zstd/common/allocations.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +/* This file provides custom allocation primitives + */ + +#define ZSTD_DEPS_NEED_MALLOC +#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ + +#include "compiler.h" /* MEM_STATIC */ +#define ZSTD_STATIC_LINKING_ONLY +#include /* ZSTD_customMem */ + +#ifndef ZSTD_ALLOCATIONS_H +#define ZSTD_ALLOCATIONS_H + +/* custom memory allocation functions */ + +MEM_STATIC void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) +{ + if (customMem.customAlloc) + return customMem.customAlloc(customMem.opaque, size); + return ZSTD_malloc(size); +} + +MEM_STATIC void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) +{ + if (customMem.customAlloc) { + /* calloc implemented as malloc+memset; + * not as efficient as calloc, but next best guess for custom malloc */ + void* const ptr = customMem.customAlloc(customMem.opaque, size); + ZSTD_memset(ptr, 0, size); + return ptr; + } + return ZSTD_calloc(1, size); +} + +MEM_STATIC void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) +{ + if (ptr!=NULL) { + if (customMem.customFree) + customMem.customFree(customMem.opaque, ptr); + else + ZSTD_free(ptr); + } +} + +#endif /* ZSTD_ALLOCATIONS_H */ diff --git a/lib/zstd/common/bits.h b/lib/zstd/common/bits.h new file mode 100644 index 000000000000..aa3487ec4b6a --- /dev/null +++ b/lib/zstd/common/bits.h @@ -0,0 +1,149 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#ifndef ZSTD_BITS_H +#define ZSTD_BITS_H + +#include "mem.h" + +MEM_STATIC unsigned ZSTD_countTrailingZeros32_fallback(U32 val) +{ + assert(val != 0); + { + static const U32 DeBruijnBytePos[32] = {0, 1, 28, 2, 29, 14, 24, 3, + 30, 22, 20, 15, 25, 17, 4, 8, + 31, 27, 13, 23, 21, 19, 16, 7, + 26, 12, 18, 6, 11, 5, 10, 9}; + return DeBruijnBytePos[((U32) ((val & -(S32) val) * 0x077CB531U)) >> 27]; + } +} + +MEM_STATIC unsigned ZSTD_countTrailingZeros32(U32 val) +{ + assert(val != 0); +# if (__GNUC__ >= 4) + return (unsigned)__builtin_ctz(val); +# else + return ZSTD_countTrailingZeros32_fallback(val); +# endif +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros32_fallback(U32 val) { + assert(val != 0); + { + static const U32 DeBruijnClz[32] = {0, 9, 1, 10, 13, 21, 2, 29, + 11, 14, 16, 18, 22, 25, 3, 30, + 8, 12, 20, 28, 15, 17, 24, 7, + 19, 27, 23, 6, 26, 5, 4, 31}; + val |= val >> 1; + val |= val >> 2; + val |= val >> 4; + val |= val >> 8; + val |= val >> 16; + return 31 - DeBruijnClz[(val * 0x07C4ACDDU) >> 27]; + } +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros32(U32 val) +{ + assert(val != 0); +# if (__GNUC__ >= 4) + return (unsigned)__builtin_clz(val); +# else + return ZSTD_countLeadingZeros32_fallback(val); +# endif +} + +MEM_STATIC unsigned ZSTD_countTrailingZeros64(U64 val) +{ + assert(val != 0); +# if (__GNUC__ >= 4) && defined(__LP64__) + return (unsigned)__builtin_ctzll(val); +# else + { + U32 mostSignificantWord = (U32)(val >> 32); + U32 leastSignificantWord = (U32)val; + if (leastSignificantWord == 0) { + return 32 + ZSTD_countTrailingZeros32(mostSignificantWord); + } else { + return ZSTD_countTrailingZeros32(leastSignificantWord); + } + } +# endif +} + +MEM_STATIC unsigned ZSTD_countLeadingZeros64(U64 val) +{ + assert(val != 0); +# if (__GNUC__ >= 4) + return (unsigned)(__builtin_clzll(val)); +# else + { + U32 mostSignificantWord = (U32)(val >> 32); + U32 leastSignificantWord = (U32)val; + if (mostSignificantWord == 0) { + return 32 + ZSTD_countLeadingZeros32(leastSignificantWord); + } else { + return ZSTD_countLeadingZeros32(mostSignificantWord); + } + } +# endif +} + +MEM_STATIC unsigned ZSTD_NbCommonBytes(size_t val) +{ + if (MEM_isLittleEndian()) { + if (MEM_64bits()) { + return ZSTD_countTrailingZeros64((U64)val) >> 3; + } else { + return ZSTD_countTrailingZeros32((U32)val) >> 3; + } + } else { /* Big Endian CPU */ + if (MEM_64bits()) { + return ZSTD_countLeadingZeros64((U64)val) >> 3; + } else { + return ZSTD_countLeadingZeros32((U32)val) >> 3; + } + } +} + +MEM_STATIC unsigned ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ +{ + assert(val != 0); + return 31 - ZSTD_countLeadingZeros32(val); +} + +/* ZSTD_rotateRight_*(): + * Rotates a bitfield to the right by "count" bits. + * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts + */ +MEM_STATIC +U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { + assert(count < 64); + count &= 0x3F; /* for fickle pattern recognition */ + return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); +} + +MEM_STATIC +U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { + assert(count < 32); + count &= 0x1F; /* for fickle pattern recognition */ + return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); +} + +MEM_STATIC +U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { + assert(count < 16); + count &= 0x0F; /* for fickle pattern recognition */ + return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); +} + +#endif /* ZSTD_BITS_H */ diff --git a/lib/zstd/common/bitstream.h b/lib/zstd/common/bitstream.h index feef3a1b1d60..6a13f1f0f1e8 100644 --- a/lib/zstd/common/bitstream.h +++ b/lib/zstd/common/bitstream.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * bitstream * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -27,6 +28,7 @@ #include "compiler.h" /* UNLIKELY() */ #include "debug.h" /* assert(), DEBUGLOG(), RAWLOG() */ #include "error_private.h" /* error codes and messages */ +#include "bits.h" /* ZSTD_highbit32 */ /*========================================= @@ -79,19 +81,20 @@ MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC); /*-******************************************** * bitStream decoding API (read backward) **********************************************/ +typedef size_t BitContainerType; typedef struct { - size_t bitContainer; + BitContainerType bitContainer; unsigned bitsConsumed; const char* ptr; const char* start; const char* limitPtr; } BIT_DStream_t; -typedef enum { BIT_DStream_unfinished = 0, - BIT_DStream_endOfBuffer = 1, - BIT_DStream_completed = 2, - BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */ - /* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */ +typedef enum { BIT_DStream_unfinished = 0, /* fully refilled */ + BIT_DStream_endOfBuffer = 1, /* still some bits left in bitstream */ + BIT_DStream_completed = 2, /* bitstream entirely consumed, bit-exact */ + BIT_DStream_overflow = 3 /* user requested more bits than present in bitstream */ + } BIT_DStream_status; /* result of BIT_reloadDStream() */ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize); MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits); @@ -101,7 +104,7 @@ MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD); /* Start by invoking BIT_initDStream(). * A chunk of the bitStream is then stored into a local register. -* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t). +* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (BitContainerType). * You can then retrieve bitFields stored into the local register, **in reverse order**. * Local register is explicitly reloaded from memory by the BIT_reloadDStream() method. * A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished. @@ -122,33 +125,6 @@ MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC); MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits); /* faster, but works only if nbBits >= 1 */ - - -/*-************************************************************** -* Internal functions -****************************************************************/ -MEM_STATIC unsigned BIT_highbit32 (U32 val) -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* Use GCC Intrinsic */ - return __builtin_clz (val) ^ 31; -# else /* Software version */ - static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, - 11, 14, 16, 18, 22, 25, 3, 30, - 8, 12, 20, 28, 15, 17, 24, 7, - 19, 27, 23, 6, 26, 5, 4, 31 }; - U32 v = val; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27]; -# endif - } -} - /*===== Local Constants =====*/ static const unsigned BIT_mask[] = { 0, 1, 3, 7, 0xF, 0x1F, @@ -178,6 +154,12 @@ MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, return 0; } +FORCE_INLINE_TEMPLATE size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) +{ + assert(nbBits < BIT_MASK_SIZE); + return bitContainer & BIT_mask[nbBits]; +} + /*! BIT_addBits() : * can add up to 31 bits into `bitC`. * Note : does not check for register overflow ! */ @@ -187,7 +169,7 @@ MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, DEBUG_STATIC_ASSERT(BIT_MASK_SIZE == 32); assert(nbBits < BIT_MASK_SIZE); assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8); - bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos; + bitC->bitContainer |= BIT_getLowerBits(value, nbBits) << bitC->bitPos; bitC->bitPos += nbBits; } @@ -266,35 +248,35 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer); bitD->bitContainer = MEM_readLEST(bitD->ptr); { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; - bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ + bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */ if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ } } else { bitD->ptr = bitD->start; bitD->bitContainer = *(const BYTE*)(bitD->start); switch(srcSize) { - case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); + case 7: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16); ZSTD_FALLTHROUGH; - case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); + case 6: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24); ZSTD_FALLTHROUGH; - case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); + case 5: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32); ZSTD_FALLTHROUGH; - case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24; + case 4: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[3]) << 24; ZSTD_FALLTHROUGH; - case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16; + case 3: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[2]) << 16; ZSTD_FALLTHROUGH; - case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8; + case 2: bitD->bitContainer += (BitContainerType)(((const BYTE*)(srcBuffer))[1]) << 8; ZSTD_FALLTHROUGH; default: break; } { BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1]; - bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; + bitD->bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */ } bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8; @@ -303,12 +285,12 @@ MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, si return srcSize; } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getUpperBits(size_t bitContainer, U32 const start) +FORCE_INLINE_TEMPLATE size_t BIT_getUpperBits(BitContainerType bitContainer, U32 const start) { return bitContainer >> start; } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits) +FORCE_INLINE_TEMPLATE size_t BIT_getMiddleBits(BitContainerType bitContainer, U32 const start, U32 const nbBits) { U32 const regMask = sizeof(bitContainer)*8 - 1; /* if start > regMask, bitstream is corrupted, and result is undefined */ @@ -325,19 +307,13 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 c #endif } -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits) -{ - assert(nbBits < BIT_MASK_SIZE); - return bitContainer & BIT_mask[nbBits]; -} - /*! BIT_lookBits() : * Provides next n bits from local register. * local register is not modified. * On 32-bits, maxNbBits==24. * On 64-bits, maxNbBits==56. * @return : value extracted */ -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) +FORCE_INLINE_TEMPLATE size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits) { /* arbitrate between double-shift and shift+mask */ #if 1 @@ -360,7 +336,7 @@ MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits) return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask); } -MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) +FORCE_INLINE_TEMPLATE void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) { bitD->bitsConsumed += nbBits; } @@ -369,7 +345,7 @@ MEM_STATIC FORCE_INLINE_ATTR void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits) * Read (consume) next n bits from local register and update. * Pay attention to not read more than nbBits contained into local register. * @return : extracted value. */ -MEM_STATIC FORCE_INLINE_ATTR size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) +FORCE_INLINE_TEMPLATE size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits) { size_t const value = BIT_lookBits(bitD, nbBits); BIT_skipBits(bitD, nbBits); @@ -377,7 +353,7 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_readBits(BIT_DStream_t* bitD, unsigned n } /*! BIT_readBitsFast() : - * unsafe version; only works only if nbBits >= 1 */ + * unsafe version; only works if nbBits >= 1 */ MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) { size_t const value = BIT_lookBitsFast(bitD, nbBits); @@ -386,6 +362,21 @@ MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits) return value; } +/*! BIT_reloadDStream_internal() : + * Simple variant of BIT_reloadDStream(), with two conditions: + * 1. bitstream is valid : bitsConsumed <= sizeof(bitD->bitContainer)*8 + * 2. look window is valid after shifted down : bitD->ptr >= bitD->start + */ +MEM_STATIC BIT_DStream_status BIT_reloadDStream_internal(BIT_DStream_t* bitD) +{ + assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); + bitD->ptr -= bitD->bitsConsumed >> 3; + assert(bitD->ptr >= bitD->start); + bitD->bitsConsumed &= 7; + bitD->bitContainer = MEM_readLEST(bitD->ptr); + return BIT_DStream_unfinished; +} + /*! BIT_reloadDStreamFast() : * Similar to BIT_reloadDStream(), but with two differences: * 1. bitsConsumed <= sizeof(bitD->bitContainer)*8 must hold! @@ -396,31 +387,35 @@ MEM_STATIC BIT_DStream_status BIT_reloadDStreamFast(BIT_DStream_t* bitD) { if (UNLIKELY(bitD->ptr < bitD->limitPtr)) return BIT_DStream_overflow; - assert(bitD->bitsConsumed <= sizeof(bitD->bitContainer)*8); - bitD->ptr -= bitD->bitsConsumed >> 3; - bitD->bitsConsumed &= 7; - bitD->bitContainer = MEM_readLEST(bitD->ptr); - return BIT_DStream_unfinished; + return BIT_reloadDStream_internal(bitD); } /*! BIT_reloadDStream() : * Refill `bitD` from buffer previously set in BIT_initDStream() . - * This function is safe, it guarantees it will not read beyond src buffer. + * This function is safe, it guarantees it will not never beyond src buffer. * @return : status of `BIT_DStream_t` internal register. * when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */ -MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) +FORCE_INLINE_TEMPLATE BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD) { - if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */ + /* note : once in overflow mode, a bitstream remains in this mode until it's reset */ + if (UNLIKELY(bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8))) { + static const BitContainerType zeroFilled = 0; + bitD->ptr = (const char*)&zeroFilled; /* aliasing is allowed for char */ + /* overflow detected, erroneous scenario or end of stream: no update */ return BIT_DStream_overflow; + } + + assert(bitD->ptr >= bitD->start); if (bitD->ptr >= bitD->limitPtr) { - return BIT_reloadDStreamFast(bitD); + return BIT_reloadDStream_internal(bitD); } if (bitD->ptr == bitD->start) { + /* reached end of bitStream => no update */ if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer; return BIT_DStream_completed; } - /* start < ptr < limitPtr */ + /* start < ptr < limitPtr => cautious update */ { U32 nbBytes = bitD->bitsConsumed >> 3; BIT_DStream_status result = BIT_DStream_unfinished; if (bitD->ptr - nbBytes < bitD->start) { diff --git a/lib/zstd/common/compiler.h b/lib/zstd/common/compiler.h index c42d39faf9bd..508ee25537bb 100644 --- a/lib/zstd/common/compiler.h +++ b/lib/zstd/common/compiler.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,6 +12,8 @@ #ifndef ZSTD_COMPILER_H #define ZSTD_COMPILER_H +#include + #include "portability_macros.h" /*-******************************************************* @@ -41,12 +44,15 @@ */ #define WIN_CDECL +/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ +#define UNUSED_ATTR __attribute__((unused)) + /* * FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant * parameters. They must be inlined for the compiler to eliminate the constant * branches. */ -#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR +#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR UNUSED_ATTR /* * HINT_INLINE is used to help the compiler generate better code. It is *not* * used for "templates", so it can be tweaked based on the compilers @@ -61,11 +67,21 @@ #if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5 # define HINT_INLINE static INLINE_KEYWORD #else -# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR +# define HINT_INLINE FORCE_INLINE_TEMPLATE #endif -/* UNUSED_ATTR tells the compiler it is okay if the function is unused. */ -#define UNUSED_ATTR __attribute__((unused)) +/* "soft" inline : + * The compiler is free to select if it's a good idea to inline or not. + * The main objective is to silence compiler warnings + * when a defined function in included but not used. + * + * Note : this macro is prefixed `MEM_` because it used to be provided by `mem.h` unit. + * Updating the prefix is probably preferable, but requires a fairly large codemod, + * since this name is used everywhere. + */ +#ifndef MEM_STATIC /* already defined in Linux Kernel mem.h */ +#define MEM_STATIC static __inline UNUSED_ATTR +#endif /* force no inlining */ #define FORCE_NOINLINE static __attribute__((__noinline__)) @@ -86,23 +102,24 @@ # define PREFETCH_L1(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) # define PREFETCH_L2(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 2 /* locality */) #elif defined(__aarch64__) -# define PREFETCH_L1(ptr) __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))) -# define PREFETCH_L2(ptr) __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))) +# define PREFETCH_L1(ptr) do { __asm__ __volatile__("prfm pldl1keep, %0" ::"Q"(*(ptr))); } while (0) +# define PREFETCH_L2(ptr) do { __asm__ __volatile__("prfm pldl2keep, %0" ::"Q"(*(ptr))); } while (0) #else -# define PREFETCH_L1(ptr) (void)(ptr) /* disabled */ -# define PREFETCH_L2(ptr) (void)(ptr) /* disabled */ +# define PREFETCH_L1(ptr) do { (void)(ptr); } while (0) /* disabled */ +# define PREFETCH_L2(ptr) do { (void)(ptr); } while (0) /* disabled */ #endif /* NO_PREFETCH */ #define CACHELINE_SIZE 64 -#define PREFETCH_AREA(p, s) { \ - const char* const _ptr = (const char*)(p); \ - size_t const _size = (size_t)(s); \ - size_t _pos; \ - for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ - PREFETCH_L2(_ptr + _pos); \ - } \ -} +#define PREFETCH_AREA(p, s) \ + do { \ + const char* const _ptr = (const char*)(p); \ + size_t const _size = (size_t)(s); \ + size_t _pos; \ + for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \ + PREFETCH_L2(_ptr + _pos); \ + } \ + } while (0) /* vectorization * older GCC (pre gcc-4.3 picked as the cutoff) uses a different syntax, @@ -126,9 +143,9 @@ #define UNLIKELY(x) (__builtin_expect((x), 0)) #if __has_builtin(__builtin_unreachable) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5))) -# define ZSTD_UNREACHABLE { assert(0), __builtin_unreachable(); } +# define ZSTD_UNREACHABLE do { assert(0), __builtin_unreachable(); } while (0) #else -# define ZSTD_UNREACHABLE { assert(0); } +# define ZSTD_UNREACHABLE do { assert(0); } while (0) #endif /* disable warnings */ @@ -179,6 +196,85 @@ * Sanitizer *****************************************************************/ +/* + * Zstd relies on pointer overflow in its decompressor. + * We add this attribute to functions that rely on pointer overflow. + */ +#ifndef ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +# if __has_attribute(no_sanitize) +# if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 8 + /* gcc < 8 only has signed-integer-overlow which triggers on pointer overflow */ +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("signed-integer-overflow"))) +# else + /* older versions of clang [3.7, 5.0) will warn that pointer-overflow is ignored. */ +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR __attribute__((no_sanitize("pointer-overflow"))) +# endif +# else +# define ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +# endif +#endif + +/* + * Helper function to perform a wrapped pointer difference without trigging + * UBSAN. + * + * @returns lhs - rhs with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +ptrdiff_t ZSTD_wrappedPtrDiff(unsigned char const* lhs, unsigned char const* rhs) +{ + return lhs - rhs; +} + +/* + * Helper function to perform a wrapped pointer add without triggering UBSAN. + * + * @return ptr + add with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +unsigned char const* ZSTD_wrappedPtrAdd(unsigned char const* ptr, ptrdiff_t add) +{ + return ptr + add; +} + +/* + * Helper function to perform a wrapped pointer subtraction without triggering + * UBSAN. + * + * @return ptr - sub with wrapping + */ +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +unsigned char const* ZSTD_wrappedPtrSub(unsigned char const* ptr, ptrdiff_t sub) +{ + return ptr - sub; +} + +/* + * Helper function to add to a pointer that works around C's undefined behavior + * of adding 0 to NULL. + * + * @returns `ptr + add` except it defines `NULL + 0 == NULL`. + */ +MEM_STATIC +unsigned char* ZSTD_maybeNullPtrAdd(unsigned char* ptr, ptrdiff_t add) +{ + return add > 0 ? ptr + add : ptr; +} + +/* Issue #3240 reports an ASAN failure on an llvm-mingw build. Out of an + * abundance of caution, disable our custom poisoning on mingw. */ +#ifdef __MINGW32__ +#ifndef ZSTD_ASAN_DONT_POISON_WORKSPACE +#define ZSTD_ASAN_DONT_POISON_WORKSPACE 1 +#endif +#ifndef ZSTD_MSAN_DONT_POISON_WORKSPACE +#define ZSTD_MSAN_DONT_POISON_WORKSPACE 1 +#endif +#endif + #endif /* ZSTD_COMPILER_H */ diff --git a/lib/zstd/common/cpu.h b/lib/zstd/common/cpu.h index 0db7b42407ee..d8319a2bef4c 100644 --- a/lib/zstd/common/cpu.h +++ b/lib/zstd/common/cpu.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/common/debug.c b/lib/zstd/common/debug.c index bb863c9ea616..8eb6aa9a3b20 100644 --- a/lib/zstd/common/debug.c +++ b/lib/zstd/common/debug.c @@ -1,7 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * debug * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -21,4 +22,10 @@ #include "debug.h" +#if (DEBUGLEVEL>=2) +/* We only use this when DEBUGLEVEL>=2, but we get -Werror=pedantic errors if a + * translation unit is empty. So remove this from Linux kernel builds, but + * otherwise just leave it in. + */ int g_debuglevel = DEBUGLEVEL; +#endif diff --git a/lib/zstd/common/debug.h b/lib/zstd/common/debug.h index 6dd88d1fbd02..226ba3c57ec3 100644 --- a/lib/zstd/common/debug.h +++ b/lib/zstd/common/debug.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * debug * Part of FSE library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -82,18 +83,27 @@ extern int g_debuglevel; /* the variable is only declared, It's useful when enabling very verbose levels on selective conditions (such as position in src) */ -# define RAWLOG(l, ...) { \ - if (l<=g_debuglevel) { \ - ZSTD_DEBUG_PRINT(__VA_ARGS__); \ - } } -# define DEBUGLOG(l, ...) { \ - if (l<=g_debuglevel) { \ - ZSTD_DEBUG_PRINT(__FILE__ ": " __VA_ARGS__); \ - ZSTD_DEBUG_PRINT(" \n"); \ - } } +# define RAWLOG(l, ...) \ + do { \ + if (l<=g_debuglevel) { \ + ZSTD_DEBUG_PRINT(__VA_ARGS__); \ + } \ + } while (0) + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) +#define LINE_AS_STRING TOSTRING(__LINE__) + +# define DEBUGLOG(l, ...) \ + do { \ + if (l<=g_debuglevel) { \ + ZSTD_DEBUG_PRINT(__FILE__ ":" LINE_AS_STRING ": " __VA_ARGS__); \ + ZSTD_DEBUG_PRINT(" \n"); \ + } \ + } while (0) #else -# define RAWLOG(l, ...) {} /* disabled */ -# define DEBUGLOG(l, ...) {} /* disabled */ +# define RAWLOG(l, ...) do { } while (0) /* disabled */ +# define DEBUGLOG(l, ...) do { } while (0) /* disabled */ #endif diff --git a/lib/zstd/common/entropy_common.c b/lib/zstd/common/entropy_common.c index fef67056f052..6cdd82233fb5 100644 --- a/lib/zstd/common/entropy_common.c +++ b/lib/zstd/common/entropy_common.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * Common functions of New Generation Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -19,8 +20,8 @@ #include "error_private.h" /* ERR_*, ERROR */ #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ #include "fse.h" -#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */ #include "huf.h" +#include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ /*=== Version ===*/ @@ -38,23 +39,6 @@ const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } /*-************************************************************** * FSE NCount encoding-decoding ****************************************************************/ -static U32 FSE_ctz(U32 val) -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* GCC Intrinsic */ - return __builtin_ctz(val); -# else /* Software version */ - U32 count = 0; - while ((val & 1) == 0) { - val >>= 1; - ++count; - } - return count; -# endif - } -} - FORCE_INLINE_TEMPLATE size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, const void* headerBuffer, size_t hbSize) @@ -102,7 +86,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne * repeat. * Avoid UB by setting the high bit to 1. */ - int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; + int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; while (repeats >= 12) { charnum += 3 * 12; if (LIKELY(ip <= iend-7)) { @@ -113,7 +97,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne ip = iend - 4; } bitStream = MEM_readLE32(ip) >> bitCount; - repeats = FSE_ctz(~bitStream | 0x80000000) >> 1; + repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; } charnum += 3 * repeats; bitStream >>= 2 * repeats; @@ -178,7 +162,7 @@ size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigne * know that threshold > 1. */ if (remaining <= 1) break; - nbBits = BIT_highbit32(remaining) + 1; + nbBits = ZSTD_highbit32(remaining) + 1; threshold = 1 << (nbBits - 1); } if (charnum >= maxSV1) break; @@ -253,7 +237,7 @@ size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, const void* src, size_t srcSize) { U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; - return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* bmi2 */ 0); + return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); } FORCE_INLINE_TEMPLATE size_t @@ -301,14 +285,14 @@ HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, if (weightTotal == 0) return ERROR(corruption_detected); /* get last non-null symbol weight (implied, total must be 2^n) */ - { U32 const tableLog = BIT_highbit32(weightTotal) + 1; + { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); *tableLogPtr = tableLog; /* determine last weight */ { U32 const total = 1 << tableLog; U32 const rest = total - weightTotal; - U32 const verif = 1 << BIT_highbit32(rest); - U32 const lastWeight = BIT_highbit32(rest) + 1; + U32 const verif = 1 << ZSTD_highbit32(rest); + U32 const lastWeight = ZSTD_highbit32(rest) + 1; if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ huffWeight[oSize] = (BYTE)lastWeight; rankStats[lastWeight]++; @@ -345,13 +329,13 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, - int bmi2) + int flags) { #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); } #endif - (void)bmi2; + (void)flags; return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); } diff --git a/lib/zstd/common/error_private.c b/lib/zstd/common/error_private.c index 6d1135f8c373..a4062d30d170 100644 --- a/lib/zstd/common/error_private.c +++ b/lib/zstd/common/error_private.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -27,9 +28,11 @@ const char* ERR_getErrorString(ERR_enum code) case PREFIX(version_unsupported): return "Version not supported"; case PREFIX(frameParameter_unsupported): return "Unsupported frame parameter"; case PREFIX(frameParameter_windowTooLarge): return "Frame requires too much memory for decoding"; - case PREFIX(corruption_detected): return "Corrupted block detected"; + case PREFIX(corruption_detected): return "Data corruption detected"; case PREFIX(checksum_wrong): return "Restored data doesn't match checksum"; + case PREFIX(literals_headerWrong): return "Header of Literals' block doesn't respect format specification"; case PREFIX(parameter_unsupported): return "Unsupported parameter"; + case PREFIX(parameter_combination_unsupported): return "Unsupported combination of parameters"; case PREFIX(parameter_outOfBound): return "Parameter is out of bound"; case PREFIX(init_missing): return "Context should be init first"; case PREFIX(memory_allocation): return "Allocation error : not enough memory"; @@ -38,17 +41,22 @@ const char* ERR_getErrorString(ERR_enum code) case PREFIX(tableLog_tooLarge): return "tableLog requires too much memory : unsupported"; case PREFIX(maxSymbolValue_tooLarge): return "Unsupported max Symbol Value : too large"; case PREFIX(maxSymbolValue_tooSmall): return "Specified maxSymbolValue is too small"; + case PREFIX(stabilityCondition_notRespected): return "pledged buffer stability condition is not respected"; case PREFIX(dictionary_corrupted): return "Dictionary is corrupted"; case PREFIX(dictionary_wrong): return "Dictionary mismatch"; case PREFIX(dictionaryCreation_failed): return "Cannot create Dictionary from provided samples"; case PREFIX(dstSize_tooSmall): return "Destination buffer is too small"; case PREFIX(srcSize_wrong): return "Src size is incorrect"; case PREFIX(dstBuffer_null): return "Operation on NULL destination buffer"; + case PREFIX(noForwardProgress_destFull): return "Operation made no progress over multiple calls, due to output buffer being full"; + case PREFIX(noForwardProgress_inputEmpty): return "Operation made no progress over multiple calls, due to input being empty"; /* following error codes are not stable and may be removed or changed in a future version */ case PREFIX(frameIndex_tooLarge): return "Frame index is too large"; case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking"; case PREFIX(dstBuffer_wrong): return "Destination buffer is wrong"; case PREFIX(srcBuffer_wrong): return "Source buffer is wrong"; + case PREFIX(sequenceProducer_failed): return "Block-level external sequence producer returned an error code"; + case PREFIX(externalSequences_invalid): return "External sequences are not valid"; case PREFIX(maxCode): default: return notErrorCode; } diff --git a/lib/zstd/common/error_private.h b/lib/zstd/common/error_private.h index ca5101e542fa..0410ca415b54 100644 --- a/lib/zstd/common/error_private.h +++ b/lib/zstd/common/error_private.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -49,8 +50,13 @@ ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); } ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) { if (!ERR_isError(code)) return (ERR_enum)0; return (ERR_enum) (0-code); } /* check and forward error code */ -#define CHECK_V_F(e, f) size_t const e = f; if (ERR_isError(e)) return e -#define CHECK_F(f) { CHECK_V_F(_var_err__, f); } +#define CHECK_V_F(e, f) \ + size_t const e = f; \ + do { \ + if (ERR_isError(e)) \ + return e; \ + } while (0) +#define CHECK_F(f) do { CHECK_V_F(_var_err__, f); } while (0) /*-**************************************** @@ -84,10 +90,12 @@ void _force_has_format_string(const char *format, ...) { * We want to force this function invocation to be syntactically correct, but * we don't want to force runtime evaluation of its arguments. */ -#define _FORCE_HAS_FORMAT_STRING(...) \ - if (0) { \ - _force_has_format_string(__VA_ARGS__); \ - } +#define _FORCE_HAS_FORMAT_STRING(...) \ + do { \ + if (0) { \ + _force_has_format_string(__VA_ARGS__); \ + } \ + } while (0) #define ERR_QUOTE(str) #str @@ -98,48 +106,50 @@ void _force_has_format_string(const char *format, ...) { * In order to do that (particularly, printing the conditional that failed), * this can't just wrap RETURN_ERROR(). */ -#define RETURN_ERROR_IF(cond, err, ...) \ - if (cond) { \ - RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ - __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return ERROR(err); \ - } +#define RETURN_ERROR_IF(cond, err, ...) \ + do { \ + if (cond) { \ + RAWLOG(3, "%s:%d: ERROR!: check %s failed, returning %s", \ + __FILE__, __LINE__, ERR_QUOTE(cond), ERR_QUOTE(ERROR(err))); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return ERROR(err); \ + } \ + } while (0) /* * Unconditionally return the specified error. * * In debug modes, prints additional information. */ -#define RETURN_ERROR(err, ...) \ - do { \ - RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ - __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return ERROR(err); \ - } while(0); +#define RETURN_ERROR(err, ...) \ + do { \ + RAWLOG(3, "%s:%d: ERROR!: unconditional check failed, returning %s", \ + __FILE__, __LINE__, ERR_QUOTE(ERROR(err))); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return ERROR(err); \ + } while(0) /* * If the provided expression evaluates to an error code, returns that error code. * * In debug modes, prints additional information. */ -#define FORWARD_IF_ERROR(err, ...) \ - do { \ - size_t const err_code = (err); \ - if (ERR_isError(err_code)) { \ - RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ - __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ - _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ - RAWLOG(3, ": " __VA_ARGS__); \ - RAWLOG(3, "\n"); \ - return err_code; \ - } \ - } while(0); +#define FORWARD_IF_ERROR(err, ...) \ + do { \ + size_t const err_code = (err); \ + if (ERR_isError(err_code)) { \ + RAWLOG(3, "%s:%d: ERROR!: forwarding error in %s: %s", \ + __FILE__, __LINE__, ERR_QUOTE(err), ERR_getErrorName(err_code)); \ + _FORCE_HAS_FORMAT_STRING(__VA_ARGS__); \ + RAWLOG(3, ": " __VA_ARGS__); \ + RAWLOG(3, "\n"); \ + return err_code; \ + } \ + } while(0) #endif /* ERROR_H_MODULE */ diff --git a/lib/zstd/common/fse.h b/lib/zstd/common/fse.h index 4507043b2287..2185a578617d 100644 --- a/lib/zstd/common/fse.h +++ b/lib/zstd/common/fse.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * FSE : Finite State Entropy codec * Public Prototypes declaration - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -50,34 +51,6 @@ FSE_PUBLIC_API unsigned FSE_versionNumber(void); /*< library version number; to be used when checking dll version */ -/*-**************************************** -* FSE simple functions -******************************************/ -/*! FSE_compress() : - Compress content of buffer 'src', of size 'srcSize', into destination buffer 'dst'. - 'dst' buffer must be already allocated. Compression runs faster is dstCapacity >= FSE_compressBound(srcSize). - @return : size of compressed data (<= dstCapacity). - Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! - if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression instead. - if FSE_isError(return), compression failed (more details using FSE_getErrorName()) -*/ -FSE_PUBLIC_API size_t FSE_compress(void* dst, size_t dstCapacity, - const void* src, size_t srcSize); - -/*! FSE_decompress(): - Decompress FSE data from buffer 'cSrc', of size 'cSrcSize', - into already allocated destination buffer 'dst', of size 'dstCapacity'. - @return : size of regenerated data (<= maxDstSize), - or an error code, which can be tested using FSE_isError() . - - ** Important ** : FSE_decompress() does not decompress non-compressible nor RLE data !!! - Why ? : making this distinction requires a header. - Header management is intentionally delegated to the user layer, which can better manage special cases. -*/ -FSE_PUBLIC_API size_t FSE_decompress(void* dst, size_t dstCapacity, - const void* cSrc, size_t cSrcSize); - - /*-***************************************** * Tool functions ******************************************/ @@ -88,20 +61,6 @@ FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return FSE_PUBLIC_API const char* FSE_getErrorName(size_t code); /* provides error code string (useful for debugging) */ -/*-***************************************** -* FSE advanced functions -******************************************/ -/*! FSE_compress2() : - Same as FSE_compress(), but allows the selection of 'maxSymbolValue' and 'tableLog' - Both parameters can be defined as '0' to mean : use default value - @return : size of compressed data - Special values : if return == 0, srcData is not compressible => Nothing is stored within cSrc !!! - if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression. - if FSE_isError(return), it's an error code. -*/ -FSE_PUBLIC_API size_t FSE_compress2 (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); - - /*-***************************************** * FSE detailed API ******************************************/ @@ -161,8 +120,6 @@ FSE_PUBLIC_API size_t FSE_writeNCount (void* buffer, size_t bufferSize, /*! Constructor and Destructor of FSE_CTable. Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */ typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */ -FSE_PUBLIC_API FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog); -FSE_PUBLIC_API void FSE_freeCTable (FSE_CTable* ct); /*! FSE_buildCTable(): Builds `ct`, which must be already allocated, using FSE_createCTable(). @@ -238,23 +195,7 @@ FSE_PUBLIC_API size_t FSE_readNCount_bmi2(short* normalizedCounter, unsigned* maxSymbolValuePtr, unsigned* tableLogPtr, const void* rBuffer, size_t rBuffSize, int bmi2); -/*! Constructor and Destructor of FSE_DTable. - Note that its size depends on 'tableLog' */ typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */ -FSE_PUBLIC_API FSE_DTable* FSE_createDTable(unsigned tableLog); -FSE_PUBLIC_API void FSE_freeDTable(FSE_DTable* dt); - -/*! FSE_buildDTable(): - Builds 'dt', which must be already allocated, using FSE_createDTable(). - return : 0, or an errorCode, which can be tested using FSE_isError() */ -FSE_PUBLIC_API size_t FSE_buildDTable (FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog); - -/*! FSE_decompress_usingDTable(): - Decompress compressed source `cSrc` of size `cSrcSize` using `dt` - into `dst` which must be already allocated. - @return : size of regenerated data (necessarily <= `dstCapacity`), - or an errorCode, which can be tested using FSE_isError() */ -FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, const FSE_DTable* dt); /*! Tutorial : @@ -286,6 +227,7 @@ If there is an error, the function will return an error code, which can be teste #endif /* FSE_H */ + #if !defined(FSE_H_FSE_STATIC_LINKING_ONLY) #define FSE_H_FSE_STATIC_LINKING_ONLY @@ -317,16 +259,6 @@ If there is an error, the function will return an error code, which can be teste unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus); /*< same as FSE_optimalTableLog(), which used `minus==2` */ -/* FSE_compress_wksp() : - * Same as FSE_compress2(), but using an externally allocated scratch buffer (`workSpace`). - * FSE_COMPRESS_WKSP_SIZE_U32() provides the minimum size required for `workSpace` as a table of FSE_CTable. - */ -#define FSE_COMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) ( FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) + ((maxTableLog > 12) ? (1 << (maxTableLog - 2)) : 1024) ) -size_t FSE_compress_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); - -size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits); -/*< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */ - size_t FSE_buildCTable_rle (FSE_CTable* ct, unsigned char symbolValue); /*< build a fake FSE_CTable, designed to compress always the same symbolValue */ @@ -344,19 +276,11 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsi FSE_PUBLIC_API size_t FSE_buildDTable_wksp(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); /*< Same as FSE_buildDTable(), using an externally allocated `workspace` produced with `FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxSymbolValue)` */ -size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits); -/*< build a fake FSE_DTable, designed to read a flat distribution where each symbol uses nbBits */ - -size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue); -/*< build a fake FSE_DTable, designed to always generate the same symbolValue */ - -#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) +#define FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) (FSE_DTABLE_SIZE_U32(maxTableLog) + 1 + FSE_BUILD_DTABLE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) + (FSE_MAX_SYMBOL_VALUE + 1) / 2 + 1) #define FSE_DECOMPRESS_WKSP_SIZE(maxTableLog, maxSymbolValue) (FSE_DECOMPRESS_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(unsigned)) -size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize); -/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)` */ - size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize, int bmi2); -/*< Same as FSE_decompress_wksp() but with dynamic BMI2 support. Pass 1 if your CPU supports BMI2 or 0 if it doesn't. */ +/*< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DECOMPRESS_WKSP_SIZE_U32(maxLog, maxSymbolValue)`. + * Set bmi2 to 1 if your CPU supports BMI2 or 0 if it doesn't */ typedef enum { FSE_repeat_none, /*< Cannot use the previous table */ @@ -539,20 +463,20 @@ MEM_STATIC void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* statePtr, un FSE_symbolCompressionTransform const symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol]; const U16* const stateTable = (const U16*)(statePtr->stateTable); U32 const nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16); - BIT_addBits(bitC, statePtr->value, nbBitsOut); + BIT_addBits(bitC, (size_t)statePtr->value, nbBitsOut); statePtr->value = stateTable[ (statePtr->value >> nbBitsOut) + symbolTT.deltaFindState]; } MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePtr) { - BIT_addBits(bitC, statePtr->value, statePtr->stateLog); + BIT_addBits(bitC, (size_t)statePtr->value, statePtr->stateLog); BIT_flushBits(bitC); } /* FSE_getMaxNbBits() : * Approximate maximum cost of a symbol, in bits. - * Fractional get rounded up (i.e : a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) + * Fractional get rounded up (i.e. a symbol with a normalized frequency of 3 gives the same result as a frequency of 2) * note 1 : assume symbolValue is valid (<= maxSymbolValue) * note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits */ MEM_STATIC U32 FSE_getMaxNbBits(const void* symbolTTPtr, U32 symbolValue) diff --git a/lib/zstd/common/fse_decompress.c b/lib/zstd/common/fse_decompress.c index 8dcb8ca39767..3a17e84f27bf 100644 --- a/lib/zstd/common/fse_decompress.c +++ b/lib/zstd/common/fse_decompress.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * FSE : Finite State Entropy decoder - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -22,8 +23,8 @@ #define FSE_STATIC_LINKING_ONLY #include "fse.h" #include "error_private.h" -#define ZSTD_DEPS_NEED_MALLOC -#include "zstd_deps.h" +#include "zstd_deps.h" /* ZSTD_memcpy */ +#include "bits.h" /* ZSTD_highbit32 */ /* ************************************************************** @@ -55,19 +56,6 @@ #define FSE_FUNCTION_NAME(X,Y) FSE_CAT(X,Y) #define FSE_TYPE_NAME(X,Y) FSE_CAT(X,Y) - -/* Function templates */ -FSE_DTable* FSE_createDTable (unsigned tableLog) -{ - if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; - return (FSE_DTable*)ZSTD_malloc( FSE_DTABLE_SIZE_U32(tableLog) * sizeof (U32) ); -} - -void FSE_freeDTable (FSE_DTable* dt) -{ - ZSTD_free(dt); -} - static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize) { void* const tdPtr = dt+1; /* because *dt is unsigned, 32-bits aligned on 32-bits */ @@ -96,7 +84,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo symbolNext[s] = 1; } else { if (normalizedCounter[s] >= largeLimit) DTableH.fastMode=0; - symbolNext[s] = normalizedCounter[s]; + symbolNext[s] = (U16)normalizedCounter[s]; } } } ZSTD_memcpy(dt, &DTableH, sizeof(DTableH)); } @@ -111,8 +99,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo * all symbols have counts <= 8. We ensure we have 8 bytes at the end of * our buffer to handle the over-write. */ - { - U64 const add = 0x0101010101010101ull; + { U64 const add = 0x0101010101010101ull; size_t pos = 0; U64 sv = 0; U32 s; @@ -123,14 +110,13 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo for (i = 8; i < n; i += 8) { MEM_write64(spread + pos + i, sv); } - pos += n; - } - } + pos += (size_t)n; + } } /* Now we spread those positions across the table. - * The benefit of doing it in two stages is that we avoid the the + * The benefit of doing it in two stages is that we avoid the * variable size inner loop, which caused lots of branch misses. * Now we can run through all the positions without any branch misses. - * We unroll the loop twice, since that is what emperically worked best. + * We unroll the loop twice, since that is what empirically worked best. */ { size_t position = 0; @@ -166,7 +152,7 @@ static size_t FSE_buildDTable_internal(FSE_DTable* dt, const short* normalizedCo for (u=0; utableLog = 0; - DTableH->fastMode = 0; - - cell->newState = 0; - cell->symbol = symbolValue; - cell->nbBits = 0; - - return 0; -} - - -size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits) -{ - void* ptr = dt; - FSE_DTableHeader* const DTableH = (FSE_DTableHeader*)ptr; - void* dPtr = dt + 1; - FSE_decode_t* const dinfo = (FSE_decode_t*)dPtr; - const unsigned tableSize = 1 << nbBits; - const unsigned tableMask = tableSize - 1; - const unsigned maxSV1 = tableMask+1; - unsigned s; - - /* Sanity checks */ - if (nbBits < 1) return ERROR(GENERIC); /* min size */ - - /* Build Decoding Table */ - DTableH->tableLog = (U16)nbBits; - DTableH->fastMode = 1; - for (s=0; sfastMode; - - /* select fast mode (static) */ - if (fastMode) return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 1); - return FSE_decompress_usingDTable_generic(dst, originalSize, cSrc, cSrcSize, dt, 0); -} - - -size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, unsigned maxLog, void* workSpace, size_t wkspSize) -{ - return FSE_decompress_wksp_bmi2(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize, /* bmi2 */ 0); + assert(op >= ostart); + return (size_t)(op-ostart); } typedef struct { short ncount[FSE_MAX_SYMBOL_VALUE + 1]; - FSE_DTable dtable[]; /* Dynamically sized */ } FSE_DecompressWksp; @@ -327,13 +250,18 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( unsigned tableLog; unsigned maxSymbolValue = FSE_MAX_SYMBOL_VALUE; FSE_DecompressWksp* const wksp = (FSE_DecompressWksp*)workSpace; + size_t const dtablePos = sizeof(FSE_DecompressWksp) / sizeof(FSE_DTable); + FSE_DTable* const dtable = (FSE_DTable*)workSpace + dtablePos; - DEBUG_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); + FSE_STATIC_ASSERT((FSE_MAX_SYMBOL_VALUE + 1) % 2 == 0); if (wkspSize < sizeof(*wksp)) return ERROR(GENERIC); + /* correct offset to dtable depends on this property */ + FSE_STATIC_ASSERT(sizeof(FSE_DecompressWksp) % sizeof(FSE_DTable) == 0); + /* normal FSE decoding mode */ - { - size_t const NCountLength = FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); + { size_t const NCountLength = + FSE_readNCount_bmi2(wksp->ncount, &maxSymbolValue, &tableLog, istart, cSrcSize, bmi2); if (FSE_isError(NCountLength)) return NCountLength; if (tableLog > maxLog) return ERROR(tableLog_tooLarge); assert(NCountLength <= cSrcSize); @@ -342,19 +270,20 @@ FORCE_INLINE_TEMPLATE size_t FSE_decompress_wksp_body( } if (FSE_DECOMPRESS_WKSP_SIZE(tableLog, maxSymbolValue) > wkspSize) return ERROR(tableLog_tooLarge); - workSpace = wksp->dtable + FSE_DTABLE_SIZE_U32(tableLog); + assert(sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog) <= wkspSize); + workSpace = (BYTE*)workSpace + sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); wkspSize -= sizeof(*wksp) + FSE_DTABLE_SIZE(tableLog); - CHECK_F( FSE_buildDTable_internal(wksp->dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); + CHECK_F( FSE_buildDTable_internal(dtable, wksp->ncount, maxSymbolValue, tableLog, workSpace, wkspSize) ); { - const void* ptr = wksp->dtable; + const void* ptr = dtable; const FSE_DTableHeader* DTableH = (const FSE_DTableHeader*)ptr; const U32 fastMode = DTableH->fastMode; /* select fast mode (static) */ - if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 1); - return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, wksp->dtable, 0); + if (fastMode) return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 1); + return FSE_decompress_usingDTable_generic(dst, dstCapacity, ip, cSrcSize, dtable, 0); } } @@ -382,9 +311,4 @@ size_t FSE_decompress_wksp_bmi2(void* dst, size_t dstCapacity, const void* cSrc, return FSE_decompress_wksp_body_default(dst, dstCapacity, cSrc, cSrcSize, maxLog, workSpace, wkspSize); } - -typedef FSE_DTable DTable_max_t[FSE_DTABLE_SIZE_U32(FSE_MAX_TABLELOG)]; - - - #endif /* FSE_COMMONDEFS_ONLY */ diff --git a/lib/zstd/common/huf.h b/lib/zstd/common/huf.h index 5042ff870308..57462466e188 100644 --- a/lib/zstd/common/huf.h +++ b/lib/zstd/common/huf.h @@ -1,7 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* ****************************************************************** * huff0 huffman codec, * part of Finite State Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -18,99 +19,22 @@ /* *** Dependencies *** */ #include "zstd_deps.h" /* size_t */ - - -/* *** library symbols visibility *** */ -/* Note : when linking with -fvisibility=hidden on gcc, or by default on Visual, - * HUF symbols remain "private" (internal symbols for library only). - * Set macro FSE_DLL_EXPORT to 1 if you want HUF symbols visible on DLL interface */ -#if defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) && defined(__GNUC__) && (__GNUC__ >= 4) -# define HUF_PUBLIC_API __attribute__ ((visibility ("default"))) -#elif defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) /* Visual expected */ -# define HUF_PUBLIC_API __declspec(dllexport) -#elif defined(FSE_DLL_IMPORT) && (FSE_DLL_IMPORT==1) -# define HUF_PUBLIC_API __declspec(dllimport) /* not required, just to generate faster code (saves a function pointer load from IAT and an indirect jump) */ -#else -# define HUF_PUBLIC_API -#endif - - -/* ========================== */ -/* *** simple functions *** */ -/* ========================== */ - -/* HUF_compress() : - * Compress content from buffer 'src', of size 'srcSize', into buffer 'dst'. - * 'dst' buffer must be already allocated. - * Compression runs faster if `dstCapacity` >= HUF_compressBound(srcSize). - * `srcSize` must be <= `HUF_BLOCKSIZE_MAX` == 128 KB. - * @return : size of compressed data (<= `dstCapacity`). - * Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!! - * if HUF_isError(return), compression failed (more details using HUF_getErrorName()) - */ -HUF_PUBLIC_API size_t HUF_compress(void* dst, size_t dstCapacity, - const void* src, size_t srcSize); - -/* HUF_decompress() : - * Decompress HUF data from buffer 'cSrc', of size 'cSrcSize', - * into already allocated buffer 'dst', of minimum size 'dstSize'. - * `originalSize` : **must** be the ***exact*** size of original (uncompressed) data. - * Note : in contrast with FSE, HUF_decompress can regenerate - * RLE (cSrcSize==1) and uncompressed (cSrcSize==dstSize) data, - * because it knows size to regenerate (originalSize). - * @return : size of regenerated data (== originalSize), - * or an error code, which can be tested using HUF_isError() - */ -HUF_PUBLIC_API size_t HUF_decompress(void* dst, size_t originalSize, - const void* cSrc, size_t cSrcSize); +#include "mem.h" /* U32 */ +#define FSE_STATIC_LINKING_ONLY +#include "fse.h" /* *** Tool functions *** */ -#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ -HUF_PUBLIC_API size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ +#define HUF_BLOCKSIZE_MAX (128 * 1024) /*< maximum input size for a single block compressed with HUF_compress */ +size_t HUF_compressBound(size_t size); /*< maximum compressed size (worst case) */ /* Error Management */ -HUF_PUBLIC_API unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ -HUF_PUBLIC_API const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ - +unsigned HUF_isError(size_t code); /*< tells if a return value is an error code */ +const char* HUF_getErrorName(size_t code); /*< provides error code string (useful for debugging) */ -/* *** Advanced function *** */ -/* HUF_compress2() : - * Same as HUF_compress(), but offers control over `maxSymbolValue` and `tableLog`. - * `maxSymbolValue` must be <= HUF_SYMBOLVALUE_MAX . - * `tableLog` must be `<= HUF_TABLELOG_MAX` . */ -HUF_PUBLIC_API size_t HUF_compress2 (void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned tableLog); - -/* HUF_compress4X_wksp() : - * Same as HUF_compress2(), but uses externally allocated `workSpace`. - * `workspace` must be at least as large as HUF_WORKSPACE_SIZE */ #define HUF_WORKSPACE_SIZE ((8 << 10) + 512 /* sorting scratch space */) #define HUF_WORKSPACE_SIZE_U64 (HUF_WORKSPACE_SIZE / sizeof(U64)) -HUF_PUBLIC_API size_t HUF_compress4X_wksp (void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned tableLog, - void* workSpace, size_t wkspSize); - -#endif /* HUF_H_298734234 */ - -/* ****************************************************************** - * WARNING !! - * The following section contains advanced and experimental definitions - * which shall never be used in the context of a dynamic library, - * because they are not guaranteed to remain stable in the future. - * Only consider them in association with static linking. - * *****************************************************************/ -#if !defined(HUF_H_HUF_STATIC_LINKING_ONLY) -#define HUF_H_HUF_STATIC_LINKING_ONLY - -/* *** Dependencies *** */ -#include "mem.h" /* U32 */ -#define FSE_STATIC_LINKING_ONLY -#include "fse.h" - /* *** Constants *** */ #define HUF_TABLELOG_MAX 12 /* max runtime value of tableLog (due to static allocation); can be modified up to HUF_TABLELOG_ABSOLUTEMAX */ @@ -151,25 +75,49 @@ typedef U32 HUF_DTable; /* **************************************** * Advanced decompression functions ******************************************/ -size_t HUF_decompress4X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -#endif -size_t HUF_decompress4X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< decodes RLE and uncompressed */ -size_t HUF_decompress4X_hufOnly(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< considers RLE and uncompressed as errors */ -size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< considers RLE and uncompressed as errors */ -size_t HUF_decompress4X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ -#endif +/* + * Huffman flags bitset. + * For all flags, 0 is the default value. + */ +typedef enum { + /* + * If compiled with DYNAMIC_BMI2: Set flag only if the CPU supports BMI2 at runtime. + * Otherwise: Ignored. + */ + HUF_flags_bmi2 = (1 << 0), + /* + * If set: Test possible table depths to find the one that produces the smallest header + encoded size. + * If unset: Use heuristic to find the table depth. + */ + HUF_flags_optimalDepth = (1 << 1), + /* + * If set: If the previous table can encode the input, always reuse the previous table. + * If unset: If the previous table can encode the input, reuse the previous table if it results in a smaller output. + */ + HUF_flags_preferRepeat = (1 << 2), + /* + * If set: Sample the input and check if the sample is uncompressible, if it is then don't attempt to compress. + * If unset: Always histogram the entire input. + */ + HUF_flags_suspectUncompressible = (1 << 3), + /* + * If set: Don't use assembly implementations + * If unset: Allow using assembly implementations + */ + HUF_flags_disableAsm = (1 << 4), + /* + * If set: Don't use the fast decoding loop, always use the fallback decoding loop. + * If unset: Use the fast decoding loop when possible. + */ + HUF_flags_disableFast = (1 << 5) +} HUF_flags_e; /* **************************************** * HUF detailed API * ****************************************/ +#define HUF_OPTIMAL_DEPTH_THRESHOLD ZSTD_btultra /*! HUF_compress() does the following: * 1. count symbol occurrence from source[] into table count[] using FSE_count() (exposed within "fse.h") @@ -182,12 +130,12 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, * For example, it's possible to compress several blocks using the same 'CTable', * or to save and regenerate 'CTable' using external methods. */ -unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue); -size_t HUF_buildCTable (HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue, unsigned maxNbBits); /* @return : maxNbBits; CTable and count can overlap. In which case, CTable will overwrite count content */ -size_t HUF_writeCTable (void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog); +unsigned HUF_minTableLog(unsigned symbolCardinality); +unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue); +unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, void* workSpace, + size_t wkspSize, HUF_CElt* table, const unsigned* count, int flags); /* table is used as scratch space for building and testing tables, not a return value */ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog, void* workspace, size_t workspaceSize); -size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); -size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); +size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); @@ -196,6 +144,7 @@ typedef enum { HUF_repeat_check, /*< Can use the previous table but it must be checked. Note : The previous table must have been constructed by HUF_compress{1, 4}X_repeat */ HUF_repeat_valid /*< Can use the previous table and it is assumed to be valid */ } HUF_repeat; + /* HUF_compress4X_repeat() : * Same as HUF_compress4X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. @@ -206,13 +155,13 @@ size_t HUF_compress4X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); + HUF_CElt* hufTable, HUF_repeat* repeat, int flags); /* HUF_buildCTable_wksp() : * Same as HUF_buildCTable(), but using externally allocated scratch buffer. * `workSpace` must be aligned on 4-bytes boundaries, and its size must be >= HUF_CTABLE_WORKSPACE_SIZE. */ -#define HUF_CTABLE_WORKSPACE_SIZE_U32 (2*HUF_SYMBOLVALUE_MAX +1 +1) +#define HUF_CTABLE_WORKSPACE_SIZE_U32 ((4 * (HUF_SYMBOLVALUE_MAX + 1)) + 192) #define HUF_CTABLE_WORKSPACE_SIZE (HUF_CTABLE_WORKSPACE_SIZE_U32 * sizeof(unsigned)) size_t HUF_buildCTable_wksp (HUF_CElt* tree, const unsigned* count, U32 maxSymbolValue, U32 maxNbBits, @@ -238,7 +187,7 @@ size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize, void* workspace, size_t wkspSize, - int bmi2); + int flags); /* HUF_readCTable() : * Loading a CTable saved with HUF_writeCTable() */ @@ -246,9 +195,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void /* HUF_getNbBitsFromCTable() : * Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX - * Note 1 : is not inlined, as HUF_CElt definition is private */ + * Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0 + * Note 2 : is not inlined, as HUF_CElt definition is private + */ U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue); +typedef struct { + BYTE tableLog; + BYTE maxSymbolValue; + BYTE unused[sizeof(size_t) - 2]; +} HUF_CTableHeader; + +/* HUF_readCTableHeader() : + * @returns The header from the CTable specifying the tableLog and the maxSymbolValue. + */ +HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable); + /* * HUF_decompress() does the following: * 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics @@ -276,32 +238,12 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize); #define HUF_DECOMPRESS_WORKSPACE_SIZE ((2 << 10) + (1 << 9)) #define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_readDTableX1 (HUF_DTable* DTable, const void* src, size_t srcSize); -size_t HUF_readDTableX1_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_readDTableX2 (HUF_DTable* DTable, const void* src, size_t srcSize); -size_t HUF_readDTableX2_wksp (HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize); -#endif - -size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress4X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress4X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif - /* ====================== */ /* single stream variants */ /* ====================== */ -size_t HUF_compress1X (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog); -size_t HUF_compress1X_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize); /*< `workSpace` must be a table of at least HUF_WORKSPACE_SIZE_U64 U64 */ -size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable); -size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2); +size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); /* HUF_compress1X_repeat() : * Same as HUF_compress1X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. @@ -312,47 +254,28 @@ size_t HUF_compress1X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /*< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible); + HUF_CElt* hufTable, HUF_repeat* repeat, int flags); -size_t HUF_decompress1X1 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* single-symbol decoder */ -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /* double-symbol decoder */ -#endif - -size_t HUF_decompress1X_DCtx (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); -size_t HUF_decompress1X_DCtx_wksp (HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< single-symbol decoder */ -size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< single-symbol decoder */ -#endif +size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2_DCtx(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize); /*< double-symbols decoder */ -size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize); /*< double-symbols decoder */ -#endif - -size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); /*< automatic selection of sing or double symbol decoder, based on DTable */ -#ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); -#endif -#ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_decompress1X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable); +size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); /*< double-symbols decoder */ #endif /* BMI2 variants. * If the CPU has BMI2 support, pass bmi2=1, otherwise pass bmi2=0. */ -size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); +size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #endif -size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2); -size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif #ifndef HUF_FORCE_DECOMPRESS_X1 -size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2); +size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif -#endif /* HUF_STATIC_LINKING_ONLY */ +#endif /* HUF_H_298734234 */ diff --git a/lib/zstd/common/mem.h b/lib/zstd/common/mem.h index c22a2e69bf46..d9bd752fe17b 100644 --- a/lib/zstd/common/mem.h +++ b/lib/zstd/common/mem.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -24,6 +24,7 @@ /*-**************************************** * Compiler specifics ******************************************/ +#undef MEM_STATIC /* may be already defined from common/compiler.h */ #define MEM_STATIC static inline /*-************************************************************** diff --git a/lib/zstd/common/portability_macros.h b/lib/zstd/common/portability_macros.h index 0e3b2c0a527d..f08638cced6c 100644 --- a/lib/zstd/common/portability_macros.h +++ b/lib/zstd/common/portability_macros.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -12,7 +13,7 @@ #define ZSTD_PORTABILITY_MACROS_H /* - * This header file contains macro defintions to support portability. + * This header file contains macro definitions to support portability. * This header is shared between C and ASM code, so it MUST only * contain macro definitions. It MUST not contain any C code. * @@ -45,6 +46,8 @@ /* Mark the internal assembly functions as hidden */ #ifdef __ELF__ # define ZSTD_HIDE_ASM_FUNCTION(func) .hidden func +#elif defined(__APPLE__) +# define ZSTD_HIDE_ASM_FUNCTION(func) .private_extern func #else # define ZSTD_HIDE_ASM_FUNCTION(func) #endif @@ -65,7 +68,7 @@ #endif /* - * Only enable assembly for GNUC comptabile compilers, + * Only enable assembly for GNUC compatible compilers, * because other platforms may not support GAS assembly syntax. * * Only enable assembly for Linux / MacOS, other platforms may @@ -90,4 +93,23 @@ */ #define ZSTD_ENABLE_ASM_X86_64_BMI2 0 +/* + * For x86 ELF targets, add .note.gnu.property section for Intel CET in + * assembly sources when CET is enabled. + * + * Additionally, any function that may be called indirectly must begin + * with ZSTD_CET_ENDBRANCH. + */ +#if defined(__ELF__) && (defined(__x86_64__) || defined(__i386__)) \ + && defined(__has_include) +# if __has_include() +# include +# define ZSTD_CET_ENDBRANCH _CET_ENDBR +# endif +#endif + +#ifndef ZSTD_CET_ENDBRANCH +# define ZSTD_CET_ENDBRANCH +#endif + #endif /* ZSTD_PORTABILITY_MACROS_H */ diff --git a/lib/zstd/common/zstd_common.c b/lib/zstd/common/zstd_common.c index 3d7e35b309b5..44b95b25344a 100644 --- a/lib/zstd/common/zstd_common.c +++ b/lib/zstd/common/zstd_common.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,7 +15,6 @@ * Dependencies ***************************************/ #define ZSTD_DEPS_NEED_MALLOC -#include "zstd_deps.h" /* ZSTD_malloc, ZSTD_calloc, ZSTD_free, ZSTD_memset */ #include "error_private.h" #include "zstd_internal.h" @@ -47,37 +47,3 @@ ZSTD_ErrorCode ZSTD_getErrorCode(size_t code) { return ERR_getErrorCode(code); } /*! ZSTD_getErrorString() : * provides error code string from enum */ const char* ZSTD_getErrorString(ZSTD_ErrorCode code) { return ERR_getErrorString(code); } - - - -/*=************************************************************** -* Custom allocator -****************************************************************/ -void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem) -{ - if (customMem.customAlloc) - return customMem.customAlloc(customMem.opaque, size); - return ZSTD_malloc(size); -} - -void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem) -{ - if (customMem.customAlloc) { - /* calloc implemented as malloc+memset; - * not as efficient as calloc, but next best guess for custom malloc */ - void* const ptr = customMem.customAlloc(customMem.opaque, size); - ZSTD_memset(ptr, 0, size); - return ptr; - } - return ZSTD_calloc(1, size); -} - -void ZSTD_customFree(void* ptr, ZSTD_customMem customMem) -{ - if (ptr!=NULL) { - if (customMem.customFree) - customMem.customFree(customMem.opaque, ptr); - else - ZSTD_free(ptr); - } -} diff --git a/lib/zstd/common/zstd_deps.h b/lib/zstd/common/zstd_deps.h index 2c34e8a33a1c..f931f7d0e294 100644 --- a/lib/zstd/common/zstd_deps.h +++ b/lib/zstd/common/zstd_deps.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -105,3 +105,17 @@ static uint64_t ZSTD_div64(uint64_t dividend, uint32_t divisor) { #endif /* ZSTD_DEPS_IO */ #endif /* ZSTD_DEPS_NEED_IO */ + +/* + * Only requested when MSAN is enabled. + * Need: + * intptr_t + */ +#ifdef ZSTD_DEPS_NEED_STDINT +#ifndef ZSTD_DEPS_STDINT +#define ZSTD_DEPS_STDINT + +/* intptr_t already provided by ZSTD_DEPS_COMMON */ + +#endif /* ZSTD_DEPS_STDINT */ +#endif /* ZSTD_DEPS_NEED_STDINT */ diff --git a/lib/zstd/common/zstd_internal.h b/lib/zstd/common/zstd_internal.h index 93305d9b41bb..11da1233e890 100644 --- a/lib/zstd/common/zstd_internal.h +++ b/lib/zstd/common/zstd_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -28,7 +29,6 @@ #include #define FSE_STATIC_LINKING_ONLY #include "fse.h" -#define HUF_STATIC_LINKING_ONLY #include "huf.h" #include /* XXH_reset, update, digest */ #define ZSTD_TRACE 0 @@ -83,9 +83,9 @@ typedef enum { bt_raw, bt_rle, bt_compressed, bt_reserved } blockType_e; #define ZSTD_FRAMECHECKSUMSIZE 4 #define MIN_SEQUENCES_SIZE 1 /* nbSeq==0 */ -#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */ + MIN_SEQUENCES_SIZE /* nbSeq==0 */) /* for a non-null block */ +#define MIN_CBLOCK_SIZE (1 /*litCSize*/ + 1 /* RLE or RAW */) /* for a non-null block */ +#define MIN_LITERALS_FOR_4_STREAMS 6 -#define HufLog 12 typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingType_e; #define LONGNBSEQ 0x7F00 @@ -93,6 +93,7 @@ typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingTy #define MINMATCH 3 #define Litbits 8 +#define LitHufLog 11 #define MaxLit ((1<= WILDCOPY_VECLEN || diff <= -WILDCOPY_VECLEN); @@ -225,12 +228,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e * one COPY16() in the first call. Then, do two calls per loop since * at that point it is more likely to have a high trip count. */ -#ifdef __aarch64__ - do { - COPY16(op, ip); - } - while (op < oend); -#else ZSTD_copy16(op, ip); if (16 >= length) return; op += 16; @@ -240,7 +237,6 @@ void ZSTD_wildcopy(void* dst, const void* src, ptrdiff_t length, ZSTD_overlap_e COPY16(op, ip); } while (op < oend); -#endif } } @@ -289,11 +285,11 @@ typedef enum { typedef struct { seqDef* sequencesStart; seqDef* sequences; /* ptr to end of sequences */ - BYTE* litStart; - BYTE* lit; /* ptr to end of literals */ - BYTE* llCode; - BYTE* mlCode; - BYTE* ofCode; + BYTE* litStart; + BYTE* lit; /* ptr to end of literals */ + BYTE* llCode; + BYTE* mlCode; + BYTE* ofCode; size_t maxNbSeq; size_t maxNbLit; @@ -301,8 +297,8 @@ typedef struct { * in the seqStore that has a value larger than U16 (if it exists). To do so, we increment * the existing value of the litLength or matchLength by 0x10000. */ - ZSTD_longLengthType_e longLengthType; - U32 longLengthPos; /* Index of the sequence to apply long length modification to */ + ZSTD_longLengthType_e longLengthType; + U32 longLengthPos; /* Index of the sequence to apply long length modification to */ } seqStore_t; typedef struct { @@ -321,10 +317,10 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore seqLen.matchLength = seq->mlBase + MINMATCH; if (seqStore->longLengthPos == (U32)(seq - seqStore->sequencesStart)) { if (seqStore->longLengthType == ZSTD_llt_literalLength) { - seqLen.litLength += 0xFFFF; + seqLen.litLength += 0x10000; } if (seqStore->longLengthType == ZSTD_llt_matchLength) { - seqLen.matchLength += 0xFFFF; + seqLen.matchLength += 0x10000; } } return seqLen; @@ -337,72 +333,13 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore * `decompressedBound != ZSTD_CONTENTSIZE_ERROR` */ typedef struct { + size_t nbBlocks; size_t compressedSize; unsigned long long decompressedBound; } ZSTD_frameSizeInfo; /* decompress & legacy */ const seqStore_t* ZSTD_getSeqStore(const ZSTD_CCtx* ctx); /* compress & dictBuilder */ -void ZSTD_seqToCodes(const seqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ - -/* custom memory allocation functions */ -void* ZSTD_customMalloc(size_t size, ZSTD_customMem customMem); -void* ZSTD_customCalloc(size_t size, ZSTD_customMem customMem); -void ZSTD_customFree(void* ptr, ZSTD_customMem customMem); - - -MEM_STATIC U32 ZSTD_highbit32(U32 val) /* compress, dictBuilder, decodeCorpus */ -{ - assert(val != 0); - { -# if (__GNUC__ >= 3) /* GCC Intrinsic */ - return __builtin_clz (val) ^ 31; -# else /* Software version */ - static const U32 DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31 }; - U32 v = val; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - return DeBruijnClz[(v * 0x07C4ACDDU) >> 27]; -# endif - } -} - -/* - * Counts the number of trailing zeros of a `size_t`. - * Most compilers should support CTZ as a builtin. A backup - * implementation is provided if the builtin isn't supported, but - * it may not be terribly efficient. - */ -MEM_STATIC unsigned ZSTD_countTrailingZeros(size_t val) -{ - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return __builtin_ctzll((U64)val); -# else - static const int DeBruijnBytePos[64] = { 0, 1, 2, 7, 3, 13, 8, 19, - 4, 25, 14, 28, 9, 34, 20, 56, - 5, 17, 26, 54, 15, 41, 29, 43, - 10, 31, 38, 35, 21, 45, 49, 57, - 63, 6, 12, 18, 24, 27, 33, 55, - 16, 53, 40, 42, 30, 37, 44, 48, - 62, 11, 23, 32, 52, 39, 36, 47, - 61, 22, 51, 46, 60, 50, 59, 58 }; - return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return __builtin_ctz((U32)val); -# else - static const int DeBruijnBytePos[32] = { 0, 1, 28, 2, 29, 14, 24, 3, - 30, 22, 20, 15, 25, 17, 4, 8, - 31, 27, 13, 23, 21, 19, 16, 7, - 26, 12, 18, 6, 11, 5, 10, 9 }; - return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; -# endif - } -} +int ZSTD_seqToCodes(const seqStore_t* seqStorePtr); /* compress, dictBuilder, decodeCorpus (shouldn't get its definition from here) */ /* ZSTD_invalidateRepCodes() : @@ -420,13 +357,13 @@ typedef struct { /*! ZSTD_getcBlockSize() : * Provides the size of compressed block from block header `src` */ -/* Used by: decompress, fullbench (does not get its definition from here) */ +/* Used by: decompress, fullbench */ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, blockProperties_t* bpPtr); /*! ZSTD_decodeSeqHeaders() : * decode sequence header from src */ -/* Used by: decompress, fullbench (does not get its definition from here) */ +/* Used by: zstd_decompress_block, fullbench */ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, const void* src, size_t srcSize); diff --git a/lib/zstd/compress/clevels.h b/lib/zstd/compress/clevels.h index d9a76112ec3a..6ab8be6532ef 100644 --- a/lib/zstd/compress/clevels.h +++ b/lib/zstd/compress/clevels.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/fse_compress.c b/lib/zstd/compress/fse_compress.c index ec5b1ca6d71a..44a3c10becf2 100644 --- a/lib/zstd/compress/fse_compress.c +++ b/lib/zstd/compress/fse_compress.c @@ -1,6 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * FSE : Finite State Entropy encoder - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -25,7 +26,8 @@ #include "../common/error_private.h" #define ZSTD_DEPS_NEED_MALLOC #define ZSTD_DEPS_NEED_MATH64 -#include "../common/zstd_deps.h" /* ZSTD_malloc, ZSTD_free, ZSTD_memcpy, ZSTD_memset */ +#include "../common/zstd_deps.h" /* ZSTD_memset */ +#include "../common/bits.h" /* ZSTD_highbit32 */ /* ************************************************************** @@ -90,7 +92,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, assert(tableLog < 16); /* required for threshold strategy to work */ /* For explanations on how to distribute symbol values over the table : - * http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ + * https://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */ #ifdef __clang_analyzer__ ZSTD_memset(tableSymbol, 0, sizeof(*tableSymbol) * tableSize); /* useless initialization, just to keep scan-build happy */ @@ -191,7 +193,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, break; default : assert(normalizedCounter[s] > 1); - { U32 const maxBitsOut = tableLog - BIT_highbit32 ((U32)normalizedCounter[s]-1); + { U32 const maxBitsOut = tableLog - ZSTD_highbit32 ((U32)normalizedCounter[s]-1); U32 const minStatePlus = (U32)normalizedCounter[s] << maxBitsOut; symbolTT[s].deltaNbBits = (maxBitsOut << 16) - minStatePlus; symbolTT[s].deltaFindState = (int)(total - (unsigned)normalizedCounter[s]); @@ -224,8 +226,8 @@ size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog) size_t const maxHeaderSize = (((maxSymbolValue+1) * tableLog + 4 /* bitCount initialized at 4 */ + 2 /* first two symbols may use one additional bit each */) / 8) - + 1 /* round up to whole nb bytes */ - + 2 /* additional two bytes for bitstream flush */; + + 1 /* round up to whole nb bytes */ + + 2 /* additional two bytes for bitstream flush */; return maxSymbolValue ? maxHeaderSize : FSE_NCOUNTBOUND; /* maxSymbolValue==0 ? use default */ } @@ -254,7 +256,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, /* Init */ remaining = tableSize+1; /* +1 for extra accuracy */ threshold = tableSize; - nbBits = tableLog+1; + nbBits = (int)tableLog+1; while ((symbol < alphabetSize) && (remaining>1)) { /* stops at 1 */ if (previousIs0) { @@ -273,7 +275,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, } while (symbol >= start+3) { start+=3; - bitStream += 3 << bitCount; + bitStream += 3U << bitCount; bitCount += 2; } bitStream += (symbol-start) << bitCount; @@ -293,7 +295,7 @@ FSE_writeNCount_generic (void* header, size_t headerBufferSize, count++; /* +1 for extra accuracy */ if (count>=threshold) count += max; /* [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ */ - bitStream += count << bitCount; + bitStream += (U32)count << bitCount; bitCount += nbBits; bitCount -= (count>8); out+= (bitCount+7) /8; - return (out-ostart); + assert(out >= ostart); + return (size_t)(out-ostart); } @@ -342,21 +345,11 @@ size_t FSE_writeNCount (void* buffer, size_t bufferSize, * FSE Compression Code ****************************************************************/ -FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog) -{ - size_t size; - if (tableLog > FSE_TABLELOG_ABSOLUTE_MAX) tableLog = FSE_TABLELOG_ABSOLUTE_MAX; - size = FSE_CTABLE_SIZE_U32 (tableLog, maxSymbolValue) * sizeof(U32); - return (FSE_CTable*)ZSTD_malloc(size); -} - -void FSE_freeCTable (FSE_CTable* ct) { ZSTD_free(ct); } - /* provides the minimum logSize to safely represent a distribution */ static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) { - U32 minBitsSrc = BIT_highbit32((U32)(srcSize)) + 1; - U32 minBitsSymbols = BIT_highbit32(maxSymbolValue) + 2; + U32 minBitsSrc = ZSTD_highbit32((U32)(srcSize)) + 1; + U32 minBitsSymbols = ZSTD_highbit32(maxSymbolValue) + 2; U32 minBits = minBitsSrc < minBitsSymbols ? minBitsSrc : minBitsSymbols; assert(srcSize > 1); /* Not supported, RLE should be used instead */ return minBits; @@ -364,7 +357,7 @@ static unsigned FSE_minTableLog(size_t srcSize, unsigned maxSymbolValue) unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus) { - U32 maxBitsSrc = BIT_highbit32((U32)(srcSize - 1)) - minus; + U32 maxBitsSrc = ZSTD_highbit32((U32)(srcSize - 1)) - minus; U32 tableLog = maxTableLog; U32 minBits = FSE_minTableLog(srcSize, maxSymbolValue); assert(srcSize > 1); /* Not supported, RLE should be used instead */ @@ -532,40 +525,6 @@ size_t FSE_normalizeCount (short* normalizedCounter, unsigned tableLog, return tableLog; } - -/* fake FSE_CTable, for raw (uncompressed) input */ -size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits) -{ - const unsigned tableSize = 1 << nbBits; - const unsigned tableMask = tableSize - 1; - const unsigned maxSymbolValue = tableMask; - void* const ptr = ct; - U16* const tableU16 = ( (U16*) ptr) + 2; - void* const FSCT = ((U32*)ptr) + 1 /* header */ + (tableSize>>1); /* assumption : tableLog >= 1 */ - FSE_symbolCompressionTransform* const symbolTT = (FSE_symbolCompressionTransform*) (FSCT); - unsigned s; - - /* Sanity checks */ - if (nbBits < 1) return ERROR(GENERIC); /* min size */ - - /* header */ - tableU16[-2] = (U16) nbBits; - tableU16[-1] = (U16) maxSymbolValue; - - /* Build table */ - for (s=0; s= 2 + +static size_t showU32(const U32* arr, size_t size) { - return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); + size_t u; + for (u=0; u= sizeof(HUF_WriteCTableWksp)); + + assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue); + assert(HUF_readCTableHeader(CTable).tableLog == huffLog); + /* check conditions */ if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC); if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge); @@ -204,16 +286,6 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, return ((maxSymbolValue+1)/2) + 1; } -/*! HUF_writeCTable() : - `CTable` : Huffman tree to save, using huf representation. - @return : size of saved CTable */ -size_t HUF_writeCTable (void* dst, size_t maxDstSize, - const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog) -{ - HUF_WriteCTableWksp wksp; - return HUF_writeCTable_wksp(dst, maxDstSize, CTable, maxSymbolValue, huffLog, &wksp, sizeof(wksp)); -} - size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* hasZeroWeights) { @@ -231,7 +303,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall); - CTable[0] = tableLog; + *maxSymbolValuePtr = nbSymbols - 1; + + HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr); /* Prepare base value per rank */ { U32 n, nextRankStart = 0; @@ -263,74 +337,71 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void { U32 n; for (n=0; n HUF_readCTableHeader(CTable).maxSymbolValue) + return 0; return (U32)HUF_getNbBits(ct[symbolValue]); } -typedef struct nodeElt_s { - U32 count; - U16 parent; - BYTE byte; - BYTE nbBits; -} nodeElt; - /* * HUF_setMaxHeight(): - * Enforces maxNbBits on the Huffman tree described in huffNode. + * Try to enforce @targetNbBits on the Huffman tree described in @huffNode. * - * It sets all nodes with nbBits > maxNbBits to be maxNbBits. Then it adjusts - * the tree to so that it is a valid canonical Huffman tree. + * It attempts to convert all nodes with nbBits > @targetNbBits + * to employ @targetNbBits instead. Then it adjusts the tree + * so that it remains a valid canonical Huffman tree. * * @pre The sum of the ranks of each symbol == 2^largestBits, * where largestBits == huffNode[lastNonNull].nbBits. * @post The sum of the ranks of each symbol == 2^largestBits, - * where largestBits is the return value <= maxNbBits. + * where largestBits is the return value (expected <= targetNbBits). * - * @param huffNode The Huffman tree modified in place to enforce maxNbBits. + * @param huffNode The Huffman tree modified in place to enforce targetNbBits. + * It's presumed sorted, from most frequent to rarest symbol. * @param lastNonNull The symbol with the lowest count in the Huffman tree. - * @param maxNbBits The maximum allowed number of bits, which the Huffman tree + * @param targetNbBits The allowed number of bits, which the Huffman tree * may not respect. After this function the Huffman tree will - * respect maxNbBits. - * @return The maximum number of bits of the Huffman tree after adjustment, - * necessarily no more than maxNbBits. + * respect targetNbBits. + * @return The maximum number of bits of the Huffman tree after adjustment. */ -static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) +static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 targetNbBits) { const U32 largestBits = huffNode[lastNonNull].nbBits; - /* early exit : no elt > maxNbBits, so the tree is already valid. */ - if (largestBits <= maxNbBits) return largestBits; + /* early exit : no elt > targetNbBits, so the tree is already valid. */ + if (largestBits <= targetNbBits) return largestBits; + + DEBUGLOG(5, "HUF_setMaxHeight (targetNbBits = %u)", targetNbBits); /* there are several too large elements (at least >= 2) */ { int totalCost = 0; - const U32 baseCost = 1 << (largestBits - maxNbBits); + const U32 baseCost = 1 << (largestBits - targetNbBits); int n = (int)lastNonNull; - /* Adjust any ranks > maxNbBits to maxNbBits. + /* Adjust any ranks > targetNbBits to targetNbBits. * Compute totalCost, which is how far the sum of the ranks is * we are over 2^largestBits after adjust the offending ranks. */ - while (huffNode[n].nbBits > maxNbBits) { + while (huffNode[n].nbBits > targetNbBits) { totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)); - huffNode[n].nbBits = (BYTE)maxNbBits; + huffNode[n].nbBits = (BYTE)targetNbBits; n--; } - /* n stops at huffNode[n].nbBits <= maxNbBits */ - assert(huffNode[n].nbBits <= maxNbBits); - /* n end at index of smallest symbol using < maxNbBits */ - while (huffNode[n].nbBits == maxNbBits) --n; + /* n stops at huffNode[n].nbBits <= targetNbBits */ + assert(huffNode[n].nbBits <= targetNbBits); + /* n end at index of smallest symbol using < targetNbBits */ + while (huffNode[n].nbBits == targetNbBits) --n; - /* renorm totalCost from 2^largestBits to 2^maxNbBits + /* renorm totalCost from 2^largestBits to 2^targetNbBits * note : totalCost is necessarily a multiple of baseCost */ - assert((totalCost & (baseCost - 1)) == 0); - totalCost >>= (largestBits - maxNbBits); + assert(((U32)totalCost & (baseCost - 1)) == 0); + totalCost >>= (largestBits - targetNbBits); assert(totalCost > 0); /* repay normalized cost */ @@ -339,19 +410,19 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) /* Get pos of last (smallest = lowest cum. count) symbol per rank */ ZSTD_memset(rankLast, 0xF0, sizeof(rankLast)); - { U32 currentNbBits = maxNbBits; + { U32 currentNbBits = targetNbBits; int pos; for (pos=n ; pos >= 0; pos--) { if (huffNode[pos].nbBits >= currentNbBits) continue; - currentNbBits = huffNode[pos].nbBits; /* < maxNbBits */ - rankLast[maxNbBits-currentNbBits] = (U32)pos; + currentNbBits = huffNode[pos].nbBits; /* < targetNbBits */ + rankLast[targetNbBits-currentNbBits] = (U32)pos; } } while (totalCost > 0) { /* Try to reduce the next power of 2 above totalCost because we * gain back half the rank. */ - U32 nBitsToDecrease = BIT_highbit32((U32)totalCost) + 1; + U32 nBitsToDecrease = ZSTD_highbit32((U32)totalCost) + 1; for ( ; nBitsToDecrease > 1; nBitsToDecrease--) { U32 const highPos = rankLast[nBitsToDecrease]; U32 const lowPos = rankLast[nBitsToDecrease-1]; @@ -391,7 +462,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) rankLast[nBitsToDecrease] = noSymbol; else { rankLast[nBitsToDecrease]--; - if (huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease) + if (huffNode[rankLast[nBitsToDecrease]].nbBits != targetNbBits-nBitsToDecrease) rankLast[nBitsToDecrease] = noSymbol; /* this rank is now empty */ } } /* while (totalCost > 0) */ @@ -403,11 +474,11 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) * TODO. */ while (totalCost < 0) { /* Sometimes, cost correction overshoot */ - /* special case : no rank 1 symbol (using maxNbBits-1); - * let's create one from largest rank 0 (using maxNbBits). + /* special case : no rank 1 symbol (using targetNbBits-1); + * let's create one from largest rank 0 (using targetNbBits). */ if (rankLast[1] == noSymbol) { - while (huffNode[n].nbBits == maxNbBits) n--; + while (huffNode[n].nbBits == targetNbBits) n--; huffNode[n+1].nbBits--; assert(n >= 0); rankLast[1] = (U32)(n+1); @@ -421,7 +492,7 @@ static U32 HUF_setMaxHeight(nodeElt* huffNode, U32 lastNonNull, U32 maxNbBits) } /* repay normalized cost */ } /* there are several too large elements (at least >= 2) */ - return maxNbBits; + return targetNbBits; } typedef struct { @@ -429,7 +500,7 @@ typedef struct { U16 curr; } rankPos; -typedef nodeElt huffNodeTable[HUF_CTABLE_WORKSPACE_SIZE_U32]; +typedef nodeElt huffNodeTable[2 * (HUF_SYMBOLVALUE_MAX + 1)]; /* Number of buckets available for HUF_sort() */ #define RANK_POSITION_TABLE_SIZE 192 @@ -448,8 +519,8 @@ typedef struct { * Let buckets 166 to 192 represent all remaining counts up to RANK_POSITION_MAX_COUNT_LOG using log2 bucketing. */ #define RANK_POSITION_MAX_COUNT_LOG 32 -#define RANK_POSITION_LOG_BUCKETS_BEGIN (RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */ -#define RANK_POSITION_DISTINCT_COUNT_CUTOFF RANK_POSITION_LOG_BUCKETS_BEGIN + BIT_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */ +#define RANK_POSITION_LOG_BUCKETS_BEGIN ((RANK_POSITION_TABLE_SIZE - 1) - RANK_POSITION_MAX_COUNT_LOG - 1 /* == 158 */) +#define RANK_POSITION_DISTINCT_COUNT_CUTOFF (RANK_POSITION_LOG_BUCKETS_BEGIN + ZSTD_highbit32(RANK_POSITION_LOG_BUCKETS_BEGIN) /* == 166 */) /* Return the appropriate bucket index for a given count. See definition of * RANK_POSITION_DISTINCT_COUNT_CUTOFF for explanation of bucketing strategy. @@ -457,7 +528,7 @@ typedef struct { static U32 HUF_getIndex(U32 const count) { return (count < RANK_POSITION_DISTINCT_COUNT_CUTOFF) ? count - : BIT_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; + : ZSTD_highbit32(count) + RANK_POSITION_LOG_BUCKETS_BEGIN; } /* Helper swap function for HUF_quickSortPartition() */ @@ -580,7 +651,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy /* Sort each bucket. */ for (n = RANK_POSITION_DISTINCT_COUNT_CUTOFF; n < RANK_POSITION_TABLE_SIZE - 1; ++n) { - U32 const bucketSize = rankPosition[n].curr-rankPosition[n].base; + int const bucketSize = rankPosition[n].curr - rankPosition[n].base; U32 const bucketStartIdx = rankPosition[n].base; if (bucketSize > 1) { assert(bucketStartIdx < maxSymbolValue1); @@ -591,6 +662,7 @@ static void HUF_sort(nodeElt huffNode[], const unsigned count[], U32 const maxSy assert(HUF_isSorted(huffNode, maxSymbolValue1)); } + /* HUF_buildCTable_wksp() : * Same as HUF_buildCTable(), but using externally allocated scratch buffer. * `workSpace` must be aligned on 4-bytes boundaries, and be at least as large as sizeof(HUF_buildCTable_wksp_tables). @@ -611,6 +683,7 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) int lowS, lowN; int nodeNb = STARTNODE; int n, nodeRoot; + DEBUGLOG(5, "HUF_buildTree (alphabet size = %u)", maxSymbolValue + 1); /* init for parents */ nonNullRank = (int)maxSymbolValue; while(huffNode[nonNullRank].count == 0) nonNullRank--; @@ -637,6 +710,8 @@ static int HUF_buildTree(nodeElt* huffNode, U32 maxSymbolValue) for (n=0; n<=nonNullRank; n++) huffNode[n].nbBits = huffNode[ huffNode[n].parent ].nbBits + 1; + DEBUGLOG(6, "Initial distribution of bits completed (%zu sorted symbols)", showHNodeBits(huffNode, maxSymbolValue+1)); + return nonNullRank; } @@ -671,31 +746,40 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits); /* push nbBits per symbol, symbol order */ for (n=0; nhuffNodeTbl; nodeElt* const huffNode = huffNode0+1; int nonNullRank; + HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE == sizeof(HUF_buildCTable_wksp_tables)); + + DEBUGLOG(5, "HUF_buildCTable_wksp (alphabet size = %u)", maxSymbolValue+1); + /* safety checks */ if (wkspSize < sizeof(HUF_buildCTable_wksp_tables)) - return ERROR(workSpace_tooSmall); + return ERROR(workSpace_tooSmall); if (maxNbBits == 0) maxNbBits = HUF_TABLELOG_DEFAULT; if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) - return ERROR(maxSymbolValue_tooLarge); + return ERROR(maxSymbolValue_tooLarge); ZSTD_memset(huffNode0, 0, sizeof(huffNodeTable)); /* sort, decreasing order */ HUF_sort(huffNode, count, maxSymbolValue, wksp_tables->rankPosition); + DEBUGLOG(6, "sorted symbols completed (%zu symbols)", showHNodeSymbols(huffNode, maxSymbolValue+1)); /* build tree */ nonNullRank = HUF_buildTree(huffNode, maxSymbolValue); - /* enforce maxTableLog */ + /* determine and enforce maxTableLog */ maxNbBits = HUF_setMaxHeight(huffNode, (U32)nonNullRank, maxNbBits); if (maxNbBits > HUF_TABLELOG_MAX) return ERROR(GENERIC); /* check fit into table */ @@ -716,13 +800,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, } int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) { - HUF_CElt const* ct = CTable + 1; - int bad = 0; - int s; - for (s = 0; s <= (int)maxSymbolValue; ++s) { - bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); - } - return !bad; + HUF_CTableHeader header = HUF_readCTableHeader(CTable); + HUF_CElt const* ct = CTable + 1; + int bad = 0; + int s; + + assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX); + + if (header.maxSymbolValue < maxSymbolValue) + return 0; + + for (s = 0; s <= (int)maxSymbolValue; ++s) { + bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0); + } + return !bad; } size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); } @@ -804,7 +895,7 @@ FORCE_INLINE_TEMPLATE void HUF_addBits(HUF_CStream_t* bitC, HUF_CElt elt, int id #if DEBUGLEVEL >= 1 { size_t const nbBits = HUF_getNbBits(elt); - size_t const dirtyBits = nbBits == 0 ? 0 : BIT_highbit32((U32)nbBits) + 1; + size_t const dirtyBits = nbBits == 0 ? 0 : ZSTD_highbit32((U32)nbBits) + 1; (void)dirtyBits; /* Middle bits are 0. */ assert(((elt >> dirtyBits) << (dirtyBits + nbBits)) == 0); @@ -884,7 +975,7 @@ static size_t HUF_closeCStream(HUF_CStream_t* bitC) { size_t const nbBits = bitC->bitPos[0] & 0xFF; if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */ - return (bitC->ptr - bitC->startPtr) + (nbBits > 0); + return (size_t)(bitC->ptr - bitC->startPtr) + (nbBits > 0); } } @@ -964,17 +1055,17 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) { - U32 const tableLog = (U32)CTable[0]; + U32 const tableLog = HUF_readCTableHeader(CTable).tableLog; HUF_CElt const* ct = CTable + 1; const BYTE* ip = (const BYTE*) src; BYTE* const ostart = (BYTE*)dst; BYTE* const oend = ostart + dstSize; - BYTE* op = ostart; HUF_CStream_t bitC; /* init */ if (dstSize < 8) return 0; /* not enough space to compress */ - { size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); + { BYTE* op = ostart; + size_t const initErr = HUF_initCStream(&bitC, op, (size_t)(oend-op)); if (HUF_isError(initErr)) return 0; } if (dstSize < HUF_tightCompressBound(srcSize, (size_t)tableLog) || tableLog > 11) @@ -1045,9 +1136,9 @@ HUF_compress1X_usingCTable_internal_default(void* dst, size_t dstSize, static size_t HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, const int bmi2) + const HUF_CElt* CTable, const int flags) { - if (bmi2) { + if (flags & HUF_flags_bmi2) { return HUF_compress1X_usingCTable_internal_bmi2(dst, dstSize, src, srcSize, CTable); } return HUF_compress1X_usingCTable_internal_default(dst, dstSize, src, srcSize, CTable); @@ -1058,28 +1149,23 @@ HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, static size_t HUF_compress1X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, const int bmi2) + const HUF_CElt* CTable, const int flags) { - (void)bmi2; + (void)flags; return HUF_compress1X_usingCTable_internal_body(dst, dstSize, src, srcSize, CTable); } #endif -size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) +size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) { - return HUF_compress1X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); -} - -size_t HUF_compress1X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) -{ - return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); + return HUF_compress1X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); } static size_t HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, const void* src, size_t srcSize, - const HUF_CElt* CTable, int bmi2) + const HUF_CElt* CTable, int flags) { size_t const segmentSize = (srcSize+3)/4; /* first 3 segments */ const BYTE* ip = (const BYTE*) src; @@ -1093,7 +1179,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, op += 6; /* jumpTable */ assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart, (U16)cSize); op += cSize; @@ -1101,7 +1187,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart+2, (U16)cSize); op += cSize; @@ -1109,7 +1195,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, segmentSize, CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; MEM_writeLE16(ostart+4, (U16)cSize); op += cSize; @@ -1118,7 +1204,7 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, ip += segmentSize; assert(op <= oend); assert(ip <= iend); - { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, bmi2) ); + { CHECK_V_F(cSize, HUF_compress1X_usingCTable_internal(op, (size_t)(oend-op), ip, (size_t)(iend-ip), CTable, flags) ); if (cSize == 0 || cSize > 65535) return 0; op += cSize; } @@ -1126,14 +1212,9 @@ HUF_compress4X_usingCTable_internal(void* dst, size_t dstSize, return (size_t)(op-ostart); } -size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable) -{ - return HUF_compress4X_usingCTable_bmi2(dst, dstSize, src, srcSize, CTable, /* bmi2 */ 0); -} - -size_t HUF_compress4X_usingCTable_bmi2(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int bmi2) +size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags) { - return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, bmi2); + return HUF_compress4X_usingCTable_internal(dst, dstSize, src, srcSize, CTable, flags); } typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; @@ -1141,11 +1222,11 @@ typedef enum { HUF_singleStream, HUF_fourStreams } HUF_nbStreams_e; static size_t HUF_compressCTable_internal( BYTE* const ostart, BYTE* op, BYTE* const oend, const void* src, size_t srcSize, - HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int bmi2) + HUF_nbStreams_e nbStreams, const HUF_CElt* CTable, const int flags) { size_t const cSize = (nbStreams==HUF_singleStream) ? - HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2) : - HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, bmi2); + HUF_compress1X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags) : + HUF_compress4X_usingCTable_internal(op, (size_t)(oend - op), src, srcSize, CTable, flags); if (HUF_isError(cSize)) { return cSize; } if (cSize==0) { return 0; } /* uncompressible */ op += cSize; @@ -1168,6 +1249,81 @@ typedef struct { #define SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE 4096 #define SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO 10 /* Must be >= 2 */ +unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue) +{ + unsigned cardinality = 0; + unsigned i; + + for (i = 0; i < maxSymbolValue + 1; i++) { + if (count[i] != 0) cardinality += 1; + } + + return cardinality; +} + +unsigned HUF_minTableLog(unsigned symbolCardinality) +{ + U32 minBitsSymbols = ZSTD_highbit32(symbolCardinality) + 1; + return minBitsSymbols; +} + +unsigned HUF_optimalTableLog( + unsigned maxTableLog, + size_t srcSize, + unsigned maxSymbolValue, + void* workSpace, size_t wkspSize, + HUF_CElt* table, + const unsigned* count, + int flags) +{ + assert(srcSize > 1); /* Not supported, RLE should be used instead */ + assert(wkspSize >= sizeof(HUF_buildCTable_wksp_tables)); + + if (!(flags & HUF_flags_optimalDepth)) { + /* cheap evaluation, based on FSE */ + return FSE_optimalTableLog_internal(maxTableLog, srcSize, maxSymbolValue, 1); + } + + { BYTE* dst = (BYTE*)workSpace + sizeof(HUF_WriteCTableWksp); + size_t dstSize = wkspSize - sizeof(HUF_WriteCTableWksp); + size_t hSize, newSize; + const unsigned symbolCardinality = HUF_cardinality(count, maxSymbolValue); + const unsigned minTableLog = HUF_minTableLog(symbolCardinality); + size_t optSize = ((size_t) ~0) - 1; + unsigned optLog = maxTableLog, optLogGuess; + + DEBUGLOG(6, "HUF_optimalTableLog: probing huf depth (srcSize=%zu)", srcSize); + + /* Search until size increases */ + for (optLogGuess = minTableLog; optLogGuess <= maxTableLog; optLogGuess++) { + DEBUGLOG(7, "checking for huffLog=%u", optLogGuess); + + { size_t maxBits = HUF_buildCTable_wksp(table, count, maxSymbolValue, optLogGuess, workSpace, wkspSize); + if (ERR_isError(maxBits)) continue; + + if (maxBits < optLogGuess && optLogGuess > minTableLog) break; + + hSize = HUF_writeCTable_wksp(dst, dstSize, table, maxSymbolValue, (U32)maxBits, workSpace, wkspSize); + } + + if (ERR_isError(hSize)) continue; + + newSize = HUF_estimateCompressedSize(table, count, maxSymbolValue) + hSize; + + if (newSize > optSize + 1) { + break; + } + + if (newSize < optSize) { + optSize = newSize; + optLog = optLogGuess; + } + } + assert(optLog <= HUF_TABLELOG_MAX); + return optLog; + } +} + /* HUF_compress_internal() : * `workSpace_align4` must be aligned on 4-bytes boundaries, * and occupies the same space as a table of HUF_WORKSPACE_SIZE_U64 unsigned */ @@ -1177,14 +1333,14 @@ HUF_compress_internal (void* dst, size_t dstSize, unsigned maxSymbolValue, unsigned huffLog, HUF_nbStreams_e nbStreams, void* workSpace, size_t wkspSize, - HUF_CElt* oldHufTable, HUF_repeat* repeat, int preferRepeat, - const int bmi2, unsigned suspectUncompressible) + HUF_CElt* oldHufTable, HUF_repeat* repeat, int flags) { HUF_compress_tables_t* const table = (HUF_compress_tables_t*)HUF_alignUpWorkspace(workSpace, &wkspSize, ZSTD_ALIGNOF(size_t)); BYTE* const ostart = (BYTE*)dst; BYTE* const oend = ostart + dstSize; BYTE* op = ostart; + DEBUGLOG(5, "HUF_compress_internal (srcSize=%zu)", srcSize); HUF_STATIC_ASSERT(sizeof(*table) + HUF_WORKSPACE_MAX_ALIGNMENT <= HUF_WORKSPACE_SIZE); /* checks & inits */ @@ -1198,16 +1354,17 @@ HUF_compress_internal (void* dst, size_t dstSize, if (!huffLog) huffLog = HUF_TABLELOG_DEFAULT; /* Heuristic : If old table is valid, use it for small inputs */ - if (preferRepeat && repeat && *repeat == HUF_repeat_valid) { + if ((flags & HUF_flags_preferRepeat) && repeat && *repeat == HUF_repeat_valid) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } /* If uncompressible data is suspected, do a smaller sampling first */ DEBUG_STATIC_ASSERT(SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO >= 2); - if (suspectUncompressible && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { + if ((flags & HUF_flags_suspectUncompressible) && srcSize >= (SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE * SUSPECT_INCOMPRESSIBLE_SAMPLE_RATIO)) { size_t largestTotal = 0; + DEBUGLOG(5, "input suspected incompressible : sampling to check"); { unsigned maxSymbolValueBegin = maxSymbolValue; CHECK_V_F(largestBegin, HIST_count_simple (table->count, &maxSymbolValueBegin, (const BYTE*)src, SUSPECT_INCOMPRESSIBLE_SAMPLE_SIZE) ); largestTotal += largestBegin; @@ -1224,6 +1381,7 @@ HUF_compress_internal (void* dst, size_t dstSize, if (largest == srcSize) { *ostart = ((const BYTE*)src)[0]; return 1; } /* single symbol, rle */ if (largest <= (srcSize >> 7)+4) return 0; /* heuristic : probably not compressible enough */ } + DEBUGLOG(6, "histogram detail completed (%zu symbols)", showU32(table->count, maxSymbolValue+1)); /* Check validity of previous table */ if ( repeat @@ -1232,25 +1390,20 @@ HUF_compress_internal (void* dst, size_t dstSize, *repeat = HUF_repeat_none; } /* Heuristic : use existing table for small inputs */ - if (preferRepeat && repeat && *repeat != HUF_repeat_none) { + if ((flags & HUF_flags_preferRepeat) && repeat && *repeat != HUF_repeat_none) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } /* Build Huffman Tree */ - huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); + huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, &table->wksps, sizeof(table->wksps), table->CTable, table->count, flags); { size_t const maxBits = HUF_buildCTable_wksp(table->CTable, table->count, maxSymbolValue, huffLog, &table->wksps.buildCTable_wksp, sizeof(table->wksps.buildCTable_wksp)); CHECK_F(maxBits); huffLog = (U32)maxBits; - } - /* Zero unused symbols in CTable, so we can check it for validity */ - { - size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue); - size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt); - ZSTD_memset(table->CTable + ctableSize, 0, unusedSize); + DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1)); } /* Write table description header */ @@ -1263,7 +1416,7 @@ HUF_compress_internal (void* dst, size_t dstSize, if (oldSize <= hSize + newSize || hSize + 12 >= srcSize) { return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, oldHufTable, bmi2); + nbStreams, oldHufTable, flags); } } /* Use the new huffman table */ @@ -1275,61 +1428,35 @@ HUF_compress_internal (void* dst, size_t dstSize, } return HUF_compressCTable_internal(ostart, op, oend, src, srcSize, - nbStreams, table->CTable, bmi2); -} - - -size_t HUF_compress1X_wksp (void* dst, size_t dstSize, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned huffLog, - void* workSpace, size_t wkspSize) -{ - return HUF_compress_internal(dst, dstSize, src, srcSize, - maxSymbolValue, huffLog, HUF_singleStream, - workSpace, wkspSize, - NULL, NULL, 0, 0 /*bmi2*/, 0); + nbStreams, table->CTable, flags); } size_t HUF_compress1X_repeat (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void* workSpace, size_t wkspSize, - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, - int bmi2, unsigned suspectUncompressible) + HUF_CElt* hufTable, HUF_repeat* repeat, int flags) { + DEBUGLOG(5, "HUF_compress1X_repeat (srcSize = %zu)", srcSize); return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, HUF_singleStream, workSpace, wkspSize, hufTable, - repeat, preferRepeat, bmi2, suspectUncompressible); -} - -/* HUF_compress4X_repeat(): - * compress input using 4 streams. - * provide workspace to generate compression tables */ -size_t HUF_compress4X_wksp (void* dst, size_t dstSize, - const void* src, size_t srcSize, - unsigned maxSymbolValue, unsigned huffLog, - void* workSpace, size_t wkspSize) -{ - return HUF_compress_internal(dst, dstSize, src, srcSize, - maxSymbolValue, huffLog, HUF_fourStreams, - workSpace, wkspSize, - NULL, NULL, 0, 0 /*bmi2*/, 0); + repeat, flags); } /* HUF_compress4X_repeat(): * compress input using 4 streams. * consider skipping quickly - * re-use an existing huffman compression table */ + * reuse an existing huffman compression table */ size_t HUF_compress4X_repeat (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned huffLog, void* workSpace, size_t wkspSize, - HUF_CElt* hufTable, HUF_repeat* repeat, int preferRepeat, int bmi2, unsigned suspectUncompressible) + HUF_CElt* hufTable, HUF_repeat* repeat, int flags) { + DEBUGLOG(5, "HUF_compress4X_repeat (srcSize = %zu)", srcSize); return HUF_compress_internal(dst, dstSize, src, srcSize, maxSymbolValue, huffLog, HUF_fourStreams, workSpace, wkspSize, - hufTable, repeat, preferRepeat, bmi2, suspectUncompressible); + hufTable, repeat, flags); } - diff --git a/lib/zstd/compress/zstd_compress.c b/lib/zstd/compress/zstd_compress.c index 16bb995bc6c4..885167f7e47b 100644 --- a/lib/zstd/compress/zstd_compress.c +++ b/lib/zstd/compress/zstd_compress.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,12 +12,12 @@ /*-************************************* * Dependencies ***************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ #include "../common/zstd_deps.h" /* INT_MAX, ZSTD_memset, ZSTD_memcpy */ #include "../common/mem.h" #include "hist.h" /* HIST_countFast_wksp */ #define FSE_STATIC_LINKING_ONLY /* FSE_encodeSymbol */ #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "zstd_compress_internal.h" #include "zstd_compress_sequences.h" @@ -27,6 +28,7 @@ #include "zstd_opt.h" #include "zstd_ldm.h" #include "zstd_compress_superblock.h" +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_rotateRight_U64 */ /* *************************************************************** * Tuning parameters @@ -55,14 +57,17 @@ * Helper functions ***************************************/ /* ZSTD_compressBound() - * Note that the result from this function is only compatible with the "normal" - * full-block strategy. - * When there are a lot of small blocks due to frequent flush in streaming mode - * the overhead of headers can make the compressed data to be larger than the - * return value of ZSTD_compressBound(). + * Note that the result from this function is only valid for + * the one-pass compression functions. + * When employing the streaming mode, + * if flushes are frequently altering the size of blocks, + * the overhead from block headers can make the compressed data larger + * than the return value of ZSTD_compressBound(). */ size_t ZSTD_compressBound(size_t srcSize) { - return ZSTD_COMPRESSBOUND(srcSize); + size_t const r = ZSTD_COMPRESSBOUND(srcSize); + if (r==0) return ERROR(srcSize_wrong); + return r; } @@ -168,15 +173,13 @@ static void ZSTD_freeCCtxContent(ZSTD_CCtx* cctx) size_t ZSTD_freeCCtx(ZSTD_CCtx* cctx) { + DEBUGLOG(3, "ZSTD_freeCCtx (address: %p)", (void*)cctx); if (cctx==NULL) return 0; /* support free on NULL */ RETURN_ERROR_IF(cctx->staticSize, memory_allocation, "not compatible with static CCtx"); - { - int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); + { int cctxInWorkspace = ZSTD_cwksp_owns_buffer(&cctx->workspace, cctx); ZSTD_freeCCtxContent(cctx); - if (!cctxInWorkspace) { - ZSTD_customFree(cctx, cctx->customMem); - } + if (!cctxInWorkspace) ZSTD_customFree(cctx, cctx->customMem); } return 0; } @@ -257,9 +260,9 @@ static int ZSTD_allocateChainTable(const ZSTD_strategy strategy, return forDDSDict || ((strategy != ZSTD_fast) && !ZSTD_rowMatchFinderUsed(strategy, useRowMatchFinder)); } -/* Returns 1 if compression parameters are such that we should +/* Returns ZSTD_ps_enable if compression parameters are such that we should * enable long distance matching (wlog >= 27, strategy >= btopt). - * Returns 0 otherwise. + * Returns ZSTD_ps_disable otherwise. */ static ZSTD_paramSwitch_e ZSTD_resolveEnableLdm(ZSTD_paramSwitch_e mode, const ZSTD_compressionParameters* const cParams) { @@ -267,6 +270,34 @@ static ZSTD_paramSwitch_e ZSTD_resolveEnableLdm(ZSTD_paramSwitch_e mode, return (cParams->strategy >= ZSTD_btopt && cParams->windowLog >= 27) ? ZSTD_ps_enable : ZSTD_ps_disable; } +static int ZSTD_resolveExternalSequenceValidation(int mode) { + return mode; +} + +/* Resolves maxBlockSize to the default if no value is present. */ +static size_t ZSTD_resolveMaxBlockSize(size_t maxBlockSize) { + if (maxBlockSize == 0) { + return ZSTD_BLOCKSIZE_MAX; + } else { + return maxBlockSize; + } +} + +static ZSTD_paramSwitch_e ZSTD_resolveExternalRepcodeSearch(ZSTD_paramSwitch_e value, int cLevel) { + if (value != ZSTD_ps_auto) return value; + if (cLevel < 10) { + return ZSTD_ps_disable; + } else { + return ZSTD_ps_enable; + } +} + +/* Returns 1 if compression parameters are such that CDict hashtable and chaintable indices are tagged. + * If so, the tags need to be removed in ZSTD_resetCCtx_byCopyingCDict. */ +static int ZSTD_CDictIndicesAreTagged(const ZSTD_compressionParameters* const cParams) { + return cParams->strategy == ZSTD_fast || cParams->strategy == ZSTD_dfast; +} + static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( ZSTD_compressionParameters cParams) { @@ -284,6 +315,10 @@ static ZSTD_CCtx_params ZSTD_makeCCtxParamsFromCParams( } cctxParams.useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams.useBlockSplitter, &cParams); cctxParams.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams.useRowMatchFinder, &cParams); + cctxParams.validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams.validateSequences); + cctxParams.maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams.maxBlockSize); + cctxParams.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams.searchForExternalRepcodes, + cctxParams.compressionLevel); assert(!ZSTD_checkCParams(cParams)); return cctxParams; } @@ -329,10 +364,13 @@ size_t ZSTD_CCtxParams_init(ZSTD_CCtx_params* cctxParams, int compressionLevel) #define ZSTD_NO_CLEVEL 0 /* - * Initializes the cctxParams from params and compressionLevel. + * Initializes `cctxParams` from `params` and `compressionLevel`. * @param compressionLevel If params are derived from a compression level then that compression level, otherwise ZSTD_NO_CLEVEL. */ -static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_parameters const* params, int compressionLevel) +static void +ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, + const ZSTD_parameters* params, + int compressionLevel) { assert(!ZSTD_checkCParams(params->cParams)); ZSTD_memset(cctxParams, 0, sizeof(*cctxParams)); @@ -345,6 +383,9 @@ static void ZSTD_CCtxParams_init_internal(ZSTD_CCtx_params* cctxParams, ZSTD_par cctxParams->useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(cctxParams->useRowMatchFinder, ¶ms->cParams); cctxParams->useBlockSplitter = ZSTD_resolveBlockSplitterMode(cctxParams->useBlockSplitter, ¶ms->cParams); cctxParams->ldmParams.enableLdm = ZSTD_resolveEnableLdm(cctxParams->ldmParams.enableLdm, ¶ms->cParams); + cctxParams->validateSequences = ZSTD_resolveExternalSequenceValidation(cctxParams->validateSequences); + cctxParams->maxBlockSize = ZSTD_resolveMaxBlockSize(cctxParams->maxBlockSize); + cctxParams->searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(cctxParams->searchForExternalRepcodes, compressionLevel); DEBUGLOG(4, "ZSTD_CCtxParams_init_internal: useRowMatchFinder=%d, useBlockSplitter=%d ldm=%d", cctxParams->useRowMatchFinder, cctxParams->useBlockSplitter, cctxParams->ldmParams.enableLdm); } @@ -359,7 +400,7 @@ size_t ZSTD_CCtxParams_init_advanced(ZSTD_CCtx_params* cctxParams, ZSTD_paramete /* * Sets cctxParams' cParams and fParams from params, but otherwise leaves them alone. - * @param param Validated zstd parameters. + * @param params Validated zstd parameters. */ static void ZSTD_CCtxParams_setZstdParams( ZSTD_CCtx_params* cctxParams, const ZSTD_parameters* params) @@ -455,8 +496,8 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) return bounds; case ZSTD_c_enableLongDistanceMatching: - bounds.lowerBound = 0; - bounds.upperBound = 1; + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; return bounds; case ZSTD_c_ldmHashLog: @@ -549,6 +590,26 @@ ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter param) bounds.upperBound = 1; return bounds; + case ZSTD_c_prefetchCDictTables: + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; + return bounds; + + case ZSTD_c_enableSeqProducerFallback: + bounds.lowerBound = 0; + bounds.upperBound = 1; + return bounds; + + case ZSTD_c_maxBlockSize: + bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; + bounds.upperBound = ZSTD_BLOCKSIZE_MAX; + return bounds; + + case ZSTD_c_searchForExternalRepcodes: + bounds.lowerBound = (int)ZSTD_ps_auto; + bounds.upperBound = (int)ZSTD_ps_disable; + return bounds; + default: bounds.error = ERROR(parameter_unsupported); return bounds; @@ -567,10 +628,11 @@ static size_t ZSTD_cParam_clampBounds(ZSTD_cParameter cParam, int* value) return 0; } -#define BOUNDCHECK(cParam, val) { \ - RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ - parameter_outOfBound, "Param out of bounds"); \ -} +#define BOUNDCHECK(cParam, val) \ + do { \ + RETURN_ERROR_IF(!ZSTD_cParam_withinBounds(cParam,val), \ + parameter_outOfBound, "Param out of bounds"); \ + } while (0) static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) @@ -613,6 +675,10 @@ static int ZSTD_isUpdateAuthorized(ZSTD_cParameter param) case ZSTD_c_useBlockSplitter: case ZSTD_c_useRowMatchFinder: case ZSTD_c_deterministicRefPrefix: + case ZSTD_c_prefetchCDictTables: + case ZSTD_c_enableSeqProducerFallback: + case ZSTD_c_maxBlockSize: + case ZSTD_c_searchForExternalRepcodes: default: return 0; } @@ -625,7 +691,7 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) if (ZSTD_isUpdateAuthorized(param)) { cctx->cParamsChanged = 1; } else { - RETURN_ERROR(stage_wrong, "can only set params in ctx init stage"); + RETURN_ERROR(stage_wrong, "can only set params in cctx init stage"); } } switch(param) @@ -668,6 +734,10 @@ size_t ZSTD_CCtx_setParameter(ZSTD_CCtx* cctx, ZSTD_cParameter param, int value) case ZSTD_c_useBlockSplitter: case ZSTD_c_useRowMatchFinder: case ZSTD_c_deterministicRefPrefix: + case ZSTD_c_prefetchCDictTables: + case ZSTD_c_enableSeqProducerFallback: + case ZSTD_c_maxBlockSize: + case ZSTD_c_searchForExternalRepcodes: break; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); @@ -723,12 +793,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_minMatch : if (value!=0) /* 0 => use default */ BOUNDCHECK(ZSTD_c_minMatch, value); - CCtxParams->cParams.minMatch = value; + CCtxParams->cParams.minMatch = (U32)value; return CCtxParams->cParams.minMatch; case ZSTD_c_targetLength : BOUNDCHECK(ZSTD_c_targetLength, value); - CCtxParams->cParams.targetLength = value; + CCtxParams->cParams.targetLength = (U32)value; return CCtxParams->cParams.targetLength; case ZSTD_c_strategy : @@ -741,12 +811,12 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, /* Content size written in frame header _when known_ (default:1) */ DEBUGLOG(4, "set content size flag = %u", (value!=0)); CCtxParams->fParams.contentSizeFlag = value != 0; - return CCtxParams->fParams.contentSizeFlag; + return (size_t)CCtxParams->fParams.contentSizeFlag; case ZSTD_c_checksumFlag : /* A 32-bits content checksum will be calculated and written at end of frame (default:0) */ CCtxParams->fParams.checksumFlag = value != 0; - return CCtxParams->fParams.checksumFlag; + return (size_t)CCtxParams->fParams.checksumFlag; case ZSTD_c_dictIDFlag : /* When applicable, dictionary's dictID is provided in frame header (default:1) */ DEBUGLOG(4, "set dictIDFlag = %u", (value!=0)); @@ -755,18 +825,18 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_forceMaxWindow : CCtxParams->forceWindow = (value != 0); - return CCtxParams->forceWindow; + return (size_t)CCtxParams->forceWindow; case ZSTD_c_forceAttachDict : { const ZSTD_dictAttachPref_e pref = (ZSTD_dictAttachPref_e)value; - BOUNDCHECK(ZSTD_c_forceAttachDict, pref); + BOUNDCHECK(ZSTD_c_forceAttachDict, (int)pref); CCtxParams->attachDictPref = pref; return CCtxParams->attachDictPref; } case ZSTD_c_literalCompressionMode : { const ZSTD_paramSwitch_e lcm = (ZSTD_paramSwitch_e)value; - BOUNDCHECK(ZSTD_c_literalCompressionMode, lcm); + BOUNDCHECK(ZSTD_c_literalCompressionMode, (int)lcm); CCtxParams->literalCompressionMode = lcm; return CCtxParams->literalCompressionMode; } @@ -789,47 +859,50 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_enableDedicatedDictSearch : CCtxParams->enableDedicatedDictSearch = (value!=0); - return CCtxParams->enableDedicatedDictSearch; + return (size_t)CCtxParams->enableDedicatedDictSearch; case ZSTD_c_enableLongDistanceMatching : + BOUNDCHECK(ZSTD_c_enableLongDistanceMatching, value); CCtxParams->ldmParams.enableLdm = (ZSTD_paramSwitch_e)value; return CCtxParams->ldmParams.enableLdm; case ZSTD_c_ldmHashLog : if (value!=0) /* 0 ==> auto */ BOUNDCHECK(ZSTD_c_ldmHashLog, value); - CCtxParams->ldmParams.hashLog = value; + CCtxParams->ldmParams.hashLog = (U32)value; return CCtxParams->ldmParams.hashLog; case ZSTD_c_ldmMinMatch : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmMinMatch, value); - CCtxParams->ldmParams.minMatchLength = value; + CCtxParams->ldmParams.minMatchLength = (U32)value; return CCtxParams->ldmParams.minMatchLength; case ZSTD_c_ldmBucketSizeLog : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmBucketSizeLog, value); - CCtxParams->ldmParams.bucketSizeLog = value; + CCtxParams->ldmParams.bucketSizeLog = (U32)value; return CCtxParams->ldmParams.bucketSizeLog; case ZSTD_c_ldmHashRateLog : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_ldmHashRateLog, value); - CCtxParams->ldmParams.hashRateLog = value; + CCtxParams->ldmParams.hashRateLog = (U32)value; return CCtxParams->ldmParams.hashRateLog; case ZSTD_c_targetCBlockSize : - if (value!=0) /* 0 ==> default */ + if (value!=0) { /* 0 ==> default */ + value = MAX(value, ZSTD_TARGETCBLOCKSIZE_MIN); BOUNDCHECK(ZSTD_c_targetCBlockSize, value); - CCtxParams->targetCBlockSize = value; + } + CCtxParams->targetCBlockSize = (U32)value; return CCtxParams->targetCBlockSize; case ZSTD_c_srcSizeHint : if (value!=0) /* 0 ==> default */ BOUNDCHECK(ZSTD_c_srcSizeHint, value); CCtxParams->srcSizeHint = value; - return CCtxParams->srcSizeHint; + return (size_t)CCtxParams->srcSizeHint; case ZSTD_c_stableInBuffer: BOUNDCHECK(ZSTD_c_stableInBuffer, value); @@ -849,7 +922,7 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_validateSequences: BOUNDCHECK(ZSTD_c_validateSequences, value); CCtxParams->validateSequences = value; - return CCtxParams->validateSequences; + return (size_t)CCtxParams->validateSequences; case ZSTD_c_useBlockSplitter: BOUNDCHECK(ZSTD_c_useBlockSplitter, value); @@ -864,7 +937,28 @@ size_t ZSTD_CCtxParams_setParameter(ZSTD_CCtx_params* CCtxParams, case ZSTD_c_deterministicRefPrefix: BOUNDCHECK(ZSTD_c_deterministicRefPrefix, value); CCtxParams->deterministicRefPrefix = !!value; - return CCtxParams->deterministicRefPrefix; + return (size_t)CCtxParams->deterministicRefPrefix; + + case ZSTD_c_prefetchCDictTables: + BOUNDCHECK(ZSTD_c_prefetchCDictTables, value); + CCtxParams->prefetchCDictTables = (ZSTD_paramSwitch_e)value; + return CCtxParams->prefetchCDictTables; + + case ZSTD_c_enableSeqProducerFallback: + BOUNDCHECK(ZSTD_c_enableSeqProducerFallback, value); + CCtxParams->enableMatchFinderFallback = value; + return (size_t)CCtxParams->enableMatchFinderFallback; + + case ZSTD_c_maxBlockSize: + if (value!=0) /* 0 ==> default */ + BOUNDCHECK(ZSTD_c_maxBlockSize, value); + CCtxParams->maxBlockSize = value; + return CCtxParams->maxBlockSize; + + case ZSTD_c_searchForExternalRepcodes: + BOUNDCHECK(ZSTD_c_searchForExternalRepcodes, value); + CCtxParams->searchForExternalRepcodes = (ZSTD_paramSwitch_e)value; + return CCtxParams->searchForExternalRepcodes; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); } @@ -980,6 +1074,18 @@ size_t ZSTD_CCtxParams_getParameter( case ZSTD_c_deterministicRefPrefix: *value = (int)CCtxParams->deterministicRefPrefix; break; + case ZSTD_c_prefetchCDictTables: + *value = (int)CCtxParams->prefetchCDictTables; + break; + case ZSTD_c_enableSeqProducerFallback: + *value = CCtxParams->enableMatchFinderFallback; + break; + case ZSTD_c_maxBlockSize: + *value = (int)CCtxParams->maxBlockSize; + break; + case ZSTD_c_searchForExternalRepcodes: + *value = (int)CCtxParams->searchForExternalRepcodes; + break; default: RETURN_ERROR(parameter_unsupported, "unknown parameter"); } return 0; @@ -1006,9 +1112,47 @@ size_t ZSTD_CCtx_setParametersUsingCCtxParams( return 0; } +size_t ZSTD_CCtx_setCParams(ZSTD_CCtx* cctx, ZSTD_compressionParameters cparams) +{ + ZSTD_STATIC_ASSERT(sizeof(cparams) == 7 * 4 /* all params are listed below */); + DEBUGLOG(4, "ZSTD_CCtx_setCParams"); + /* only update if all parameters are valid */ + FORWARD_IF_ERROR(ZSTD_checkCParams(cparams), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, cparams.windowLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_chainLog, cparams.chainLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_hashLog, cparams.hashLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_searchLog, cparams.searchLog), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_minMatch, cparams.minMatch), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_targetLength, cparams.targetLength), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_strategy, cparams.strategy), ""); + return 0; +} + +size_t ZSTD_CCtx_setFParams(ZSTD_CCtx* cctx, ZSTD_frameParameters fparams) +{ + ZSTD_STATIC_ASSERT(sizeof(fparams) == 3 * 4 /* all params are listed below */); + DEBUGLOG(4, "ZSTD_CCtx_setFParams"); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_contentSizeFlag, fparams.contentSizeFlag != 0), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, fparams.checksumFlag != 0), ""); + FORWARD_IF_ERROR(ZSTD_CCtx_setParameter(cctx, ZSTD_c_dictIDFlag, fparams.noDictIDFlag == 0), ""); + return 0; +} + +size_t ZSTD_CCtx_setParams(ZSTD_CCtx* cctx, ZSTD_parameters params) +{ + DEBUGLOG(4, "ZSTD_CCtx_setParams"); + /* First check cParams, because we want to update all or none. */ + FORWARD_IF_ERROR(ZSTD_checkCParams(params.cParams), ""); + /* Next set fParams, because this could fail if the cctx isn't in init stage. */ + FORWARD_IF_ERROR(ZSTD_CCtx_setFParams(cctx, params.fParams), ""); + /* Finally set cParams, which should succeed. */ + FORWARD_IF_ERROR(ZSTD_CCtx_setCParams(cctx, params.cParams), ""); + return 0; +} + size_t ZSTD_CCtx_setPledgedSrcSize(ZSTD_CCtx* cctx, unsigned long long pledgedSrcSize) { - DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %u bytes", (U32)pledgedSrcSize); + DEBUGLOG(4, "ZSTD_CCtx_setPledgedSrcSize to %llu bytes", pledgedSrcSize); RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, "Can't set pledgedSrcSize when not in init stage."); cctx->pledgedSrcSizePlusOne = pledgedSrcSize+1; @@ -1024,9 +1168,9 @@ static void ZSTD_dedicatedDictSearch_revertCParams( ZSTD_compressionParameters* cParams); /* - * Initializes the local dict using the requested parameters. - * NOTE: This does not use the pledged src size, because it may be used for more - * than one compression. + * Initializes the local dictionary using requested parameters. + * NOTE: Initialization does not employ the pledged src size, + * because the dictionary may be used for multiple compressions. */ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) { @@ -1039,8 +1183,8 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) return 0; } if (dl->cdict != NULL) { - assert(cctx->cdict == dl->cdict); /* Local dictionary already initialized. */ + assert(cctx->cdict == dl->cdict); return 0; } assert(dl->dictSize > 0); @@ -1060,26 +1204,30 @@ static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx) } size_t ZSTD_CCtx_loadDictionary_advanced( - ZSTD_CCtx* cctx, const void* dict, size_t dictSize, - ZSTD_dictLoadMethod_e dictLoadMethod, ZSTD_dictContentType_e dictContentType) + ZSTD_CCtx* cctx, + const void* dict, size_t dictSize, + ZSTD_dictLoadMethod_e dictLoadMethod, + ZSTD_dictContentType_e dictContentType) { - RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, - "Can't load a dictionary when ctx is not in init stage."); DEBUGLOG(4, "ZSTD_CCtx_loadDictionary_advanced (size: %u)", (U32)dictSize); - ZSTD_clearAllDicts(cctx); /* in case one already exists */ - if (dict == NULL || dictSize == 0) /* no dictionary mode */ + RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, + "Can't load a dictionary when cctx is not in init stage."); + ZSTD_clearAllDicts(cctx); /* erase any previously set dictionary */ + if (dict == NULL || dictSize == 0) /* no dictionary */ return 0; if (dictLoadMethod == ZSTD_dlm_byRef) { cctx->localDict.dict = dict; } else { + /* copy dictionary content inside CCtx to own its lifetime */ void* dictBuffer; RETURN_ERROR_IF(cctx->staticSize, memory_allocation, - "no malloc for static CCtx"); + "static CCtx can't allocate for an internal copy of dictionary"); dictBuffer = ZSTD_customMalloc(dictSize, cctx->customMem); - RETURN_ERROR_IF(!dictBuffer, memory_allocation, "NULL pointer!"); + RETURN_ERROR_IF(dictBuffer==NULL, memory_allocation, + "allocation failed for dictionary content"); ZSTD_memcpy(dictBuffer, dict, dictSize); - cctx->localDict.dictBuffer = dictBuffer; - cctx->localDict.dict = dictBuffer; + cctx->localDict.dictBuffer = dictBuffer; /* owned ptr to free */ + cctx->localDict.dict = dictBuffer; /* read-only reference */ } cctx->localDict.dictSize = dictSize; cctx->localDict.dictContentType = dictContentType; @@ -1149,7 +1297,7 @@ size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset) if ( (reset == ZSTD_reset_parameters) || (reset == ZSTD_reset_session_and_parameters) ) { RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong, - "Can't reset parameters only when not in init stage."); + "Reset parameters is only possible during init stage."); ZSTD_clearAllDicts(cctx); return ZSTD_CCtxParams_reset(&cctx->requestedParams); } @@ -1178,11 +1326,12 @@ size_t ZSTD_checkCParams(ZSTD_compressionParameters cParams) static ZSTD_compressionParameters ZSTD_clampCParams(ZSTD_compressionParameters cParams) { -# define CLAMP_TYPE(cParam, val, type) { \ - ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ - if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ - } +# define CLAMP_TYPE(cParam, val, type) \ + do { \ + ZSTD_bounds const bounds = ZSTD_cParam_getBounds(cParam); \ + if ((int)valbounds.upperBound) val=(type)bounds.upperBound; \ + } while (0) # define CLAMP(cParam, val) CLAMP_TYPE(cParam, val, unsigned) CLAMP(ZSTD_c_windowLog, cParams.windowLog); CLAMP(ZSTD_c_chainLog, cParams.chainLog); @@ -1247,12 +1396,55 @@ static ZSTD_compressionParameters ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, unsigned long long srcSize, size_t dictSize, - ZSTD_cParamMode_e mode) + ZSTD_cParamMode_e mode, + ZSTD_paramSwitch_e useRowMatchFinder) { const U64 minSrcSize = 513; /* (1<<9) + 1 */ const U64 maxWindowResize = 1ULL << (ZSTD_WINDOWLOG_MAX-1); assert(ZSTD_checkCParams(cPar)==0); + /* Cascade the selected strategy down to the next-highest one built into + * this binary. */ +#ifdef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btultra2) { + cPar.strategy = ZSTD_btultra; + } + if (cPar.strategy == ZSTD_btultra) { + cPar.strategy = ZSTD_btopt; + } +#endif +#ifdef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btopt) { + cPar.strategy = ZSTD_btlazy2; + } +#endif +#ifdef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_btlazy2) { + cPar.strategy = ZSTD_lazy2; + } +#endif +#ifdef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_lazy2) { + cPar.strategy = ZSTD_lazy; + } +#endif +#ifdef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_lazy) { + cPar.strategy = ZSTD_greedy; + } +#endif +#ifdef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_greedy) { + cPar.strategy = ZSTD_dfast; + } +#endif +#ifdef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + if (cPar.strategy == ZSTD_dfast) { + cPar.strategy = ZSTD_fast; + cPar.targetLength = 0; + } +#endif + switch (mode) { case ZSTD_cpm_unknown: case ZSTD_cpm_noAttachDict: @@ -1281,8 +1473,8 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, } /* resize windowLog if input is small enough, to use less memory */ - if ( (srcSize < maxWindowResize) - && (dictSize < maxWindowResize) ) { + if ( (srcSize <= maxWindowResize) + && (dictSize <= maxWindowResize) ) { U32 const tSize = (U32)(srcSize + dictSize); static U32 const hashSizeMin = 1 << ZSTD_HASHLOG_MIN; U32 const srcLog = (tSize < hashSizeMin) ? ZSTD_HASHLOG_MIN : @@ -1300,6 +1492,42 @@ ZSTD_adjustCParams_internal(ZSTD_compressionParameters cPar, if (cPar.windowLog < ZSTD_WINDOWLOG_ABSOLUTEMIN) cPar.windowLog = ZSTD_WINDOWLOG_ABSOLUTEMIN; /* minimum wlog required for valid frame header */ + /* We can't use more than 32 bits of hash in total, so that means that we require: + * (hashLog + 8) <= 32 && (chainLog + 8) <= 32 + */ + if (mode == ZSTD_cpm_createCDict && ZSTD_CDictIndicesAreTagged(&cPar)) { + U32 const maxShortCacheHashLog = 32 - ZSTD_SHORT_CACHE_TAG_BITS; + if (cPar.hashLog > maxShortCacheHashLog) { + cPar.hashLog = maxShortCacheHashLog; + } + if (cPar.chainLog > maxShortCacheHashLog) { + cPar.chainLog = maxShortCacheHashLog; + } + } + + + /* At this point, we aren't 100% sure if we are using the row match finder. + * Unless it is explicitly disabled, conservatively assume that it is enabled. + * In this case it will only be disabled for small sources, so shrinking the + * hash log a little bit shouldn't result in any ratio loss. + */ + if (useRowMatchFinder == ZSTD_ps_auto) + useRowMatchFinder = ZSTD_ps_enable; + + /* We can't hash more than 32-bits in total. So that means that we require: + * (hashLog - rowLog + 8) <= 32 + */ + if (ZSTD_rowMatchFinderUsed(cPar.strategy, useRowMatchFinder)) { + /* Switch to 32-entry rows if searchLog is 5 (or more) */ + U32 const rowLog = BOUNDED(4, cPar.searchLog, 6); + U32 const maxRowHashLog = 32 - ZSTD_ROW_HASH_TAG_BITS; + U32 const maxHashLog = maxRowHashLog + rowLog; + assert(cPar.hashLog >= rowLog); + if (cPar.hashLog > maxHashLog) { + cPar.hashLog = maxHashLog; + } + } + return cPar; } @@ -1310,7 +1538,7 @@ ZSTD_adjustCParams(ZSTD_compressionParameters cPar, { cPar = ZSTD_clampCParams(cPar); /* resulting cPar is necessarily valid (all parameters within range) */ if (srcSize == 0) srcSize = ZSTD_CONTENTSIZE_UNKNOWN; - return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown); + return ZSTD_adjustCParams_internal(cPar, srcSize, dictSize, ZSTD_cpm_unknown, ZSTD_ps_auto); } static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, unsigned long long srcSizeHint, size_t dictSize, ZSTD_cParamMode_e mode); @@ -1341,7 +1569,7 @@ ZSTD_compressionParameters ZSTD_getCParamsFromCCtxParams( ZSTD_overrideCParams(&cParams, &CCtxParams->cParams); assert(!ZSTD_checkCParams(cParams)); /* srcSizeHint == 0 means 0 */ - return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode); + return ZSTD_adjustCParams_internal(cParams, srcSizeHint, dictSize, mode, CCtxParams->useRowMatchFinder); } static size_t @@ -1367,10 +1595,10 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, + ZSTD_cwksp_aligned_alloc_size((MaxLL+1) * sizeof(U32)) + ZSTD_cwksp_aligned_alloc_size((MaxOff+1) * sizeof(U32)) + ZSTD_cwksp_aligned_alloc_size((1<strategy, useRowMatchFinder) - ? ZSTD_cwksp_aligned_alloc_size(hSize*sizeof(U16)) + ? ZSTD_cwksp_aligned_alloc_size(hSize) : 0; size_t const optSpace = (forCCtx && (cParams->strategy >= ZSTD_btopt)) ? optPotentialSpace @@ -1386,6 +1614,13 @@ ZSTD_sizeof_matchState(const ZSTD_compressionParameters* const cParams, return tableSpace + optSpace + slackSpace + lazyAdditionalSpace; } +/* Helper function for calculating memory requirements. + * Gives a tighter bound than ZSTD_sequenceBound() by taking minMatch into account. */ +static size_t ZSTD_maxNbSeq(size_t blockSize, unsigned minMatch, int useSequenceProducer) { + U32 const divider = (minMatch==3 || useSequenceProducer) ? 3 : 4; + return blockSize / divider; +} + static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( const ZSTD_compressionParameters* cParams, const ldmParams_t* ldmParams, @@ -1393,12 +1628,13 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( const ZSTD_paramSwitch_e useRowMatchFinder, const size_t buffInSize, const size_t buffOutSize, - const U64 pledgedSrcSize) + const U64 pledgedSrcSize, + int useSequenceProducer, + size_t maxBlockSize) { size_t const windowSize = (size_t) BOUNDED(1ULL, 1ULL << cParams->windowLog, pledgedSrcSize); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); - U32 const divider = (cParams->minMatch==3) ? 3 : 4; - size_t const maxNbSeq = blockSize / divider; + size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(maxBlockSize), windowSize); + size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, cParams->minMatch, useSequenceProducer); size_t const tokenSpace = ZSTD_cwksp_alloc_size(WILDCOPY_OVERLENGTH + blockSize) + ZSTD_cwksp_aligned_alloc_size(maxNbSeq * sizeof(seqDef)) + 3 * ZSTD_cwksp_alloc_size(maxNbSeq * sizeof(BYTE)); @@ -1417,6 +1653,11 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( size_t const cctxSpace = isStatic ? ZSTD_cwksp_alloc_size(sizeof(ZSTD_CCtx)) : 0; + size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); + size_t const externalSeqSpace = useSequenceProducer + ? ZSTD_cwksp_aligned_alloc_size(maxNbExternalSeq * sizeof(ZSTD_Sequence)) + : 0; + size_t const neededSpace = cctxSpace + entropySpace + @@ -1425,7 +1666,8 @@ static size_t ZSTD_estimateCCtxSize_usingCCtxParams_internal( ldmSeqSpace + matchStateSize + tokenSpace + - bufferSpace; + bufferSpace + + externalSeqSpace; DEBUGLOG(5, "estimate workspace : %u", (U32)neededSpace); return neededSpace; @@ -1443,7 +1685,7 @@ size_t ZSTD_estimateCCtxSize_usingCCtxParams(const ZSTD_CCtx_params* params) * be needed. However, we still allocate two 0-sized buffers, which can * take space under ASAN. */ return ZSTD_estimateCCtxSize_usingCCtxParams_internal( - &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN); + &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, 0, 0, ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); } size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams) @@ -1493,7 +1735,7 @@ size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params) RETURN_ERROR_IF(params->nbWorkers > 0, GENERIC, "Estimate CCtx size is supported for single-threaded compression only."); { ZSTD_compressionParameters const cParams = ZSTD_getCParamsFromCCtxParams(params, ZSTD_CONTENTSIZE_UNKNOWN, 0, ZSTD_cpm_noAttachDict); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, (size_t)1 << cParams.windowLog); + size_t const blockSize = MIN(ZSTD_resolveMaxBlockSize(params->maxBlockSize), (size_t)1 << cParams.windowLog); size_t const inBuffSize = (params->inBufferMode == ZSTD_bm_buffered) ? ((size_t)1 << cParams.windowLog) + blockSize : 0; @@ -1504,7 +1746,7 @@ size_t ZSTD_estimateCStreamSize_usingCCtxParams(const ZSTD_CCtx_params* params) return ZSTD_estimateCCtxSize_usingCCtxParams_internal( &cParams, ¶ms->ldmParams, 1, useRowMatchFinder, inBuffSize, outBuffSize, - ZSTD_CONTENTSIZE_UNKNOWN); + ZSTD_CONTENTSIZE_UNKNOWN, ZSTD_hasExtSeqProd(params), params->maxBlockSize); } } @@ -1637,6 +1879,19 @@ typedef enum { ZSTD_resetTarget_CCtx } ZSTD_resetTarget_e; +/* Mixes bits in a 64 bits in a value, based on XXH3_rrmxmx */ +static U64 ZSTD_bitmix(U64 val, U64 len) { + val ^= ZSTD_rotateRight_U64(val, 49) ^ ZSTD_rotateRight_U64(val, 24); + val *= 0x9FB21C651E98DF25ULL; + val ^= (val >> 35) + len ; + val *= 0x9FB21C651E98DF25ULL; + return val ^ (val >> 28); +} + +/* Mixes in the hashSalt and hashSaltEntropy to create a new hashSalt */ +static void ZSTD_advanceHashSalt(ZSTD_matchState_t* ms) { + ms->hashSalt = ZSTD_bitmix(ms->hashSalt, 8) ^ ZSTD_bitmix((U64) ms->hashSaltEntropy, 4); +} static size_t ZSTD_reset_matchState(ZSTD_matchState_t* ms, @@ -1664,6 +1919,7 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, } ms->hashLog3 = hashLog3; + ms->lazySkipping = 0; ZSTD_invalidateMatchState(ms); @@ -1685,22 +1941,19 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, ZSTD_cwksp_clean_tables(ws); } - /* opt parser space */ - if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { - DEBUGLOG(4, "reserving optimal parser space"); - ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxLL+1) * sizeof(unsigned)); - ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxML+1) * sizeof(unsigned)); - ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxOff+1) * sizeof(unsigned)); - ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_match_t)); - ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned(ws, (ZSTD_OPT_NUM+1) * sizeof(ZSTD_optimal_t)); - } - if (ZSTD_rowMatchFinderUsed(cParams->strategy, useRowMatchFinder)) { - { /* Row match finder needs an additional table of hashes ("tags") */ - size_t const tagTableSize = hSize*sizeof(U16); - ms->tagTable = (U16*)ZSTD_cwksp_reserve_aligned(ws, tagTableSize); - if (ms->tagTable) ZSTD_memset(ms->tagTable, 0, tagTableSize); + /* Row match finder needs an additional table of hashes ("tags") */ + size_t const tagTableSize = hSize; + /* We want to generate a new salt in case we reset a Cctx, but we always want to use + * 0 when we reset a Cdict */ + if(forWho == ZSTD_resetTarget_CCtx) { + ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned_init_once(ws, tagTableSize); + ZSTD_advanceHashSalt(ms); + } else { + /* When we are not salting we want to always memset the memory */ + ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned(ws, tagTableSize); + ZSTD_memset(ms->tagTable, 0, tagTableSize); + ms->hashSalt = 0; } { /* Switch to 32-entry rows if searchLog is 5 (or more) */ U32 const rowLog = BOUNDED(4, cParams->searchLog, 6); @@ -1709,6 +1962,17 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms, } } + /* opt parser space */ + if ((forWho == ZSTD_resetTarget_CCtx) && (cParams->strategy >= ZSTD_btopt)) { + DEBUGLOG(4, "reserving optimal parser space"); + ms->opt.litFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (1<opt.litLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxLL+1) * sizeof(unsigned)); + ms->opt.matchLengthFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxML+1) * sizeof(unsigned)); + ms->opt.offCodeFreq = (unsigned*)ZSTD_cwksp_reserve_aligned(ws, (MaxOff+1) * sizeof(unsigned)); + ms->opt.matchTable = (ZSTD_match_t*)ZSTD_cwksp_reserve_aligned(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_match_t)); + ms->opt.priceTable = (ZSTD_optimal_t*)ZSTD_cwksp_reserve_aligned(ws, ZSTD_OPT_SIZE * sizeof(ZSTD_optimal_t)); + } + ms->cParams = *cParams; RETURN_ERROR_IF(ZSTD_cwksp_reserve_failed(ws), memory_allocation, @@ -1768,6 +2032,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, assert(params->useRowMatchFinder != ZSTD_ps_auto); assert(params->useBlockSplitter != ZSTD_ps_auto); assert(params->ldmParams.enableLdm != ZSTD_ps_auto); + assert(params->maxBlockSize != 0); if (params->ldmParams.enableLdm == ZSTD_ps_enable) { /* Adjust long distance matching parameters */ ZSTD_ldm_adjustParameters(&zc->appliedParams.ldmParams, ¶ms->cParams); @@ -1776,9 +2041,8 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, } { size_t const windowSize = MAX(1, (size_t)MIN(((U64)1 << params->cParams.windowLog), pledgedSrcSize)); - size_t const blockSize = MIN(ZSTD_BLOCKSIZE_MAX, windowSize); - U32 const divider = (params->cParams.minMatch==3) ? 3 : 4; - size_t const maxNbSeq = blockSize / divider; + size_t const blockSize = MIN(params->maxBlockSize, windowSize); + size_t const maxNbSeq = ZSTD_maxNbSeq(blockSize, params->cParams.minMatch, ZSTD_hasExtSeqProd(params)); size_t const buffOutSize = (zbuff == ZSTDb_buffered && params->outBufferMode == ZSTD_bm_buffered) ? ZSTD_compressBound(blockSize) + 1 : 0; @@ -1795,8 +2059,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, size_t const neededSpace = ZSTD_estimateCCtxSize_usingCCtxParams_internal( ¶ms->cParams, ¶ms->ldmParams, zc->staticSize != 0, params->useRowMatchFinder, - buffInSize, buffOutSize, pledgedSrcSize); - int resizeWorkspace; + buffInSize, buffOutSize, pledgedSrcSize, ZSTD_hasExtSeqProd(params), params->maxBlockSize); FORWARD_IF_ERROR(neededSpace, "cctx size estimate failed!"); @@ -1805,7 +2068,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, { /* Check if workspace is large enough, alloc a new one if needed */ int const workspaceTooSmall = ZSTD_cwksp_sizeof(ws) < neededSpace; int const workspaceWasteful = ZSTD_cwksp_check_wasteful(ws, neededSpace); - resizeWorkspace = workspaceTooSmall || workspaceWasteful; + int resizeWorkspace = workspaceTooSmall || workspaceWasteful; DEBUGLOG(4, "Need %zu B workspace", neededSpace); DEBUGLOG(4, "windowSize: %zu - blockSize: %zu", windowSize, blockSize); @@ -1838,6 +2101,7 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, /* init params */ zc->blockState.matchState.cParams = params->cParams; + zc->blockState.matchState.prefetchCDictTables = params->prefetchCDictTables == ZSTD_ps_enable; zc->pledgedSrcSizePlusOne = pledgedSrcSize+1; zc->consumedSrcSize = 0; zc->producedCSize = 0; @@ -1854,13 +2118,46 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, ZSTD_reset_compressedBlockState(zc->blockState.prevCBlock); + FORWARD_IF_ERROR(ZSTD_reset_matchState( + &zc->blockState.matchState, + ws, + ¶ms->cParams, + params->useRowMatchFinder, + crp, + needsIndexReset, + ZSTD_resetTarget_CCtx), ""); + + zc->seqStore.sequencesStart = (seqDef*)ZSTD_cwksp_reserve_aligned(ws, maxNbSeq * sizeof(seqDef)); + + /* ldm hash table */ + if (params->ldmParams.enableLdm == ZSTD_ps_enable) { + /* TODO: avoid memset? */ + size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; + zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned(ws, ldmHSize * sizeof(ldmEntry_t)); + ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); + zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned(ws, maxNbLdmSeq * sizeof(rawSeq)); + zc->maxNbLdmSequences = maxNbLdmSeq; + + ZSTD_window_init(&zc->ldmState.window); + zc->ldmState.loadedDictEnd = 0; + } + + /* reserve space for block-level external sequences */ + if (ZSTD_hasExtSeqProd(params)) { + size_t const maxNbExternalSeq = ZSTD_sequenceBound(blockSize); + zc->extSeqBufCapacity = maxNbExternalSeq; + zc->extSeqBuf = + (ZSTD_Sequence*)ZSTD_cwksp_reserve_aligned(ws, maxNbExternalSeq * sizeof(ZSTD_Sequence)); + } + + /* buffers */ + /* ZSTD_wildcopy() is used to copy into the literals buffer, * so we have to oversize the buffer by WILDCOPY_OVERLENGTH bytes. */ zc->seqStore.litStart = ZSTD_cwksp_reserve_buffer(ws, blockSize + WILDCOPY_OVERLENGTH); zc->seqStore.maxNbLit = blockSize; - /* buffers */ zc->bufferedPolicy = zbuff; zc->inBuffSize = buffInSize; zc->inBuff = (char*)ZSTD_cwksp_reserve_buffer(ws, buffInSize); @@ -1883,32 +2180,9 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc, zc->seqStore.llCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); zc->seqStore.mlCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); zc->seqStore.ofCode = ZSTD_cwksp_reserve_buffer(ws, maxNbSeq * sizeof(BYTE)); - zc->seqStore.sequencesStart = (seqDef*)ZSTD_cwksp_reserve_aligned(ws, maxNbSeq * sizeof(seqDef)); - - FORWARD_IF_ERROR(ZSTD_reset_matchState( - &zc->blockState.matchState, - ws, - ¶ms->cParams, - params->useRowMatchFinder, - crp, - needsIndexReset, - ZSTD_resetTarget_CCtx), ""); - - /* ldm hash table */ - if (params->ldmParams.enableLdm == ZSTD_ps_enable) { - /* TODO: avoid memset? */ - size_t const ldmHSize = ((size_t)1) << params->ldmParams.hashLog; - zc->ldmState.hashTable = (ldmEntry_t*)ZSTD_cwksp_reserve_aligned(ws, ldmHSize * sizeof(ldmEntry_t)); - ZSTD_memset(zc->ldmState.hashTable, 0, ldmHSize * sizeof(ldmEntry_t)); - zc->ldmSequences = (rawSeq*)ZSTD_cwksp_reserve_aligned(ws, maxNbLdmSeq * sizeof(rawSeq)); - zc->maxNbLdmSequences = maxNbLdmSeq; - - ZSTD_window_init(&zc->ldmState.window); - zc->ldmState.loadedDictEnd = 0; - } DEBUGLOG(3, "wksp: finished allocating, %zd bytes remain available", ZSTD_cwksp_available_space(ws)); - assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace, resizeWorkspace)); + assert(ZSTD_cwksp_estimated_space_within_bounds(ws, neededSpace)); zc->initialized = 1; @@ -1980,7 +2254,8 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, } params.cParams = ZSTD_adjustCParams_internal(adjusted_cdict_cParams, pledgedSrcSize, - cdict->dictContentSize, ZSTD_cpm_attachDict); + cdict->dictContentSize, ZSTD_cpm_attachDict, + params.useRowMatchFinder); params.cParams.windowLog = windowLog; params.useRowMatchFinder = cdict->useRowMatchFinder; /* cdict overrides */ FORWARD_IF_ERROR(ZSTD_resetCCtx_internal(cctx, ¶ms, pledgedSrcSize, @@ -2019,6 +2294,22 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx, return 0; } +static void ZSTD_copyCDictTableIntoCCtx(U32* dst, U32 const* src, size_t tableSize, + ZSTD_compressionParameters const* cParams) { + if (ZSTD_CDictIndicesAreTagged(cParams)){ + /* Remove tags from the CDict table if they are present. + * See docs on "short cache" in zstd_compress_internal.h for context. */ + size_t i; + for (i = 0; i < tableSize; i++) { + U32 const taggedIndex = src[i]; + U32 const index = taggedIndex >> ZSTD_SHORT_CACHE_TAG_BITS; + dst[i] = index; + } + } else { + ZSTD_memcpy(dst, src, tableSize * sizeof(U32)); + } +} + static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict, ZSTD_CCtx_params params, @@ -2054,21 +2345,23 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx, : 0; size_t const hSize = (size_t)1 << cdict_cParams->hashLog; - ZSTD_memcpy(cctx->blockState.matchState.hashTable, - cdict->matchState.hashTable, - hSize * sizeof(U32)); + ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.hashTable, + cdict->matchState.hashTable, + hSize, cdict_cParams); + /* Do not copy cdict's chainTable if cctx has parameters such that it would not use chainTable */ if (ZSTD_allocateChainTable(cctx->appliedParams.cParams.strategy, cctx->appliedParams.useRowMatchFinder, 0 /* forDDSDict */)) { - ZSTD_memcpy(cctx->blockState.matchState.chainTable, - cdict->matchState.chainTable, - chainSize * sizeof(U32)); + ZSTD_copyCDictTableIntoCCtx(cctx->blockState.matchState.chainTable, + cdict->matchState.chainTable, + chainSize, cdict_cParams); } /* copy tag table */ if (ZSTD_rowMatchFinderUsed(cdict_cParams->strategy, cdict->useRowMatchFinder)) { - size_t const tagTableSize = hSize*sizeof(U16); + size_t const tagTableSize = hSize; ZSTD_memcpy(cctx->blockState.matchState.tagTable, - cdict->matchState.tagTable, - tagTableSize); + cdict->matchState.tagTable, + tagTableSize); + cctx->blockState.matchState.hashSalt = cdict->matchState.hashSalt; } } @@ -2147,6 +2440,7 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx, params.useBlockSplitter = srcCCtx->appliedParams.useBlockSplitter; params.ldmParams = srcCCtx->appliedParams.ldmParams; params.fParams = fParams; + params.maxBlockSize = srcCCtx->appliedParams.maxBlockSize; ZSTD_resetCCtx_internal(dstCCtx, ¶ms, pledgedSrcSize, /* loadedDictSize */ 0, ZSTDcrp_leaveDirty, zbuff); @@ -2294,7 +2588,7 @@ static void ZSTD_reduceIndex (ZSTD_matchState_t* ms, ZSTD_CCtx_params const* par /* See doc/zstd_compression_format.md for detailed format description */ -void ZSTD_seqToCodes(const seqStore_t* seqStorePtr) +int ZSTD_seqToCodes(const seqStore_t* seqStorePtr) { const seqDef* const sequences = seqStorePtr->sequencesStart; BYTE* const llCodeTable = seqStorePtr->llCode; @@ -2302,18 +2596,24 @@ void ZSTD_seqToCodes(const seqStore_t* seqStorePtr) BYTE* const mlCodeTable = seqStorePtr->mlCode; U32 const nbSeq = (U32)(seqStorePtr->sequences - seqStorePtr->sequencesStart); U32 u; + int longOffsets = 0; assert(nbSeq <= seqStorePtr->maxNbSeq); for (u=0; u= STREAM_ACCUMULATOR_MIN)); + if (MEM_32bits() && ofCode >= STREAM_ACCUMULATOR_MIN) + longOffsets = 1; } if (seqStorePtr->longLengthType==ZSTD_llt_literalLength) llCodeTable[seqStorePtr->longLengthPos] = MaxLL; if (seqStorePtr->longLengthType==ZSTD_llt_matchLength) mlCodeTable[seqStorePtr->longLengthPos] = MaxML; + return longOffsets; } /* ZSTD_useTargetCBlockSize(): @@ -2347,6 +2647,7 @@ typedef struct { U32 MLtype; size_t size; size_t lastCountSize; /* Accounts for bug in 1.3.4. More detail in ZSTD_entropyCompressSeqStore_internal() */ + int longOffsets; } ZSTD_symbolEncodingTypeStats_t; /* ZSTD_buildSequencesStatistics(): @@ -2357,11 +2658,13 @@ typedef struct { * entropyWkspSize must be of size at least ENTROPY_WORKSPACE_SIZE - (MaxSeq + 1)*sizeof(U32) */ static ZSTD_symbolEncodingTypeStats_t -ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, - const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, - BYTE* dst, const BYTE* const dstEnd, - ZSTD_strategy strategy, unsigned* countWorkspace, - void* entropyWorkspace, size_t entropyWkspSize) { +ZSTD_buildSequencesStatistics( + const seqStore_t* seqStorePtr, size_t nbSeq, + const ZSTD_fseCTables_t* prevEntropy, ZSTD_fseCTables_t* nextEntropy, + BYTE* dst, const BYTE* const dstEnd, + ZSTD_strategy strategy, unsigned* countWorkspace, + void* entropyWorkspace, size_t entropyWkspSize) +{ BYTE* const ostart = dst; const BYTE* const oend = dstEnd; BYTE* op = ostart; @@ -2375,7 +2678,7 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, stats.lastCountSize = 0; /* convert length/distances into codes */ - ZSTD_seqToCodes(seqStorePtr); + stats.longOffsets = ZSTD_seqToCodes(seqStorePtr); assert(op <= oend); assert(nbSeq != 0); /* ZSTD_selectEncodingType() divides by nbSeq */ /* build CTable for Literal Lengths */ @@ -2480,22 +2783,22 @@ ZSTD_buildSequencesStatistics(seqStore_t* seqStorePtr, size_t nbSeq, */ #define SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO 20 MEM_STATIC size_t -ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - void* dst, size_t dstCapacity, - void* entropyWorkspace, size_t entropyWkspSize, - const int bmi2) +ZSTD_entropyCompressSeqStore_internal( + const seqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + void* dst, size_t dstCapacity, + void* entropyWorkspace, size_t entropyWkspSize, + const int bmi2) { - const int longOffsets = cctxParams->cParams.windowLog > STREAM_ACCUMULATOR_MIN; ZSTD_strategy const strategy = cctxParams->cParams.strategy; unsigned* count = (unsigned*)entropyWorkspace; FSE_CTable* CTable_LitLength = nextEntropy->fse.litlengthCTable; FSE_CTable* CTable_OffsetBits = nextEntropy->fse.offcodeCTable; FSE_CTable* CTable_MatchLength = nextEntropy->fse.matchlengthCTable; const seqDef* const sequences = seqStorePtr->sequencesStart; - const size_t nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; + const size_t nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); const BYTE* const ofCodeTable = seqStorePtr->ofCode; const BYTE* const llCodeTable = seqStorePtr->llCode; const BYTE* const mlCodeTable = seqStorePtr->mlCode; @@ -2503,29 +2806,31 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, BYTE* const oend = ostart + dstCapacity; BYTE* op = ostart; size_t lastCountSize; + int longOffsets = 0; entropyWorkspace = count + (MaxSeq + 1); entropyWkspSize -= (MaxSeq + 1) * sizeof(*count); - DEBUGLOG(4, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu)", nbSeq); + DEBUGLOG(5, "ZSTD_entropyCompressSeqStore_internal (nbSeq=%zu, dstCapacity=%zu)", nbSeq, dstCapacity); ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<= HUF_WORKSPACE_SIZE); /* Compress literals */ { const BYTE* const literals = seqStorePtr->litStart; - size_t const numSequences = seqStorePtr->sequences - seqStorePtr->sequencesStart; - size_t const numLiterals = seqStorePtr->lit - seqStorePtr->litStart; + size_t const numSequences = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); + size_t const numLiterals = (size_t)(seqStorePtr->lit - seqStorePtr->litStart); /* Base suspicion of uncompressibility on ratio of literals to sequences */ unsigned const suspectUncompressible = (numSequences == 0) || (numLiterals / numSequences >= SUSPECT_UNCOMPRESSIBLE_LITERAL_RATIO); size_t const litSize = (size_t)(seqStorePtr->lit - literals); + size_t const cSize = ZSTD_compressLiterals( - &prevEntropy->huf, &nextEntropy->huf, - cctxParams->cParams.strategy, - ZSTD_literalsCompressionIsDisabled(cctxParams), op, dstCapacity, literals, litSize, entropyWorkspace, entropyWkspSize, - bmi2, suspectUncompressible); + &prevEntropy->huf, &nextEntropy->huf, + cctxParams->cParams.strategy, + ZSTD_literalsCompressionIsDisabled(cctxParams), + suspectUncompressible, bmi2); FORWARD_IF_ERROR(cSize, "ZSTD_compressLiterals failed"); assert(cSize <= dstCapacity); op += cSize; @@ -2551,11 +2856,10 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, ZSTD_memcpy(&nextEntropy->fse, &prevEntropy->fse, sizeof(prevEntropy->fse)); return (size_t)(op - ostart); } - { - ZSTD_symbolEncodingTypeStats_t stats; - BYTE* seqHead = op++; + { BYTE* const seqHead = op++; /* build stats for sequences */ - stats = ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, + const ZSTD_symbolEncodingTypeStats_t stats = + ZSTD_buildSequencesStatistics(seqStorePtr, nbSeq, &prevEntropy->fse, &nextEntropy->fse, op, oend, strategy, count, @@ -2564,6 +2868,7 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, *seqHead = (BYTE)((stats.LLtype<<6) + (stats.Offtype<<4) + (stats.MLtype<<2)); lastCountSize = stats.lastCountSize; op += stats.size; + longOffsets = stats.longOffsets; } { size_t const bitstreamSize = ZSTD_encodeSequences( @@ -2598,14 +2903,15 @@ ZSTD_entropyCompressSeqStore_internal(seqStore_t* seqStorePtr, } MEM_STATIC size_t -ZSTD_entropyCompressSeqStore(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - void* dst, size_t dstCapacity, - size_t srcSize, - void* entropyWorkspace, size_t entropyWkspSize, - int bmi2) +ZSTD_entropyCompressSeqStore( + const seqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + void* dst, size_t dstCapacity, + size_t srcSize, + void* entropyWorkspace, size_t entropyWkspSize, + int bmi2) { size_t const cSize = ZSTD_entropyCompressSeqStore_internal( seqStorePtr, prevEntropy, nextEntropy, cctxParams, @@ -2615,15 +2921,21 @@ ZSTD_entropyCompressSeqStore(seqStore_t* seqStorePtr, /* When srcSize <= dstCapacity, there is enough space to write a raw uncompressed block. * Since we ran out of space, block must be not compressible, so fall back to raw uncompressed block. */ - if ((cSize == ERROR(dstSize_tooSmall)) & (srcSize <= dstCapacity)) + if ((cSize == ERROR(dstSize_tooSmall)) & (srcSize <= dstCapacity)) { + DEBUGLOG(4, "not enough dstCapacity (%zu) for ZSTD_entropyCompressSeqStore_internal()=> do not compress block", dstCapacity); return 0; /* block not compressed */ + } FORWARD_IF_ERROR(cSize, "ZSTD_entropyCompressSeqStore_internal failed"); /* Check compressibility */ { size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, cctxParams->cParams.strategy); if (cSize >= maxCSize) return 0; /* block not compressed */ } - DEBUGLOG(4, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); + DEBUGLOG(5, "ZSTD_entropyCompressSeqStore() cSize: %zu", cSize); + /* libzstd decoder before > v1.5.4 is not compatible with compressed blocks of size ZSTD_BLOCKSIZE_MAX exactly. + * This restriction is indirectly already fulfilled by respecting ZSTD_minGain() condition above. + */ + assert(cSize < ZSTD_BLOCKSIZE_MAX); return cSize; } @@ -2635,40 +2947,43 @@ ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramS static const ZSTD_blockCompressor blockCompressor[4][ZSTD_STRATEGY_MAX+1] = { { ZSTD_compressBlock_fast /* default for 0 */, ZSTD_compressBlock_fast, - ZSTD_compressBlock_doubleFast, - ZSTD_compressBlock_greedy, - ZSTD_compressBlock_lazy, - ZSTD_compressBlock_lazy2, - ZSTD_compressBlock_btlazy2, - ZSTD_compressBlock_btopt, - ZSTD_compressBlock_btultra, - ZSTD_compressBlock_btultra2 }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST, + ZSTD_COMPRESSBLOCK_GREEDY, + ZSTD_COMPRESSBLOCK_LAZY, + ZSTD_COMPRESSBLOCK_LAZY2, + ZSTD_COMPRESSBLOCK_BTLAZY2, + ZSTD_COMPRESSBLOCK_BTOPT, + ZSTD_COMPRESSBLOCK_BTULTRA, + ZSTD_COMPRESSBLOCK_BTULTRA2 + }, { ZSTD_compressBlock_fast_extDict /* default for 0 */, ZSTD_compressBlock_fast_extDict, - ZSTD_compressBlock_doubleFast_extDict, - ZSTD_compressBlock_greedy_extDict, - ZSTD_compressBlock_lazy_extDict, - ZSTD_compressBlock_lazy2_extDict, - ZSTD_compressBlock_btlazy2_extDict, - ZSTD_compressBlock_btopt_extDict, - ZSTD_compressBlock_btultra_extDict, - ZSTD_compressBlock_btultra_extDict }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT, + ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT, + ZSTD_COMPRESSBLOCK_LAZY_EXTDICT, + ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT, + ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT, + ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT, + ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT, + ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT + }, { ZSTD_compressBlock_fast_dictMatchState /* default for 0 */, ZSTD_compressBlock_fast_dictMatchState, - ZSTD_compressBlock_doubleFast_dictMatchState, - ZSTD_compressBlock_greedy_dictMatchState, - ZSTD_compressBlock_lazy_dictMatchState, - ZSTD_compressBlock_lazy2_dictMatchState, - ZSTD_compressBlock_btlazy2_dictMatchState, - ZSTD_compressBlock_btopt_dictMatchState, - ZSTD_compressBlock_btultra_dictMatchState, - ZSTD_compressBlock_btultra_dictMatchState }, + ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE, + ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE + }, { NULL /* default for 0 */, NULL, NULL, - ZSTD_compressBlock_greedy_dedicatedDictSearch, - ZSTD_compressBlock_lazy_dedicatedDictSearch, - ZSTD_compressBlock_lazy2_dedicatedDictSearch, + ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH, + ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH, + ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH, NULL, NULL, NULL, @@ -2681,18 +2996,26 @@ ZSTD_blockCompressor ZSTD_selectBlockCompressor(ZSTD_strategy strat, ZSTD_paramS DEBUGLOG(4, "Selected block compressor: dictMode=%d strat=%d rowMatchfinder=%d", (int)dictMode, (int)strat, (int)useRowMatchFinder); if (ZSTD_rowMatchFinderUsed(strat, useRowMatchFinder)) { static const ZSTD_blockCompressor rowBasedBlockCompressors[4][3] = { - { ZSTD_compressBlock_greedy_row, - ZSTD_compressBlock_lazy_row, - ZSTD_compressBlock_lazy2_row }, - { ZSTD_compressBlock_greedy_extDict_row, - ZSTD_compressBlock_lazy_extDict_row, - ZSTD_compressBlock_lazy2_extDict_row }, - { ZSTD_compressBlock_greedy_dictMatchState_row, - ZSTD_compressBlock_lazy_dictMatchState_row, - ZSTD_compressBlock_lazy2_dictMatchState_row }, - { ZSTD_compressBlock_greedy_dedicatedDictSearch_row, - ZSTD_compressBlock_lazy_dedicatedDictSearch_row, - ZSTD_compressBlock_lazy2_dedicatedDictSearch_row } + { + ZSTD_COMPRESSBLOCK_GREEDY_ROW, + ZSTD_COMPRESSBLOCK_LAZY_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW, + ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW, + ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW + }, + { + ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW, + ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW, + ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW + } }; DEBUGLOG(4, "Selecting a row-based matchfinder"); assert(useRowMatchFinder != ZSTD_ps_auto); @@ -2718,6 +3041,72 @@ void ZSTD_resetSeqStore(seqStore_t* ssPtr) ssPtr->longLengthType = ZSTD_llt_none; } +/* ZSTD_postProcessSequenceProducerResult() : + * Validates and post-processes sequences obtained through the external matchfinder API: + * - Checks whether nbExternalSeqs represents an error condition. + * - Appends a block delimiter to outSeqs if one is not already present. + * See zstd.h for context regarding block delimiters. + * Returns the number of sequences after post-processing, or an error code. */ +static size_t ZSTD_postProcessSequenceProducerResult( + ZSTD_Sequence* outSeqs, size_t nbExternalSeqs, size_t outSeqsCapacity, size_t srcSize +) { + RETURN_ERROR_IF( + nbExternalSeqs > outSeqsCapacity, + sequenceProducer_failed, + "External sequence producer returned error code %lu", + (unsigned long)nbExternalSeqs + ); + + RETURN_ERROR_IF( + nbExternalSeqs == 0 && srcSize > 0, + sequenceProducer_failed, + "Got zero sequences from external sequence producer for a non-empty src buffer!" + ); + + if (srcSize == 0) { + ZSTD_memset(&outSeqs[0], 0, sizeof(ZSTD_Sequence)); + return 1; + } + + { + ZSTD_Sequence const lastSeq = outSeqs[nbExternalSeqs - 1]; + + /* We can return early if lastSeq is already a block delimiter. */ + if (lastSeq.offset == 0 && lastSeq.matchLength == 0) { + return nbExternalSeqs; + } + + /* This error condition is only possible if the external matchfinder + * produced an invalid parse, by definition of ZSTD_sequenceBound(). */ + RETURN_ERROR_IF( + nbExternalSeqs == outSeqsCapacity, + sequenceProducer_failed, + "nbExternalSeqs == outSeqsCapacity but lastSeq is not a block delimiter!" + ); + + /* lastSeq is not a block delimiter, so we need to append one. */ + ZSTD_memset(&outSeqs[nbExternalSeqs], 0, sizeof(ZSTD_Sequence)); + return nbExternalSeqs + 1; + } +} + +/* ZSTD_fastSequenceLengthSum() : + * Returns sum(litLen) + sum(matchLen) + lastLits for *seqBuf*. + * Similar to another function in zstd_compress.c (determine_blockSize), + * except it doesn't check for a block delimiter to end summation. + * Removing the early exit allows the compiler to auto-vectorize (https://godbolt.org/z/cY1cajz9P). + * This function can be deleted and replaced by determine_blockSize after we resolve issue #3456. */ +static size_t ZSTD_fastSequenceLengthSum(ZSTD_Sequence const* seqBuf, size_t seqBufSize) { + size_t matchLenSum, litLenSum, i; + matchLenSum = 0; + litLenSum = 0; + for (i = 0; i < seqBufSize; i++) { + litLenSum += seqBuf[i].litLength; + matchLenSum += seqBuf[i].matchLength; + } + return litLenSum + matchLenSum; +} + typedef enum { ZSTDbss_compress, ZSTDbss_noCompress } ZSTD_buildSeqStore_e; static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) @@ -2727,7 +3116,9 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) assert(srcSize <= ZSTD_BLOCKSIZE_MAX); /* Assert that we have correctly flushed the ctx params into the ms's copy */ ZSTD_assertEqualCParams(zc->appliedParams.cParams, ms->cParams); - if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + if (srcSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { if (zc->appliedParams.cParams.strategy >= ZSTD_btopt) { ZSTD_ldm_skipRawSeqStoreBytes(&zc->externSeqStore, srcSize); } else { @@ -2763,6 +3154,15 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) } if (zc->externSeqStore.pos < zc->externSeqStore.size) { assert(zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_disable); + + /* External matchfinder + LDM is technically possible, just not implemented yet. + * We need to revisit soon and implement it. */ + RETURN_ERROR_IF( + ZSTD_hasExtSeqProd(&zc->appliedParams), + parameter_combination_unsupported, + "Long-distance matching with external sequence producer enabled is not currently supported." + ); + /* Updates ldmSeqStore.pos */ lastLLSize = ZSTD_ldm_blockCompress(&zc->externSeqStore, @@ -2774,6 +3174,14 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) } else if (zc->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable) { rawSeqStore_t ldmSeqStore = kNullRawSeqStore; + /* External matchfinder + LDM is technically possible, just not implemented yet. + * We need to revisit soon and implement it. */ + RETURN_ERROR_IF( + ZSTD_hasExtSeqProd(&zc->appliedParams), + parameter_combination_unsupported, + "Long-distance matching with external sequence producer enabled is not currently supported." + ); + ldmSeqStore.seq = zc->ldmSequences; ldmSeqStore.capacity = zc->maxNbLdmSequences; /* Updates ldmSeqStore.size */ @@ -2788,10 +3196,74 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) zc->appliedParams.useRowMatchFinder, src, srcSize); assert(ldmSeqStore.pos == ldmSeqStore.size); - } else { /* not long range mode */ - ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor(zc->appliedParams.cParams.strategy, - zc->appliedParams.useRowMatchFinder, - dictMode); + } else if (ZSTD_hasExtSeqProd(&zc->appliedParams)) { + assert( + zc->extSeqBufCapacity >= ZSTD_sequenceBound(srcSize) + ); + assert(zc->appliedParams.extSeqProdFunc != NULL); + + { U32 const windowSize = (U32)1 << zc->appliedParams.cParams.windowLog; + + size_t const nbExternalSeqs = (zc->appliedParams.extSeqProdFunc)( + zc->appliedParams.extSeqProdState, + zc->extSeqBuf, + zc->extSeqBufCapacity, + src, srcSize, + NULL, 0, /* dict and dictSize, currently not supported */ + zc->appliedParams.compressionLevel, + windowSize + ); + + size_t const nbPostProcessedSeqs = ZSTD_postProcessSequenceProducerResult( + zc->extSeqBuf, + nbExternalSeqs, + zc->extSeqBufCapacity, + srcSize + ); + + /* Return early if there is no error, since we don't need to worry about last literals */ + if (!ZSTD_isError(nbPostProcessedSeqs)) { + ZSTD_sequencePosition seqPos = {0,0,0}; + size_t const seqLenSum = ZSTD_fastSequenceLengthSum(zc->extSeqBuf, nbPostProcessedSeqs); + RETURN_ERROR_IF(seqLenSum > srcSize, externalSequences_invalid, "External sequences imply too large a block!"); + FORWARD_IF_ERROR( + ZSTD_copySequencesToSeqStoreExplicitBlockDelim( + zc, &seqPos, + zc->extSeqBuf, nbPostProcessedSeqs, + src, srcSize, + zc->appliedParams.searchForExternalRepcodes + ), + "Failed to copy external sequences to seqStore!" + ); + ms->ldmSeqStore = NULL; + DEBUGLOG(5, "Copied %lu sequences from external sequence producer to internal seqStore.", (unsigned long)nbExternalSeqs); + return ZSTDbss_compress; + } + + /* Propagate the error if fallback is disabled */ + if (!zc->appliedParams.enableMatchFinderFallback) { + return nbPostProcessedSeqs; + } + + /* Fallback to software matchfinder */ + { ZSTD_blockCompressor const blockCompressor = + ZSTD_selectBlockCompressor( + zc->appliedParams.cParams.strategy, + zc->appliedParams.useRowMatchFinder, + dictMode); + ms->ldmSeqStore = NULL; + DEBUGLOG( + 5, + "External sequence producer returned error code %lu. Falling back to internal parser.", + (unsigned long)nbExternalSeqs + ); + lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); + } } + } else { /* not long range mode and no external matchfinder */ + ZSTD_blockCompressor const blockCompressor = ZSTD_selectBlockCompressor( + zc->appliedParams.cParams.strategy, + zc->appliedParams.useRowMatchFinder, + dictMode); ms->ldmSeqStore = NULL; lastLLSize = blockCompressor(ms, &zc->seqStore, zc->blockState.nextCBlock->rep, src, srcSize); } @@ -2801,29 +3273,38 @@ static size_t ZSTD_buildSeqStore(ZSTD_CCtx* zc, const void* src, size_t srcSize) return ZSTDbss_compress; } -static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) +static size_t ZSTD_copyBlockSequences(SeqCollector* seqCollector, const seqStore_t* seqStore, const U32 prevRepcodes[ZSTD_REP_NUM]) { - const seqStore_t* seqStore = ZSTD_getSeqStore(zc); - const seqDef* seqStoreSeqs = seqStore->sequencesStart; - size_t seqStoreSeqSize = seqStore->sequences - seqStoreSeqs; - size_t seqStoreLiteralsSize = (size_t)(seqStore->lit - seqStore->litStart); - size_t literalsRead = 0; - size_t lastLLSize; + const seqDef* inSeqs = seqStore->sequencesStart; + const size_t nbInSequences = seqStore->sequences - inSeqs; + const size_t nbInLiterals = (size_t)(seqStore->lit - seqStore->litStart); - ZSTD_Sequence* outSeqs = &zc->seqCollector.seqStart[zc->seqCollector.seqIndex]; + ZSTD_Sequence* outSeqs = seqCollector->seqIndex == 0 ? seqCollector->seqStart : seqCollector->seqStart + seqCollector->seqIndex; + const size_t nbOutSequences = nbInSequences + 1; + size_t nbOutLiterals = 0; + repcodes_t repcodes; size_t i; - repcodes_t updatedRepcodes; - assert(zc->seqCollector.seqIndex + 1 < zc->seqCollector.maxSequences); - /* Ensure we have enough space for last literals "sequence" */ - assert(zc->seqCollector.maxSequences >= seqStoreSeqSize + 1); - ZSTD_memcpy(updatedRepcodes.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); - for (i = 0; i < seqStoreSeqSize; ++i) { - U32 rawOffset = seqStoreSeqs[i].offBase - ZSTD_REP_NUM; - outSeqs[i].litLength = seqStoreSeqs[i].litLength; - outSeqs[i].matchLength = seqStoreSeqs[i].mlBase + MINMATCH; + /* Bounds check that we have enough space for every input sequence + * and the block delimiter + */ + assert(seqCollector->seqIndex <= seqCollector->maxSequences); + RETURN_ERROR_IF( + nbOutSequences > (size_t)(seqCollector->maxSequences - seqCollector->seqIndex), + dstSize_tooSmall, + "Not enough space to copy sequences"); + + ZSTD_memcpy(&repcodes, prevRepcodes, sizeof(repcodes)); + for (i = 0; i < nbInSequences; ++i) { + U32 rawOffset; + outSeqs[i].litLength = inSeqs[i].litLength; + outSeqs[i].matchLength = inSeqs[i].mlBase + MINMATCH; outSeqs[i].rep = 0; + /* Handle the possible single length >= 64K + * There can only be one because we add MINMATCH to every match length, + * and blocks are at most 128K. + */ if (i == seqStore->longLengthPos) { if (seqStore->longLengthType == ZSTD_llt_literalLength) { outSeqs[i].litLength += 0x10000; @@ -2832,37 +3313,55 @@ static void ZSTD_copyBlockSequences(ZSTD_CCtx* zc) } } - if (seqStoreSeqs[i].offBase <= ZSTD_REP_NUM) { - /* Derive the correct offset corresponding to a repcode */ - outSeqs[i].rep = seqStoreSeqs[i].offBase; + /* Determine the raw offset given the offBase, which may be a repcode. */ + if (OFFBASE_IS_REPCODE(inSeqs[i].offBase)) { + const U32 repcode = OFFBASE_TO_REPCODE(inSeqs[i].offBase); + assert(repcode > 0); + outSeqs[i].rep = repcode; if (outSeqs[i].litLength != 0) { - rawOffset = updatedRepcodes.rep[outSeqs[i].rep - 1]; + rawOffset = repcodes.rep[repcode - 1]; } else { - if (outSeqs[i].rep == 3) { - rawOffset = updatedRepcodes.rep[0] - 1; + if (repcode == 3) { + assert(repcodes.rep[0] > 1); + rawOffset = repcodes.rep[0] - 1; } else { - rawOffset = updatedRepcodes.rep[outSeqs[i].rep]; + rawOffset = repcodes.rep[repcode]; } } + } else { + rawOffset = OFFBASE_TO_OFFSET(inSeqs[i].offBase); } outSeqs[i].offset = rawOffset; - /* seqStoreSeqs[i].offset == offCode+1, and ZSTD_updateRep() expects offCode - so we provide seqStoreSeqs[i].offset - 1 */ - ZSTD_updateRep(updatedRepcodes.rep, - seqStoreSeqs[i].offBase - 1, - seqStoreSeqs[i].litLength == 0); - literalsRead += outSeqs[i].litLength; + + /* Update repcode history for the sequence */ + ZSTD_updateRep(repcodes.rep, + inSeqs[i].offBase, + inSeqs[i].litLength == 0); + + nbOutLiterals += outSeqs[i].litLength; } /* Insert last literals (if any exist) in the block as a sequence with ml == off == 0. * If there are no last literals, then we'll emit (of: 0, ml: 0, ll: 0), which is a marker * for the block boundary, according to the API. */ - assert(seqStoreLiteralsSize >= literalsRead); - lastLLSize = seqStoreLiteralsSize - literalsRead; - outSeqs[i].litLength = (U32)lastLLSize; - outSeqs[i].matchLength = outSeqs[i].offset = outSeqs[i].rep = 0; - seqStoreSeqSize++; - zc->seqCollector.seqIndex += seqStoreSeqSize; + assert(nbInLiterals >= nbOutLiterals); + { + const size_t lastLLSize = nbInLiterals - nbOutLiterals; + outSeqs[nbInSequences].litLength = (U32)lastLLSize; + outSeqs[nbInSequences].matchLength = 0; + outSeqs[nbInSequences].offset = 0; + assert(nbOutSequences == nbInSequences + 1); + } + seqCollector->seqIndex += nbOutSequences; + assert(seqCollector->seqIndex <= seqCollector->maxSequences); + + return 0; +} + +size_t ZSTD_sequenceBound(size_t srcSize) { + const size_t maxNbSeq = (srcSize / ZSTD_MINMATCH_MIN) + 1; + const size_t maxNbDelims = (srcSize / ZSTD_BLOCKSIZE_MAX_MIN) + 1; + return maxNbSeq + maxNbDelims; } size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, @@ -2871,6 +3370,16 @@ size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, const size_t dstCapacity = ZSTD_compressBound(srcSize); void* dst = ZSTD_customMalloc(dstCapacity, ZSTD_defaultCMem); SeqCollector seqCollector; + { + int targetCBlockSize; + FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_targetCBlockSize, &targetCBlockSize), ""); + RETURN_ERROR_IF(targetCBlockSize != 0, parameter_unsupported, "targetCBlockSize != 0"); + } + { + int nbWorkers; + FORWARD_IF_ERROR(ZSTD_CCtx_getParameter(zc, ZSTD_c_nbWorkers, &nbWorkers), ""); + RETURN_ERROR_IF(nbWorkers != 0, parameter_unsupported, "nbWorkers != 0"); + } RETURN_ERROR_IF(dst == NULL, memory_allocation, "NULL pointer!"); @@ -2880,8 +3389,12 @@ size_t ZSTD_generateSequences(ZSTD_CCtx* zc, ZSTD_Sequence* outSeqs, seqCollector.maxSequences = outSeqsSize; zc->seqCollector = seqCollector; - ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); - ZSTD_customFree(dst, ZSTD_defaultCMem); + { + const size_t ret = ZSTD_compress2(zc, dst, dstCapacity, src, srcSize); + ZSTD_customFree(dst, ZSTD_defaultCMem); + FORWARD_IF_ERROR(ret, "ZSTD_compress2 failed"); + } + assert(zc->seqCollector.seqIndex <= ZSTD_sequenceBound(srcSize)); return zc->seqCollector.seqIndex; } @@ -2910,19 +3423,17 @@ static int ZSTD_isRLE(const BYTE* src, size_t length) { const size_t unrollMask = unrollSize - 1; const size_t prefixLength = length & unrollMask; size_t i; - size_t u; if (length == 1) return 1; /* Check if prefix is RLE first before using unrolled loop */ if (prefixLength && ZSTD_count(ip+1, ip, ip+prefixLength) != prefixLength-1) { return 0; } for (i = prefixLength; i != length; i += unrollSize) { + size_t u; for (u = 0; u < unrollSize; u += sizeof(size_t)) { if (MEM_readST(ip + i + u) != valueST) { return 0; - } - } - } + } } } return 1; } @@ -2938,7 +3449,8 @@ static int ZSTD_maybeRLE(seqStore_t const* seqStore) return nbSeqs < 4 && nbLits < 10; } -static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) +static void +ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* const bs) { ZSTD_compressedBlockState_t* const tmp = bs->prevCBlock; bs->prevCBlock = bs->nextCBlock; @@ -2946,7 +3458,9 @@ static void ZSTD_blockState_confirmRepcodesAndEntropyTables(ZSTD_blockState_t* c } /* Writes the block header */ -static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) { +static void +writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastBlock) +{ U32 const cBlockHeader = cSize == 1 ? lastBlock + (((U32)bt_rle)<<1) + (U32)(blockSize << 3) : lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); @@ -2959,13 +3473,16 @@ static void writeBlockHeader(void* op, size_t cSize, size_t blockSize, U32 lastB * Stores literals block type (raw, rle, compressed, repeat) and * huffman description table to hufMetadata. * Requires ENTROPY_WORKSPACE_SIZE workspace - * @return : size of huffman description table or error code */ -static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, - const ZSTD_hufCTables_t* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_hufCTablesMetadata_t* hufMetadata, - const int literalsCompressionIsDisabled, - void* workspace, size_t wkspSize) + * @return : size of huffman description table, or an error code + */ +static size_t +ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSize, + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_hufCTablesMetadata_t* hufMetadata, + const int literalsCompressionIsDisabled, + void* workspace, size_t wkspSize, + int hufFlags) { BYTE* const wkspStart = (BYTE*)workspace; BYTE* const wkspEnd = wkspStart + wkspSize; @@ -2973,9 +3490,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi unsigned* const countWksp = (unsigned*)workspace; const size_t countWkspSize = (HUF_SYMBOLVALUE_MAX + 1) * sizeof(unsigned); BYTE* const nodeWksp = countWkspStart + countWkspSize; - const size_t nodeWkspSize = wkspEnd-nodeWksp; + const size_t nodeWkspSize = (size_t)(wkspEnd - nodeWksp); unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; - unsigned huffLog = HUF_TABLELOG_DEFAULT; + unsigned huffLog = LitHufLog; HUF_repeat repeat = prevHuf->repeatMode; DEBUGLOG(5, "ZSTD_buildBlockEntropyStats_literals (srcSize=%zu)", srcSize); @@ -2990,73 +3507,77 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi /* small ? don't even attempt compression (speed opt) */ #ifndef COMPRESS_LITERALS_SIZE_MIN -#define COMPRESS_LITERALS_SIZE_MIN 63 +# define COMPRESS_LITERALS_SIZE_MIN 63 /* heuristic */ #endif { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; if (srcSize <= minLitSize) { DEBUGLOG(5, "set_basic - too small"); hufMetadata->hType = set_basic; return 0; - } - } + } } /* Scan input and build symbol stats */ - { size_t const largest = HIST_count_wksp (countWksp, &maxSymbolValue, (const BYTE*)src, srcSize, workspace, wkspSize); + { size_t const largest = + HIST_count_wksp (countWksp, &maxSymbolValue, + (const BYTE*)src, srcSize, + workspace, wkspSize); FORWARD_IF_ERROR(largest, "HIST_count_wksp failed"); if (largest == srcSize) { + /* only one literal symbol */ DEBUGLOG(5, "set_rle"); hufMetadata->hType = set_rle; return 0; } if (largest <= (srcSize >> 7)+4) { + /* heuristic: likely not compressible */ DEBUGLOG(5, "set_basic - no gain"); hufMetadata->hType = set_basic; return 0; - } - } + } } /* Validate the previous Huffman table */ - if (repeat == HUF_repeat_check && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { + if (repeat == HUF_repeat_check + && !HUF_validateCTable((HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue)) { repeat = HUF_repeat_none; } /* Build Huffman Tree */ ZSTD_memset(nextHuf->CTable, 0, sizeof(nextHuf->CTable)); - huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue); + huffLog = HUF_optimalTableLog(huffLog, srcSize, maxSymbolValue, nodeWksp, nodeWkspSize, nextHuf->CTable, countWksp, hufFlags); + assert(huffLog <= LitHufLog); { size_t const maxBits = HUF_buildCTable_wksp((HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue, huffLog, nodeWksp, nodeWkspSize); FORWARD_IF_ERROR(maxBits, "HUF_buildCTable_wksp"); huffLog = (U32)maxBits; - { /* Build and write the CTable */ - size_t const newCSize = HUF_estimateCompressedSize( - (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); - size_t const hSize = HUF_writeCTable_wksp( - hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), - (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, - nodeWksp, nodeWkspSize); - /* Check against repeating the previous CTable */ - if (repeat != HUF_repeat_none) { - size_t const oldCSize = HUF_estimateCompressedSize( - (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); - if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { - DEBUGLOG(5, "set_repeat - smaller"); - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - hufMetadata->hType = set_repeat; - return 0; - } - } - if (newCSize + hSize >= srcSize) { - DEBUGLOG(5, "set_basic - no gains"); + } + { /* Build and write the CTable */ + size_t const newCSize = HUF_estimateCompressedSize( + (HUF_CElt*)nextHuf->CTable, countWksp, maxSymbolValue); + size_t const hSize = HUF_writeCTable_wksp( + hufMetadata->hufDesBuffer, sizeof(hufMetadata->hufDesBuffer), + (HUF_CElt*)nextHuf->CTable, maxSymbolValue, huffLog, + nodeWksp, nodeWkspSize); + /* Check against repeating the previous CTable */ + if (repeat != HUF_repeat_none) { + size_t const oldCSize = HUF_estimateCompressedSize( + (HUF_CElt const*)prevHuf->CTable, countWksp, maxSymbolValue); + if (oldCSize < srcSize && (oldCSize <= hSize + newCSize || hSize + 12 >= srcSize)) { + DEBUGLOG(5, "set_repeat - smaller"); ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - hufMetadata->hType = set_basic; + hufMetadata->hType = set_repeat; return 0; - } - DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); - hufMetadata->hType = set_compressed; - nextHuf->repeatMode = HUF_repeat_check; - return hSize; + } } + if (newCSize + hSize >= srcSize) { + DEBUGLOG(5, "set_basic - no gains"); + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + hufMetadata->hType = set_basic; + return 0; } + DEBUGLOG(5, "set_compressed (hSize=%u)", (U32)hSize); + hufMetadata->hType = set_compressed; + nextHuf->repeatMode = HUF_repeat_check; + return hSize; } } @@ -3066,8 +3587,9 @@ static size_t ZSTD_buildBlockEntropyStats_literals(void* const src, size_t srcSi * and updates nextEntropy to the appropriate repeatMode. */ static ZSTD_symbolEncodingTypeStats_t -ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { - ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0}; +ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) +{ + ZSTD_symbolEncodingTypeStats_t stats = {set_basic, set_basic, set_basic, 0, 0, 0}; nextEntropy->litlength_repeatMode = FSE_repeat_none; nextEntropy->offcode_repeatMode = FSE_repeat_none; nextEntropy->matchlength_repeatMode = FSE_repeat_none; @@ -3078,16 +3600,18 @@ ZSTD_buildDummySequencesStatistics(ZSTD_fseCTables_t* nextEntropy) { * Builds entropy for the sequences. * Stores symbol compression modes and fse table to fseMetadata. * Requires ENTROPY_WORKSPACE_SIZE wksp. - * @return : size of fse tables or error code */ -static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, - const ZSTD_fseCTables_t* prevEntropy, - ZSTD_fseCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_fseCTablesMetadata_t* fseMetadata, - void* workspace, size_t wkspSize) + * @return : size of fse tables or error code */ +static size_t +ZSTD_buildBlockEntropyStats_sequences( + const seqStore_t* seqStorePtr, + const ZSTD_fseCTables_t* prevEntropy, + ZSTD_fseCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_fseCTablesMetadata_t* fseMetadata, + void* workspace, size_t wkspSize) { ZSTD_strategy const strategy = cctxParams->cParams.strategy; - size_t const nbSeq = seqStorePtr->sequences - seqStorePtr->sequencesStart; + size_t const nbSeq = (size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart); BYTE* const ostart = fseMetadata->fseTablesBuffer; BYTE* const oend = ostart + sizeof(fseMetadata->fseTablesBuffer); BYTE* op = ostart; @@ -3114,23 +3638,28 @@ static size_t ZSTD_buildBlockEntropyStats_sequences(seqStore_t* seqStorePtr, /* ZSTD_buildBlockEntropyStats() : * Builds entropy for the block. * Requires workspace size ENTROPY_WORKSPACE_SIZE - * - * @return : 0 on success or error code + * @return : 0 on success, or an error code + * Note : also employed in superblock */ -size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize) -{ - size_t const litSize = seqStorePtr->lit - seqStorePtr->litStart; +size_t ZSTD_buildBlockEntropyStats( + const seqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize) +{ + size_t const litSize = (size_t)(seqStorePtr->lit - seqStorePtr->litStart); + int const huf_useOptDepth = (cctxParams->cParams.strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD); + int const hufFlags = huf_useOptDepth ? HUF_flags_optimalDepth : 0; + entropyMetadata->hufMetadata.hufDesSize = ZSTD_buildBlockEntropyStats_literals(seqStorePtr->litStart, litSize, &prevEntropy->huf, &nextEntropy->huf, &entropyMetadata->hufMetadata, ZSTD_literalsCompressionIsDisabled(cctxParams), - workspace, wkspSize); + workspace, wkspSize, hufFlags); + FORWARD_IF_ERROR(entropyMetadata->hufMetadata.hufDesSize, "ZSTD_buildBlockEntropyStats_literals failed"); entropyMetadata->fseMetadata.fseTablesSize = ZSTD_buildBlockEntropyStats_sequences(seqStorePtr, @@ -3143,11 +3672,12 @@ size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, } /* Returns the size estimate for the literals section (header + content) of a block */ -static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, - const ZSTD_hufCTables_t* huf, - const ZSTD_hufCTablesMetadata_t* hufMetadata, - void* workspace, size_t wkspSize, - int writeEntropy) +static size_t +ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSize, + const ZSTD_hufCTables_t* huf, + const ZSTD_hufCTablesMetadata_t* hufMetadata, + void* workspace, size_t wkspSize, + int writeEntropy) { unsigned* const countWksp = (unsigned*)workspace; unsigned maxSymbolValue = HUF_SYMBOLVALUE_MAX; @@ -3169,12 +3699,13 @@ static size_t ZSTD_estimateBlockSize_literal(const BYTE* literals, size_t litSiz } /* Returns the size estimate for the FSE-compressed symbols (of, ml, ll) of a block */ -static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, - const BYTE* codeTable, size_t nbSeq, unsigned maxCode, - const FSE_CTable* fseCTable, - const U8* additionalBits, - short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, - void* workspace, size_t wkspSize) +static size_t +ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, + const BYTE* codeTable, size_t nbSeq, unsigned maxCode, + const FSE_CTable* fseCTable, + const U8* additionalBits, + short const* defaultNorm, U32 defaultNormLog, U32 defaultMax, + void* workspace, size_t wkspSize) { unsigned* const countWksp = (unsigned*)workspace; const BYTE* ctp = codeTable; @@ -3206,99 +3737,107 @@ static size_t ZSTD_estimateBlockSize_symbolType(symbolEncodingType_e type, } /* Returns the size estimate for the sequences section (header + content) of a block */ -static size_t ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, - const BYTE* llCodeTable, - const BYTE* mlCodeTable, - size_t nbSeq, - const ZSTD_fseCTables_t* fseTables, - const ZSTD_fseCTablesMetadata_t* fseMetadata, - void* workspace, size_t wkspSize, - int writeEntropy) +static size_t +ZSTD_estimateBlockSize_sequences(const BYTE* ofCodeTable, + const BYTE* llCodeTable, + const BYTE* mlCodeTable, + size_t nbSeq, + const ZSTD_fseCTables_t* fseTables, + const ZSTD_fseCTablesMetadata_t* fseMetadata, + void* workspace, size_t wkspSize, + int writeEntropy) { size_t sequencesSectionHeaderSize = 1 /* seqHead */ + 1 /* min seqSize size */ + (nbSeq >= 128) + (nbSeq >= LONGNBSEQ); size_t cSeqSizeEstimate = 0; cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->ofType, ofCodeTable, nbSeq, MaxOff, - fseTables->offcodeCTable, NULL, - OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, - workspace, wkspSize); + fseTables->offcodeCTable, NULL, + OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff, + workspace, wkspSize); cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->llType, llCodeTable, nbSeq, MaxLL, - fseTables->litlengthCTable, LL_bits, - LL_defaultNorm, LL_defaultNormLog, MaxLL, - workspace, wkspSize); + fseTables->litlengthCTable, LL_bits, + LL_defaultNorm, LL_defaultNormLog, MaxLL, + workspace, wkspSize); cSeqSizeEstimate += ZSTD_estimateBlockSize_symbolType(fseMetadata->mlType, mlCodeTable, nbSeq, MaxML, - fseTables->matchlengthCTable, ML_bits, - ML_defaultNorm, ML_defaultNormLog, MaxML, - workspace, wkspSize); + fseTables->matchlengthCTable, ML_bits, + ML_defaultNorm, ML_defaultNormLog, MaxML, + workspace, wkspSize); if (writeEntropy) cSeqSizeEstimate += fseMetadata->fseTablesSize; return cSeqSizeEstimate + sequencesSectionHeaderSize; } /* Returns the size estimate for a given stream of literals, of, ll, ml */ -static size_t ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, - const BYTE* ofCodeTable, - const BYTE* llCodeTable, - const BYTE* mlCodeTable, - size_t nbSeq, - const ZSTD_entropyCTables_t* entropy, - const ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize, - int writeLitEntropy, int writeSeqEntropy) { +static size_t +ZSTD_estimateBlockSize(const BYTE* literals, size_t litSize, + const BYTE* ofCodeTable, + const BYTE* llCodeTable, + const BYTE* mlCodeTable, + size_t nbSeq, + const ZSTD_entropyCTables_t* entropy, + const ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize, + int writeLitEntropy, int writeSeqEntropy) +{ size_t const literalsSize = ZSTD_estimateBlockSize_literal(literals, litSize, - &entropy->huf, &entropyMetadata->hufMetadata, - workspace, wkspSize, writeLitEntropy); + &entropy->huf, &entropyMetadata->hufMetadata, + workspace, wkspSize, writeLitEntropy); size_t const seqSize = ZSTD_estimateBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, - nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, - workspace, wkspSize, writeSeqEntropy); + nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, + workspace, wkspSize, writeSeqEntropy); return seqSize + literalsSize + ZSTD_blockHeaderSize; } /* Builds entropy statistics and uses them for blocksize estimation. * - * Returns the estimated compressed size of the seqStore, or a zstd error. + * @return: estimated compressed size of the seqStore, or a zstd error. */ -static size_t ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(seqStore_t* seqStore, ZSTD_CCtx* zc) { - ZSTD_entropyCTablesMetadata_t* entropyMetadata = &zc->blockSplitCtx.entropyMetadata; +static size_t +ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(seqStore_t* seqStore, ZSTD_CCtx* zc) +{ + ZSTD_entropyCTablesMetadata_t* const entropyMetadata = &zc->blockSplitCtx.entropyMetadata; DEBUGLOG(6, "ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize()"); FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(seqStore, &zc->blockState.prevCBlock->entropy, &zc->blockState.nextCBlock->entropy, &zc->appliedParams, entropyMetadata, - zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */), ""); - return ZSTD_estimateBlockSize(seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), + zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE), ""); + return ZSTD_estimateBlockSize( + seqStore->litStart, (size_t)(seqStore->lit - seqStore->litStart), seqStore->ofCode, seqStore->llCode, seqStore->mlCode, (size_t)(seqStore->sequences - seqStore->sequencesStart), - &zc->blockState.nextCBlock->entropy, entropyMetadata, zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE, + &zc->blockState.nextCBlock->entropy, + entropyMetadata, + zc->entropyWorkspace, ENTROPY_WORKSPACE_SIZE, (int)(entropyMetadata->hufMetadata.hType == set_compressed), 1); } /* Returns literals bytes represented in a seqStore */ -static size_t ZSTD_countSeqStoreLiteralsBytes(const seqStore_t* const seqStore) { +static size_t ZSTD_countSeqStoreLiteralsBytes(const seqStore_t* const seqStore) +{ size_t literalsBytes = 0; - size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; + size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); size_t i; for (i = 0; i < nbSeqs; ++i) { - seqDef seq = seqStore->sequencesStart[i]; + seqDef const seq = seqStore->sequencesStart[i]; literalsBytes += seq.litLength; if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_literalLength) { literalsBytes += 0x10000; - } - } + } } return literalsBytes; } /* Returns match bytes represented in a seqStore */ -static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) { +static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) +{ size_t matchBytes = 0; - size_t const nbSeqs = seqStore->sequences - seqStore->sequencesStart; + size_t const nbSeqs = (size_t)(seqStore->sequences - seqStore->sequencesStart); size_t i; for (i = 0; i < nbSeqs; ++i) { seqDef seq = seqStore->sequencesStart[i]; matchBytes += seq.mlBase + MINMATCH; if (i == seqStore->longLengthPos && seqStore->longLengthType == ZSTD_llt_matchLength) { matchBytes += 0x10000; - } - } + } } return matchBytes; } @@ -3307,15 +3846,12 @@ static size_t ZSTD_countSeqStoreMatchBytes(const seqStore_t* const seqStore) { */ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, const seqStore_t* originalSeqStore, - size_t startIdx, size_t endIdx) { - BYTE* const litEnd = originalSeqStore->lit; - size_t literalsBytes; - size_t literalsBytesPreceding = 0; - + size_t startIdx, size_t endIdx) +{ *resultSeqStore = *originalSeqStore; if (startIdx > 0) { resultSeqStore->sequences = originalSeqStore->sequencesStart + startIdx; - literalsBytesPreceding = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); + resultSeqStore->litStart += ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); } /* Move longLengthPos into the correct position if necessary */ @@ -3328,13 +3864,12 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, } resultSeqStore->sequencesStart = originalSeqStore->sequencesStart + startIdx; resultSeqStore->sequences = originalSeqStore->sequencesStart + endIdx; - literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); - resultSeqStore->litStart += literalsBytesPreceding; if (endIdx == (size_t)(originalSeqStore->sequences - originalSeqStore->sequencesStart)) { /* This accounts for possible last literals if the derived chunk reaches the end of the block */ - resultSeqStore->lit = litEnd; + assert(resultSeqStore->lit == originalSeqStore->lit); } else { - resultSeqStore->lit = resultSeqStore->litStart+literalsBytes; + size_t const literalsBytes = ZSTD_countSeqStoreLiteralsBytes(resultSeqStore); + resultSeqStore->lit = resultSeqStore->litStart + literalsBytes; } resultSeqStore->llCode += startIdx; resultSeqStore->mlCode += startIdx; @@ -3342,20 +3877,26 @@ static void ZSTD_deriveSeqStoreChunk(seqStore_t* resultSeqStore, } /* - * Returns the raw offset represented by the combination of offCode, ll0, and repcode history. - * offCode must represent a repcode in the numeric representation of ZSTD_storeSeq(). + * Returns the raw offset represented by the combination of offBase, ll0, and repcode history. + * offBase must represent a repcode in the numeric representation of ZSTD_storeSeq(). */ static U32 -ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, const U32 ll0) -{ - U32 const adjustedOffCode = STORED_REPCODE(offCode) - 1 + ll0; /* [ 0 - 3 ] */ - assert(STORED_IS_REPCODE(offCode)); - if (adjustedOffCode == ZSTD_REP_NUM) { - /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 */ - assert(rep[0] > 0); +ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offBase, const U32 ll0) +{ + U32 const adjustedRepCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; /* [ 0 - 3 ] */ + assert(OFFBASE_IS_REPCODE(offBase)); + if (adjustedRepCode == ZSTD_REP_NUM) { + assert(ll0); + /* litlength == 0 and offCode == 2 implies selection of first repcode - 1 + * This is only valid if it results in a valid offset value, aka > 0. + * Note : it may happen that `rep[0]==1` in exceptional circumstances. + * In which case this function will return 0, which is an invalid offset. + * It's not an issue though, since this value will be + * compared and discarded within ZSTD_seqStore_resolveOffCodes(). + */ return rep[0] - 1; } - return rep[adjustedOffCode]; + return rep[adjustedRepCode]; } /* @@ -3371,30 +3912,33 @@ ZSTD_resolveRepcodeToRawOffset(const U32 rep[ZSTD_REP_NUM], const U32 offCode, c * 1-3 : repcode 1-3 * 4+ : real_offset+3 */ -static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_t* const cRepcodes, - seqStore_t* const seqStore, U32 const nbSeq) { +static void +ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_t* const cRepcodes, + const seqStore_t* const seqStore, U32 const nbSeq) +{ U32 idx = 0; + U32 const longLitLenIdx = seqStore->longLengthType == ZSTD_llt_literalLength ? seqStore->longLengthPos : nbSeq; for (; idx < nbSeq; ++idx) { seqDef* const seq = seqStore->sequencesStart + idx; - U32 const ll0 = (seq->litLength == 0); - U32 const offCode = OFFBASE_TO_STORED(seq->offBase); - assert(seq->offBase > 0); - if (STORED_IS_REPCODE(offCode)) { - U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offCode, ll0); - U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offCode, ll0); + U32 const ll0 = (seq->litLength == 0) && (idx != longLitLenIdx); + U32 const offBase = seq->offBase; + assert(offBase > 0); + if (OFFBASE_IS_REPCODE(offBase)) { + U32 const dRawOffset = ZSTD_resolveRepcodeToRawOffset(dRepcodes->rep, offBase, ll0); + U32 const cRawOffset = ZSTD_resolveRepcodeToRawOffset(cRepcodes->rep, offBase, ll0); /* Adjust simulated decompression repcode history if we come across a mismatch. Replace * the repcode with the offset it actually references, determined by the compression * repcode history. */ if (dRawOffset != cRawOffset) { - seq->offBase = cRawOffset + ZSTD_REP_NUM; + seq->offBase = OFFSET_TO_OFFBASE(cRawOffset); } } /* Compression repcode history is always updated with values directly from the unmodified seqStore. * Decompression repcode history may use modified seq->offset value taken from compression repcode history. */ - ZSTD_updateRep(dRepcodes->rep, OFFBASE_TO_STORED(seq->offBase), ll0); - ZSTD_updateRep(cRepcodes->rep, offCode, ll0); + ZSTD_updateRep(dRepcodes->rep, seq->offBase, ll0); + ZSTD_updateRep(cRepcodes->rep, offBase, ll0); } } @@ -3404,10 +3948,11 @@ static void ZSTD_seqStore_resolveOffCodes(repcodes_t* const dRepcodes, repcodes_ * Returns the total size of that block (including header) or a ZSTD error code. */ static size_t -ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, +ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, + const seqStore_t* const seqStore, repcodes_t* const dRep, repcodes_t* const cRep, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, + const void* src, size_t srcSize, U32 lastBlock, U32 isPartition) { const U32 rleMaxLength = 25; @@ -3442,8 +3987,9 @@ ZSTD_compressSeqStore_singleBlock(ZSTD_CCtx* zc, seqStore_t* const seqStore, cSeqsSize = 1; } + /* Sequence collection not supported when block splitting */ if (zc->seqCollector.collectSequences) { - ZSTD_copyBlockSequences(zc); + FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, seqStore, dRepOriginal.rep), "copyBlockSequences failed"); ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); return 0; } @@ -3481,45 +4027,49 @@ typedef struct { /* Helper function to perform the recursive search for block splits. * Estimates the cost of seqStore prior to split, and estimates the cost of splitting the sequences in half. - * If advantageous to split, then we recurse down the two sub-blocks. If not, or if an error occurred in estimation, then - * we do not recurse. + * If advantageous to split, then we recurse down the two sub-blocks. + * If not, or if an error occurred in estimation, then we do not recurse. * - * Note: The recursion depth is capped by a heuristic minimum number of sequences, defined by MIN_SEQUENCES_BLOCK_SPLITTING. + * Note: The recursion depth is capped by a heuristic minimum number of sequences, + * defined by MIN_SEQUENCES_BLOCK_SPLITTING. * In theory, this means the absolute largest recursion depth is 10 == log2(maxNbSeqInBlock/MIN_SEQUENCES_BLOCK_SPLITTING). * In practice, recursion depth usually doesn't go beyond 4. * - * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize + * Furthermore, the number of splits is capped by ZSTD_MAX_NB_BLOCK_SPLITS. + * At ZSTD_MAX_NB_BLOCK_SPLITS == 196 with the current existing blockSize * maximum of 128 KB, this value is actually impossible to reach. */ static void ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t endIdx, ZSTD_CCtx* zc, const seqStore_t* origSeqStore) { - seqStore_t* fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; - seqStore_t* firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; - seqStore_t* secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; + seqStore_t* const fullSeqStoreChunk = &zc->blockSplitCtx.fullSeqStoreChunk; + seqStore_t* const firstHalfSeqStore = &zc->blockSplitCtx.firstHalfSeqStore; + seqStore_t* const secondHalfSeqStore = &zc->blockSplitCtx.secondHalfSeqStore; size_t estimatedOriginalSize; size_t estimatedFirstHalfSize; size_t estimatedSecondHalfSize; size_t midIdx = (startIdx + endIdx)/2; + DEBUGLOG(5, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); + assert(endIdx >= startIdx); if (endIdx - startIdx < MIN_SEQUENCES_BLOCK_SPLITTING || splits->idx >= ZSTD_MAX_NB_BLOCK_SPLITS) { - DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences"); + DEBUGLOG(6, "ZSTD_deriveBlockSplitsHelper: Too few sequences (%zu)", endIdx - startIdx); return; } - DEBUGLOG(4, "ZSTD_deriveBlockSplitsHelper: startIdx=%zu endIdx=%zu", startIdx, endIdx); ZSTD_deriveSeqStoreChunk(fullSeqStoreChunk, origSeqStore, startIdx, endIdx); ZSTD_deriveSeqStoreChunk(firstHalfSeqStore, origSeqStore, startIdx, midIdx); ZSTD_deriveSeqStoreChunk(secondHalfSeqStore, origSeqStore, midIdx, endIdx); estimatedOriginalSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(fullSeqStoreChunk, zc); estimatedFirstHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(firstHalfSeqStore, zc); estimatedSecondHalfSize = ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(secondHalfSeqStore, zc); - DEBUGLOG(4, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", + DEBUGLOG(5, "Estimated original block size: %zu -- First half split: %zu -- Second half split: %zu", estimatedOriginalSize, estimatedFirstHalfSize, estimatedSecondHalfSize); if (ZSTD_isError(estimatedOriginalSize) || ZSTD_isError(estimatedFirstHalfSize) || ZSTD_isError(estimatedSecondHalfSize)) { return; } if (estimatedFirstHalfSize + estimatedSecondHalfSize < estimatedOriginalSize) { + DEBUGLOG(5, "split decided at seqNb:%zu", midIdx); ZSTD_deriveBlockSplitsHelper(splits, startIdx, midIdx, zc, origSeqStore); splits->splitLocations[splits->idx] = (U32)midIdx; splits->idx++; @@ -3527,14 +4077,18 @@ ZSTD_deriveBlockSplitsHelper(seqStoreSplits* splits, size_t startIdx, size_t end } } -/* Base recursive function. Populates a table with intra-block partition indices that can improve compression ratio. +/* Base recursive function. + * Populates a table with intra-block partition indices that can improve compression ratio. * - * Returns the number of splits made (which equals the size of the partition table - 1). + * @return: number of splits made (which equals the size of the partition table - 1). */ -static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) { - seqStoreSplits splits = {partitions, 0}; +static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) +{ + seqStoreSplits splits; + splits.splitLocations = partitions; + splits.idx = 0; if (nbSeq <= 4) { - DEBUGLOG(4, "ZSTD_deriveBlockSplits: Too few sequences to split"); + DEBUGLOG(5, "ZSTD_deriveBlockSplits: Too few sequences to split (%u <= 4)", nbSeq); /* Refuse to try and split anything with less than 4 sequences */ return 0; } @@ -3550,18 +4104,20 @@ static size_t ZSTD_deriveBlockSplits(ZSTD_CCtx* zc, U32 partitions[], U32 nbSeq) * Returns combined size of all blocks (which includes headers), or a ZSTD error code. */ static size_t -ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, - const void* src, size_t blockSize, U32 lastBlock, U32 nbSeq) +ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, + void* dst, size_t dstCapacity, + const void* src, size_t blockSize, + U32 lastBlock, U32 nbSeq) { size_t cSize = 0; const BYTE* ip = (const BYTE*)src; BYTE* op = (BYTE*)dst; size_t i = 0; size_t srcBytesTotal = 0; - U32* partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ - seqStore_t* nextSeqStore = &zc->blockSplitCtx.nextSeqStore; - seqStore_t* currSeqStore = &zc->blockSplitCtx.currSeqStore; - size_t numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); + U32* const partitions = zc->blockSplitCtx.partitions; /* size == ZSTD_MAX_NB_BLOCK_SPLITS */ + seqStore_t* const nextSeqStore = &zc->blockSplitCtx.nextSeqStore; + seqStore_t* const currSeqStore = &zc->blockSplitCtx.currSeqStore; + size_t const numSplits = ZSTD_deriveBlockSplits(zc, partitions, nbSeq); /* If a block is split and some partitions are emitted as RLE/uncompressed, then repcode history * may become invalid. In order to reconcile potentially invalid repcodes, we keep track of two @@ -3583,30 +4139,31 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac ZSTD_memcpy(cRep.rep, zc->blockState.prevCBlock->rep, sizeof(repcodes_t)); ZSTD_memset(nextSeqStore, 0, sizeof(seqStore_t)); - DEBUGLOG(4, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", + DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal (dstCapacity=%u, dictLimit=%u, nextToUpdate=%u)", (unsigned)dstCapacity, (unsigned)zc->blockState.matchState.window.dictLimit, (unsigned)zc->blockState.matchState.nextToUpdate); if (numSplits == 0) { - size_t cSizeSingleBlock = ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, - &dRep, &cRep, - op, dstCapacity, - ip, blockSize, - lastBlock, 0 /* isPartition */); + size_t cSizeSingleBlock = + ZSTD_compressSeqStore_singleBlock(zc, &zc->seqStore, + &dRep, &cRep, + op, dstCapacity, + ip, blockSize, + lastBlock, 0 /* isPartition */); FORWARD_IF_ERROR(cSizeSingleBlock, "Compressing single block from splitBlock_internal() failed!"); DEBUGLOG(5, "ZSTD_compressBlock_splitBlock_internal: No splits"); - assert(cSizeSingleBlock <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); + assert(zc->blockSize <= ZSTD_BLOCKSIZE_MAX); + assert(cSizeSingleBlock <= zc->blockSize + ZSTD_blockHeaderSize); return cSizeSingleBlock; } ZSTD_deriveSeqStoreChunk(currSeqStore, &zc->seqStore, 0, partitions[0]); for (i = 0; i <= numSplits; ++i) { - size_t srcBytes; size_t cSizeChunk; U32 const lastPartition = (i == numSplits); U32 lastBlockEntireSrc = 0; - srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); + size_t srcBytes = ZSTD_countSeqStoreLiteralsBytes(currSeqStore) + ZSTD_countSeqStoreMatchBytes(currSeqStore); srcBytesTotal += srcBytes; if (lastPartition) { /* This is the final partition, need to account for possible last literals */ @@ -3621,7 +4178,8 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac op, dstCapacity, ip, srcBytes, lastBlockEntireSrc, 1 /* isPartition */); - DEBUGLOG(5, "Estimated size: %zu actual size: %zu", ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); + DEBUGLOG(5, "Estimated size: %zu vs %zu : actual size", + ZSTD_buildEntropyStatisticsAndEstimateSubBlockSize(currSeqStore, zc), cSizeChunk); FORWARD_IF_ERROR(cSizeChunk, "Compressing chunk failed!"); ip += srcBytes; @@ -3629,10 +4187,10 @@ ZSTD_compressBlock_splitBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapac dstCapacity -= cSizeChunk; cSize += cSizeChunk; *currSeqStore = *nextSeqStore; - assert(cSizeChunk <= ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize); + assert(cSizeChunk <= zc->blockSize + ZSTD_blockHeaderSize); } - /* cRep and dRep may have diverged during the compression. If so, we use the dRep repcodes - * for the next block. + /* cRep and dRep may have diverged during the compression. + * If so, we use the dRep repcodes for the next block. */ ZSTD_memcpy(zc->blockState.prevCBlock->rep, dRep.rep, sizeof(repcodes_t)); return cSize; @@ -3643,8 +4201,6 @@ ZSTD_compressBlock_splitBlock(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) { - const BYTE* ip = (const BYTE*)src; - BYTE* op = (BYTE*)dst; U32 nbSeq; size_t cSize; DEBUGLOG(4, "ZSTD_compressBlock_splitBlock"); @@ -3655,7 +4211,8 @@ ZSTD_compressBlock_splitBlock(ZSTD_CCtx* zc, if (bss == ZSTDbss_noCompress) { if (zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode == FSE_repeat_valid) zc->blockState.prevCBlock->entropy.fse.offcode_repeatMode = FSE_repeat_check; - cSize = ZSTD_noCompressBlock(op, dstCapacity, ip, srcSize, lastBlock); + RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); + cSize = ZSTD_noCompressBlock(dst, dstCapacity, src, srcSize, lastBlock); FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); DEBUGLOG(4, "ZSTD_compressBlock_splitBlock: Nocompress block"); return cSize; @@ -3673,9 +4230,9 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 frame) { - /* This the upper bound for the length of an rle block. - * This isn't the actual upper bound. Finding the real threshold - * needs further investigation. + /* This is an estimated upper bound for the length of an rle block. + * This isn't the actual upper bound. + * Finding the real threshold needs further investigation. */ const U32 rleMaxLength = 25; size_t cSize; @@ -3687,11 +4244,15 @@ ZSTD_compressBlock_internal(ZSTD_CCtx* zc, { const size_t bss = ZSTD_buildSeqStore(zc, src, srcSize); FORWARD_IF_ERROR(bss, "ZSTD_buildSeqStore failed"); - if (bss == ZSTDbss_noCompress) { cSize = 0; goto out; } + if (bss == ZSTDbss_noCompress) { + RETURN_ERROR_IF(zc->seqCollector.collectSequences, sequenceProducer_failed, "Uncompressible block"); + cSize = 0; + goto out; + } } if (zc->seqCollector.collectSequences) { - ZSTD_copyBlockSequences(zc); + FORWARD_IF_ERROR(ZSTD_copyBlockSequences(&zc->seqCollector, ZSTD_getSeqStore(zc), zc->blockState.prevCBlock->rep), "copyBlockSequences failed"); ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); return 0; } @@ -3767,10 +4328,11 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, * * cSize >= blockBound(srcSize): We have expanded the block too much so * emit an uncompressed block. */ - { - size_t const cSize = ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); + { size_t const cSize = + ZSTD_compressSuperBlock(zc, dst, dstCapacity, src, srcSize, lastBlock); if (cSize != ERROR(dstSize_tooSmall)) { - size_t const maxCSize = srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); + size_t const maxCSize = + srcSize - ZSTD_minGain(srcSize, zc->appliedParams.cParams.strategy); FORWARD_IF_ERROR(cSize, "ZSTD_compressSuperBlock failed"); if (cSize != 0 && cSize < maxCSize + ZSTD_blockHeaderSize) { ZSTD_blockState_confirmRepcodesAndEntropyTables(&zc->blockState); @@ -3778,7 +4340,7 @@ static size_t ZSTD_compressBlock_targetCBlockSize_body(ZSTD_CCtx* zc, } } } - } + } /* if (bss == ZSTDbss_compress)*/ DEBUGLOG(6, "Resorting to ZSTD_noCompressBlock()"); /* Superblock compression failed, attempt to emit a single no compress block. @@ -3836,7 +4398,7 @@ static void ZSTD_overflowCorrectIfNeeded(ZSTD_matchState_t* ms, * All blocks will be terminated, all input will be consumed. * Function will issue an error if there is not enough `dstCapacity` to hold the compressed content. * Frame is supposed already started (header already produced) -* @return : compressed size, or an error code +* @return : compressed size, or an error code */ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, @@ -3860,7 +4422,9 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, ZSTD_matchState_t* const ms = &cctx->blockState.matchState; U32 const lastBlock = lastFrameChunk & (blockSize >= remaining); - RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE, + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize + MIN_CBLOCK_SIZE + 1, dstSize_tooSmall, "not enough space to store compressed block"); if (remaining < blockSize) blockSize = remaining; @@ -3899,7 +4463,7 @@ static size_t ZSTD_compress_frameChunk(ZSTD_CCtx* cctx, MEM_writeLE24(op, cBlockHeader); cSize += ZSTD_blockHeaderSize; } - } + } /* if (ZSTD_useTargetCBlockSize(&cctx->appliedParams))*/ ip += blockSize; @@ -4001,19 +4565,15 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity) } } -size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) +void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq) { - RETURN_ERROR_IF(cctx->stage != ZSTDcs_init, stage_wrong, - "wrong cctx stage"); - RETURN_ERROR_IF(cctx->appliedParams.ldmParams.enableLdm == ZSTD_ps_enable, - parameter_unsupported, - "incompatible with ldm"); + assert(cctx->stage == ZSTDcs_init); + assert(nbSeq == 0 || cctx->appliedParams.ldmParams.enableLdm != ZSTD_ps_enable); cctx->externSeqStore.seq = seq; cctx->externSeqStore.size = nbSeq; cctx->externSeqStore.capacity = nbSeq; cctx->externSeqStore.pos = 0; cctx->externSeqStore.posInSequence = 0; - return 0; } @@ -4078,31 +4638,51 @@ static size_t ZSTD_compressContinue_internal (ZSTD_CCtx* cctx, } } -size_t ZSTD_compressContinue (ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressContinue (srcSize=%u)", (unsigned)srcSize); return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 1 /* frame mode */, 0 /* last chunk */); } +/* NOTE: Must just wrap ZSTD_compressContinue_public() */ +size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_compressContinue_public(cctx, dst, dstCapacity, src, srcSize); +} -size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) +static size_t ZSTD_getBlockSize_deprecated(const ZSTD_CCtx* cctx) { ZSTD_compressionParameters const cParams = cctx->appliedParams.cParams; assert(!ZSTD_checkCParams(cParams)); - return MIN (ZSTD_BLOCKSIZE_MAX, (U32)1 << cParams.windowLog); + return MIN(cctx->appliedParams.maxBlockSize, (size_t)1 << cParams.windowLog); } -size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) +/* NOTE: Must just wrap ZSTD_getBlockSize_deprecated() */ +size_t ZSTD_getBlockSize(const ZSTD_CCtx* cctx) +{ + return ZSTD_getBlockSize_deprecated(cctx); +} + +/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ +size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) { DEBUGLOG(5, "ZSTD_compressBlock: srcSize = %u", (unsigned)srcSize); - { size_t const blockSizeMax = ZSTD_getBlockSize(cctx); + { size_t const blockSizeMax = ZSTD_getBlockSize_deprecated(cctx); RETURN_ERROR_IF(srcSize > blockSizeMax, srcSize_wrong, "input is larger than a block"); } return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 0 /* frame mode */, 0 /* last chunk */); } +/* NOTE: Must just wrap ZSTD_compressBlock_deprecated() */ +size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize) +{ + return ZSTD_compressBlock_deprecated(cctx, dst, dstCapacity, src, srcSize); +} + /*! ZSTD_loadDictionaryContent() : * @return : 0, or an error code */ @@ -4111,25 +4691,36 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, ZSTD_cwksp* ws, ZSTD_CCtx_params const* params, const void* src, size_t srcSize, - ZSTD_dictTableLoadMethod_e dtlm) + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) { const BYTE* ip = (const BYTE*) src; const BYTE* const iend = ip + srcSize; int const loadLdmDict = params->ldmParams.enableLdm == ZSTD_ps_enable && ls != NULL; - /* Assert that we the ms params match the params we're being given */ + /* Assert that the ms params match the params we're being given */ ZSTD_assertEqualCParams(params->cParams, ms->cParams); - if (srcSize > ZSTD_CHUNKSIZE_MAX) { + { /* Ensure large dictionaries can't cause index overflow */ + /* Allow the dictionary to set indices up to exactly ZSTD_CURRENT_MAX. * Dictionaries right at the edge will immediately trigger overflow * correction, but I don't want to insert extra constraints here. */ - U32 const maxDictSize = ZSTD_CURRENT_MAX - 1; - /* We must have cleared our windows when our source is this large. */ - assert(ZSTD_window_isEmpty(ms->window)); - if (loadLdmDict) - assert(ZSTD_window_isEmpty(ls->window)); + U32 maxDictSize = ZSTD_CURRENT_MAX - ZSTD_WINDOW_START_INDEX; + + int const CDictTaggedIndices = ZSTD_CDictIndicesAreTagged(¶ms->cParams); + if (CDictTaggedIndices && tfp == ZSTD_tfp_forCDict) { + /* Some dictionary matchfinders in zstd use "short cache", + * which treats the lower ZSTD_SHORT_CACHE_TAG_BITS of each + * CDict hashtable entry as a tag rather than as part of an index. + * When short cache is used, we need to truncate the dictionary + * so that its indices don't overlap with the tag. */ + U32 const shortCacheMaxDictSize = (1u << (32 - ZSTD_SHORT_CACHE_TAG_BITS)) - ZSTD_WINDOW_START_INDEX; + maxDictSize = MIN(maxDictSize, shortCacheMaxDictSize); + assert(!loadLdmDict); + } + /* If the dictionary is too large, only load the suffix of the dictionary. */ if (srcSize > maxDictSize) { ip = iend - maxDictSize; @@ -4138,35 +4729,58 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, } } - DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder); + if (srcSize > ZSTD_CHUNKSIZE_MAX) { + /* We must have cleared our windows when our source is this large. */ + assert(ZSTD_window_isEmpty(ms->window)); + if (loadLdmDict) assert(ZSTD_window_isEmpty(ls->window)); + } ZSTD_window_update(&ms->window, src, srcSize, /* forceNonContiguous */ 0); - ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); - ms->forceNonContiguous = params->deterministicRefPrefix; - if (loadLdmDict) { + DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder); + + if (loadLdmDict) { /* Load the entire dict into LDM matchfinders. */ ZSTD_window_update(&ls->window, src, srcSize, /* forceNonContiguous */ 0); ls->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ls->window.base); + ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); } + /* If the dict is larger than we can reasonably index in our tables, only load the suffix. */ + if (params->cParams.strategy < ZSTD_btultra) { + U32 maxDictSize = 8U << MIN(MAX(params->cParams.hashLog, params->cParams.chainLog), 28); + if (srcSize > maxDictSize) { + ip = iend - maxDictSize; + src = ip; + srcSize = maxDictSize; + } + } + + ms->nextToUpdate = (U32)(ip - ms->window.base); + ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base); + ms->forceNonContiguous = params->deterministicRefPrefix; + if (srcSize <= HASH_READ_SIZE) return 0; ZSTD_overflowCorrectIfNeeded(ms, ws, params, ip, iend); - if (loadLdmDict) - ZSTD_ldm_fillHashTable(ls, ip, iend, ¶ms->ldmParams); - switch(params->cParams.strategy) { case ZSTD_fast: - ZSTD_fillHashTable(ms, iend, dtlm); + ZSTD_fillHashTable(ms, iend, dtlm, tfp); break; case ZSTD_dfast: - ZSTD_fillDoubleHashTable(ms, iend, dtlm); +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + ZSTD_fillDoubleHashTable(ms, iend, dtlm, tfp); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_greedy: case ZSTD_lazy: case ZSTD_lazy2: +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) assert(srcSize >= HASH_READ_SIZE); if (ms->dedicatedDictSearch) { assert(ms->chainTable != NULL); @@ -4174,7 +4788,7 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, } else { assert(params->useRowMatchFinder != ZSTD_ps_auto); if (params->useRowMatchFinder == ZSTD_ps_enable) { - size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog) * sizeof(U16); + size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog); ZSTD_memset(ms->tagTable, 0, tagTableSize); ZSTD_row_update(ms, iend-HASH_READ_SIZE); DEBUGLOG(4, "Using row-based hash table for lazy dict"); @@ -4183,14 +4797,23 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms, DEBUGLOG(4, "Using chain-based hash table for lazy dict"); } } +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_btlazy2: /* we want the dictionary table fully sorted */ case ZSTD_btopt: case ZSTD_btultra: case ZSTD_btultra2: +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) assert(srcSize >= HASH_READ_SIZE); ZSTD_updateTree(ms, iend-HASH_READ_SIZE, iend); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; default: @@ -4237,11 +4860,10 @@ size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace, /* We only set the loaded table as valid if it contains all non-zero * weights. Otherwise, we set it to check */ - if (!hasZeroWeights) + if (!hasZeroWeights && maxSymbolValue == 255) bs->entropy.huf.repeatMode = HUF_repeat_valid; RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted, ""); - RETURN_ERROR_IF(maxSymbolValue < 255, dictionary_corrupted, ""); dictPtr += hufHeaderSize; } @@ -4327,6 +4949,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, ZSTD_CCtx_params const* params, const void* dict, size_t dictSize, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp, void* workspace) { const BYTE* dictPtr = (const BYTE*)dict; @@ -4345,7 +4968,7 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs, { size_t const dictContentSize = (size_t)(dictEnd - dictPtr); FORWARD_IF_ERROR(ZSTD_loadDictionaryContent( - ms, NULL, ws, params, dictPtr, dictContentSize, dtlm), ""); + ms, NULL, ws, params, dictPtr, dictContentSize, dtlm, tfp), ""); } return dictID; } @@ -4361,6 +4984,7 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, const void* dict, size_t dictSize, ZSTD_dictContentType_e dictContentType, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp, void* workspace) { DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize); @@ -4373,13 +4997,13 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, /* dict restricted modes */ if (dictContentType == ZSTD_dct_rawContent) - return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm); + return ZSTD_loadDictionaryContent(ms, ls, ws, params, dict, dictSize, dtlm, tfp); if (MEM_readLE32(dict) != ZSTD_MAGIC_DICTIONARY) { if (dictContentType == ZSTD_dct_auto) { DEBUGLOG(4, "raw content dictionary detected"); return ZSTD_loadDictionaryContent( - ms, ls, ws, params, dict, dictSize, dtlm); + ms, ls, ws, params, dict, dictSize, dtlm, tfp); } RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong, ""); assert(0); /* impossible */ @@ -4387,13 +5011,14 @@ ZSTD_compress_insertDictionary(ZSTD_compressedBlockState_t* bs, /* dict as full zstd dictionary */ return ZSTD_loadZstdDictionary( - bs, ms, ws, params, dict, dictSize, dtlm, workspace); + bs, ms, ws, params, dict, dictSize, dtlm, tfp, workspace); } #define ZSTD_USE_CDICT_PARAMS_SRCSIZE_CUTOFF (128 KB) #define ZSTD_USE_CDICT_PARAMS_DICTSIZE_MULTIPLIER (6ULL) /*! ZSTD_compressBegin_internal() : + * Assumption : either @dict OR @cdict (or none) is non-NULL, never both * @return : 0, or an error code */ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, @@ -4426,11 +5051,11 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx, cctx->blockState.prevCBlock, &cctx->blockState.matchState, &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, cdict->dictContent, cdict->dictContentSize, cdict->dictContentType, dtlm, - cctx->entropyWorkspace) + ZSTD_tfp_forCCtx, cctx->entropyWorkspace) : ZSTD_compress_insertDictionary( cctx->blockState.prevCBlock, &cctx->blockState.matchState, &cctx->ldmState, &cctx->workspace, &cctx->appliedParams, dict, dictSize, - dictContentType, dtlm, cctx->entropyWorkspace); + dictContentType, dtlm, ZSTD_tfp_forCCtx, cctx->entropyWorkspace); FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); assert(dictID <= UINT_MAX); cctx->dictID = (U32)dictID; @@ -4471,11 +5096,11 @@ size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, &cctxParams, pledgedSrcSize); } -size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) +static size_t +ZSTD_compressBegin_usingDict_deprecated(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) { ZSTD_CCtx_params cctxParams; - { - ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); + { ZSTD_parameters const params = ZSTD_getParams_internal(compressionLevel, ZSTD_CONTENTSIZE_UNKNOWN, dictSize, ZSTD_cpm_noAttachDict); ZSTD_CCtxParams_init_internal(&cctxParams, ¶ms, (compressionLevel == 0) ? ZSTD_CLEVEL_DEFAULT : compressionLevel); } DEBUGLOG(4, "ZSTD_compressBegin_usingDict (dictSize=%u)", (unsigned)dictSize); @@ -4483,9 +5108,15 @@ size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t di &cctxParams, ZSTD_CONTENTSIZE_UNKNOWN, ZSTDb_not_buffered); } +size_t +ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel) +{ + return ZSTD_compressBegin_usingDict_deprecated(cctx, dict, dictSize, compressionLevel); +} + size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel) { - return ZSTD_compressBegin_usingDict(cctx, NULL, 0, compressionLevel); + return ZSTD_compressBegin_usingDict_deprecated(cctx, NULL, 0, compressionLevel); } @@ -4496,14 +5127,13 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) { BYTE* const ostart = (BYTE*)dst; BYTE* op = ostart; - size_t fhSize = 0; DEBUGLOG(4, "ZSTD_writeEpilogue"); RETURN_ERROR_IF(cctx->stage == ZSTDcs_created, stage_wrong, "init missing"); /* special case : empty frame */ if (cctx->stage == ZSTDcs_init) { - fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); + size_t fhSize = ZSTD_writeFrameHeader(dst, dstCapacity, &cctx->appliedParams, 0, 0); FORWARD_IF_ERROR(fhSize, "ZSTD_writeFrameHeader failed"); dstCapacity -= fhSize; op += fhSize; @@ -4513,8 +5143,9 @@ static size_t ZSTD_writeEpilogue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity) if (cctx->stage != ZSTDcs_ending) { /* write one last empty block, make it the "last" block */ U32 const cBlockHeader24 = 1 /* last block */ + (((U32)bt_raw)<<1) + 0; - RETURN_ERROR_IF(dstCapacity<4, dstSize_tooSmall, "no room for epilogue"); - MEM_writeLE32(op, cBlockHeader24); + ZSTD_STATIC_ASSERT(ZSTD_BLOCKHEADERSIZE == 3); + RETURN_ERROR_IF(dstCapacity<3, dstSize_tooSmall, "no room for epilogue"); + MEM_writeLE24(op, cBlockHeader24); op += ZSTD_blockHeaderSize; dstCapacity -= ZSTD_blockHeaderSize; } @@ -4537,9 +5168,9 @@ void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize) (void)extraCSize; } -size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { size_t endResult; size_t const cSize = ZSTD_compressContinue_internal(cctx, @@ -4563,6 +5194,14 @@ size_t ZSTD_compressEnd (ZSTD_CCtx* cctx, return cSize + endResult; } +/* NOTE: Must just wrap ZSTD_compressEnd_public() */ +size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); +} + size_t ZSTD_compress_advanced (ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, @@ -4591,7 +5230,7 @@ size_t ZSTD_compress_advanced_internal( FORWARD_IF_ERROR( ZSTD_compressBegin_internal(cctx, dict, dictSize, ZSTD_dct_auto, ZSTD_dtlm_fast, NULL, params, srcSize, ZSTDb_not_buffered) , ""); - return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); } size_t ZSTD_compress_usingDict(ZSTD_CCtx* cctx, @@ -4709,7 +5348,7 @@ static size_t ZSTD_initCDict_internal( { size_t const dictID = ZSTD_compress_insertDictionary( &cdict->cBlockState, &cdict->matchState, NULL, &cdict->workspace, ¶ms, cdict->dictContent, cdict->dictContentSize, - dictContentType, ZSTD_dtlm_full, cdict->entropyWorkspace); + dictContentType, ZSTD_dtlm_full, ZSTD_tfp_forCDict, cdict->entropyWorkspace); FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed"); assert(dictID <= (size_t)(U32)-1); cdict->dictID = (U32)dictID; @@ -4813,7 +5452,7 @@ ZSTD_CDict* ZSTD_createCDict_advanced2( if (!cdict) return NULL; - if (ZSTD_isError( ZSTD_initCDict_internal(cdict, + if (!cdict || ZSTD_isError( ZSTD_initCDict_internal(cdict, dict, dictSize, dictLoadMethod, dictContentType, cctxParams) )) { @@ -4908,6 +5547,7 @@ const ZSTD_CDict* ZSTD_initStaticCDict( params.cParams = cParams; params.useRowMatchFinder = useRowMatchFinder; cdict->useRowMatchFinder = useRowMatchFinder; + cdict->compressionLevel = ZSTD_NO_CLEVEL; if (ZSTD_isError( ZSTD_initCDict_internal(cdict, dict, dictSize, @@ -4987,12 +5627,17 @@ size_t ZSTD_compressBegin_usingCDict_advanced( /* ZSTD_compressBegin_usingCDict() : * cdict must be != NULL */ -size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) +size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) { ZSTD_frameParameters const fParams = { 0 /*content*/, 0 /*checksum*/, 0 /*noDictID*/ }; return ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, ZSTD_CONTENTSIZE_UNKNOWN); } +size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict) +{ + return ZSTD_compressBegin_usingCDict_deprecated(cctx, cdict); +} + /*! ZSTD_compress_usingCDict_internal(): * Implementation of various ZSTD_compress_usingCDict* functions. */ @@ -5002,7 +5647,7 @@ static size_t ZSTD_compress_usingCDict_internal(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict, ZSTD_frameParameters fParams) { FORWARD_IF_ERROR(ZSTD_compressBegin_usingCDict_internal(cctx, cdict, fParams, srcSize), ""); /* will check if cdict != NULL */ - return ZSTD_compressEnd(cctx, dst, dstCapacity, src, srcSize); + return ZSTD_compressEnd_public(cctx, dst, dstCapacity, src, srcSize); } /*! ZSTD_compress_usingCDict_advanced(): @@ -5199,30 +5844,41 @@ size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel) static size_t ZSTD_nextInputSizeHint(const ZSTD_CCtx* cctx) { - size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; - if (hintInSize==0) hintInSize = cctx->blockSize; - return hintInSize; + if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { + return cctx->blockSize - cctx->stableIn_notConsumed; + } + assert(cctx->appliedParams.inBufferMode == ZSTD_bm_buffered); + { size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos; + if (hintInSize==0) hintInSize = cctx->blockSize; + return hintInSize; + } } /* ZSTD_compressStream_generic(): * internal function for all *compressStream*() variants - * non-static, because can be called from zstdmt_compress.c - * @return : hint size for next input */ + * @return : hint size for next input to complete ongoing block */ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective const flushMode) { - const char* const istart = (const char*)input->src; - const char* const iend = input->size != 0 ? istart + input->size : istart; - const char* ip = input->pos != 0 ? istart + input->pos : istart; - char* const ostart = (char*)output->dst; - char* const oend = output->size != 0 ? ostart + output->size : ostart; - char* op = output->pos != 0 ? ostart + output->pos : ostart; + const char* const istart = (assert(input != NULL), (const char*)input->src); + const char* const iend = (istart != NULL) ? istart + input->size : istart; + const char* ip = (istart != NULL) ? istart + input->pos : istart; + char* const ostart = (assert(output != NULL), (char*)output->dst); + char* const oend = (ostart != NULL) ? ostart + output->size : ostart; + char* op = (ostart != NULL) ? ostart + output->pos : ostart; U32 someMoreWork = 1; /* check expectations */ - DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%u", (unsigned)flushMode); + DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%i, srcSize = %zu", (int)flushMode, input->size - input->pos); + assert(zcs != NULL); + if (zcs->appliedParams.inBufferMode == ZSTD_bm_stable) { + assert(input->pos >= zcs->stableIn_notConsumed); + input->pos -= zcs->stableIn_notConsumed; + if (ip) ip -= zcs->stableIn_notConsumed; + zcs->stableIn_notConsumed = 0; + } if (zcs->appliedParams.inBufferMode == ZSTD_bm_buffered) { assert(zcs->inBuff != NULL); assert(zcs->inBuffSize > 0); @@ -5231,8 +5887,10 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, assert(zcs->outBuff != NULL); assert(zcs->outBuffSize > 0); } - assert(output->pos <= output->size); + if (input->src == NULL) assert(input->size == 0); assert(input->pos <= input->size); + if (output->dst == NULL) assert(output->size == 0); + assert(output->pos <= output->size); assert((U32)flushMode <= (U32)ZSTD_e_end); while (someMoreWork) { @@ -5247,7 +5905,7 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) /* OR we are allowed to return dstSizeTooSmall */ && (zcs->inBuffPos == 0) ) { /* shortcut to compression pass directly into output buffer */ - size_t const cSize = ZSTD_compressEnd(zcs, + size_t const cSize = ZSTD_compressEnd_public(zcs, op, oend-op, ip, iend-ip); DEBUGLOG(4, "ZSTD_compressEnd : cSize=%u", (unsigned)cSize); FORWARD_IF_ERROR(cSize, "ZSTD_compressEnd failed"); @@ -5264,8 +5922,7 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, zcs->inBuff + zcs->inBuffPos, toLoad, ip, iend-ip); zcs->inBuffPos += loaded; - if (loaded != 0) - ip += loaded; + if (ip) ip += loaded; if ( (flushMode == ZSTD_e_continue) && (zcs->inBuffPos < zcs->inBuffTarget) ) { /* not enough input to fill full block : stop here */ @@ -5276,6 +5933,20 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, /* empty */ someMoreWork = 0; break; } + } else { + assert(zcs->appliedParams.inBufferMode == ZSTD_bm_stable); + if ( (flushMode == ZSTD_e_continue) + && ( (size_t)(iend - ip) < zcs->blockSize) ) { + /* can't compress a full block : stop here */ + zcs->stableIn_notConsumed = (size_t)(iend - ip); + ip = iend; /* pretend to have consumed input */ + someMoreWork = 0; break; + } + if ( (flushMode == ZSTD_e_flush) + && (ip == iend) ) { + /* empty */ + someMoreWork = 0; break; + } } /* compress current block (note : this stage cannot be stopped in the middle) */ DEBUGLOG(5, "stream compression stage (flushMode==%u)", flushMode); @@ -5283,9 +5954,8 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, void* cDst; size_t cSize; size_t oSize = oend-op; - size_t const iSize = inputBuffered - ? zcs->inBuffPos - zcs->inToCompress - : MIN((size_t)(iend - ip), zcs->blockSize); + size_t const iSize = inputBuffered ? zcs->inBuffPos - zcs->inToCompress + : MIN((size_t)(iend - ip), zcs->blockSize); if (oSize >= ZSTD_compressBound(iSize) || zcs->appliedParams.outBufferMode == ZSTD_bm_stable) cDst = op; /* compress into output buffer, to skip flush stage */ else @@ -5293,9 +5963,9 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, if (inputBuffered) { unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip==iend); cSize = lastBlock ? - ZSTD_compressEnd(zcs, cDst, oSize, + ZSTD_compressEnd_public(zcs, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize) : - ZSTD_compressContinue(zcs, cDst, oSize, + ZSTD_compressContinue_public(zcs, cDst, oSize, zcs->inBuff + zcs->inToCompress, iSize); FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); zcs->frameEnded = lastBlock; @@ -5308,19 +5978,16 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs, if (!lastBlock) assert(zcs->inBuffTarget <= zcs->inBuffSize); zcs->inToCompress = zcs->inBuffPos; - } else { - unsigned const lastBlock = (ip + iSize == iend); - assert(flushMode == ZSTD_e_end /* Already validated */); + } else { /* !inputBuffered, hence ZSTD_bm_stable */ + unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip + iSize == iend); cSize = lastBlock ? - ZSTD_compressEnd(zcs, cDst, oSize, ip, iSize) : - ZSTD_compressContinue(zcs, cDst, oSize, ip, iSize); + ZSTD_compressEnd_public(zcs, cDst, oSize, ip, iSize) : + ZSTD_compressContinue_public(zcs, cDst, oSize, ip, iSize); /* Consume the input prior to error checking to mirror buffered mode. */ - if (iSize > 0) - ip += iSize; + if (ip) ip += iSize; FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed"); zcs->frameEnded = lastBlock; - if (lastBlock) - assert(ip == iend); + if (lastBlock) assert(ip == iend); } if (cDst == op) { /* no need to flush */ op += cSize; @@ -5390,8 +6057,10 @@ size_t ZSTD_compressStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuf /* After a compression call set the expected input/output buffer. * This is validated at the start of the next compression call. */ -static void ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, ZSTD_outBuffer const* output, ZSTD_inBuffer const* input) +static void +ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, const ZSTD_outBuffer* output, const ZSTD_inBuffer* input) { + DEBUGLOG(5, "ZSTD_setBufferExpectations (for advanced stable in/out modes)"); if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { cctx->expectedInBuffer = *input; } @@ -5410,22 +6079,22 @@ static size_t ZSTD_checkBufferStability(ZSTD_CCtx const* cctx, { if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) { ZSTD_inBuffer const expect = cctx->expectedInBuffer; - if (expect.src != input->src || expect.pos != input->pos || expect.size != input->size) - RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer enabled but input differs!"); - if (endOp != ZSTD_e_end) - RETURN_ERROR(srcBuffer_wrong, "ZSTD_c_stableInBuffer can only be used with ZSTD_e_end!"); + if (expect.src != input->src || expect.pos != input->pos) + RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableInBuffer enabled but input differs!"); } + (void)endOp; if (cctx->appliedParams.outBufferMode == ZSTD_bm_stable) { size_t const outBufferSize = output->size - output->pos; if (cctx->expectedOutBufferSize != outBufferSize) - RETURN_ERROR(dstBuffer_wrong, "ZSTD_c_stableOutBuffer enabled but output size differs!"); + RETURN_ERROR(stabilityCondition_notRespected, "ZSTD_c_stableOutBuffer enabled but output size differs!"); } return 0; } static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, ZSTD_EndDirective endOp, - size_t inSize) { + size_t inSize) +{ ZSTD_CCtx_params params = cctx->requestedParams; ZSTD_prefixDict const prefixDict = cctx->prefixDict; FORWARD_IF_ERROR( ZSTD_initLocalDict(cctx) , ""); /* Init the local dict if present. */ @@ -5439,9 +6108,9 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, params.compressionLevel = cctx->cdict->compressionLevel; } DEBUGLOG(4, "ZSTD_compressStream2 : transparent init stage"); - if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-fix pledgedSrcSize */ - { - size_t const dictSize = prefixDict.dict + if (endOp == ZSTD_e_end) cctx->pledgedSrcSizePlusOne = inSize + 1; /* auto-determine pledgedSrcSize */ + + { size_t const dictSize = prefixDict.dict ? prefixDict.dictSize : (cctx->cdict ? cctx->cdict->dictContentSize : 0); ZSTD_cParamMode_e const mode = ZSTD_getCParamMode(cctx->cdict, ¶ms, cctx->pledgedSrcSizePlusOne - 1); @@ -5453,6 +6122,9 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, params.useBlockSplitter = ZSTD_resolveBlockSplitterMode(params.useBlockSplitter, ¶ms.cParams); params.ldmParams.enableLdm = ZSTD_resolveEnableLdm(params.ldmParams.enableLdm, ¶ms.cParams); params.useRowMatchFinder = ZSTD_resolveRowMatchFinderMode(params.useRowMatchFinder, ¶ms.cParams); + params.validateSequences = ZSTD_resolveExternalSequenceValidation(params.validateSequences); + params.maxBlockSize = ZSTD_resolveMaxBlockSize(params.maxBlockSize); + params.searchForExternalRepcodes = ZSTD_resolveExternalRepcodeSearch(params.searchForExternalRepcodes, params.compressionLevel); { U64 const pledgedSrcSize = cctx->pledgedSrcSizePlusOne - 1; assert(!ZSTD_isError(ZSTD_checkCParams(params.cParams))); @@ -5479,6 +6151,8 @@ static size_t ZSTD_CCtx_init_compressStream2(ZSTD_CCtx* cctx, return 0; } +/* @return provides a minimum amount of data remaining to be flushed from internal buffers + */ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, @@ -5493,8 +6167,27 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx, /* transparent initialization stage */ if (cctx->streamStage == zcss_init) { - FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, input->size), "CompressStream2 initialization failed"); - ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ + size_t const inputSize = input->size - input->pos; /* no obligation to start from pos==0 */ + size_t const totalInputSize = inputSize + cctx->stableIn_notConsumed; + if ( (cctx->requestedParams.inBufferMode == ZSTD_bm_stable) /* input is presumed stable, across invocations */ + && (endOp == ZSTD_e_continue) /* no flush requested, more input to come */ + && (totalInputSize < ZSTD_BLOCKSIZE_MAX) ) { /* not even reached one block yet */ + if (cctx->stableIn_notConsumed) { /* not the first time */ + /* check stable source guarantees */ + RETURN_ERROR_IF(input->src != cctx->expectedInBuffer.src, stabilityCondition_notRespected, "stableInBuffer condition not respected: wrong src pointer"); + RETURN_ERROR_IF(input->pos != cctx->expectedInBuffer.size, stabilityCondition_notRespected, "stableInBuffer condition not respected: externally modified pos"); + } + /* pretend input was consumed, to give a sense forward progress */ + input->pos = input->size; + /* save stable inBuffer, for later control, and flush/end */ + cctx->expectedInBuffer = *input; + /* but actually input wasn't consumed, so keep track of position from where compression shall resume */ + cctx->stableIn_notConsumed += inputSize; + /* don't initialize yet, wait for the first block of flush() order, for better parameters adaptation */ + return ZSTD_FRAMEHEADERSIZE_MIN(cctx->requestedParams.format); /* at least some header to produce */ + } + FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, totalInputSize), "compressStream2 initialization failed"); + ZSTD_setBufferExpectations(cctx, output, input); /* Set initial buffer expectations now that we've initialized */ } /* end of transparent initialization stage */ @@ -5512,13 +6205,20 @@ size_t ZSTD_compressStream2_simpleArgs ( const void* src, size_t srcSize, size_t* srcPos, ZSTD_EndDirective endOp) { - ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; - ZSTD_inBuffer input = { src, srcSize, *srcPos }; + ZSTD_outBuffer output; + ZSTD_inBuffer input; + output.dst = dst; + output.size = dstCapacity; + output.pos = *dstPos; + input.src = src; + input.size = srcSize; + input.pos = *srcPos; /* ZSTD_compressStream2() will check validity of dstPos and srcPos */ - size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); - *dstPos = output.pos; - *srcPos = input.pos; - return cErr; + { size_t const cErr = ZSTD_compressStream2(cctx, &output, &input, endOp); + *dstPos = output.pos; + *srcPos = input.pos; + return cErr; + } } size_t ZSTD_compress2(ZSTD_CCtx* cctx, @@ -5541,6 +6241,7 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, /* Reset to the original values. */ cctx->requestedParams.inBufferMode = originalInBufferMode; cctx->requestedParams.outBufferMode = originalOutBufferMode; + FORWARD_IF_ERROR(result, "ZSTD_compressStream2_simpleArgs failed"); if (result != 0) { /* compression not completed, due to lack of output space */ assert(oPos == dstCapacity); @@ -5551,64 +6252,61 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx, } } -typedef struct { - U32 idx; /* Index in array of ZSTD_Sequence */ - U32 posInSequence; /* Position within sequence at idx */ - size_t posInSrc; /* Number of bytes given by sequences provided so far */ -} ZSTD_sequencePosition; - /* ZSTD_validateSequence() : * @offCode : is presumed to follow format required by ZSTD_storeSeq() * @returns a ZSTD error code if sequence is not valid */ static size_t -ZSTD_validateSequence(U32 offCode, U32 matchLength, - size_t posInSrc, U32 windowLog, size_t dictSize) +ZSTD_validateSequence(U32 offCode, U32 matchLength, U32 minMatch, + size_t posInSrc, U32 windowLog, size_t dictSize, int useSequenceProducer) { - U32 const windowSize = 1 << windowLog; + U32 const windowSize = 1u << windowLog; /* posInSrc represents the amount of data the decoder would decode up to this point. * As long as the amount of data decoded is less than or equal to window size, offsets may be * larger than the total length of output decoded in order to reference the dict, even larger than * window size. After output surpasses windowSize, we're limited to windowSize offsets again. */ size_t const offsetBound = posInSrc > windowSize ? (size_t)windowSize : posInSrc + (size_t)dictSize; - RETURN_ERROR_IF(offCode > STORE_OFFSET(offsetBound), corruption_detected, "Offset too large!"); - RETURN_ERROR_IF(matchLength < MINMATCH, corruption_detected, "Matchlength too small"); + size_t const matchLenLowerBound = (minMatch == 3 || useSequenceProducer) ? 3 : 4; + RETURN_ERROR_IF(offCode > OFFSET_TO_OFFBASE(offsetBound), externalSequences_invalid, "Offset too large!"); + /* Validate maxNbSeq is large enough for the given matchLength and minMatch */ + RETURN_ERROR_IF(matchLength < matchLenLowerBound, externalSequences_invalid, "Matchlength too small for the minMatch"); return 0; } /* Returns an offset code, given a sequence's raw offset, the ongoing repcode array, and whether litLength == 0 */ -static U32 ZSTD_finalizeOffCode(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) +static U32 ZSTD_finalizeOffBase(U32 rawOffset, const U32 rep[ZSTD_REP_NUM], U32 ll0) { - U32 offCode = STORE_OFFSET(rawOffset); + U32 offBase = OFFSET_TO_OFFBASE(rawOffset); if (!ll0 && rawOffset == rep[0]) { - offCode = STORE_REPCODE_1; + offBase = REPCODE1_TO_OFFBASE; } else if (rawOffset == rep[1]) { - offCode = STORE_REPCODE(2 - ll0); + offBase = REPCODE_TO_OFFBASE(2 - ll0); } else if (rawOffset == rep[2]) { - offCode = STORE_REPCODE(3 - ll0); + offBase = REPCODE_TO_OFFBASE(3 - ll0); } else if (ll0 && rawOffset == rep[0] - 1) { - offCode = STORE_REPCODE_3; + offBase = REPCODE3_TO_OFFBASE; } - return offCode; + return offBase; } -/* Returns 0 on success, and a ZSTD_error otherwise. This function scans through an array of - * ZSTD_Sequence, storing the sequences it finds, until it reaches a block delimiter. - */ -static size_t +size_t ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize) + const void* src, size_t blockSize, + ZSTD_paramSwitch_e externalRepSearch) { U32 idx = seqPos->idx; + U32 const startIdx = idx; BYTE const* ip = (BYTE const*)(src); const BYTE* const iend = ip + blockSize; repcodes_t updatedRepcodes; U32 dictSize; + DEBUGLOG(5, "ZSTD_copySequencesToSeqStoreExplicitBlockDelim (blockSize = %zu)", blockSize); + if (cctx->cdict) { dictSize = (U32)cctx->cdict->dictContentSize; } else if (cctx->prefixDict.dict) { @@ -5617,25 +6315,55 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, dictSize = 0; } ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); - for (; (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0) && idx < inSeqsSize; ++idx) { + for (; idx < inSeqsSize && (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0); ++idx) { U32 const litLength = inSeqs[idx].litLength; - U32 const ll0 = (litLength == 0); U32 const matchLength = inSeqs[idx].matchLength; - U32 const offCode = ZSTD_finalizeOffCode(inSeqs[idx].offset, updatedRepcodes.rep, ll0); - ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); + U32 offBase; + + if (externalRepSearch == ZSTD_ps_disable) { + offBase = OFFSET_TO_OFFBASE(inSeqs[idx].offset); + } else { + U32 const ll0 = (litLength == 0); + offBase = ZSTD_finalizeOffBase(inSeqs[idx].offset, updatedRepcodes.rep, ll0); + ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); + } - DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); + DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); if (cctx->appliedParams.validateSequences) { seqPos->posInSrc += litLength + matchLength; - FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, - cctx->appliedParams.cParams.windowLog, dictSize), + FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, seqPos->posInSrc, + cctx->appliedParams.cParams.windowLog, dictSize, ZSTD_hasExtSeqProd(&cctx->appliedParams)), "Sequence validation failed"); } - RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, + RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); - ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); + ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); ip += matchLength + litLength; } + + /* If we skipped repcode search while parsing, we need to update repcodes now */ + assert(externalRepSearch != ZSTD_ps_auto); + assert(idx >= startIdx); + if (externalRepSearch == ZSTD_ps_disable && idx != startIdx) { + U32* const rep = updatedRepcodes.rep; + U32 lastSeqIdx = idx - 1; /* index of last non-block-delimiter sequence */ + + if (lastSeqIdx >= startIdx + 2) { + rep[2] = inSeqs[lastSeqIdx - 2].offset; + rep[1] = inSeqs[lastSeqIdx - 1].offset; + rep[0] = inSeqs[lastSeqIdx].offset; + } else if (lastSeqIdx == startIdx + 1) { + rep[2] = rep[0]; + rep[1] = inSeqs[lastSeqIdx - 1].offset; + rep[0] = inSeqs[lastSeqIdx].offset; + } else { + assert(lastSeqIdx == startIdx); + rep[2] = rep[1]; + rep[1] = rep[0]; + rep[0] = inSeqs[lastSeqIdx].offset; + } + } + ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t)); if (inSeqs[idx].litLength) { @@ -5644,26 +6372,15 @@ ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, ip += inSeqs[idx].litLength; seqPos->posInSrc += inSeqs[idx].litLength; } - RETURN_ERROR_IF(ip != iend, corruption_detected, "Blocksize doesn't agree with block delimiter!"); + RETURN_ERROR_IF(ip != iend, externalSequences_invalid, "Blocksize doesn't agree with block delimiter!"); seqPos->idx = idx+1; return 0; } -/* Returns the number of bytes to move the current read position back by. Only non-zero - * if we ended up splitting a sequence. Otherwise, it may return a ZSTD error if something - * went wrong. - * - * This function will attempt to scan through blockSize bytes represented by the sequences - * in inSeqs, storing any (partial) sequences. - * - * Occasionally, we may want to change the actual number of bytes we consumed from inSeqs to - * avoid splitting a match, or to avoid splitting a match such that it would produce a match - * smaller than MINMATCH. In this case, we return the number of bytes that we didn't read from this block. - */ -static size_t +size_t ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize) + const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch) { U32 idx = seqPos->idx; U32 startPosInSequence = seqPos->posInSequence; @@ -5675,6 +6392,9 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* U32 bytesAdjustment = 0; U32 finalMatchSplit = 0; + /* TODO(embg) support fast parsing mode in noBlockDelim mode */ + (void)externalRepSearch; + if (cctx->cdict) { dictSize = cctx->cdict->dictContentSize; } else if (cctx->prefixDict.dict) { @@ -5682,7 +6402,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* } else { dictSize = 0; } - DEBUGLOG(5, "ZSTD_copySequencesToSeqStore: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); + DEBUGLOG(5, "ZSTD_copySequencesToSeqStoreNoBlockDelim: idx: %u PIS: %u blockSize: %zu", idx, startPosInSequence, blockSize); DEBUGLOG(5, "Start seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t)); while (endPosInSequence && idx < inSeqsSize && !finalMatchSplit) { @@ -5690,7 +6410,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* U32 litLength = currSeq.litLength; U32 matchLength = currSeq.matchLength; U32 const rawOffset = currSeq.offset; - U32 offCode; + U32 offBase; /* Modify the sequence depending on where endPosInSequence lies */ if (endPosInSequence >= currSeq.litLength + currSeq.matchLength) { @@ -5704,7 +6424,6 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* /* Move to the next sequence */ endPosInSequence -= currSeq.litLength + currSeq.matchLength; startPosInSequence = 0; - idx++; } else { /* This is the final (partial) sequence we're adding from inSeqs, and endPosInSequence does not reach the end of the match. So, we have to split the sequence */ @@ -5744,21 +6463,23 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* } /* Check if this offset can be represented with a repcode */ { U32 const ll0 = (litLength == 0); - offCode = ZSTD_finalizeOffCode(rawOffset, updatedRepcodes.rep, ll0); - ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0); + offBase = ZSTD_finalizeOffBase(rawOffset, updatedRepcodes.rep, ll0); + ZSTD_updateRep(updatedRepcodes.rep, offBase, ll0); } if (cctx->appliedParams.validateSequences) { seqPos->posInSrc += litLength + matchLength; - FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, matchLength, seqPos->posInSrc, - cctx->appliedParams.cParams.windowLog, dictSize), + FORWARD_IF_ERROR(ZSTD_validateSequence(offBase, matchLength, cctx->appliedParams.cParams.minMatch, seqPos->posInSrc, + cctx->appliedParams.cParams.windowLog, dictSize, ZSTD_hasExtSeqProd(&cctx->appliedParams)), "Sequence validation failed"); } - DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength); - RETURN_ERROR_IF(idx - seqPos->idx > cctx->seqStore.maxNbSeq, memory_allocation, + DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offBase, matchLength, litLength); + RETURN_ERROR_IF(idx - seqPos->idx >= cctx->seqStore.maxNbSeq, externalSequences_invalid, "Not enough memory allocated. Try adjusting ZSTD_c_minMatch."); - ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offCode, matchLength); + ZSTD_storeSeq(&cctx->seqStore, litLength, ip, iend, offBase, matchLength); ip += matchLength + litLength; + if (!finalMatchSplit) + idx++; /* Next Sequence */ } DEBUGLOG(5, "Ending seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength); assert(idx == inSeqsSize || endPosInSequence <= inSeqs[idx].litLength + inSeqs[idx].matchLength); @@ -5781,7 +6502,7 @@ ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* typedef size_t (*ZSTD_sequenceCopier) (ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, - const void* src, size_t blockSize); + const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); static ZSTD_sequenceCopier ZSTD_selectSequenceCopier(ZSTD_sequenceFormat_e mode) { ZSTD_sequenceCopier sequenceCopier = NULL; @@ -5795,6 +6516,57 @@ static ZSTD_sequenceCopier ZSTD_selectSequenceCopier(ZSTD_sequenceFormat_e mode) return sequenceCopier; } +/* Discover the size of next block by searching for the delimiter. + * Note that a block delimiter **must** exist in this mode, + * otherwise it's an input error. + * The block size retrieved will be later compared to ensure it remains within bounds */ +static size_t +blockSize_explicitDelimiter(const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ZSTD_sequencePosition seqPos) +{ + int end = 0; + size_t blockSize = 0; + size_t spos = seqPos.idx; + DEBUGLOG(6, "blockSize_explicitDelimiter : seq %zu / %zu", spos, inSeqsSize); + assert(spos <= inSeqsSize); + while (spos < inSeqsSize) { + end = (inSeqs[spos].offset == 0); + blockSize += inSeqs[spos].litLength + inSeqs[spos].matchLength; + if (end) { + if (inSeqs[spos].matchLength != 0) + RETURN_ERROR(externalSequences_invalid, "delimiter format error : both matchlength and offset must be == 0"); + break; + } + spos++; + } + if (!end) + RETURN_ERROR(externalSequences_invalid, "Reached end of sequences without finding a block delimiter"); + return blockSize; +} + +/* More a "target" block size */ +static size_t blockSize_noDelimiter(size_t blockSize, size_t remaining) +{ + int const lastBlock = (remaining <= blockSize); + return lastBlock ? remaining : blockSize; +} + +static size_t determine_blockSize(ZSTD_sequenceFormat_e mode, + size_t blockSize, size_t remaining, + const ZSTD_Sequence* inSeqs, size_t inSeqsSize, ZSTD_sequencePosition seqPos) +{ + DEBUGLOG(6, "determine_blockSize : remainingSize = %zu", remaining); + if (mode == ZSTD_sf_noBlockDelimiters) + return blockSize_noDelimiter(blockSize, remaining); + { size_t const explicitBlockSize = blockSize_explicitDelimiter(inSeqs, inSeqsSize, seqPos); + FORWARD_IF_ERROR(explicitBlockSize, "Error while determining block size with explicit delimiters"); + if (explicitBlockSize > blockSize) + RETURN_ERROR(externalSequences_invalid, "sequences incorrectly define a too large block"); + if (explicitBlockSize > remaining) + RETURN_ERROR(externalSequences_invalid, "sequences define a frame longer than source"); + return explicitBlockSize; + } +} + /* Compress, block-by-block, all of the sequences given. * * Returns the cumulative size of all compressed blocks (including their headers), @@ -5807,9 +6579,6 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, const void* src, size_t srcSize) { size_t cSize = 0; - U32 lastBlock; - size_t blockSize; - size_t compressedSeqsSize; size_t remaining = srcSize; ZSTD_sequencePosition seqPos = {0, 0, 0}; @@ -5829,22 +6598,29 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, } while (remaining) { + size_t compressedSeqsSize; size_t cBlockSize; size_t additionalByteAdjustment; - lastBlock = remaining <= cctx->blockSize; - blockSize = lastBlock ? (U32)remaining : (U32)cctx->blockSize; + size_t blockSize = determine_blockSize(cctx->appliedParams.blockDelimiters, + cctx->blockSize, remaining, + inSeqs, inSeqsSize, seqPos); + U32 const lastBlock = (blockSize == remaining); + FORWARD_IF_ERROR(blockSize, "Error while trying to determine block size"); + assert(blockSize <= remaining); ZSTD_resetSeqStore(&cctx->seqStore); - DEBUGLOG(4, "Working on new block. Blocksize: %zu", blockSize); + DEBUGLOG(5, "Working on new block. Blocksize: %zu (total:%zu)", blockSize, (ip - (const BYTE*)src) + blockSize); - additionalByteAdjustment = sequenceCopier(cctx, &seqPos, inSeqs, inSeqsSize, ip, blockSize); + additionalByteAdjustment = sequenceCopier(cctx, &seqPos, inSeqs, inSeqsSize, ip, blockSize, cctx->appliedParams.searchForExternalRepcodes); FORWARD_IF_ERROR(additionalByteAdjustment, "Bad sequence copy"); blockSize -= additionalByteAdjustment; /* If blocks are too small, emit as a nocompress block */ - if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1) { + /* TODO: See 3090. We reduced MIN_CBLOCK_SIZE from 3 to 2 so to compensate we are adding + * additional 1. We need to revisit and change this logic to be more consistent */ + if (blockSize < MIN_CBLOCK_SIZE+ZSTD_blockHeaderSize+1+1) { cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); - DEBUGLOG(4, "Block too small, writing out nocompress block: cSize: %zu", cBlockSize); + DEBUGLOG(5, "Block too small, writing out nocompress block: cSize: %zu", cBlockSize); cSize += cBlockSize; ip += blockSize; op += cBlockSize; @@ -5853,6 +6629,7 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, continue; } + RETURN_ERROR_IF(dstCapacity < ZSTD_blockHeaderSize, dstSize_tooSmall, "not enough dstCapacity to write a new compressed block"); compressedSeqsSize = ZSTD_entropyCompressSeqStore(&cctx->seqStore, &cctx->blockState.prevCBlock->entropy, &cctx->blockState.nextCBlock->entropy, &cctx->appliedParams, @@ -5861,11 +6638,11 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, cctx->entropyWorkspace, ENTROPY_WORKSPACE_SIZE /* statically allocated in resetCCtx */, cctx->bmi2); FORWARD_IF_ERROR(compressedSeqsSize, "Compressing sequences of block failed"); - DEBUGLOG(4, "Compressed sequences size: %zu", compressedSeqsSize); + DEBUGLOG(5, "Compressed sequences size: %zu", compressedSeqsSize); if (!cctx->isFirstBlock && ZSTD_maybeRLE(&cctx->seqStore) && - ZSTD_isRLE((BYTE const*)src, srcSize)) { + ZSTD_isRLE(ip, blockSize)) { /* We don't want to emit our first block as a RLE even if it qualifies because * doing so will cause the decoder (cli only) to throw a "should consume all input error." * This is only an issue for zstd <= v1.4.3 @@ -5876,12 +6653,12 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, if (compressedSeqsSize == 0) { /* ZSTD_noCompressBlock writes the block header as well */ cBlockSize = ZSTD_noCompressBlock(op, dstCapacity, ip, blockSize, lastBlock); - FORWARD_IF_ERROR(cBlockSize, "Nocompress block failed"); - DEBUGLOG(4, "Writing out nocompress block, size: %zu", cBlockSize); + FORWARD_IF_ERROR(cBlockSize, "ZSTD_noCompressBlock failed"); + DEBUGLOG(5, "Writing out nocompress block, size: %zu", cBlockSize); } else if (compressedSeqsSize == 1) { cBlockSize = ZSTD_rleCompressBlock(op, dstCapacity, *ip, blockSize, lastBlock); - FORWARD_IF_ERROR(cBlockSize, "RLE compress block failed"); - DEBUGLOG(4, "Writing out RLE block, size: %zu", cBlockSize); + FORWARD_IF_ERROR(cBlockSize, "ZSTD_rleCompressBlock failed"); + DEBUGLOG(5, "Writing out RLE block, size: %zu", cBlockSize); } else { U32 cBlockHeader; /* Error checking and repcodes update */ @@ -5893,11 +6670,10 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, cBlockHeader = lastBlock + (((U32)bt_compressed)<<1) + (U32)(compressedSeqsSize << 3); MEM_writeLE24(op, cBlockHeader); cBlockSize = ZSTD_blockHeaderSize + compressedSeqsSize; - DEBUGLOG(4, "Writing out compressed block, size: %zu", cBlockSize); + DEBUGLOG(5, "Writing out compressed block, size: %zu", cBlockSize); } cSize += cBlockSize; - DEBUGLOG(4, "cSize running total: %zu", cSize); if (lastBlock) { break; @@ -5908,12 +6684,15 @@ ZSTD_compressSequences_internal(ZSTD_CCtx* cctx, dstCapacity -= cBlockSize; cctx->isFirstBlock = 0; } + DEBUGLOG(5, "cSize running total: %zu (remaining dstCapacity=%zu)", cSize, dstCapacity); } + DEBUGLOG(4, "cSize final total: %zu", cSize); return cSize; } -size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapacity, +size_t ZSTD_compressSequences(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, const ZSTD_Sequence* inSeqs, size_t inSeqsSize, const void* src, size_t srcSize) { @@ -5923,7 +6702,7 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci size_t frameHeaderSize = 0; /* Transparent initialization stage, same as compressStream2() */ - DEBUGLOG(3, "ZSTD_compressSequences()"); + DEBUGLOG(4, "ZSTD_compressSequences (dstCapacity=%zu)", dstCapacity); assert(cctx != NULL); FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, ZSTD_e_end, srcSize), "CCtx initialization failed"); /* Begin writing output, starting with frame header */ @@ -5951,26 +6730,34 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci cSize += 4; } - DEBUGLOG(3, "Final compressed size: %zu", cSize); + DEBUGLOG(4, "Final compressed size: %zu", cSize); return cSize; } /*====== Finalize ======*/ +static ZSTD_inBuffer inBuffer_forEndFlush(const ZSTD_CStream* zcs) +{ + const ZSTD_inBuffer nullInput = { NULL, 0, 0 }; + const int stableInput = (zcs->appliedParams.inBufferMode == ZSTD_bm_stable); + return stableInput ? zcs->expectedInBuffer : nullInput; +} + /*! ZSTD_flushStream() : * @return : amount of data remaining to flush */ size_t ZSTD_flushStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) { - ZSTD_inBuffer input = { NULL, 0, 0 }; + ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); + input.size = input.pos; /* do not ingest more input during flush */ return ZSTD_compressStream2(zcs, output, &input, ZSTD_e_flush); } size_t ZSTD_endStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output) { - ZSTD_inBuffer input = { NULL, 0, 0 }; + ZSTD_inBuffer input = inBuffer_forEndFlush(zcs); size_t const remainingToFlush = ZSTD_compressStream2(zcs, output, &input, ZSTD_e_end); - FORWARD_IF_ERROR( remainingToFlush , "ZSTD_compressStream2 failed"); + FORWARD_IF_ERROR(remainingToFlush , "ZSTD_compressStream2(,,ZSTD_e_end) failed"); if (zcs->appliedParams.nbWorkers > 0) return remainingToFlush; /* minimal estimation */ /* single thread mode : attempt to calculate remaining to flush more precisely */ { size_t const lastBlockSize = zcs->frameEnded ? 0 : ZSTD_BLOCKHEADERSIZE; @@ -6092,7 +6879,7 @@ static ZSTD_compressionParameters ZSTD_getCParams_internal(int compressionLevel, cp.targetLength = (unsigned)(-clampedCompressionLevel); } /* refine parameters based on srcSize & dictSize */ - return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode); + return ZSTD_adjustCParams_internal(cp, srcSizeHint, dictSize, mode, ZSTD_ps_auto); } } @@ -6127,3 +6914,29 @@ ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSizeH if (srcSizeHint == 0) srcSizeHint = ZSTD_CONTENTSIZE_UNKNOWN; return ZSTD_getParams_internal(compressionLevel, srcSizeHint, dictSize, ZSTD_cpm_unknown); } + +void ZSTD_registerSequenceProducer( + ZSTD_CCtx* zc, + void* extSeqProdState, + ZSTD_sequenceProducer_F extSeqProdFunc +) { + assert(zc != NULL); + ZSTD_CCtxParams_registerSequenceProducer( + &zc->requestedParams, extSeqProdState, extSeqProdFunc + ); +} + +void ZSTD_CCtxParams_registerSequenceProducer( + ZSTD_CCtx_params* params, + void* extSeqProdState, + ZSTD_sequenceProducer_F extSeqProdFunc +) { + assert(params != NULL); + if (extSeqProdFunc != NULL) { + params->extSeqProdFunc = extSeqProdFunc; + params->extSeqProdState = extSeqProdState; + } else { + params->extSeqProdFunc = NULL; + params->extSeqProdState = NULL; + } +} diff --git a/lib/zstd/compress/zstd_compress_internal.h b/lib/zstd/compress/zstd_compress_internal.h index 71697a11ae30..53cb582a8d2b 100644 --- a/lib/zstd/compress/zstd_compress_internal.h +++ b/lib/zstd/compress/zstd_compress_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -20,6 +21,7 @@ ***************************************/ #include "../common/zstd_internal.h" #include "zstd_cwksp.h" +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_NbCommonBytes */ /*-************************************* @@ -32,7 +34,7 @@ It's not a big deal though : candidate will just be sorted again. Additionally, candidate position 1 will be lost. But candidate 1 cannot hide a large tree of candidates, so it's a minimal loss. - The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table re-use with a different strategy. + The benefit is that ZSTD_DUBT_UNSORTED_MARK cannot be mishandled after table reuse with a different strategy. This constant is required by ZSTD_compressBlock_btlazy2() and ZSTD_reduceTable_internal() */ @@ -111,12 +113,13 @@ typedef struct { /* ZSTD_buildBlockEntropyStats() : * Builds entropy for the block. * @return : 0 on success or error code */ -size_t ZSTD_buildBlockEntropyStats(seqStore_t* seqStorePtr, - const ZSTD_entropyCTables_t* prevEntropy, - ZSTD_entropyCTables_t* nextEntropy, - const ZSTD_CCtx_params* cctxParams, - ZSTD_entropyCTablesMetadata_t* entropyMetadata, - void* workspace, size_t wkspSize); +size_t ZSTD_buildBlockEntropyStats( + const seqStore_t* seqStorePtr, + const ZSTD_entropyCTables_t* prevEntropy, + ZSTD_entropyCTables_t* nextEntropy, + const ZSTD_CCtx_params* cctxParams, + ZSTD_entropyCTablesMetadata_t* entropyMetadata, + void* workspace, size_t wkspSize); /* ******************************* * Compression internals structs * @@ -142,26 +145,33 @@ typedef struct { size_t capacity; /* The capacity starting from `seq` pointer */ } rawSeqStore_t; +typedef struct { + U32 idx; /* Index in array of ZSTD_Sequence */ + U32 posInSequence; /* Position within sequence at idx */ + size_t posInSrc; /* Number of bytes given by sequences provided so far */ +} ZSTD_sequencePosition; + UNUSED_ATTR static const rawSeqStore_t kNullRawSeqStore = {NULL, 0, 0, 0, 0}; typedef struct { - int price; - U32 off; - U32 mlen; - U32 litlen; - U32 rep[ZSTD_REP_NUM]; + int price; /* price from beginning of segment to this position */ + U32 off; /* offset of previous match */ + U32 mlen; /* length of previous match */ + U32 litlen; /* nb of literals since previous match */ + U32 rep[ZSTD_REP_NUM]; /* offset history after previous match */ } ZSTD_optimal_t; typedef enum { zop_dynamic=0, zop_predef } ZSTD_OptPrice_e; +#define ZSTD_OPT_SIZE (ZSTD_OPT_NUM+3) typedef struct { /* All tables are allocated inside cctx->workspace by ZSTD_resetCCtx_internal() */ unsigned* litFreq; /* table of literals statistics, of size 256 */ unsigned* litLengthFreq; /* table of litLength statistics, of size (MaxLL+1) */ unsigned* matchLengthFreq; /* table of matchLength statistics, of size (MaxML+1) */ unsigned* offCodeFreq; /* table of offCode statistics, of size (MaxOff+1) */ - ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_NUM+1 */ - ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_NUM+1 */ + ZSTD_match_t* matchTable; /* list of found matches, of size ZSTD_OPT_SIZE */ + ZSTD_optimal_t* priceTable; /* All positions tracked by optimal parser, of size ZSTD_OPT_SIZE */ U32 litSum; /* nb of literals */ U32 litLengthSum; /* nb of litLength codes */ @@ -212,8 +222,10 @@ struct ZSTD_matchState_t { U32 hashLog3; /* dispatch table for matches of len==3 : larger == faster, more memory */ U32 rowHashLog; /* For row-based matchfinder: Hashlog based on nb of rows in the hashTable.*/ - U16* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ + BYTE* tagTable; /* For row-based matchFinder: A row-based table containing the hashes and head index. */ U32 hashCache[ZSTD_ROW_HASH_CACHE_SIZE]; /* For row-based matchFinder: a cache of hashes to improve speed */ + U64 hashSalt; /* For row-based matchFinder: salts the hash for reuse of tag table */ + U32 hashSaltEntropy; /* For row-based matchFinder: collects entropy for salt generation */ U32* hashTable; U32* hashTable3; @@ -228,6 +240,18 @@ struct ZSTD_matchState_t { const ZSTD_matchState_t* dictMatchState; ZSTD_compressionParameters cParams; const rawSeqStore_t* ldmSeqStore; + + /* Controls prefetching in some dictMatchState matchfinders. + * This behavior is controlled from the cctx ms. + * This parameter has no effect in the cdict ms. */ + int prefetchCDictTables; + + /* When == 0, lazy match finders insert every position. + * When != 0, lazy match finders only insert positions they search. + * This allows them to skip much faster over incompressible data, + * at a small cost to compression ratio. + */ + int lazySkipping; }; typedef struct { @@ -324,6 +348,25 @@ struct ZSTD_CCtx_params_s { /* Internal use, for createCCtxParams() and freeCCtxParams() only */ ZSTD_customMem customMem; + + /* Controls prefetching in some dictMatchState matchfinders */ + ZSTD_paramSwitch_e prefetchCDictTables; + + /* Controls whether zstd will fall back to an internal matchfinder + * if the external matchfinder returns an error code. */ + int enableMatchFinderFallback; + + /* Parameters for the external sequence producer API. + * Users set these parameters through ZSTD_registerSequenceProducer(). + * It is not possible to set these parameters individually through the public API. */ + void* extSeqProdState; + ZSTD_sequenceProducer_F extSeqProdFunc; + + /* Adjust the max block size*/ + size_t maxBlockSize; + + /* Controls repcode search in external sequence parsing */ + ZSTD_paramSwitch_e searchForExternalRepcodes; }; /* typedef'd to ZSTD_CCtx_params within "zstd.h" */ #define COMPRESS_SEQUENCES_WORKSPACE_SIZE (sizeof(unsigned) * (MaxSeq + 2)) @@ -404,6 +447,7 @@ struct ZSTD_CCtx_s { /* Stable in/out buffer verification */ ZSTD_inBuffer expectedInBuffer; + size_t stableIn_notConsumed; /* nb bytes within stable input buffer that are said to be consumed but are not */ size_t expectedOutBufferSize; /* Dictionary */ @@ -417,9 +461,14 @@ struct ZSTD_CCtx_s { /* Workspace for block splitter */ ZSTD_blockSplitCtx blockSplitCtx; + + /* Buffer for output from external sequence producer */ + ZSTD_Sequence* extSeqBuf; + size_t extSeqBufCapacity; }; typedef enum { ZSTD_dtlm_fast, ZSTD_dtlm_full } ZSTD_dictTableLoadMethod_e; +typedef enum { ZSTD_tfp_forCCtx, ZSTD_tfp_forCDict } ZSTD_tableFillPurpose_e; typedef enum { ZSTD_noDict = 0, @@ -441,7 +490,7 @@ typedef enum { * In this mode we take both the source size and the dictionary size * into account when selecting and adjusting the parameters. */ - ZSTD_cpm_unknown = 3, /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. + ZSTD_cpm_unknown = 3 /* ZSTD_getCParams, ZSTD_getParams, ZSTD_adjustParams. * We don't know what these parameters are for. We default to the legacy * behavior of taking both the source size and the dict size into account * when selecting and adjusting parameters. @@ -500,9 +549,11 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value) /* ZSTD_noCompressBlock() : * Writes uncompressed block to dst buffer from given src. * Returns the size of the block */ -MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) +MEM_STATIC size_t +ZSTD_noCompressBlock(void* dst, size_t dstCapacity, const void* src, size_t srcSize, U32 lastBlock) { U32 const cBlockHeader24 = lastBlock + (((U32)bt_raw)<<1) + (U32)(srcSize << 3); + DEBUGLOG(5, "ZSTD_noCompressBlock (srcSize=%zu, dstCapacity=%zu)", srcSize, dstCapacity); RETURN_ERROR_IF(srcSize + ZSTD_blockHeaderSize > dstCapacity, dstSize_tooSmall, "dst buf too small for uncompressed block"); MEM_writeLE24(dst, cBlockHeader24); @@ -510,7 +561,8 @@ MEM_STATIC size_t ZSTD_noCompressBlock (void* dst, size_t dstCapacity, const voi return ZSTD_blockHeaderSize + srcSize; } -MEM_STATIC size_t ZSTD_rleCompressBlock (void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) +MEM_STATIC size_t +ZSTD_rleCompressBlock(void* dst, size_t dstCapacity, BYTE src, size_t srcSize, U32 lastBlock) { BYTE* const op = (BYTE*)dst; U32 const cBlockHeader = lastBlock + (((U32)bt_rle)<<1) + (U32)(srcSize << 3); @@ -529,7 +581,7 @@ MEM_STATIC size_t ZSTD_minGain(size_t srcSize, ZSTD_strategy strat) { U32 const minlog = (strat>=ZSTD_btultra) ? (U32)(strat) - 1 : 6; ZSTD_STATIC_ASSERT(ZSTD_btultra == 8); - assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, strat)); + assert(ZSTD_cParam_withinBounds(ZSTD_c_strategy, (int)strat)); return (srcSize >> minlog) + 2; } @@ -565,29 +617,27 @@ ZSTD_safecopyLiterals(BYTE* op, BYTE const* ip, BYTE const* const iend, BYTE con while (ip < iend) *op++ = *ip++; } -#define ZSTD_REP_MOVE (ZSTD_REP_NUM-1) -#define STORE_REPCODE_1 STORE_REPCODE(1) -#define STORE_REPCODE_2 STORE_REPCODE(2) -#define STORE_REPCODE_3 STORE_REPCODE(3) -#define STORE_REPCODE(r) (assert((r)>=1), assert((r)<=3), (r)-1) -#define STORE_OFFSET(o) (assert((o)>0), o + ZSTD_REP_MOVE) -#define STORED_IS_OFFSET(o) ((o) > ZSTD_REP_MOVE) -#define STORED_IS_REPCODE(o) ((o) <= ZSTD_REP_MOVE) -#define STORED_OFFSET(o) (assert(STORED_IS_OFFSET(o)), (o)-ZSTD_REP_MOVE) -#define STORED_REPCODE(o) (assert(STORED_IS_REPCODE(o)), (o)+1) /* returns ID 1,2,3 */ -#define STORED_TO_OFFBASE(o) ((o)+1) -#define OFFBASE_TO_STORED(o) ((o)-1) + +#define REPCODE1_TO_OFFBASE REPCODE_TO_OFFBASE(1) +#define REPCODE2_TO_OFFBASE REPCODE_TO_OFFBASE(2) +#define REPCODE3_TO_OFFBASE REPCODE_TO_OFFBASE(3) +#define REPCODE_TO_OFFBASE(r) (assert((r)>=1), assert((r)<=ZSTD_REP_NUM), (r)) /* accepts IDs 1,2,3 */ +#define OFFSET_TO_OFFBASE(o) (assert((o)>0), o + ZSTD_REP_NUM) +#define OFFBASE_IS_OFFSET(o) ((o) > ZSTD_REP_NUM) +#define OFFBASE_IS_REPCODE(o) ( 1 <= (o) && (o) <= ZSTD_REP_NUM) +#define OFFBASE_TO_OFFSET(o) (assert(OFFBASE_IS_OFFSET(o)), (o) - ZSTD_REP_NUM) +#define OFFBASE_TO_REPCODE(o) (assert(OFFBASE_IS_REPCODE(o)), (o)) /* returns ID 1,2,3 */ /*! ZSTD_storeSeq() : - * Store a sequence (litlen, litPtr, offCode and matchLength) into seqStore_t. - * @offBase_minus1 : Users should use employ macros STORE_REPCODE_X and STORE_OFFSET(). + * Store a sequence (litlen, litPtr, offBase and matchLength) into seqStore_t. + * @offBase : Users should employ macros REPCODE_TO_OFFBASE() and OFFSET_TO_OFFBASE(). * @matchLength : must be >= MINMATCH - * Allowed to overread literals up to litLimit. + * Allowed to over-read literals up to litLimit. */ HINT_INLINE UNUSED_ATTR void ZSTD_storeSeq(seqStore_t* seqStorePtr, size_t litLength, const BYTE* literals, const BYTE* litLimit, - U32 offBase_minus1, + U32 offBase, size_t matchLength) { BYTE const* const litLimit_w = litLimit - WILDCOPY_OVERLENGTH; @@ -596,8 +646,8 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, static const BYTE* g_start = NULL; if (g_start==NULL) g_start = (const BYTE*)literals; /* note : index only works for compression within a single segment */ { U32 const pos = (U32)((const BYTE*)literals - g_start); - DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offCode%7u", - pos, (U32)litLength, (U32)matchLength, (U32)offBase_minus1); + DEBUGLOG(6, "Cpos%7u :%3u literals, match%4u bytes at offBase%7u", + pos, (U32)litLength, (U32)matchLength, (U32)offBase); } #endif assert((size_t)(seqStorePtr->sequences - seqStorePtr->sequencesStart) < seqStorePtr->maxNbSeq); @@ -607,9 +657,9 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, assert(literals + litLength <= litLimit); if (litEnd <= litLimit_w) { /* Common case we can use wildcopy. - * First copy 16 bytes, because literals are likely short. - */ - assert(WILDCOPY_OVERLENGTH >= 16); + * First copy 16 bytes, because literals are likely short. + */ + ZSTD_STATIC_ASSERT(WILDCOPY_OVERLENGTH >= 16); ZSTD_copy16(seqStorePtr->lit, literals); if (litLength > 16) { ZSTD_wildcopy(seqStorePtr->lit+16, literals+16, (ptrdiff_t)litLength-16, ZSTD_no_overlap); @@ -628,7 +678,7 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, seqStorePtr->sequences[0].litLength = (U16)litLength; /* match offset */ - seqStorePtr->sequences[0].offBase = STORED_TO_OFFBASE(offBase_minus1); + seqStorePtr->sequences[0].offBase = offBase; /* match Length */ assert(matchLength >= MINMATCH); @@ -646,17 +696,17 @@ ZSTD_storeSeq(seqStore_t* seqStorePtr, /* ZSTD_updateRep() : * updates in-place @rep (array of repeat offsets) - * @offBase_minus1 : sum-type, with same numeric representation as ZSTD_storeSeq() + * @offBase : sum-type, using numeric representation of ZSTD_storeSeq() */ MEM_STATIC void -ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) +ZSTD_updateRep(U32 rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) { - if (STORED_IS_OFFSET(offBase_minus1)) { /* full offset */ + if (OFFBASE_IS_OFFSET(offBase)) { /* full offset */ rep[2] = rep[1]; rep[1] = rep[0]; - rep[0] = STORED_OFFSET(offBase_minus1); + rep[0] = OFFBASE_TO_OFFSET(offBase); } else { /* repcode */ - U32 const repCode = STORED_REPCODE(offBase_minus1) - 1 + ll0; + U32 const repCode = OFFBASE_TO_REPCODE(offBase) - 1 + ll0; if (repCode > 0) { /* note : if repCode==0, no change */ U32 const currentOffset = (repCode==ZSTD_REP_NUM) ? (rep[0] - 1) : rep[repCode]; rep[2] = (repCode >= 2) ? rep[1] : rep[2]; @@ -673,11 +723,11 @@ typedef struct repcodes_s { } repcodes_t; MEM_STATIC repcodes_t -ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0) +ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase, U32 const ll0) { repcodes_t newReps; ZSTD_memcpy(&newReps, rep, sizeof(newReps)); - ZSTD_updateRep(newReps.rep, offBase_minus1, ll0); + ZSTD_updateRep(newReps.rep, offBase, ll0); return newReps; } @@ -685,59 +735,6 @@ ZSTD_newRep(U32 const rep[ZSTD_REP_NUM], U32 const offBase_minus1, U32 const ll0 /*-************************************* * Match length counter ***************************************/ -static unsigned ZSTD_NbCommonBytes (size_t val) -{ - if (MEM_isLittleEndian()) { - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return (__builtin_ctzll((U64)val) >> 3); -# else - static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, - 0, 3, 1, 3, 1, 4, 2, 7, - 0, 2, 3, 6, 1, 5, 3, 5, - 1, 3, 4, 4, 2, 5, 6, 7, - 7, 0, 1, 2, 3, 3, 4, 6, - 2, 6, 5, 5, 3, 4, 5, 6, - 7, 1, 2, 4, 6, 4, 4, 5, - 7, 2, 6, 5, 7, 6, 7, 7 }; - return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return (__builtin_ctz((U32)val) >> 3); -# else - static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, - 3, 2, 2, 1, 3, 2, 0, 1, - 3, 3, 1, 2, 2, 2, 2, 0, - 3, 1, 2, 0, 1, 0, 1, 1 }; - return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; -# endif - } - } else { /* Big Endian CPU */ - if (MEM_64bits()) { -# if (__GNUC__ >= 4) - return (__builtin_clzll(val) >> 3); -# else - unsigned r; - const unsigned n32 = sizeof(size_t)*4; /* calculate this way due to compiler complaining in 32-bits mode */ - if (!(val>>n32)) { r=4; } else { r=0; val>>=n32; } - if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } - r += (!val); - return r; -# endif - } else { /* 32 bits */ -# if (__GNUC__ >= 3) - return (__builtin_clz((U32)val) >> 3); -# else - unsigned r; - if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } - r += (!val); - return r; -# endif - } } -} - - MEM_STATIC size_t ZSTD_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* const pInLimit) { const BYTE* const pStart = pIn; @@ -783,32 +780,43 @@ ZSTD_count_2segments(const BYTE* ip, const BYTE* match, * Hashes ***************************************/ static const U32 prime3bytes = 506832829U; -static U32 ZSTD_hash3(U32 u, U32 h) { return ((u << (32-24)) * prime3bytes) >> (32-h) ; } -MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h); } /* only in zstd_opt.h */ +static U32 ZSTD_hash3(U32 u, U32 h, U32 s) { assert(h <= 32); return (((u << (32-24)) * prime3bytes) ^ s) >> (32-h) ; } +MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h, 0); } /* only in zstd_opt.h */ +MEM_STATIC size_t ZSTD_hash3PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash3(MEM_readLE32(ptr), h, s); } static const U32 prime4bytes = 2654435761U; -static U32 ZSTD_hash4(U32 u, U32 h) { return (u * prime4bytes) >> (32-h) ; } -static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_read32(ptr), h); } +static U32 ZSTD_hash4(U32 u, U32 h, U32 s) { assert(h <= 32); return ((u * prime4bytes) ^ s) >> (32-h) ; } +static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_readLE32(ptr), h, 0); } +static size_t ZSTD_hash4PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash4(MEM_readLE32(ptr), h, s); } static const U64 prime5bytes = 889523592379ULL; -static size_t ZSTD_hash5(U64 u, U32 h) { return (size_t)(((u << (64-40)) * prime5bytes) >> (64-h)) ; } -static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h); } +static size_t ZSTD_hash5(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-40)) * prime5bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash5PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash5(MEM_readLE64(p), h, s); } static const U64 prime6bytes = 227718039650203ULL; -static size_t ZSTD_hash6(U64 u, U32 h) { return (size_t)(((u << (64-48)) * prime6bytes) >> (64-h)) ; } -static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h); } +static size_t ZSTD_hash6(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-48)) * prime6bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash6PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash6(MEM_readLE64(p), h, s); } static const U64 prime7bytes = 58295818150454627ULL; -static size_t ZSTD_hash7(U64 u, U32 h) { return (size_t)(((u << (64-56)) * prime7bytes) >> (64-h)) ; } -static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h); } +static size_t ZSTD_hash7(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u << (64-56)) * prime7bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash7PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash7(MEM_readLE64(p), h, s); } static const U64 prime8bytes = 0xCF1BBCDCB7A56463ULL; -static size_t ZSTD_hash8(U64 u, U32 h) { return (size_t)(((u) * prime8bytes) >> (64-h)) ; } -static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h); } +static size_t ZSTD_hash8(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u) * prime8bytes) ^ s) >> (64-h)) ; } +static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h, 0); } +static size_t ZSTD_hash8PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash8(MEM_readLE64(p), h, s); } + MEM_STATIC FORCE_INLINE_ATTR size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) { + /* Although some of these hashes do support hBits up to 64, some do not. + * To be on the safe side, always avoid hBits > 32. */ + assert(hBits <= 32); + switch(mls) { default: @@ -820,6 +828,24 @@ size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls) } } +MEM_STATIC FORCE_INLINE_ATTR +size_t ZSTD_hashPtrSalted(const void* p, U32 hBits, U32 mls, const U64 hashSalt) { + /* Although some of these hashes do support hBits up to 64, some do not. + * To be on the safe side, always avoid hBits > 32. */ + assert(hBits <= 32); + + switch(mls) + { + default: + case 4: return ZSTD_hash4PtrS(p, hBits, (U32)hashSalt); + case 5: return ZSTD_hash5PtrS(p, hBits, hashSalt); + case 6: return ZSTD_hash6PtrS(p, hBits, hashSalt); + case 7: return ZSTD_hash7PtrS(p, hBits, hashSalt); + case 8: return ZSTD_hash8PtrS(p, hBits, hashSalt); + } +} + + /* ZSTD_ipow() : * Return base^exponent. */ @@ -1011,7 +1037,9 @@ MEM_STATIC U32 ZSTD_window_needOverflowCorrection(ZSTD_window_t const window, * The least significant cycleLog bits of the indices must remain the same, * which may be 0. Every index up to maxDist in the past must be valid. */ -MEM_STATIC U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_window_correctOverflow(ZSTD_window_t* window, U32 cycleLog, U32 maxDist, void const* src) { /* preemptive overflow correction: @@ -1167,10 +1195,15 @@ ZSTD_checkDictValidity(const ZSTD_window_t* window, (unsigned)blockEndIdx, (unsigned)maxDist, (unsigned)loadedDictEnd); assert(blockEndIdx >= loadedDictEnd); - if (blockEndIdx > loadedDictEnd + maxDist) { + if (blockEndIdx > loadedDictEnd + maxDist || loadedDictEnd != window->dictLimit) { /* On reaching window size, dictionaries are invalidated. * For simplification, if window size is reached anywhere within next block, * the dictionary is invalidated for the full block. + * + * We also have to invalidate the dictionary if ZSTD_window_update() has detected + * non-contiguous segments, which means that loadedDictEnd != window->dictLimit. + * loadedDictEnd may be 0, if forceWindow is true, but in that case we never use + * dictMatchState, so setting it to NULL is not a problem. */ DEBUGLOG(6, "invalidating dictionary for current block (distance > windowSize)"); *loadedDictEndPtr = 0; @@ -1199,7 +1232,9 @@ MEM_STATIC void ZSTD_window_init(ZSTD_window_t* window) { * forget about the extDict. Handles overlap of the prefix and extDict. * Returns non-zero if the segment is contiguous. */ -MEM_STATIC U32 ZSTD_window_update(ZSTD_window_t* window, +MEM_STATIC +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_window_update(ZSTD_window_t* window, void const* src, size_t srcSize, int forceNonContiguous) { @@ -1302,6 +1337,42 @@ MEM_STATIC void ZSTD_debugTable(const U32* table, U32 max) #endif +/* Short Cache */ + +/* Normally, zstd matchfinders follow this flow: + * 1. Compute hash at ip + * 2. Load index from hashTable[hash] + * 3. Check if *ip == *(base + index) + * In dictionary compression, loading *(base + index) is often an L2 or even L3 miss. + * + * Short cache is an optimization which allows us to avoid step 3 most of the time + * when the data doesn't actually match. With short cache, the flow becomes: + * 1. Compute (hash, currentTag) at ip. currentTag is an 8-bit independent hash at ip. + * 2. Load (index, matchTag) from hashTable[hash]. See ZSTD_writeTaggedIndex to understand how this works. + * 3. Only if currentTag == matchTag, check *ip == *(base + index). Otherwise, continue. + * + * Currently, short cache is only implemented in CDict hashtables. Thus, its use is limited to + * dictMatchState matchfinders. + */ +#define ZSTD_SHORT_CACHE_TAG_BITS 8 +#define ZSTD_SHORT_CACHE_TAG_MASK ((1u << ZSTD_SHORT_CACHE_TAG_BITS) - 1) + +/* Helper function for ZSTD_fillHashTable and ZSTD_fillDoubleHashTable. + * Unpacks hashAndTag into (hash, tag), then packs (index, tag) into hashTable[hash]. */ +MEM_STATIC void ZSTD_writeTaggedIndex(U32* const hashTable, size_t hashAndTag, U32 index) { + size_t const hash = hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; + U32 const tag = (U32)(hashAndTag & ZSTD_SHORT_CACHE_TAG_MASK); + assert(index >> (32 - ZSTD_SHORT_CACHE_TAG_BITS) == 0); + hashTable[hash] = (index << ZSTD_SHORT_CACHE_TAG_BITS) | tag; +} + +/* Helper function for short cache matchfinders. + * Unpacks tag1 and tag2 from lower bits of packedTag1 and packedTag2, then checks if the tags match. */ +MEM_STATIC int ZSTD_comparePackedTags(size_t packedTag1, size_t packedTag2) { + U32 const tag1 = packedTag1 & ZSTD_SHORT_CACHE_TAG_MASK; + U32 const tag2 = packedTag2 & ZSTD_SHORT_CACHE_TAG_MASK; + return tag1 == tag2; +} /* =============================================================== @@ -1381,11 +1452,10 @@ size_t ZSTD_writeLastEmptyBlock(void* dst, size_t dstCapacity); * This cannot be used when long range matching is enabled. * Zstd will use these sequences, and pass the literals to a secondary block * compressor. - * @return : An error code on failure. * NOTE: seqs are not verified! Invalid sequences can cause out-of-bounds memory * access and data corruption. */ -size_t ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); +void ZSTD_referenceExternalSequences(ZSTD_CCtx* cctx, rawSeq* seq, size_t nbSeq); /* ZSTD_cycleLog() : * condition for correct operation : hashLog > 1 */ @@ -1396,4 +1466,55 @@ U32 ZSTD_cycleLog(U32 hashLog, ZSTD_strategy strat); */ void ZSTD_CCtx_trace(ZSTD_CCtx* cctx, size_t extraCSize); +/* Returns 0 on success, and a ZSTD_error otherwise. This function scans through an array of + * ZSTD_Sequence, storing the sequences it finds, until it reaches a block delimiter. + * Note that the block delimiter must include the last literals of the block. + */ +size_t +ZSTD_copySequencesToSeqStoreExplicitBlockDelim(ZSTD_CCtx* cctx, + ZSTD_sequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); + +/* Returns the number of bytes to move the current read position back by. + * Only non-zero if we ended up splitting a sequence. + * Otherwise, it may return a ZSTD error if something went wrong. + * + * This function will attempt to scan through blockSize bytes + * represented by the sequences in @inSeqs, + * storing any (partial) sequences. + * + * Occasionally, we may want to change the actual number of bytes we consumed from inSeqs to + * avoid splitting a match, or to avoid splitting a match such that it would produce a match + * smaller than MINMATCH. In this case, we return the number of bytes that we didn't read from this block. + */ +size_t +ZSTD_copySequencesToSeqStoreNoBlockDelim(ZSTD_CCtx* cctx, ZSTD_sequencePosition* seqPos, + const ZSTD_Sequence* const inSeqs, size_t inSeqsSize, + const void* src, size_t blockSize, ZSTD_paramSwitch_e externalRepSearch); + +/* Returns 1 if an external sequence producer is registered, otherwise returns 0. */ +MEM_STATIC int ZSTD_hasExtSeqProd(const ZSTD_CCtx_params* params) { + return params->extSeqProdFunc != NULL; +} + +/* =============================================================== + * Deprecated definitions that are still used internally to avoid + * deprecation warnings. These functions are exactly equivalent to + * their public variants, but avoid the deprecation warnings. + * =============================================================== */ + +size_t ZSTD_compressBegin_usingCDict_deprecated(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict); + +size_t ZSTD_compressContinue_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + +size_t ZSTD_compressEnd_public(ZSTD_CCtx* cctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + +size_t ZSTD_compressBlock_deprecated(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); + + #endif /* ZSTD_COMPRESS_H */ diff --git a/lib/zstd/compress/zstd_compress_literals.c b/lib/zstd/compress/zstd_compress_literals.c index 52b0a8059aba..3e9ea46a670a 100644 --- a/lib/zstd/compress/zstd_compress_literals.c +++ b/lib/zstd/compress/zstd_compress_literals.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -13,11 +14,36 @@ ***************************************/ #include "zstd_compress_literals.h" + +/* ************************************************************** +* Debug Traces +****************************************************************/ +#if DEBUGLEVEL >= 2 + +static size_t showHexa(const void* src, size_t srcSize) +{ + const BYTE* const ip = (const BYTE*)src; + size_t u; + for (u=0; u31) + (srcSize>4095); + DEBUGLOG(5, "ZSTD_noCompressLiterals: srcSize=%zu, dstCapacity=%zu", srcSize, dstCapacity); + RETURN_ERROR_IF(srcSize + flSize > dstCapacity, dstSize_tooSmall, ""); switch(flSize) @@ -36,16 +62,30 @@ size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, } ZSTD_memcpy(ostart + flSize, src, srcSize); - DEBUGLOG(5, "Raw literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); + DEBUGLOG(5, "Raw (uncompressed) literals: %u -> %u", (U32)srcSize, (U32)(srcSize + flSize)); return srcSize + flSize; } +static int allBytesIdentical(const void* src, size_t srcSize) +{ + assert(srcSize >= 1); + assert(src != NULL); + { const BYTE b = ((const BYTE*)src)[0]; + size_t p; + for (p=1; p31) + (srcSize>4095); - (void)dstCapacity; /* dstCapacity already guaranteed to be >=4, hence large enough */ + assert(dstCapacity >= 4); (void)dstCapacity; + assert(allBytesIdentical(src, srcSize)); switch(flSize) { @@ -63,28 +103,51 @@ size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* } ostart[flSize] = *(const BYTE*)src; - DEBUGLOG(5, "RLE literals: %u -> %u", (U32)srcSize, (U32)flSize + 1); + DEBUGLOG(5, "RLE : Repeated Literal (%02X: %u times) -> %u bytes encoded", ((const BYTE*)src)[0], (U32)srcSize, (U32)flSize + 1); return flSize+1; } -size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_strategy strategy, int disableLiteralCompression, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize, - void* entropyWorkspace, size_t entropyWorkspaceSize, - const int bmi2, - unsigned suspectUncompressible) +/* ZSTD_minLiteralsToCompress() : + * returns minimal amount of literals + * for literal compression to even be attempted. + * Minimum is made tighter as compression strategy increases. + */ +static size_t +ZSTD_minLiteralsToCompress(ZSTD_strategy strategy, HUF_repeat huf_repeat) +{ + assert((int)strategy >= 0); + assert((int)strategy <= 9); + /* btultra2 : min 8 bytes; + * then 2x larger for each successive compression strategy + * max threshold 64 bytes */ + { int const shift = MIN(9-(int)strategy, 3); + size_t const mintc = (huf_repeat == HUF_repeat_valid) ? 6 : (size_t)8 << shift; + DEBUGLOG(7, "minLiteralsToCompress = %zu", mintc); + return mintc; + } +} + +size_t ZSTD_compressLiterals ( + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + void* entropyWorkspace, size_t entropyWorkspaceSize, + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_strategy strategy, + int disableLiteralCompression, + int suspectUncompressible, + int bmi2) { - size_t const minGain = ZSTD_minGain(srcSize, strategy); size_t const lhSize = 3 + (srcSize >= 1 KB) + (srcSize >= 16 KB); BYTE* const ostart = (BYTE*)dst; U32 singleStream = srcSize < 256; symbolEncodingType_e hType = set_compressed; size_t cLitSize; - DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i srcSize=%u)", - disableLiteralCompression, (U32)srcSize); + DEBUGLOG(5,"ZSTD_compressLiterals (disableLiteralCompression=%i, srcSize=%u, dstCapacity=%zu)", + disableLiteralCompression, (U32)srcSize, dstCapacity); + + DEBUGLOG(6, "Completed literals listing (%zu bytes)", showHexa(src, srcSize)); /* Prepare nextEntropy assuming reusing the existing table */ ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); @@ -92,40 +155,51 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, if (disableLiteralCompression) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - /* small ? don't even attempt compression (speed opt) */ -# define COMPRESS_LITERALS_SIZE_MIN 63 - { size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN; - if (srcSize <= minLitSize) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - } + /* if too small, don't even attempt compression (speed opt) */ + if (srcSize < ZSTD_minLiteralsToCompress(strategy, prevHuf->repeatMode)) + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); RETURN_ERROR_IF(dstCapacity < lhSize+1, dstSize_tooSmall, "not enough space for compression"); { HUF_repeat repeat = prevHuf->repeatMode; - int const preferRepeat = strategy < ZSTD_lazy ? srcSize <= 1024 : 0; + int const flags = 0 + | (bmi2 ? HUF_flags_bmi2 : 0) + | (strategy < ZSTD_lazy && srcSize <= 1024 ? HUF_flags_preferRepeat : 0) + | (strategy >= HUF_OPTIMAL_DEPTH_THRESHOLD ? HUF_flags_optimalDepth : 0) + | (suspectUncompressible ? HUF_flags_suspectUncompressible : 0); + + typedef size_t (*huf_compress_f)(void*, size_t, const void*, size_t, unsigned, unsigned, void*, size_t, HUF_CElt*, HUF_repeat*, int); + huf_compress_f huf_compress; if (repeat == HUF_repeat_valid && lhSize == 3) singleStream = 1; - cLitSize = singleStream ? - HUF_compress1X_repeat( - ostart+lhSize, dstCapacity-lhSize, src, srcSize, - HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, - (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible) : - HUF_compress4X_repeat( - ostart+lhSize, dstCapacity-lhSize, src, srcSize, - HUF_SYMBOLVALUE_MAX, HUF_TABLELOG_DEFAULT, entropyWorkspace, entropyWorkspaceSize, - (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2, suspectUncompressible); + huf_compress = singleStream ? HUF_compress1X_repeat : HUF_compress4X_repeat; + cLitSize = huf_compress(ostart+lhSize, dstCapacity-lhSize, + src, srcSize, + HUF_SYMBOLVALUE_MAX, LitHufLog, + entropyWorkspace, entropyWorkspaceSize, + (HUF_CElt*)nextHuf->CTable, + &repeat, flags); + DEBUGLOG(5, "%zu literals compressed into %zu bytes (before header)", srcSize, cLitSize); if (repeat != HUF_repeat_none) { /* reused the existing table */ - DEBUGLOG(5, "Reusing previous huffman table"); + DEBUGLOG(5, "reusing statistics from previous huffman block"); hType = set_repeat; } } - if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); - } + { size_t const minGain = ZSTD_minGain(srcSize, strategy); + if ((cLitSize==0) || (cLitSize >= srcSize - minGain) || ERR_isError(cLitSize)) { + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize); + } } if (cLitSize==1) { - ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); - return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); - } + /* A return value of 1 signals that the alphabet consists of a single symbol. + * However, in some rare circumstances, it could be the compressed size (a single byte). + * For that outcome to have a chance to happen, it's necessary that `srcSize < 8`. + * (it's also necessary to not generate statistics). + * Therefore, in such a case, actively check that all bytes are identical. */ + if ((srcSize >= 8) || allBytesIdentical(src, srcSize)) { + ZSTD_memcpy(nextHuf, prevHuf, sizeof(*prevHuf)); + return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize); + } } if (hType == set_compressed) { /* using a newly constructed table */ @@ -136,16 +210,19 @@ size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, switch(lhSize) { case 3: /* 2 - 2 - 10 - 10 */ - { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); + if (!singleStream) assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); + { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<14); MEM_writeLE24(ostart, lhc); break; } case 4: /* 2 - 2 - 14 - 14 */ + assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); { U32 const lhc = hType + (2 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<18); MEM_writeLE32(ostart, lhc); break; } case 5: /* 2 - 2 - 18 - 18 */ + assert(srcSize >= MIN_LITERALS_FOR_4_STREAMS); { U32 const lhc = hType + (3 << 2) + ((U32)srcSize<<4) + ((U32)cLitSize<<22); MEM_writeLE32(ostart, lhc); ostart[4] = (BYTE)(cLitSize >> 10); diff --git a/lib/zstd/compress/zstd_compress_literals.h b/lib/zstd/compress/zstd_compress_literals.h index 9775fb97cb70..a2a85d6b69e5 100644 --- a/lib/zstd/compress/zstd_compress_literals.h +++ b/lib/zstd/compress/zstd_compress_literals.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -16,16 +17,24 @@ size_t ZSTD_noCompressLiterals (void* dst, size_t dstCapacity, const void* src, size_t srcSize); +/* ZSTD_compressRleLiteralsBlock() : + * Conditions : + * - All bytes in @src are identical + * - dstCapacity >= 4 */ size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t dstCapacity, const void* src, size_t srcSize); -/* If suspectUncompressible then some sampling checks will be run to potentially skip huffman coding */ -size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf, - ZSTD_hufCTables_t* nextHuf, - ZSTD_strategy strategy, int disableLiteralCompression, - void* dst, size_t dstCapacity, +/* ZSTD_compressLiterals(): + * @entropyWorkspace: must be aligned on 4-bytes boundaries + * @entropyWorkspaceSize : must be >= HUF_WORKSPACE_SIZE + * @suspectUncompressible: sampling checks, to potentially skip huffman coding + */ +size_t ZSTD_compressLiterals (void* dst, size_t dstCapacity, const void* src, size_t srcSize, void* entropyWorkspace, size_t entropyWorkspaceSize, - const int bmi2, - unsigned suspectUncompressible); + const ZSTD_hufCTables_t* prevHuf, + ZSTD_hufCTables_t* nextHuf, + ZSTD_strategy strategy, int disableLiteralCompression, + int suspectUncompressible, + int bmi2); #endif /* ZSTD_COMPRESS_LITERALS_H */ diff --git a/lib/zstd/compress/zstd_compress_sequences.c b/lib/zstd/compress/zstd_compress_sequences.c index 21ddc1b37acf..5c028c78d889 100644 --- a/lib/zstd/compress/zstd_compress_sequences.c +++ b/lib/zstd/compress/zstd_compress_sequences.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -58,7 +59,7 @@ static unsigned ZSTD_useLowProbCount(size_t const nbSeq) { /* Heuristic: This should cover most blocks <= 16K and * start to fade out after 16K to about 32K depending on - * comprssibility. + * compressibility. */ return nbSeq >= 2048; } @@ -166,7 +167,7 @@ ZSTD_selectEncodingType( if (mostFrequent == nbSeq) { *repeatMode = FSE_repeat_none; if (isDefaultAllowed && nbSeq <= 2) { - /* Prefer set_basic over set_rle when there are 2 or less symbols, + /* Prefer set_basic over set_rle when there are 2 or fewer symbols, * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol. * If basic encoding isn't possible, always choose RLE. */ diff --git a/lib/zstd/compress/zstd_compress_sequences.h b/lib/zstd/compress/zstd_compress_sequences.h index 7991364c2f71..7fe6f4ff5cf2 100644 --- a/lib/zstd/compress/zstd_compress_sequences.h +++ b/lib/zstd/compress/zstd_compress_sequences.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_compress_superblock.c b/lib/zstd/compress/zstd_compress_superblock.c index 17d836cc84e8..41f6521b27cd 100644 --- a/lib/zstd/compress/zstd_compress_superblock.c +++ b/lib/zstd/compress/zstd_compress_superblock.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -36,13 +37,14 @@ * If it is set_compressed, first sub-block's literals section will be Treeless_Literals_Block * and the following sub-blocks' literals sections will be Treeless_Literals_Block. * @return : compressed size of literals section of a sub-block - * Or 0 if it unable to compress. + * Or 0 if unable to compress. * Or error code */ -static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, - const ZSTD_hufCTablesMetadata_t* hufMetadata, - const BYTE* literals, size_t litSize, - void* dst, size_t dstSize, - const int bmi2, int writeEntropy, int* entropyWritten) +static size_t +ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, + const ZSTD_hufCTablesMetadata_t* hufMetadata, + const BYTE* literals, size_t litSize, + void* dst, size_t dstSize, + const int bmi2, int writeEntropy, int* entropyWritten) { size_t const header = writeEntropy ? 200 : 0; size_t const lhSize = 3 + (litSize >= (1 KB - header)) + (litSize >= (16 KB - header)); @@ -53,8 +55,6 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, symbolEncodingType_e hType = writeEntropy ? hufMetadata->hType : set_repeat; size_t cLitSize = 0; - (void)bmi2; /* TODO bmi2... */ - DEBUGLOG(5, "ZSTD_compressSubBlock_literal (litSize=%zu, lhSize=%zu, writeEntropy=%d)", litSize, lhSize, writeEntropy); *entropyWritten = 0; @@ -76,9 +76,9 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, DEBUGLOG(5, "ZSTD_compressSubBlock_literal (hSize=%zu)", hufMetadata->hufDesSize); } - /* TODO bmi2 */ - { const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, oend-op, literals, litSize, hufTable) - : HUF_compress4X_usingCTable(op, oend-op, literals, litSize, hufTable); + { int const flags = bmi2 ? HUF_flags_bmi2 : 0; + const size_t cSize = singleStream ? HUF_compress1X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags) + : HUF_compress4X_usingCTable(op, (size_t)(oend-op), literals, litSize, hufTable, flags); op += cSize; cLitSize += cSize; if (cSize == 0 || ERR_isError(cSize)) { @@ -103,7 +103,7 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, switch(lhSize) { case 3: /* 2 - 2 - 10 - 10 */ - { U32 const lhc = hType + ((!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); + { U32 const lhc = hType + ((U32)(!singleStream) << 2) + ((U32)litSize<<4) + ((U32)cLitSize<<14); MEM_writeLE24(ostart, lhc); break; } @@ -123,26 +123,30 @@ static size_t ZSTD_compressSubBlock_literal(const HUF_CElt* hufTable, } *entropyWritten = 1; DEBUGLOG(5, "Compressed literals: %u -> %u", (U32)litSize, (U32)(op-ostart)); - return op-ostart; + return (size_t)(op-ostart); } -static size_t ZSTD_seqDecompressedSize(seqStore_t const* seqStore, const seqDef* sequences, size_t nbSeq, size_t litSize, int lastSequence) { - const seqDef* const sstart = sequences; - const seqDef* const send = sequences + nbSeq; - const seqDef* sp = sstart; +static size_t +ZSTD_seqDecompressedSize(seqStore_t const* seqStore, + const seqDef* sequences, size_t nbSeqs, + size_t litSize, int lastSubBlock) +{ size_t matchLengthSum = 0; size_t litLengthSum = 0; - (void)(litLengthSum); /* suppress unused variable warning on some environments */ - while (send-sp > 0) { - ZSTD_sequenceLength const seqLen = ZSTD_getSequenceLength(seqStore, sp); + size_t n; + for (n=0; ncParams.windowLog > STREAM_ACCUMULATOR_MIN; BYTE* const ostart = (BYTE*)dst; @@ -176,14 +181,14 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables /* Sequences Header */ RETURN_ERROR_IF((oend-op) < 3 /*max nbSeq Size*/ + 1 /*seqHead*/, dstSize_tooSmall, ""); - if (nbSeq < 0x7F) + if (nbSeq < 128) *op++ = (BYTE)nbSeq; else if (nbSeq < LONGNBSEQ) op[0] = (BYTE)((nbSeq>>8) + 0x80), op[1] = (BYTE)nbSeq, op+=2; else op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3; if (nbSeq==0) { - return op - ostart; + return (size_t)(op - ostart); } /* seqHead : flags for FSE encoding type */ @@ -205,7 +210,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables } { size_t const bitstreamSize = ZSTD_encodeSequences( - op, oend - op, + op, (size_t)(oend - op), fseTables->matchlengthCTable, mlCode, fseTables->offcodeCTable, ofCode, fseTables->litlengthCTable, llCode, @@ -249,7 +254,7 @@ static size_t ZSTD_compressSubBlock_sequences(const ZSTD_fseCTables_t* fseTables #endif *entropyWritten = 1; - return op - ostart; + return (size_t)(op - ostart); } /* ZSTD_compressSubBlock() : @@ -275,7 +280,8 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, litSize, nbSeq, writeLitEntropy, writeSeqEntropy, lastBlock); { size_t cLitSize = ZSTD_compressSubBlock_literal((const HUF_CElt*)entropy->huf.CTable, &entropyMetadata->hufMetadata, literals, litSize, - op, oend-op, bmi2, writeLitEntropy, litEntropyWritten); + op, (size_t)(oend-op), + bmi2, writeLitEntropy, litEntropyWritten); FORWARD_IF_ERROR(cLitSize, "ZSTD_compressSubBlock_literal failed"); if (cLitSize == 0) return 0; op += cLitSize; @@ -285,18 +291,18 @@ static size_t ZSTD_compressSubBlock(const ZSTD_entropyCTables_t* entropy, sequences, nbSeq, llCode, mlCode, ofCode, cctxParams, - op, oend-op, + op, (size_t)(oend-op), bmi2, writeSeqEntropy, seqEntropyWritten); FORWARD_IF_ERROR(cSeqSize, "ZSTD_compressSubBlock_sequences failed"); if (cSeqSize == 0) return 0; op += cSeqSize; } /* Write block header */ - { size_t cSize = (op-ostart)-ZSTD_blockHeaderSize; + { size_t cSize = (size_t)(op-ostart) - ZSTD_blockHeaderSize; U32 const cBlockHeader24 = lastBlock + (((U32)bt_compressed)<<1) + (U32)(cSize << 3); MEM_writeLE24(ostart, cBlockHeader24); } - return op-ostart; + return (size_t)(op-ostart); } static size_t ZSTD_estimateSubBlockSize_literal(const BYTE* literals, size_t litSize, @@ -385,7 +391,11 @@ static size_t ZSTD_estimateSubBlockSize_sequences(const BYTE* ofCodeTable, return cSeqSizeEstimate + sequencesSectionHeaderSize; } -static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, +typedef struct { + size_t estLitSize; + size_t estBlockSize; +} EstimatedBlockSize; +static EstimatedBlockSize ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, const BYTE* ofCodeTable, const BYTE* llCodeTable, const BYTE* mlCodeTable, @@ -393,15 +403,17 @@ static size_t ZSTD_estimateSubBlockSize(const BYTE* literals, size_t litSize, const ZSTD_entropyCTables_t* entropy, const ZSTD_entropyCTablesMetadata_t* entropyMetadata, void* workspace, size_t wkspSize, - int writeLitEntropy, int writeSeqEntropy) { - size_t cSizeEstimate = 0; - cSizeEstimate += ZSTD_estimateSubBlockSize_literal(literals, litSize, - &entropy->huf, &entropyMetadata->hufMetadata, - workspace, wkspSize, writeLitEntropy); - cSizeEstimate += ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, + int writeLitEntropy, int writeSeqEntropy) +{ + EstimatedBlockSize ebs; + ebs.estLitSize = ZSTD_estimateSubBlockSize_literal(literals, litSize, + &entropy->huf, &entropyMetadata->hufMetadata, + workspace, wkspSize, writeLitEntropy); + ebs.estBlockSize = ZSTD_estimateSubBlockSize_sequences(ofCodeTable, llCodeTable, mlCodeTable, nbSeq, &entropy->fse, &entropyMetadata->fseMetadata, workspace, wkspSize, writeSeqEntropy); - return cSizeEstimate + ZSTD_blockHeaderSize; + ebs.estBlockSize += ebs.estLitSize + ZSTD_blockHeaderSize; + return ebs; } static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMetadata) @@ -415,13 +427,56 @@ static int ZSTD_needSequenceEntropyTables(ZSTD_fseCTablesMetadata_t const* fseMe return 0; } +static size_t countLiterals(seqStore_t const* seqStore, const seqDef* sp, size_t seqCount) +{ + size_t n, total = 0; + assert(sp != NULL); + for (n=0; n %zu bytes", seqCount, (const void*)sp, total); + return total; +} + +#define BYTESCALE 256 + +static size_t sizeBlockSequences(const seqDef* sp, size_t nbSeqs, + size_t targetBudget, size_t avgLitCost, size_t avgSeqCost, + int firstSubBlock) +{ + size_t n, budget = 0, inSize=0; + /* entropy headers */ + size_t const headerSize = (size_t)firstSubBlock * 120 * BYTESCALE; /* generous estimate */ + assert(firstSubBlock==0 || firstSubBlock==1); + budget += headerSize; + + /* first sequence => at least one sequence*/ + budget += sp[0].litLength * avgLitCost + avgSeqCost; + if (budget > targetBudget) return 1; + inSize = sp[0].litLength + (sp[0].mlBase+MINMATCH); + + /* loop over sequences */ + for (n=1; n targetBudget) + /* though continue to expand until the sub-block is deemed compressible */ + && (budget < inSize * BYTESCALE) ) + break; + } + + return n; +} + /* ZSTD_compressSubBlock_multi() : * Breaks super-block into multiple sub-blocks and compresses them. - * Entropy will be written to the first block. - * The following blocks will use repeat mode to compress. - * All sub-blocks are compressed blocks (no raw or rle blocks). - * @return : compressed size of the super block (which is multiple ZSTD blocks) - * Or 0 if it failed to compress. */ + * Entropy will be written into the first block. + * The following blocks use repeat_mode to compress. + * Sub-blocks are all compressed, except the last one when beneficial. + * @return : compressed size of the super block (which features multiple ZSTD blocks) + * or 0 if it failed to compress. */ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, const ZSTD_compressedBlockState_t* prevCBlock, ZSTD_compressedBlockState_t* nextCBlock, @@ -434,10 +489,12 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, { const seqDef* const sstart = seqStorePtr->sequencesStart; const seqDef* const send = seqStorePtr->sequences; - const seqDef* sp = sstart; + const seqDef* sp = sstart; /* tracks progresses within seqStorePtr->sequences */ + size_t const nbSeqs = (size_t)(send - sstart); const BYTE* const lstart = seqStorePtr->litStart; const BYTE* const lend = seqStorePtr->lit; const BYTE* lp = lstart; + size_t const nbLiterals = (size_t)(lend - lstart); BYTE const* ip = (BYTE const*)src; BYTE const* const iend = ip + srcSize; BYTE* const ostart = (BYTE*)dst; @@ -446,112 +503,171 @@ static size_t ZSTD_compressSubBlock_multi(const seqStore_t* seqStorePtr, const BYTE* llCodePtr = seqStorePtr->llCode; const BYTE* mlCodePtr = seqStorePtr->mlCode; const BYTE* ofCodePtr = seqStorePtr->ofCode; - size_t targetCBlockSize = cctxParams->targetCBlockSize; - size_t litSize, seqCount; - int writeLitEntropy = entropyMetadata->hufMetadata.hType == set_compressed; + size_t const minTarget = ZSTD_TARGETCBLOCKSIZE_MIN; /* enforce minimum size, to reduce undesirable side effects */ + size_t const targetCBlockSize = MAX(minTarget, cctxParams->targetCBlockSize); + int writeLitEntropy = (entropyMetadata->hufMetadata.hType == set_compressed); int writeSeqEntropy = 1; - int lastSequence = 0; - - DEBUGLOG(5, "ZSTD_compressSubBlock_multi (litSize=%u, nbSeq=%u)", - (unsigned)(lend-lp), (unsigned)(send-sstart)); - - litSize = 0; - seqCount = 0; - do { - size_t cBlockSizeEstimate = 0; - if (sstart == send) { - lastSequence = 1; - } else { - const seqDef* const sequence = sp + seqCount; - lastSequence = sequence == send - 1; - litSize += ZSTD_getSequenceLength(seqStorePtr, sequence).litLength; - seqCount++; - } - if (lastSequence) { - assert(lp <= lend); - assert(litSize <= (size_t)(lend - lp)); - litSize = (size_t)(lend - lp); + + DEBUGLOG(5, "ZSTD_compressSubBlock_multi (srcSize=%u, litSize=%u, nbSeq=%u)", + (unsigned)srcSize, (unsigned)(lend-lstart), (unsigned)(send-sstart)); + + /* let's start by a general estimation for the full block */ + if (nbSeqs > 0) { + EstimatedBlockSize const ebs = + ZSTD_estimateSubBlockSize(lp, nbLiterals, + ofCodePtr, llCodePtr, mlCodePtr, nbSeqs, + &nextCBlock->entropy, entropyMetadata, + workspace, wkspSize, + writeLitEntropy, writeSeqEntropy); + /* quick estimation */ + size_t const avgLitCost = nbLiterals ? (ebs.estLitSize * BYTESCALE) / nbLiterals : BYTESCALE; + size_t const avgSeqCost = ((ebs.estBlockSize - ebs.estLitSize) * BYTESCALE) / nbSeqs; + const size_t nbSubBlocks = MAX((ebs.estBlockSize + (targetCBlockSize/2)) / targetCBlockSize, 1); + size_t n, avgBlockBudget, blockBudgetSupp=0; + avgBlockBudget = (ebs.estBlockSize * BYTESCALE) / nbSubBlocks; + DEBUGLOG(5, "estimated fullblock size=%u bytes ; avgLitCost=%.2f ; avgSeqCost=%.2f ; targetCBlockSize=%u, nbSubBlocks=%u ; avgBlockBudget=%.0f bytes", + (unsigned)ebs.estBlockSize, (double)avgLitCost/BYTESCALE, (double)avgSeqCost/BYTESCALE, + (unsigned)targetCBlockSize, (unsigned)nbSubBlocks, (double)avgBlockBudget/BYTESCALE); + /* simplification: if estimates states that the full superblock doesn't compress, just bail out immediately + * this will result in the production of a single uncompressed block covering @srcSize.*/ + if (ebs.estBlockSize > srcSize) return 0; + + /* compress and write sub-blocks */ + assert(nbSubBlocks>0); + for (n=0; n < nbSubBlocks-1; n++) { + /* determine nb of sequences for current sub-block + nbLiterals from next sequence */ + size_t const seqCount = sizeBlockSequences(sp, (size_t)(send-sp), + avgBlockBudget + blockBudgetSupp, avgLitCost, avgSeqCost, n==0); + /* if reached last sequence : break to last sub-block (simplification) */ + assert(seqCount <= (size_t)(send-sp)); + if (sp + seqCount == send) break; + assert(seqCount > 0); + /* compress sub-block */ + { int litEntropyWritten = 0; + int seqEntropyWritten = 0; + size_t litSize = countLiterals(seqStorePtr, sp, seqCount); + const size_t decompressedSize = + ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 0); + size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, + sp, seqCount, + lp, litSize, + llCodePtr, mlCodePtr, ofCodePtr, + cctxParams, + op, (size_t)(oend-op), + bmi2, writeLitEntropy, writeSeqEntropy, + &litEntropyWritten, &seqEntropyWritten, + 0); + FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); + + /* check compressibility, update state components */ + if (cSize > 0 && cSize < decompressedSize) { + DEBUGLOG(5, "Committed sub-block compressing %u bytes => %u bytes", + (unsigned)decompressedSize, (unsigned)cSize); + assert(ip + decompressedSize <= iend); + ip += decompressedSize; + lp += litSize; + op += cSize; + llCodePtr += seqCount; + mlCodePtr += seqCount; + ofCodePtr += seqCount; + /* Entropy only needs to be written once */ + if (litEntropyWritten) { + writeLitEntropy = 0; + } + if (seqEntropyWritten) { + writeSeqEntropy = 0; + } + sp += seqCount; + blockBudgetSupp = 0; + } } + /* otherwise : do not compress yet, coalesce current sub-block with following one */ } - /* I think there is an optimization opportunity here. - * Calling ZSTD_estimateSubBlockSize for every sequence can be wasteful - * since it recalculates estimate from scratch. - * For example, it would recount literal distribution and symbol codes every time. - */ - cBlockSizeEstimate = ZSTD_estimateSubBlockSize(lp, litSize, ofCodePtr, llCodePtr, mlCodePtr, seqCount, - &nextCBlock->entropy, entropyMetadata, - workspace, wkspSize, writeLitEntropy, writeSeqEntropy); - if (cBlockSizeEstimate > targetCBlockSize || lastSequence) { - int litEntropyWritten = 0; - int seqEntropyWritten = 0; - const size_t decompressedSize = ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, lastSequence); - const size_t cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, - sp, seqCount, - lp, litSize, - llCodePtr, mlCodePtr, ofCodePtr, - cctxParams, - op, oend-op, - bmi2, writeLitEntropy, writeSeqEntropy, - &litEntropyWritten, &seqEntropyWritten, - lastBlock && lastSequence); - FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); - if (cSize > 0 && cSize < decompressedSize) { - DEBUGLOG(5, "Committed the sub-block"); - assert(ip + decompressedSize <= iend); - ip += decompressedSize; - sp += seqCount; - lp += litSize; - op += cSize; - llCodePtr += seqCount; - mlCodePtr += seqCount; - ofCodePtr += seqCount; - litSize = 0; - seqCount = 0; - /* Entropy only needs to be written once */ - if (litEntropyWritten) { - writeLitEntropy = 0; - } - if (seqEntropyWritten) { - writeSeqEntropy = 0; - } + } /* if (nbSeqs > 0) */ + + /* write last block */ + DEBUGLOG(5, "Generate last sub-block: %u sequences remaining", (unsigned)(send - sp)); + { int litEntropyWritten = 0; + int seqEntropyWritten = 0; + size_t litSize = (size_t)(lend - lp); + size_t seqCount = (size_t)(send - sp); + const size_t decompressedSize = + ZSTD_seqDecompressedSize(seqStorePtr, sp, seqCount, litSize, 1); + size_t const cSize = ZSTD_compressSubBlock(&nextCBlock->entropy, entropyMetadata, + sp, seqCount, + lp, litSize, + llCodePtr, mlCodePtr, ofCodePtr, + cctxParams, + op, (size_t)(oend-op), + bmi2, writeLitEntropy, writeSeqEntropy, + &litEntropyWritten, &seqEntropyWritten, + lastBlock); + FORWARD_IF_ERROR(cSize, "ZSTD_compressSubBlock failed"); + + /* update pointers, the nb of literals borrowed from next sequence must be preserved */ + if (cSize > 0 && cSize < decompressedSize) { + DEBUGLOG(5, "Last sub-block compressed %u bytes => %u bytes", + (unsigned)decompressedSize, (unsigned)cSize); + assert(ip + decompressedSize <= iend); + ip += decompressedSize; + lp += litSize; + op += cSize; + llCodePtr += seqCount; + mlCodePtr += seqCount; + ofCodePtr += seqCount; + /* Entropy only needs to be written once */ + if (litEntropyWritten) { + writeLitEntropy = 0; } + if (seqEntropyWritten) { + writeSeqEntropy = 0; + } + sp += seqCount; } - } while (!lastSequence); + } + + if (writeLitEntropy) { - DEBUGLOG(5, "ZSTD_compressSubBlock_multi has literal entropy tables unwritten"); + DEBUGLOG(5, "Literal entropy tables were never written"); ZSTD_memcpy(&nextCBlock->entropy.huf, &prevCBlock->entropy.huf, sizeof(prevCBlock->entropy.huf)); } if (writeSeqEntropy && ZSTD_needSequenceEntropyTables(&entropyMetadata->fseMetadata)) { /* If we haven't written our entropy tables, then we've violated our contract and * must emit an uncompressed block. */ - DEBUGLOG(5, "ZSTD_compressSubBlock_multi has sequence entropy tables unwritten"); + DEBUGLOG(5, "Sequence entropy tables were never written => cancel, emit an uncompressed block"); return 0; } + if (ip < iend) { - size_t const cSize = ZSTD_noCompressBlock(op, oend - op, ip, iend - ip, lastBlock); - DEBUGLOG(5, "ZSTD_compressSubBlock_multi last sub-block uncompressed, %zu bytes", (size_t)(iend - ip)); + /* some data left : last part of the block sent uncompressed */ + size_t const rSize = (size_t)((iend - ip)); + size_t const cSize = ZSTD_noCompressBlock(op, (size_t)(oend - op), ip, rSize, lastBlock); + DEBUGLOG(5, "Generate last uncompressed sub-block of %u bytes", (unsigned)(rSize)); FORWARD_IF_ERROR(cSize, "ZSTD_noCompressBlock failed"); assert(cSize != 0); op += cSize; /* We have to regenerate the repcodes because we've skipped some sequences */ if (sp < send) { - seqDef const* seq; + const seqDef* seq; repcodes_t rep; ZSTD_memcpy(&rep, prevCBlock->rep, sizeof(rep)); for (seq = sstart; seq < sp; ++seq) { - ZSTD_updateRep(rep.rep, seq->offBase - 1, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); + ZSTD_updateRep(rep.rep, seq->offBase, ZSTD_getSequenceLength(seqStorePtr, seq).litLength == 0); } ZSTD_memcpy(nextCBlock->rep, &rep, sizeof(rep)); } } - DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed"); - return op-ostart; + + DEBUGLOG(5, "ZSTD_compressSubBlock_multi compressed all subBlocks: total compressed size = %u", + (unsigned)(op-ostart)); + return (size_t)(op-ostart); } size_t ZSTD_compressSuperBlock(ZSTD_CCtx* zc, void* dst, size_t dstCapacity, - void const* src, size_t srcSize, - unsigned lastBlock) { + const void* src, size_t srcSize, + unsigned lastBlock) +{ ZSTD_entropyCTablesMetadata_t entropyMetadata; FORWARD_IF_ERROR(ZSTD_buildBlockEntropyStats(&zc->seqStore, diff --git a/lib/zstd/compress/zstd_compress_superblock.h b/lib/zstd/compress/zstd_compress_superblock.h index 224ece79546e..826bbc9e029b 100644 --- a/lib/zstd/compress/zstd_compress_superblock.h +++ b/lib/zstd/compress/zstd_compress_superblock.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_cwksp.h b/lib/zstd/compress/zstd_cwksp.h index 349fc923c355..86bc3c2c23c7 100644 --- a/lib/zstd/compress/zstd_cwksp.h +++ b/lib/zstd/compress/zstd_cwksp.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,7 +15,9 @@ /*-************************************* * Dependencies ***************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ #include "../common/zstd_internal.h" +#include "../common/portability_macros.h" /*-************************************* @@ -41,8 +44,9 @@ ***************************************/ typedef enum { ZSTD_cwksp_alloc_objects, - ZSTD_cwksp_alloc_buffers, - ZSTD_cwksp_alloc_aligned + ZSTD_cwksp_alloc_aligned_init_once, + ZSTD_cwksp_alloc_aligned, + ZSTD_cwksp_alloc_buffers } ZSTD_cwksp_alloc_phase_e; /* @@ -95,8 +99,8 @@ typedef enum { * * Workspace Layout: * - * [ ... workspace ... ] - * [objects][tables ... ->] free space [<- ... aligned][<- ... buffers] + * [ ... workspace ... ] + * [objects][tables ->] free space [<- buffers][<- aligned][<- init once] * * The various objects that live in the workspace are divided into the * following categories, and are allocated separately: @@ -120,9 +124,18 @@ typedef enum { * uint32_t arrays, all of whose values are between 0 and (nextSrc - base). * Their sizes depend on the cparams. These tables are 64-byte aligned. * - * - Aligned: these buffers are used for various purposes that require 4 byte - * alignment, but don't require any initialization before they're used. These - * buffers are each aligned to 64 bytes. + * - Init once: these buffers require to be initialized at least once before + * use. They should be used when we want to skip memory initialization + * while not triggering memory checkers (like Valgrind) when reading from + * from this memory without writing to it first. + * These buffers should be used carefully as they might contain data + * from previous compressions. + * Buffers are aligned to 64 bytes. + * + * - Aligned: these buffers don't require any initialization before they're + * used. The user of the buffer should make sure they write into a buffer + * location before reading from it. + * Buffers are aligned to 64 bytes. * * - Buffers: these buffers are used for various purposes that don't require * any alignment or initialization before they're used. This means they can @@ -134,8 +147,9 @@ typedef enum { * correctly packed into the workspace buffer. That order is: * * 1. Objects - * 2. Buffers - * 3. Aligned/Tables + * 2. Init once / Tables + * 3. Aligned / Tables + * 4. Buffers / Tables * * Attempts to reserve objects of different types out of order will fail. */ @@ -147,6 +161,7 @@ typedef struct { void* tableEnd; void* tableValidEnd; void* allocStart; + void* initOnceStart; BYTE allocFailed; int workspaceOversizedDuration; @@ -159,6 +174,7 @@ typedef struct { ***************************************/ MEM_STATIC size_t ZSTD_cwksp_available_space(ZSTD_cwksp* ws); +MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws); MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { (void)ws; @@ -168,6 +184,8 @@ MEM_STATIC void ZSTD_cwksp_assert_internal_consistency(ZSTD_cwksp* ws) { assert(ws->tableEnd <= ws->allocStart); assert(ws->tableValidEnd <= ws->allocStart); assert(ws->allocStart <= ws->workspaceEnd); + assert(ws->initOnceStart <= ZSTD_cwksp_initialAllocStart(ws)); + assert(ws->workspace <= ws->initOnceStart); } /* @@ -210,14 +228,10 @@ MEM_STATIC size_t ZSTD_cwksp_aligned_alloc_size(size_t size) { * for internal purposes (currently only alignment). */ MEM_STATIC size_t ZSTD_cwksp_slack_space_required(void) { - /* For alignment, the wksp will always allocate an additional n_1=[1, 64] bytes - * to align the beginning of tables section, as well as another n_2=[0, 63] bytes - * to align the beginning of the aligned section. - * - * n_1 + n_2 == 64 bytes if the cwksp is freshly allocated, due to tables and - * aligneds being sized in multiples of 64 bytes. + /* For alignment, the wksp will always allocate an additional 2*ZSTD_CWKSP_ALIGNMENT_BYTES + * bytes to align the beginning of tables section and end of buffers; */ - size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES; + size_t const slackSpace = ZSTD_CWKSP_ALIGNMENT_BYTES * 2; return slackSpace; } @@ -230,10 +244,18 @@ MEM_STATIC size_t ZSTD_cwksp_bytes_to_align_ptr(void* ptr, const size_t alignByt size_t const alignBytesMask = alignBytes - 1; size_t const bytes = (alignBytes - ((size_t)ptr & (alignBytesMask))) & alignBytesMask; assert((alignBytes & alignBytesMask) == 0); - assert(bytes != ZSTD_CWKSP_ALIGNMENT_BYTES); + assert(bytes < alignBytes); return bytes; } +/* + * Returns the initial value for allocStart which is used to determine the position from + * which we can allocate from the end of the workspace. + */ +MEM_STATIC void* ZSTD_cwksp_initialAllocStart(ZSTD_cwksp* ws) { + return (void*)((size_t)ws->workspaceEnd & ~(ZSTD_CWKSP_ALIGNMENT_BYTES-1)); +} + /* * Internal function. Do not use directly. * Reserves the given number of bytes within the aligned/buffer segment of the wksp, @@ -274,27 +296,16 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase { assert(phase >= ws->phase); if (phase > ws->phase) { - /* Going from allocating objects to allocating buffers */ - if (ws->phase < ZSTD_cwksp_alloc_buffers && - phase >= ZSTD_cwksp_alloc_buffers) { + /* Going from allocating objects to allocating initOnce / tables */ + if (ws->phase < ZSTD_cwksp_alloc_aligned_init_once && + phase >= ZSTD_cwksp_alloc_aligned_init_once) { ws->tableValidEnd = ws->objectEnd; - } + ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); - /* Going from allocating buffers to allocating aligneds/tables */ - if (ws->phase < ZSTD_cwksp_alloc_aligned && - phase >= ZSTD_cwksp_alloc_aligned) { - { /* Align the start of the "aligned" to 64 bytes. Use [1, 64] bytes. */ - size_t const bytesToAlign = - ZSTD_CWKSP_ALIGNMENT_BYTES - ZSTD_cwksp_bytes_to_align_ptr(ws->allocStart, ZSTD_CWKSP_ALIGNMENT_BYTES); - DEBUGLOG(5, "reserving aligned alignment addtl space: %zu", bytesToAlign); - ZSTD_STATIC_ASSERT((ZSTD_CWKSP_ALIGNMENT_BYTES & (ZSTD_CWKSP_ALIGNMENT_BYTES - 1)) == 0); /* power of 2 */ - RETURN_ERROR_IF(!ZSTD_cwksp_reserve_internal_buffer_space(ws, bytesToAlign), - memory_allocation, "aligned phase - alignment initial allocation failed!"); - } { /* Align the start of the tables to 64 bytes. Use [0, 63] bytes */ - void* const alloc = ws->objectEnd; + void *const alloc = ws->objectEnd; size_t const bytesToAlign = ZSTD_cwksp_bytes_to_align_ptr(alloc, ZSTD_CWKSP_ALIGNMENT_BYTES); - void* const objectEnd = (BYTE*)alloc + bytesToAlign; + void *const objectEnd = (BYTE *) alloc + bytesToAlign; DEBUGLOG(5, "reserving table alignment addtl space: %zu", bytesToAlign); RETURN_ERROR_IF(objectEnd > ws->workspaceEnd, memory_allocation, "table phase - alignment initial allocation failed!"); @@ -302,7 +313,9 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase ws->tableEnd = objectEnd; /* table area starts being empty */ if (ws->tableValidEnd < ws->tableEnd) { ws->tableValidEnd = ws->tableEnd; - } } } + } + } + } ws->phase = phase; ZSTD_cwksp_assert_internal_consistency(ws); } @@ -314,7 +327,7 @@ ZSTD_cwksp_internal_advance_phase(ZSTD_cwksp* ws, ZSTD_cwksp_alloc_phase_e phase */ MEM_STATIC int ZSTD_cwksp_owns_buffer(const ZSTD_cwksp* ws, const void* ptr) { - return (ptr != NULL) && (ws->workspace <= ptr) && (ptr <= ws->workspaceEnd); + return (ptr != NULL) && (ws->workspace <= ptr) && (ptr < ws->workspaceEnd); } /* @@ -343,6 +356,33 @@ MEM_STATIC BYTE* ZSTD_cwksp_reserve_buffer(ZSTD_cwksp* ws, size_t bytes) return (BYTE*)ZSTD_cwksp_reserve_internal(ws, bytes, ZSTD_cwksp_alloc_buffers); } +/* + * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). + * This memory has been initialized at least once in the past. + * This doesn't mean it has been initialized this time, and it might contain data from previous + * operations. + * The main usage is for algorithms that might need read access into uninitialized memory. + * The algorithm must maintain safety under these conditions and must make sure it doesn't + * leak any of the past data (directly or in side channels). + */ +MEM_STATIC void* ZSTD_cwksp_reserve_aligned_init_once(ZSTD_cwksp* ws, size_t bytes) +{ + size_t const alignedBytes = ZSTD_cwksp_align(bytes, ZSTD_CWKSP_ALIGNMENT_BYTES); + void* ptr = ZSTD_cwksp_reserve_internal(ws, alignedBytes, ZSTD_cwksp_alloc_aligned_init_once); + assert(((size_t)ptr & (ZSTD_CWKSP_ALIGNMENT_BYTES-1))== 0); + if(ptr && ptr < ws->initOnceStart) { + /* We assume the memory following the current allocation is either: + * 1. Not usable as initOnce memory (end of workspace) + * 2. Another initOnce buffer that has been allocated before (and so was previously memset) + * 3. An ASAN redzone, in which case we don't want to write on it + * For these reasons it should be fine to not explicitly zero every byte up to ws->initOnceStart. + * Note that we assume here that MSAN and ASAN cannot run in the same time. */ + ZSTD_memset(ptr, 0, MIN((size_t)((U8*)ws->initOnceStart - (U8*)ptr), alignedBytes)); + ws->initOnceStart = ptr; + } + return ptr; +} + /* * Reserves and returns memory sized on and aligned on ZSTD_CWKSP_ALIGNMENT_BYTES (64 bytes). */ @@ -356,18 +396,22 @@ MEM_STATIC void* ZSTD_cwksp_reserve_aligned(ZSTD_cwksp* ws, size_t bytes) /* * Aligned on 64 bytes. These buffers have the special property that - * their values remain constrained, allowing us to re-use them without + * their values remain constrained, allowing us to reuse them without * memset()-ing them. */ MEM_STATIC void* ZSTD_cwksp_reserve_table(ZSTD_cwksp* ws, size_t bytes) { - const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned; + const ZSTD_cwksp_alloc_phase_e phase = ZSTD_cwksp_alloc_aligned_init_once; void* alloc; void* end; void* top; - if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { - return NULL; + /* We can only start allocating tables after we are done reserving space for objects at the + * start of the workspace */ + if(ws->phase < phase) { + if (ZSTD_isError(ZSTD_cwksp_internal_advance_phase(ws, phase))) { + return NULL; + } } alloc = ws->tableEnd; end = (BYTE *)alloc + bytes; @@ -451,7 +495,7 @@ MEM_STATIC void ZSTD_cwksp_clean_tables(ZSTD_cwksp* ws) { assert(ws->tableValidEnd >= ws->objectEnd); assert(ws->tableValidEnd <= ws->allocStart); if (ws->tableValidEnd < ws->tableEnd) { - ZSTD_memset(ws->tableValidEnd, 0, (BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd); + ZSTD_memset(ws->tableValidEnd, 0, (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->tableValidEnd)); } ZSTD_cwksp_mark_tables_clean(ws); } @@ -478,14 +522,23 @@ MEM_STATIC void ZSTD_cwksp_clear(ZSTD_cwksp* ws) { ws->tableEnd = ws->objectEnd; - ws->allocStart = ws->workspaceEnd; + ws->allocStart = ZSTD_cwksp_initialAllocStart(ws); ws->allocFailed = 0; - if (ws->phase > ZSTD_cwksp_alloc_buffers) { - ws->phase = ZSTD_cwksp_alloc_buffers; + if (ws->phase > ZSTD_cwksp_alloc_aligned_init_once) { + ws->phase = ZSTD_cwksp_alloc_aligned_init_once; } ZSTD_cwksp_assert_internal_consistency(ws); } +MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { + return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); +} + +MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { + return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) + + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); +} + /* * The provided workspace takes ownership of the buffer [start, start+size). * Any existing values in the workspace are ignored (the previously managed @@ -498,6 +551,7 @@ MEM_STATIC void ZSTD_cwksp_init(ZSTD_cwksp* ws, void* start, size_t size, ZSTD_c ws->workspaceEnd = (BYTE*)start + size; ws->objectEnd = ws->workspace; ws->tableValidEnd = ws->objectEnd; + ws->initOnceStart = ZSTD_cwksp_initialAllocStart(ws); ws->phase = ZSTD_cwksp_alloc_objects; ws->isStatic = isStatic; ZSTD_cwksp_clear(ws); @@ -529,15 +583,6 @@ MEM_STATIC void ZSTD_cwksp_move(ZSTD_cwksp* dst, ZSTD_cwksp* src) { ZSTD_memset(src, 0, sizeof(ZSTD_cwksp)); } -MEM_STATIC size_t ZSTD_cwksp_sizeof(const ZSTD_cwksp* ws) { - return (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->workspace); -} - -MEM_STATIC size_t ZSTD_cwksp_used(const ZSTD_cwksp* ws) { - return (size_t)((BYTE*)ws->tableEnd - (BYTE*)ws->workspace) - + (size_t)((BYTE*)ws->workspaceEnd - (BYTE*)ws->allocStart); -} - MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { return ws->allocFailed; } @@ -550,17 +595,11 @@ MEM_STATIC int ZSTD_cwksp_reserve_failed(const ZSTD_cwksp* ws) { * Returns if the estimated space needed for a wksp is within an acceptable limit of the * actual amount of space used. */ -MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp* const ws, - size_t const estimatedSpace, int resizedWorkspace) { - if (resizedWorkspace) { - /* Resized/newly allocated wksp should have exact bounds */ - return ZSTD_cwksp_used(ws) == estimatedSpace; - } else { - /* Due to alignment, when reusing a workspace, we can actually consume 63 fewer or more bytes - * than estimatedSpace. See the comments in zstd_cwksp.h for details. - */ - return (ZSTD_cwksp_used(ws) >= estimatedSpace - 63) && (ZSTD_cwksp_used(ws) <= estimatedSpace + 63); - } +MEM_STATIC int ZSTD_cwksp_estimated_space_within_bounds(const ZSTD_cwksp *const ws, size_t const estimatedSpace) { + /* We have an alignment space between objects and tables between tables and buffers, so we can have up to twice + * the alignment bytes difference between estimation and actual usage */ + return (estimatedSpace - ZSTD_cwksp_slack_space_required()) <= ZSTD_cwksp_used(ws) && + ZSTD_cwksp_used(ws) <= estimatedSpace; } diff --git a/lib/zstd/compress/zstd_double_fast.c b/lib/zstd/compress/zstd_double_fast.c index 76933dea2624..5ff54f17d92f 100644 --- a/lib/zstd/compress/zstd_double_fast.c +++ b/lib/zstd/compress/zstd_double_fast.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,8 +12,49 @@ #include "zstd_compress_internal.h" #include "zstd_double_fast.h" +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR -void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillDoubleHashTableForCDict(ZSTD_matchState_t* ms, + void const* end, ZSTD_dictTableLoadMethod_e dtlm) +{ + const ZSTD_compressionParameters* const cParams = &ms->cParams; + U32* const hashLarge = ms->hashTable; + U32 const hBitsL = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + U32 const mls = cParams->minMatch; + U32* const hashSmall = ms->chainTable; + U32 const hBitsS = cParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* const base = ms->window.base; + const BYTE* ip = base + ms->nextToUpdate; + const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; + const U32 fastHashFillStep = 3; + + /* Always insert every fastHashFillStep position into the hash tables. + * Insert the other positions into the large hash table if their entry + * is empty. + */ + for (; ip + fastHashFillStep - 1 <= iend; ip += fastHashFillStep) { + U32 const curr = (U32)(ip - base); + U32 i; + for (i = 0; i < fastHashFillStep; ++i) { + size_t const smHashAndTag = ZSTD_hashPtr(ip + i, hBitsS, mls); + size_t const lgHashAndTag = ZSTD_hashPtr(ip + i, hBitsL, 8); + if (i == 0) { + ZSTD_writeTaggedIndex(hashSmall, smHashAndTag, curr + i); + } + if (i == 0 || hashLarge[lgHashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { + ZSTD_writeTaggedIndex(hashLarge, lgHashAndTag, curr + i); + } + /* Only load extra positions for ZSTD_dtlm_full */ + if (dtlm == ZSTD_dtlm_fast) + break; + } } +} + +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillDoubleHashTableForCCtx(ZSTD_matchState_t* ms, void const* end, ZSTD_dictTableLoadMethod_e dtlm) { const ZSTD_compressionParameters* const cParams = &ms->cParams; @@ -43,11 +85,24 @@ void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, /* Only load extra positions for ZSTD_dtlm_full */ if (dtlm == ZSTD_dtlm_fast) break; - } } + } } +} + +void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) +{ + if (tfp == ZSTD_tfp_forCDict) { + ZSTD_fillDoubleHashTableForCDict(ms, end, dtlm); + } else { + ZSTD_fillDoubleHashTableForCCtx(ms, end, dtlm); + } } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_doubleFast_noDict_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls /* template */) @@ -67,7 +122,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; + U32 offsetSaved1 = 0, offsetSaved2 = 0; size_t mLength; U32 offset; @@ -100,8 +155,8 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( U32 const current = (U32)(ip - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, current, cParams->windowLog); U32 const maxRep = current - windowLow; - if (offset_2 > maxRep) offsetSaved = offset_2, offset_2 = 0; - if (offset_1 > maxRep) offsetSaved = offset_1, offset_1 = 0; + if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; } /* Outer Loop: one iteration per match found and stored */ @@ -131,7 +186,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( if ((offset_1 > 0) & (MEM_read32(ip+1-offset_1) == MEM_read32(ip+1))) { mLength = ZSTD_count(ip+1+4, ip+1+4-offset_1, iend) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); goto _match_stored; } @@ -175,9 +230,13 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( } while (ip1 <= ilimit); _cleanup: + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; + /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -217,7 +276,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( hashLong[hl1] = (U32)(ip1 - base); } - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); _match_stored: /* match found */ @@ -243,7 +302,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( U32 const tmpOff = offset_2; offset_2 = offset_1; offset_1 = tmpOff; /* swap offset_2 <=> offset_1 */ hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = (U32)(ip-base); hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = (U32)(ip-base); - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, rLength); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, rLength); ip += rLength; anchor = ip; continue; /* faster when present ... (?) */ @@ -254,6 +313,7 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic( FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, @@ -275,7 +335,6 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; const ZSTD_matchState_t* const dms = ms->dictMatchState; const ZSTD_compressionParameters* const dictCParams = &dms->cParams; @@ -286,8 +345,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* const dictStart = dictBase + dictStartIndex; const BYTE* const dictEnd = dms->window.nextSrc; const U32 dictIndexDelta = prefixLowestIndex - (U32)(dictEnd - dictBase); - const U32 dictHBitsL = dictCParams->hashLog; - const U32 dictHBitsS = dictCParams->chainLog; + const U32 dictHBitsL = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + const U32 dictHBitsS = dictCParams->chainLog + ZSTD_SHORT_CACHE_TAG_BITS; const U32 dictAndPrefixLength = (U32)((ip - prefixLowest) + (dictEnd - dictStart)); DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_dictMatchState_generic"); @@ -295,6 +354,13 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( /* if a dictionary is attached, it must be within window range */ assert(ms->window.dictLimit + (1U << cParams->windowLog) >= endIndex); + if (ms->prefetchCDictTables) { + size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); + size_t const chainTableBytes = (((size_t)1) << dictCParams->chainLog) * sizeof(U32); + PREFETCH_AREA(dictHashLong, hashTableBytes); + PREFETCH_AREA(dictHashSmall, chainTableBytes); + } + /* init */ ip += (dictAndPrefixLength == 0); @@ -309,8 +375,12 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( U32 offset; size_t const h2 = ZSTD_hashPtr(ip, hBitsL, 8); size_t const h = ZSTD_hashPtr(ip, hBitsS, mls); - size_t const dictHL = ZSTD_hashPtr(ip, dictHBitsL, 8); - size_t const dictHS = ZSTD_hashPtr(ip, dictHBitsS, mls); + size_t const dictHashAndTagL = ZSTD_hashPtr(ip, dictHBitsL, 8); + size_t const dictHashAndTagS = ZSTD_hashPtr(ip, dictHBitsS, mls); + U32 const dictMatchIndexAndTagL = dictHashLong[dictHashAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS]; + U32 const dictMatchIndexAndTagS = dictHashSmall[dictHashAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS]; + int const dictTagsMatchL = ZSTD_comparePackedTags(dictMatchIndexAndTagL, dictHashAndTagL); + int const dictTagsMatchS = ZSTD_comparePackedTags(dictMatchIndexAndTagS, dictHashAndTagS); U32 const curr = (U32)(ip-base); U32 const matchIndexL = hashLong[h2]; U32 matchIndexS = hashSmall[h]; @@ -328,7 +398,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); goto _match_stored; } @@ -340,9 +410,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( while (((ip>anchor) & (matchLong>prefixLowest)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ goto _match_found; } - } else { + } else if (dictTagsMatchL) { /* check dictMatchState long match */ - U32 const dictMatchIndexL = dictHashLong[dictHL]; + U32 const dictMatchIndexL = dictMatchIndexAndTagL >> ZSTD_SHORT_CACHE_TAG_BITS; const BYTE* dictMatchL = dictBase + dictMatchIndexL; assert(dictMatchL < dictEnd); @@ -358,9 +428,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( if (MEM_read32(match) == MEM_read32(ip)) { goto _search_next_long; } - } else { + } else if (dictTagsMatchS) { /* check dictMatchState short match */ - U32 const dictMatchIndexS = dictHashSmall[dictHS]; + U32 const dictMatchIndexS = dictMatchIndexAndTagS >> ZSTD_SHORT_CACHE_TAG_BITS; match = dictBase + dictMatchIndexS; matchIndexS = dictMatchIndexS + dictIndexDelta; @@ -375,10 +445,11 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( continue; _search_next_long: - { size_t const hl3 = ZSTD_hashPtr(ip+1, hBitsL, 8); - size_t const dictHLNext = ZSTD_hashPtr(ip+1, dictHBitsL, 8); + size_t const dictHashAndTagL3 = ZSTD_hashPtr(ip+1, dictHBitsL, 8); U32 const matchIndexL3 = hashLong[hl3]; + U32 const dictMatchIndexAndTagL3 = dictHashLong[dictHashAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS]; + int const dictTagsMatchL3 = ZSTD_comparePackedTags(dictMatchIndexAndTagL3, dictHashAndTagL3); const BYTE* matchL3 = base + matchIndexL3; hashLong[hl3] = curr + 1; @@ -391,9 +462,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( while (((ip>anchor) & (matchL3>prefixLowest)) && (ip[-1] == matchL3[-1])) { ip--; matchL3--; mLength++; } /* catch up */ goto _match_found; } - } else { + } else if (dictTagsMatchL3) { /* check dict long +1 match */ - U32 const dictMatchIndexL3 = dictHashLong[dictHLNext]; + U32 const dictMatchIndexL3 = dictMatchIndexAndTagL3 >> ZSTD_SHORT_CACHE_TAG_BITS; const BYTE* dictMatchL3 = dictBase + dictMatchIndexL3; assert(dictMatchL3 < dictEnd); if (dictMatchL3 > dictStart && MEM_read64(dictMatchL3) == MEM_read64(ip+1)) { @@ -419,7 +490,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); _match_stored: /* match found */ @@ -448,7 +519,7 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( const BYTE* const repEnd2 = repIndex2 < prefixLowestIndex ? dictEnd : iend; size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixLowest) + 4; U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; ip += repLength2; @@ -461,8 +532,8 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState_generic( } /* while (ip < ilimit) */ /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1; + rep[1] = offset_2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -527,7 +598,9 @@ size_t ZSTD_compressBlock_doubleFast_dictMatchState( } -static size_t ZSTD_compressBlock_doubleFast_extDict_generic( +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_doubleFast_extDict_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls /* template */) @@ -585,7 +658,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( const BYTE* repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); } else { if ((matchLongIndex > dictStartIndex) && (MEM_read64(matchLong) == MEM_read64(ip))) { const BYTE* const matchEnd = matchLongIndex < prefixStartIndex ? dictEnd : iend; @@ -596,7 +669,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( while (((ip>anchor) & (matchLong>lowMatchPtr)) && (ip[-1] == matchLong[-1])) { ip--; matchLong--; mLength++; } /* catch up */ offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); } else if ((matchIndex > dictStartIndex) && (MEM_read32(match) == MEM_read32(ip))) { size_t const h3 = ZSTD_hashPtr(ip+1, hBitsL, 8); @@ -621,7 +694,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( } offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); } else { ip += ((ip-anchor) >> kSearchStrength) + 1; @@ -653,7 +726,7 @@ static size_t ZSTD_compressBlock_doubleFast_extDict_generic( const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); hashSmall[ZSTD_hashPtr(ip, hBitsS, mls)] = current2; hashLong[ZSTD_hashPtr(ip, hBitsL, 8)] = current2; ip += repLength2; @@ -694,3 +767,5 @@ size_t ZSTD_compressBlock_doubleFast_extDict( return ZSTD_compressBlock_doubleFast_extDict_7(ms, seqStore, rep, src, srcSize); } } + +#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ diff --git a/lib/zstd/compress/zstd_double_fast.h b/lib/zstd/compress/zstd_double_fast.h index 6822bde65a1d..b7ddc714f13e 100644 --- a/lib/zstd/compress/zstd_double_fast.h +++ b/lib/zstd/compress/zstd_double_fast.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -15,8 +16,12 @@ #include "../common/mem.h" /* U32 */ #include "zstd_compress_internal.h" /* ZSTD_CCtx, size_t */ +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + void ZSTD_fillDoubleHashTable(ZSTD_matchState_t* ms, - void const* end, ZSTD_dictTableLoadMethod_e dtlm); + void const* end, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp); + size_t ZSTD_compressBlock_doubleFast( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); @@ -27,6 +32,14 @@ size_t ZSTD_compressBlock_doubleFast_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST ZSTD_compressBlock_doubleFast +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE ZSTD_compressBlock_doubleFast_dictMatchState +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT ZSTD_compressBlock_doubleFast_extDict +#else +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST NULL +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_DOUBLEFAST_EXTDICT NULL +#endif /* ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR */ #endif /* ZSTD_DOUBLE_FAST_H */ diff --git a/lib/zstd/compress/zstd_fast.c b/lib/zstd/compress/zstd_fast.c index a752e6beab52..b7a63ba4ce56 100644 --- a/lib/zstd/compress/zstd_fast.c +++ b/lib/zstd/compress/zstd_fast.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -11,8 +12,46 @@ #include "zstd_compress_internal.h" /* ZSTD_hashPtr, ZSTD_count, ZSTD_storeSeq */ #include "zstd_fast.h" +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillHashTableForCDict(ZSTD_matchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm) +{ + const ZSTD_compressionParameters* const cParams = &ms->cParams; + U32* const hashTable = ms->hashTable; + U32 const hBits = cParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; + U32 const mls = cParams->minMatch; + const BYTE* const base = ms->window.base; + const BYTE* ip = base + ms->nextToUpdate; + const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; + const U32 fastHashFillStep = 3; -void ZSTD_fillHashTable(ZSTD_matchState_t* ms, + /* Currently, we always use ZSTD_dtlm_full for filling CDict tables. + * Feel free to remove this assert if there's a good reason! */ + assert(dtlm == ZSTD_dtlm_full); + + /* Always insert every fastHashFillStep position into the hash table. + * Insert the other positions if their hash entry is empty. + */ + for ( ; ip + fastHashFillStep < iend + 2; ip += fastHashFillStep) { + U32 const curr = (U32)(ip - base); + { size_t const hashAndTag = ZSTD_hashPtr(ip, hBits, mls); + ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr); } + + if (dtlm == ZSTD_dtlm_fast) continue; + /* Only load extra positions for ZSTD_dtlm_full */ + { U32 p; + for (p = 1; p < fastHashFillStep; ++p) { + size_t const hashAndTag = ZSTD_hashPtr(ip + p, hBits, mls); + if (hashTable[hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { /* not yet filled */ + ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr + p); + } } } } +} + +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_fillHashTableForCCtx(ZSTD_matchState_t* ms, const void* const end, ZSTD_dictTableLoadMethod_e dtlm) { @@ -25,6 +64,10 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, const BYTE* const iend = ((const BYTE*)end) - HASH_READ_SIZE; const U32 fastHashFillStep = 3; + /* Currently, we always use ZSTD_dtlm_fast for filling CCtx tables. + * Feel free to remove this assert if there's a good reason! */ + assert(dtlm == ZSTD_dtlm_fast); + /* Always insert every fastHashFillStep position into the hash table. * Insert the other positions if their hash entry is empty. */ @@ -42,6 +85,18 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, } } } } } +void ZSTD_fillHashTable(ZSTD_matchState_t* ms, + const void* const end, + ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp) +{ + if (tfp == ZSTD_tfp_forCDict) { + ZSTD_fillHashTableForCDict(ms, end, dtlm); + } else { + ZSTD_fillHashTableForCCtx(ms, end, dtlm); + } +} + /* * If you squint hard enough (and ignore repcodes), the search operation at any @@ -89,8 +144,9 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms, * * This is also the work we do at the beginning to enter the loop initially. */ -FORCE_INLINE_TEMPLATE size_t -ZSTD_compressBlock_fast_noDict_generic( +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_fast_noDict_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls, U32 const hasStep) @@ -117,7 +173,7 @@ ZSTD_compressBlock_fast_noDict_generic( U32 rep_offset1 = rep[0]; U32 rep_offset2 = rep[1]; - U32 offsetSaved = 0; + U32 offsetSaved1 = 0, offsetSaved2 = 0; size_t hash0; /* hash for ip0 */ size_t hash1; /* hash for ip1 */ @@ -141,8 +197,8 @@ ZSTD_compressBlock_fast_noDict_generic( { U32 const curr = (U32)(ip0 - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, cParams->windowLog); U32 const maxRep = curr - windowLow; - if (rep_offset2 > maxRep) offsetSaved = rep_offset2, rep_offset2 = 0; - if (rep_offset1 > maxRep) offsetSaved = rep_offset1, rep_offset1 = 0; + if (rep_offset2 > maxRep) offsetSaved2 = rep_offset2, rep_offset2 = 0; + if (rep_offset1 > maxRep) offsetSaved1 = rep_offset1, rep_offset1 = 0; } /* start each op */ @@ -180,8 +236,14 @@ ZSTD_compressBlock_fast_noDict_generic( mLength = ip0[-1] == match0[-1]; ip0 -= mLength; match0 -= mLength; - offcode = STORE_REPCODE_1; + offcode = REPCODE1_TO_OFFBASE; mLength += 4; + + /* First write next hash table entry; we've already calculated it. + * This write is known to be safe because the ip1 is before the + * repcode (ip2). */ + hashTable[hash1] = (U32)(ip1 - base); + goto _match; } @@ -195,6 +257,12 @@ ZSTD_compressBlock_fast_noDict_generic( /* check match at ip[0] */ if (MEM_read32(ip0) == mval) { /* found a match! */ + + /* First write next hash table entry; we've already calculated it. + * This write is known to be safe because the ip1 == ip0 + 1, so + * we know we will resume searching after ip1 */ + hashTable[hash1] = (U32)(ip1 - base); + goto _offset; } @@ -224,6 +292,21 @@ ZSTD_compressBlock_fast_noDict_generic( /* check match at ip[0] */ if (MEM_read32(ip0) == mval) { /* found a match! */ + + /* first write next hash table entry; we've already calculated it */ + if (step <= 4) { + /* We need to avoid writing an index into the hash table >= the + * position at which we will pick up our searching after we've + * taken this match. + * + * The minimum possible match has length 4, so the earliest ip0 + * can be after we take this match will be the current ip0 + 4. + * ip1 is ip0 + step - 1. If ip1 is >= ip0 + 4, we can't safely + * write this position. + */ + hashTable[hash1] = (U32)(ip1 - base); + } + goto _offset; } @@ -254,9 +337,24 @@ ZSTD_compressBlock_fast_noDict_generic( * However, it seems to be a meaningful performance hit to try to search * them. So let's not. */ + /* When the repcodes are outside of the prefix, we set them to zero before the loop. + * When the offsets are still zero, we need to restore them after the block to have a correct + * repcode history. If only one offset was invalid, it is easy. The tricky case is when both + * offsets were invalid. We need to figure out which offset to refill with. + * - If both offsets are zero they are in the same order. + * - If both offsets are non-zero, we won't restore the offsets from `offsetSaved[12]`. + * - If only one is zero, we need to decide which offset to restore. + * - If rep_offset1 is non-zero, then rep_offset2 must be offsetSaved1. + * - It is impossible for rep_offset2 to be non-zero. + * + * So if rep_offset1 started invalid (offsetSaved1 != 0) and became valid (rep_offset1 != 0), then + * set rep[0] = rep_offset1 and rep[1] = offsetSaved1. + */ + offsetSaved2 = ((offsetSaved1 != 0) && (rep_offset1 != 0)) ? offsetSaved1 : offsetSaved2; + /* save reps for next block */ - rep[0] = rep_offset1 ? rep_offset1 : offsetSaved; - rep[1] = rep_offset2 ? rep_offset2 : offsetSaved; + rep[0] = rep_offset1 ? rep_offset1 : offsetSaved1; + rep[1] = rep_offset2 ? rep_offset2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -267,7 +365,7 @@ ZSTD_compressBlock_fast_noDict_generic( match0 = base + idx; rep_offset2 = rep_offset1; rep_offset1 = (U32)(ip0-match0); - offcode = STORE_OFFSET(rep_offset1); + offcode = OFFSET_TO_OFFBASE(rep_offset1); mLength = 4; /* Count the backwards match length. */ @@ -287,11 +385,6 @@ ZSTD_compressBlock_fast_noDict_generic( ip0 += mLength; anchor = ip0; - /* write next hash table entry */ - if (ip1 < ip0) { - hashTable[hash1] = (U32)(ip1 - base); - } - /* Fill table and check for immediate repcode. */ if (ip0 <= ilimit) { /* Fill Table */ @@ -306,7 +399,7 @@ ZSTD_compressBlock_fast_noDict_generic( { U32 const tmpOff = rep_offset2; rep_offset2 = rep_offset1; rep_offset1 = tmpOff; } /* swap rep_offset2 <=> rep_offset1 */ hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); ip0 += rLength; - ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, STORE_REPCODE_1, rLength); + ZSTD_storeSeq(seqStore, 0 /*litLen*/, anchor, iend, REPCODE1_TO_OFFBASE, rLength); anchor = ip0; continue; /* faster when present (confirmed on gcc-8) ... (?) */ } } } @@ -369,6 +462,7 @@ size_t ZSTD_compressBlock_fast( } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_fast_dictMatchState_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls, U32 const hasStep) @@ -380,14 +474,14 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( U32 const stepSize = cParams->targetLength + !(cParams->targetLength); const BYTE* const base = ms->window.base; const BYTE* const istart = (const BYTE*)src; - const BYTE* ip = istart; + const BYTE* ip0 = istart; + const BYTE* ip1 = ip0 + stepSize; /* we assert below that stepSize >= 1 */ const BYTE* anchor = istart; const U32 prefixStartIndex = ms->window.dictLimit; const BYTE* const prefixStart = base + prefixStartIndex; const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - HASH_READ_SIZE; U32 offset_1=rep[0], offset_2=rep[1]; - U32 offsetSaved = 0; const ZSTD_matchState_t* const dms = ms->dictMatchState; const ZSTD_compressionParameters* const dictCParams = &dms->cParams ; @@ -397,13 +491,13 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( const BYTE* const dictStart = dictBase + dictStartIndex; const BYTE* const dictEnd = dms->window.nextSrc; const U32 dictIndexDelta = prefixStartIndex - (U32)(dictEnd - dictBase); - const U32 dictAndPrefixLength = (U32)(ip - prefixStart + dictEnd - dictStart); - const U32 dictHLog = dictCParams->hashLog; + const U32 dictAndPrefixLength = (U32)(istart - prefixStart + dictEnd - dictStart); + const U32 dictHBits = dictCParams->hashLog + ZSTD_SHORT_CACHE_TAG_BITS; /* if a dictionary is still attached, it necessarily means that * it is within window size. So we just check it. */ const U32 maxDistance = 1U << cParams->windowLog; - const U32 endIndex = (U32)((size_t)(ip - base) + srcSize); + const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); assert(endIndex - prefixStartIndex <= maxDistance); (void)maxDistance; (void)endIndex; /* these variables are not used when assert() is disabled */ @@ -413,106 +507,155 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic( * when translating a dict index into a local index */ assert(prefixStartIndex >= (U32)(dictEnd - dictBase)); + if (ms->prefetchCDictTables) { + size_t const hashTableBytes = (((size_t)1) << dictCParams->hashLog) * sizeof(U32); + PREFETCH_AREA(dictHashTable, hashTableBytes); + } + /* init */ DEBUGLOG(5, "ZSTD_compressBlock_fast_dictMatchState_generic"); - ip += (dictAndPrefixLength == 0); + ip0 += (dictAndPrefixLength == 0); /* dictMatchState repCode checks don't currently handle repCode == 0 * disabling. */ assert(offset_1 <= dictAndPrefixLength); assert(offset_2 <= dictAndPrefixLength); - /* Main Search Loop */ - while (ip < ilimit) { /* < instead of <=, because repcode check at (ip+1) */ + /* Outer search loop */ + assert(stepSize >= 1); + while (ip1 <= ilimit) { /* repcode check at (ip0 + 1) is safe because ip0 < ip1 */ size_t mLength; - size_t const h = ZSTD_hashPtr(ip, hlog, mls); - U32 const curr = (U32)(ip-base); - U32 const matchIndex = hashTable[h]; - const BYTE* match = base + matchIndex; - const U32 repIndex = curr + 1 - offset_1; - const BYTE* repMatch = (repIndex < prefixStartIndex) ? - dictBase + (repIndex - dictIndexDelta) : - base + repIndex; - hashTable[h] = curr; /* update hash table */ - - if ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow : ensure repIndex isn't overlapping dict + prefix */ - && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { - const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; - mLength = ZSTD_count_2segments(ip+1+4, repMatch+4, iend, repMatchEnd, prefixStart) + 4; - ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, mLength); - } else if ( (matchIndex <= prefixStartIndex) ) { - size_t const dictHash = ZSTD_hashPtr(ip, dictHLog, mls); - U32 const dictMatchIndex = dictHashTable[dictHash]; - const BYTE* dictMatch = dictBase + dictMatchIndex; - if (dictMatchIndex <= dictStartIndex || - MEM_read32(dictMatch) != MEM_read32(ip)) { - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; - } else { - /* found a dict match */ - U32 const offset = (U32)(curr-dictMatchIndex-dictIndexDelta); - mLength = ZSTD_count_2segments(ip+4, dictMatch+4, iend, dictEnd, prefixStart) + 4; - while (((ip>anchor) & (dictMatch>dictStart)) - && (ip[-1] == dictMatch[-1])) { - ip--; dictMatch--; mLength++; + size_t hash0 = ZSTD_hashPtr(ip0, hlog, mls); + + size_t const dictHashAndTag0 = ZSTD_hashPtr(ip0, dictHBits, mls); + U32 dictMatchIndexAndTag = dictHashTable[dictHashAndTag0 >> ZSTD_SHORT_CACHE_TAG_BITS]; + int dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag0); + + U32 matchIndex = hashTable[hash0]; + U32 curr = (U32)(ip0 - base); + size_t step = stepSize; + const size_t kStepIncr = 1 << kSearchStrength; + const BYTE* nextStep = ip0 + kStepIncr; + + /* Inner search loop */ + while (1) { + const BYTE* match = base + matchIndex; + const U32 repIndex = curr + 1 - offset_1; + const BYTE* repMatch = (repIndex < prefixStartIndex) ? + dictBase + (repIndex - dictIndexDelta) : + base + repIndex; + const size_t hash1 = ZSTD_hashPtr(ip1, hlog, mls); + size_t const dictHashAndTag1 = ZSTD_hashPtr(ip1, dictHBits, mls); + hashTable[hash0] = curr; /* update hash table */ + + if (((U32) ((prefixStartIndex - 1) - repIndex) >= + 3) /* intentional underflow : ensure repIndex isn't overlapping dict + prefix */ + && (MEM_read32(repMatch) == MEM_read32(ip0 + 1))) { + const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; + mLength = ZSTD_count_2segments(ip0 + 1 + 4, repMatch + 4, iend, repMatchEnd, prefixStart) + 4; + ip0++; + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, REPCODE1_TO_OFFBASE, mLength); + break; + } + + if (dictTagsMatch) { + /* Found a possible dict match */ + const U32 dictMatchIndex = dictMatchIndexAndTag >> ZSTD_SHORT_CACHE_TAG_BITS; + const BYTE* dictMatch = dictBase + dictMatchIndex; + if (dictMatchIndex > dictStartIndex && + MEM_read32(dictMatch) == MEM_read32(ip0)) { + /* To replicate extDict parse behavior, we only use dict matches when the normal matchIndex is invalid */ + if (matchIndex <= prefixStartIndex) { + U32 const offset = (U32) (curr - dictMatchIndex - dictIndexDelta); + mLength = ZSTD_count_2segments(ip0 + 4, dictMatch + 4, iend, dictEnd, prefixStart) + 4; + while (((ip0 > anchor) & (dictMatch > dictStart)) + && (ip0[-1] == dictMatch[-1])) { + ip0--; + dictMatch--; + mLength++; + } /* catch up */ + offset_2 = offset_1; + offset_1 = offset; + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + break; + } + } + } + + if (matchIndex > prefixStartIndex && MEM_read32(match) == MEM_read32(ip0)) { + /* found a regular match */ + U32 const offset = (U32) (ip0 - match); + mLength = ZSTD_count(ip0 + 4, match + 4, iend) + 4; + while (((ip0 > anchor) & (match > prefixStart)) + && (ip0[-1] == match[-1])) { + ip0--; + match--; + mLength++; } /* catch up */ offset_2 = offset_1; offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); + ZSTD_storeSeq(seqStore, (size_t) (ip0 - anchor), anchor, iend, OFFSET_TO_OFFBASE(offset), mLength); + break; } - } else if (MEM_read32(match) != MEM_read32(ip)) { - /* it's not a match, and we're not going to check the dictionary */ - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; - } else { - /* found a regular match */ - U32 const offset = (U32)(ip-match); - mLength = ZSTD_count(ip+4, match+4, iend) + 4; - while (((ip>anchor) & (match>prefixStart)) - && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ - offset_2 = offset_1; - offset_1 = offset; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); - } + + /* Prepare for next iteration */ + dictMatchIndexAndTag = dictHashTable[dictHashAndTag1 >> ZSTD_SHORT_CACHE_TAG_BITS]; + dictTagsMatch = ZSTD_comparePackedTags(dictMatchIndexAndTag, dictHashAndTag1); + matchIndex = hashTable[hash1]; + + if (ip1 >= nextStep) { + step++; + nextStep += kStepIncr; + } + ip0 = ip1; + ip1 = ip1 + step; + if (ip1 > ilimit) goto _cleanup; + + curr = (U32)(ip0 - base); + hash0 = hash1; + } /* end inner search loop */ /* match found */ - ip += mLength; - anchor = ip; + assert(mLength); + ip0 += mLength; + anchor = ip0; - if (ip <= ilimit) { + if (ip0 <= ilimit) { /* Fill Table */ assert(base+curr+2 > istart); /* check base overflow */ hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; /* here because curr+2 could be > iend-8 */ - hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); + hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); /* check immediate repcode */ - while (ip <= ilimit) { - U32 const current2 = (U32)(ip-base); + while (ip0 <= ilimit) { + U32 const current2 = (U32)(ip0-base); U32 const repIndex2 = current2 - offset_2; const BYTE* repMatch2 = repIndex2 < prefixStartIndex ? dictBase - dictIndexDelta + repIndex2 : base + repIndex2; if ( ((U32)((prefixStartIndex-1) - (U32)repIndex2) >= 3 /* intentional overflow */) - && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { + && (MEM_read32(repMatch2) == MEM_read32(ip0))) { const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; - size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; U32 tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, repLength2); - hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; - ip += repLength2; - anchor = ip; + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = current2; + ip0 += repLength2; + anchor = ip0; continue; } break; } } + + /* Prepare for next iteration */ + assert(ip0 == anchor); + ip1 = ip0 + stepSize; } +_cleanup: /* save reps for next block */ - rep[0] = offset_1 ? offset_1 : offsetSaved; - rep[1] = offset_2 ? offset_2 : offsetSaved; + rep[0] = offset_1; + rep[1] = offset_2; /* Return the last literals size */ return (size_t)(iend - anchor); @@ -545,7 +688,9 @@ size_t ZSTD_compressBlock_fast_dictMatchState( } -static size_t ZSTD_compressBlock_fast_extDict_generic( +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_fast_extDict_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize, U32 const mls, U32 const hasStep) { @@ -553,11 +698,10 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( U32* const hashTable = ms->hashTable; U32 const hlog = cParams->hashLog; /* support stepSize of 0 */ - U32 const stepSize = cParams->targetLength + !(cParams->targetLength); + size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; const BYTE* const base = ms->window.base; const BYTE* const dictBase = ms->window.dictBase; const BYTE* const istart = (const BYTE*)src; - const BYTE* ip = istart; const BYTE* anchor = istart; const U32 endIndex = (U32)((size_t)(istart - base) + srcSize); const U32 lowLimit = ZSTD_getLowestMatchIndex(ms, endIndex, cParams->windowLog); @@ -570,6 +714,28 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( const BYTE* const iend = istart + srcSize; const BYTE* const ilimit = iend - 8; U32 offset_1=rep[0], offset_2=rep[1]; + U32 offsetSaved1 = 0, offsetSaved2 = 0; + + const BYTE* ip0 = istart; + const BYTE* ip1; + const BYTE* ip2; + const BYTE* ip3; + U32 current0; + + + size_t hash0; /* hash for ip0 */ + size_t hash1; /* hash for ip1 */ + U32 idx; /* match idx for ip0 */ + const BYTE* idxBase; /* base pointer for idx */ + + U32 offcode; + const BYTE* match0; + size_t mLength; + const BYTE* matchEnd = 0; /* initialize to avoid warning, assert != 0 later */ + + size_t step; + const BYTE* nextStep; + const size_t kStepIncr = (1 << (kSearchStrength - 1)); (void)hasStep; /* not currently specialized on whether it's accelerated */ @@ -579,75 +745,202 @@ static size_t ZSTD_compressBlock_fast_extDict_generic( if (prefixStartIndex == dictStartIndex) return ZSTD_compressBlock_fast(ms, seqStore, rep, src, srcSize); - /* Search Loop */ - while (ip < ilimit) { /* < instead of <=, because (ip+1) */ - const size_t h = ZSTD_hashPtr(ip, hlog, mls); - const U32 matchIndex = hashTable[h]; - const BYTE* const matchBase = matchIndex < prefixStartIndex ? dictBase : base; - const BYTE* match = matchBase + matchIndex; - const U32 curr = (U32)(ip-base); - const U32 repIndex = curr + 1 - offset_1; - const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; - const BYTE* const repMatch = repBase + repIndex; - hashTable[h] = curr; /* update hash table */ - DEBUGLOG(7, "offset_1 = %u , curr = %u", offset_1, curr); - - if ( ( ((U32)((prefixStartIndex-1) - repIndex) >= 3) /* intentional underflow */ - & (offset_1 <= curr+1 - dictStartIndex) ) /* note: we are searching at curr+1 */ - && (MEM_read32(repMatch) == MEM_read32(ip+1)) ) { - const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend; - size_t const rLength = ZSTD_count_2segments(ip+1 +4, repMatch +4, iend, repMatchEnd, prefixStart) + 4; - ip++; - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_REPCODE_1, rLength); - ip += rLength; - anchor = ip; - } else { - if ( (matchIndex < dictStartIndex) || - (MEM_read32(match) != MEM_read32(ip)) ) { - assert(stepSize >= 1); - ip += ((ip-anchor) >> kSearchStrength) + stepSize; - continue; + { U32 const curr = (U32)(ip0 - base); + U32 const maxRep = curr - dictStartIndex; + if (offset_2 >= maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 >= maxRep) offsetSaved1 = offset_1, offset_1 = 0; + } + + /* start each op */ +_start: /* Requires: ip0 */ + + step = stepSize; + nextStep = ip0 + kStepIncr; + + /* calculate positions, ip0 - anchor == 0, so we skip step calc */ + ip1 = ip0 + 1; + ip2 = ip0 + step; + ip3 = ip2 + 1; + + if (ip3 >= ilimit) { + goto _cleanup; + } + + hash0 = ZSTD_hashPtr(ip0, hlog, mls); + hash1 = ZSTD_hashPtr(ip1, hlog, mls); + + idx = hashTable[hash0]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + do { + { /* load repcode match for ip[2] */ + U32 const current2 = (U32)(ip2 - base); + U32 const repIndex = current2 - offset_1; + const BYTE* const repBase = repIndex < prefixStartIndex ? dictBase : base; + U32 rval; + if ( ((U32)(prefixStartIndex - repIndex) >= 4) /* intentional underflow */ + & (offset_1 > 0) ) { + rval = MEM_read32(repBase + repIndex); + } else { + rval = MEM_read32(ip2) ^ 1; /* guaranteed to not match. */ } - { const BYTE* const matchEnd = matchIndex < prefixStartIndex ? dictEnd : iend; - const BYTE* const lowMatchPtr = matchIndex < prefixStartIndex ? dictStart : prefixStart; - U32 const offset = curr - matchIndex; - size_t mLength = ZSTD_count_2segments(ip+4, match+4, iend, matchEnd, prefixStart) + 4; - while (((ip>anchor) & (match>lowMatchPtr)) && (ip[-1] == match[-1])) { ip--; match--; mLength++; } /* catch up */ - offset_2 = offset_1; offset_1 = offset; /* update offset history */ - ZSTD_storeSeq(seqStore, (size_t)(ip-anchor), anchor, iend, STORE_OFFSET(offset), mLength); - ip += mLength; - anchor = ip; + + /* write back hash table entry */ + current0 = (U32)(ip0 - base); + hashTable[hash0] = current0; + + /* check repcode at ip[2] */ + if (MEM_read32(ip2) == rval) { + ip0 = ip2; + match0 = repBase + repIndex; + matchEnd = repIndex < prefixStartIndex ? dictEnd : iend; + assert((match0 != prefixStart) & (match0 != dictStart)); + mLength = ip0[-1] == match0[-1]; + ip0 -= mLength; + match0 -= mLength; + offcode = REPCODE1_TO_OFFBASE; + mLength += 4; + goto _match; } } - if (ip <= ilimit) { - /* Fill Table */ - hashTable[ZSTD_hashPtr(base+curr+2, hlog, mls)] = curr+2; - hashTable[ZSTD_hashPtr(ip-2, hlog, mls)] = (U32)(ip-2-base); - /* check immediate repcode */ - while (ip <= ilimit) { - U32 const current2 = (U32)(ip-base); - U32 const repIndex2 = current2 - offset_2; - const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; - if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) & (offset_2 <= curr - dictStartIndex)) /* intentional overflow */ - && (MEM_read32(repMatch2) == MEM_read32(ip)) ) { - const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; - size_t const repLength2 = ZSTD_count_2segments(ip+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; - { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, STORE_REPCODE_1, repLength2); - hashTable[ZSTD_hashPtr(ip, hlog, mls)] = current2; - ip += repLength2; - anchor = ip; - continue; - } - break; - } } } + { /* load match for ip[0] */ + U32 const mval = idx >= dictStartIndex ? + MEM_read32(idxBase + idx) : + MEM_read32(ip0) ^ 1; /* guaranteed not to match */ + + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ + goto _offset; + } } + + /* lookup ip[1] */ + idx = hashTable[hash1]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + /* hash ip[2] */ + hash0 = hash1; + hash1 = ZSTD_hashPtr(ip2, hlog, mls); + + /* advance to next positions */ + ip0 = ip1; + ip1 = ip2; + ip2 = ip3; + + /* write back hash table entry */ + current0 = (U32)(ip0 - base); + hashTable[hash0] = current0; + + { /* load match for ip[0] */ + U32 const mval = idx >= dictStartIndex ? + MEM_read32(idxBase + idx) : + MEM_read32(ip0) ^ 1; /* guaranteed not to match */ + + /* check match at ip[0] */ + if (MEM_read32(ip0) == mval) { + /* found a match! */ + goto _offset; + } } + + /* lookup ip[1] */ + idx = hashTable[hash1]; + idxBase = idx < prefixStartIndex ? dictBase : base; + + /* hash ip[2] */ + hash0 = hash1; + hash1 = ZSTD_hashPtr(ip2, hlog, mls); + + /* advance to next positions */ + ip0 = ip1; + ip1 = ip2; + ip2 = ip0 + step; + ip3 = ip1 + step; + + /* calculate step */ + if (ip2 >= nextStep) { + step++; + PREFETCH_L1(ip1 + 64); + PREFETCH_L1(ip1 + 128); + nextStep += kStepIncr; + } + } while (ip3 < ilimit); + +_cleanup: + /* Note that there are probably still a couple positions we could search. + * However, it seems to be a meaningful performance hit to try to search + * them. So let's not. */ + + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; /* save reps for next block */ - rep[0] = offset_1; - rep[1] = offset_2; + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); + +_offset: /* Requires: ip0, idx, idxBase */ + + /* Compute the offset code. */ + { U32 const offset = current0 - idx; + const BYTE* const lowMatchPtr = idx < prefixStartIndex ? dictStart : prefixStart; + matchEnd = idx < prefixStartIndex ? dictEnd : iend; + match0 = idxBase + idx; + offset_2 = offset_1; + offset_1 = offset; + offcode = OFFSET_TO_OFFBASE(offset); + mLength = 4; + + /* Count the backwards match length. */ + while (((ip0>anchor) & (match0>lowMatchPtr)) && (ip0[-1] == match0[-1])) { + ip0--; + match0--; + mLength++; + } } + +_match: /* Requires: ip0, match0, offcode, matchEnd */ + + /* Count the forward length. */ + assert(matchEnd != 0); + mLength += ZSTD_count_2segments(ip0 + mLength, match0 + mLength, iend, matchEnd, prefixStart); + + ZSTD_storeSeq(seqStore, (size_t)(ip0 - anchor), anchor, iend, offcode, mLength); + + ip0 += mLength; + anchor = ip0; + + /* write next hash table entry */ + if (ip1 < ip0) { + hashTable[hash1] = (U32)(ip1 - base); + } + + /* Fill table and check for immediate repcode. */ + if (ip0 <= ilimit) { + /* Fill Table */ + assert(base+current0+2 > istart); /* check base overflow */ + hashTable[ZSTD_hashPtr(base+current0+2, hlog, mls)] = current0+2; /* here because current+2 could be > iend-8 */ + hashTable[ZSTD_hashPtr(ip0-2, hlog, mls)] = (U32)(ip0-2-base); + + while (ip0 <= ilimit) { + U32 const repIndex2 = (U32)(ip0-base) - offset_2; + const BYTE* const repMatch2 = repIndex2 < prefixStartIndex ? dictBase + repIndex2 : base + repIndex2; + if ( (((U32)((prefixStartIndex-1) - repIndex2) >= 3) & (offset_2 > 0)) /* intentional underflow */ + && (MEM_read32(repMatch2) == MEM_read32(ip0)) ) { + const BYTE* const repEnd2 = repIndex2 < prefixStartIndex ? dictEnd : iend; + size_t const repLength2 = ZSTD_count_2segments(ip0+4, repMatch2+4, iend, repEnd2, prefixStart) + 4; + { U32 const tmpOffset = offset_2; offset_2 = offset_1; offset_1 = tmpOffset; } /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStore, 0 /*litlen*/, anchor, iend, REPCODE1_TO_OFFBASE, repLength2); + hashTable[ZSTD_hashPtr(ip0, hlog, mls)] = (U32)(ip0-base); + ip0 += repLength2; + anchor = ip0; + continue; + } + break; + } } + + goto _start; } ZSTD_GEN_FAST_FN(extDict, 4, 0) @@ -660,6 +953,7 @@ size_t ZSTD_compressBlock_fast_extDict( void const* src, size_t srcSize) { U32 const mls = ms->cParams.minMatch; + assert(ms->dictMatchState == NULL); switch(mls) { default: /* includes case 3 */ diff --git a/lib/zstd/compress/zstd_fast.h b/lib/zstd/compress/zstd_fast.h index fddc2f532d21..e64d9e1b2d39 100644 --- a/lib/zstd/compress/zstd_fast.h +++ b/lib/zstd/compress/zstd_fast.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -16,7 +17,8 @@ #include "zstd_compress_internal.h" void ZSTD_fillHashTable(ZSTD_matchState_t* ms, - void const* end, ZSTD_dictTableLoadMethod_e dtlm); + void const* end, ZSTD_dictTableLoadMethod_e dtlm, + ZSTD_tableFillPurpose_e tfp); size_t ZSTD_compressBlock_fast( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); diff --git a/lib/zstd/compress/zstd_lazy.c b/lib/zstd/compress/zstd_lazy.c index 0298a01a7504..3e88d8a1a136 100644 --- a/lib/zstd/compress/zstd_lazy.c +++ b/lib/zstd/compress/zstd_lazy.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -10,14 +11,23 @@ #include "zstd_compress_internal.h" #include "zstd_lazy.h" +#include "../common/bits.h" /* ZSTD_countTrailingZeros64 */ + +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) + +#define kLazySkippingStep 8 /*-************************************* * Binary Tree search ***************************************/ -static void -ZSTD_updateDUBT(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_updateDUBT(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend, U32 mls) { @@ -60,8 +70,9 @@ ZSTD_updateDUBT(ZSTD_matchState_t* ms, * sort one already inserted but unsorted position * assumption : curr >= btlow == (curr - btmask) * doesn't fail */ -static void -ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, U32 curr, const BYTE* inputEnd, U32 nbCompares, U32 btLow, const ZSTD_dictMode_e dictMode) @@ -149,8 +160,9 @@ ZSTD_insertDUBT1(const ZSTD_matchState_t* ms, } -static size_t -ZSTD_DUBT_findBetterDictMatch ( +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_DUBT_findBetterDictMatch ( const ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iend, size_t* offsetPtr, @@ -197,8 +209,8 @@ ZSTD_DUBT_findBetterDictMatch ( U32 matchIndex = dictMatchIndex + dictIndexDelta; if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) { DEBUGLOG(9, "ZSTD_DUBT_findBetterDictMatch(%u) : found better match length %u -> %u and offsetCode %u -> %u (dictMatchIndex %u, matchIndex %u)", - curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, STORE_OFFSET(curr - matchIndex), dictMatchIndex, matchIndex); - bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); + curr, (U32)bestLength, (U32)matchLength, (U32)*offsetPtr, OFFSET_TO_OFFBASE(curr - matchIndex), dictMatchIndex, matchIndex); + bestLength = matchLength, *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); } if (ip+matchLength == iend) { /* reached end of input : ip[matchLength] is not valid, no way to know if it's larger or smaller than match */ break; /* drop, to guarantee consistency (miss a little bit of compression) */ @@ -218,7 +230,7 @@ ZSTD_DUBT_findBetterDictMatch ( } if (bestLength >= MINMATCH) { - U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; + U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offsetPtr); (void)mIndex; DEBUGLOG(8, "ZSTD_DUBT_findBetterDictMatch(%u) : found match of length %u and offsetCode %u (pos %u)", curr, (U32)bestLength, (U32)*offsetPtr, mIndex); } @@ -227,10 +239,11 @@ ZSTD_DUBT_findBetterDictMatch ( } -static size_t -ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iend, - size_t* offsetPtr, + size_t* offBasePtr, U32 const mls, const ZSTD_dictMode_e dictMode) { @@ -327,8 +340,8 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, if (matchLength > bestLength) { if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; - if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr-matchIndex+1) - ZSTD_highbit32((U32)offsetPtr[0]+1)) ) - bestLength = matchLength, *offsetPtr = STORE_OFFSET(curr - matchIndex); + if ( (4*(int)(matchLength-bestLength)) > (int)(ZSTD_highbit32(curr - matchIndex + 1) - ZSTD_highbit32((U32)*offBasePtr)) ) + bestLength = matchLength, *offBasePtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+matchLength == iend) { /* equal : no way to know if inf or sup */ if (dictMode == ZSTD_dictMatchState) { nbCompares = 0; /* in addition to avoiding checking any @@ -361,16 +374,16 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, if (dictMode == ZSTD_dictMatchState && nbCompares) { bestLength = ZSTD_DUBT_findBetterDictMatch( ms, ip, iend, - offsetPtr, bestLength, nbCompares, + offBasePtr, bestLength, nbCompares, mls, dictMode); } assert(matchEndIdx > curr+8); /* ensure nextToUpdate is increased */ ms->nextToUpdate = matchEndIdx - 8; /* skip repetitive patterns */ if (bestLength >= MINMATCH) { - U32 const mIndex = curr - (U32)STORED_OFFSET(*offsetPtr); (void)mIndex; + U32 const mIndex = curr - (U32)OFFBASE_TO_OFFSET(*offBasePtr); (void)mIndex; DEBUGLOG(8, "ZSTD_DUBT_findBestMatch(%u) : found match of length %u and offsetCode %u (pos %u)", - curr, (U32)bestLength, (U32)*offsetPtr, mIndex); + curr, (U32)bestLength, (U32)*offBasePtr, mIndex); } return bestLength; } @@ -378,17 +391,18 @@ ZSTD_DUBT_findBestMatch(ZSTD_matchState_t* ms, /* ZSTD_BtFindBestMatch() : Tree updater, providing best match */ -FORCE_INLINE_TEMPLATE size_t -ZSTD_BtFindBestMatch( ZSTD_matchState_t* ms, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_BtFindBestMatch( ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, - size_t* offsetPtr, + size_t* offBasePtr, const U32 mls /* template */, const ZSTD_dictMode_e dictMode) { DEBUGLOG(7, "ZSTD_BtFindBestMatch"); if (ip < ms->window.base + ms->nextToUpdate) return 0; /* skipped area */ ZSTD_updateDUBT(ms, ip, iLimit, mls); - return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offsetPtr, mls, dictMode); + return ZSTD_DUBT_findBestMatch(ms, ip, iLimit, offBasePtr, mls, dictMode); } /* ********************************* @@ -561,7 +575,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); if (ip+currentMl == iLimit) { /* best possible, avoids read overflow on next attempt */ return ml; @@ -598,7 +612,7 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - (matchIndex + ddsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + ddsIndexDelta)); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } } @@ -614,10 +628,12 @@ size_t ZSTD_dedicatedDictSearch_lazy_search(size_t* offsetPtr, size_t ml, U32 nb /* Update chains up to ip (excluded) Assumption : always within prefix (i.e. not within extDict) */ -FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertAndFindFirstIndex_internal( ZSTD_matchState_t* ms, const ZSTD_compressionParameters* const cParams, - const BYTE* ip, U32 const mls) + const BYTE* ip, U32 const mls, U32 const lazySkipping) { U32* const hashTable = ms->hashTable; const U32 hashLog = cParams->hashLog; @@ -632,6 +648,9 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( NEXT_IN_CHAIN(idx, chainMask) = hashTable[h]; hashTable[h] = idx; idx++; + /* Stop inserting every position when in the lazy skipping mode. */ + if (lazySkipping) + break; } ms->nextToUpdate = target; @@ -640,11 +659,12 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_insertAndFindFirstIndex_internal( U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip) { const ZSTD_compressionParameters* const cParams = &ms->cParams; - return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch); + return ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, ms->cParams.minMatch, /* lazySkipping*/ 0); } /* inlining is important to hardwire a hot branch (template emulation) */ FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_HcFindBestMatch( ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, @@ -684,14 +704,15 @@ size_t ZSTD_HcFindBestMatch( } /* HC4 match finder */ - matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls); + matchIndex = ZSTD_insertAndFindFirstIndex_internal(ms, cParams, ip, mls, ms->lazySkipping); for ( ; (matchIndex>=lowLimit) & (nbAttempts>0) ; nbAttempts--) { size_t currentMl=0; if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { const BYTE* const match = base + matchIndex; assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ - if (match[ml] == ip[ml]) /* potentially better */ + /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ + if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ currentMl = ZSTD_count(ip, match, iLimit); } else { const BYTE* const match = dictBase + matchIndex; @@ -703,7 +724,7 @@ size_t ZSTD_HcFindBestMatch( /* save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - matchIndex); + *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } @@ -739,7 +760,7 @@ size_t ZSTD_HcFindBestMatch( if (currentMl > ml) { ml = currentMl; assert(curr > matchIndex + dmsIndexDelta); - *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } @@ -756,8 +777,6 @@ size_t ZSTD_HcFindBestMatch( * (SIMD) Row-based matchfinder ***********************************/ /* Constants for row-based hash */ -#define ZSTD_ROW_HASH_TAG_OFFSET 16 /* byte offset of hashes in the match state's tagTable from the beginning of a row */ -#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ #define ZSTD_ROW_HASH_TAG_MASK ((1u << ZSTD_ROW_HASH_TAG_BITS) - 1) #define ZSTD_ROW_HASH_MAX_ENTRIES 64 /* absolute maximum number of entries per row, for all configurations */ @@ -769,64 +788,19 @@ typedef U64 ZSTD_VecMask; /* Clarifies when we are interacting with a U64 repr * Starting from the LSB, returns the idx of the next non-zero bit. * Basically counting the nb of trailing zeroes. */ -static U32 ZSTD_VecMask_next(ZSTD_VecMask val) { - assert(val != 0); -# if (defined(__GNUC__) && ((__GNUC__ > 3) || ((__GNUC__ == 3) && (__GNUC_MINOR__ >= 4)))) - if (sizeof(size_t) == 4) { - U32 mostSignificantWord = (U32)(val >> 32); - U32 leastSignificantWord = (U32)val; - if (leastSignificantWord == 0) { - return 32 + (U32)__builtin_ctz(mostSignificantWord); - } else { - return (U32)__builtin_ctz(leastSignificantWord); - } - } else { - return (U32)__builtin_ctzll(val); - } -# else - /* Software ctz version: http://aggregate.org/MAGIC/#Trailing%20Zero%20Count - * and: https://stackoverflow.com/questions/2709430/count-number-of-bits-in-a-64-bit-long-big-integer - */ - val = ~val & (val - 1ULL); /* Lowest set bit mask */ - val = val - ((val >> 1) & 0x5555555555555555); - val = (val & 0x3333333333333333ULL) + ((val >> 2) & 0x3333333333333333ULL); - return (U32)((((val + (val >> 4)) & 0xF0F0F0F0F0F0F0FULL) * 0x101010101010101ULL) >> 56); -# endif -} - -/* ZSTD_rotateRight_*(): - * Rotates a bitfield to the right by "count" bits. - * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts - */ -FORCE_INLINE_TEMPLATE -U64 ZSTD_rotateRight_U64(U64 const value, U32 count) { - assert(count < 64); - count &= 0x3F; /* for fickle pattern recognition */ - return (value >> count) | (U64)(value << ((0U - count) & 0x3F)); -} - -FORCE_INLINE_TEMPLATE -U32 ZSTD_rotateRight_U32(U32 const value, U32 count) { - assert(count < 32); - count &= 0x1F; /* for fickle pattern recognition */ - return (value >> count) | (U32)(value << ((0U - count) & 0x1F)); -} - -FORCE_INLINE_TEMPLATE -U16 ZSTD_rotateRight_U16(U16 const value, U32 count) { - assert(count < 16); - count &= 0x0F; /* for fickle pattern recognition */ - return (value >> count) | (U16)(value << ((0U - count) & 0x0F)); +MEM_STATIC U32 ZSTD_VecMask_next(ZSTD_VecMask val) { + return ZSTD_countTrailingZeros64(val); } /* ZSTD_row_nextIndex(): * Returns the next index to insert at within a tagTable row, and updates the "head" - * value to reflect the update. Essentially cycles backwards from [0, {entries per row}) + * value to reflect the update. Essentially cycles backwards from [1, {entries per row}) */ FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextIndex(BYTE* const tagRow, U32 const rowMask) { - U32 const next = (*tagRow - 1) & rowMask; - *tagRow = (BYTE)next; - return next; + U32 next = (*tagRow-1) & rowMask; + next += (next == 0) ? rowMask : 0; /* skip first position */ + *tagRow = (BYTE)next; + return next; } /* ZSTD_isAligned(): @@ -840,7 +814,7 @@ MEM_STATIC int ZSTD_isAligned(void const* ptr, size_t align) { /* ZSTD_row_prefetch(): * Performs prefetching for the hashTable and tagTable at a given row. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* tagTable, U32 const relRow, U32 const rowLog) { +FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, BYTE const* tagTable, U32 const relRow, U32 const rowLog) { PREFETCH_L1(hashTable + relRow); if (rowLog >= 5) { PREFETCH_L1(hashTable + relRow + 16); @@ -859,18 +833,20 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_prefetch(U32 const* hashTable, U16 const* ta * Fill up the hash cache starting at idx, prefetching up to ZSTD_ROW_HASH_CACHE_SIZE entries, * but not beyond iLimit. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const BYTE* base, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const BYTE* base, U32 const rowLog, U32 const mls, U32 idx, const BYTE* const iLimit) { U32 const* const hashTable = ms->hashTable; - U16 const* const tagTable = ms->tagTable; + BYTE const* const tagTable = ms->tagTable; U32 const hashLog = ms->rowHashLog; U32 const maxElemsToPrefetch = (base + idx) > iLimit ? 0 : (U32)(iLimit - (base + idx) + 1); U32 const lim = idx + MIN(ZSTD_ROW_HASH_CACHE_SIZE, maxElemsToPrefetch); for (; idx < lim; ++idx) { - U32 const hash = (U32)ZSTD_hashPtr(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const hash = (U32)ZSTD_hashPtrSalted(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); U32 const row = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); ms->hashCache[idx & ZSTD_ROW_HASH_CACHE_MASK] = hash; @@ -885,12 +861,15 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const B * Returns the hash of base + idx, and replaces the hash in the hash cache with the byte at * base + idx + ZSTD_ROW_HASH_CACHE_SIZE. Also prefetches the appropriate rows from hashTable and tagTable. */ -FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, - U16 const* tagTable, BYTE const* base, +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable, + BYTE const* tagTable, BYTE const* base, U32 idx, U32 const hashLog, - U32 const rowLog, U32 const mls) + U32 const rowLog, U32 const mls, + U64 const hashSalt) { - U32 const newHash = (U32)ZSTD_hashPtr(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const newHash = (U32)ZSTD_hashPtrSalted(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); U32 const row = (newHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; ZSTD_row_prefetch(hashTable, tagTable, row, rowLog); { U32 const hash = cache[idx & ZSTD_ROW_HASH_CACHE_MASK]; @@ -902,28 +881,29 @@ FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTab /* ZSTD_row_update_internalImpl(): * Updates the hash table with positions starting from updateStartIdx until updateEndIdx. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, - U32 updateStartIdx, U32 const updateEndIdx, - U32 const mls, U32 const rowLog, - U32 const rowMask, U32 const useCache) +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, + U32 updateStartIdx, U32 const updateEndIdx, + U32 const mls, U32 const rowLog, + U32 const rowMask, U32 const useCache) { U32* const hashTable = ms->hashTable; - U16* const tagTable = ms->tagTable; + BYTE* const tagTable = ms->tagTable; U32 const hashLog = ms->rowHashLog; const BYTE* const base = ms->window.base; DEBUGLOG(6, "ZSTD_row_update_internalImpl(): updateStartIdx=%u, updateEndIdx=%u", updateStartIdx, updateEndIdx); for (; updateStartIdx < updateEndIdx; ++updateStartIdx) { - U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls) - : (U32)ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls); + U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls, ms->hashSalt) + : (U32)ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt); U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; U32* const row = hashTable + relRow; - BYTE* tagRow = (BYTE*)(tagTable + relRow); /* Though tagTable is laid out as a table of U16, each tag is only 1 byte. - Explicit cast allows us to get exact desired position within each row */ + BYTE* tagRow = tagTable + relRow; U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); - assert(hash == ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls)); - ((BYTE*)tagRow)[pos + ZSTD_ROW_HASH_TAG_OFFSET] = hash & ZSTD_ROW_HASH_TAG_MASK; + assert(hash == ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt)); + tagRow[pos] = hash & ZSTD_ROW_HASH_TAG_MASK; row[pos] = updateStartIdx; } } @@ -932,9 +912,11 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms, * Inserts the byte at ip into the appropriate position in the hash table, and updates ms->nextToUpdate. * Skips sections of long matches as is necessary. */ -FORCE_INLINE_TEMPLATE void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const BYTE* ip, - U32 const mls, U32 const rowLog, - U32 const rowMask, U32 const useCache) +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_row_update_internal(ZSTD_matchState_t* ms, const BYTE* ip, + U32 const mls, U32 const rowLog, + U32 const rowMask, U32 const useCache) { U32 idx = ms->nextToUpdate; const BYTE* const base = ms->window.base; @@ -971,7 +953,35 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) { const U32 mls = MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */); DEBUGLOG(5, "ZSTD_row_update(), rowLog=%u", rowLog); - ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* dont use cache */); + ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* don't use cache */); +} + +/* Returns the mask width of bits group of which will be set to 1. Given not all + * architectures have easy movemask instruction, this helps to iterate over + * groups of bits easier and faster. + */ +FORCE_INLINE_TEMPLATE U32 +ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) +{ + assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); + assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); + (void)rowEntries; +#if defined(ZSTD_ARCH_ARM_NEON) + /* NEON path only works for little endian */ + if (!MEM_isLittleEndian()) { + return 1; + } + if (rowEntries == 16) { + return 4; + } + if (rowEntries == 32) { + return 2; + } + if (rowEntries == 64) { + return 1; + } +#endif + return 1; } #if defined(ZSTD_ARCH_X86_SSE2) @@ -994,71 +1004,82 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U } #endif -/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches - * the hash at the nth position in a row of the tagTable. - * Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield - * to match up with the actual layout of the entries within the hashTable */ +#if defined(ZSTD_ARCH_ARM_NEON) +FORCE_INLINE_TEMPLATE ZSTD_VecMask +ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag, const U32 headGrouped) +{ + assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); + if (rowEntries == 16) { + /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits. + * After that groups of 4 bits represent the equalMask. We lower + * all bits except the highest in these groups by doing AND with + * 0x88 = 0b10001000. + */ + const uint8x16_t chunk = vld1q_u8(src); + const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); + const uint8x8_t res = vshrn_n_u16(equalMask, 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0); + return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull; + } else if (rowEntries == 32) { + /* Same idea as with rowEntries == 16 but doing AND with + * 0x55 = 0b01010101. + */ + const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src); + const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); + const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); + const uint8x16_t dup = vdupq_n_u8(tag); + const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6); + const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6); + const uint8x8_t res = vsli_n_u8(t0, t1, 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ; + return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull; + } else { /* rowEntries == 64 */ + const uint8x16x4_t chunk = vld4q_u8(src); + const uint8x16_t dup = vdupq_n_u8(tag); + const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); + const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); + const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); + const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); + + const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); + const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); + const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); + const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); + const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); + const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); + return ZSTD_rotateRight_U64(matches, headGrouped); + } +} +#endif + +/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by + * ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag" + * matches the hash at the nth position in a row of the tagTable. + * Each row is a circular buffer beginning at the value of "headGrouped". So we + * must rotate the "matches" bitfield to match up with the actual layout of the + * entries within the hashTable */ FORCE_INLINE_TEMPLATE ZSTD_VecMask -ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) +ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries) { - const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET; + const BYTE* const src = tagRow; assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64); assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES); + assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8); #if defined(ZSTD_ARCH_X86_SSE2) - return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head); + return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped); #else /* SW or NEON-LE */ # if defined(ZSTD_ARCH_ARM_NEON) /* This NEON path only works for little endian - otherwise use SWAR below */ if (MEM_isLittleEndian()) { - if (rowEntries == 16) { - const uint8x16_t chunk = vld1q_u8(src); - const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag))); - const uint16x8_t t0 = vshlq_n_u16(equalMask, 7); - const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14)); - const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14)); - const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28)); - const U16 hi = (U16)vgetq_lane_u8(t3, 8); - const U16 lo = (U16)vgetq_lane_u8(t3, 0); - return ZSTD_rotateRight_U16((hi << 8) | lo, head); - } else if (rowEntries == 32) { - const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src); - const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]); - const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]); - const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag)); - const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag)); - const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0)); - const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1)); - const uint8x8_t t0 = vreinterpret_u8_s8(pack0); - const uint8x8_t t1 = vreinterpret_u8_s8(pack1); - const uint8x8_t t2 = vsri_n_u8(t1, t0, 2); - const uint8x8x2_t t3 = vuzp_u8(t2, t0); - const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4); - const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0); - return ZSTD_rotateRight_U32(matches, head); - } else { /* rowEntries == 64 */ - const uint8x16x4_t chunk = vld4q_u8(src); - const uint8x16_t dup = vdupq_n_u8(tag); - const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup); - const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup); - const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup); - const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup); - - const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1); - const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1); - const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2); - const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4); - const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4); - const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0); - return ZSTD_rotateRight_U64(matches, head); - } + return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped); } # endif /* ZSTD_ARCH_ARM_NEON */ /* SWAR */ - { const size_t chunkSize = sizeof(size_t); + { const int chunkSize = sizeof(size_t); const size_t shiftAmount = ((chunkSize * 8) - chunkSize); const size_t xFF = ~((size_t)0); const size_t x01 = xFF / 0xFF; @@ -1091,11 +1112,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, } matches = ~matches; if (rowEntries == 16) { - return ZSTD_rotateRight_U16((U16)matches, head); + return ZSTD_rotateRight_U16((U16)matches, headGrouped); } else if (rowEntries == 32) { - return ZSTD_rotateRight_U32((U32)matches, head); + return ZSTD_rotateRight_U32((U32)matches, headGrouped); } else { - return ZSTD_rotateRight_U64((U64)matches, head); + return ZSTD_rotateRight_U64((U64)matches, headGrouped); } } #endif @@ -1103,20 +1124,21 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, /* The high-level approach of the SIMD row based match finder is as follows: * - Figure out where to insert the new entry: - * - Generate a hash from a byte along with an additional 1-byte "short hash". The additional byte is our "tag" - * - The hashTable is effectively split into groups or "rows" of 16 or 32 entries of U32, and the hash determines + * - Generate a hash for current input posistion and split it into a one byte of tag and `rowHashLog` bits of index. + * - The hash is salted by a value that changes on every contex reset, so when the same table is used + * we will avoid collisions that would otherwise slow us down by intorducing phantom matches. + * - The hashTable is effectively split into groups or "rows" of 15 or 31 entries of U32, and the index determines * which row to insert into. - * - Determine the correct position within the row to insert the entry into. Each row of 16 or 32 can - * be considered as a circular buffer with a "head" index that resides in the tagTable. - * - Also insert the "tag" into the equivalent row and position in the tagTable. - * - Note: The tagTable has 17 or 33 1-byte entries per row, due to 16 or 32 tags, and 1 "head" entry. - * The 17 or 33 entry rows are spaced out to occur every 32 or 64 bytes, respectively, - * for alignment/performance reasons, leaving some bytes unused. - * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte "short hash" and + * - Determine the correct position within the row to insert the entry into. Each row of 15 or 31 can + * be considered as a circular buffer with a "head" index that resides in the tagTable (overall 16 or 32 bytes + * per row). + * - Use SIMD to efficiently compare the tags in the tagTable to the 1-byte tag calculated for the position and * generate a bitfield that we can cycle through to check the collisions in the hash table. * - Pick the longest match. + * - Insert the tag into the equivalent row and position in the tagTable. */ FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_RowFindBestMatch( ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iLimit, @@ -1125,7 +1147,7 @@ size_t ZSTD_RowFindBestMatch( const U32 rowLog) { U32* const hashTable = ms->hashTable; - U16* const tagTable = ms->tagTable; + BYTE* const tagTable = ms->tagTable; U32* const hashCache = ms->hashCache; const U32 hashLog = ms->rowHashLog; const ZSTD_compressionParameters* const cParams = &ms->cParams; @@ -1143,8 +1165,11 @@ size_t ZSTD_RowFindBestMatch( const U32 rowEntries = (1U << rowLog); const U32 rowMask = rowEntries - 1; const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */ + const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries); + const U64 hashSalt = ms->hashSalt; U32 nbAttempts = 1U << cappedSearchLog; size_t ml=4-1; + U32 hash; /* DMS/DDS variables that may be referenced laster */ const ZSTD_matchState_t* const dms = ms->dictMatchState; @@ -1168,7 +1193,7 @@ size_t ZSTD_RowFindBestMatch( if (dictMode == ZSTD_dictMatchState) { /* Prefetch DMS rows */ U32* const dmsHashTable = dms->hashTable; - U16* const dmsTagTable = dms->tagTable; + BYTE* const dmsTagTable = dms->tagTable; U32 const dmsHash = (U32)ZSTD_hashPtr(ip, dms->rowHashLog + ZSTD_ROW_HASH_TAG_BITS, mls); U32 const dmsRelRow = (dmsHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; dmsTag = dmsHash & ZSTD_ROW_HASH_TAG_MASK; @@ -1178,23 +1203,34 @@ size_t ZSTD_RowFindBestMatch( } /* Update the hashTable and tagTable up to (but not including) ip */ - ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); + if (!ms->lazySkipping) { + ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */); + hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls, hashSalt); + } else { + /* Stop inserting every position when in the lazy skipping mode. + * The hash cache is also not kept up to date in this mode. + */ + hash = (U32)ZSTD_hashPtrSalted(ip, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt); + ms->nextToUpdate = curr; + } + ms->hashSaltEntropy += hash; /* collect salt entropy */ + { /* Get the hash for ip, compute the appropriate row */ - U32 const hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls); U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog; U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK; U32* const row = hashTable + relRow; BYTE* tagRow = (BYTE*)(tagTable + relRow); - U32 const head = *tagRow & rowMask; + U32 const headGrouped = (*tagRow & rowMask) * groupWidth; U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; size_t numMatches = 0; size_t currMatch = 0; - ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries); + ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries); /* Cycle through the matches and prefetch */ - for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { - U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; + for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { + U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; U32 const matchIndex = row[matchPos]; + if(matchPos == 0) continue; assert(numMatches < rowEntries); if (matchIndex < lowLimit) break; @@ -1204,13 +1240,14 @@ size_t ZSTD_RowFindBestMatch( PREFETCH_L1(dictBase + matchIndex); } matchBuffer[numMatches++] = matchIndex; + --nbAttempts; } /* Speed opt: insert current byte into hashtable too. This allows us to avoid one iteration of the loop in ZSTD_row_update_internal() at the next search. */ { U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask); - tagRow[pos + ZSTD_ROW_HASH_TAG_OFFSET] = (BYTE)tag; + tagRow[pos] = (BYTE)tag; row[pos] = ms->nextToUpdate++; } @@ -1224,7 +1261,8 @@ size_t ZSTD_RowFindBestMatch( if ((dictMode != ZSTD_extDict) || matchIndex >= dictLimit) { const BYTE* const match = base + matchIndex; assert(matchIndex >= dictLimit); /* ensures this is true if dictMode != ZSTD_extDict */ - if (match[ml] == ip[ml]) /* potentially better */ + /* read 4B starting from (match + ml + 1 - sizeof(U32)) */ + if (MEM_read32(match + ml - 3) == MEM_read32(ip + ml - 3)) /* potentially better */ currentMl = ZSTD_count(ip, match, iLimit); } else { const BYTE* const match = dictBase + matchIndex; @@ -1236,7 +1274,7 @@ size_t ZSTD_RowFindBestMatch( /* Save best solution */ if (currentMl > ml) { ml = currentMl; - *offsetPtr = STORE_OFFSET(curr - matchIndex); + *offsetPtr = OFFSET_TO_OFFBASE(curr - matchIndex); if (ip+currentMl == iLimit) break; /* best possible, avoids read overflow on next attempt */ } } @@ -1254,19 +1292,21 @@ size_t ZSTD_RowFindBestMatch( const U32 dmsSize = (U32)(dmsEnd - dmsBase); const U32 dmsIndexDelta = dictLimit - dmsSize; - { U32 const head = *dmsTagRow & rowMask; + { U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth; U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES]; size_t numMatches = 0; size_t currMatch = 0; - ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries); + ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries); - for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) { - U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask; + for (; (matches > 0) && (nbAttempts > 0); matches &= (matches - 1)) { + U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask; U32 const matchIndex = dmsRow[matchPos]; + if(matchPos == 0) continue; if (matchIndex < dmsLowestIndex) break; PREFETCH_L1(dmsBase + matchIndex); matchBuffer[numMatches++] = matchIndex; + --nbAttempts; } /* Return the longest match */ @@ -1285,7 +1325,7 @@ size_t ZSTD_RowFindBestMatch( if (currentMl > ml) { ml = currentMl; assert(curr > matchIndex + dmsIndexDelta); - *offsetPtr = STORE_OFFSET(curr - (matchIndex + dmsIndexDelta)); + *offsetPtr = OFFSET_TO_OFFBASE(curr - (matchIndex + dmsIndexDelta)); if (ip+currentMl == iLimit) break; } } @@ -1472,8 +1512,9 @@ FORCE_INLINE_TEMPLATE size_t ZSTD_searchMax( * Common parser - lazy strategy *********************************/ -FORCE_INLINE_TEMPLATE size_t -ZSTD_compressBlock_lazy_generic( +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_lazy_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, @@ -1491,7 +1532,8 @@ ZSTD_compressBlock_lazy_generic( const U32 mls = BOUNDED(4, ms->cParams.minMatch, 6); const U32 rowLog = BOUNDED(4, ms->cParams.searchLog, 6); - U32 offset_1 = rep[0], offset_2 = rep[1], savedOffset=0; + U32 offset_1 = rep[0], offset_2 = rep[1]; + U32 offsetSaved1 = 0, offsetSaved2 = 0; const int isDMS = dictMode == ZSTD_dictMatchState; const int isDDS = dictMode == ZSTD_dedicatedDictSearch; @@ -1512,8 +1554,8 @@ ZSTD_compressBlock_lazy_generic( U32 const curr = (U32)(ip - base); U32 const windowLow = ZSTD_getLowestPrefixIndex(ms, curr, ms->cParams.windowLog); U32 const maxRep = curr - windowLow; - if (offset_2 > maxRep) savedOffset = offset_2, offset_2 = 0; - if (offset_1 > maxRep) savedOffset = offset_1, offset_1 = 0; + if (offset_2 > maxRep) offsetSaved2 = offset_2, offset_2 = 0; + if (offset_1 > maxRep) offsetSaved1 = offset_1, offset_1 = 0; } if (isDxS) { /* dictMatchState repCode checks don't currently handle repCode == 0 @@ -1522,10 +1564,11 @@ ZSTD_compressBlock_lazy_generic( assert(offset_2 <= dictAndPrefixLength); } + /* Reset the lazy skipping state */ + ms->lazySkipping = 0; + if (searchMethod == search_rowHash) { - ZSTD_row_fillHashCache(ms, base, rowLog, - MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), - ms->nextToUpdate, ilimit); + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); } /* Match Loop */ @@ -1537,7 +1580,7 @@ ZSTD_compressBlock_lazy_generic( #endif while (ip < ilimit) { size_t matchLength=0; - size_t offcode=STORE_REPCODE_1; + size_t offBase = REPCODE1_TO_OFFBASE; const BYTE* start=ip+1; DEBUGLOG(7, "search baseline (depth 0)"); @@ -1562,14 +1605,23 @@ ZSTD_compressBlock_lazy_generic( } /* first search (depth 0) */ - { size_t offsetFound = 999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, dictMode); + { size_t offbaseFound = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offbaseFound, mls, rowLog, searchMethod, dictMode); if (ml2 > matchLength) - matchLength = ml2, start = ip, offcode=offsetFound; + matchLength = ml2, start = ip, offBase = offbaseFound; } if (matchLength < 4) { - ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ + size_t const step = ((size_t)(ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */; + ip += step; + /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. + * In this mode we stop inserting every position into our tables, and only insert + * positions that we search, which is one in step positions. + * The exact cutoff is flexible, I've just chosen a number that is reasonably high, + * so we minimize the compression ratio loss in "normal" scenarios. This mode gets + * triggered once we've gone 2KB without finding any matches. + */ + ms->lazySkipping = step > kLazySkippingStep; continue; } @@ -1579,12 +1631,12 @@ ZSTD_compressBlock_lazy_generic( DEBUGLOG(7, "search depth 1"); ip ++; if ( (dictMode == ZSTD_noDict) - && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; int const gain2 = (int)(mlRep * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } if (isDxS) { const U32 repIndex = (U32)(ip - base) - offset_1; @@ -1596,17 +1648,17 @@ ZSTD_compressBlock_lazy_generic( const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; int const gain2 = (int)(mlRep * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } } - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); + { size_t ofbCandidate=999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; /* search a better one */ } } @@ -1615,12 +1667,12 @@ ZSTD_compressBlock_lazy_generic( DEBUGLOG(7, "search depth 2"); ip ++; if ( (dictMode == ZSTD_noDict) - && (offcode) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { + && (offBase) && ((offset_1>0) & (MEM_read32(ip) == MEM_read32(ip - offset_1)))) { size_t const mlRep = ZSTD_count(ip+4, ip+4-offset_1, iend) + 4; int const gain2 = (int)(mlRep * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } if (isDxS) { const U32 repIndex = (U32)(ip - base) - offset_1; @@ -1632,17 +1684,17 @@ ZSTD_compressBlock_lazy_generic( const BYTE* repMatchEnd = repIndex < prefixLowestIndex ? dictEnd : iend; size_t const mlRep = ZSTD_count_2segments(ip+4, repMatch+4, iend, repMatchEnd, prefixLowest) + 4; int const gain2 = (int)(mlRep * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((mlRep >= 4) && (gain2 > gain1)) - matchLength = mlRep, offcode = STORE_REPCODE_1, start = ip; + matchLength = mlRep, offBase = REPCODE1_TO_OFFBASE, start = ip; } } - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, dictMode); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); + { size_t ofbCandidate=999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, dictMode); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; } } } break; /* nothing found : store previous solution */ @@ -1653,26 +1705,33 @@ ZSTD_compressBlock_lazy_generic( * notably if `value` is unsigned, resulting in a large positive `-value`. */ /* catch up */ - if (STORED_IS_OFFSET(offcode)) { + if (OFFBASE_IS_OFFSET(offBase)) { if (dictMode == ZSTD_noDict) { - while ( ((start > anchor) & (start - STORED_OFFSET(offcode) > prefixLowest)) - && (start[-1] == (start-STORED_OFFSET(offcode))[-1]) ) /* only search for offset within prefix */ + while ( ((start > anchor) & (start - OFFBASE_TO_OFFSET(offBase) > prefixLowest)) + && (start[-1] == (start-OFFBASE_TO_OFFSET(offBase))[-1]) ) /* only search for offset within prefix */ { start--; matchLength++; } } if (isDxS) { - U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); + U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); const BYTE* match = (matchIndex < prefixLowestIndex) ? dictBase + matchIndex - dictIndexDelta : base + matchIndex; const BYTE* const mStart = (matchIndex < prefixLowestIndex) ? dictLowest : prefixLowest; while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ } - offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); + offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); } /* store sequence */ _storeSequence: { size_t const litLength = (size_t)(start - anchor); - ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); + ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); anchor = ip = start + matchLength; } + if (ms->lazySkipping) { + /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ + if (searchMethod == search_rowHash) { + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + ms->lazySkipping = 0; + } /* check immediate repcode */ if (isDxS) { @@ -1686,8 +1745,8 @@ ZSTD_compressBlock_lazy_generic( && (MEM_read32(repMatch) == MEM_read32(ip)) ) { const BYTE* const repEnd2 = repIndex < prefixLowestIndex ? dictEnd : iend; matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd2, prefixLowest) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset_2 <=> offset_1 */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset_2 <=> offset_1 */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; @@ -1701,166 +1760,181 @@ ZSTD_compressBlock_lazy_generic( && (MEM_read32(ip) == MEM_read32(ip - offset_2)) ) { /* store sequence */ matchLength = ZSTD_count(ip+4, ip+4-offset_2, iend) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap repcodes */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap repcodes */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; /* faster when present ... (?) */ } } } - /* Save reps for next block */ - rep[0] = offset_1 ? offset_1 : savedOffset; - rep[1] = offset_2 ? offset_2 : savedOffset; + /* If offset_1 started invalid (offsetSaved1 != 0) and became valid (offset_1 != 0), + * rotate saved offsets. See comment in ZSTD_compressBlock_fast_noDict for more context. */ + offsetSaved2 = ((offsetSaved1 != 0) && (offset_1 != 0)) ? offsetSaved1 : offsetSaved2; + + /* save reps for next block */ + rep[0] = offset_1 ? offset_1 : offsetSaved1; + rep[1] = offset_2 ? offset_2 : offsetSaved2; /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ -size_t ZSTD_compressBlock_btlazy2( +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_greedy( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); } -size_t ZSTD_compressBlock_lazy2( +size_t ZSTD_compressBlock_greedy_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy( +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_greedy( +size_t ZSTD_compressBlock_greedy_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); } -size_t ZSTD_compressBlock_btlazy2_dictMatchState( +size_t ZSTD_compressBlock_greedy_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy2_dictMatchState( +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_dictMatchState( +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dictMatchState( +size_t ZSTD_compressBlock_lazy_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dictMatchState); } - -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( +size_t ZSTD_compressBlock_lazy_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( +size_t ZSTD_compressBlock_lazy_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); } -/* Row-based matchfinder */ -size_t ZSTD_compressBlock_lazy2_row( +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_row( +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_row( +size_t ZSTD_compressBlock_lazy2_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_noDict); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_lazy2_dictMatchState_row( +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2, ZSTD_dedicatedDictSearch); } -size_t ZSTD_compressBlock_lazy_dictMatchState_row( +size_t ZSTD_compressBlock_lazy2_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dictMatchState_row( +size_t ZSTD_compressBlock_lazy2_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dictMatchState); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dictMatchState); } - size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2, ZSTD_dedicatedDictSearch); } +#endif -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_noDict); } -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( +size_t ZSTD_compressBlock_btlazy2_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0, ZSTD_dedicatedDictSearch); + return ZSTD_compressBlock_lazy_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2, ZSTD_dictMatchState); } +#endif +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_compressBlock_lazy_extDict_generic( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], @@ -1886,12 +1960,13 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( DEBUGLOG(5, "ZSTD_compressBlock_lazy_extDict_generic (searchFunc=%u)", (U32)searchMethod); + /* Reset the lazy skipping state */ + ms->lazySkipping = 0; + /* init */ ip += (ip == prefixStart); if (searchMethod == search_rowHash) { - ZSTD_row_fillHashCache(ms, base, rowLog, - MIN(ms->cParams.minMatch, 6 /* mls caps out at 6 */), - ms->nextToUpdate, ilimit); + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); } /* Match Loop */ @@ -1903,7 +1978,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( #endif while (ip < ilimit) { size_t matchLength=0; - size_t offcode=STORE_REPCODE_1; + size_t offBase = REPCODE1_TO_OFFBASE; const BYTE* start=ip+1; U32 curr = (U32)(ip-base); @@ -1922,14 +1997,23 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( } } /* first search (depth 0) */ - { size_t offsetFound = 999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offsetFound, mls, rowLog, searchMethod, ZSTD_extDict); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); if (ml2 > matchLength) - matchLength = ml2, start = ip, offcode=offsetFound; + matchLength = ml2, start = ip, offBase = ofbCandidate; } if (matchLength < 4) { - ip += ((ip-anchor) >> kSearchStrength) + 1; /* jump faster over incompressible sections */ + size_t const step = ((size_t)(ip-anchor) >> kSearchStrength); + ip += step + 1; /* jump faster over incompressible sections */ + /* Enter the lazy skipping mode once we are skipping more than 8 bytes at a time. + * In this mode we stop inserting every position into our tables, and only insert + * positions that we search, which is one in step positions. + * The exact cutoff is flexible, I've just chosen a number that is reasonably high, + * so we minimize the compression ratio loss in "normal" scenarios. This mode gets + * triggered once we've gone 2KB without finding any matches. + */ + ms->lazySkipping = step > kLazySkippingStep; continue; } @@ -1939,7 +2023,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( ip ++; curr++; /* check repCode */ - if (offcode) { + if (offBase) { const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); const U32 repIndex = (U32)(curr - offset_1); const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; @@ -1951,18 +2035,18 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; int const gain2 = (int)(repLength * 3); - int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*3 - ZSTD_highbit32((U32)offBase) + 1); if ((repLength >= 4) && (gain2 > gain1)) - matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; + matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; } } /* search match, depth 1 */ - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 4); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 4); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; /* search a better one */ } } @@ -1971,7 +2055,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( ip ++; curr++; /* check repCode */ - if (offcode) { + if (offBase) { const U32 windowLow = ZSTD_getLowestMatchIndex(ms, curr, windowLog); const U32 repIndex = (U32)(curr - offset_1); const BYTE* const repBase = repIndex < dictLimit ? dictBase : base; @@ -1983,38 +2067,45 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; size_t const repLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; int const gain2 = (int)(repLength * 4); - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 1); + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 1); if ((repLength >= 4) && (gain2 > gain1)) - matchLength = repLength, offcode = STORE_REPCODE_1, start = ip; + matchLength = repLength, offBase = REPCODE1_TO_OFFBASE, start = ip; } } /* search match, depth 2 */ - { size_t offset2=999999999; - size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &offset2, mls, rowLog, searchMethod, ZSTD_extDict); - int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offset2))); /* raw approx */ - int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)STORED_TO_OFFBASE(offcode)) + 7); + { size_t ofbCandidate = 999999999; + size_t const ml2 = ZSTD_searchMax(ms, ip, iend, &ofbCandidate, mls, rowLog, searchMethod, ZSTD_extDict); + int const gain2 = (int)(ml2*4 - ZSTD_highbit32((U32)ofbCandidate)); /* raw approx */ + int const gain1 = (int)(matchLength*4 - ZSTD_highbit32((U32)offBase) + 7); if ((ml2 >= 4) && (gain2 > gain1)) { - matchLength = ml2, offcode = offset2, start = ip; + matchLength = ml2, offBase = ofbCandidate, start = ip; continue; } } } break; /* nothing found : store previous solution */ } /* catch up */ - if (STORED_IS_OFFSET(offcode)) { - U32 const matchIndex = (U32)((size_t)(start-base) - STORED_OFFSET(offcode)); + if (OFFBASE_IS_OFFSET(offBase)) { + U32 const matchIndex = (U32)((size_t)(start-base) - OFFBASE_TO_OFFSET(offBase)); const BYTE* match = (matchIndex < dictLimit) ? dictBase + matchIndex : base + matchIndex; const BYTE* const mStart = (matchIndex < dictLimit) ? dictStart : prefixStart; while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */ - offset_2 = offset_1; offset_1 = (U32)STORED_OFFSET(offcode); + offset_2 = offset_1; offset_1 = (U32)OFFBASE_TO_OFFSET(offBase); } /* store sequence */ _storeSequence: { size_t const litLength = (size_t)(start - anchor); - ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offcode, matchLength); + ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offBase, matchLength); anchor = ip = start + matchLength; } + if (ms->lazySkipping) { + /* We've found a match, disable lazy skipping mode, and refill the hash cache. */ + if (searchMethod == search_rowHash) { + ZSTD_row_fillHashCache(ms, base, rowLog, mls, ms->nextToUpdate, ilimit); + } + ms->lazySkipping = 0; + } /* check immediate repcode */ while (ip <= ilimit) { @@ -2029,8 +2120,8 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( /* repcode detected we should take it */ const BYTE* const repEnd = repIndex < dictLimit ? dictEnd : iend; matchLength = ZSTD_count_2segments(ip+4, repMatch+4, iend, repEnd, prefixStart) + 4; - offcode = offset_2; offset_2 = offset_1; offset_1 = (U32)offcode; /* swap offset history */ - ZSTD_storeSeq(seqStore, 0, anchor, iend, STORE_REPCODE_1, matchLength); + offBase = offset_2; offset_2 = offset_1; offset_1 = (U32)offBase; /* swap offset history */ + ZSTD_storeSeq(seqStore, 0, anchor, iend, REPCODE1_TO_OFFBASE, matchLength); ip += matchLength; anchor = ip; continue; /* faster when present ... (?) */ @@ -2045,8 +2136,9 @@ size_t ZSTD_compressBlock_lazy_extDict_generic( /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ - +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_greedy_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) @@ -2054,49 +2146,55 @@ size_t ZSTD_compressBlock_greedy_extDict( return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 0); } -size_t ZSTD_compressBlock_lazy_extDict( +size_t ZSTD_compressBlock_greedy_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) - { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); } +#endif -size_t ZSTD_compressBlock_lazy2_extDict( +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 1); } -size_t ZSTD_compressBlock_btlazy2_extDict( +size_t ZSTD_compressBlock_lazy_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); } +#endif -size_t ZSTD_compressBlock_greedy_extDict_row( +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) + { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 0); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_hashChain, 2); } -size_t ZSTD_compressBlock_lazy_extDict_row( +size_t ZSTD_compressBlock_lazy2_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) - { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 1); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); } +#endif -size_t ZSTD_compressBlock_lazy2_extDict_row( +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize) { - return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_rowHash, 2); + return ZSTD_compressBlock_lazy_extDict_generic(ms, seqStore, rep, src, srcSize, search_binaryTree, 2); } +#endif diff --git a/lib/zstd/compress/zstd_lazy.h b/lib/zstd/compress/zstd_lazy.h index e5bdf4df8dde..22c9201f4e63 100644 --- a/lib/zstd/compress/zstd_lazy.h +++ b/lib/zstd/compress/zstd_lazy.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -22,98 +23,175 @@ */ #define ZSTD_LAZY_DDSS_BUCKET_LOG 2 +#define ZSTD_ROW_HASH_TAG_BITS 8 /* nb bits to use for the tag */ + +#if !defined(ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) U32 ZSTD_insertAndFindFirstIndex(ZSTD_matchState_t* ms, const BYTE* ip); void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip); void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const BYTE* const ip); void ZSTD_preserveUnsortedMark (U32* const table, U32 const size, U32 const reducerValue); /*! used in ZSTD_reduceIndex(). preemptively increase value of ZSTD_DUBT_UNSORTED_MARK */ +#endif -size_t ZSTD_compressBlock_btlazy2( +#ifndef ZSTD_EXCLUDE_GREEDY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_greedy( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2( +size_t ZSTD_compressBlock_greedy_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy( +size_t ZSTD_compressBlock_greedy_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy( +size_t ZSTD_compressBlock_greedy_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_row( +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_row( +size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_row( +size_t ZSTD_compressBlock_greedy_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_btlazy2_dictMatchState( +size_t ZSTD_compressBlock_greedy_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dictMatchState( + +#define ZSTD_COMPRESSBLOCK_GREEDY ZSTD_compressBlock_greedy +#define ZSTD_COMPRESSBLOCK_GREEDY_ROW ZSTD_compressBlock_greedy_row +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE ZSTD_compressBlock_greedy_dictMatchState +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW ZSTD_compressBlock_greedy_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH ZSTD_compressBlock_greedy_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_greedy_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT ZSTD_compressBlock_greedy_extDict +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW ZSTD_compressBlock_greedy_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_GREEDY NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_GREEDY_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_LAZY_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dictMatchState( +size_t ZSTD_compressBlock_lazy_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dictMatchState( +size_t ZSTD_compressBlock_lazy_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dictMatchState_row( +size_t ZSTD_compressBlock_lazy_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dictMatchState_row( +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dictMatchState_row( +size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( +size_t ZSTD_compressBlock_lazy_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch( +size_t ZSTD_compressBlock_lazy_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch( + +#define ZSTD_COMPRESSBLOCK_LAZY ZSTD_compressBlock_lazy +#define ZSTD_COMPRESSBLOCK_LAZY_ROW ZSTD_compressBlock_lazy_row +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE ZSTD_compressBlock_lazy_dictMatchState +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT ZSTD_compressBlock_lazy_extDict +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW ZSTD_compressBlock_lazy_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_LAZY NULL +#define ZSTD_COMPRESSBLOCK_LAZY_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_LAZY_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_LAZY_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_LAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_lazy2( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( +size_t ZSTD_compressBlock_lazy2_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_dedicatedDictSearch_row( +size_t ZSTD_compressBlock_lazy2_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_dedicatedDictSearch_row( +size_t ZSTD_compressBlock_lazy2_dictMatchState_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_greedy_extDict( +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_extDict( +size_t ZSTD_compressBlock_lazy2_dedicatedDictSearch_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_lazy2_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_greedy_extDict_row( +size_t ZSTD_compressBlock_lazy2_extDict_row( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy_extDict_row( + +#define ZSTD_COMPRESSBLOCK_LAZY2 ZSTD_compressBlock_lazy2 +#define ZSTD_COMPRESSBLOCK_LAZY2_ROW ZSTD_compressBlock_lazy2_row +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE ZSTD_compressBlock_lazy2_dictMatchState +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW ZSTD_compressBlock_lazy2_dictMatchState_row +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH ZSTD_compressBlock_lazy2_dedicatedDictSearch +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW ZSTD_compressBlock_lazy2_dedicatedDictSearch_row +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT ZSTD_compressBlock_lazy2_extDict +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW ZSTD_compressBlock_lazy2_extDict_row +#else +#define ZSTD_COMPRESSBLOCK_LAZY2 NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DICTMATCHSTATE_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_DEDICATEDDICTSEARCH_ROW NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_LAZY2_EXTDICT_ROW NULL +#endif + +#ifndef ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btlazy2( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_lazy2_extDict_row( +size_t ZSTD_compressBlock_btlazy2_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_btlazy2_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - + +#define ZSTD_COMPRESSBLOCK_BTLAZY2 ZSTD_compressBlock_btlazy2 +#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE ZSTD_compressBlock_btlazy2_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT ZSTD_compressBlock_btlazy2_extDict +#else +#define ZSTD_COMPRESSBLOCK_BTLAZY2 NULL +#define ZSTD_COMPRESSBLOCK_BTLAZY2_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTLAZY2_EXTDICT NULL +#endif + #endif /* ZSTD_LAZY_H */ diff --git a/lib/zstd/compress/zstd_ldm.c b/lib/zstd/compress/zstd_ldm.c index dd86fc83e7dd..07f3bc6437ce 100644 --- a/lib/zstd/compress/zstd_ldm.c +++ b/lib/zstd/compress/zstd_ldm.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -242,11 +243,15 @@ static size_t ZSTD_ldm_fillFastTables(ZSTD_matchState_t* ms, switch(ms->cParams.strategy) { case ZSTD_fast: - ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast); + ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); break; case ZSTD_dfast: - ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast); +#ifndef ZSTD_EXCLUDE_DFAST_BLOCK_COMPRESSOR + ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast, ZSTD_tfp_forCCtx); +#else + assert(0); /* shouldn't be called: cparams should've been adjusted. */ +#endif break; case ZSTD_greedy: @@ -318,7 +323,9 @@ static void ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t* ms, const BYTE* anchor) } } -static size_t ZSTD_ldm_generateSequences_internal( +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_ldm_generateSequences_internal( ldmState_t* ldmState, rawSeqStore_t* rawSeqStore, ldmParams_t const* params, void const* src, size_t srcSize) { @@ -549,7 +556,7 @@ size_t ZSTD_ldm_generateSequences( * the window through early invalidation. * TODO: * Test the chunk size. * * Try invalidation after the sequence generation and test the - * the offset against maxDist directly. + * offset against maxDist directly. * * NOTE: Because of dictionaries + sequence splitting we MUST make sure * that any offset used is valid at the END of the sequence, since it may @@ -689,7 +696,6 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, /* maybeSplitSequence updates rawSeqStore->pos */ rawSeq const sequence = maybeSplitSequence(rawSeqStore, (U32)(iend - ip), minMatch); - int i; /* End signal */ if (sequence.offset == 0) break; @@ -702,6 +708,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, /* Run the block compressor */ DEBUGLOG(5, "pos %u : calling block compressor on segment of size %u", (unsigned)(ip-istart), sequence.litLength); { + int i; size_t const newLitLength = blockCompressor(ms, seqStore, rep, ip, sequence.litLength); ip += sequence.litLength; @@ -711,7 +718,7 @@ size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore, rep[0] = sequence.offset; /* Store the sequence */ ZSTD_storeSeq(seqStore, newLitLength, ip - newLitLength, iend, - STORE_OFFSET(sequence.offset), + OFFSET_TO_OFFBASE(sequence.offset), sequence.matchLength); ip += sequence.matchLength; } diff --git a/lib/zstd/compress/zstd_ldm.h b/lib/zstd/compress/zstd_ldm.h index fbc6a5e88fd7..c540731abde7 100644 --- a/lib/zstd/compress/zstd_ldm.h +++ b/lib/zstd/compress/zstd_ldm.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_ldm_geartab.h b/lib/zstd/compress/zstd_ldm_geartab.h index 647f865be290..cfccfc46f6f7 100644 --- a/lib/zstd/compress/zstd_ldm_geartab.h +++ b/lib/zstd/compress/zstd_ldm_geartab.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/compress/zstd_opt.c b/lib/zstd/compress/zstd_opt.c index fd82acfda62f..a87b66ac8d24 100644 --- a/lib/zstd/compress/zstd_opt.c +++ b/lib/zstd/compress/zstd_opt.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Przemyslaw Skibinski, Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -12,11 +13,14 @@ #include "hist.h" #include "zstd_opt.h" +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) #define ZSTD_LITFREQ_ADD 2 /* scaling factor for litFreq, so that frequencies adapt faster to new stats */ #define ZSTD_MAX_PRICE (1<<30) -#define ZSTD_PREDEF_THRESHOLD 1024 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ +#define ZSTD_PREDEF_THRESHOLD 8 /* if srcSize < ZSTD_PREDEF_THRESHOLD, symbols' cost is assumed static, directly determined by pre-defined distributions */ /*-************************************* @@ -26,27 +30,35 @@ #if 0 /* approximation at bit level (for tests) */ # define BITCOST_ACCURACY 0 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat, opt) ((void)opt, ZSTD_bitWeight(stat)) +# define WEIGHT(stat, opt) ((void)(opt), ZSTD_bitWeight(stat)) #elif 0 /* fractional bit accuracy (for tests) */ # define BITCOST_ACCURACY 8 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat,opt) ((void)opt, ZSTD_fracWeight(stat)) +# define WEIGHT(stat,opt) ((void)(opt), ZSTD_fracWeight(stat)) #else /* opt==approx, ultra==accurate */ # define BITCOST_ACCURACY 8 # define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY) -# define WEIGHT(stat,opt) (opt ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) +# define WEIGHT(stat,opt) ((opt) ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat)) #endif +/* ZSTD_bitWeight() : + * provide estimated "cost" of a stat in full bits only */ MEM_STATIC U32 ZSTD_bitWeight(U32 stat) { return (ZSTD_highbit32(stat+1) * BITCOST_MULTIPLIER); } +/* ZSTD_fracWeight() : + * provide fractional-bit "cost" of a stat, + * using linear interpolation approximation */ MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) { U32 const stat = rawStat + 1; U32 const hb = ZSTD_highbit32(stat); U32 const BWeight = hb * BITCOST_MULTIPLIER; + /* Fweight was meant for "Fractional weight" + * but it's effectively a value between 1 and 2 + * using fixed point arithmetic */ U32 const FWeight = (stat << BITCOST_ACCURACY) >> hb; U32 const weight = BWeight + FWeight; assert(hb + BITCOST_ACCURACY < 31); @@ -57,7 +69,7 @@ MEM_STATIC U32 ZSTD_fracWeight(U32 rawStat) /* debugging function, * @return price in bytes as fractional value * for debug messages only */ -MEM_STATIC double ZSTD_fCost(U32 price) +MEM_STATIC double ZSTD_fCost(int price) { return (double)price / (BITCOST_MULTIPLIER*8); } @@ -88,20 +100,26 @@ static U32 sum_u32(const unsigned table[], size_t nbElts) return total; } -static U32 ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift) +typedef enum { base_0possible=0, base_1guaranteed=1 } base_directive_e; + +static U32 +ZSTD_downscaleStats(unsigned* table, U32 lastEltIndex, U32 shift, base_directive_e base1) { U32 s, sum=0; - DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", (unsigned)lastEltIndex+1, (unsigned)shift); + DEBUGLOG(5, "ZSTD_downscaleStats (nbElts=%u, shift=%u)", + (unsigned)lastEltIndex+1, (unsigned)shift ); assert(shift < 30); for (s=0; s> shift); - sum += table[s]; + unsigned const base = base1 ? 1 : (table[s]>0); + unsigned const newStat = base + (table[s] >> shift); + sum += newStat; + table[s] = newStat; } return sum; } /* ZSTD_scaleStats() : - * reduce all elements in table is sum too large + * reduce all elt frequencies in table if sum too large * return the resulting sum of elements */ static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) { @@ -110,7 +128,7 @@ static U32 ZSTD_scaleStats(unsigned* table, U32 lastEltIndex, U32 logTarget) DEBUGLOG(5, "ZSTD_scaleStats (nbElts=%u, target=%u)", (unsigned)lastEltIndex+1, (unsigned)logTarget); assert(logTarget < 30); if (factor <= 1) return prevsum; - return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor)); + return ZSTD_downscaleStats(table, lastEltIndex, ZSTD_highbit32(factor), base_1guaranteed); } /* ZSTD_rescaleFreqs() : @@ -129,18 +147,22 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, DEBUGLOG(5, "ZSTD_rescaleFreqs (srcSize=%u)", (unsigned)srcSize); optPtr->priceType = zop_dynamic; - if (optPtr->litLengthSum == 0) { /* first block : init */ - if (srcSize <= ZSTD_PREDEF_THRESHOLD) { /* heuristic */ - DEBUGLOG(5, "(srcSize <= ZSTD_PREDEF_THRESHOLD) => zop_predef"); + if (optPtr->litLengthSum == 0) { /* no literals stats collected -> first block assumed -> init */ + + /* heuristic: use pre-defined stats for too small inputs */ + if (srcSize <= ZSTD_PREDEF_THRESHOLD) { + DEBUGLOG(5, "srcSize <= %i : use predefined stats", ZSTD_PREDEF_THRESHOLD); optPtr->priceType = zop_predef; } assert(optPtr->symbolCosts != NULL); if (optPtr->symbolCosts->huf.repeatMode == HUF_repeat_valid) { - /* huffman table presumed generated by dictionary */ + + /* huffman stats covering the full value set : table presumed generated by dictionary */ optPtr->priceType = zop_dynamic; if (compressedLiterals) { + /* generate literals statistics from huffman table */ unsigned lit; assert(optPtr->litFreq != NULL); optPtr->litSum = 0; @@ -188,13 +210,14 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, optPtr->offCodeSum += optPtr->offCodeFreq[of]; } } - } else { /* not a dictionary */ + } else { /* first block, no dictionary */ assert(optPtr->litFreq != NULL); if (compressedLiterals) { + /* base initial cost of literals on direct frequency within src */ unsigned lit = MaxLit; HIST_count_simple(optPtr->litFreq, &lit, src, srcSize); /* use raw first block to init statistics */ - optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8); + optPtr->litSum = ZSTD_downscaleStats(optPtr->litFreq, MaxLit, 8, base_0possible); } { unsigned const baseLLfreqs[MaxLL+1] = { @@ -224,10 +247,9 @@ ZSTD_rescaleFreqs(optState_t* const optPtr, optPtr->offCodeSum = sum_u32(baseOFCfreqs, MaxOff+1); } - } - } else { /* new block : re-use previous statistics, scaled down */ + } else { /* new block : scale down accumulated statistics */ if (compressedLiterals) optPtr->litSum = ZSTD_scaleStats(optPtr->litFreq, MaxLit, 12); @@ -246,6 +268,7 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, const optState_t* const optPtr, int optLevel) { + DEBUGLOG(8, "ZSTD_rawLiteralsCost (%u literals)", litLength); if (litLength == 0) return 0; if (!ZSTD_compressedLiterals(optPtr)) @@ -255,11 +278,14 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength, return (litLength*6) * BITCOST_MULTIPLIER; /* 6 bit per literal - no statistic used */ /* dynamic statistics */ - { U32 price = litLength * optPtr->litSumBasePrice; + { U32 price = optPtr->litSumBasePrice * litLength; + U32 const litPriceMax = optPtr->litSumBasePrice - BITCOST_MULTIPLIER; U32 u; + assert(optPtr->litSumBasePrice >= BITCOST_MULTIPLIER); for (u=0; u < litLength; u++) { - assert(WEIGHT(optPtr->litFreq[literals[u]], optLevel) <= optPtr->litSumBasePrice); /* literal cost should never be negative */ - price -= WEIGHT(optPtr->litFreq[literals[u]], optLevel); + U32 litPrice = WEIGHT(optPtr->litFreq[literals[u]], optLevel); + if (UNLIKELY(litPrice > litPriceMax)) litPrice = litPriceMax; + price -= litPrice; } return price; } @@ -272,10 +298,11 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP assert(litLength <= ZSTD_BLOCKSIZE_MAX); if (optPtr->priceType == zop_predef) return WEIGHT(litLength, optLevel); - /* We can't compute the litLength price for sizes >= ZSTD_BLOCKSIZE_MAX - * because it isn't representable in the zstd format. So instead just - * call it 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. In this case the block - * would be all literals. + + /* ZSTD_LLcode() can't compute litLength price for sizes >= ZSTD_BLOCKSIZE_MAX + * because it isn't representable in the zstd format. + * So instead just pretend it would cost 1 bit more than ZSTD_BLOCKSIZE_MAX - 1. + * In such a case, the block would be all literals. */ if (litLength == ZSTD_BLOCKSIZE_MAX) return BITCOST_MULTIPLIER + ZSTD_litLengthPrice(ZSTD_BLOCKSIZE_MAX - 1, optPtr, optLevel); @@ -289,24 +316,25 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP } /* ZSTD_getMatchPrice() : - * Provides the cost of the match part (offset + matchLength) of a sequence + * Provides the cost of the match part (offset + matchLength) of a sequence. * Must be combined with ZSTD_fullLiteralsCost() to get the full cost of a sequence. - * @offcode : expects a scale where 0,1,2 are repcodes 1-3, and 3+ are real_offsets+2 + * @offBase : sumtype, representing an offset or a repcode, and using numeric representation of ZSTD_storeSeq() * @optLevel: when <2, favors small offset for decompression speed (improved cache efficiency) */ FORCE_INLINE_TEMPLATE U32 -ZSTD_getMatchPrice(U32 const offcode, +ZSTD_getMatchPrice(U32 const offBase, U32 const matchLength, const optState_t* const optPtr, int const optLevel) { U32 price; - U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offcode)); + U32 const offCode = ZSTD_highbit32(offBase); U32 const mlBase = matchLength - MINMATCH; assert(matchLength >= MINMATCH); - if (optPtr->priceType == zop_predef) /* fixed scheme, do not use statistics */ - return WEIGHT(mlBase, optLevel) + ((16 + offCode) * BITCOST_MULTIPLIER); + if (optPtr->priceType == zop_predef) /* fixed scheme, does not use statistics */ + return WEIGHT(mlBase, optLevel) + + ((16 + offCode) * BITCOST_MULTIPLIER); /* emulated offset cost */ /* dynamic statistics */ price = (offCode * BITCOST_MULTIPLIER) + (optPtr->offCodeSumBasePrice - WEIGHT(optPtr->offCodeFreq[offCode], optLevel)); @@ -325,10 +353,10 @@ ZSTD_getMatchPrice(U32 const offcode, } /* ZSTD_updateStats() : - * assumption : literals + litLengtn <= iend */ + * assumption : literals + litLength <= iend */ static void ZSTD_updateStats(optState_t* const optPtr, U32 litLength, const BYTE* literals, - U32 offsetCode, U32 matchLength) + U32 offBase, U32 matchLength) { /* literals */ if (ZSTD_compressedLiterals(optPtr)) { @@ -344,8 +372,8 @@ static void ZSTD_updateStats(optState_t* const optPtr, optPtr->litLengthSum++; } - /* offset code : expected to follow storeSeq() numeric representation */ - { U32 const offCode = ZSTD_highbit32(STORED_TO_OFFBASE(offsetCode)); + /* offset code : follows storeSeq() numeric representation */ + { U32 const offCode = ZSTD_highbit32(offBase); assert(offCode <= MaxOff); optPtr->offCodeFreq[offCode]++; optPtr->offCodeSum++; @@ -379,9 +407,11 @@ MEM_STATIC U32 ZSTD_readMINMATCH(const void* memPtr, U32 length) /* Update hashTable3 up to ip (excluded) Assumption : always within prefix (i.e. not within extDict) */ -static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, - U32* nextToUpdate3, - const BYTE* const ip) +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, + U32* nextToUpdate3, + const BYTE* const ip) { U32* const hashTable3 = ms->hashTable3; U32 const hashLog3 = ms->hashLog3; @@ -408,7 +438,9 @@ static U32 ZSTD_insertAndFindFirstIndexHash3 (const ZSTD_matchState_t* ms, * @param ip assumed <= iend-8 . * @param target The target of ZSTD_updateTree_internal() - we are filling to this position * @return : nb of positions added */ -static U32 ZSTD_insertBt1( +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_insertBt1( const ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iend, U32 const target, @@ -527,6 +559,7 @@ static U32 ZSTD_insertBt1( } FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR void ZSTD_updateTree_internal( ZSTD_matchState_t* ms, const BYTE* const ip, const BYTE* const iend, @@ -535,7 +568,7 @@ void ZSTD_updateTree_internal( const BYTE* const base = ms->window.base; U32 const target = (U32)(ip - base); U32 idx = ms->nextToUpdate; - DEBUGLOG(6, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", + DEBUGLOG(7, "ZSTD_updateTree_internal, from %u to %u (dictMode:%u)", idx, target, dictMode); while(idx < target) { @@ -553,15 +586,18 @@ void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend) { } FORCE_INLINE_TEMPLATE -U32 ZSTD_insertBtAndGetAllMatches ( - ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ - ZSTD_matchState_t* ms, - U32* nextToUpdate3, - const BYTE* const ip, const BYTE* const iLimit, const ZSTD_dictMode_e dictMode, - const U32 rep[ZSTD_REP_NUM], - U32 const ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ - const U32 lengthToBeat, - U32 const mls /* template */) +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 +ZSTD_insertBtAndGetAllMatches ( + ZSTD_match_t* matches, /* store result (found matches) in this table (presumed large enough) */ + ZSTD_matchState_t* ms, + U32* nextToUpdate3, + const BYTE* const ip, const BYTE* const iLimit, + const ZSTD_dictMode_e dictMode, + const U32 rep[ZSTD_REP_NUM], + const U32 ll0, /* tells if associated literal length is 0 or not. This value must be 0 or 1 */ + const U32 lengthToBeat, + const U32 mls /* template */) { const ZSTD_compressionParameters* const cParams = &ms->cParams; U32 const sufficient_len = MIN(cParams->targetLength, ZSTD_OPT_NUM -1); @@ -644,7 +680,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( DEBUGLOG(8, "found repCode %u (ll0:%u, offset:%u) of length %u", repCode, ll0, repOffset, repLen); bestLength = repLen; - matches[mnum].off = STORE_REPCODE(repCode - ll0 + 1); /* expect value between 1 and 3 */ + matches[mnum].off = REPCODE_TO_OFFBASE(repCode - ll0 + 1); /* expect value between 1 and 3 */ matches[mnum].len = (U32)repLen; mnum++; if ( (repLen > sufficient_len) @@ -673,7 +709,7 @@ U32 ZSTD_insertBtAndGetAllMatches ( bestLength = mlen; assert(curr > matchIndex3); assert(mnum==0); /* no prior solution */ - matches[0].off = STORE_OFFSET(curr - matchIndex3); + matches[0].off = OFFSET_TO_OFFBASE(curr - matchIndex3); matches[0].len = (U32)mlen; mnum = 1; if ( (mlen > sufficient_len) | @@ -706,13 +742,13 @@ U32 ZSTD_insertBtAndGetAllMatches ( } if (matchLength > bestLength) { - DEBUGLOG(8, "found match of length %u at distance %u (offCode=%u)", - (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); + DEBUGLOG(8, "found match of length %u at distance %u (offBase=%u)", + (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); assert(matchEndIdx > matchIndex); if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; bestLength = matchLength; - matches[mnum].off = STORE_OFFSET(curr - matchIndex); + matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); matches[mnum].len = (U32)matchLength; mnum++; if ( (matchLength > ZSTD_OPT_NUM) @@ -754,12 +790,12 @@ U32 ZSTD_insertBtAndGetAllMatches ( if (matchLength > bestLength) { matchIndex = dictMatchIndex + dmsIndexDelta; - DEBUGLOG(8, "found dms match of length %u at distance %u (offCode=%u)", - (U32)matchLength, curr - matchIndex, STORE_OFFSET(curr - matchIndex)); + DEBUGLOG(8, "found dms match of length %u at distance %u (offBase=%u)", + (U32)matchLength, curr - matchIndex, OFFSET_TO_OFFBASE(curr - matchIndex)); if (matchLength > matchEndIdx - matchIndex) matchEndIdx = matchIndex + (U32)matchLength; bestLength = matchLength; - matches[mnum].off = STORE_OFFSET(curr - matchIndex); + matches[mnum].off = OFFSET_TO_OFFBASE(curr - matchIndex); matches[mnum].len = (U32)matchLength; mnum++; if ( (matchLength > ZSTD_OPT_NUM) @@ -792,7 +828,9 @@ typedef U32 (*ZSTD_getAllMatchesFn)( U32 const ll0, U32 const lengthToBeat); -FORCE_INLINE_TEMPLATE U32 ZSTD_btGetAllMatches_internal( +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +U32 ZSTD_btGetAllMatches_internal( ZSTD_match_t* matches, ZSTD_matchState_t* ms, U32* nextToUpdate3, @@ -960,7 +998,7 @@ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, const ZSTD_optLdm_t* optLdm, U32 currPosInBlock) { U32 const posDiff = currPosInBlock - optLdm->startPosInBlock; - /* Note: ZSTD_match_t actually contains offCode and matchLength (before subtracting MINMATCH) */ + /* Note: ZSTD_match_t actually contains offBase and matchLength (before subtracting MINMATCH) */ U32 const candidateMatchLength = optLdm->endPosInBlock - optLdm->startPosInBlock - posDiff; /* Ensure that current block position is not outside of the match */ @@ -971,11 +1009,11 @@ static void ZSTD_optLdm_maybeAddMatch(ZSTD_match_t* matches, U32* nbMatches, } if (*nbMatches == 0 || ((candidateMatchLength > matches[*nbMatches-1].len) && *nbMatches < ZSTD_OPT_NUM)) { - U32 const candidateOffCode = STORE_OFFSET(optLdm->offset); - DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offCode: %u matchLength %u) at block position=%u", - candidateOffCode, candidateMatchLength, currPosInBlock); + U32 const candidateOffBase = OFFSET_TO_OFFBASE(optLdm->offset); + DEBUGLOG(6, "ZSTD_optLdm_maybeAddMatch(): Adding ldm candidate match (offBase: %u matchLength %u) at block position=%u", + candidateOffBase, candidateMatchLength, currPosInBlock); matches[*nbMatches].len = candidateMatchLength; - matches[*nbMatches].off = candidateOffCode; + matches[*nbMatches].off = candidateOffBase; (*nbMatches)++; } } @@ -1011,11 +1049,6 @@ ZSTD_optLdm_processMatchCandidate(ZSTD_optLdm_t* optLdm, * Optimal parser *********************************/ -static U32 ZSTD_totalLen(ZSTD_optimal_t sol) -{ - return sol.litlen + sol.mlen; -} - #if 0 /* debug */ static void @@ -1033,7 +1066,13 @@ listStats(const U32* table, int lastEltID) #endif -FORCE_INLINE_TEMPLATE size_t +#define LIT_PRICE(_p) (int)ZSTD_rawLiteralsCost(_p, 1, optStatePtr, optLevel) +#define LL_PRICE(_l) (int)ZSTD_litLengthPrice(_l, optStatePtr, optLevel) +#define LL_INCPRICE(_l) (LL_PRICE(_l) - LL_PRICE(_l-1)) + +FORCE_INLINE_TEMPLATE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], @@ -1059,9 +1098,11 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, ZSTD_optimal_t* const opt = optStatePtr->priceTable; ZSTD_match_t* const matches = optStatePtr->matchTable; - ZSTD_optimal_t lastSequence; + ZSTD_optimal_t lastStretch; ZSTD_optLdm_t optLdm; + ZSTD_memset(&lastStretch, 0, sizeof(ZSTD_optimal_t)); + optLdm.seqStore = ms->ldmSeqStore ? *ms->ldmSeqStore : kNullRawSeqStore; optLdm.endPosInBlock = optLdm.startPosInBlock = optLdm.offset = 0; ZSTD_opt_getNextMatchAndUpdateSeqStore(&optLdm, (U32)(ip-istart), (U32)(iend-ip)); @@ -1082,103 +1123,139 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, U32 const ll0 = !litlen; U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, ip, iend, rep, ll0, minMatch); ZSTD_optLdm_processMatchCandidate(&optLdm, matches, &nbMatches, - (U32)(ip-istart), (U32)(iend - ip)); - if (!nbMatches) { ip++; continue; } + (U32)(ip-istart), (U32)(iend-ip)); + if (!nbMatches) { + DEBUGLOG(8, "no match found at cPos %u", (unsigned)(ip-istart)); + ip++; + continue; + } + + /* Match found: let's store this solution, and eventually find more candidates. + * During this forward pass, @opt is used to store stretches, + * defined as "a match followed by N literals". + * Note how this is different from a Sequence, which is "N literals followed by a match". + * Storing stretches allows us to store different match predecessors + * for each literal position part of a literals run. */ /* initialize opt[0] */ - { U32 i ; for (i=0; i immediate encoding */ { U32 const maxML = matches[nbMatches-1].len; - U32 const maxOffcode = matches[nbMatches-1].off; - DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffCode=%u at cPos=%u => start new series", - nbMatches, maxML, maxOffcode, (U32)(ip-prefixStart)); + U32 const maxOffBase = matches[nbMatches-1].off; + DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffBase=%u at cPos=%u => start new series", + nbMatches, maxML, maxOffBase, (U32)(ip-prefixStart)); if (maxML > sufficient_len) { - lastSequence.litlen = litlen; - lastSequence.mlen = maxML; - lastSequence.off = maxOffcode; - DEBUGLOG(6, "large match (%u>%u), immediate encoding", + lastStretch.litlen = 0; + lastStretch.mlen = maxML; + lastStretch.off = maxOffBase; + DEBUGLOG(6, "large match (%u>%u) => immediate encoding", maxML, sufficient_len); cur = 0; - last_pos = ZSTD_totalLen(lastSequence); + last_pos = maxML; goto _shortestPath; } } /* set prices for first matches starting position == 0 */ assert(opt[0].price >= 0); - { U32 const literalsPrice = (U32)opt[0].price + ZSTD_litLengthPrice(0, optStatePtr, optLevel); - U32 pos; + { U32 pos; U32 matchNb; for (pos = 1; pos < minMatch; pos++) { - opt[pos].price = ZSTD_MAX_PRICE; /* mlen, litlen and price will be fixed during forward scanning */ + opt[pos].price = ZSTD_MAX_PRICE; + opt[pos].mlen = 0; + opt[pos].litlen = litlen + pos; } for (matchNb = 0; matchNb < nbMatches; matchNb++) { - U32 const offcode = matches[matchNb].off; + U32 const offBase = matches[matchNb].off; U32 const end = matches[matchNb].len; for ( ; pos <= end ; pos++ ) { - U32 const matchPrice = ZSTD_getMatchPrice(offcode, pos, optStatePtr, optLevel); - U32 const sequencePrice = literalsPrice + matchPrice; + int const matchPrice = (int)ZSTD_getMatchPrice(offBase, pos, optStatePtr, optLevel); + int const sequencePrice = opt[0].price + matchPrice; DEBUGLOG(7, "rPos:%u => set initial price : %.2f", pos, ZSTD_fCost(sequencePrice)); opt[pos].mlen = pos; - opt[pos].off = offcode; - opt[pos].litlen = litlen; - opt[pos].price = (int)sequencePrice; - } } + opt[pos].off = offBase; + opt[pos].litlen = 0; /* end of match */ + opt[pos].price = sequencePrice + LL_PRICE(0); + } + } last_pos = pos-1; + opt[pos].price = ZSTD_MAX_PRICE; } } /* check further positions */ for (cur = 1; cur <= last_pos; cur++) { const BYTE* const inr = ip + cur; - assert(cur < ZSTD_OPT_NUM); - DEBUGLOG(7, "cPos:%zi==rPos:%u", inr-istart, cur) + assert(cur <= ZSTD_OPT_NUM); + DEBUGLOG(7, "cPos:%zi==rPos:%u", inr-istart, cur); /* Fix current position with one literal if cheaper */ - { U32 const litlen = (opt[cur-1].mlen == 0) ? opt[cur-1].litlen + 1 : 1; + { U32 const litlen = opt[cur-1].litlen + 1; int const price = opt[cur-1].price - + (int)ZSTD_rawLiteralsCost(ip+cur-1, 1, optStatePtr, optLevel) - + (int)ZSTD_litLengthPrice(litlen, optStatePtr, optLevel) - - (int)ZSTD_litLengthPrice(litlen-1, optStatePtr, optLevel); + + LIT_PRICE(ip+cur-1) + + LL_INCPRICE(litlen); assert(price < 1000000000); /* overflow check */ if (price <= opt[cur].price) { + ZSTD_optimal_t const prevMatch = opt[cur]; DEBUGLOG(7, "cPos:%zi==rPos:%u : better price (%.2f<=%.2f) using literal (ll==%u) (hist:%u,%u,%u)", inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), litlen, opt[cur-1].rep[0], opt[cur-1].rep[1], opt[cur-1].rep[2]); - opt[cur].mlen = 0; - opt[cur].off = 0; + opt[cur] = opt[cur-1]; opt[cur].litlen = litlen; opt[cur].price = price; + if ( (optLevel >= 1) /* additional check only for higher modes */ + && (prevMatch.litlen == 0) /* replace a match */ + && (LL_INCPRICE(1) < 0) /* ll1 is cheaper than ll0 */ + && LIKELY(ip + cur < iend) + ) { + /* check next position, in case it would be cheaper */ + int with1literal = prevMatch.price + LIT_PRICE(ip+cur) + LL_INCPRICE(1); + int withMoreLiterals = price + LIT_PRICE(ip+cur) + LL_INCPRICE(litlen+1); + DEBUGLOG(7, "then at next rPos %u : match+1lit %.2f vs %ulits %.2f", + cur+1, ZSTD_fCost(with1literal), litlen+1, ZSTD_fCost(withMoreLiterals)); + if ( (with1literal < withMoreLiterals) + && (with1literal < opt[cur+1].price) ) { + /* update offset history - before it disappears */ + U32 const prev = cur - prevMatch.mlen; + repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, prevMatch.off, opt[prev].litlen==0); + assert(cur >= prevMatch.mlen); + DEBUGLOG(7, "==> match+1lit is cheaper (%.2f < %.2f) (hist:%u,%u,%u) !", + ZSTD_fCost(with1literal), ZSTD_fCost(withMoreLiterals), + newReps.rep[0], newReps.rep[1], newReps.rep[2] ); + opt[cur+1] = prevMatch; /* mlen & offbase */ + ZSTD_memcpy(opt[cur+1].rep, &newReps, sizeof(repcodes_t)); + opt[cur+1].litlen = 1; + opt[cur+1].price = with1literal; + if (last_pos < cur+1) last_pos = cur+1; + } + } } else { - DEBUGLOG(7, "cPos:%zi==rPos:%u : literal would cost more (%.2f>%.2f) (hist:%u,%u,%u)", - inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price), - opt[cur].rep[0], opt[cur].rep[1], opt[cur].rep[2]); + DEBUGLOG(7, "cPos:%zi==rPos:%u : literal would cost more (%.2f>%.2f)", + inr-istart, cur, ZSTD_fCost(price), ZSTD_fCost(opt[cur].price)); } } - /* Set the repcodes of the current position. We must do it here - * because we rely on the repcodes of the 2nd to last sequence being - * correct to set the next chunks repcodes during the backward - * traversal. + /* Offset history is not updated during match comparison. + * Do it here, now that the match is selected and confirmed. */ ZSTD_STATIC_ASSERT(sizeof(opt[cur].rep) == sizeof(repcodes_t)); assert(cur >= opt[cur].mlen); - if (opt[cur].mlen != 0) { + if (opt[cur].litlen == 0) { + /* just finished a match => alter offset history */ U32 const prev = cur - opt[cur].mlen; - repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[cur].litlen==0); + repcodes_t const newReps = ZSTD_newRep(opt[prev].rep, opt[cur].off, opt[prev].litlen==0); ZSTD_memcpy(opt[cur].rep, &newReps, sizeof(repcodes_t)); - } else { - ZSTD_memcpy(opt[cur].rep, opt[cur - 1].rep, sizeof(repcodes_t)); } /* last match must start at a minimum distance of 8 from oend */ @@ -1188,15 +1265,14 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, if ( (optLevel==0) /*static_test*/ && (opt[cur+1].price <= opt[cur].price + (BITCOST_MULTIPLIER/2)) ) { - DEBUGLOG(7, "move to next rPos:%u : price is <=", cur+1); + DEBUGLOG(7, "skip current position : next rPos(%u) price is cheaper", cur+1); continue; /* skip unpromising positions; about ~+6% speed, -0.01 ratio */ } assert(opt[cur].price >= 0); - { U32 const ll0 = (opt[cur].mlen != 0); - U32 const litlen = (opt[cur].mlen == 0) ? opt[cur].litlen : 0; - U32 const previousPrice = (U32)opt[cur].price; - U32 const basePrice = previousPrice + ZSTD_litLengthPrice(0, optStatePtr, optLevel); + { U32 const ll0 = (opt[cur].litlen == 0); + int const previousPrice = opt[cur].price; + int const basePrice = previousPrice + LL_PRICE(0); U32 nbMatches = getAllMatches(matches, ms, &nextToUpdate3, inr, iend, opt[cur].rep, ll0, minMatch); U32 matchNb; @@ -1208,18 +1284,17 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, continue; } - { U32 const maxML = matches[nbMatches-1].len; - DEBUGLOG(7, "cPos:%zi==rPos:%u, found %u matches, of maxLength=%u", - inr-istart, cur, nbMatches, maxML); - - if ( (maxML > sufficient_len) - || (cur + maxML >= ZSTD_OPT_NUM) ) { - lastSequence.mlen = maxML; - lastSequence.off = matches[nbMatches-1].off; - lastSequence.litlen = litlen; - cur -= (opt[cur].mlen==0) ? opt[cur].litlen : 0; /* last sequence is actually only literals, fix cur to last match - note : may underflow, in which case, it's first sequence, and it's okay */ - last_pos = cur + ZSTD_totalLen(lastSequence); - if (cur > ZSTD_OPT_NUM) cur = 0; /* underflow => first match */ + { U32 const longestML = matches[nbMatches-1].len; + DEBUGLOG(7, "cPos:%zi==rPos:%u, found %u matches, of longest ML=%u", + inr-istart, cur, nbMatches, longestML); + + if ( (longestML > sufficient_len) + || (cur + longestML >= ZSTD_OPT_NUM) + || (ip + cur + longestML >= iend) ) { + lastStretch.mlen = longestML; + lastStretch.off = matches[nbMatches-1].off; + lastStretch.litlen = 0; + last_pos = cur + longestML; goto _shortestPath; } } @@ -1230,20 +1305,25 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, U32 const startML = (matchNb>0) ? matches[matchNb-1].len+1 : minMatch; U32 mlen; - DEBUGLOG(7, "testing match %u => offCode=%4u, mlen=%2u, llen=%2u", - matchNb, matches[matchNb].off, lastML, litlen); + DEBUGLOG(7, "testing match %u => offBase=%4u, mlen=%2u, llen=%2u", + matchNb, matches[matchNb].off, lastML, opt[cur].litlen); for (mlen = lastML; mlen >= startML; mlen--) { /* scan downward */ U32 const pos = cur + mlen; - int const price = (int)basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); + int const price = basePrice + (int)ZSTD_getMatchPrice(offset, mlen, optStatePtr, optLevel); if ((pos > last_pos) || (price < opt[pos].price)) { DEBUGLOG(7, "rPos:%u (ml=%2u) => new better price (%.2f<%.2f)", pos, mlen, ZSTD_fCost(price), ZSTD_fCost(opt[pos].price)); - while (last_pos < pos) { opt[last_pos+1].price = ZSTD_MAX_PRICE; last_pos++; } /* fill empty positions */ + while (last_pos < pos) { + /* fill empty positions, for future comparisons */ + last_pos++; + opt[last_pos].price = ZSTD_MAX_PRICE; + opt[last_pos].litlen = !0; /* just needs to be != 0, to mean "not an end of match" */ + } opt[pos].mlen = mlen; opt[pos].off = offset; - opt[pos].litlen = litlen; + opt[pos].litlen = 0; opt[pos].price = price; } else { DEBUGLOG(7, "rPos:%u (ml=%2u) => new price is worse (%.2f>=%.2f)", @@ -1251,52 +1331,86 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, if (optLevel==0) break; /* early update abort; gets ~+10% speed for about -0.01 ratio loss */ } } } } + opt[last_pos+1].price = ZSTD_MAX_PRICE; } /* for (cur = 1; cur <= last_pos; cur++) */ - lastSequence = opt[last_pos]; - cur = last_pos > ZSTD_totalLen(lastSequence) ? last_pos - ZSTD_totalLen(lastSequence) : 0; /* single sequence, and it starts before `ip` */ - assert(cur < ZSTD_OPT_NUM); /* control overflow*/ + lastStretch = opt[last_pos]; + assert(cur >= lastStretch.mlen); + cur = last_pos - lastStretch.mlen; _shortestPath: /* cur, last_pos, best_mlen, best_off have to be set */ assert(opt[0].mlen == 0); + assert(last_pos >= lastStretch.mlen); + assert(cur == last_pos - lastStretch.mlen); - /* Set the next chunk's repcodes based on the repcodes of the beginning - * of the last match, and the last sequence. This avoids us having to - * update them while traversing the sequences. - */ - if (lastSequence.mlen != 0) { - repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastSequence.off, lastSequence.litlen==0); - ZSTD_memcpy(rep, &reps, sizeof(reps)); + if (lastStretch.mlen==0) { + /* no solution : all matches have been converted into literals */ + assert(lastStretch.litlen == (ip - anchor) + last_pos); + ip += last_pos; + continue; + } + assert(lastStretch.off > 0); + + /* Update offset history */ + if (lastStretch.litlen == 0) { + /* finishing on a match : update offset history */ + repcodes_t const reps = ZSTD_newRep(opt[cur].rep, lastStretch.off, opt[cur].litlen==0); + ZSTD_memcpy(rep, &reps, sizeof(repcodes_t)); } else { - ZSTD_memcpy(rep, opt[cur].rep, sizeof(repcodes_t)); + ZSTD_memcpy(rep, lastStretch.rep, sizeof(repcodes_t)); + assert(cur >= lastStretch.litlen); + cur -= lastStretch.litlen; } - { U32 const storeEnd = cur + 1; + /* Let's write the shortest path solution. + * It is stored in @opt in reverse order, + * starting from @storeEnd (==cur+2), + * effectively partially @opt overwriting. + * Content is changed too: + * - So far, @opt stored stretches, aka a match followed by literals + * - Now, it will store sequences, aka literals followed by a match + */ + { U32 const storeEnd = cur + 2; U32 storeStart = storeEnd; - U32 seqPos = cur; + U32 stretchPos = cur; DEBUGLOG(6, "start reverse traversal (last_pos:%u, cur:%u)", last_pos, cur); (void)last_pos; - assert(storeEnd < ZSTD_OPT_NUM); - DEBUGLOG(6, "last sequence copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", - storeEnd, lastSequence.litlen, lastSequence.mlen, lastSequence.off); - opt[storeEnd] = lastSequence; - while (seqPos > 0) { - U32 const backDist = ZSTD_totalLen(opt[seqPos]); + assert(storeEnd < ZSTD_OPT_SIZE); + DEBUGLOG(6, "last stretch copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", + storeEnd, lastStretch.litlen, lastStretch.mlen, lastStretch.off); + if (lastStretch.litlen > 0) { + /* last "sequence" is unfinished: just a bunch of literals */ + opt[storeEnd].litlen = lastStretch.litlen; + opt[storeEnd].mlen = 0; + storeStart = storeEnd-1; + opt[storeStart] = lastStretch; + } { + opt[storeEnd] = lastStretch; /* note: litlen will be fixed */ + storeStart = storeEnd; + } + while (1) { + ZSTD_optimal_t nextStretch = opt[stretchPos]; + opt[storeStart].litlen = nextStretch.litlen; + DEBUGLOG(6, "selected sequence (llen=%u,mlen=%u,ofc=%u)", + opt[storeStart].litlen, opt[storeStart].mlen, opt[storeStart].off); + if (nextStretch.mlen == 0) { + /* reaching beginning of segment */ + break; + } storeStart--; - DEBUGLOG(6, "sequence from rPos=%u copied into pos=%u (llen=%u,mlen=%u,ofc=%u)", - seqPos, storeStart, opt[seqPos].litlen, opt[seqPos].mlen, opt[seqPos].off); - opt[storeStart] = opt[seqPos]; - seqPos = (seqPos > backDist) ? seqPos - backDist : 0; + opt[storeStart] = nextStretch; /* note: litlen will be fixed */ + assert(nextStretch.litlen + nextStretch.mlen <= stretchPos); + stretchPos -= nextStretch.litlen + nextStretch.mlen; } /* save sequences */ - DEBUGLOG(6, "sending selected sequences into seqStore") + DEBUGLOG(6, "sending selected sequences into seqStore"); { U32 storePos; for (storePos=storeStart; storePos <= storeEnd; storePos++) { U32 const llen = opt[storePos].litlen; U32 const mlen = opt[storePos].mlen; - U32 const offCode = opt[storePos].off; + U32 const offBase = opt[storePos].off; U32 const advance = llen + mlen; DEBUGLOG(6, "considering seq starting at %zi, llen=%u, mlen=%u", anchor - istart, (unsigned)llen, (unsigned)mlen); @@ -1308,11 +1422,14 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, } assert(anchor + llen <= iend); - ZSTD_updateStats(optStatePtr, llen, anchor, offCode, mlen); - ZSTD_storeSeq(seqStore, llen, anchor, iend, offCode, mlen); + ZSTD_updateStats(optStatePtr, llen, anchor, offBase, mlen); + ZSTD_storeSeq(seqStore, llen, anchor, iend, offBase, mlen); anchor += advance; ip = anchor; } } + DEBUGLOG(7, "new offset history : %u, %u, %u", rep[0], rep[1], rep[2]); + + /* update all costs */ ZSTD_setBasePrices(optStatePtr, optLevel); } } /* while (ip < ilimit) */ @@ -1320,21 +1437,27 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms, /* Return the last literals size */ return (size_t)(iend - anchor); } +#endif /* build exclusions */ +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR static size_t ZSTD_compressBlock_opt0( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) { return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 0 /* optLevel */, dictMode); } +#endif +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR static size_t ZSTD_compressBlock_opt2( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize, const ZSTD_dictMode_e dictMode) { return ZSTD_compressBlock_opt_generic(ms, seqStore, rep, src, srcSize, 2 /* optLevel */, dictMode); } +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) @@ -1342,20 +1465,23 @@ size_t ZSTD_compressBlock_btopt( DEBUGLOG(5, "ZSTD_compressBlock_btopt"); return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_noDict); } +#endif +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR /* ZSTD_initStats_ultra(): * make a first compression pass, just to seed stats with more accurate starting values. * only works on first block, with no dictionary and no ldm. - * this function cannot error, hence its contract must be respected. + * this function cannot error out, its narrow contract must be respected. */ -static void -ZSTD_initStats_ultra(ZSTD_matchState_t* ms, - seqStore_t* seqStore, - U32 rep[ZSTD_REP_NUM], - const void* src, size_t srcSize) +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +void ZSTD_initStats_ultra(ZSTD_matchState_t* ms, + seqStore_t* seqStore, + U32 rep[ZSTD_REP_NUM], + const void* src, size_t srcSize) { U32 tmpRep[ZSTD_REP_NUM]; /* updated rep codes will sink here */ ZSTD_memcpy(tmpRep, rep, sizeof(tmpRep)); @@ -1368,7 +1494,7 @@ ZSTD_initStats_ultra(ZSTD_matchState_t* ms, ZSTD_compressBlock_opt2(ms, seqStore, tmpRep, src, srcSize, ZSTD_noDict); /* generate stats into ms->opt*/ - /* invalidate first scan from history */ + /* invalidate first scan from history, only keep entropy stats */ ZSTD_resetSeqStore(seqStore); ms->window.base -= srcSize; ms->window.dictLimit += (U32)srcSize; @@ -1392,10 +1518,10 @@ size_t ZSTD_compressBlock_btultra2( U32 const curr = (U32)((const BYTE*)src - ms->window.base); DEBUGLOG(5, "ZSTD_compressBlock_btultra2 (srcSize=%zu)", srcSize); - /* 2-pass strategy: + /* 2-passes strategy: * this strategy makes a first pass over first block to collect statistics - * and seed next round's statistics with it. - * After 1st pass, function forgets everything, and starts a new block. + * in order to seed next round's statistics with it. + * After 1st pass, function forgets history, and starts a new block. * Consequently, this can only work if no data has been previously loaded in tables, * aka, no dictionary, no prefix, no ldm preprocessing. * The compression ratio gain is generally small (~0.5% on first block), @@ -1404,15 +1530,17 @@ size_t ZSTD_compressBlock_btultra2( if ( (ms->opt.litLengthSum==0) /* first block */ && (seqStore->sequences == seqStore->sequencesStart) /* no ldm */ && (ms->window.dictLimit == ms->window.lowLimit) /* no dictionary */ - && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ - && (srcSize > ZSTD_PREDEF_THRESHOLD) + && (curr == ms->window.dictLimit) /* start of frame, nothing already loaded nor skipped */ + && (srcSize > ZSTD_PREDEF_THRESHOLD) /* input large enough to not employ default stats */ ) { ZSTD_initStats_ultra(ms, seqStore, rep, src, srcSize); } return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_noDict); } +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) @@ -1420,18 +1548,20 @@ size_t ZSTD_compressBlock_btopt_dictMatchState( return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); } -size_t ZSTD_compressBlock_btultra_dictMatchState( +size_t ZSTD_compressBlock_btopt_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { - return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); + return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); } +#endif -size_t ZSTD_compressBlock_btopt_extDict( +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btultra_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], const void* src, size_t srcSize) { - return ZSTD_compressBlock_opt0(ms, seqStore, rep, src, srcSize, ZSTD_extDict); + return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_dictMatchState); } size_t ZSTD_compressBlock_btultra_extDict( @@ -1440,6 +1570,7 @@ size_t ZSTD_compressBlock_btultra_extDict( { return ZSTD_compressBlock_opt2(ms, seqStore, rep, src, srcSize, ZSTD_extDict); } +#endif /* note : no btultra2 variant for extDict nor dictMatchState, * because btultra2 is not meant to work with dictionaries diff --git a/lib/zstd/compress/zstd_opt.h b/lib/zstd/compress/zstd_opt.h index 22b862858ba7..ac1b743d27cd 100644 --- a/lib/zstd/compress/zstd_opt.h +++ b/lib/zstd/compress/zstd_opt.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,30 +15,40 @@ #include "zstd_compress_internal.h" +#if !defined(ZSTD_EXCLUDE_BTLAZY2_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR) \ + || !defined(ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR) /* used in ZSTD_loadDictionaryContent() */ void ZSTD_updateTree(ZSTD_matchState_t* ms, const BYTE* ip, const BYTE* iend); +#endif +#ifndef ZSTD_EXCLUDE_BTOPT_BLOCK_COMPRESSOR size_t ZSTD_compressBlock_btopt( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_btultra( +size_t ZSTD_compressBlock_btopt_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); -size_t ZSTD_compressBlock_btultra2( +size_t ZSTD_compressBlock_btopt_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); +#define ZSTD_COMPRESSBLOCK_BTOPT ZSTD_compressBlock_btopt +#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE ZSTD_compressBlock_btopt_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT ZSTD_compressBlock_btopt_extDict +#else +#define ZSTD_COMPRESSBLOCK_BTOPT NULL +#define ZSTD_COMPRESSBLOCK_BTOPT_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTOPT_EXTDICT NULL +#endif -size_t ZSTD_compressBlock_btopt_dictMatchState( +#ifndef ZSTD_EXCLUDE_BTULTRA_BLOCK_COMPRESSOR +size_t ZSTD_compressBlock_btultra( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); size_t ZSTD_compressBlock_btultra_dictMatchState( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); - -size_t ZSTD_compressBlock_btopt_extDict( - ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], - void const* src, size_t srcSize); size_t ZSTD_compressBlock_btultra_extDict( ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], void const* src, size_t srcSize); @@ -45,6 +56,20 @@ size_t ZSTD_compressBlock_btultra_extDict( /* note : no btultra2 variant for extDict nor dictMatchState, * because btultra2 is not meant to work with dictionaries * and is only specific for the first block (no prefix) */ +size_t ZSTD_compressBlock_btultra2( + ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], + void const* src, size_t srcSize); + +#define ZSTD_COMPRESSBLOCK_BTULTRA ZSTD_compressBlock_btultra +#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE ZSTD_compressBlock_btultra_dictMatchState +#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT ZSTD_compressBlock_btultra_extDict +#define ZSTD_COMPRESSBLOCK_BTULTRA2 ZSTD_compressBlock_btultra2 +#else +#define ZSTD_COMPRESSBLOCK_BTULTRA NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA_DICTMATCHSTATE NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA_EXTDICT NULL +#define ZSTD_COMPRESSBLOCK_BTULTRA2 NULL +#endif #endif /* ZSTD_OPT_H */ diff --git a/lib/zstd/decompress/huf_decompress.c b/lib/zstd/decompress/huf_decompress.c index 60958afebc41..ac8b87f48f84 100644 --- a/lib/zstd/decompress/huf_decompress.c +++ b/lib/zstd/decompress/huf_decompress.c @@ -1,7 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* ****************************************************************** * huff0 huffman decoder, * part of Finite State Entropy library - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy @@ -19,10 +20,10 @@ #include "../common/compiler.h" #include "../common/bitstream.h" /* BIT_* */ #include "../common/fse.h" /* to compress headers */ -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "../common/error_private.h" #include "../common/zstd_internal.h" +#include "../common/bits.h" /* ZSTD_highbit32, ZSTD_countTrailingZeros64 */ /* ************************************************************** * Constants @@ -34,6 +35,12 @@ * Macros ****************************************************************/ +#ifdef HUF_DISABLE_FAST_DECODE +# define HUF_ENABLE_FAST_DECODE 0 +#else +# define HUF_ENABLE_FAST_DECODE 1 +#endif + /* These two optional macros force the use one way or another of the two * Huffman decompression implementations. You can't force in both directions * at the same time. @@ -43,27 +50,25 @@ #error "Cannot force the use of the X1 and X2 decoders at the same time!" #endif -#if ZSTD_ENABLE_ASM_X86_64_BMI2 && DYNAMIC_BMI2 -# define HUF_ASM_X86_64_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE +/* When DYNAMIC_BMI2 is enabled, fast decoders are only called when bmi2 is + * supported at runtime, so we can add the BMI2 target attribute. + * When it is disabled, we will still get BMI2 if it is enabled statically. + */ +#if DYNAMIC_BMI2 +# define HUF_FAST_BMI2_ATTRS BMI2_TARGET_ATTRIBUTE #else -# define HUF_ASM_X86_64_BMI2_ATTRS +# define HUF_FAST_BMI2_ATTRS #endif #define HUF_EXTERN_C #define HUF_ASM_DECL HUF_EXTERN_C -#if DYNAMIC_BMI2 || (ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) +#if DYNAMIC_BMI2 # define HUF_NEED_BMI2_FUNCTION 1 #else # define HUF_NEED_BMI2_FUNCTION 0 #endif -#if !(ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__)) -# define HUF_NEED_DEFAULT_FUNCTION 1 -#else -# define HUF_NEED_DEFAULT_FUNCTION 0 -#endif - /* ************************************************************** * Error Management ****************************************************************/ @@ -80,6 +85,11 @@ /* ************************************************************** * BMI2 Variant Wrappers ****************************************************************/ +typedef size_t (*HUF_DecompressUsingDTableFn)(void *dst, size_t dstSize, + const void *cSrc, + size_t cSrcSize, + const HUF_DTable *DTable); + #if DYNAMIC_BMI2 #define HUF_DGEN(fn) \ @@ -101,9 +111,9 @@ } \ \ static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ + size_t cSrcSize, HUF_DTable const* DTable, int flags) \ { \ - if (bmi2) { \ + if (flags & HUF_flags_bmi2) { \ return fn##_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); \ } \ return fn##_default(dst, dstSize, cSrc, cSrcSize, DTable); \ @@ -113,9 +123,9 @@ #define HUF_DGEN(fn) \ static size_t fn(void* dst, size_t dstSize, void const* cSrc, \ - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) \ + size_t cSrcSize, HUF_DTable const* DTable, int flags) \ { \ - (void)bmi2; \ + (void)flags; \ return fn##_body(dst, dstSize, cSrc, cSrcSize, DTable); \ } @@ -134,43 +144,66 @@ static DTableDesc HUF_getDTableDesc(const HUF_DTable* table) return dtd; } -#if ZSTD_ENABLE_ASM_X86_64_BMI2 - -static size_t HUF_initDStream(BYTE const* ip) { +static size_t HUF_initFastDStream(BYTE const* ip) { BYTE const lastByte = ip[7]; - size_t const bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; + size_t const bitsConsumed = lastByte ? 8 - ZSTD_highbit32(lastByte) : 0; size_t const value = MEM_readLEST(ip) | 1; assert(bitsConsumed <= 8); + assert(sizeof(size_t) == 8); return value << bitsConsumed; } + + +/* + * The input/output arguments to the Huffman fast decoding loop: + * + * ip [in/out] - The input pointers, must be updated to reflect what is consumed. + * op [in/out] - The output pointers, must be updated to reflect what is written. + * bits [in/out] - The bitstream containers, must be updated to reflect the current state. + * dt [in] - The decoding table. + * ilowest [in] - The beginning of the valid range of the input. Decoders may read + * down to this pointer. It may be below iend[0]. + * oend [in] - The end of the output stream. op[3] must not cross oend. + * iend [in] - The end of each input stream. ip[i] may cross iend[i], + * as long as it is above ilowest, but that indicates corruption. + */ typedef struct { BYTE const* ip[4]; BYTE* op[4]; U64 bits[4]; void const* dt; - BYTE const* ilimit; + BYTE const* ilowest; BYTE* oend; BYTE const* iend[4]; -} HUF_DecompressAsmArgs; +} HUF_DecompressFastArgs; + +typedef void (*HUF_DecompressFastLoopFn)(HUF_DecompressFastArgs*); /* - * Initializes args for the asm decoding loop. - * @returns 0 on success - * 1 if the fallback implementation should be used. + * Initializes args for the fast decoding loop. + * @returns 1 on success + * 0 if the fallback implementation should be used. * Or an error code on failure. */ -static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) +static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* dst, size_t dstSize, void const* src, size_t srcSize, const HUF_DTable* DTable) { void const* dt = DTable + 1; U32 const dtLog = HUF_getDTableDesc(DTable).tableLog; - const BYTE* const ilimit = (const BYTE*)src + 6 + 8; + const BYTE* const istart = (const BYTE*)src; - BYTE* const oend = (BYTE*)dst + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); - /* The following condition is false on x32 platform, - * but HUF_asm is not compatible with this ABI */ - if (!(MEM_isLittleEndian() && !MEM_32bits())) return 1; + /* The fast decoding loop assumes 64-bit little-endian. + * This condition is false on x32. + */ + if (!MEM_isLittleEndian() || MEM_32bits()) + return 0; + + /* Avoid nullptr addition */ + if (dstSize == 0) + return 0; + assert(dst != NULL); /* strict minimum : jump table + 1 byte per stream */ if (srcSize < 10) @@ -181,11 +214,10 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, * On small inputs we don't have enough data to trigger the fast loop, so use the old decoder. */ if (dtLog != HUF_DECODER_FAST_TABLELOG) - return 1; + return 0; /* Read the jump table. */ { - const BYTE* const istart = (const BYTE*)src; size_t const length1 = MEM_readLE16(istart); size_t const length2 = MEM_readLE16(istart+2); size_t const length3 = MEM_readLE16(istart+4); @@ -195,13 +227,11 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, args->iend[2] = args->iend[1] + length2; args->iend[3] = args->iend[2] + length3; - /* HUF_initDStream() requires this, and this small of an input + /* HUF_initFastDStream() requires this, and this small of an input * won't benefit from the ASM loop anyways. - * length1 must be >= 16 so that ip[0] >= ilimit before the loop - * starts. */ - if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8) - return 1; + if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8) + return 0; if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */ } /* ip[] contains the position that is currently loaded into bits[]. */ @@ -218,7 +248,7 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, /* No point to call the ASM loop for tiny outputs. */ if (args->op[3] >= oend) - return 1; + return 0; /* bits[] is the bit container. * It is read from the MSB down to the LSB. @@ -227,24 +257,25 @@ static size_t HUF_DecompressAsmArgs_init(HUF_DecompressAsmArgs* args, void* dst, * set, so that CountTrailingZeros(bits[]) can be used * to count how many bits we've consumed. */ - args->bits[0] = HUF_initDStream(args->ip[0]); - args->bits[1] = HUF_initDStream(args->ip[1]); - args->bits[2] = HUF_initDStream(args->ip[2]); - args->bits[3] = HUF_initDStream(args->ip[3]); - - /* If ip[] >= ilimit, it is guaranteed to be safe to - * reload bits[]. It may be beyond its section, but is - * guaranteed to be valid (>= istart). - */ - args->ilimit = ilimit; + args->bits[0] = HUF_initFastDStream(args->ip[0]); + args->bits[1] = HUF_initFastDStream(args->ip[1]); + args->bits[2] = HUF_initFastDStream(args->ip[2]); + args->bits[3] = HUF_initFastDStream(args->ip[3]); + + /* The decoders must be sure to never read beyond ilowest. + * This is lower than iend[0], but allowing decoders to read + * down to ilowest can allow an extra iteration or two in the + * fast loop. + */ + args->ilowest = istart; args->oend = oend; args->dt = dt; - return 0; + return 1; } -static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs const* args, int stream, BYTE* segmentEnd) +static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArgs const* args, int stream, BYTE* segmentEnd) { /* Validate that we haven't overwritten. */ if (args->op[stream] > segmentEnd) @@ -258,15 +289,33 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressAsmArgs return ERROR(corruption_detected); /* Construct the BIT_DStream_t. */ - bit->bitContainer = MEM_readLE64(args->ip[stream]); - bit->bitsConsumed = ZSTD_countTrailingZeros((size_t)args->bits[stream]); - bit->start = (const char*)args->iend[0]; + assert(sizeof(size_t) == 8); + bit->bitContainer = MEM_readLEST(args->ip[stream]); + bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]); + bit->start = (const char*)args->ilowest; bit->limitPtr = bit->start + sizeof(size_t); bit->ptr = (const char*)args->ip[stream]; return 0; } -#endif + +/* Calls X(N) for each stream 0, 1, 2, 3. */ +#define HUF_4X_FOR_EACH_STREAM(X) \ + do { \ + X(0); \ + X(1); \ + X(2); \ + X(3); \ + } while (0) + +/* Calls X(N, var) for each stream 0, 1, 2, 3. */ +#define HUF_4X_FOR_EACH_STREAM_WITH_VAR(X, var) \ + do { \ + X(0, (var)); \ + X(1, (var)); \ + X(2, (var)); \ + X(3, (var)); \ + } while (0) #ifndef HUF_FORCE_DECOMPRESS_X2 @@ -283,10 +332,11 @@ typedef struct { BYTE nbBits; BYTE byte; } HUF_DEltX1; /* single-symbol decodi static U64 HUF_DEltX1_set4(BYTE symbol, BYTE nbBits) { U64 D4; if (MEM_isLittleEndian()) { - D4 = (symbol << 8) + nbBits; + D4 = (U64)((symbol << 8) + nbBits); } else { - D4 = symbol + (nbBits << 8); + D4 = (U64)(symbol + (nbBits << 8)); } + assert(D4 < (1U << 16)); D4 *= 0x0001000100010001ULL; return D4; } @@ -329,13 +379,7 @@ typedef struct { BYTE huffWeight[HUF_SYMBOLVALUE_MAX + 1]; } HUF_ReadDTableX1_Workspace; - -size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize) -{ - return HUF_readDTableX1_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - -size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags) { U32 tableLog = 0; U32 nbSymbols = 0; @@ -350,7 +394,7 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr DEBUG_STATIC_ASSERT(sizeof(DTableDesc) == sizeof(HUF_DTable)); /* ZSTD_memset(huffWeight, 0, sizeof(huffWeight)); */ /* is not necessary, even though some analyzer complain ... */ - iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), bmi2); + iSize = HUF_readStats_wksp(wksp->huffWeight, HUF_SYMBOLVALUE_MAX + 1, wksp->rankVal, &nbSymbols, &tableLog, src, srcSize, wksp->statsWksp, sizeof(wksp->statsWksp), flags); if (HUF_isError(iSize)) return iSize; @@ -377,9 +421,8 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr * rankStart[0] is not filled because there are no entries in the table for * weight 0. */ - { - int n; - int nextRankStart = 0; + { int n; + U32 nextRankStart = 0; int const unroll = 4; int const nLimit = (int)nbSymbols - unroll + 1; for (n=0; n<(int)tableLog+1; n++) { @@ -406,10 +449,9 @@ size_t HUF_readDTableX1_wksp_bmi2(HUF_DTable* DTable, const void* src, size_t sr * We can switch based on the length to a different inner loop which is * optimized for that particular case. */ - { - U32 w; - int symbol=wksp->rankVal[0]; - int rankStart=0; + { U32 w; + int symbol = wksp->rankVal[0]; + int rankStart = 0; for (w=1; wrankVal[w]; int const length = (1 << w) >> 1; @@ -483,15 +525,19 @@ HUF_decodeSymbolX1(BIT_DStream_t* Dstream, const HUF_DEltX1* dt, const U32 dtLog } #define HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) \ - *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog) + do { *ptr++ = HUF_decodeSymbolX1(DStreamPtr, dt, dtLog); } while (0) -#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ - if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ - HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) +#define HUF_DECODE_SYMBOLX1_1(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ + HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ + } while (0) -#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ - if (MEM_64bits()) \ - HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr) +#define HUF_DECODE_SYMBOLX1_2(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits()) \ + HUF_DECODE_SYMBOLX1_0(ptr, DStreamPtr); \ + } while (0) HINT_INLINE size_t HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, const HUF_DEltX1* const dt, const U32 dtLog) @@ -519,7 +565,7 @@ HUF_decodeStreamX1(BYTE* p, BIT_DStream_t* const bitDPtr, BYTE* const pEnd, cons while (p < pEnd) HUF_DECODE_SYMBOLX1_0(p, bitDPtr); - return pEnd-pStart; + return (size_t)(pEnd-pStart); } FORCE_INLINE_TEMPLATE size_t @@ -529,7 +575,7 @@ HUF_decompress1X1_usingDTable_internal_body( const HUF_DTable* DTable) { BYTE* op = (BYTE*)dst; - BYTE* const oend = op + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(op, dstSize); const void* dtPtr = DTable + 1; const HUF_DEltX1* const dt = (const HUF_DEltX1*)dtPtr; BIT_DStream_t bitD; @@ -545,6 +591,10 @@ HUF_decompress1X1_usingDTable_internal_body( return dstSize; } +/* HUF_decompress4X1_usingDTable_internal_body(): + * Conditions : + * @dstSize >= 6 + */ FORCE_INLINE_TEMPLATE size_t HUF_decompress4X1_usingDTable_internal_body( void* dst, size_t dstSize, @@ -553,6 +603,7 @@ HUF_decompress4X1_usingDTable_internal_body( { /* Check */ if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ { const BYTE* const istart = (const BYTE*) cSrc; BYTE* const ostart = (BYTE*) dst; @@ -588,6 +639,7 @@ HUF_decompress4X1_usingDTable_internal_body( if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + assert(dstSize >= 6); /* validated above */ CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); @@ -650,52 +702,173 @@ size_t HUF_decompress4X1_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo } #endif -#if HUF_NEED_DEFAULT_FUNCTION static size_t HUF_decompress4X1_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, size_t cSrcSize, HUF_DTable const* DTable) { return HUF_decompress4X1_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); } -#endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 -HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; +HUF_ASM_DECL void HUF_decompress4X1_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; + +#endif + +static HUF_FAST_BMI2_ATTRS +void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) +{ + U64 bits[4]; + BYTE const* ip[4]; + BYTE* op[4]; + U16 const* const dtable = (U16 const*)args->dt; + BYTE* const oend = args->oend; + BYTE const* const ilowest = args->ilowest; + + /* Copy the arguments to local variables */ + ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); + ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); + ZSTD_memcpy(&op, &args->op, sizeof(op)); + + assert(MEM_isLittleEndian()); + assert(!MEM_32bits()); + + for (;;) { + BYTE* olimit; + int stream; + + /* Assert loop preconditions */ +#ifndef NDEBUG + for (stream = 0; stream < 4; ++stream) { + assert(op[stream] <= (stream == 3 ? oend : op[stream + 1])); + assert(ip[stream] >= ilowest); + } +#endif + /* Compute olimit */ + { + /* Each iteration produces 5 output symbols per stream */ + size_t const oiters = (size_t)(oend - op[3]) / 5; + /* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes + * per stream. + */ + size_t const iiters = (size_t)(ip[0] - ilowest) / 7; + /* We can safely run iters iterations before running bounds checks */ + size_t const iters = MIN(oiters, iiters); + size_t const symbols = iters * 5; + + /* We can simply check that op[3] < olimit, instead of checking all + * of our bounds, since we can't hit the other bounds until we've run + * iters iterations, which only happens when op[3] == olimit. + */ + olimit = op[3] + symbols; + + /* Exit fast decoding loop once we reach the end. */ + if (op[3] == olimit) + break; + + /* Exit the decoding loop if any input pointer has crossed the + * previous one. This indicates corruption, and a precondition + * to our loop is that ip[i] >= ip[0]. + */ + for (stream = 1; stream < 4; ++stream) { + if (ip[stream] < ip[stream - 1]) + goto _out; + } + } + +#ifndef NDEBUG + for (stream = 1; stream < 4; ++stream) { + assert(ip[stream] >= ip[stream - 1]); + } +#endif + +#define HUF_4X1_DECODE_SYMBOL(_stream, _symbol) \ + do { \ + int const index = (int)(bits[(_stream)] >> 53); \ + int const entry = (int)dtable[index]; \ + bits[(_stream)] <<= (entry & 0x3F); \ + op[(_stream)][(_symbol)] = (BYTE)((entry >> 8) & 0xFF); \ + } while (0) + +#define HUF_4X1_RELOAD_STREAM(_stream) \ + do { \ + int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ + int const nbBits = ctz & 7; \ + int const nbBytes = ctz >> 3; \ + op[(_stream)] += 5; \ + ip[(_stream)] -= nbBytes; \ + bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ + bits[(_stream)] <<= nbBits; \ + } while (0) + + /* Manually unroll the loop because compilers don't consistently + * unroll the inner loops, which destroys performance. + */ + do { + /* Decode 5 symbols in each of the 4 streams */ + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 1); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 2); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 3); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X1_DECODE_SYMBOL, 4); + + /* Reload each of the 4 the bitstreams */ + HUF_4X_FOR_EACH_STREAM(HUF_4X1_RELOAD_STREAM); + } while (op[3] < olimit); + +#undef HUF_4X1_DECODE_SYMBOL +#undef HUF_4X1_RELOAD_STREAM + } -static HUF_ASM_X86_64_BMI2_ATTRS +_out: + + /* Save the final values of each of the state variables back to args. */ + ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); + ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); + ZSTD_memcpy(&args->op, &op, sizeof(op)); +} + +/* + * @returns @p dstSize on success (>= 6) + * 0 if the fallback implementation should be used + * An error if an error occurred + */ +static HUF_FAST_BMI2_ATTRS size_t -HUF_decompress4X1_usingDTable_internal_bmi2_asm( +HUF_decompress4X1_usingDTable_internal_fast( void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) + const HUF_DTable* DTable, + HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; - BYTE* const oend = (BYTE*)dst + dstSize; - HUF_DecompressAsmArgs args; - { - size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); - FORWARD_IF_ERROR(ret, "Failed to init asm args"); - if (ret != 0) - return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + BYTE const* const ilowest = (BYTE const*)cSrc; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); + HUF_DecompressFastArgs args; + { size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); + FORWARD_IF_ERROR(ret, "Failed to init fast loop args"); + if (ret == 0) + return 0; } - assert(args.ip[0] >= args.ilimit); - HUF_decompress4X1_usingDTable_internal_bmi2_asm_loop(&args); + assert(args.ip[0] >= args.ilowest); + loopFn(&args); - /* Our loop guarantees that ip[] >= ilimit and that we haven't + /* Our loop guarantees that ip[] >= ilowest and that we haven't * overwritten any op[]. */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bit streams one by one. */ - { - size_t const segmentSize = (dstSize+3) / 4; + { size_t const segmentSize = (dstSize+3) / 4; BYTE* segmentEnd = (BYTE*)dst; int i; for (i = 0; i < 4; ++i) { @@ -712,97 +885,59 @@ HUF_decompress4X1_usingDTable_internal_bmi2_asm( } /* decoded size */ + assert(dstSize != 0); return dstSize; } -#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ - -typedef size_t (*HUF_decompress_usingDTable_t)(void *dst, size_t dstSize, - const void *cSrc, - size_t cSrcSize, - const HUF_DTable *DTable); HUF_DGEN(HUF_decompress1X1_usingDTable_internal) static size_t HUF_decompress4X1_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) + size_t cSrcSize, HUF_DTable const* DTable, int flags) { + HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X1_usingDTable_internal_default; + HUF_DecompressFastLoopFn loopFn = HUF_decompress4X1_usingDTable_internal_fast_c_loop; + #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { + fallbackFn = HUF_decompress4X1_usingDTable_internal_bmi2; # if ZSTD_ENABLE_ASM_X86_64_BMI2 - return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -# else - return HUF_decompress4X1_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; + } # endif + } else { + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -#else - (void)bmi2; #endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) - return HUF_decompress4X1_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -#else - return HUF_decompress4X1_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X1_usingDTable_internal_fast_asm_loop; + } #endif -} - - -size_t HUF_decompress1X1_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 0) return ERROR(GENERIC); - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} -size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - const BYTE* ip = (const BYTE*) cSrc; - - size_t const hSize = HUF_readDTableX1_wksp(DCtx, cSrc, cSrcSize, workSpace, wkspSize); - if (HUF_isError(hSize)) return hSize; - if (hSize >= cSrcSize) return ERROR(srcSize_wrong); - ip += hSize; cSrcSize -= hSize; - - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); -} - - -size_t HUF_decompress4X1_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 0) return ERROR(GENERIC); - return HUF_decompress4X1_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); + if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { + size_t const ret = HUF_decompress4X1_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); + if (ret != 0) + return ret; + } + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -static size_t HUF_decompress4X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, +static size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; - size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); -} - -size_t HUF_decompress4X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, 0); + return HUF_decompress4X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } - #endif /* HUF_FORCE_DECOMPRESS_X2 */ @@ -985,7 +1120,7 @@ static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 targetLog, const U32 static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog, const sortedSymbol_t* sortedList, - const U32* rankStart, rankValCol_t *rankValOrigin, const U32 maxWeight, + const U32* rankStart, rankValCol_t* rankValOrigin, const U32 maxWeight, const U32 nbBitsBaseline) { U32* const rankVal = rankValOrigin[0]; @@ -1040,14 +1175,7 @@ typedef struct { size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_readDTableX2_wksp_bmi2(DTable, src, srcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - -size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, - const void* src, size_t srcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { U32 tableLog, maxW, nbSymbols; DTableDesc dtd = HUF_getDTableDesc(DTable); @@ -1069,7 +1197,7 @@ size_t HUF_readDTableX2_wksp_bmi2(HUF_DTable* DTable, if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge); /* ZSTD_memset(weightList, 0, sizeof(weightList)); */ /* is not necessary, even though some analyzer complain ... */ - iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), bmi2); + iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->calleeWksp, sizeof(wksp->calleeWksp), flags); if (HUF_isError(iSize)) return iSize; /* check result */ @@ -1159,15 +1287,19 @@ HUF_decodeLastSymbolX2(void* op, BIT_DStream_t* DStream, const HUF_DEltX2* dt, c } #define HUF_DECODE_SYMBOLX2_0(ptr, DStreamPtr) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) + do { ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); } while (0) -#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ - if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) +#define HUF_DECODE_SYMBOLX2_1(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits() || (HUF_TABLELOG_MAX<=12)) \ + ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ + } while (0) -#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ - if (MEM_64bits()) \ - ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog) +#define HUF_DECODE_SYMBOLX2_2(ptr, DStreamPtr) \ + do { \ + if (MEM_64bits()) \ + ptr += HUF_decodeSymbolX2(ptr, DStreamPtr, dt, dtLog); \ + } while (0) HINT_INLINE size_t HUF_decodeStreamX2(BYTE* p, BIT_DStream_t* bitDPtr, BYTE* const pEnd, @@ -1227,7 +1359,7 @@ HUF_decompress1X2_usingDTable_internal_body( /* decode */ { BYTE* const ostart = (BYTE*) dst; - BYTE* const oend = ostart + dstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, dstSize); const void* const dtPtr = DTable+1; /* force compiler to not use strict-aliasing */ const HUF_DEltX2* const dt = (const HUF_DEltX2*)dtPtr; DTableDesc const dtd = HUF_getDTableDesc(DTable); @@ -1240,6 +1372,11 @@ HUF_decompress1X2_usingDTable_internal_body( /* decoded size */ return dstSize; } + +/* HUF_decompress4X2_usingDTable_internal_body(): + * Conditions: + * @dstSize >= 6 + */ FORCE_INLINE_TEMPLATE size_t HUF_decompress4X2_usingDTable_internal_body( void* dst, size_t dstSize, @@ -1247,6 +1384,7 @@ HUF_decompress4X2_usingDTable_internal_body( const HUF_DTable* DTable) { if (cSrcSize < 10) return ERROR(corruption_detected); /* strict minimum : jump table + 1 byte per stream */ + if (dstSize < 6) return ERROR(corruption_detected); /* stream 4-split doesn't work */ { const BYTE* const istart = (const BYTE*) cSrc; BYTE* const ostart = (BYTE*) dst; @@ -1280,8 +1418,9 @@ HUF_decompress4X2_usingDTable_internal_body( DTableDesc const dtd = HUF_getDTableDesc(DTable); U32 const dtLog = dtd.tableLog; - if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ - if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + if (length4 > cSrcSize) return ERROR(corruption_detected); /* overflow */ + if (opStart4 > oend) return ERROR(corruption_detected); /* overflow */ + assert(dstSize >= 6 /* validated above */); CHECK_F( BIT_initDStream(&bitD1, istart1, length1) ); CHECK_F( BIT_initDStream(&bitD2, istart2, length2) ); CHECK_F( BIT_initDStream(&bitD3, istart3, length3) ); @@ -1366,44 +1505,191 @@ size_t HUF_decompress4X2_usingDTable_internal_bmi2(void* dst, size_t dstSize, vo } #endif -#if HUF_NEED_DEFAULT_FUNCTION static size_t HUF_decompress4X2_usingDTable_internal_default(void* dst, size_t dstSize, void const* cSrc, size_t cSrcSize, HUF_DTable const* DTable) { return HUF_decompress4X2_usingDTable_internal_body(dst, dstSize, cSrc, cSrcSize, DTable); } -#endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 -HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(HUF_DecompressAsmArgs* args) ZSTDLIB_HIDDEN; +HUF_ASM_DECL void HUF_decompress4X2_usingDTable_internal_fast_asm_loop(HUF_DecompressFastArgs* args) ZSTDLIB_HIDDEN; + +#endif + +static HUF_FAST_BMI2_ATTRS +void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs* args) +{ + U64 bits[4]; + BYTE const* ip[4]; + BYTE* op[4]; + BYTE* oend[4]; + HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt; + BYTE const* const ilowest = args->ilowest; + + /* Copy the arguments to local registers. */ + ZSTD_memcpy(&bits, &args->bits, sizeof(bits)); + ZSTD_memcpy((void*)(&ip), &args->ip, sizeof(ip)); + ZSTD_memcpy(&op, &args->op, sizeof(op)); + + oend[0] = op[1]; + oend[1] = op[2]; + oend[2] = op[3]; + oend[3] = args->oend; + + assert(MEM_isLittleEndian()); + assert(!MEM_32bits()); + + for (;;) { + BYTE* olimit; + int stream; + + /* Assert loop preconditions */ +#ifndef NDEBUG + for (stream = 0; stream < 4; ++stream) { + assert(op[stream] <= oend[stream]); + assert(ip[stream] >= ilowest); + } +#endif + /* Compute olimit */ + { + /* Each loop does 5 table lookups for each of the 4 streams. + * Each table lookup consumes up to 11 bits of input, and produces + * up to 2 bytes of output. + */ + /* We can consume up to 7 bytes of input per iteration per stream. + * We also know that each input pointer is >= ip[0]. So we can run + * iters loops before running out of input. + */ + size_t iters = (size_t)(ip[0] - ilowest) / 7; + /* Each iteration can produce up to 10 bytes of output per stream. + * Each output stream my advance at different rates. So take the + * minimum number of safe iterations among all the output streams. + */ + for (stream = 0; stream < 4; ++stream) { + size_t const oiters = (size_t)(oend[stream] - op[stream]) / 10; + iters = MIN(iters, oiters); + } + + /* Each iteration produces at least 5 output symbols. So until + * op[3] crosses olimit, we know we haven't executed iters + * iterations yet. This saves us maintaining an iters counter, + * at the expense of computing the remaining # of iterations + * more frequently. + */ + olimit = op[3] + (iters * 5); + + /* Exit the fast decoding loop once we reach the end. */ + if (op[3] == olimit) + break; + + /* Exit the decoding loop if any input pointer has crossed the + * previous one. This indicates corruption, and a precondition + * to our loop is that ip[i] >= ip[0]. + */ + for (stream = 1; stream < 4; ++stream) { + if (ip[stream] < ip[stream - 1]) + goto _out; + } + } + +#ifndef NDEBUG + for (stream = 1; stream < 4; ++stream) { + assert(ip[stream] >= ip[stream - 1]); + } +#endif -static HUF_ASM_X86_64_BMI2_ATTRS size_t -HUF_decompress4X2_usingDTable_internal_bmi2_asm( +#define HUF_4X2_DECODE_SYMBOL(_stream, _decode3) \ + do { \ + if ((_decode3) || (_stream) != 3) { \ + int const index = (int)(bits[(_stream)] >> 53); \ + HUF_DEltX2 const entry = dtable[index]; \ + MEM_write16(op[(_stream)], entry.sequence); \ + bits[(_stream)] <<= (entry.nbBits) & 0x3F; \ + op[(_stream)] += (entry.length); \ + } \ + } while (0) + +#define HUF_4X2_RELOAD_STREAM(_stream) \ + do { \ + HUF_4X2_DECODE_SYMBOL(3, 1); \ + { \ + int const ctz = ZSTD_countTrailingZeros64(bits[(_stream)]); \ + int const nbBits = ctz & 7; \ + int const nbBytes = ctz >> 3; \ + ip[(_stream)] -= nbBytes; \ + bits[(_stream)] = MEM_read64(ip[(_stream)]) | 1; \ + bits[(_stream)] <<= nbBits; \ + } \ + } while (0) + + /* Manually unroll the loop because compilers don't consistently + * unroll the inner loops, which destroys performance. + */ + do { + /* Decode 5 symbols from each of the first 3 streams. + * The final stream will be decoded during the reload phase + * to reduce register pressure. + */ + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + HUF_4X_FOR_EACH_STREAM_WITH_VAR(HUF_4X2_DECODE_SYMBOL, 0); + + /* Decode one symbol from the final stream */ + HUF_4X2_DECODE_SYMBOL(3, 1); + + /* Decode 4 symbols from the final stream & reload bitstreams. + * The final stream is reloaded last, meaning that all 5 symbols + * are decoded from the final stream before it is reloaded. + */ + HUF_4X_FOR_EACH_STREAM(HUF_4X2_RELOAD_STREAM); + } while (op[3] < olimit); + } + +#undef HUF_4X2_DECODE_SYMBOL +#undef HUF_4X2_RELOAD_STREAM + +_out: + + /* Save the final values of each of the state variables back to args. */ + ZSTD_memcpy(&args->bits, &bits, sizeof(bits)); + ZSTD_memcpy((void*)(&args->ip), &ip, sizeof(ip)); + ZSTD_memcpy(&args->op, &op, sizeof(op)); +} + + +static HUF_FAST_BMI2_ATTRS size_t +HUF_decompress4X2_usingDTable_internal_fast( void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) { + const HUF_DTable* DTable, + HUF_DecompressFastLoopFn loopFn) { void const* dt = DTable + 1; - const BYTE* const iend = (const BYTE*)cSrc + 6; - BYTE* const oend = (BYTE*)dst + dstSize; - HUF_DecompressAsmArgs args; + const BYTE* const ilowest = (const BYTE*)cSrc; + BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize); + HUF_DecompressFastArgs args; { - size_t const ret = HUF_DecompressAsmArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); + size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable); FORWARD_IF_ERROR(ret, "Failed to init asm args"); - if (ret != 0) - return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (ret == 0) + return 0; } - assert(args.ip[0] >= args.ilimit); - HUF_decompress4X2_usingDTable_internal_bmi2_asm_loop(&args); + assert(args.ip[0] >= args.ilowest); + loopFn(&args); /* note : op4 already verified within main loop */ - assert(args.ip[0] >= iend); - assert(args.ip[1] >= iend); - assert(args.ip[2] >= iend); - assert(args.ip[3] >= iend); + assert(args.ip[0] >= ilowest); + assert(args.ip[1] >= ilowest); + assert(args.ip[2] >= ilowest); + assert(args.ip[3] >= ilowest); assert(args.op[3] <= oend); - (void)iend; + + assert(ilowest == args.ilowest); + assert(ilowest + 6 == args.iend[0]); + (void)ilowest; /* finish bitStreams one by one */ { @@ -1426,91 +1712,72 @@ HUF_decompress4X2_usingDTable_internal_bmi2_asm( /* decoded size */ return dstSize; } -#endif /* ZSTD_ENABLE_ASM_X86_64_BMI2 */ static size_t HUF_decompress4X2_usingDTable_internal(void* dst, size_t dstSize, void const* cSrc, - size_t cSrcSize, HUF_DTable const* DTable, int bmi2) + size_t cSrcSize, HUF_DTable const* DTable, int flags) { + HUF_DecompressUsingDTableFn fallbackFn = HUF_decompress4X2_usingDTable_internal_default; + HUF_DecompressFastLoopFn loopFn = HUF_decompress4X2_usingDTable_internal_fast_c_loop; + #if DYNAMIC_BMI2 - if (bmi2) { + if (flags & HUF_flags_bmi2) { + fallbackFn = HUF_decompress4X2_usingDTable_internal_bmi2; # if ZSTD_ENABLE_ASM_X86_64_BMI2 - return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -# else - return HUF_decompress4X2_usingDTable_internal_bmi2(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; + } # endif + } else { + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } -#else - (void)bmi2; #endif #if ZSTD_ENABLE_ASM_X86_64_BMI2 && defined(__BMI2__) - return HUF_decompress4X2_usingDTable_internal_bmi2_asm(dst, dstSize, cSrc, cSrcSize, DTable); -#else - return HUF_decompress4X2_usingDTable_internal_default(dst, dstSize, cSrc, cSrcSize, DTable); + if (!(flags & HUF_flags_disableAsm)) { + loopFn = HUF_decompress4X2_usingDTable_internal_fast_asm_loop; + } #endif + + if (HUF_ENABLE_FAST_DECODE && !(flags & HUF_flags_disableFast)) { + size_t const ret = HUF_decompress4X2_usingDTable_internal_fast(dst, dstSize, cSrc, cSrcSize, DTable, loopFn); + if (ret != 0) + return ret; + } + return fallbackFn(dst, dstSize, cSrc, cSrcSize, DTable); } HUF_DGEN(HUF_decompress1X2_usingDTable_internal) -size_t HUF_decompress1X2_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 1) return ERROR(GENERIC); - return HUF_decompress1X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} - size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* DCtx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; size_t const hSize = HUF_readDTableX2_wksp(DCtx, cSrc, cSrcSize, - workSpace, wkspSize); + workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, /* bmi2 */ 0); + return HUF_decompress1X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, DCtx, flags); } - -size_t HUF_decompress4X2_usingDTable( - void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc dtd = HUF_getDTableDesc(DTable); - if (dtd.tableType != 1) return ERROR(GENERIC); - return HUF_decompress4X2_usingDTable_internal(dst, dstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -} - -static size_t HUF_decompress4X2_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, +static size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize, int bmi2) + void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; size_t hSize = HUF_readDTableX2_wksp(dctx, cSrc, cSrcSize, - workSpace, wkspSize); + workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); + return HUF_decompress4X2_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } -size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, - const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) -{ - return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, /* bmi2 */ 0); -} - - #endif /* HUF_FORCE_DECOMPRESS_X1 */ @@ -1518,44 +1785,6 @@ size_t HUF_decompress4X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, /* Universal decompression selectors */ /* ***********************************/ -size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc const dtd = HUF_getDTableDesc(DTable); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)dtd; - assert(dtd.tableType == 0); - return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)dtd; - assert(dtd.tableType == 1); - return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#else - return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : - HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#endif -} - -size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, - const void* cSrc, size_t cSrcSize, - const HUF_DTable* DTable) -{ - DTableDesc const dtd = HUF_getDTableDesc(DTable); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)dtd; - assert(dtd.tableType == 0); - return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)dtd; - assert(dtd.tableType == 1); - return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#else - return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0) : - HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, /* bmi2 */ 0); -#endif -} - #if !defined(HUF_FORCE_DECOMPRESS_X1) && !defined(HUF_FORCE_DECOMPRESS_X2) typedef struct { U32 tableTime; U32 decode256Time; } algo_time_t; @@ -1610,36 +1839,9 @@ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize) #endif } - -size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, - size_t dstSize, const void* cSrc, - size_t cSrcSize, void* workSpace, - size_t wkspSize) -{ - /* validation checks */ - if (dstSize == 0) return ERROR(dstSize_tooSmall); - if (cSrcSize == 0) return ERROR(corruption_detected); - - { U32 const algoNb = HUF_selectDecoder(dstSize, cSrcSize); -#if defined(HUF_FORCE_DECOMPRESS_X1) - (void)algoNb; - assert(algoNb == 0); - return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#elif defined(HUF_FORCE_DECOMPRESS_X2) - (void)algoNb; - assert(algoNb == 1); - return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#else - return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize): - HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize); -#endif - } -} - size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, - void* workSpace, size_t wkspSize) + void* workSpace, size_t wkspSize, int flags) { /* validation checks */ if (dstSize == 0) return ERROR(dstSize_tooSmall); @@ -1652,71 +1854,71 @@ size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, (void)algoNb; assert(algoNb == 0); return HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)algoNb; assert(algoNb == 1); return HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #else return algoNb ? HUF_decompress1X2_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize): + cSrcSize, workSpace, wkspSize, flags): HUF_decompress1X1_DCtx_wksp(dctx, dst, dstSize, cSrc, - cSrcSize, workSpace, wkspSize); + cSrcSize, workSpace, wkspSize, flags); #endif } } -size_t HUF_decompress1X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) +size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) { DTableDesc const dtd = HUF_getDTableDesc(DTable); #if defined(HUF_FORCE_DECOMPRESS_X1) (void)dtd; assert(dtd.tableType == 0); - return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)dtd; assert(dtd.tableType == 1); - return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #else - return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : - HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return dtd.tableType ? HUF_decompress1X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : + HUF_decompress1X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #endif } #ifndef HUF_FORCE_DECOMPRESS_X2 -size_t HUF_decompress1X1_DCtx_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) { const BYTE* ip = (const BYTE*) cSrc; - size_t const hSize = HUF_readDTableX1_wksp_bmi2(dctx, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + size_t const hSize = HUF_readDTableX1_wksp(dctx, cSrc, cSrcSize, workSpace, wkspSize, flags); if (HUF_isError(hSize)) return hSize; if (hSize >= cSrcSize) return ERROR(srcSize_wrong); ip += hSize; cSrcSize -= hSize; - return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, bmi2); + return HUF_decompress1X1_usingDTable_internal(dst, dstSize, ip, cSrcSize, dctx, flags); } #endif -size_t HUF_decompress4X_usingDTable_bmi2(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int bmi2) +size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags) { DTableDesc const dtd = HUF_getDTableDesc(DTable); #if defined(HUF_FORCE_DECOMPRESS_X1) (void)dtd; assert(dtd.tableType == 0); - return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)dtd; assert(dtd.tableType == 1); - return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #else - return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2) : - HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, bmi2); + return dtd.tableType ? HUF_decompress4X2_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags) : + HUF_decompress4X1_usingDTable_internal(dst, maxDstSize, cSrc, cSrcSize, DTable, flags); #endif } -size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int bmi2) +size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags) { /* validation checks */ if (dstSize == 0) return ERROR(dstSize_tooSmall); @@ -1726,15 +1928,14 @@ size_t HUF_decompress4X_hufOnly_wksp_bmi2(HUF_DTable* dctx, void* dst, size_t ds #if defined(HUF_FORCE_DECOMPRESS_X1) (void)algoNb; assert(algoNb == 0); - return HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #elif defined(HUF_FORCE_DECOMPRESS_X2) (void)algoNb; assert(algoNb == 1); - return HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #else - return algoNb ? HUF_decompress4X2_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2) : - HUF_decompress4X1_DCtx_wksp_bmi2(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, bmi2); + return algoNb ? HUF_decompress4X2_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags) : + HUF_decompress4X1_DCtx_wksp(dctx, dst, dstSize, cSrc, cSrcSize, workSpace, wkspSize, flags); #endif } } - diff --git a/lib/zstd/decompress/zstd_ddict.c b/lib/zstd/decompress/zstd_ddict.c index dbbc7919de53..30ef65e1ab5c 100644 --- a/lib/zstd/decompress/zstd_ddict.c +++ b/lib/zstd/decompress/zstd_ddict.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -14,12 +15,12 @@ /*-******************************************************* * Dependencies *********************************************************/ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customFree */ #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ #include "../common/cpu.h" /* bmi2 */ #include "../common/mem.h" /* low level memory routines */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "zstd_decompress_internal.h" #include "zstd_ddict.h" @@ -131,7 +132,7 @@ static size_t ZSTD_initDDict_internal(ZSTD_DDict* ddict, ZSTD_memcpy(internalBuffer, dict, dictSize); } ddict->dictSize = dictSize; - ddict->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + ddict->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ /* parse dictionary content */ FORWARD_IF_ERROR( ZSTD_loadEntropy_intoDDict(ddict, dictContentType) , ""); @@ -237,5 +238,5 @@ size_t ZSTD_sizeof_DDict(const ZSTD_DDict* ddict) unsigned ZSTD_getDictID_fromDDict(const ZSTD_DDict* ddict) { if (ddict==NULL) return 0; - return ZSTD_getDictID_fromDict(ddict->dictContent, ddict->dictSize); + return ddict->dictID; } diff --git a/lib/zstd/decompress/zstd_ddict.h b/lib/zstd/decompress/zstd_ddict.h index 8c1a79d666f8..de459a0dacd1 100644 --- a/lib/zstd/decompress/zstd_ddict.h +++ b/lib/zstd/decompress/zstd_ddict.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/decompress/zstd_decompress.c b/lib/zstd/decompress/zstd_decompress.c index 6b3177c94711..c9cbc45f6ed9 100644 --- a/lib/zstd/decompress/zstd_decompress.c +++ b/lib/zstd/decompress/zstd_decompress.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -53,13 +54,15 @@ * Dependencies *********************************************************/ #include "../common/zstd_deps.h" /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */ +#include "../common/allocations.h" /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */ +#include "../common/error_private.h" +#include "../common/zstd_internal.h" /* blockProperties_t */ #include "../common/mem.h" /* low level memory routines */ +#include "../common/bits.h" /* ZSTD_highbit32 */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include /* xxh64_reset, xxh64_update, xxh64_digest, XXH64 */ -#include "../common/zstd_internal.h" /* blockProperties_t */ #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ #include "zstd_decompress_block.h" /* ZSTD_decompressBlock_internal */ @@ -72,11 +75,11 @@ *************************************/ #define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4 -#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. - * Currently, that means a 0.75 load factor. - * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded - * the load factor of the ddict hash set. - */ +#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3 /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float. + * Currently, that means a 0.75 load factor. + * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded + * the load factor of the ddict hash set. + */ #define DDICT_HASHSET_TABLE_BASE_SIZE 64 #define DDICT_HASHSET_RESIZE_FACTOR 2 @@ -237,6 +240,8 @@ static void ZSTD_DCtx_resetParameters(ZSTD_DCtx* dctx) dctx->outBufferMode = ZSTD_bm_buffered; dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum; dctx->refMultipleDDicts = ZSTD_rmd_refSingleDDict; + dctx->disableHufAsm = 0; + dctx->maxBlockSizeParam = 0; } static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) @@ -253,6 +258,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx) dctx->streamStage = zdss_init; dctx->noForwardProgress = 0; dctx->oversizedDuration = 0; + dctx->isFrameDecompression = 1; #if DYNAMIC_BMI2 dctx->bmi2 = ZSTD_cpuSupportsBmi2(); #endif @@ -421,16 +427,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize) * note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless * @return : 0, `zfhPtr` is correctly filled, * >0, `srcSize` is too small, value is wanted `srcSize` amount, - * or an error code, which can be tested using ZSTD_isError() */ +** or an error code, which can be tested using ZSTD_isError() */ size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format) { const BYTE* ip = (const BYTE*)src; size_t const minInputSize = ZSTD_startingInputLength(format); - ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */ - if (srcSize < minInputSize) return minInputSize; - RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter"); + DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize); + + if (srcSize > 0) { + /* note : technically could be considered an assert(), since it's an invalid entry */ + RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0"); + } + if (srcSize < minInputSize) { + if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) { + /* when receiving less than @minInputSize bytes, + * control these bytes at least correspond to a supported magic number + * in order to error out early if they don't. + **/ + size_t const toCopy = MIN(4, srcSize); + unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER); + assert(src != NULL); + ZSTD_memcpy(hbuf, src, toCopy); + if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) { + /* not a zstd frame : let's check if it's a skippable frame */ + MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START); + ZSTD_memcpy(hbuf, src, toCopy); + if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) { + RETURN_ERROR(prefix_unknown, + "first bytes don't correspond to any supported magic number"); + } } } + return minInputSize; + } + ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr)); /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */ if ( (format != ZSTD_f_zstd1_magicless) && (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) { if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { @@ -540,61 +570,62 @@ static size_t readSkippableFrameSize(void const* src, size_t srcSize) sizeU32 = MEM_readLE32((BYTE const*)src + ZSTD_FRAMEIDSIZE); RETURN_ERROR_IF((U32)(sizeU32 + ZSTD_SKIPPABLEHEADERSIZE) < sizeU32, frameParameter_unsupported, ""); - { - size_t const skippableSize = skippableHeaderSize + sizeU32; + { size_t const skippableSize = skippableHeaderSize + sizeU32; RETURN_ERROR_IF(skippableSize > srcSize, srcSize_wrong, ""); return skippableSize; } } /*! ZSTD_readSkippableFrame() : - * Retrieves a zstd skippable frame containing data given by src, and writes it to dst buffer. + * Retrieves content of a skippable frame, and writes it to dst buffer. * * The parameter magicVariant will receive the magicVariant that was supplied when the frame was written, * i.e. magicNumber - ZSTD_MAGIC_SKIPPABLE_START. This can be NULL if the caller is not interested * in the magicVariant. * - * Returns an error if destination buffer is not large enough, or if the frame is not skippable. + * Returns an error if destination buffer is not large enough, or if this is not a valid skippable frame. * * @return : number of bytes written or a ZSTD error. */ -ZSTDLIB_API size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, unsigned* magicVariant, - const void* src, size_t srcSize) +size_t ZSTD_readSkippableFrame(void* dst, size_t dstCapacity, + unsigned* magicVariant, /* optional, can be NULL */ + const void* src, size_t srcSize) { - U32 const magicNumber = MEM_readLE32(src); - size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); - size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; - - /* check input validity */ - RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); - RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); - RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); + RETURN_ERROR_IF(srcSize < ZSTD_SKIPPABLEHEADERSIZE, srcSize_wrong, ""); - /* deliver payload */ - if (skippableContentSize > 0 && dst != NULL) - ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); - if (magicVariant != NULL) - *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; - return skippableContentSize; + { U32 const magicNumber = MEM_readLE32(src); + size_t skippableFrameSize = readSkippableFrameSize(src, srcSize); + size_t skippableContentSize = skippableFrameSize - ZSTD_SKIPPABLEHEADERSIZE; + + /* check input validity */ + RETURN_ERROR_IF(!ZSTD_isSkippableFrame(src, srcSize), frameParameter_unsupported, ""); + RETURN_ERROR_IF(skippableFrameSize < ZSTD_SKIPPABLEHEADERSIZE || skippableFrameSize > srcSize, srcSize_wrong, ""); + RETURN_ERROR_IF(skippableContentSize > dstCapacity, dstSize_tooSmall, ""); + + /* deliver payload */ + if (skippableContentSize > 0 && dst != NULL) + ZSTD_memcpy(dst, (const BYTE *)src + ZSTD_SKIPPABLEHEADERSIZE, skippableContentSize); + if (magicVariant != NULL) + *magicVariant = magicNumber - ZSTD_MAGIC_SKIPPABLE_START; + return skippableContentSize; + } } /* ZSTD_findDecompressedSize() : - * compatible with legacy mode * `srcSize` must be the exact length of some number of ZSTD compressed and/or * skippable frames - * @return : decompressed size of the frames contained */ + * note: compatible with legacy mode + * @return : decompressed size of the frames contained */ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) { - unsigned long long totalDstSize = 0; + U64 totalDstSize = 0; while (srcSize >= ZSTD_startingInputLength(ZSTD_f_zstd1)) { U32 const magicNumber = MEM_readLE32(src); if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { size_t const skippableSize = readSkippableFrameSize(src, srcSize); - if (ZSTD_isError(skippableSize)) { - return ZSTD_CONTENTSIZE_ERROR; - } + if (ZSTD_isError(skippableSize)) return ZSTD_CONTENTSIZE_ERROR; assert(skippableSize <= srcSize); src = (const BYTE *)src + skippableSize; @@ -602,17 +633,17 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) continue; } - { unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize); - if (ret >= ZSTD_CONTENTSIZE_ERROR) return ret; + { unsigned long long const fcs = ZSTD_getFrameContentSize(src, srcSize); + if (fcs >= ZSTD_CONTENTSIZE_ERROR) return fcs; - /* check for overflow */ - if (totalDstSize + ret < totalDstSize) return ZSTD_CONTENTSIZE_ERROR; - totalDstSize += ret; + if (U64_MAX - totalDstSize < fcs) + return ZSTD_CONTENTSIZE_ERROR; /* check for overflow */ + totalDstSize += fcs; } + /* skip to next frame */ { size_t const frameSrcSize = ZSTD_findFrameCompressedSize(src, srcSize); - if (ZSTD_isError(frameSrcSize)) { - return ZSTD_CONTENTSIZE_ERROR; - } + if (ZSTD_isError(frameSrcSize)) return ZSTD_CONTENTSIZE_ERROR; + assert(frameSrcSize <= srcSize); src = (const BYTE *)src + frameSrcSize; srcSize -= frameSrcSize; @@ -676,13 +707,13 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret) return frameSizeInfo; } -static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize) +static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format) { ZSTD_frameSizeInfo frameSizeInfo; ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo)); - if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE) + if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE) && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize); assert(ZSTD_isError(frameSizeInfo.compressedSize) || @@ -696,7 +727,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize ZSTD_frameHeader zfh; /* Extract Frame Header */ - { size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize); + { size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format); if (ZSTD_isError(ret)) return ZSTD_errorFrameSizeInfo(ret); if (ret > 0) @@ -730,23 +761,26 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize ip += 4; } + frameSizeInfo.nbBlocks = nbBlocks; frameSizeInfo.compressedSize = (size_t)(ip - ipstart); frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) ? zfh.frameContentSize - : nbBlocks * zfh.blockSizeMax; + : (unsigned long long)nbBlocks * zfh.blockSizeMax; return frameSizeInfo; } } +static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format); + return frameSizeInfo.compressedSize; +} + /* ZSTD_findFrameCompressedSize() : - * compatible with legacy mode - * `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame - * `srcSize` must be at least as large as the frame contained - * @return : the compressed size of the frame starting at `src` */ + * See docs in zstd.h + * Note: compatible with legacy mode */ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); - return frameSizeInfo.compressedSize; + return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1); } /* ZSTD_decompressBound() : @@ -760,7 +794,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) unsigned long long bound = 0; /* Iterate over each frame */ while (srcSize > 0) { - ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); size_t const compressedSize = frameSizeInfo.compressedSize; unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) @@ -773,6 +807,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) return bound; } +size_t ZSTD_decompressionMargin(void const* src, size_t srcSize) +{ + size_t margin = 0; + unsigned maxBlockSize = 0; + + /* Iterate over each frame */ + while (srcSize > 0) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1); + size_t const compressedSize = frameSizeInfo.compressedSize; + unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; + ZSTD_frameHeader zfh; + + FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), ""); + if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) + return ERROR(corruption_detected); + + if (zfh.frameType == ZSTD_frame) { + /* Add the frame header to our margin */ + margin += zfh.headerSize; + /* Add the checksum to our margin */ + margin += zfh.checksumFlag ? 4 : 0; + /* Add 3 bytes per block */ + margin += 3 * frameSizeInfo.nbBlocks; + + /* Compute the max block size */ + maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax); + } else { + assert(zfh.frameType == ZSTD_skippableFrame); + /* Add the entire skippable frame size to our margin. */ + margin += compressedSize; + } + + assert(srcSize >= compressedSize); + src = (const BYTE*)src + compressedSize; + srcSize -= compressedSize; + } + + /* Add the max block size back to the margin. */ + margin += maxBlockSize; + + return margin; +} /*-************************************************************* * Frame decoding @@ -856,6 +932,10 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, ip += frameHeaderSize; remainingSrcSize -= frameHeaderSize; } + /* Shrink the blockSizeMax if enabled */ + if (dctx->maxBlockSizeParam != 0) + dctx->fParams.blockSizeMax = MIN(dctx->fParams.blockSizeMax, (unsigned)dctx->maxBlockSizeParam); + /* Loop on each block */ while (1) { BYTE* oBlockEnd = oend; @@ -888,7 +968,8 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, switch(blockProperties.blockType) { case bt_compressed: - decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming); + assert(dctx->isFrameDecompression == 1); + decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, not_streaming); break; case bt_raw : /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */ @@ -901,12 +982,14 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, default: RETURN_ERROR(corruption_detected, "invalid block type"); } - - if (ZSTD_isError(decodedSize)) return decodedSize; - if (dctx->validateChecksum) + FORWARD_IF_ERROR(decodedSize, "Block decompression failure"); + DEBUGLOG(5, "Decompressed block of dSize = %u", (unsigned)decodedSize); + if (dctx->validateChecksum) { xxh64_update(&dctx->xxhState, op, decodedSize); - if (decodedSize != 0) + } + if (decodedSize) /* support dst = NULL,0 */ { op += decodedSize; + } assert(ip != NULL); ip += cBlockSize; remainingSrcSize -= cBlockSize; @@ -930,12 +1013,15 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, } ZSTD_DCtx_trace_end(dctx, (U64)(op-ostart), (U64)(ip-istart), /* streaming */ 0); /* Allow caller to get size read */ + DEBUGLOG(4, "ZSTD_decompressFrame: decompressed frame of size %zi, consuming %zi bytes of input", op-ostart, ip - (const BYTE*)*srcPtr); *srcPtr = ip; *srcSizePtr = remainingSrcSize; return (size_t)(op-ostart); } -static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, +static +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR +size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, const void* dict, size_t dictSize, @@ -955,17 +1041,18 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, while (srcSize >= ZSTD_startingInputLength(dctx->format)) { - { U32 const magicNumber = MEM_readLE32(src); - DEBUGLOG(4, "reading magic number %08X (expecting %08X)", - (unsigned)magicNumber, ZSTD_MAGICNUMBER); + if (dctx->format == ZSTD_f_zstd1 && srcSize >= 4) { + U32 const magicNumber = MEM_readLE32(src); + DEBUGLOG(5, "reading magic number %08X", (unsigned)magicNumber); if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { + /* skippable frame detected : skip it */ size_t const skippableSize = readSkippableFrameSize(src, srcSize); - FORWARD_IF_ERROR(skippableSize, "readSkippableFrameSize failed"); + FORWARD_IF_ERROR(skippableSize, "invalid skippable frame"); assert(skippableSize <= srcSize); src = (const BYTE *)src + skippableSize; srcSize -= skippableSize; - continue; + continue; /* check next frame */ } } if (ddict) { @@ -1061,8 +1148,8 @@ size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t sr size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx) { return dctx->expected; } /* - * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, - * we allow taking a partial block as the input. Currently only raw uncompressed blocks can + * Similar to ZSTD_nextSrcSizeToDecompress(), but when a block input can be streamed, we + * allow taking a partial block as the input. Currently only raw uncompressed blocks can * be streamed. * * For blocks that can be streamed, this allows us to reduce the latency until we produce @@ -1181,7 +1268,8 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c { case bt_compressed: DEBUGLOG(5, "ZSTD_decompressContinue: case bt_compressed"); - rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 1, is_streaming); + assert(dctx->isFrameDecompression == 1); + rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, is_streaming); dctx->expected = 0; /* Streaming not supported */ break; case bt_raw : @@ -1250,6 +1338,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c case ZSTDds_decodeSkippableHeader: assert(src != NULL); assert(srcSize <= ZSTD_SKIPPABLEHEADERSIZE); + assert(dctx->format != ZSTD_f_zstd1_magicless); ZSTD_memcpy(dctx->headerBuffer + (ZSTD_SKIPPABLEHEADERSIZE - srcSize), src, srcSize); /* complete skippable header */ dctx->expected = MEM_readLE32(dctx->headerBuffer + ZSTD_FRAMEIDSIZE); /* note : dctx->expected can grow seriously large, beyond local buffer size */ dctx->stage = ZSTDds_skipFrame; @@ -1262,7 +1351,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c default: assert(0); /* impossible */ - RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ + RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ } } @@ -1303,11 +1392,11 @@ ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy, /* in minimal huffman, we always use X1 variants */ size_t const hSize = HUF_readDTableX1_wksp(entropy->hufTable, dictPtr, dictEnd - dictPtr, - workspace, workspaceSize); + workspace, workspaceSize, /* flags */ 0); #else size_t const hSize = HUF_readDTableX2_wksp(entropy->hufTable, dictPtr, (size_t)(dictEnd - dictPtr), - workspace, workspaceSize); + workspace, workspaceSize, /* flags */ 0); #endif RETURN_ERROR_IF(HUF_isError(hSize), dictionary_corrupted, ""); dictPtr += hSize; @@ -1403,10 +1492,11 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx) dctx->prefixStart = NULL; dctx->virtualStart = NULL; dctx->dictEnd = NULL; - dctx->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001); /* cover both little and big endian */ + dctx->entropy.hufTable[0] = (HUF_DTable)((ZSTD_HUFFDTABLE_CAPACITY_LOG)*0x1000001); /* cover both little and big endian */ dctx->litEntropy = dctx->fseEntropy = 0; dctx->dictID = 0; dctx->bType = bt_reserved; + dctx->isFrameDecompression = 1; ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue)); ZSTD_memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue)); /* initial repcodes */ dctx->LLTptr = dctx->entropy.LLTable; @@ -1465,7 +1555,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) * This could for one of the following reasons : * - The frame does not require a dictionary (most common case). * - The frame was built with dictID intentionally removed. - * Needed dictionary is a hidden information. + * Needed dictionary is a hidden piece of information. * Note : this use case also happens when using a non-conformant dictionary. * - `srcSize` is too small, and as a result, frame header could not be decoded. * Note : possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`. @@ -1474,7 +1564,7 @@ unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize) * ZSTD_getFrameHeader(), which will provide a more precise error code. */ unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize) { - ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0 }; + ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0, 0, 0 }; size_t const hError = ZSTD_getFrameHeader(&zfp, src, srcSize); if (ZSTD_isError(hError)) return 0; return zfp.dictID; @@ -1581,7 +1671,9 @@ size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t di size_t ZSTD_initDStream(ZSTD_DStream* zds) { DEBUGLOG(4, "ZSTD_initDStream"); - return ZSTD_initDStream_usingDDict(zds, NULL); + FORWARD_IF_ERROR(ZSTD_DCtx_reset(zds, ZSTD_reset_session_only), ""); + FORWARD_IF_ERROR(ZSTD_DCtx_refDDict(zds, NULL), ""); + return ZSTD_startingInputLength(zds->format); } /* ZSTD_initDStream_usingDDict() : @@ -1589,6 +1681,7 @@ size_t ZSTD_initDStream(ZSTD_DStream* zds) * this function cannot fail */ size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) { + DEBUGLOG(4, "ZSTD_initDStream_usingDDict"); FORWARD_IF_ERROR( ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only) , ""); FORWARD_IF_ERROR( ZSTD_DCtx_refDDict(dctx, ddict) , ""); return ZSTD_startingInputLength(dctx->format); @@ -1599,6 +1692,7 @@ size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict) * this function cannot fail */ size_t ZSTD_resetDStream(ZSTD_DStream* dctx) { + DEBUGLOG(4, "ZSTD_resetDStream"); FORWARD_IF_ERROR(ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only), ""); return ZSTD_startingInputLength(dctx->format); } @@ -1670,6 +1764,15 @@ ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam) bounds.lowerBound = (int)ZSTD_rmd_refSingleDDict; bounds.upperBound = (int)ZSTD_rmd_refMultipleDDicts; return bounds; + case ZSTD_d_disableHuffmanAssembly: + bounds.lowerBound = 0; + bounds.upperBound = 1; + return bounds; + case ZSTD_d_maxBlockSize: + bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN; + bounds.upperBound = ZSTD_BLOCKSIZE_MAX; + return bounds; + default:; } bounds.error = ERROR(parameter_unsupported); @@ -1710,6 +1813,12 @@ size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParameter param, int* value case ZSTD_d_refMultipleDDicts: *value = (int)dctx->refMultipleDDicts; return 0; + case ZSTD_d_disableHuffmanAssembly: + *value = (int)dctx->disableHufAsm; + return 0; + case ZSTD_d_maxBlockSize: + *value = dctx->maxBlockSizeParam; + return 0; default:; } RETURN_ERROR(parameter_unsupported, ""); @@ -1743,6 +1852,14 @@ size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value } dctx->refMultipleDDicts = (ZSTD_refMultipleDDicts_e)value; return 0; + case ZSTD_d_disableHuffmanAssembly: + CHECK_DBOUNDS(ZSTD_d_disableHuffmanAssembly, value); + dctx->disableHufAsm = value != 0; + return 0; + case ZSTD_d_maxBlockSize: + if (value != 0) CHECK_DBOUNDS(ZSTD_d_maxBlockSize, value); + dctx->maxBlockSizeParam = value; + return 0; default:; } RETURN_ERROR(parameter_unsupported, ""); @@ -1754,6 +1871,7 @@ size_t ZSTD_DCtx_reset(ZSTD_DCtx* dctx, ZSTD_ResetDirective reset) || (reset == ZSTD_reset_session_and_parameters) ) { dctx->streamStage = zdss_init; dctx->noForwardProgress = 0; + dctx->isFrameDecompression = 1; } if ( (reset == ZSTD_reset_parameters) || (reset == ZSTD_reset_session_and_parameters) ) { @@ -1770,11 +1888,17 @@ size_t ZSTD_sizeof_DStream(const ZSTD_DStream* dctx) return ZSTD_sizeof_DCtx(dctx); } -size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) +static size_t ZSTD_decodingBufferSize_internal(unsigned long long windowSize, unsigned long long frameContentSize, size_t blockSizeMax) { - size_t const blockSize = (size_t) MIN(windowSize, ZSTD_BLOCKSIZE_MAX); - /* space is needed to store the litbuffer after the output of a given block without stomping the extDict of a previous run, as well as to cover both windows against wildcopy*/ - unsigned long long const neededRBSize = windowSize + blockSize + ZSTD_BLOCKSIZE_MAX + (WILDCOPY_OVERLENGTH * 2); + size_t const blockSize = MIN((size_t)MIN(windowSize, ZSTD_BLOCKSIZE_MAX), blockSizeMax); + /* We need blockSize + WILDCOPY_OVERLENGTH worth of buffer so that if a block + * ends at windowSize + WILDCOPY_OVERLENGTH + 1 bytes, we can start writing + * the block at the beginning of the output buffer, and maintain a full window. + * + * We need another blockSize worth of buffer so that we can store split + * literals at the end of the block without overwriting the extDict window. + */ + unsigned long long const neededRBSize = windowSize + (blockSize * 2) + (WILDCOPY_OVERLENGTH * 2); unsigned long long const neededSize = MIN(frameContentSize, neededRBSize); size_t const minRBSize = (size_t) neededSize; RETURN_ERROR_IF((unsigned long long)minRBSize != neededSize, @@ -1782,6 +1906,11 @@ size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long return minRBSize; } +size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize) +{ + return ZSTD_decodingBufferSize_internal(windowSize, frameContentSize, ZSTD_BLOCKSIZE_MAX); +} + size_t ZSTD_estimateDStreamSize(size_t windowSize) { size_t const blockSize = MIN(windowSize, ZSTD_BLOCKSIZE_MAX); @@ -1918,7 +2047,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->refMultipleDDicts && zds->ddictSet) { ZSTD_DCtx_selectFrameDDict(zds); } - DEBUGLOG(5, "header size : %u", (U32)hSize); if (ZSTD_isError(hSize)) { return hSize; /* error */ } @@ -1932,6 +2060,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->lhSize += remainingInput; } input->pos = input->size; + /* check first few bytes */ + FORWARD_IF_ERROR( + ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format), + "First few bytes detected incorrect" ); + /* return hint input size */ return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize; /* remaining header bytes + next block header */ } assert(ip != NULL); @@ -1943,14 +2076,15 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN && zds->fParams.frameType != ZSTD_skippableFrame && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) { - size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart)); + size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format); if (cSize <= (size_t)(iend-istart)) { /* shortcut : using single-pass mode */ size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds)); if (ZSTD_isError(decompressedSize)) return decompressedSize; - DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()") + DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()"); + assert(istart != NULL); ip = istart + cSize; - op += decompressedSize; + op = op ? op + decompressedSize : op; /* can occur if frameContentSize = 0 (empty frame) */ zds->expected = 0; zds->streamStage = zdss_init; someMoreWork = 0; @@ -1969,7 +2103,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB DEBUGLOG(4, "Consume header"); FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)), ""); - if ((MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ + if (zds->format == ZSTD_f_zstd1 + && (MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) { /* skippable frame */ zds->expected = MEM_readLE32(zds->headerBuffer + ZSTD_FRAMEIDSIZE); zds->stage = ZSTDds_skipFrame; } else { @@ -1985,11 +2120,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN); RETURN_ERROR_IF(zds->fParams.windowSize > zds->maxWindowSize, frameParameter_windowTooLarge, ""); + if (zds->maxBlockSizeParam != 0) + zds->fParams.blockSizeMax = MIN(zds->fParams.blockSizeMax, (unsigned)zds->maxBlockSizeParam); /* Adapt buffer sizes to frame header instructions */ { size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */); size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_bm_buffered - ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize) + ? ZSTD_decodingBufferSize_internal(zds->fParams.windowSize, zds->fParams.frameContentSize, zds->fParams.blockSizeMax) : 0; ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize); @@ -2034,6 +2171,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB } if ((size_t)(iend-ip) >= neededInSize) { /* decode directly from src */ FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, ip, neededInSize), ""); + assert(ip != NULL); ip += neededInSize; /* Function modifies the stage so we must break */ break; @@ -2048,7 +2186,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB int const isSkipFrame = ZSTD_isSkipFrame(zds); size_t loadedSize; /* At this point we shouldn't be decompressing a block that we can stream. */ - assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, iend - ip)); + assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, (size_t)(iend - ip))); if (isSkipFrame) { loadedSize = MIN(toLoad, (size_t)(iend-ip)); } else { @@ -2057,8 +2195,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB "should never happen"); loadedSize = ZSTD_limitCopy(zds->inBuff + zds->inPos, toLoad, ip, (size_t)(iend-ip)); } - ip += loadedSize; - zds->inPos += loadedSize; + if (loadedSize != 0) { + /* ip may be NULL */ + ip += loadedSize; + zds->inPos += loadedSize; + } if (loadedSize < toLoad) { someMoreWork = 0; break; } /* not enough input, wait for more */ /* decode loaded input */ @@ -2068,14 +2209,17 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB break; } case zdss_flush: - { size_t const toFlushSize = zds->outEnd - zds->outStart; + { + size_t const toFlushSize = zds->outEnd - zds->outStart; size_t const flushedSize = ZSTD_limitCopy(op, (size_t)(oend-op), zds->outBuff + zds->outStart, toFlushSize); - op += flushedSize; + + op = op ? op + flushedSize : op; + zds->outStart += flushedSize; if (flushedSize == toFlushSize) { /* flush completed */ zds->streamStage = zdss_read; if ( (zds->outBuffSize < zds->fParams.frameContentSize) - && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { + && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) { DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)", (int)(zds->outBuffSize - zds->outStart), (U32)zds->fParams.blockSizeMax); @@ -2089,7 +2233,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB default: assert(0); /* impossible */ - RETURN_ERROR(GENERIC, "impossible to reach"); /* some compiler require default to do something */ + RETURN_ERROR(GENERIC, "impossible to reach"); /* some compilers require default to do something */ } } /* result */ @@ -2102,8 +2246,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB if ((ip==istart) && (op==ostart)) { /* no forward progress */ zds->noForwardProgress ++; if (zds->noForwardProgress >= ZSTD_NO_FORWARD_PROGRESS_MAX) { - RETURN_ERROR_IF(op==oend, dstSize_tooSmall, ""); - RETURN_ERROR_IF(ip==iend, srcSize_wrong, ""); + RETURN_ERROR_IF(op==oend, noForwardProgress_destFull, ""); + RETURN_ERROR_IF(ip==iend, noForwardProgress_inputEmpty, ""); assert(0); } } else { @@ -2140,11 +2284,17 @@ size_t ZSTD_decompressStream_simpleArgs ( void* dst, size_t dstCapacity, size_t* dstPos, const void* src, size_t srcSize, size_t* srcPos) { - ZSTD_outBuffer output = { dst, dstCapacity, *dstPos }; - ZSTD_inBuffer input = { src, srcSize, *srcPos }; - /* ZSTD_compress_generic() will check validity of dstPos and srcPos */ - size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); - *dstPos = output.pos; - *srcPos = input.pos; - return cErr; + ZSTD_outBuffer output; + ZSTD_inBuffer input; + output.dst = dst; + output.size = dstCapacity; + output.pos = *dstPos; + input.src = src; + input.size = srcSize; + input.pos = *srcPos; + { size_t const cErr = ZSTD_decompressStream(dctx, &output, &input); + *dstPos = output.pos; + *srcPos = input.pos; + return cErr; + } } diff --git a/lib/zstd/decompress/zstd_decompress_block.c b/lib/zstd/decompress/zstd_decompress_block.c index c1913b8e7c89..9fe9a12c8a2c 100644 --- a/lib/zstd/decompress/zstd_decompress_block.c +++ b/lib/zstd/decompress/zstd_decompress_block.c @@ -1,5 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -20,12 +21,12 @@ #include "../common/mem.h" /* low level memory routines */ #define FSE_STATIC_LINKING_ONLY #include "../common/fse.h" -#define HUF_STATIC_LINKING_ONLY #include "../common/huf.h" #include "../common/zstd_internal.h" #include "zstd_decompress_internal.h" /* ZSTD_DCtx */ #include "zstd_ddict.h" /* ZSTD_DDictDictContent */ #include "zstd_decompress_block.h" +#include "../common/bits.h" /* ZSTD_highbit32 */ /*_******************************************************* * Macros @@ -51,6 +52,13 @@ static void ZSTD_copy4(void* dst, const void* src) { ZSTD_memcpy(dst, src, 4); } * Block decoding ***************************************************************/ +static size_t ZSTD_blockSizeMax(ZSTD_DCtx const* dctx) +{ + size_t const blockSizeMax = dctx->isFrameDecompression ? dctx->fParams.blockSizeMax : ZSTD_BLOCKSIZE_MAX; + assert(blockSizeMax <= ZSTD_BLOCKSIZE_MAX); + return blockSizeMax; +} + /*! ZSTD_getcBlockSize() : * Provides the size of compressed block from block header `src` */ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, @@ -73,41 +81,49 @@ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, static void ZSTD_allocateLiteralsBuffer(ZSTD_DCtx* dctx, void* const dst, const size_t dstCapacity, const size_t litSize, const streaming_operation streaming, const size_t expectedWriteSize, const unsigned splitImmediately) { - if (streaming == not_streaming && dstCapacity > ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) - { - /* room for litbuffer to fit without read faulting */ - dctx->litBuffer = (BYTE*)dst + ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH; + size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); + assert(litSize <= blockSizeMax); + assert(dctx->isFrameDecompression || streaming == not_streaming); + assert(expectedWriteSize <= blockSizeMax); + if (streaming == not_streaming && dstCapacity > blockSizeMax + WILDCOPY_OVERLENGTH + litSize + WILDCOPY_OVERLENGTH) { + /* If we aren't streaming, we can just put the literals after the output + * of the current block. We don't need to worry about overwriting the + * extDict of our window, because it doesn't exist. + * So if we have space after the end of the block, just put it there. + */ + dctx->litBuffer = (BYTE*)dst + blockSizeMax + WILDCOPY_OVERLENGTH; dctx->litBufferEnd = dctx->litBuffer + litSize; dctx->litBufferLocation = ZSTD_in_dst; - } - else if (litSize > ZSTD_LITBUFFEREXTRASIZE) - { - /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ + } else if (litSize <= ZSTD_LITBUFFEREXTRASIZE) { + /* Literals fit entirely within the extra buffer, put them there to avoid + * having to split the literals. + */ + dctx->litBuffer = dctx->litExtraBuffer; + dctx->litBufferEnd = dctx->litBuffer + litSize; + dctx->litBufferLocation = ZSTD_not_in_dst; + } else { + assert(blockSizeMax > ZSTD_LITBUFFEREXTRASIZE); + /* Literals must be split between the output block and the extra lit + * buffer. We fill the extra lit buffer with the tail of the literals, + * and put the rest of the literals at the end of the block, with + * WILDCOPY_OVERLENGTH of buffer room to allow for overreads. + * This MUST not write more than our maxBlockSize beyond dst, because in + * streaming mode, that could overwrite part of our extDict window. + */ if (splitImmediately) { /* won't fit in litExtraBuffer, so it will be split between end of dst and extra buffer */ dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; dctx->litBufferEnd = dctx->litBuffer + litSize - ZSTD_LITBUFFEREXTRASIZE; - } - else { - /* initially this will be stored entirely in dst during huffman decoding, it will partially shifted to litExtraBuffer after */ + } else { + /* initially this will be stored entirely in dst during huffman decoding, it will partially be shifted to litExtraBuffer after */ dctx->litBuffer = (BYTE*)dst + expectedWriteSize - litSize; dctx->litBufferEnd = (BYTE*)dst + expectedWriteSize; } dctx->litBufferLocation = ZSTD_split; - } - else - { - /* fits entirely within litExtraBuffer, so no split is necessary */ - dctx->litBuffer = dctx->litExtraBuffer; - dctx->litBufferEnd = dctx->litBuffer + litSize; - dctx->litBufferLocation = ZSTD_not_in_dst; + assert(dctx->litBufferEnd <= (BYTE*)dst + expectedWriteSize); } } -/* Hidden declaration for fullbench */ -size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, - const void* src, size_t srcSize, - void* dst, size_t dstCapacity, const streaming_operation streaming); /*! ZSTD_decodeLiteralsBlock() : * Where it is possible to do so without being stomped by the output during decompression, the literals block will be stored * in the dstBuffer. If there is room to do so, it will be stored in full in the excess dst space after where the current @@ -116,7 +132,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, * * @return : nb of bytes read from src (< srcSize ) * note : symbol not declared but exposed for fullbench */ -size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, +static size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, const void* src, size_t srcSize, /* note : srcSize < BLOCKSIZE */ void* dst, size_t dstCapacity, const streaming_operation streaming) { @@ -125,6 +141,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, { const BYTE* const istart = (const BYTE*) src; symbolEncodingType_e const litEncType = (symbolEncodingType_e)(istart[0] & 3); + size_t const blockSizeMax = ZSTD_blockSizeMax(dctx); switch(litEncType) { @@ -134,13 +151,16 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, ZSTD_FALLTHROUGH; case set_compressed: - RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3"); + RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need up to 5 for case 3"); { size_t lhSize, litSize, litCSize; U32 singleStream=0; U32 const lhlCode = (istart[0] >> 2) & 3; U32 const lhc = MEM_readLE32(istart); size_t hufSuccess; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); + int const flags = 0 + | (ZSTD_DCtx_get_bmi2(dctx) ? HUF_flags_bmi2 : 0) + | (dctx->disableHufAsm ? HUF_flags_disableAsm : 0); switch(lhlCode) { case 0: case 1: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -164,7 +184,11 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); - RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); + if (!singleStream) + RETURN_ERROR_IF(litSize < MIN_LITERALS_FOR_4_STREAMS, literals_headerWrong, + "Not enough literals (%zu) for the 4-streams mode (min %u)", + litSize, MIN_LITERALS_FOR_4_STREAMS); RETURN_ERROR_IF(litCSize + lhSize > srcSize, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize , dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 0); @@ -176,13 +200,14 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, if (litEncType==set_repeat) { if (singleStream) { - hufSuccess = HUF_decompress1X_usingDTable_bmi2( + hufSuccess = HUF_decompress1X_usingDTable( dctx->litBuffer, litSize, istart+lhSize, litCSize, - dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); + dctx->HUFptr, flags); } else { - hufSuccess = HUF_decompress4X_usingDTable_bmi2( + assert(litSize >= MIN_LITERALS_FOR_4_STREAMS); + hufSuccess = HUF_decompress4X_usingDTable( dctx->litBuffer, litSize, istart+lhSize, litCSize, - dctx->HUFptr, ZSTD_DCtx_get_bmi2(dctx)); + dctx->HUFptr, flags); } } else { if (singleStream) { @@ -190,26 +215,28 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, hufSuccess = HUF_decompress1X_DCtx_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace)); + sizeof(dctx->workspace), flags); #else - hufSuccess = HUF_decompress1X1_DCtx_wksp_bmi2( + hufSuccess = HUF_decompress1X1_DCtx_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); + sizeof(dctx->workspace), flags); #endif } else { - hufSuccess = HUF_decompress4X_hufOnly_wksp_bmi2( + hufSuccess = HUF_decompress4X_hufOnly_wksp( dctx->entropy.hufTable, dctx->litBuffer, litSize, istart+lhSize, litCSize, dctx->workspace, - sizeof(dctx->workspace), ZSTD_DCtx_get_bmi2(dctx)); + sizeof(dctx->workspace), flags); } } if (dctx->litBufferLocation == ZSTD_split) { + assert(litSize > ZSTD_LITBUFFEREXTRASIZE); ZSTD_memcpy(dctx->litExtraBuffer, dctx->litBufferEnd - ZSTD_LITBUFFEREXTRASIZE, ZSTD_LITBUFFEREXTRASIZE); ZSTD_memmove(dctx->litBuffer + ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH, dctx->litBuffer, litSize - ZSTD_LITBUFFEREXTRASIZE); dctx->litBuffer += ZSTD_LITBUFFEREXTRASIZE - WILDCOPY_OVERLENGTH; dctx->litBufferEnd -= WILDCOPY_OVERLENGTH; + assert(dctx->litBufferEnd <= (BYTE*)dst + blockSizeMax); } RETURN_ERROR_IF(HUF_isError(hufSuccess), corruption_detected, ""); @@ -224,7 +251,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, case set_basic: { size_t litSize, lhSize; U32 const lhlCode = ((istart[0]) >> 2) & 3; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); switch(lhlCode) { case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -237,11 +264,13 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; case 3: lhSize = 3; + RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize = 3"); litSize = MEM_readLE24(istart) >> 4; break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); if (lhSize+litSize+WILDCOPY_OVERLENGTH > srcSize) { /* risk reading beyond src buffer with wildcopy */ @@ -270,7 +299,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, case set_rle: { U32 const lhlCode = ((istart[0]) >> 2) & 3; size_t litSize, lhSize; - size_t expectedWriteSize = MIN(ZSTD_BLOCKSIZE_MAX, dstCapacity); + size_t expectedWriteSize = MIN(blockSizeMax, dstCapacity); switch(lhlCode) { case 0: case 2: default: /* note : default is impossible, since lhlCode into [0..3] */ @@ -279,16 +308,17 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, break; case 1: lhSize = 2; + RETURN_ERROR_IF(srcSize<3, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 3"); litSize = MEM_readLE16(istart) >> 4; break; case 3: lhSize = 3; + RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 2; here we need lhSize+1 = 4"); litSize = MEM_readLE24(istart) >> 4; - RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4"); break; } RETURN_ERROR_IF(litSize > 0 && dst == NULL, dstSize_tooSmall, "NULL not handled"); - RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected, ""); + RETURN_ERROR_IF(litSize > blockSizeMax, corruption_detected, ""); RETURN_ERROR_IF(expectedWriteSize < litSize, dstSize_tooSmall, ""); ZSTD_allocateLiteralsBuffer(dctx, dst, dstCapacity, litSize, streaming, expectedWriteSize, 1); if (dctx->litBufferLocation == ZSTD_split) @@ -310,6 +340,18 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx, } } +/* Hidden declaration for fullbench */ +size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, + const void* src, size_t srcSize, + void* dst, size_t dstCapacity); +size_t ZSTD_decodeLiteralsBlock_wrapper(ZSTD_DCtx* dctx, + const void* src, size_t srcSize, + void* dst, size_t dstCapacity) +{ + dctx->isFrameDecompression = 0; + return ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, not_streaming); +} + /* Default FSE distribution tables. * These are pre-calculated FSE decoding tables using default distributions as defined in specification : * https://github.com/facebook/zstd/blob/release/doc/zstd_compression_format.md#default-distributions @@ -506,14 +548,15 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, for (i = 8; i < n; i += 8) { MEM_write64(spread + pos + i, sv); } - pos += n; + assert(n>=0); + pos += (size_t)n; } } /* Now we spread those positions across the table. - * The benefit of doing it in two stages is that we avoid the the + * The benefit of doing it in two stages is that we avoid the * variable size inner loop, which caused lots of branch misses. * Now we can run through all the positions without any branch misses. - * We unroll the loop twice, since that is what emperically worked best. + * We unroll the loop twice, since that is what empirically worked best. */ { size_t position = 0; @@ -540,7 +583,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, for (i=0; i highThreshold) position = (position + step) & tableMask; /* lowprob area */ + while (UNLIKELY(position > highThreshold)) position = (position + step) & tableMask; /* lowprob area */ } } assert(position == 0); /* position must reach all cells once, otherwise normalizedCounter is incorrect */ } @@ -551,7 +594,7 @@ void ZSTD_buildFSETable_body(ZSTD_seqSymbol* dt, for (u=0; u 0x7F) { if (nbSeq == 0xFF) { RETURN_ERROR_IF(ip+2 > iend, srcSize_wrong, ""); @@ -681,8 +719,16 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr, } *nbSeqPtr = nbSeq; + if (nbSeq == 0) { + /* No sequence : section ends immediately */ + RETURN_ERROR_IF(ip != iend, corruption_detected, + "extraneous data present in the Sequences section"); + return (size_t)(ip - istart); + } + /* FSE table descriptors */ RETURN_ERROR_IF(ip+1 > iend, srcSize_wrong, ""); /* minimum possible size: 1 byte for symbol encoding types */ + RETURN_ERROR_IF(*ip & 3, corruption_detected, ""); /* The last field, Reserved, must be all-zeroes. */ { symbolEncodingType_e const LLtype = (symbolEncodingType_e)(*ip >> 6); symbolEncodingType_e const OFtype = (symbolEncodingType_e)((*ip >> 4) & 3); symbolEncodingType_e const MLtype = (symbolEncodingType_e)((*ip >> 2) & 3); @@ -829,7 +875,7 @@ static void ZSTD_safecopy(BYTE* op, const BYTE* const oend_w, BYTE const* ip, pt /* ZSTD_safecopyDstBeforeSrc(): * This version allows overlap with dst before src, or handles the non-overlap case with dst after src * Kept separate from more common ZSTD_safecopy case to avoid performance impact to the safecopy common case */ -static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length) { +static void ZSTD_safecopyDstBeforeSrc(BYTE* op, const BYTE* ip, ptrdiff_t length) { ptrdiff_t const diff = op - ip; BYTE* const oend = op + length; @@ -858,6 +904,7 @@ static void ZSTD_safecopyDstBeforeSrc(BYTE* op, BYTE const* ip, ptrdiff_t length * to be optimized for many small sequences, since those fall into ZSTD_execSequence(). */ FORCE_NOINLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceEnd(BYTE* op, BYTE* const oend, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -905,6 +952,7 @@ size_t ZSTD_execSequenceEnd(BYTE* op, * This version is intended to be used during instances where the litBuffer is still split. It is kept separate to avoid performance impact for the good case. */ FORCE_NOINLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, BYTE* const oend, const BYTE* const oend_w, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -950,6 +998,7 @@ size_t ZSTD_execSequenceEndSplitLitBuffer(BYTE* op, } HINT_INLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequence(BYTE* op, BYTE* const oend, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -964,6 +1013,11 @@ size_t ZSTD_execSequence(BYTE* op, assert(op != NULL /* Precondition */); assert(oend_w < oend /* No underflow */); + +#if defined(__aarch64__) + /* prefetch sequence starting from match that will be used for copy later */ + PREFETCH_L1(match); +#endif /* Handle edge cases in a slow path: * - Read beyond end of literals * - Match end is within WILDCOPY_OVERLIMIT of oend @@ -1043,6 +1097,7 @@ size_t ZSTD_execSequence(BYTE* op, } HINT_INLINE +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR size_t ZSTD_execSequenceSplitLitBuffer(BYTE* op, BYTE* const oend, const BYTE* const oend_w, seq_t sequence, const BYTE** litPtr, const BYTE* const litLimit, @@ -1154,7 +1209,7 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 } /* We need to add at most (ZSTD_WINDOWLOG_MAX_32 - 1) bits to read the maximum - * offset bits. But we can only read at most (STREAM_ACCUMULATOR_MIN_32 - 1) + * offset bits. But we can only read at most STREAM_ACCUMULATOR_MIN_32 * bits before reloading. This value is the maximum number of bytes we read * after reloading when we are decoding long offsets. */ @@ -1165,13 +1220,37 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e; +/* + * ZSTD_decodeSequence(): + * @p longOffsets : tells the decoder to reload more bit while decoding large offsets + * only used in 32-bit mode + * @return : Sequence (litL + matchL + offset) + */ FORCE_INLINE_TEMPLATE seq_t -ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) +ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const int isLastSeq) { seq_t seq; + /* + * ZSTD_seqSymbol is a 64 bits wide structure. + * It can be loaded in one operation + * and its fields extracted by simply shifting or bit-extracting on aarch64. + * GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh + * operations that cause performance drop. This can be avoided by using this + * ZSTD_memcpy hack. + */ +#if defined(__aarch64__) && (defined(__GNUC__) && !defined(__clang__)) + ZSTD_seqSymbol llDInfoS, mlDInfoS, ofDInfoS; + ZSTD_seqSymbol* const llDInfo = &llDInfoS; + ZSTD_seqSymbol* const mlDInfo = &mlDInfoS; + ZSTD_seqSymbol* const ofDInfo = &ofDInfoS; + ZSTD_memcpy(llDInfo, seqState->stateLL.table + seqState->stateLL.state, sizeof(ZSTD_seqSymbol)); + ZSTD_memcpy(mlDInfo, seqState->stateML.table + seqState->stateML.state, sizeof(ZSTD_seqSymbol)); + ZSTD_memcpy(ofDInfo, seqState->stateOffb.table + seqState->stateOffb.state, sizeof(ZSTD_seqSymbol)); +#else const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state; const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state; const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state; +#endif seq.matchLength = mlDInfo->baseValue; seq.litLength = llDInfo->baseValue; { U32 const ofBase = ofDInfo->baseValue; @@ -1186,28 +1265,31 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) U32 const llnbBits = llDInfo->nbBits; U32 const mlnbBits = mlDInfo->nbBits; U32 const ofnbBits = ofDInfo->nbBits; + + assert(llBits <= MaxLLBits); + assert(mlBits <= MaxMLBits); + assert(ofBits <= MaxOff); /* * As gcc has better branch and block analyzers, sometimes it is only - * valuable to mark likelyness for clang, it gives around 3-4% of + * valuable to mark likeliness for clang, it gives around 3-4% of * performance. */ /* sequence */ { size_t offset; - #if defined(__clang__) - if (LIKELY(ofBits > 1)) { - #else if (ofBits > 1) { - #endif ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1); ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5); - assert(ofBits <= MaxOff); + ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32); + ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits); if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) { - U32 const extraBits = ofBits - MIN(ofBits, 32 - seqState->DStream.bitsConsumed); + /* Always read extra bits, this keeps the logic simple, + * avoids branches, and avoids accidentally reading 0 bits. + */ + U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32; offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits); BIT_reloadDStream(&seqState->DStream); - if (extraBits) offset += BIT_readBitsFast(&seqState->DStream, extraBits); - assert(extraBits <= LONG_OFFSETS_MAX_EXTRA_BITS_32); /* to avoid another reload */ + offset += BIT_readBitsFast(&seqState->DStream, extraBits); } else { offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */ if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); @@ -1224,7 +1306,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) } else { offset = ofBase + ll0 + BIT_readBitsFast(&seqState->DStream, 1); { size_t temp = (offset==3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset]; - temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */ + temp -= !temp; /* 0 is not valid: input corrupted => force offset to -1 => corruption detected at execSequence */ if (offset != 1) seqState->prevOffset[2] = seqState->prevOffset[1]; seqState->prevOffset[1] = seqState->prevOffset[0]; seqState->prevOffset[0] = offset = temp; @@ -1232,11 +1314,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) seq.offset = offset; } - #if defined(__clang__) - if (UNLIKELY(mlBits > 0)) - #else if (mlBits > 0) - #endif seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/); if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32)) @@ -1246,11 +1324,7 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) /* Ensure there are enough bits to read the rest of data in 64-bit mode. */ ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64); - #if defined(__clang__) - if (UNLIKELY(llBits > 0)) - #else if (llBits > 0) - #endif seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/); if (MEM_32bits()) @@ -1259,17 +1333,22 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets) DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u", (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); - ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ - ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ - if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ - ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ + if (!isLastSeq) { + /* don't update FSE state for last Sequence */ + ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */ + ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */ + if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */ + ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */ + BIT_reloadDStream(&seqState->DStream); + } } return seq; } -#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION -MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) +#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) +#if DEBUGLEVEL >= 1 +static int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd) { size_t const windowSize = dctx->fParams.windowSize; /* No dictionary used. */ @@ -1283,30 +1362,33 @@ MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefix /* Dictionary is active. */ return 1; } +#endif -MEM_STATIC void ZSTD_assertValidSequence( +static void ZSTD_assertValidSequence( ZSTD_DCtx const* dctx, BYTE const* op, BYTE const* oend, seq_t const seq, BYTE const* prefixStart, BYTE const* virtualStart) { #if DEBUGLEVEL >= 1 - size_t const windowSize = dctx->fParams.windowSize; - size_t const sequenceSize = seq.litLength + seq.matchLength; - BYTE const* const oLitEnd = op + seq.litLength; - DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", - (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); - assert(op <= oend); - assert((size_t)(oend - op) >= sequenceSize); - assert(sequenceSize <= ZSTD_BLOCKSIZE_MAX); - if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { - size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); - /* Offset must be within the dictionary. */ - assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); - assert(seq.offset <= windowSize + dictSize); - } else { - /* Offset must be within our window. */ - assert(seq.offset <= windowSize); + if (dctx->isFrameDecompression) { + size_t const windowSize = dctx->fParams.windowSize; + size_t const sequenceSize = seq.litLength + seq.matchLength; + BYTE const* const oLitEnd = op + seq.litLength; + DEBUGLOG(6, "Checking sequence: litL=%u matchL=%u offset=%u", + (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); + assert(op <= oend); + assert((size_t)(oend - op) >= sequenceSize); + assert(sequenceSize <= ZSTD_blockSizeMax(dctx)); + if (ZSTD_dictionaryIsActive(dctx, prefixStart, oLitEnd)) { + size_t const dictSize = (size_t)((char const*)dctx->dictContentEndForFuzzing - (char const*)dctx->dictContentBeginForFuzzing); + /* Offset must be within the dictionary. */ + assert(seq.offset <= (size_t)(oLitEnd - virtualStart)); + assert(seq.offset <= windowSize + dictSize); + } else { + /* Offset must be within our window. */ + assert(seq.offset <= windowSize); + } } #else (void)dctx, (void)op, (void)oend, (void)seq, (void)prefixStart, (void)virtualStart; @@ -1322,23 +1404,21 @@ DONT_VECTORIZE ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = ostart + maxDstSize; + BYTE* const oend = ZSTD_maybeNullPtrAdd(ostart, maxDstSize); BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* litBufferEnd = dctx->litBufferEnd; const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); const BYTE* const vBase = (const BYTE*) (dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); - DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer"); - (void)frame; + DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer (%i seqs)", nbSeq); - /* Regen sequences */ + /* Literals are split between internal buffer & output buffer */ if (nbSeq) { seqState_t seqState; dctx->fseEntropy = 1; @@ -1357,8 +1437,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, BIT_DStream_completed < BIT_DStream_overflow); /* decompress without overrunning litPtr begins */ - { - seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + { seq_t sequence = {0,0,0}; /* some static analyzer believe that @sequence is not initialized (it necessarily is, since for(;;) loop as at least one iteration) */ /* Align the decompression loop to 32 + 16 bytes. * * zstd compiled with gcc-9 on an Intel i9-9900k shows 10% decompression @@ -1420,27 +1499,26 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, #endif /* Handle the initial state where litBuffer is currently split between dst and litExtraBuffer */ - for (; litPtr + sequence.litLength <= dctx->litBufferEnd; ) { - size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + for ( ; nbSeq; nbSeq--) { + sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); + if (litPtr + sequence.litLength > dctx->litBufferEnd) break; + { size_t const oneSeqSize = ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequence.litLength - WILDCOPY_OVERLENGTH, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) - assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + assert(!ZSTD_isError(oneSeqSize)); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif - if (UNLIKELY(ZSTD_isError(oneSeqSize))) - return oneSeqSize; - DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); - op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); - sequence = ZSTD_decodeSequence(&seqState, isLongOffset); - } + if (UNLIKELY(ZSTD_isError(oneSeqSize))) + return oneSeqSize; + DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); + op += oneSeqSize; + } } + DEBUGLOG(6, "reached: (litPtr + sequence.litLength > dctx->litBufferEnd)"); /* If there are more sequences, they will need to read literals from litExtraBuffer; copy over the remainder from dst and update litPtr and litEnd */ if (nbSeq > 0) { const size_t leftoverLit = dctx->litBufferEnd - litPtr; - if (leftoverLit) - { + DEBUGLOG(6, "There are %i sequences left, and %zu/%zu literals left in buffer", nbSeq, leftoverLit, sequence.litLength); + if (leftoverLit) { RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); sequence.litLength -= leftoverLit; @@ -1449,24 +1527,22 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - { - size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (--nbSeq) - BIT_reloadDStream(&(seqState.DStream)); } + nbSeq--; } } - if (nbSeq > 0) /* there is remaining lit from extra buffer */ - { + if (nbSeq > 0) { + /* there is remaining lit from extra buffer */ #if defined(__x86_64__) __asm__(".p2align 6"); @@ -1485,35 +1561,34 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, # endif #endif - for (; ; ) { - seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + for ( ; nbSeq ; nbSeq--) { + seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); } } /* check if reached exact end */ DEBUGLOG(5, "ZSTD_decompressSequences_bodySplitLitBuffer: after decode loop, remaining nbSeq : %i", nbSeq); RETURN_ERROR_IF(nbSeq, corruption_detected, ""); - RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); + DEBUGLOG(5, "bitStream : start=%p, ptr=%p, bitsConsumed=%u", seqState.DStream.start, seqState.DStream.ptr, seqState.DStream.bitsConsumed); + RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); /* save reps for next block */ { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } } /* last literal segment */ - if (dctx->litBufferLocation == ZSTD_split) /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ - { - size_t const lastLLSize = litBufferEnd - litPtr; + if (dctx->litBufferLocation == ZSTD_split) { + /* split hasn't been reached yet, first get dst then copy litExtraBuffer */ + size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); + DEBUGLOG(6, "copy last literals from segment : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memmove(op, litPtr, lastLLSize); @@ -1523,15 +1598,17 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx, litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; } - { size_t const lastLLSize = litBufferEnd - litPtr; + /* copy last literals from internal buffer */ + { size_t const lastLLSize = (size_t)(litBufferEnd - litPtr); + DEBUGLOG(6, "copy last literals from internal buffer : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memcpy(op, litPtr, lastLLSize); op += lastLLSize; - } - } + } } - return op-ostart; + DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); + return (size_t)(op - ostart); } FORCE_INLINE_TEMPLATE size_t @@ -1539,21 +1616,19 @@ DONT_VECTORIZE ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ostart + maxDstSize : dctx->litBuffer; + BYTE* const oend = dctx->litBufferLocation == ZSTD_not_in_dst ? ZSTD_maybeNullPtrAdd(ostart, maxDstSize) : dctx->litBuffer; BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* const litEnd = litPtr + dctx->litSize; const BYTE* const prefixStart = (const BYTE*)(dctx->prefixStart); const BYTE* const vBase = (const BYTE*)(dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*)(dctx->dictEnd); - DEBUGLOG(5, "ZSTD_decompressSequences_body"); - (void)frame; + DEBUGLOG(5, "ZSTD_decompressSequences_body: nbSeq = %d", nbSeq); /* Regen sequences */ if (nbSeq) { @@ -1568,11 +1643,6 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); assert(dst != NULL); - ZSTD_STATIC_ASSERT( - BIT_DStream_unfinished < BIT_DStream_completed && - BIT_DStream_endOfBuffer < BIT_DStream_completed && - BIT_DStream_completed < BIT_DStream_overflow); - #if defined(__x86_64__) __asm__(".p2align 6"); __asm__("nop"); @@ -1587,73 +1657,70 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx, # endif #endif - for ( ; ; ) { - seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset); + for ( ; nbSeq ; nbSeq--) { + seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); + ZSTD_assertValidSequence(dctx, op, oend, sequence, prefixStart, vBase); #endif if (UNLIKELY(ZSTD_isError(oneSeqSize))) return oneSeqSize; DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize); op += oneSeqSize; - if (UNLIKELY(!--nbSeq)) - break; - BIT_reloadDStream(&(seqState.DStream)); } /* check if reached exact end */ - DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq); - RETURN_ERROR_IF(nbSeq, corruption_detected, ""); - RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, ""); + assert(nbSeq == 0); + RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, ""); /* save reps for next block */ { U32 i; for (i=0; ientropy.rep[i] = (U32)(seqState.prevOffset[i]); } } /* last literal segment */ - { size_t const lastLLSize = litEnd - litPtr; + { size_t const lastLLSize = (size_t)(litEnd - litPtr); + DEBUGLOG(6, "copy last literals : %u", (U32)lastLLSize); RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall, ""); if (op != NULL) { ZSTD_memcpy(op, litPtr, lastLLSize); op += lastLLSize; - } - } + } } - return op-ostart; + DEBUGLOG(6, "decoded block of size %u bytes", (U32)(op - ostart)); + return (size_t)(op - ostart); } static size_t ZSTD_decompressSequences_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static size_t ZSTD_decompressSequencesSplitLitBuffer_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT -FORCE_INLINE_TEMPLATE size_t -ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, +FORCE_INLINE_TEMPLATE + +size_t ZSTD_prefetchMatch(size_t prefetchPos, seq_t const sequence, const BYTE* const prefixStart, const BYTE* const dictEnd) { prefetchPos += sequence.litLength; { const BYTE* const matchBase = (sequence.offset > prefetchPos) ? dictEnd : prefixStart; - const BYTE* const match = matchBase + prefetchPos - sequence.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. - * No consequence though : memory address is only used for prefetching, not for dereferencing */ + /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted. + * No consequence though : memory address is only used for prefetching, not for dereferencing */ + const BYTE* const match = ZSTD_wrappedPtrSub(ZSTD_wrappedPtrAdd(matchBase, prefetchPos), sequence.offset); PREFETCH_L1(match); PREFETCH_L1(match+CACHELINE_SIZE); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */ } return prefetchPos + sequence.matchLength; @@ -1668,20 +1735,18 @@ ZSTD_decompressSequencesLong_body( ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { const BYTE* ip = (const BYTE*)seqStart; const BYTE* const iend = ip + seqSize; BYTE* const ostart = (BYTE*)dst; - BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ostart + maxDstSize; + BYTE* const oend = dctx->litBufferLocation == ZSTD_in_dst ? dctx->litBuffer : ZSTD_maybeNullPtrAdd(ostart, maxDstSize); BYTE* op = ostart; const BYTE* litPtr = dctx->litPtr; const BYTE* litBufferEnd = dctx->litBufferEnd; const BYTE* const prefixStart = (const BYTE*) (dctx->prefixStart); const BYTE* const dictStart = (const BYTE*) (dctx->virtualStart); const BYTE* const dictEnd = (const BYTE*) (dctx->dictEnd); - (void)frame; /* Regen sequences */ if (nbSeq) { @@ -1706,20 +1771,17 @@ ZSTD_decompressSequencesLong_body( ZSTD_initFseState(&seqState.stateML, &seqState.DStream, dctx->MLTptr); /* prepare in advance */ - for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNblitBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) - { + if (dctx->litBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd) { /* lit buffer is reaching split point, empty out the first buffer and transition to litExtraBuffer */ const size_t leftoverLit = dctx->litBufferEnd - litPtr; if (leftoverLit) @@ -1732,26 +1794,26 @@ ZSTD_decompressSequencesLong_body( litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) - assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + assert(!ZSTD_isError(oneSeqSize)); + ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); #endif - if (ZSTD_isError(oneSeqSize)) return oneSeqSize; + if (ZSTD_isError(oneSeqSize)) return oneSeqSize; - prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); - sequences[seqNb & STORED_SEQS_MASK] = sequence; - op += oneSeqSize; - } + prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd); + sequences[seqNb & STORED_SEQS_MASK] = sequence; + op += oneSeqSize; + } } else { /* lit buffer is either wholly contained in first or second split, or not split at all*/ - oneSeqSize = dctx->litBufferLocation == ZSTD_split ? + size_t const oneSeqSize = dctx->litBufferLocation == ZSTD_split ? ZSTD_execSequenceSplitLitBuffer(op, oend, litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength - WILDCOPY_OVERLENGTH, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd) : ZSTD_execSequence(op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; @@ -1760,17 +1822,15 @@ ZSTD_decompressSequencesLong_body( op += oneSeqSize; } } - RETURN_ERROR_IF(seqNblitBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) - { + if (dctx->litBufferLocation == ZSTD_split && litPtr + sequence->litLength > dctx->litBufferEnd) { const size_t leftoverLit = dctx->litBufferEnd - litPtr; - if (leftoverLit) - { + if (leftoverLit) { RETURN_ERROR_IF(leftoverLit > (size_t)(oend - op), dstSize_tooSmall, "remaining lit must fit within dstBuffer"); ZSTD_safecopyDstBeforeSrc(op, litPtr, leftoverLit); sequence->litLength -= leftoverLit; @@ -1779,11 +1839,10 @@ ZSTD_decompressSequencesLong_body( litPtr = dctx->litExtraBuffer; litBufferEnd = dctx->litExtraBuffer + ZSTD_LITBUFFEREXTRASIZE; dctx->litBufferLocation = ZSTD_not_in_dst; - { - size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); + { size_t const oneSeqSize = ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; op += oneSeqSize; @@ -1796,7 +1855,7 @@ ZSTD_decompressSequencesLong_body( ZSTD_execSequence(op, oend, *sequence, &litPtr, litBufferEnd, prefixStart, dictStart, dictEnd); #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) assert(!ZSTD_isError(oneSeqSize)); - if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); + ZSTD_assertValidSequence(dctx, op, oend, sequences[seqNb&STORED_SEQS_MASK], prefixStart, dictStart); #endif if (ZSTD_isError(oneSeqSize)) return oneSeqSize; op += oneSeqSize; @@ -1808,8 +1867,7 @@ ZSTD_decompressSequencesLong_body( } /* last literal segment */ - if (dctx->litBufferLocation == ZSTD_split) /* first deplete literal buffer in dst, then copy litExtraBuffer */ - { + if (dctx->litBufferLocation == ZSTD_split) { /* first deplete literal buffer in dst, then copy litExtraBuffer */ size_t const lastLLSize = litBufferEnd - litPtr; RETURN_ERROR_IF(lastLLSize > (size_t)(oend - op), dstSize_tooSmall, ""); if (op != NULL) { @@ -1827,17 +1885,16 @@ ZSTD_decompressSequencesLong_body( } } - return op-ostart; + return (size_t)(op - ostart); } static size_t ZSTD_decompressSequencesLong_default(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ @@ -1851,20 +1908,18 @@ DONT_VECTORIZE ZSTD_decompressSequences_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static BMI2_TARGET_ATTRIBUTE size_t DONT_VECTORIZE ZSTD_decompressSequencesSplitLitBuffer_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bodySplitLitBuffer(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ @@ -1873,10 +1928,9 @@ static BMI2_TARGET_ATTRIBUTE size_t ZSTD_decompressSequencesLong_bmi2(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { - return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_body(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ @@ -1886,37 +1940,34 @@ typedef size_t (*ZSTD_decompressSequences_t)( ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame); + const ZSTD_longOffset_e isLongOffset); #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG static size_t ZSTD_decompressSequences(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequences"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } static size_t ZSTD_decompressSequencesSplitLitBuffer(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequencesSplitLitBuffer"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ @@ -1931,69 +1982,114 @@ static size_t ZSTD_decompressSequencesLong(ZSTD_DCtx* dctx, void* dst, size_t maxDstSize, const void* seqStart, size_t seqSize, int nbSeq, - const ZSTD_longOffset_e isLongOffset, - const int frame) + const ZSTD_longOffset_e isLongOffset) { DEBUGLOG(5, "ZSTD_decompressSequencesLong"); #if DYNAMIC_BMI2 if (ZSTD_DCtx_get_bmi2(dctx)) { - return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_bmi2(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif - return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong_default(dctx, dst, maxDstSize, seqStart, seqSize, nbSeq, isLongOffset); } #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT */ +/* + * @returns The total size of the history referenceable by zstd, including + * both the prefix and the extDict. At @p op any offset larger than this + * is invalid. + */ +static size_t ZSTD_totalHistorySize(BYTE* op, BYTE const* virtualStart) +{ + return (size_t)(op - virtualStart); +} + +typedef struct { + unsigned longOffsetShare; + unsigned maxNbAdditionalBits; +} ZSTD_OffsetInfo; -#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ - !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) -/* ZSTD_getLongOffsetsShare() : +/* ZSTD_getOffsetInfo() : * condition : offTable must be valid * @return : "share" of long offsets (arbitrarily defined as > (1<<23)) - * compared to maximum possible of (1< 22) total += 1; + ZSTD_OffsetInfo info = {0, 0}; + /* If nbSeq == 0, then the offTable is uninitialized, but we have + * no sequences, so both values should be 0. + */ + if (nbSeq != 0) { + const void* ptr = offTable; + U32 const tableLog = ((const ZSTD_seqSymbol_header*)ptr)[0].tableLog; + const ZSTD_seqSymbol* table = offTable + 1; + U32 const max = 1 << tableLog; + U32 u; + DEBUGLOG(5, "ZSTD_getLongOffsetsShare: (tableLog=%u)", tableLog); + + assert(max <= (1 << OffFSELog)); /* max not too large */ + for (u=0; u 22) info.longOffsetShare += 1; + } + + assert(tableLog <= OffFSELog); + info.longOffsetShare <<= (OffFSELog - tableLog); /* scale to OffFSELog */ } - assert(tableLog <= OffFSELog); - total <<= (OffFSELog - tableLog); /* scale to OffFSELog */ + return info; +} - return total; +/* + * @returns The maximum offset we can decode in one read of our bitstream, without + * reloading more bits in the middle of the offset bits read. Any offsets larger + * than this must use the long offset decoder. + */ +static size_t ZSTD_maxShortOffset(void) +{ + if (MEM_64bits()) { + /* We can decode any offset without reloading bits. + * This might change if the max window size grows. + */ + ZSTD_STATIC_ASSERT(ZSTD_WINDOWLOG_MAX <= 31); + return (size_t)-1; + } else { + /* The maximum offBase is (1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1. + * This offBase would require STREAM_ACCUMULATOR_MIN extra bits. + * Then we have to subtract ZSTD_REP_NUM to get the maximum possible offset. + */ + size_t const maxOffbase = ((size_t)1 << (STREAM_ACCUMULATOR_MIN + 1)) - 1; + size_t const maxOffset = maxOffbase - ZSTD_REP_NUM; + assert(ZSTD_highbit32((U32)maxOffbase) == STREAM_ACCUMULATOR_MIN); + return maxOffset; + } } -#endif size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, const int frame, const streaming_operation streaming) + const void* src, size_t srcSize, const streaming_operation streaming) { /* blockType == blockCompressed */ const BYTE* ip = (const BYTE*)src; - /* isLongOffset must be true if there are long offsets. - * Offsets are long if they are larger than 2^STREAM_ACCUMULATOR_MIN. - * We don't expect that to be the case in 64-bit mode. - * In block mode, window size is not known, so we have to be conservative. - * (note: but it could be evaluated from current-lowLimit) - */ - ZSTD_longOffset_e const isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (!frame || (dctx->fParams.windowSize > (1ULL << STREAM_ACCUMULATOR_MIN)))); - DEBUGLOG(5, "ZSTD_decompressBlock_internal (size : %u)", (U32)srcSize); - - RETURN_ERROR_IF(srcSize >= ZSTD_BLOCKSIZE_MAX, srcSize_wrong, ""); + DEBUGLOG(5, "ZSTD_decompressBlock_internal (cSize : %u)", (unsigned)srcSize); + + /* Note : the wording of the specification + * allows compressed block to be sized exactly ZSTD_blockSizeMax(dctx). + * This generally does not happen, as it makes little sense, + * since an uncompressed block would feature same size and have no decompression cost. + * Also, note that decoder from reference libzstd before < v1.5.4 + * would consider this edge case as an error. + * As a consequence, avoid generating compressed blocks of size ZSTD_blockSizeMax(dctx) + * for broader compatibility with the deployed ecosystem of zstd decoders */ + RETURN_ERROR_IF(srcSize > ZSTD_blockSizeMax(dctx), srcSize_wrong, ""); /* Decode literals section */ { size_t const litCSize = ZSTD_decodeLiteralsBlock(dctx, src, srcSize, dst, dstCapacity, streaming); - DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : %u", (U32)litCSize); + DEBUGLOG(5, "ZSTD_decodeLiteralsBlock : cSize=%u, nbLiterals=%zu", (U32)litCSize, dctx->litSize); if (ZSTD_isError(litCSize)) return litCSize; ip += litCSize; srcSize -= litCSize; @@ -2001,6 +2097,23 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, /* Build Decoding Tables */ { + /* Compute the maximum block size, which must also work when !frame and fParams are unset. + * Additionally, take the min with dstCapacity to ensure that the totalHistorySize fits in a size_t. + */ + size_t const blockSizeMax = MIN(dstCapacity, ZSTD_blockSizeMax(dctx)); + size_t const totalHistorySize = ZSTD_totalHistorySize(ZSTD_maybeNullPtrAdd((BYTE*)dst, blockSizeMax), (BYTE const*)dctx->virtualStart); + /* isLongOffset must be true if there are long offsets. + * Offsets are long if they are larger than ZSTD_maxShortOffset(). + * We don't expect that to be the case in 64-bit mode. + * + * We check here to see if our history is large enough to allow long offsets. + * If it isn't, then we can't possible have (valid) long offsets. If the offset + * is invalid, then it is okay to read it incorrectly. + * + * If isLongOffsets is true, then we will later check our decoding table to see + * if it is even possible to generate long offsets. + */ + ZSTD_longOffset_e isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (totalHistorySize > ZSTD_maxShortOffset())); /* These macros control at build-time which decompressor implementation * we use. If neither is defined, we do some inspection and dispatch at * runtime. @@ -2008,6 +2121,11 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) int usePrefetchDecoder = dctx->ddictIsCold; +#else + /* Set to 1 to avoid computing offset info if we don't need to. + * Otherwise this value is ignored. + */ + int usePrefetchDecoder = 1; #endif int nbSeq; size_t const seqHSize = ZSTD_decodeSeqHeaders(dctx, &nbSeq, ip, srcSize); @@ -2015,40 +2133,55 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, ip += seqHSize; srcSize -= seqHSize; - RETURN_ERROR_IF(dst == NULL && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF((dst == NULL || dstCapacity == 0) && nbSeq > 0, dstSize_tooSmall, "NULL not handled"); + RETURN_ERROR_IF(MEM_64bits() && sizeof(size_t) == sizeof(void*) && (size_t)(-1) - (size_t)dst < (size_t)(1 << 20), dstSize_tooSmall, + "invalid dst"); -#if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ - !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) - if ( !usePrefetchDecoder - && (!frame || (dctx->fParams.windowSize > (1<<24))) - && (nbSeq>ADVANCED_SEQS) ) { /* could probably use a larger nbSeq limit */ - U32 const shareLongOffsets = ZSTD_getLongOffsetsShare(dctx->OFTptr); - U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ - usePrefetchDecoder = (shareLongOffsets >= minShare); + /* If we could potentially have long offsets, or we might want to use the prefetch decoder, + * compute information about the share of long offsets, and the maximum nbAdditionalBits. + * NOTE: could probably use a larger nbSeq limit + */ + if (isLongOffset || (!usePrefetchDecoder && (totalHistorySize > (1u << 24)) && (nbSeq > 8))) { + ZSTD_OffsetInfo const info = ZSTD_getOffsetInfo(dctx->OFTptr, nbSeq); + if (isLongOffset && info.maxNbAdditionalBits <= STREAM_ACCUMULATOR_MIN) { + /* If isLongOffset, but the maximum number of additional bits that we see in our table is small + * enough, then we know it is impossible to have too long an offset in this block, so we can + * use the regular offset decoder. + */ + isLongOffset = ZSTD_lo_isRegularOffset; + } + if (!usePrefetchDecoder) { + U32 const minShare = MEM_64bits() ? 7 : 20; /* heuristic values, correspond to 2.73% and 7.81% */ + usePrefetchDecoder = (info.longOffsetShare >= minShare); + } } -#endif dctx->ddictIsCold = 0; #if !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT) && \ !defined(ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG) - if (usePrefetchDecoder) + if (usePrefetchDecoder) { +#else + (void)usePrefetchDecoder; + { #endif #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT - return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesLong(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); #endif + } #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG /* else */ if (dctx->litBufferLocation == ZSTD_split) - return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequencesSplitLitBuffer(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); else - return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset, frame); + return ZSTD_decompressSequences(dctx, dst, dstCapacity, ip, srcSize, nbSeq, isLongOffset); #endif } } +ZSTD_ALLOW_POINTER_OVERFLOW_ATTR void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) { if (dst != dctx->previousDstEnd && dstSize > 0) { /* not contiguous */ @@ -2060,13 +2193,24 @@ void ZSTD_checkContinuity(ZSTD_DCtx* dctx, const void* dst, size_t dstSize) } -size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, - void* dst, size_t dstCapacity, - const void* src, size_t srcSize) +size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) { size_t dSize; + dctx->isFrameDecompression = 0; ZSTD_checkContinuity(dctx, dst, dstCapacity); - dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 0, not_streaming); + dSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, not_streaming); + FORWARD_IF_ERROR(dSize, ""); dctx->previousDstEnd = (char*)dst + dSize; return dSize; } + + +/* NOTE: Must just wrap ZSTD_decompressBlock_deprecated() */ +size_t ZSTD_decompressBlock(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize) +{ + return ZSTD_decompressBlock_deprecated(dctx, dst, dstCapacity, src, srcSize); +} diff --git a/lib/zstd/decompress/zstd_decompress_block.h b/lib/zstd/decompress/zstd_decompress_block.h index 3d2d57a5d25a..becffbd89364 100644 --- a/lib/zstd/decompress/zstd_decompress_block.h +++ b/lib/zstd/decompress/zstd_decompress_block.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -47,7 +48,7 @@ typedef enum { */ size_t ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, - const void* src, size_t srcSize, const int frame, const streaming_operation streaming); + const void* src, size_t srcSize, const streaming_operation streaming); /* ZSTD_buildFSETable() : * generate FSE decoding table for one symbol (ll, ml or off) @@ -64,5 +65,10 @@ void ZSTD_buildFSETable(ZSTD_seqSymbol* dt, unsigned tableLog, void* wksp, size_t wkspSize, int bmi2); +/* Internal definition of ZSTD_decompressBlock() to avoid deprecation warnings. */ +size_t ZSTD_decompressBlock_deprecated(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize); + #endif /* ZSTD_DEC_BLOCK_H */ diff --git a/lib/zstd/decompress/zstd_decompress_internal.h b/lib/zstd/decompress/zstd_decompress_internal.h index 98102edb6a83..0f02526be774 100644 --- a/lib/zstd/decompress/zstd_decompress_internal.h +++ b/lib/zstd/decompress/zstd_decompress_internal.h @@ -1,5 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Yann Collet, Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -75,12 +76,13 @@ static UNUSED_ATTR const U32 ML_base[MaxML+1] = { #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE (sizeof(S16) * (MaxSeq + 1) + (1u << MaxFSELog) + sizeof(U64)) #define ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32 ((ZSTD_BUILD_FSE_TABLE_WKSP_SIZE + sizeof(U32) - 1) / sizeof(U32)) +#define ZSTD_HUFFDTABLE_CAPACITY_LOG 12 typedef struct { ZSTD_seqSymbol LLTable[SEQSYMBOL_TABLE_SIZE(LLFSELog)]; /* Note : Space reserved for FSE Tables */ ZSTD_seqSymbol OFTable[SEQSYMBOL_TABLE_SIZE(OffFSELog)]; /* is also used as temporary workspace while building hufTable during DDict creation */ ZSTD_seqSymbol MLTable[SEQSYMBOL_TABLE_SIZE(MLFSELog)]; /* and therefore must be at least HUF_DECOMPRESS_WORKSPACE_SIZE large */ - HUF_DTable hufTable[HUF_DTABLE_SIZE(HufLog)]; /* can accommodate HUF_decompress4X */ + HUF_DTable hufTable[HUF_DTABLE_SIZE(ZSTD_HUFFDTABLE_CAPACITY_LOG)]; /* can accommodate HUF_decompress4X */ U32 rep[ZSTD_REP_NUM]; U32 workspace[ZSTD_BUILD_FSE_TABLE_WKSP_SIZE_U32]; } ZSTD_entropyDTables_t; @@ -152,6 +154,7 @@ struct ZSTD_DCtx_s size_t litSize; size_t rleSize; size_t staticSize; + int isFrameDecompression; #if DYNAMIC_BMI2 != 0 int bmi2; /* == 1 if the CPU supports BMI2 and 0 otherwise. CPU support is determined dynamically once per context lifetime. */ #endif @@ -164,6 +167,8 @@ struct ZSTD_DCtx_s ZSTD_dictUses_e dictUses; ZSTD_DDictHashSet* ddictSet; /* Hash set for multiple ddicts */ ZSTD_refMultipleDDicts_e refMultipleDDicts; /* User specified: if == 1, will allow references to multiple DDicts. Default == 0 (disabled) */ + int disableHufAsm; + int maxBlockSizeParam; /* streaming */ ZSTD_dStreamStage streamStage; diff --git a/lib/zstd/decompress_sources.h b/lib/zstd/decompress_sources.h index a06ca187aab5..8a47eb2a4514 100644 --- a/lib/zstd/decompress_sources.h +++ b/lib/zstd/decompress_sources.h @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */ /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/zstd_common_module.c b/lib/zstd/zstd_common_module.c index 22686e367e6f..466828e35752 100644 --- a/lib/zstd/zstd_common_module.c +++ b/lib/zstd/zstd_common_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -24,9 +24,6 @@ EXPORT_SYMBOL_GPL(HUF_readStats_wksp); EXPORT_SYMBOL_GPL(ZSTD_isError); EXPORT_SYMBOL_GPL(ZSTD_getErrorName); EXPORT_SYMBOL_GPL(ZSTD_getErrorCode); -EXPORT_SYMBOL_GPL(ZSTD_customMalloc); -EXPORT_SYMBOL_GPL(ZSTD_customCalloc); -EXPORT_SYMBOL_GPL(ZSTD_customFree); MODULE_LICENSE("Dual BSD/GPL"); MODULE_DESCRIPTION("Zstd Common"); diff --git a/lib/zstd/zstd_compress_module.c b/lib/zstd/zstd_compress_module.c index bd8784449b31..ceaf352d03e2 100644 --- a/lib/zstd/zstd_compress_module.c +++ b/lib/zstd/zstd_compress_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the diff --git a/lib/zstd/zstd_decompress_module.c b/lib/zstd/zstd_decompress_module.c index 469fc3059be0..0ae819f0c927 100644 --- a/lib/zstd/zstd_decompress_module.c +++ b/lib/zstd/zstd_decompress_module.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause /* - * Copyright (c) Facebook, Inc. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -113,7 +113,7 @@ EXPORT_SYMBOL(zstd_init_dstream); size_t zstd_reset_dstream(zstd_dstream *dstream) { - return ZSTD_resetDStream(dstream); + return ZSTD_DCtx_reset(dstream, ZSTD_reset_session_only); } EXPORT_SYMBOL(zstd_reset_dstream); -- 2.47.0