From f5f3359a1c4e7721538fbb5c7dfbe36550e5ae2b Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Thu, 8 Dec 2022 13:03:50 +0100 Subject: [PATCH 01/20] cachy Signed-off-by: Peter Jung --- .gitignore | 1 + .../admin-guide/kernel-parameters.txt | 11 +- Documentation/mm/index.rst | 1 + Documentation/mm/zblock.rst | 31 + MAINTAINERS | 7 + Makefile | 137 +++- arch/arc/configs/axs101_defconfig | 1 + arch/arc/configs/axs103_defconfig | 1 + arch/arc/configs/axs103_smp_defconfig | 1 + arch/arc/configs/haps_hs_defconfig | 1 + arch/arc/configs/haps_hs_smp_defconfig | 1 + arch/arc/configs/hsdk_defconfig | 1 + arch/arc/configs/nsim_700_defconfig | 1 + arch/arc/configs/nsimosci_defconfig | 1 + arch/arc/configs/nsimosci_hs_defconfig | 1 + arch/arc/configs/nsimosci_hs_smp_defconfig | 1 + arch/arc/configs/tb10x_defconfig | 1 + arch/arc/configs/vdk_hs38_defconfig | 1 + arch/arc/configs/vdk_hs38_smp_defconfig | 1 + arch/arm/Makefile | 43 -- arch/x86/Kconfig.cpu | 332 ++++++++- arch/x86/Makefile | 11 +- arch/x86/Makefile.postlink | 41 ++ arch/x86/Makefile_32.cpu | 41 -- arch/x86/boot/compressed/Makefile | 10 +- arch/x86/include/asm/vermagic.h | 66 ++ block/bfq-iosched.c | 6 + block/elevator.c | 7 +- drivers/i2c/busses/Kconfig | 9 + drivers/i2c/busses/Makefile | 1 + drivers/i2c/busses/i2c-nct6775.c | 647 ++++++++++++++++++ drivers/i2c/busses/i2c-piix4.c | 4 +- drivers/md/dm-crypt.c | 5 + drivers/pci/quirks.c | 101 +++ include/linux/pagemap.h | 2 +- include/linux/user_namespace.h | 4 + init/Kconfig | 39 ++ kernel/Kconfig.hz | 24 + kernel/fork.c | 14 + kernel/module/Kconfig | 25 + kernel/rcu/Kconfig | 4 +- kernel/rcu/rcutorture.c | 2 +- kernel/rcu/tree.c | 6 +- kernel/rcu/tree_nocb.h | 4 +- kernel/rcu/tree_plugin.h | 4 +- kernel/sched/core.c | 20 +- kernel/sched/fair.c | 28 +- kernel/sched/pelt.c | 60 ++ kernel/sched/pelt.h | 42 +- kernel/sched/sched.h | 1 + kernel/sysctl.c | 12 + kernel/user_namespace.c | 7 + lib/Kconfig.debug | 29 +- lib/string.c | 62 +- mm/Kconfig | 17 + mm/Makefile | 1 + mm/compaction.c | 6 +- mm/page-writeback.c | 8 + mm/swap.c | 5 + mm/vmpressure.c | 4 + mm/vmscan.c | 4 + mm/zblock.c | 642 +++++++++++++++++ scripts/Makefile.debug | 6 +- scripts/Makefile.lib | 13 +- scripts/Makefile.modinst | 7 +- 65 files changed, 2457 insertions(+), 170 deletions(-) create mode 100644 Documentation/mm/zblock.rst create mode 100644 arch/x86/Makefile.postlink create mode 100644 drivers/i2c/busses/i2c-nct6775.c create mode 100644 mm/zblock.c diff --git a/.gitignore b/.gitignore index 265959544978..cd4ef88584ea 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ *.o *.o.* *.patch +*.relocs *.s *.so *.so.dbg diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt index 2bc11a61c4d0..74053d8b1287 100644 --- a/Documentation/admin-guide/kernel-parameters.txt +++ b/Documentation/admin-guide/kernel-parameters.txt @@ -4128,6 +4128,15 @@ nomsi [MSI] If the PCI_MSI kernel config parameter is enabled, this kernel boot option can be used to disable the use of MSI interrupts system-wide. + pcie_acs_override = + [PCIE] Override missing PCIe ACS support for: + downstream + All downstream ports - full ACS capabilities + multfunction + All multifunction devices - multifunction ACS subset + id:nnnn:nnnn + Specfic device - full ACS capabilities + Specified as vid:did (vendor/device ID) in hex noioapicquirk [APIC] Disable all boot interrupt quirks. Safety option to keep boot IRQs enabled. This should never be necessary. @@ -4703,7 +4712,7 @@ overwritten. rcutree.kthread_prio= [KNL,BOOT] - Set the SCHED_FIFO priority of the RCU per-CPU + Set the SCHED_RR priority of the RCU per-CPU kthreads (rcuc/N). This value is also used for the priority of the RCU boost threads (rcub/N) and for the RCU grace-period kthreads (rcu_bh, diff --git a/Documentation/mm/index.rst b/Documentation/mm/index.rst index 575ccd40e30c..25561b95780f 100644 --- a/Documentation/mm/index.rst +++ b/Documentation/mm/index.rst @@ -65,4 +65,5 @@ above structured documentation, or deleted if it has served its purpose. vmalloced-kernel-stacks vmemmap_dedup z3fold + zblock zsmalloc diff --git a/Documentation/mm/zblock.rst b/Documentation/mm/zblock.rst new file mode 100644 index 000000000000..fa4f8f24a5fd --- /dev/null +++ b/Documentation/mm/zblock.rst @@ -0,0 +1,31 @@ +.. SPDX-License-Identifier: GPL-2.0 + +.. _zblock: + +====== +zblock +====== + +Zblock stores integer number of compressed objects per block. These blocks +consist of several consecutive physical pages (from 1 to 8) and are arranged +in lists. The range from 0 to PAGE_SIZE is divided into the number of intervals +corresponding to the number of lists and these only operate on objects of size +from its interval. Thus the block lists are isolated from each other, which +makes it possible to simultaneously perform actions with several objects +from different lists. + +With zlock, it is possible to densely arrange objects of various sizes, +resulting in low internal fragmentation. Also this allocator tries to fill +incomplete blocks instead of adding new ones. As a result, in many cases it +provides a compression ratio substantially higher than z3fold and zbud. Zblock +does not require MMU and also is superior to zsmalloc with regard to the worst +execution times, thus allowing for better response time and real-time +characteristics of the whole system. + +Like similar allocation method from z3fold and zsmalloc, zblock_alloc() does +not return a dereferenceable pointer. Instead, it returns an unsigned long +handle which encodes actual location of the allocated object. + +Unlike zbud and z3fold, zblock works well with objects of various sizes - +including but not limited to highly compressed and poorly compressed, as well +as cases where both object types exist. diff --git a/MAINTAINERS b/MAINTAINERS index 72b9654f764c..e3f4a32f28e4 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -22471,6 +22471,13 @@ L: linux-mm@kvack.org S: Maintained F: mm/z3fold.c +ZBLOCK COMPRESSED PAGE ALLOCATOR +M: Ananda Badmaev +M: Vitaly Wool +L: linux-mm@kvack.org +S: Maintained +F: mm/zblock.c + ZD1211RW WIRELESS DRIVER M: Ulrich Kunitz L: linux-wireless@vger.kernel.org diff --git a/Makefile b/Makefile index 46c6eb57b354..5d4cf9796f91 100644 --- a/Makefile +++ b/Makefile @@ -756,8 +756,140 @@ KBUILD_CFLAGS += $(call cc-disable-warning, format-truncation) KBUILD_CFLAGS += $(call cc-disable-warning, format-overflow) KBUILD_CFLAGS += $(call cc-disable-warning, address-of-packed-member) +# This selects which ARM instruction set is used. +# Note that GCC does not numerically define an architecture version +# macro, but instead defines a whole series of macros which makes +# testing for a specific architecture or later rather impossible. +arch-$(CONFIG_CPU_32v7M) =-D__LINUX_ARM_ARCH__=7 -march=armv7-m +arch-$(CONFIG_CPU_32v7) =-D__LINUX_ARM_ARCH__=7 -march=armv7-a +arch-$(CONFIG_CPU_32v6) =-D__LINUX_ARM_ARCH__=6 -march=armv6 +# Only override the compiler option if ARMv6. The ARMv6K extensions are +# always available in ARMv7 +ifeq ($(CONFIG_CPU_32v6),y) +arch-$(CONFIG_CPU_32v6K) =-D__LINUX_ARM_ARCH__=6 -march=armv6k +endif +arch-$(CONFIG_CPU_32v5) =-D__LINUX_ARM_ARCH__=5 -march=armv5te +arch-$(CONFIG_CPU_32v4T) =-D__LINUX_ARM_ARCH__=4 -march=armv4t +arch-$(CONFIG_CPU_32v4) =-D__LINUX_ARM_ARCH__=4 -march=armv4 +arch-$(CONFIG_CPU_32v3) =-D__LINUX_ARM_ARCH__=3 -march=armv3m + +# Evaluate arch cc-option calls now +arch-y := $(arch-y) + +# This selects how we optimise for the processor. +tune-$(CONFIG_CPU_ARM7TDMI) =-mtune=arm7tdmi +tune-$(CONFIG_CPU_ARM720T) =-mtune=arm7tdmi +tune-$(CONFIG_CPU_ARM740T) =-mtune=arm7tdmi +tune-$(CONFIG_CPU_ARM9TDMI) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_ARM940T) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_ARM946E) =-mtune=arm9e +tune-$(CONFIG_CPU_ARM920T) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_ARM922T) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_ARM925T) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_ARM926T) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_FA526) =-mtune=arm9tdmi +tune-$(CONFIG_CPU_SA110) =-mtune=strongarm110 +tune-$(CONFIG_CPU_SA1100) =-mtune=strongarm1100 +tune-$(CONFIG_CPU_XSCALE) =-mtune=xscale +tune-$(CONFIG_CPU_XSC3) =-mtune=xscale +tune-$(CONFIG_CPU_FEROCEON) =-mtune=xscale +tune-$(CONFIG_CPU_V6) =-mtune=arm1136j-s +tune-$(CONFIG_CPU_V6K) =-mtune=arm1136j-s + +# Evaluate tune cc-option calls now +tune-y := $(tune-y) + +# This selects which x86 instruction set is used. +cflags-$(CONFIG_M486SX) += -march=i486 +cflags-$(CONFIG_M486) += -march=i486 +cflags-$(CONFIG_M586) += -march=i586 +cflags-$(CONFIG_M586TSC) += -march=i586 +cflags-$(CONFIG_M586MMX) += -march=pentium-mmx +cflags-$(CONFIG_M686) += -march=i686 +cflags-$(CONFIG_MPENTIUMII) += -march=i686 $(call tune,pentium2) +cflags-$(CONFIG_MPENTIUMIII) += -march=i686 $(call tune,pentium3) +cflags-$(CONFIG_MPENTIUMM) += -march=i686 $(call tune,pentium3) +cflags-$(CONFIG_MPENTIUM4) += -march=i686 $(call tune,pentium4) +cflags-$(CONFIG_MK6) += -march=k6 +# Please note, that patches that add -march=athlon-xp and friends are pointless. +# They make zero difference whatsosever to performance at this time. +cflags-$(CONFIG_MK7) += -march=athlon +cflags-$(CONFIG_MK8) += $(call cc-option,-march=k8,-march=athlon) +cflags-$(CONFIG_MCRUSOE) += -march=i686 $(align) +cflags-$(CONFIG_MEFFICEON) += -march=i686 $(call tune,pentium3) $(align) +cflags-$(CONFIG_MWINCHIPC6) += $(call cc-option,-march=winchip-c6,-march=i586) +cflags-$(CONFIG_MWINCHIP3D) += $(call cc-option,-march=winchip2,-march=i586) +cflags-$(CONFIG_MCYRIXIII) += $(call cc-option,-march=c3,-march=i486) $(align) +cflags-$(CONFIG_MVIAC3_2) += $(call cc-option,-march=c3-2,-march=i686) +cflags-$(CONFIG_MVIAC7) += -march=i686 +cflags-$(CONFIG_MCORE2) += -march=i686 $(call tune,core2) +cflags-$(CONFIG_MATOM) += $(call cc-option,-march=atom,$(call cc-option,-march=core2,-march=i686)) \ +$(call cc-option,-mtune=atom,$(call cc-option,-mtune=generic)) + +# AMD Elan support +cflags-$(CONFIG_MELAN) += -march=i486 + +# Geode GX1 support +cflags-$(CONFIG_MGEODEGX1) += -march=pentium-mmx +cflags-$(CONFIG_MGEODE_LX) += $(call cc-option,-march=geode,-march=pentium-mmx) +# add at the end to overwrite eventual tuning options from earlier +# cpu entries +cflags-$(CONFIG_X86_GENERIC) += $(call tune,generic,$(call tune,i686)) + +# Bug fix for binutils: this option is required in order to keep +# binutils from generating NOPL instructions against our will. +ifneq ($(CONFIG_X86_P6_NOP),y) +cflags-y += $(call cc-option,-Wa$(comma)-mtune=generic32,) +endif + +# x86_64 instruction set +cflags64-$(CONFIG_MK8) += -march=k8 +cflags64-$(CONFIG_MPSC) += -march=nocona +cflags64-$(CONFIG_MK8SSE3) += -march=k8-sse3 +cflags64-$(CONFIG_MK10) += -march=amdfam10 +cflags64-$(CONFIG_MBARCELONA) += -march=barcelona +cflags64-$(CONFIG_MBOBCAT) += -march=btver1 +cflags64-$(CONFIG_MJAGUAR) += -march=btver2 +cflags64-$(CONFIG_MBULLDOZER) += -march=bdver1 +cflags64-$(CONFIG_MPILEDRIVER) += -march=bdver2 -mno-tbm +cflags64-$(CONFIG_MSTEAMROLLER) += -march=bdver3 -mno-tbm +cflags64-$(CONFIG_MEXCAVATOR) += -march=bdver4 -mno-tbm +cflags64-$(CONFIG_MZEN) += -march=znver1 +cflags64-$(CONFIG_MZEN2) += -march=znver2 +cflags64-$(CONFIG_MZEN3) += -march=znver3 +cflags64-$(CONFIG_MNATIVE_INTEL) += -march=native +cflags64-$(CONFIG_MNATIVE_AMD) += -march=native +cflags64-$(CONFIG_MATOM) += -march=bonnell +cflags64-$(CONFIG_MCORE2) += -march=core2 +cflags64-$(CONFIG_MNEHALEM) += -march=nehalem +cflags64-$(CONFIG_MWESTMERE) += -march=westmere +cflags64-$(CONFIG_MSILVERMONT) += -march=silvermont +cflags64-$(CONFIG_MGOLDMONT) += -march=goldmont +cflags64-$(CONFIG_MGOLDMONTPLUS) += -march=goldmont-plus +cflags64-$(CONFIG_MSANDYBRIDGE) += -march=sandybridge +cflags64-$(CONFIG_MIVYBRIDGE) += -march=ivybridge +cflags64-$(CONFIG_MHASWELL) += -march=haswell +cflags64-$(CONFIG_MBROADWELL) += -march=broadwell +cflags64-$(CONFIG_MSKYLAKE) += -march=skylake +cflags64-$(CONFIG_MSKYLAKEX) += -march=skylake-avx512 +cflags64-$(CONFIG_MCANNONLAKE) += -march=cannonlake +cflags64-$(CONFIG_MICELAKE) += -march=icelake-client +cflags64-$(CONFIG_MCASCADELAKE) += -march=cascadelake +cflags64-$(CONFIG_MCOOPERLAKE) += -march=cooperlake +cflags64-$(CONFIG_MTIGERLAKE) += -march=tigerlake +cflags64-$(CONFIG_MSAPPHIRERAPIDS) += -march=sapphirerapids +cflags64-$(CONFIG_MROCKETLAKE) += -march=rocketlake +cflags64-$(CONFIG_MALDERLAKE) += -march=alderlake +cflags64-$(CONFIG_GENERIC_CPU2) += -march=x86-64-v2 +cflags64-$(CONFIG_GENERIC_CPU3) += -march=x86-64-v3 +cflags64-$(CONFIG_GENERIC_CPU4) += -march=x86-64-v4 +cflags64-$(CONFIG_GENERIC_CPU) += -mtune=generic +KBUILD_CFLAGS += $(cflags64-y) + ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE KBUILD_CFLAGS += -O2 +else ifdef CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3 +KBUILD_CFLAGS += -O3 else ifdef CONFIG_CC_OPTIMIZE_FOR_SIZE KBUILD_CFLAGS += -Os endif @@ -993,11 +1125,6 @@ KBUILD_CFLAGS += -fno-strict-overflow # Make sure -fstack-check isn't enabled (like gentoo apparently did) KBUILD_CFLAGS += -fno-stack-check -# conserve stack if available -ifdef CONFIG_CC_IS_GCC -KBUILD_CFLAGS += -fconserve-stack -endif - # Prohibit date/time macros, which would make the build non-deterministic KBUILD_CFLAGS += -Werror=date-time diff --git a/arch/arc/configs/axs101_defconfig b/arch/arc/configs/axs101_defconfig index e31a8ebc3ecc..0016149f9583 100644 --- a/arch/arc/configs/axs101_defconfig +++ b/arch/arc/configs/axs101_defconfig @@ -9,6 +9,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/axs103_defconfig b/arch/arc/configs/axs103_defconfig index e0e8567f0d75..5b031582a1cf 100644 --- a/arch/arc/configs/axs103_defconfig +++ b/arch/arc/configs/axs103_defconfig @@ -9,6 +9,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/axs103_smp_defconfig b/arch/arc/configs/axs103_smp_defconfig index fcbc952bc75b..d4eec39e0112 100644 --- a/arch/arc/configs/axs103_smp_defconfig +++ b/arch/arc/configs/axs103_smp_defconfig @@ -9,6 +9,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/haps_hs_defconfig b/arch/arc/configs/haps_hs_defconfig index d87ad7e88d62..7337cdf4ffdd 100644 --- a/arch/arc/configs/haps_hs_defconfig +++ b/arch/arc/configs/haps_hs_defconfig @@ -11,6 +11,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EXPERT=y CONFIG_PERF_EVENTS=y # CONFIG_COMPAT_BRK is not set diff --git a/arch/arc/configs/haps_hs_smp_defconfig b/arch/arc/configs/haps_hs_smp_defconfig index 8d82cdb7f86a..bc927221afc0 100644 --- a/arch/arc/configs/haps_hs_smp_defconfig +++ b/arch/arc/configs/haps_hs_smp_defconfig @@ -11,6 +11,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/hsdk_defconfig b/arch/arc/configs/hsdk_defconfig index f856b03e0fb5..aa000075a575 100644 --- a/arch/arc/configs/hsdk_defconfig +++ b/arch/arc/configs/hsdk_defconfig @@ -9,6 +9,7 @@ CONFIG_NAMESPACES=y # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y CONFIG_BLK_DEV_RAM=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/nsim_700_defconfig b/arch/arc/configs/nsim_700_defconfig index a1ce12bf5b16..326f6cde7826 100644 --- a/arch/arc/configs/nsim_700_defconfig +++ b/arch/arc/configs/nsim_700_defconfig @@ -11,6 +11,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_KALLSYMS_ALL=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y diff --git a/arch/arc/configs/nsimosci_defconfig b/arch/arc/configs/nsimosci_defconfig index ca10f4a2c823..bf39a0091679 100644 --- a/arch/arc/configs/nsimosci_defconfig +++ b/arch/arc/configs/nsimosci_defconfig @@ -10,6 +10,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_KALLSYMS_ALL=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y diff --git a/arch/arc/configs/nsimosci_hs_defconfig b/arch/arc/configs/nsimosci_hs_defconfig index 31b6ec3683c6..7121bd71c543 100644 --- a/arch/arc/configs/nsimosci_hs_defconfig +++ b/arch/arc/configs/nsimosci_hs_defconfig @@ -10,6 +10,7 @@ CONFIG_NAMESPACES=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_KALLSYMS_ALL=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y diff --git a/arch/arc/configs/nsimosci_hs_smp_defconfig b/arch/arc/configs/nsimosci_hs_smp_defconfig index 41a0037f48a5..f9863b294a70 100644 --- a/arch/arc/configs/nsimosci_hs_smp_defconfig +++ b/arch/arc/configs/nsimosci_hs_smp_defconfig @@ -8,6 +8,7 @@ CONFIG_IKCONFIG_PROC=y # CONFIG_UTS_NS is not set # CONFIG_PID_NS is not set CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_PERF_EVENTS=y # CONFIG_COMPAT_BRK is not set CONFIG_KPROBES=y diff --git a/arch/arc/configs/tb10x_defconfig b/arch/arc/configs/tb10x_defconfig index d93b65008d4a..a12656ec0072 100644 --- a/arch/arc/configs/tb10x_defconfig +++ b/arch/arc/configs/tb10x_defconfig @@ -14,6 +14,7 @@ CONFIG_INITRAMFS_SOURCE="../tb10x-rootfs.cpio" CONFIG_INITRAMFS_ROOT_UID=2100 CONFIG_INITRAMFS_ROOT_GID=501 # CONFIG_RD_GZIP is not set +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_KALLSYMS_ALL=y # CONFIG_AIO is not set CONFIG_EMBEDDED=y diff --git a/arch/arc/configs/vdk_hs38_defconfig b/arch/arc/configs/vdk_hs38_defconfig index 0c3b21416819..d7c858df520c 100644 --- a/arch/arc/configs/vdk_hs38_defconfig +++ b/arch/arc/configs/vdk_hs38_defconfig @@ -4,6 +4,7 @@ CONFIG_HIGH_RES_TIMERS=y CONFIG_IKCONFIG=y CONFIG_IKCONFIG_PROC=y CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arc/configs/vdk_hs38_smp_defconfig b/arch/arc/configs/vdk_hs38_smp_defconfig index f9ad9d3ee702..015c1d43889e 100644 --- a/arch/arc/configs/vdk_hs38_smp_defconfig +++ b/arch/arc/configs/vdk_hs38_smp_defconfig @@ -4,6 +4,7 @@ CONFIG_HIGH_RES_TIMERS=y CONFIG_IKCONFIG=y CONFIG_IKCONFIG_PROC=y CONFIG_BLK_DEV_INITRD=y +CONFIG_CC_OPTIMIZE_FOR_PERFORMANCE_O3=y CONFIG_EMBEDDED=y CONFIG_PERF_EVENTS=y # CONFIG_VM_EVENT_COUNTERS is not set diff --git a/arch/arm/Makefile b/arch/arm/Makefile index 56f655deebb1..5b8ad4df790e 100644 --- a/arch/arm/Makefile +++ b/arch/arm/Makefile @@ -56,49 +56,6 @@ endif # KBUILD_CFLAGS += $(call cc-option,-fno-ipa-sra) -# This selects which instruction set is used. -# Note that GCC does not numerically define an architecture version -# macro, but instead defines a whole series of macros which makes -# testing for a specific architecture or later rather impossible. -arch-$(CONFIG_CPU_32v7M) =-D__LINUX_ARM_ARCH__=7 -march=armv7-m -arch-$(CONFIG_CPU_32v7) =-D__LINUX_ARM_ARCH__=7 -march=armv7-a -arch-$(CONFIG_CPU_32v6) =-D__LINUX_ARM_ARCH__=6 -march=armv6 -# Only override the compiler option if ARMv6. The ARMv6K extensions are -# always available in ARMv7 -ifeq ($(CONFIG_CPU_32v6),y) -arch-$(CONFIG_CPU_32v6K) =-D__LINUX_ARM_ARCH__=6 -march=armv6k -endif -arch-$(CONFIG_CPU_32v5) =-D__LINUX_ARM_ARCH__=5 -march=armv5te -arch-$(CONFIG_CPU_32v4T) =-D__LINUX_ARM_ARCH__=4 -march=armv4t -arch-$(CONFIG_CPU_32v4) =-D__LINUX_ARM_ARCH__=4 -march=armv4 -arch-$(CONFIG_CPU_32v3) =-D__LINUX_ARM_ARCH__=3 -march=armv3m - -# Evaluate arch cc-option calls now -arch-y := $(arch-y) - -# This selects how we optimise for the processor. -tune-$(CONFIG_CPU_ARM7TDMI) =-mtune=arm7tdmi -tune-$(CONFIG_CPU_ARM720T) =-mtune=arm7tdmi -tune-$(CONFIG_CPU_ARM740T) =-mtune=arm7tdmi -tune-$(CONFIG_CPU_ARM9TDMI) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_ARM940T) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_ARM946E) =-mtune=arm9e -tune-$(CONFIG_CPU_ARM920T) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_ARM922T) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_ARM925T) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_ARM926T) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_FA526) =-mtune=arm9tdmi -tune-$(CONFIG_CPU_SA110) =-mtune=strongarm110 -tune-$(CONFIG_CPU_SA1100) =-mtune=strongarm1100 -tune-$(CONFIG_CPU_XSCALE) =-mtune=xscale -tune-$(CONFIG_CPU_XSC3) =-mtune=xscale -tune-$(CONFIG_CPU_FEROCEON) =-mtune=xscale -tune-$(CONFIG_CPU_V6) =-mtune=arm1136j-s -tune-$(CONFIG_CPU_V6K) =-mtune=arm1136j-s - -# Evaluate tune cc-option calls now -tune-y := $(tune-y) - ifeq ($(CONFIG_AEABI),y) CFLAGS_ABI :=-mabi=aapcs-linux -mfpu=vfp else diff --git a/arch/x86/Kconfig.cpu b/arch/x86/Kconfig.cpu index 542377cd419d..22b919cdb6d1 100644 --- a/arch/x86/Kconfig.cpu +++ b/arch/x86/Kconfig.cpu @@ -157,7 +157,7 @@ config MPENTIUM4 config MK6 - bool "K6/K6-II/K6-III" + bool "AMD K6/K6-II/K6-III" depends on X86_32 help Select this for an AMD K6-family processor. Enables use of @@ -165,7 +165,7 @@ config MK6 flags to GCC. config MK7 - bool "Athlon/Duron/K7" + bool "AMD Athlon/Duron/K7" depends on X86_32 help Select this for an AMD Athlon K7-family processor. Enables use of @@ -173,12 +173,98 @@ config MK7 flags to GCC. config MK8 - bool "Opteron/Athlon64/Hammer/K8" + bool "AMD Opteron/Athlon64/Hammer/K8" help Select this for an AMD Opteron or Athlon64 Hammer-family processor. Enables use of some extended instructions, and passes appropriate optimization flags to GCC. +config MK8SSE3 + bool "AMD Opteron/Athlon64/Hammer/K8 with SSE3" + help + Select this for improved AMD Opteron or Athlon64 Hammer-family processors. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MK10 + bool "AMD 61xx/7x50/PhenomX3/X4/II/K10" + help + Select this for an AMD 61xx Eight-Core Magny-Cours, Athlon X2 7x50, + Phenom X3/X4/II, Athlon II X2/X3/X4, or Turion II-family processor. + Enables use of some extended instructions, and passes appropriate + optimization flags to GCC. + +config MBARCELONA + bool "AMD Barcelona" + help + Select this for AMD Family 10h Barcelona processors. + + Enables -march=barcelona + +config MBOBCAT + bool "AMD Bobcat" + help + Select this for AMD Family 14h Bobcat processors. + + Enables -march=btver1 + +config MJAGUAR + bool "AMD Jaguar" + help + Select this for AMD Family 16h Jaguar processors. + + Enables -march=btver2 + +config MBULLDOZER + bool "AMD Bulldozer" + help + Select this for AMD Family 15h Bulldozer processors. + + Enables -march=bdver1 + +config MPILEDRIVER + bool "AMD Piledriver" + help + Select this for AMD Family 15h Piledriver processors. + + Enables -march=bdver2 + +config MSTEAMROLLER + bool "AMD Steamroller" + help + Select this for AMD Family 15h Steamroller processors. + + Enables -march=bdver3 + +config MEXCAVATOR + bool "AMD Excavator" + help + Select this for AMD Family 15h Excavator processors. + + Enables -march=bdver4 + +config MZEN + bool "AMD Zen" + help + Select this for AMD Family 17h Zen processors. + + Enables -march=znver1 + +config MZEN2 + bool "AMD Zen 2" + help + Select this for AMD Family 17h Zen 2 processors. + + Enables -march=znver2 + +config MZEN3 + bool "AMD Zen 3" + depends on (CC_IS_GCC && GCC_VERSION >= 100300) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + help + Select this for AMD Family 19h Zen 3 processors. + + Enables -march=znver3 + config MCRUSOE bool "Crusoe" depends on X86_32 @@ -270,7 +356,7 @@ config MPSC in /proc/cpuinfo. Family 15 is an older Xeon, Family 6 a newer one. config MCORE2 - bool "Core 2/newer Xeon" + bool "Intel Core 2" help Select this for Intel Core 2 and newer Core 2 Xeons (Xeon 51xx and @@ -278,6 +364,8 @@ config MCORE2 family in /proc/cpuinfo. Newer ones have 6 and older ones 15 (not a typo) + Enables -march=core2 + config MATOM bool "Intel Atom" help @@ -287,6 +375,182 @@ config MATOM accordingly optimized code. Use a recent GCC with specific Atom support in order to fully benefit from selecting this option. +config MNEHALEM + bool "Intel Nehalem" + select X86_P6_NOP + help + + Select this for 1st Gen Core processors in the Nehalem family. + + Enables -march=nehalem + +config MWESTMERE + bool "Intel Westmere" + select X86_P6_NOP + help + + Select this for the Intel Westmere formerly Nehalem-C family. + + Enables -march=westmere + +config MSILVERMONT + bool "Intel Silvermont" + select X86_P6_NOP + help + + Select this for the Intel Silvermont platform. + + Enables -march=silvermont + +config MGOLDMONT + bool "Intel Goldmont" + select X86_P6_NOP + help + + Select this for the Intel Goldmont platform including Apollo Lake and Denverton. + + Enables -march=goldmont + +config MGOLDMONTPLUS + bool "Intel Goldmont Plus" + select X86_P6_NOP + help + + Select this for the Intel Goldmont Plus platform including Gemini Lake. + + Enables -march=goldmont-plus + +config MSANDYBRIDGE + bool "Intel Sandy Bridge" + select X86_P6_NOP + help + + Select this for 2nd Gen Core processors in the Sandy Bridge family. + + Enables -march=sandybridge + +config MIVYBRIDGE + bool "Intel Ivy Bridge" + select X86_P6_NOP + help + + Select this for 3rd Gen Core processors in the Ivy Bridge family. + + Enables -march=ivybridge + +config MHASWELL + bool "Intel Haswell" + select X86_P6_NOP + help + + Select this for 4th Gen Core processors in the Haswell family. + + Enables -march=haswell + +config MBROADWELL + bool "Intel Broadwell" + select X86_P6_NOP + help + + Select this for 5th Gen Core processors in the Broadwell family. + + Enables -march=broadwell + +config MSKYLAKE + bool "Intel Skylake" + select X86_P6_NOP + help + + Select this for 6th Gen Core processors in the Skylake family. + + Enables -march=skylake + +config MSKYLAKEX + bool "Intel Skylake X" + select X86_P6_NOP + help + + Select this for 6th Gen Core processors in the Skylake X family. + + Enables -march=skylake-avx512 + +config MCANNONLAKE + bool "Intel Cannon Lake" + select X86_P6_NOP + help + + Select this for 8th Gen Core processors + + Enables -march=cannonlake + +config MICELAKE + bool "Intel Ice Lake" + select X86_P6_NOP + help + + Select this for 10th Gen Core processors in the Ice Lake family. + + Enables -march=icelake-client + +config MCASCADELAKE + bool "Intel Cascade Lake" + select X86_P6_NOP + help + + Select this for Xeon processors in the Cascade Lake family. + + Enables -march=cascadelake + +config MCOOPERLAKE + bool "Intel Cooper Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + select X86_P6_NOP + help + + Select this for Xeon processors in the Cooper Lake family. + + Enables -march=cooperlake + +config MTIGERLAKE + bool "Intel Tiger Lake" + depends on (CC_IS_GCC && GCC_VERSION > 100100) || (CC_IS_CLANG && CLANG_VERSION >= 100000) + select X86_P6_NOP + help + + Select this for third-generation 10 nm process processors in the Tiger Lake family. + + Enables -march=tigerlake + +config MSAPPHIRERAPIDS + bool "Intel Sapphire Rapids" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + select X86_P6_NOP + help + + Select this for third-generation 10 nm process processors in the Sapphire Rapids family. + + Enables -march=sapphirerapids + +config MROCKETLAKE + bool "Intel Rocket Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + select X86_P6_NOP + help + + Select this for eleventh-generation processors in the Rocket Lake family. + + Enables -march=rocketlake + +config MALDERLAKE + bool "Intel Alder Lake" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + select X86_P6_NOP + help + + Select this for twelfth-generation processors in the Alder Lake family. + + Enables -march=alderlake + config GENERIC_CPU bool "Generic-x86-64" depends on X86_64 @@ -294,6 +558,50 @@ config GENERIC_CPU Generic x86-64 CPU. Run equally well on all x86-64 CPUs. +config GENERIC_CPU2 + bool "Generic-x86-64-v2" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + depends on X86_64 + help + Generic x86-64 CPU. + Run equally well on all x86-64 CPUs with min support of x86-64-v2. + +config GENERIC_CPU3 + bool "Generic-x86-64-v3" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + depends on X86_64 + help + Generic x86-64-v3 CPU with v3 instructions. + Run equally well on all x86-64 CPUs with min support of x86-64-v3. + +config GENERIC_CPU4 + bool "Generic-x86-64-v4" + depends on (CC_IS_GCC && GCC_VERSION > 110000) || (CC_IS_CLANG && CLANG_VERSION >= 120000) + depends on X86_64 + help + Generic x86-64 CPU with v4 instructions. + Run equally well on all x86-64 CPUs with min support of x86-64-v4. + +config MNATIVE_INTEL + bool "Intel-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for AMD CPUs. Intel Only! + + Enables -march=native + +config MNATIVE_AMD + bool "AMD-Native optimizations autodetected by the compiler" + help + + Clang 3.8, GCC 4.2 and above support -march=native, which automatically detects + the optimum settings to use based on your processor. Do NOT use this + for Intel CPUs. AMD Only! + + Enables -march=native + endchoice config X86_GENERIC @@ -318,7 +626,7 @@ config X86_INTERNODE_CACHE_SHIFT config X86_L1_CACHE_SHIFT int default "7" if MPENTIUM4 || MPSC - default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || X86_GENERIC || GENERIC_CPU + default "6" if MK7 || MK8 || MPENTIUMM || MCORE2 || MATOM || MVIAC7 || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD || X86_GENERIC || GENERIC_CPU || GENERIC_CPU2 || GENERIC_CPU3 || GENERIC_CPU4 default "4" if MELAN || M486SX || M486 || MGEODEGX1 default "5" if MWINCHIP3D || MWINCHIPC6 || MCRUSOE || MEFFICEON || MCYRIXIII || MK6 || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || M586 || MVIAC3_2 || MGEODE_LX @@ -336,11 +644,11 @@ config X86_ALIGNMENT_16 config X86_INTEL_USERCOPY def_bool y - depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 + depends on MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M586MMX || X86_GENERIC || MK8 || MK7 || MEFFICEON || MCORE2 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL config X86_USE_PPRO_CHECKSUM def_bool y - depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM + depends on MWINCHIP3D || MWINCHIPC6 || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MK8 || MVIAC3_2 || MVIAC7 || MEFFICEON || MGEODE_LX || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD # # P6_NOPs are a relatively minor optimization that require a family >= @@ -356,26 +664,26 @@ config X86_USE_PPRO_CHECKSUM config X86_P6_NOP def_bool y depends on X86_64 - depends on (MCORE2 || MPENTIUM4 || MPSC) + depends on (MCORE2 || MPENTIUM4 || MPSC || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL) config X86_TSC def_bool y - depends on (MWINCHIP3D || MCRUSOE || MEFFICEON || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || MK8 || MVIAC3_2 || MVIAC7 || MGEODEGX1 || MGEODE_LX || MCORE2 || MATOM) || X86_64 + depends on (MWINCHIP3D || MCRUSOE || MEFFICEON || MCYRIXIII || MK7 || MK6 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || M586MMX || M586TSC || MK8 || MVIAC3_2 || MVIAC7 || MGEODEGX1 || MGEODE_LX || MCORE2 || MATOM || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD) || X86_64 config X86_CMPXCHG64 def_bool y - depends on X86_PAE || X86_64 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || M586TSC || M586MMX || MATOM || MGEODE_LX || MGEODEGX1 || MK6 || MK7 || MK8 + depends on X86_PAE || X86_64 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || M586TSC || M586MMX || MATOM || MGEODE_LX || MGEODEGX1 || MK6 || MK7 || MK8 || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD # this should be set for all -march=.. options where the compiler # generates cmov. config X86_CMOV def_bool y - depends on (MK8 || MK7 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MCRUSOE || MEFFICEON || X86_64 || MATOM || MGEODE_LX) + depends on (MK8 || MK7 || MCORE2 || MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MCRUSOE || MEFFICEON || X86_64 || MATOM || MGEODE_LX || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD) config X86_MINIMUM_CPU_FAMILY int default "64" if X86_64 - default "6" if X86_32 && (MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MEFFICEON || MATOM || MCRUSOE || MCORE2 || MK7 || MK8) + default "6" if X86_32 && (MPENTIUM4 || MPENTIUMM || MPENTIUMIII || MPENTIUMII || M686 || MVIAC3_2 || MVIAC7 || MEFFICEON || MATOM || MCRUSOE || MCORE2 || MK7 || MK8 || MK8SSE3 || MK10 || MBARCELONA || MBOBCAT || MJAGUAR || MBULLDOZER || MPILEDRIVER || MSTEAMROLLER || MEXCAVATOR || MZEN || MZEN2 || MZEN3 || MNEHALEM || MWESTMERE || MSILVERMONT || MGOLDMONT || MGOLDMONTPLUS || MSANDYBRIDGE || MIVYBRIDGE || MHASWELL || MBROADWELL || MSKYLAKE || MSKYLAKEX || MCANNONLAKE || MICELAKE || MCASCADELAKE || MCOOPERLAKE || MTIGERLAKE || MSAPPHIRERAPIDS || MROCKETLAKE || MALDERLAKE || MNATIVE_INTEL || MNATIVE_AMD) default "5" if X86_32 && X86_CMPXCHG64 default "4" diff --git a/arch/x86/Makefile b/arch/x86/Makefile index bafbd905e6e7..bc336020df59 100644 --- a/arch/x86/Makefile +++ b/arch/x86/Makefile @@ -67,7 +67,8 @@ export BITS # # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=53383 # -KBUILD_CFLAGS += -mno-sse -mno-mmx -mno-sse2 -mno-3dnow -mno-avx +KBUILD_CFLAGS += -mno-sse -mno-mmx -mno-sse2 -mno-3dnow -mno-avx -mno-avx2 \ + -mno-avx512f -O3 ifeq ($(CONFIG_X86_KERNEL_IBT),y) # @@ -147,14 +148,6 @@ else # Use -mskip-rax-setup if supported. KBUILD_CFLAGS += $(call cc-option,-mskip-rax-setup) - # FIXME - should be integrated in Makefile.cpu (Makefile_32.cpu) - cflags-$(CONFIG_MK8) += -march=k8 - cflags-$(CONFIG_MPSC) += -march=nocona - cflags-$(CONFIG_MCORE2) += -march=core2 - cflags-$(CONFIG_MATOM) += -march=atom - cflags-$(CONFIG_GENERIC_CPU) += -mtune=generic - KBUILD_CFLAGS += $(cflags-y) - KBUILD_CFLAGS += -mno-red-zone KBUILD_CFLAGS += -mcmodel=kernel endif diff --git a/arch/x86/Makefile.postlink b/arch/x86/Makefile.postlink new file mode 100644 index 000000000000..b38ffa4defb3 --- /dev/null +++ b/arch/x86/Makefile.postlink @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: GPL-2.0 +# =========================================================================== +# Post-link x86 pass +# =========================================================================== +# +# 1. Separate relocations from vmlinux into vmlinux.relocs. +# 2. Strip relocations from vmlinux. + +PHONY := __archpost +__archpost: + +-include include/config/auto.conf +include scripts/Kbuild.include + +CMD_RELOCS = arch/x86/tools/relocs +quiet_cmd_relocs = RELOCS $@.relocs + cmd_relocs = $(CMD_RELOCS) $@ > $@.relocs;$(CMD_RELOCS) --abs-relocs $@ + +quiet_cmd_strip_relocs = RSTRIP $@ + cmd_strip_relocs = $(OBJCOPY) --remove-section='.rel.*' --remove-section='.rel__*' --remove-section='.rela.*' --remove-section='.rela__*' $@ + +# `@true` prevents complaint when there is nothing to be done + +vmlinux: FORCE + @true +ifeq ($(CONFIG_X86_NEED_RELOCS),y) + $(call cmd,relocs) + $(call cmd,strip_relocs) +endif + +%.ko: FORCE + @true + +clean: + @rm -f vmlinux.relocs + +PHONY += FORCE clean + +FORCE: + +.PHONY: $(PHONY) diff --git a/arch/x86/Makefile_32.cpu b/arch/x86/Makefile_32.cpu index 94834c4b5e5e..81923b4afdf8 100644 --- a/arch/x86/Makefile_32.cpu +++ b/arch/x86/Makefile_32.cpu @@ -10,44 +10,3 @@ else align := -falign-functions=0 -falign-jumps=0 -falign-loops=0 endif -cflags-$(CONFIG_M486SX) += -march=i486 -cflags-$(CONFIG_M486) += -march=i486 -cflags-$(CONFIG_M586) += -march=i586 -cflags-$(CONFIG_M586TSC) += -march=i586 -cflags-$(CONFIG_M586MMX) += -march=pentium-mmx -cflags-$(CONFIG_M686) += -march=i686 -cflags-$(CONFIG_MPENTIUMII) += -march=i686 $(call tune,pentium2) -cflags-$(CONFIG_MPENTIUMIII) += -march=i686 $(call tune,pentium3) -cflags-$(CONFIG_MPENTIUMM) += -march=i686 $(call tune,pentium3) -cflags-$(CONFIG_MPENTIUM4) += -march=i686 $(call tune,pentium4) -cflags-$(CONFIG_MK6) += -march=k6 -# Please note, that patches that add -march=athlon-xp and friends are pointless. -# They make zero difference whatsosever to performance at this time. -cflags-$(CONFIG_MK7) += -march=athlon -cflags-$(CONFIG_MK8) += $(call cc-option,-march=k8,-march=athlon) -cflags-$(CONFIG_MCRUSOE) += -march=i686 $(align) -cflags-$(CONFIG_MEFFICEON) += -march=i686 $(call tune,pentium3) $(align) -cflags-$(CONFIG_MWINCHIPC6) += $(call cc-option,-march=winchip-c6,-march=i586) -cflags-$(CONFIG_MWINCHIP3D) += $(call cc-option,-march=winchip2,-march=i586) -cflags-$(CONFIG_MCYRIXIII) += $(call cc-option,-march=c3,-march=i486) $(align) -cflags-$(CONFIG_MVIAC3_2) += $(call cc-option,-march=c3-2,-march=i686) -cflags-$(CONFIG_MVIAC7) += -march=i686 -cflags-$(CONFIG_MCORE2) += -march=i686 $(call tune,core2) -cflags-$(CONFIG_MATOM) += $(call cc-option,-march=atom,$(call cc-option,-march=core2,-march=i686)) \ - $(call cc-option,-mtune=atom,$(call cc-option,-mtune=generic)) - -# AMD Elan support -cflags-$(CONFIG_MELAN) += -march=i486 - -# Geode GX1 support -cflags-$(CONFIG_MGEODEGX1) += -march=pentium-mmx -cflags-$(CONFIG_MGEODE_LX) += $(call cc-option,-march=geode,-march=pentium-mmx) -# add at the end to overwrite eventual tuning options from earlier -# cpu entries -cflags-$(CONFIG_X86_GENERIC) += $(call tune,generic,$(call tune,i686)) - -# Bug fix for binutils: this option is required in order to keep -# binutils from generating NOPL instructions against our will. -ifneq ($(CONFIG_X86_P6_NOP),y) -cflags-y += $(call cc-option,-Wa$(comma)-mtune=generic32,) -endif diff --git a/arch/x86/boot/compressed/Makefile b/arch/x86/boot/compressed/Makefile index 35ce1a64068b..eba7709d75ae 100644 --- a/arch/x86/boot/compressed/Makefile +++ b/arch/x86/boot/compressed/Makefile @@ -120,14 +120,12 @@ $(obj)/vmlinux.bin: vmlinux FORCE targets += $(patsubst $(obj)/%,%,$(vmlinux-objs-y)) vmlinux.bin.all vmlinux.relocs -CMD_RELOCS = arch/x86/tools/relocs -quiet_cmd_relocs = RELOCS $@ - cmd_relocs = $(CMD_RELOCS) $< > $@;$(CMD_RELOCS) --abs-relocs $< -$(obj)/vmlinux.relocs: vmlinux FORCE - $(call if_changed,relocs) +# vmlinux.relocs is created by the vmlinux postlink step. +vmlinux.relocs: vmlinux + @true vmlinux.bin.all-y := $(obj)/vmlinux.bin -vmlinux.bin.all-$(CONFIG_X86_NEED_RELOCS) += $(obj)/vmlinux.relocs +vmlinux.bin.all-$(CONFIG_X86_NEED_RELOCS) += vmlinux.relocs $(obj)/vmlinux.bin.gz: $(vmlinux.bin.all-y) FORCE $(call if_changed,gzip) diff --git a/arch/x86/include/asm/vermagic.h b/arch/x86/include/asm/vermagic.h index 75884d2cdec3..4e6a08d4c7e5 100644 --- a/arch/x86/include/asm/vermagic.h +++ b/arch/x86/include/asm/vermagic.h @@ -17,6 +17,48 @@ #define MODULE_PROC_FAMILY "586MMX " #elif defined CONFIG_MCORE2 #define MODULE_PROC_FAMILY "CORE2 " +#elif defined CONFIG_MNATIVE_INTEL +#define MODULE_PROC_FAMILY "NATIVE_INTEL " +#elif defined CONFIG_MNATIVE_AMD +#define MODULE_PROC_FAMILY "NATIVE_AMD " +#elif defined CONFIG_MNEHALEM +#define MODULE_PROC_FAMILY "NEHALEM " +#elif defined CONFIG_MWESTMERE +#define MODULE_PROC_FAMILY "WESTMERE " +#elif defined CONFIG_MSILVERMONT +#define MODULE_PROC_FAMILY "SILVERMONT " +#elif defined CONFIG_MGOLDMONT +#define MODULE_PROC_FAMILY "GOLDMONT " +#elif defined CONFIG_MGOLDMONTPLUS +#define MODULE_PROC_FAMILY "GOLDMONTPLUS " +#elif defined CONFIG_MSANDYBRIDGE +#define MODULE_PROC_FAMILY "SANDYBRIDGE " +#elif defined CONFIG_MIVYBRIDGE +#define MODULE_PROC_FAMILY "IVYBRIDGE " +#elif defined CONFIG_MHASWELL +#define MODULE_PROC_FAMILY "HASWELL " +#elif defined CONFIG_MBROADWELL +#define MODULE_PROC_FAMILY "BROADWELL " +#elif defined CONFIG_MSKYLAKE +#define MODULE_PROC_FAMILY "SKYLAKE " +#elif defined CONFIG_MSKYLAKEX +#define MODULE_PROC_FAMILY "SKYLAKEX " +#elif defined CONFIG_MCANNONLAKE +#define MODULE_PROC_FAMILY "CANNONLAKE " +#elif defined CONFIG_MICELAKE +#define MODULE_PROC_FAMILY "ICELAKE " +#elif defined CONFIG_MCASCADELAKE +#define MODULE_PROC_FAMILY "CASCADELAKE " +#elif defined CONFIG_MCOOPERLAKE +#define MODULE_PROC_FAMILY "COOPERLAKE " +#elif defined CONFIG_MTIGERLAKE +#define MODULE_PROC_FAMILY "TIGERLAKE " +#elif defined CONFIG_MSAPPHIRERAPIDS +#define MODULE_PROC_FAMILY "SAPPHIRERAPIDS " +#elif defined CONFIG_ROCKETLAKE +#define MODULE_PROC_FAMILY "ROCKETLAKE " +#elif defined CONFIG_MALDERLAKE +#define MODULE_PROC_FAMILY "ALDERLAKE " #elif defined CONFIG_MATOM #define MODULE_PROC_FAMILY "ATOM " #elif defined CONFIG_M686 @@ -35,6 +77,30 @@ #define MODULE_PROC_FAMILY "K7 " #elif defined CONFIG_MK8 #define MODULE_PROC_FAMILY "K8 " +#elif defined CONFIG_MK8SSE3 +#define MODULE_PROC_FAMILY "K8SSE3 " +#elif defined CONFIG_MK10 +#define MODULE_PROC_FAMILY "K10 " +#elif defined CONFIG_MBARCELONA +#define MODULE_PROC_FAMILY "BARCELONA " +#elif defined CONFIG_MBOBCAT +#define MODULE_PROC_FAMILY "BOBCAT " +#elif defined CONFIG_MBULLDOZER +#define MODULE_PROC_FAMILY "BULLDOZER " +#elif defined CONFIG_MPILEDRIVER +#define MODULE_PROC_FAMILY "PILEDRIVER " +#elif defined CONFIG_MSTEAMROLLER +#define MODULE_PROC_FAMILY "STEAMROLLER " +#elif defined CONFIG_MJAGUAR +#define MODULE_PROC_FAMILY "JAGUAR " +#elif defined CONFIG_MEXCAVATOR +#define MODULE_PROC_FAMILY "EXCAVATOR " +#elif defined CONFIG_MZEN +#define MODULE_PROC_FAMILY "ZEN " +#elif defined CONFIG_MZEN2 +#define MODULE_PROC_FAMILY "ZEN2 " +#elif defined CONFIG_MZEN3 +#define MODULE_PROC_FAMILY "ZEN3 " #elif defined CONFIG_MELAN #define MODULE_PROC_FAMILY "ELAN " #elif defined CONFIG_MCRUSOE diff --git a/block/bfq-iosched.c b/block/bfq-iosched.c index c740b41fe0a4..ac24b8222bd1 100644 --- a/block/bfq-iosched.c +++ b/block/bfq-iosched.c @@ -7463,6 +7463,7 @@ MODULE_ALIAS("bfq-iosched"); static int __init bfq_init(void) { int ret; + char msg[60] = "BFQ I/O-scheduler: BFQ-CachyOS v6.0"; #ifdef CONFIG_BFQ_GROUP_IOSCHED ret = blkcg_policy_register(&blkcg_policy_bfq); @@ -7494,6 +7495,11 @@ static int __init bfq_init(void) if (ret) goto slab_kill; +#ifdef CONFIG_BFQ_GROUP_IOSCHED + strcat(msg, " (with cgroups support)"); +#endif + pr_info("%s", msg); + return 0; slab_kill: diff --git a/block/elevator.c b/block/elevator.c index bd71f0fc4e4b..389cb51389af 100644 --- a/block/elevator.c +++ b/block/elevator.c @@ -640,8 +640,13 @@ static struct elevator_type *elevator_get_default(struct request_queue *q) if (q->nr_hw_queues != 1 && !blk_mq_is_shared_tags(q->tag_set->flags)) +#if defined(CONFIG_CACHY) && defined(CONFIG_MQ_IOSCHED_KYBER) + return elevator_get(q, "kyber", false); +#elif defined(CONFIG_CACHY) + return elevator_get(q, "mq-deadline", false); +#else return NULL; - +#endif return elevator_get(q, "mq-deadline", false); } diff --git a/drivers/i2c/busses/Kconfig b/drivers/i2c/busses/Kconfig index 7284206b278b..6849fe2af246 100644 --- a/drivers/i2c/busses/Kconfig +++ b/drivers/i2c/busses/Kconfig @@ -229,6 +229,15 @@ config I2C_CHT_WC combined with a FUSB302 Type-C port-controller as such it is advised to also select CONFIG_TYPEC_FUSB302=m. +config I2C_NCT6775 + tristate "Nuvoton NCT6775 and compatible SMBus controller" + help + If you say yes to this option, support will be included for the + Nuvoton NCT6775 and compatible SMBus controllers. + + This driver can also be built as a module. If so, the module + will be called i2c-nct6775. + config I2C_NFORCE2 tristate "Nvidia nForce2, nForce3 and nForce4" depends on PCI diff --git a/drivers/i2c/busses/Makefile b/drivers/i2c/busses/Makefile index c5cac15f075c..245012554359 100644 --- a/drivers/i2c/busses/Makefile +++ b/drivers/i2c/busses/Makefile @@ -20,6 +20,7 @@ obj-$(CONFIG_I2C_CHT_WC) += i2c-cht-wc.o obj-$(CONFIG_I2C_I801) += i2c-i801.o obj-$(CONFIG_I2C_ISCH) += i2c-isch.o obj-$(CONFIG_I2C_ISMT) += i2c-ismt.o +obj-$(CONFIG_I2C_NCT6775) += i2c-nct6775.o obj-$(CONFIG_I2C_NFORCE2) += i2c-nforce2.o obj-$(CONFIG_I2C_NFORCE2_S4985) += i2c-nforce2-s4985.o obj-$(CONFIG_I2C_NVIDIA_GPU) += i2c-nvidia-gpu.o diff --git a/drivers/i2c/busses/i2c-nct6775.c b/drivers/i2c/busses/i2c-nct6775.c new file mode 100644 index 000000000000..0462f0952043 --- /dev/null +++ b/drivers/i2c/busses/i2c-nct6775.c @@ -0,0 +1,647 @@ +/* + * i2c-nct6775 - Driver for the SMBus master functionality of + * Nuvoton NCT677x Super-I/O chips + * + * Copyright (C) 2019 Adam Honse + * + * Derived from nct6775 hwmon driver + * Copyright (C) 2012 Guenter Roeck + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DRVNAME "i2c-nct6775" + +/* Nuvoton SMBus address offsets */ +#define SMBHSTDAT (0 + nuvoton_nct6793d_smba) +#define SMBBLKSZ (1 + nuvoton_nct6793d_smba) +#define SMBHSTCMD (2 + nuvoton_nct6793d_smba) +#define SMBHSTIDX (3 + nuvoton_nct6793d_smba) //Index field is the Command field on other controllers +#define SMBHSTCTL (4 + nuvoton_nct6793d_smba) +#define SMBHSTADD (5 + nuvoton_nct6793d_smba) +#define SMBHSTERR (9 + nuvoton_nct6793d_smba) +#define SMBHSTSTS (0xE + nuvoton_nct6793d_smba) + +/* Command register */ +#define NCT6793D_READ_BYTE 0 +#define NCT6793D_READ_WORD 1 +#define NCT6793D_READ_BLOCK 2 +#define NCT6793D_BLOCK_WRITE_READ_PROC_CALL 3 +#define NCT6793D_PROC_CALL 4 +#define NCT6793D_WRITE_BYTE 8 +#define NCT6793D_WRITE_WORD 9 +#define NCT6793D_WRITE_BLOCK 10 + +/* Control register */ +#define NCT6793D_MANUAL_START 128 +#define NCT6793D_SOFT_RESET 64 + +/* Error register */ +#define NCT6793D_NO_ACK 32 + +/* Status register */ +#define NCT6793D_FIFO_EMPTY 1 +#define NCT6793D_FIFO_FULL 2 +#define NCT6793D_MANUAL_ACTIVE 4 + +#define NCT6775_LD_SMBUS 0x0B + +/* Other settings */ +#define MAX_RETRIES 400 + +enum kinds { nct6106, nct6775, nct6776, nct6779, nct6791, nct6792, nct6793, + nct6795, nct6796, nct6798 }; + +struct nct6775_sio_data { + int sioreg; + enum kinds kind; +}; + +/* used to set data->name = nct6775_device_names[data->sio_kind] */ +static const char * const nct6775_device_names[] = { + "nct6106", + "nct6775", + "nct6776", + "nct6779", + "nct6791", + "nct6792", + "nct6793", + "nct6795", + "nct6796", + "nct6798", +}; + +static const char * const nct6775_sio_names[] __initconst = { + "NCT6106D", + "NCT6775F", + "NCT6776D/F", + "NCT6779D", + "NCT6791D", + "NCT6792D", + "NCT6793D", + "NCT6795D", + "NCT6796D", + "NCT6798D", +}; + +#define SIO_REG_LDSEL 0x07 /* Logical device select */ +#define SIO_REG_DEVID 0x20 /* Device ID (2 bytes) */ +#define SIO_REG_SMBA 0x62 /* SMBus base address register */ + +#define SIO_NCT6106_ID 0xc450 +#define SIO_NCT6775_ID 0xb470 +#define SIO_NCT6776_ID 0xc330 +#define SIO_NCT6779_ID 0xc560 +#define SIO_NCT6791_ID 0xc800 +#define SIO_NCT6792_ID 0xc910 +#define SIO_NCT6793_ID 0xd120 +#define SIO_NCT6795_ID 0xd350 +#define SIO_NCT6796_ID 0xd420 +#define SIO_NCT6798_ID 0xd428 +#define SIO_ID_MASK 0xFFF0 + +static inline void +superio_outb(int ioreg, int reg, int val) +{ + outb(reg, ioreg); + outb(val, ioreg + 1); +} + +static inline int +superio_inb(int ioreg, int reg) +{ + outb(reg, ioreg); + return inb(ioreg + 1); +} + +static inline void +superio_select(int ioreg, int ld) +{ + outb(SIO_REG_LDSEL, ioreg); + outb(ld, ioreg + 1); +} + +static inline int +superio_enter(int ioreg) +{ + /* + * Try to reserve and for exclusive access. + */ + if (!request_muxed_region(ioreg, 2, DRVNAME)) + return -EBUSY; + + outb(0x87, ioreg); + outb(0x87, ioreg); + + return 0; +} + +static inline void +superio_exit(int ioreg) +{ + outb(0xaa, ioreg); + outb(0x02, ioreg); + outb(0x02, ioreg + 1); + release_region(ioreg, 2); +} + +/* + * ISA constants + */ + +#define IOREGION_ALIGNMENT (~7) +#define IOREGION_LENGTH 2 +#define ADDR_REG_OFFSET 0 +#define DATA_REG_OFFSET 1 + +#define NCT6775_REG_BANK 0x4E +#define NCT6775_REG_CONFIG 0x40 + +static struct i2c_adapter *nct6775_adapter; + +struct i2c_nct6775_adapdata { + unsigned short smba; +}; + +/* Return negative errno on error. */ +static s32 nct6775_access(struct i2c_adapter * adap, u16 addr, + unsigned short flags, char read_write, + u8 command, int size, union i2c_smbus_data * data) +{ + struct i2c_nct6775_adapdata *adapdata = i2c_get_adapdata(adap); + unsigned short nuvoton_nct6793d_smba = adapdata->smba; + int i, len, cnt; + union i2c_smbus_data tmp_data; + int timeout = 0; + + tmp_data.word = 0; + cnt = 0; + len = 0; + + outb_p(NCT6793D_SOFT_RESET, SMBHSTCTL); + + switch (size) { + case I2C_SMBUS_QUICK: + outb_p((addr << 1) | read_write, + SMBHSTADD); + break; + case I2C_SMBUS_BYTE_DATA: + tmp_data.byte = data->byte; + case I2C_SMBUS_BYTE: + outb_p((addr << 1) | read_write, + SMBHSTADD); + outb_p(command, SMBHSTIDX); + if (read_write == I2C_SMBUS_WRITE) { + outb_p(tmp_data.byte, SMBHSTDAT); + outb_p(NCT6793D_WRITE_BYTE, SMBHSTCMD); + } + else { + outb_p(NCT6793D_READ_BYTE, SMBHSTCMD); + } + break; + case I2C_SMBUS_WORD_DATA: + outb_p((addr << 1) | read_write, + SMBHSTADD); + outb_p(command, SMBHSTIDX); + if (read_write == I2C_SMBUS_WRITE) { + outb_p(data->word & 0xff, SMBHSTDAT); + outb_p((data->word & 0xff00) >> 8, SMBHSTDAT); + outb_p(NCT6793D_WRITE_WORD, SMBHSTCMD); + } + else { + outb_p(NCT6793D_READ_WORD, SMBHSTCMD); + } + break; + case I2C_SMBUS_BLOCK_DATA: + outb_p((addr << 1) | read_write, + SMBHSTADD); + outb_p(command, SMBHSTIDX); + if (read_write == I2C_SMBUS_WRITE) { + len = data->block[0]; + if (len == 0 || len > I2C_SMBUS_BLOCK_MAX) + return -EINVAL; + outb_p(len, SMBBLKSZ); + + cnt = 1; + if (len >= 4) { + for (i = cnt; i <= 4; i++) { + outb_p(data->block[i], SMBHSTDAT); + } + + len -= 4; + cnt += 4; + } + else { + for (i = cnt; i <= len; i++ ) { + outb_p(data->block[i], SMBHSTDAT); + } + + len = 0; + } + + outb_p(NCT6793D_WRITE_BLOCK, SMBHSTCMD); + } + else { + return -ENOTSUPP; + } + break; + default: + dev_warn(&adap->dev, "Unsupported transaction %d\n", size); + return -EOPNOTSUPP; + } + + outb_p(NCT6793D_MANUAL_START, SMBHSTCTL); + + while ((size == I2C_SMBUS_BLOCK_DATA) && (len > 0)) { + if (read_write == I2C_SMBUS_WRITE) { + timeout = 0; + while ((inb_p(SMBHSTSTS) & NCT6793D_FIFO_EMPTY) == 0) + { + if(timeout > MAX_RETRIES) + { + return -ETIMEDOUT; + } + usleep_range(250, 500); + timeout++; + } + + //Load more bytes into FIFO + if (len >= 4) { + for (i = cnt; i <= (cnt + 4); i++) { + outb_p(data->block[i], SMBHSTDAT); + } + + len -= 4; + cnt += 4; + } + else { + for (i = cnt; i <= (cnt + len); i++) { + outb_p(data->block[i], SMBHSTDAT); + } + + len = 0; + } + } + else { + return -ENOTSUPP; + } + + } + + //wait for manual mode to complete + timeout = 0; + while ((inb_p(SMBHSTSTS) & NCT6793D_MANUAL_ACTIVE) != 0) + { + if(timeout > MAX_RETRIES) + { + return -ETIMEDOUT; + } + usleep_range(250, 500); + timeout++; + } + + if ((inb_p(SMBHSTERR) & NCT6793D_NO_ACK) != 0) { + return -ENXIO; + } + else if ((read_write == I2C_SMBUS_WRITE) || (size == I2C_SMBUS_QUICK)) { + return 0; + } + + switch (size) { + case I2C_SMBUS_QUICK: + case I2C_SMBUS_BYTE_DATA: + data->byte = inb_p(SMBHSTDAT); + break; + case I2C_SMBUS_WORD_DATA: + data->word = inb_p(SMBHSTDAT) + (inb_p(SMBHSTDAT) << 8); + break; + } + return 0; +} + +static u32 nct6775_func(struct i2c_adapter *adapter) +{ + return I2C_FUNC_SMBUS_QUICK | I2C_FUNC_SMBUS_BYTE | + I2C_FUNC_SMBUS_BYTE_DATA | I2C_FUNC_SMBUS_WORD_DATA | + I2C_FUNC_SMBUS_BLOCK_DATA; +} + +static const struct i2c_algorithm smbus_algorithm = { + .smbus_xfer = nct6775_access, + .functionality = nct6775_func, +}; + +static int nct6775_add_adapter(unsigned short smba, const char *name, struct i2c_adapter **padap) +{ + struct i2c_adapter *adap; + struct i2c_nct6775_adapdata *adapdata; + int retval; + + adap = kzalloc(sizeof(*adap), GFP_KERNEL); + if (adap == NULL) { + return -ENOMEM; + } + + adap->owner = THIS_MODULE; + adap->class = I2C_CLASS_HWMON | I2C_CLASS_SPD; + adap->algo = &smbus_algorithm; + + adapdata = kzalloc(sizeof(*adapdata), GFP_KERNEL); + if (adapdata == NULL) { + kfree(adap); + return -ENOMEM; + } + + adapdata->smba = smba; + + snprintf(adap->name, sizeof(adap->name), + "SMBus NCT67xx adapter%s at %04x", name, smba); + + i2c_set_adapdata(adap, adapdata); + + retval = i2c_add_adapter(adap); + if (retval) { + kfree(adapdata); + kfree(adap); + return retval; + } + + *padap = adap; + return 0; +} + +static void nct6775_remove_adapter(struct i2c_adapter *adap) +{ + struct i2c_nct6775_adapdata *adapdata = i2c_get_adapdata(adap); + + if (adapdata->smba) { + i2c_del_adapter(adap); + kfree(adapdata); + kfree(adap); + } +} + +//static SIMPLE_DEV_PM_OPS(nct6775_dev_pm_ops, nct6775_suspend, nct6775_resume); + +/* + * when Super-I/O functions move to a separate file, the Super-I/O + * bus will manage the lifetime of the device and this module will only keep + * track of the nct6775 driver. But since we use platform_device_alloc(), we + * must keep track of the device + */ +static struct platform_device *pdev[2]; + +static int nct6775_probe(struct platform_device *pdev) +{ + struct device *dev = &pdev->dev; + struct nct6775_sio_data *sio_data = dev_get_platdata(dev); + struct resource *res; + + res = platform_get_resource(pdev, IORESOURCE_IO, 0); + if (!devm_request_region(&pdev->dev, res->start, IOREGION_LENGTH, + DRVNAME)) + return -EBUSY; + + switch (sio_data->kind) { + case nct6791: + case nct6792: + case nct6793: + case nct6795: + case nct6796: + case nct6798: + nct6775_add_adapter(res->start, "", &nct6775_adapter); + break; + default: + return -ENODEV; + } + + return 0; +} +/* +static void nct6791_enable_io_mapping(int sioaddr) +{ + int val; + + val = superio_inb(sioaddr, NCT6791_REG_HM_IO_SPACE_LOCK_ENABLE); + if (val & 0x10) { + pr_info("Enabling hardware monitor logical device mappings.\n"); + superio_outb(sioaddr, NCT6791_REG_HM_IO_SPACE_LOCK_ENABLE, + val & ~0x10); + } +}*/ + +static struct platform_driver i2c_nct6775_driver = { + .driver = { + .name = DRVNAME, +// .pm = &nct6775_dev_pm_ops, + }, + .probe = nct6775_probe, +}; + +static void __exit i2c_nct6775_exit(void) +{ + int i; + + if(nct6775_adapter) + nct6775_remove_adapter(nct6775_adapter); + + for (i = 0; i < ARRAY_SIZE(pdev); i++) { + if (pdev[i]) + platform_device_unregister(pdev[i]); + } + platform_driver_unregister(&i2c_nct6775_driver); +} + +/* nct6775_find() looks for a '627 in the Super-I/O config space */ +static int __init nct6775_find(int sioaddr, struct nct6775_sio_data *sio_data) +{ + u16 val; + int err; + int addr; + + err = superio_enter(sioaddr); + if (err) + return err; + + val = (superio_inb(sioaddr, SIO_REG_DEVID) << 8) | + superio_inb(sioaddr, SIO_REG_DEVID + 1); + + switch (val & SIO_ID_MASK) { + case SIO_NCT6106_ID: + sio_data->kind = nct6106; + break; + case SIO_NCT6775_ID: + sio_data->kind = nct6775; + break; + case SIO_NCT6776_ID: + sio_data->kind = nct6776; + break; + case SIO_NCT6779_ID: + sio_data->kind = nct6779; + break; + case SIO_NCT6791_ID: + sio_data->kind = nct6791; + break; + case SIO_NCT6792_ID: + sio_data->kind = nct6792; + break; + case SIO_NCT6793_ID: + sio_data->kind = nct6793; + break; + case SIO_NCT6795_ID: + sio_data->kind = nct6795; + break; + case SIO_NCT6796_ID: + sio_data->kind = nct6796; + break; + case SIO_NCT6798_ID: + sio_data->kind = nct6798; + break; + default: + if (val != 0xffff) + pr_debug("unsupported chip ID: 0x%04x\n", val); + superio_exit(sioaddr); + return -ENODEV; + } + + /* We have a known chip, find the SMBus I/O address */ + superio_select(sioaddr, NCT6775_LD_SMBUS); + val = (superio_inb(sioaddr, SIO_REG_SMBA) << 8) + | superio_inb(sioaddr, SIO_REG_SMBA + 1); + addr = val & IOREGION_ALIGNMENT; + if (addr == 0) { + pr_err("Refusing to enable a Super-I/O device with a base I/O port 0\n"); + superio_exit(sioaddr); + return -ENODEV; + } + + //if (sio_data->kind == nct6791 || sio_data->kind == nct6792 || + // sio_data->kind == nct6793 || sio_data->kind == nct6795 || + // sio_data->kind == nct6796) + // nct6791_enable_io_mapping(sioaddr); + + superio_exit(sioaddr); + pr_info("Found %s or compatible chip at %#x:%#x\n", + nct6775_sio_names[sio_data->kind], sioaddr, addr); + sio_data->sioreg = sioaddr; + + return addr; +} + +static int __init i2c_nct6775_init(void) +{ + int i, err; + bool found = false; + int address; + struct resource res; + struct nct6775_sio_data sio_data; + int sioaddr[2] = { 0x2e, 0x4e }; + + err = platform_driver_register(&i2c_nct6775_driver); + if (err) + return err; + + /* + * initialize sio_data->kind and sio_data->sioreg. + * + * when Super-I/O functions move to a separate file, the Super-I/O + * driver will probe 0x2e and 0x4e and auto-detect the presence of a + * nct6775 hardware monitor, and call probe() + */ + for (i = 0; i < ARRAY_SIZE(pdev); i++) { + address = nct6775_find(sioaddr[i], &sio_data); + if (address <= 0) + continue; + + found = true; + + pdev[i] = platform_device_alloc(DRVNAME, address); + if (!pdev[i]) { + err = -ENOMEM; + goto exit_device_unregister; + } + + err = platform_device_add_data(pdev[i], &sio_data, + sizeof(struct nct6775_sio_data)); + if (err) + goto exit_device_put; + + memset(&res, 0, sizeof(res)); + res.name = DRVNAME; + res.start = address; + res.end = address + IOREGION_LENGTH - 1; + res.flags = IORESOURCE_IO; + + err = acpi_check_resource_conflict(&res); + if (err) { + platform_device_put(pdev[i]); + pdev[i] = NULL; + continue; + } + + err = platform_device_add_resources(pdev[i], &res, 1); + if (err) + goto exit_device_put; + + /* platform_device_add calls probe() */ + err = platform_device_add(pdev[i]); + if (err) + goto exit_device_put; + } + if (!found) { + err = -ENODEV; + goto exit_unregister; + } + + return 0; + +exit_device_put: + platform_device_put(pdev[i]); +exit_device_unregister: + while (--i >= 0) { + if (pdev[i]) + platform_device_unregister(pdev[i]); + } +exit_unregister: + platform_driver_unregister(&i2c_nct6775_driver); + return err; +} + +MODULE_AUTHOR("Adam Honse "); +MODULE_DESCRIPTION("SMBus driver for NCT6775F and compatible chips"); +MODULE_LICENSE("GPL"); + +module_init(i2c_nct6775_init); +module_exit(i2c_nct6775_exit); diff --git a/drivers/i2c/busses/i2c-piix4.c b/drivers/i2c/busses/i2c-piix4.c index 809fbd014cd6..d54b35b147ee 100644 --- a/drivers/i2c/busses/i2c-piix4.c +++ b/drivers/i2c/busses/i2c-piix4.c @@ -568,11 +568,11 @@ static int piix4_transaction(struct i2c_adapter *piix4_adapter) if (srvrworks_csb5_delay) /* Extra delay for SERVERWORKS_CSB5 */ usleep_range(2000, 2100); else - usleep_range(250, 500); + usleep_range(25, 50); while ((++timeout < MAX_TIMEOUT) && ((temp = inb_p(SMBHSTSTS)) & 0x01)) - usleep_range(250, 500); + usleep_range(25, 50); /* If the SMBus is still busy, we give up */ if (timeout == MAX_TIMEOUT) { diff --git a/drivers/md/dm-crypt.c b/drivers/md/dm-crypt.c index 2653516bcdef..cdf9d8c7b556 100644 --- a/drivers/md/dm-crypt.c +++ b/drivers/md/dm-crypt.c @@ -3137,6 +3137,11 @@ static int crypt_ctr_optional(struct dm_target *ti, unsigned int argc, char **ar } } +#ifdef CONFIG_CACHY + set_bit(DM_CRYPT_NO_READ_WORKQUEUE, &cc->flags); + set_bit(DM_CRYPT_NO_WRITE_WORKQUEUE, &cc->flags); +#endif + return 0; } diff --git a/drivers/pci/quirks.c b/drivers/pci/quirks.c index 4944798e75b5..d11a9aeb1096 100644 --- a/drivers/pci/quirks.c +++ b/drivers/pci/quirks.c @@ -3612,6 +3612,106 @@ static void quirk_no_bus_reset(struct pci_dev *dev) dev->dev_flags |= PCI_DEV_FLAGS_NO_BUS_RESET; } +static bool acs_on_downstream; +static bool acs_on_multifunction; + +#define NUM_ACS_IDS 16 +struct acs_on_id { + unsigned short vendor; + unsigned short device; +}; +static struct acs_on_id acs_on_ids[NUM_ACS_IDS]; +static u8 max_acs_id; + +static __init int pcie_acs_override_setup(char *p) +{ + if (!p) + return -EINVAL; + + while (*p) { + if (!strncmp(p, "downstream", 10)) + acs_on_downstream = true; + if (!strncmp(p, "multifunction", 13)) + acs_on_multifunction = true; + if (!strncmp(p, "id:", 3)) { + char opt[5]; + int ret; + long val; + + if (max_acs_id >= NUM_ACS_IDS - 1) { + pr_warn("Out of PCIe ACS override slots (%d)\n", + NUM_ACS_IDS); + goto next; + } + + p += 3; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].vendor = val; + + p += strcspn(p, ":"); + if (*p != ':') { + pr_warn("PCIe ACS invalid ID\n"); + goto next; + } + + p++; + snprintf(opt, 5, "%s", p); + ret = kstrtol(opt, 16, &val); + if (ret) { + pr_warn("PCIe ACS ID parse error %d\n", ret); + goto next; + } + acs_on_ids[max_acs_id].device = val; + max_acs_id++; + } +next: + p += strcspn(p, ","); + if (*p == ',') + p++; + } + + if (acs_on_downstream || acs_on_multifunction || max_acs_id) + pr_warn("Warning: PCIe ACS overrides enabled; This may allow non-IOMMU protected peer-to-peer DMA\n"); + + return 0; +} +early_param("pcie_acs_override", pcie_acs_override_setup); + +static int pcie_acs_overrides(struct pci_dev *dev, u16 acs_flags) +{ + int i; + + /* Never override ACS for legacy devices or devices with ACS caps */ + if (!pci_is_pcie(dev) || + pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ACS)) + return -ENOTTY; + + for (i = 0; i < max_acs_id; i++) + if (acs_on_ids[i].vendor == dev->vendor && + acs_on_ids[i].device == dev->device) + return 1; + + switch (pci_pcie_type(dev)) { + case PCI_EXP_TYPE_DOWNSTREAM: + case PCI_EXP_TYPE_ROOT_PORT: + if (acs_on_downstream) + return 1; + break; + case PCI_EXP_TYPE_ENDPOINT: + case PCI_EXP_TYPE_UPSTREAM: + case PCI_EXP_TYPE_LEG_END: + case PCI_EXP_TYPE_RC_END: + if (acs_on_multifunction && dev->multifunction) + return 1; + } + + return -ENOTTY; +} /* * Some NVIDIA GPU devices do not work with bus reset, SBR needs to be * prevented for those affected devices. @@ -4980,6 +5080,7 @@ static const struct pci_dev_acs_enabled { { PCI_VENDOR_ID_NXP, 0x8d9b, pci_quirk_nxp_rp_acs }, /* Zhaoxin Root/Downstream Ports */ { PCI_VENDOR_ID_ZHAOXIN, PCI_ANY_ID, pci_quirk_zhaoxin_pcie_ports_acs }, + { PCI_ANY_ID, PCI_ANY_ID, pcie_acs_overrides }, { 0 } }; diff --git a/include/linux/pagemap.h b/include/linux/pagemap.h index 0178b2040ea3..39f05e00dac4 100644 --- a/include/linux/pagemap.h +++ b/include/linux/pagemap.h @@ -1183,7 +1183,7 @@ struct readahead_control { ._index = i, \ } -#define VM_READAHEAD_PAGES (SZ_128K / PAGE_SIZE) +#define VM_READAHEAD_PAGES (SZ_8M / PAGE_SIZE) void page_cache_ra_unbounded(struct readahead_control *, unsigned long nr_to_read, unsigned long lookahead_count); diff --git a/include/linux/user_namespace.h b/include/linux/user_namespace.h index 33a4240e6a6f..82213f9c4c17 100644 --- a/include/linux/user_namespace.h +++ b/include/linux/user_namespace.h @@ -139,6 +139,8 @@ static inline void set_rlimit_ucount_max(struct user_namespace *ns, #ifdef CONFIG_USER_NS +extern int unprivileged_userns_clone; + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { if (ns) @@ -172,6 +174,8 @@ extern bool current_in_userns(const struct user_namespace *target_ns); struct ns_common *ns_get_owner(struct ns_common *ns); #else +#define unprivileged_userns_clone 0 + static inline struct user_namespace *get_user_ns(struct user_namespace *ns) { return &init_user_ns; diff --git a/init/Kconfig b/init/Kconfig index d1d779d6ba43..96d0f05a4a7e 100644 --- a/init/Kconfig +++ b/init/Kconfig @@ -112,6 +112,10 @@ config THREAD_INFO_IN_TASK menu "General setup" +config CACHY + bool "Some kernel tweaks by CachyOS" + default y + config BROKEN bool @@ -334,6 +338,19 @@ config KERNEL_UNCOMPRESSED endchoice +menu "ZSTD compression options" + depends on KERNEL_ZSTD + +config ZSTD_COMP_VAL + int "Compression level (1-22)" + range 1 22 + default "22" + help + Choose a compression level for zstd kernel compression. + Default is 22, which is the maximum. + +endmenu + config DEFAULT_INIT string "Default init path" default "" @@ -1241,6 +1258,22 @@ config USER_NS If unsure, say N. +config USER_NS_UNPRIVILEGED + bool "Allow unprivileged users to create namespaces" + default y + depends on USER_NS + help + When disabled, unprivileged users will not be able to create + new namespaces. Allowing users to create their own namespaces + has been part of several recent local privilege escalation + exploits, so if you need user namespaces but are + paranoid^Wsecurity-conscious you want to disable this. + + This setting can be overridden at runtime via the + kernel.unprivileged_userns_clone sysctl. + + If unsure, say Y. + config PID_NS bool "PID Namespaces" default y @@ -1407,6 +1440,12 @@ config CC_OPTIMIZE_FOR_PERFORMANCE with the "-O2" compiler flag for best performance and most helpful compile-time warnings. +config CC_OPTIMIZE_FOR_PERFORMANCE_O3 + bool "Optimize more for performance (-O3)" + help + Choosing this option will pass "-O3" to your compiler to optimize + the kernel yet more for performance. + config CC_OPTIMIZE_FOR_SIZE bool "Optimize for size (-Os)" help diff --git a/kernel/Kconfig.hz b/kernel/Kconfig.hz index 38ef6d06888e..0f78364efd4f 100644 --- a/kernel/Kconfig.hz +++ b/kernel/Kconfig.hz @@ -40,6 +40,27 @@ choice on SMP and NUMA systems and exactly dividing by both PAL and NTSC frame rates for video and multimedia work. + config HZ_500 + bool "500 HZ" + help + 500 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_600 + bool "600 HZ" + help + 600 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + + config HZ_750 + bool "750 HZ" + help + 750 Hz is a balanced timer frequency. Provides fast interactivity + on desktops with good smoothness without increasing CPU power + consumption and sacrificing the battery life on laptops. + config HZ_1000 bool "1000 HZ" help @@ -53,6 +74,9 @@ config HZ default 100 if HZ_100 default 250 if HZ_250 default 300 if HZ_300 + default 500 if HZ_500 + default 600 if HZ_600 + default 750 if HZ_750 default 1000 if HZ_1000 config SCHED_HRTICK diff --git a/kernel/fork.c b/kernel/fork.c index 2b6bd511c6ed..704fe6bc9cb4 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -99,6 +99,10 @@ #include #include +#ifdef CONFIG_USER_NS +#include +#endif + #include #include #include @@ -2009,6 +2013,10 @@ static __latent_entropy struct task_struct *copy_process( if ((clone_flags & (CLONE_NEWUSER|CLONE_FS)) == (CLONE_NEWUSER|CLONE_FS)) return ERR_PTR(-EINVAL); + if ((clone_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) + if (!capable(CAP_SYS_ADMIN)) + return ERR_PTR(-EPERM); + /* * Thread groups must share signals as well, and detached threads * can only be started up within the thread group. @@ -3159,6 +3167,12 @@ int ksys_unshare(unsigned long unshare_flags) if (unshare_flags & CLONE_NEWNS) unshare_flags |= CLONE_FS; + if ((unshare_flags & CLONE_NEWUSER) && !unprivileged_userns_clone) { + err = -EPERM; + if (!capable(CAP_SYS_ADMIN)) + goto bad_unshare_out; + } + err = check_unshare_flags(unshare_flags); if (err) goto bad_unshare_out; diff --git a/kernel/module/Kconfig b/kernel/module/Kconfig index 26ea5d04f56c..e5311101b93d 100644 --- a/kernel/module/Kconfig +++ b/kernel/module/Kconfig @@ -219,6 +219,31 @@ config MODULE_COMPRESS_ZSTD endchoice +menu "ZSTD module compression options" + depends on MODULE_COMPRESS_ZSTD + +config MODULE_COMPRESS_ZSTD_LEVEL + int "Compression level (1-19)" + range 1 19 + default 9 + help + Compression level used by zstd for compressing modules. + +config MODULE_COMPRESS_ZSTD_ULTRA + bool "Enable ZSTD ultra compression" + help + Compress modules with ZSTD using the highest possible compression. + +config MODULE_COMPRESS_ZSTD_LEVEL_ULTRA + int "Compression level (20-22)" + depends on MODULE_COMPRESS_ZSTD_ULTRA + range 20 22 + default 20 + help + Ultra compression level used by zstd for compressing modules. + +endmenu + config MODULE_DECOMPRESS bool "Support in-kernel module decompression" depends on MODULE_COMPRESS_GZIP || MODULE_COMPRESS_XZ diff --git a/kernel/rcu/Kconfig b/kernel/rcu/Kconfig index d471d22a5e21..e945e522969e 100644 --- a/kernel/rcu/Kconfig +++ b/kernel/rcu/Kconfig @@ -282,9 +282,9 @@ config RCU_NOCB_CPU_CB_BOOST depends on RCU_NOCB_CPU && RCU_BOOST default y if PREEMPT_RT help - Use this option to invoke offloaded callbacks as SCHED_FIFO + Use this option to invoke offloaded callbacks as SCHED_RR to avoid starvation by heavy SCHED_OTHER background load. - Of course, running as SCHED_FIFO during callback floods will + Of course, running as SCHED_RR during callback floods will cause the rcuo[ps] kthreads to monopolize the CPU for hundreds of milliseconds or more. Therefore, when enabling this option, it is your responsibility to ensure that latency-sensitive diff --git a/kernel/rcu/rcutorture.c b/kernel/rcu/rcutorture.c index d8e1b270a065..c543c348e7c0 100644 --- a/kernel/rcu/rcutorture.c +++ b/kernel/rcu/rcutorture.c @@ -2156,7 +2156,7 @@ static int rcutorture_booster_init(unsigned int cpu) t = per_cpu(ksoftirqd, cpu); WARN_ON_ONCE(!t); sp.sched_priority = 2; - sched_setscheduler_nocheck(t, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(t, SCHED_RR, &sp); } /* Don't allow time recalculation while creating a new task. */ diff --git a/kernel/rcu/tree.c b/kernel/rcu/tree.c index 5b52727dcc1c..759d73a3411c 100644 --- a/kernel/rcu/tree.c +++ b/kernel/rcu/tree.c @@ -4239,8 +4239,8 @@ static void __init rcu_start_exp_gp_kworkers(void) return; } - sched_setscheduler_nocheck(rcu_exp_gp_kworker->task, SCHED_FIFO, ¶m); - sched_setscheduler_nocheck(rcu_exp_par_gp_kworker->task, SCHED_FIFO, + sched_setscheduler_nocheck(rcu_exp_gp_kworker->task, SCHED_RR, ¶m); + sched_setscheduler_nocheck(rcu_exp_par_gp_kworker->task, SCHED_RR, ¶m); } @@ -4278,7 +4278,7 @@ static int __init rcu_spawn_gp_kthread(void) return 0; if (kthread_prio) { sp.sched_priority = kthread_prio; - sched_setscheduler_nocheck(t, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(t, SCHED_RR, &sp); } rnp = rcu_get_root(); raw_spin_lock_irqsave_rcu_node(rnp, flags); diff --git a/kernel/rcu/tree_nocb.h b/kernel/rcu/tree_nocb.h index a8f574d8850d..34325fbc0cf0 100644 --- a/kernel/rcu/tree_nocb.h +++ b/kernel/rcu/tree_nocb.h @@ -1319,7 +1319,7 @@ static void rcu_spawn_cpu_nocb_kthread(int cpu) } WRITE_ONCE(rdp_gp->nocb_gp_kthread, t); if (kthread_prio) - sched_setscheduler_nocheck(t, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(t, SCHED_RR, &sp); } mutex_unlock(&rdp_gp->nocb_gp_kthread_mutex); @@ -1330,7 +1330,7 @@ static void rcu_spawn_cpu_nocb_kthread(int cpu) goto end; if (IS_ENABLED(CONFIG_RCU_NOCB_CPU_CB_BOOST) && kthread_prio) - sched_setscheduler_nocheck(t, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(t, SCHED_RR, &sp); WRITE_ONCE(rdp->nocb_cb_kthread, t); WRITE_ONCE(rdp->nocb_gp_kthread, rdp_gp->nocb_gp_kthread); diff --git a/kernel/rcu/tree_plugin.h b/kernel/rcu/tree_plugin.h index 49468b4d1b43..304c8c20acbc 100644 --- a/kernel/rcu/tree_plugin.h +++ b/kernel/rcu/tree_plugin.h @@ -1007,7 +1007,7 @@ static void rcu_cpu_kthread_setup(unsigned int cpu) struct sched_param sp; sp.sched_priority = kthread_prio; - sched_setscheduler_nocheck(current, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(current, SCHED_RR, &sp); #endif /* #ifdef CONFIG_RCU_BOOST */ WRITE_ONCE(rdp->rcuc_activity, jiffies); @@ -1206,7 +1206,7 @@ static void rcu_spawn_one_boost_kthread(struct rcu_node *rnp) rnp->boost_kthread_task = t; raw_spin_unlock_irqrestore_rcu_node(rnp, flags); sp.sched_priority = kthread_prio; - sched_setscheduler_nocheck(t, SCHED_FIFO, &sp); + sched_setscheduler_nocheck(t, SCHED_RR, &sp); wake_up_process(t); /* get to TASK_INTERRUPTIBLE quickly. */ out: diff --git a/kernel/sched/core.c b/kernel/sched/core.c index ee28253c9ac0..6a4417178679 100644 --- a/kernel/sched/core.c +++ b/kernel/sched/core.c @@ -728,7 +728,7 @@ static void update_rq_clock_task(struct rq *rq, s64 delta) if ((irq_delta + steal) && sched_feat(NONTASK_CAPACITY)) update_irq_load_avg(rq, irq_delta + steal); #endif - update_rq_clock_pelt(rq, delta); + update_rq_clock_task_mult(rq, delta); } void update_rq_clock(struct rq *rq) @@ -3725,13 +3725,6 @@ void sched_ttwu_pending(void *arg) if (!llist) return; - /* - * rq::ttwu_pending racy indication of out-standing wakeups. - * Races such that false-negatives are possible, since they - * are shorter lived that false-positives would be. - */ - WRITE_ONCE(rq->ttwu_pending, 0); - rq_lock_irqsave(rq, &rf); update_rq_clock(rq); @@ -3745,6 +3738,17 @@ void sched_ttwu_pending(void *arg) ttwu_do_activate(rq, p, p->sched_remote_wakeup ? WF_MIGRATED : 0, &rf); } + /* + * Must be after enqueueing at least once task such that + * idle_cpu() does not observe a false-negative -- if it does, + * it is possible for select_idle_siblings() to stack a number + * of tasks on this CPU during that window. + * + * It is ok to clear ttwu_pending when another task pending. + * We will receive IPI after local irq enabled and then enqueue it. + * Since now nr_running > 0, idle_cpu() will always get correct result. + */ + WRITE_ONCE(rq->ttwu_pending, 0); rq_unlock_irqrestore(rq, &rf); } diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c index 914096c5b1ae..c4fd77e7e8f3 100644 --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -68,9 +68,13 @@ * * (default: 6ms * (1 + ilog(ncpus)), units: nanoseconds) */ +#ifdef CONFIG_CACHY +unsigned int sysctl_sched_latency = 3000000ULL; +static unsigned int normalized_sysctl_sched_latency = 3000000ULL; +#else unsigned int sysctl_sched_latency = 6000000ULL; static unsigned int normalized_sysctl_sched_latency = 6000000ULL; - +#endif /* * The initial- and re-scaling of tunables is configurable * @@ -89,8 +93,13 @@ unsigned int sysctl_sched_tunable_scaling = SCHED_TUNABLESCALING_LOG; * * (default: 0.75 msec * (1 + ilog(ncpus)), units: nanoseconds) */ +#ifdef CONFIG_CACHY +unsigned int sysctl_sched_min_granularity = 400000ULL; +static unsigned int normalized_sysctl_sched_min_granularity = 400000ULL; +#else unsigned int sysctl_sched_min_granularity = 750000ULL; static unsigned int normalized_sysctl_sched_min_granularity = 750000ULL; +#endif /* * Minimal preemption granularity for CPU-bound SCHED_IDLE tasks. @@ -120,8 +129,13 @@ unsigned int sysctl_sched_child_runs_first __read_mostly; * * (default: 1 msec * (1 + ilog(ncpus)), units: nanoseconds) */ +#ifdef CONFIG_CACHY +unsigned int sysctl_sched_wakeup_granularity = 500000UL; +static unsigned int normalized_sysctl_sched_wakeup_granularity = 500000UL; +#else unsigned int sysctl_sched_wakeup_granularity = 1000000UL; static unsigned int normalized_sysctl_sched_wakeup_granularity = 1000000UL; +#endif const_debug unsigned int sysctl_sched_migration_cost = 500000UL; @@ -174,8 +188,12 @@ int __weak arch_asym_cpu_priority(int cpu) * * (default: 5 msec, units: microseconds) */ +#ifdef CONFIG_CACHY +static unsigned int sysctl_sched_cfs_bandwidth_slice = 3000UL; +#else static unsigned int sysctl_sched_cfs_bandwidth_slice = 5000UL; #endif +#endif #ifdef CONFIG_SYSCTL static struct ctl_table sched_fair_sysctls[] = { @@ -4573,7 +4591,13 @@ check_preempt_tick(struct cfs_rq *cfs_rq, struct sched_entity *curr) struct sched_entity *se; s64 delta; - ideal_runtime = sched_slice(cfs_rq, curr); + /* + * When many tasks blow up the sched_period; it is possible that + * sched_slice() reports unusually large results (when many tasks are + * very light for example). Therefore impose a maximum. + */ + ideal_runtime = min_t(u64, sched_slice(cfs_rq, curr), sysctl_sched_latency); + delta_exec = curr->sum_exec_runtime - curr->prev_sum_exec_runtime; if (delta_exec > ideal_runtime) { resched_curr(rq_of(cfs_rq)); diff --git a/kernel/sched/pelt.c b/kernel/sched/pelt.c index 0f310768260c..036b0e2cd2b4 100644 --- a/kernel/sched/pelt.c +++ b/kernel/sched/pelt.c @@ -467,3 +467,63 @@ int update_irq_load_avg(struct rq *rq, u64 running) return ret; } #endif + +__read_mostly unsigned int sched_pelt_lshift; + +#ifdef CONFIG_SYSCTL +static unsigned int sysctl_sched_pelt_multiplier = 1; + +int sched_pelt_multiplier(struct ctl_table *table, int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + static DEFINE_MUTEX(mutex); + unsigned int old; + int ret; + + mutex_lock(&mutex); + old = sysctl_sched_pelt_multiplier; + ret = proc_dointvec(table, write, buffer, lenp, ppos); + if (ret) + goto undo; + if (!write) + goto done; + + switch (sysctl_sched_pelt_multiplier) { + case 1: + fallthrough; + case 2: + fallthrough; + case 4: + WRITE_ONCE(sched_pelt_lshift, + sysctl_sched_pelt_multiplier >> 1); + goto done; + default: + ret = -EINVAL; + } + +undo: + sysctl_sched_pelt_multiplier = old; +done: + mutex_unlock(&mutex); + + return ret; +} + +static struct ctl_table sched_pelt_sysctls[] = { + { + .procname = "sched_pelt_multiplier", + .data = &sysctl_sched_pelt_multiplier, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = sched_pelt_multiplier, + }, + {} +}; + +static int __init sched_pelt_sysctl_init(void) +{ + register_sysctl_init("kernel", sched_pelt_sysctls); + return 0; +} +late_initcall(sched_pelt_sysctl_init); +#endif diff --git a/kernel/sched/pelt.h b/kernel/sched/pelt.h index 3a0e0dc28721..9b35b5072bae 100644 --- a/kernel/sched/pelt.h +++ b/kernel/sched/pelt.h @@ -61,6 +61,14 @@ static inline void cfs_se_util_change(struct sched_avg *avg) WRITE_ONCE(avg->util_est.enqueued, enqueued); } +static inline u64 rq_clock_task_mult(struct rq *rq) +{ + lockdep_assert_rq_held(rq); + assert_clock_updated(rq); + + return rq->clock_task_mult; +} + static inline u64 rq_clock_pelt(struct rq *rq) { lockdep_assert_rq_held(rq); @@ -72,7 +80,7 @@ static inline u64 rq_clock_pelt(struct rq *rq) /* The rq is idle, we can sync to clock_task */ static inline void _update_idle_rq_clock_pelt(struct rq *rq) { - rq->clock_pelt = rq_clock_task(rq); + rq->clock_pelt = rq_clock_task_mult(rq); u64_u32_store(rq->clock_idle, rq_clock(rq)); /* Paired with smp_rmb in migrate_se_pelt_lag() */ @@ -121,6 +129,27 @@ static inline void update_rq_clock_pelt(struct rq *rq, s64 delta) rq->clock_pelt += delta; } +extern unsigned int sched_pelt_lshift; + +/* + * absolute time |1 |2 |3 |4 |5 |6 | + * @ mult = 1 --------****************--------****************- + * @ mult = 2 --------********----------------********--------- + * @ mult = 4 --------****--------------------****------------- + * clock task mult + * @ mult = 2 | | |2 |3 | | | | |5 |6 | | | + * @ mult = 4 | | | | |2|3| | | | | | | | | | |5|6| | | | | | | + * + */ +static inline void update_rq_clock_task_mult(struct rq *rq, s64 delta) +{ + delta <<= READ_ONCE(sched_pelt_lshift); + + rq->clock_task_mult += delta; + + update_rq_clock_pelt(rq, delta); +} + /* * When rq becomes idle, we have to check if it has lost idle time * because it was fully busy. A rq is fully used when the /Sum util_sum @@ -147,7 +176,7 @@ static inline void update_idle_rq_clock_pelt(struct rq *rq) * rq's clock_task. */ if (util_sum >= divider) - rq->lost_idle_time += rq_clock_task(rq) - rq->clock_pelt; + rq->lost_idle_time += rq_clock_task_mult(rq) - rq->clock_pelt; _update_idle_rq_clock_pelt(rq); } @@ -218,13 +247,18 @@ update_irq_load_avg(struct rq *rq, u64 running) return 0; } -static inline u64 rq_clock_pelt(struct rq *rq) +static inline u64 rq_clock_task_mult(struct rq *rq) { return rq_clock_task(rq); } +static inline u64 rq_clock_pelt(struct rq *rq) +{ + return rq_clock_task_mult(rq); +} + static inline void -update_rq_clock_pelt(struct rq *rq, s64 delta) { } +update_rq_clock_task_mult(struct rq *rq, s64 delta) { } static inline void update_idle_rq_clock_pelt(struct rq *rq) { } diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h index f34b489636ff..8485f5407eb8 100644 --- a/kernel/sched/sched.h +++ b/kernel/sched/sched.h @@ -1024,6 +1024,7 @@ struct rq { u64 clock; /* Ensure that all clocks are in the same cache line */ u64 clock_task ____cacheline_aligned; + u64 clock_task_mult; u64 clock_pelt; unsigned long lost_idle_time; u64 clock_pelt_idle; diff --git a/kernel/sysctl.c b/kernel/sysctl.c index e9a3094c52e5..9623c2fc774b 100644 --- a/kernel/sysctl.c +++ b/kernel/sysctl.c @@ -89,6 +89,9 @@ #ifdef CONFIG_PERF_EVENTS static const int six_hundred_forty_kb = 640 * 1024; #endif +#ifdef CONFIG_USER_NS +#include +#endif static const int ngroups_max = NGROUPS_MAX; @@ -1649,6 +1652,15 @@ static struct ctl_table kern_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, +#ifdef CONFIG_USER_NS + { + .procname = "unprivileged_userns_clone", + .data = &unprivileged_userns_clone, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif #ifdef CONFIG_PROC_SYSCTL { .procname = "tainted", diff --git a/kernel/user_namespace.c b/kernel/user_namespace.c index 5481ba44a8d6..423ab2563ad7 100644 --- a/kernel/user_namespace.c +++ b/kernel/user_namespace.c @@ -21,6 +21,13 @@ #include #include +/* sysctl */ +#ifdef CONFIG_USER_NS_UNPRIVILEGED +int unprivileged_userns_clone = 1; +#else +int unprivileged_userns_clone; +#endif + static struct kmem_cache *user_ns_cachep __read_mostly; static DEFINE_MUTEX(userns_state_mutex); diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 88f34cdeef02..2c579d965266 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -312,8 +312,21 @@ config DEBUG_INFO_REDUCED DEBUG_INFO build and compile times are reduced too. Only works with newer gcc versions. -config DEBUG_INFO_COMPRESSED - bool "Compressed debugging information" +choice + prompt "Compressed Debug information" + help + Compress the resulting debug info. Results in smaller debug info sections, + but requires that consumers are able to decompress the results. + + If unsure, choose DEBUG_INFO_COMPRESSED_NONE. + +config DEBUG_INFO_COMPRESSED_NONE + bool "Don't compress debug information" + help + Don't compress debug info sections. + +config DEBUG_INFO_COMPRESSED_ZLIB + bool "Compress debugging information with zlib" depends on $(cc-option,-gz=zlib) depends on $(ld-option,--compress-debug-sections=zlib) help @@ -327,6 +340,18 @@ config DEBUG_INFO_COMPRESSED preferable to setting $KDEB_COMPRESS to "none" which would be even larger. +config DEBUG_INFO_COMPRESSED_ZSTD + bool "Compress debugging information with zstd" + depends on $(cc-option,-gz=zstd) + depends on $(ld-option,--compress-debug-sections=zstd) + help + Compress the debug information using zstd. This may provide better + compression than zlib, for about the same time costs, but requires newer + toolchain support. Requires GCC 13.0+ or Clang 16.0+, binutils 2.40+, and + zstd. + +endchoice # "Compressed Debug information" + config DEBUG_INFO_SPLIT bool "Produce split debuginfo in .dwo files" depends on $(cc-option,-gsplit-dwarf) diff --git a/lib/string.c b/lib/string.c index 6f334420f687..2e60e9e40535 100644 --- a/lib/string.c +++ b/lib/string.c @@ -866,24 +866,61 @@ char *strnstr(const char *s1, const char *s2, size_t len) EXPORT_SYMBOL(strnstr); #endif +#if defined(CONFIG_ARCH_HAS_FAST_MULTIPLIER) && BITS_PER_LONG == 64 + +#define MEMCHR_MASK_GEN(mask) (mask *= 0x0101010101010101ULL) + +#elif defined(CONFIG_ARCH_HAS_FAST_MULTIPLIER) + +#define MEMCHR_MASK_GEN(mask) \ + do { \ + mask *= 0x01010101; \ + mask |= mask << 32; \ + } while (0) + +#else + +#define MEMCHR_MASK_GEN(mask) \ + do { \ + mask |= mask << 8; \ + mask |= mask << 16; \ + mask |= mask << 32; \ + } while (0) + +#endif + #ifndef __HAVE_ARCH_MEMCHR /** * memchr - Find a character in an area of memory. - * @s: The memory area + * @p: The memory area * @c: The byte to search for - * @n: The size of the area. + * @length: The size of the area. * * returns the address of the first occurrence of @c, or %NULL * if @c is not found */ -void *memchr(const void *s, int c, size_t n) +void *memchr(const void *p, int c, unsigned long length) { - const unsigned char *p = s; - while (n-- != 0) { - if ((unsigned char)c == *p++) { - return (void *)(p - 1); + u64 mask, val; + const void *end = p + length; + + c &= 0xff; + if (p <= end - 8) { + mask = c; + MEMCHR_MASK_GEN(mask); + + for (; p <= end - 8; p += 8) { + val = *(u64 *)p ^ mask; + if ((val + 0xfefefefefefefeffu) & + (~val & 0x8080808080808080u)) + break; } } + + for (; p < end; p++) + if (*(unsigned char *)p == c) + return (void *)p; + return NULL; } EXPORT_SYMBOL(memchr); @@ -919,16 +956,7 @@ void *memchr_inv(const void *start, int c, size_t bytes) return check_bytes8(start, value, bytes); value64 = value; -#if defined(CONFIG_ARCH_HAS_FAST_MULTIPLIER) && BITS_PER_LONG == 64 - value64 *= 0x0101010101010101ULL; -#elif defined(CONFIG_ARCH_HAS_FAST_MULTIPLIER) - value64 *= 0x01010101; - value64 |= value64 << 32; -#else - value64 |= value64 << 8; - value64 |= value64 << 16; - value64 |= value64 << 32; -#endif + MEMCHR_MASK_GEN(value64); prefix = (unsigned long)start % 8; if (prefix) { diff --git a/mm/Kconfig b/mm/Kconfig index 0331f1461f81..d5373a23902c 100644 --- a/mm/Kconfig +++ b/mm/Kconfig @@ -149,6 +149,12 @@ config ZSWAP_ZPOOL_DEFAULT_ZSMALLOC select ZSMALLOC help Use the zsmalloc allocator as the default allocator. + +config ZSWAP_ZPOOL_DEFAULT_ZBLOCK + bool "zblock" + select ZBLOCK + help + Use the zblock allocator as the default allocator. endchoice config ZSWAP_ZPOOL_DEFAULT @@ -157,6 +163,7 @@ config ZSWAP_ZPOOL_DEFAULT default "zbud" if ZSWAP_ZPOOL_DEFAULT_ZBUD default "z3fold" if ZSWAP_ZPOOL_DEFAULT_Z3FOLD default "zsmalloc" if ZSWAP_ZPOOL_DEFAULT_ZSMALLOC + default "zblock" if ZSWAP_ZPOOL_DEFAULT_ZBLOCK default "" config ZBUD @@ -187,6 +194,16 @@ config ZSMALLOC pages of various compression levels efficiently. It achieves the highest storage density with the least amount of fragmentation. +config ZBLOCK + tristate "Simple block allocator (zblock)" + depends on ZPOOL + help + A special purpose allocator for storing compressed pages. + It stores integer number of compressed pages per block and + each block consists of number of physical pages being a power of 2. + zblock provides fast read/write, limited worst case time for + operations and good compression ratio in most scenarios. + config ZSMALLOC_STAT bool "Export zsmalloc statistics" depends on ZSMALLOC diff --git a/mm/Makefile b/mm/Makefile index 9a564f836403..eb7235da6e61 100644 --- a/mm/Makefile +++ b/mm/Makefile @@ -110,6 +110,7 @@ obj-$(CONFIG_ZPOOL) += zpool.o obj-$(CONFIG_ZBUD) += zbud.o obj-$(CONFIG_ZSMALLOC) += zsmalloc.o obj-$(CONFIG_Z3FOLD) += z3fold.o +obj-$(CONFIG_ZBLOCK) += zblock.o obj-$(CONFIG_GENERIC_EARLY_IOREMAP) += early_ioremap.o obj-$(CONFIG_CMA) += cma.o obj-$(CONFIG_MEMORY_BALLOON) += balloon_compaction.o diff --git a/mm/compaction.c b/mm/compaction.c index 88fea74c3a86..9745eb6cfb57 100644 --- a/mm/compaction.c +++ b/mm/compaction.c @@ -1727,7 +1727,7 @@ typedef enum { * Allow userspace to control policy on scanning the unevictable LRU for * compactable pages. */ -#ifdef CONFIG_PREEMPT_RT +#if defined(CONFIG_PREEMPT_RT) || defined(CONFIG_CACHY) int sysctl_compact_unevictable_allowed __read_mostly = 0; #else int sysctl_compact_unevictable_allowed __read_mostly = 1; @@ -2719,7 +2719,11 @@ static void compact_nodes(void) * aggressively the kernel should compact memory in the * background. It takes values in the range [0, 100]. */ +#ifdef CONFIG_CACHY +unsigned int __read_mostly sysctl_compaction_proactiveness; +#else unsigned int __read_mostly sysctl_compaction_proactiveness = 20; +#endif int compaction_proactiveness_sysctl_handler(struct ctl_table *table, int write, void *buffer, size_t *length, loff_t *ppos) diff --git a/mm/page-writeback.c b/mm/page-writeback.c index 032a7bf8d259..cbcecd565c16 100644 --- a/mm/page-writeback.c +++ b/mm/page-writeback.c @@ -70,7 +70,11 @@ static long ratelimit_pages = 32; /* * Start background writeback (via writeback threads) at this percentage */ +#ifdef CONFIG_CACHY +static int dirty_background_ratio = 5; +#else static int dirty_background_ratio = 10; +#endif /* * dirty_background_bytes starts at 0 (disabled) so that it is a function of @@ -98,7 +102,11 @@ static unsigned long vm_dirty_bytes; /* * The interval between `kupdate'-style writebacks */ +#ifdef CONFIG_CACHY +unsigned int dirty_writeback_interval = 10 * 100; /* centiseconds */ +#else unsigned int dirty_writeback_interval = 5 * 100; /* centiseconds */ +#endif EXPORT_SYMBOL_GPL(dirty_writeback_interval); diff --git a/mm/swap.c b/mm/swap.c index 9cee7f6a3809..186b4e5dcecf 100644 --- a/mm/swap.c +++ b/mm/swap.c @@ -1070,6 +1070,10 @@ EXPORT_SYMBOL(pagevec_lookup_range_tag); */ void __init swap_setup(void) { +#ifdef CONFIG_CACHY + /* Only swap-in pages requested, avoid readahead */ + page_cluster = 0; +#else unsigned long megs = totalram_pages() >> (20 - PAGE_SHIFT); /* Use a smaller cluster for small-memory machines */ @@ -1081,4 +1085,5 @@ void __init swap_setup(void) * Right now other parts of the system means that we * _really_ don't want to cluster much more */ +#endif } diff --git a/mm/vmpressure.c b/mm/vmpressure.c index b52644771cc4..11a4b0e3b583 100644 --- a/mm/vmpressure.c +++ b/mm/vmpressure.c @@ -43,7 +43,11 @@ static const unsigned long vmpressure_win = SWAP_CLUSTER_MAX * 16; * essence, they are percents: the higher the value, the more number * unsuccessful reclaims there were. */ +#ifdef CONFIG_CACHY +static const unsigned int vmpressure_level_med = 65; +#else static const unsigned int vmpressure_level_med = 60; +#endif static const unsigned int vmpressure_level_critical = 95; /* diff --git a/mm/vmscan.c b/mm/vmscan.c index a69f32dd5ce5..52699ada5be2 100644 --- a/mm/vmscan.c +++ b/mm/vmscan.c @@ -178,7 +178,11 @@ struct scan_control { /* * From 0 .. 200. Higher means more swappy. */ +#ifdef CONFIG_CACHY +int vm_swappiness = 20; +#else int vm_swappiness = 60; +#endif static void set_task_reclaim_state(struct task_struct *task, struct reclaim_state *rs) diff --git a/mm/zblock.c b/mm/zblock.c new file mode 100644 index 000000000000..4767a367fa71 --- /dev/null +++ b/mm/zblock.c @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * zblock.c + * + * Author: Ananda Badmaev + * Copyright (C) 2022, Konsulko AB. + * + * This implementation is based on z3fold written by Vitaly Wool. + * Zblock is a small object allocator with the intention to serve as a + * zpool backend. It operates on page blocks which consist of number + * of physical pages being a power of 2 and store integer number of + * compressed pages per block which results in determinism and simplicity. + * + * zblock doesn't export any API and is meant to be used via zpool API. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#define SLOT_FREE 0 +#define SLOT_OCCUPIED 1 +#define SLOT_MAPPED 2 +#define SLOT_UNMAPPED 3 + +#define SLOT_BITS 5 +#define MAX_SLOTS (1 << SLOT_BITS) +#define SLOT_MASK ((0x1UL << SLOT_BITS) - 1) + +#define BLOCK_DATA_SIZE(order) ((PAGE_SIZE << order) - sizeof(struct zblock_block)) +#define SLOT_SIZE(nslots, order) (round_down((BLOCK_DATA_SIZE(order) / nslots), sizeof(long))) + +#define BLOCK_CACHE_SIZE 32 + +struct zblock_pool; + +struct zblock_ops { + int (*evict)(struct zblock_pool *pool, unsigned long handle); +}; + +/** + * struct zblock_block - block metadata + * Block consists of several (1/2/4/8) pages and contains fixed + * integer number of slots for allocating compressed pages. + * + * lock: protects block + * block_node: links block into the relevant list in the pool + * slot_info: contains data about free/occupied slots + * free_slots: number of free slots in the block + * under_reclaim: if true shows that block is being evicted + */ +struct zblock_block { + spinlock_t lock; + struct list_head block_node; + u8 slot_info[MAX_SLOTS]; + unsigned int free_slots; + bool under_reclaim; +}; +/** + * struct block_desc - general metadata for block lists + * Each block list stores only blocks of corresponding type which means + * that all blocks in it have the same number and size of slots. + * All slots are aligned to size of long. + * + * slot_size: size of slot for this list + * slots_per_block: number of slots per block for this list + * order: order for __get_free_pages + */ +static const struct block_desc { + const unsigned int slot_size; + const unsigned short slots_per_block; + const unsigned short order; +} block_desc[] = { + { SLOT_SIZE(32, 0), 32, 0 }, + { SLOT_SIZE(22, 0), 22, 0 }, + { SLOT_SIZE(17, 0), 17, 0 }, + { SLOT_SIZE(13, 0), 13, 0 }, + { SLOT_SIZE(11, 0), 11, 0 }, + { SLOT_SIZE(9, 0), 9, 0 }, + { SLOT_SIZE(8, 0), 8, 0 }, + { SLOT_SIZE(14, 1), 14, 1 }, + { SLOT_SIZE(12, 1), 12, 1 }, + { SLOT_SIZE(11, 1), 11, 1 }, + { SLOT_SIZE(10, 1), 10, 1 }, + { SLOT_SIZE(9, 1), 9, 1 }, + { SLOT_SIZE(8, 1), 8, 1 }, + { SLOT_SIZE(15, 2), 15, 2 }, + { SLOT_SIZE(14, 2), 14, 2 }, + { SLOT_SIZE(13, 2), 13, 2 }, + { SLOT_SIZE(12, 2), 12, 2 }, + { SLOT_SIZE(11, 2), 11, 2 }, + { SLOT_SIZE(10, 2), 10, 2 }, + { SLOT_SIZE(9, 2), 9, 2 }, + { SLOT_SIZE(8, 2), 8, 2 }, + { SLOT_SIZE(15, 3), 15, 3 }, + { SLOT_SIZE(14, 3), 14, 3 }, + { SLOT_SIZE(13, 3), 13, 3 }, + { SLOT_SIZE(12, 3), 12, 3 }, + { SLOT_SIZE(11, 3), 11, 3 }, + { SLOT_SIZE(10, 3), 10, 3 }, + { SLOT_SIZE(9, 3), 9, 3 }, + { SLOT_SIZE(7, 3), 7, 3 } +}; + +/** + * struct block_list - stores metadata of particular list + * lock: protects list + * head: head of this list + * block_cache: blocks with free slots + * block_count: total number of blocks in the list + */ +struct block_list { + spinlock_t lock; + struct list_head head; + struct zblock_block *block_cache[BLOCK_CACHE_SIZE]; + unsigned long block_count; +}; + +/** + * struct zblock_pool - stores metadata for each zblock pool + * @block_lists: array of block lists + * @ops: pointer to a structure of user defined operations specified at + * pool creation time. + * @zpool: zpool driver + * @zpool_ops: zpool operations structure with an evict callback + * @alloc_flag: protects block allocation from memory leak + * + * This structure is allocated at pool creation time and maintains metadata + * for a particular zblock pool. + */ +struct zblock_pool { + struct block_list block_lists[ARRAY_SIZE(block_desc)]; + const struct zblock_ops *ops; + struct zpool *zpool; + const struct zpool_ops *zpool_ops; + atomic_t alloc_flag; +}; + +/***************** + * Helpers + *****************/ + +static void cache_insert_block(struct zblock_block *block, struct block_list *list) +{ + unsigned int i, min_free_slots, min_index; + + min_free_slots = MAX_SLOTS; + for (i = 0; i < BLOCK_CACHE_SIZE; i++) { + if (!list->block_cache[i] || !(list->block_cache[i])->free_slots) { + list->block_cache[i] = block; + return; + } + if ((list->block_cache[i])->free_slots < min_free_slots) { + min_free_slots = (list->block_cache[i])->free_slots; + min_index = i; + } + } + list->block_cache[min_index] = block; +} + +static struct zblock_block *cache_find_block(struct block_list *list) +{ + int i; + + for (i = 0; i < BLOCK_CACHE_SIZE; i++) { + if (list->block_cache[i] && (list->block_cache[i])->free_slots) + return list->block_cache[i]; + } + return NULL; +} + +static int is_in_cache(struct zblock_block *block, struct block_list *list) +{ + int i; + + for (i = 0; i < BLOCK_CACHE_SIZE; i++) { + if (block == list->block_cache[i]) + return i; + } + return -1; +} + +/* + * allocate new block and add it to corresponding block list + */ +static struct zblock_block *alloc_block(struct zblock_pool *pool, + int block_type, gfp_t gfp) +{ + struct zblock_block *block; + struct block_list *list; + + block = (void *)__get_free_pages(gfp, block_desc[block_type].order); + if (!block) + return NULL; + + list = &(pool->block_lists)[block_type]; + + /* init block data */ + spin_lock_init(&block->lock); + memset(block->slot_info, SLOT_FREE, block_desc[block_type].slots_per_block); + block->free_slots = block_desc[block_type].slots_per_block; + block->under_reclaim = false; + + spin_lock(&list->lock); + /* inserting block into list */ + INIT_LIST_HEAD(&block->block_node); + list_add(&block->block_node, &list->head); + cache_insert_block(block, list); + list->block_count++; + spin_unlock(&list->lock); + return block; +} + +/* + * Encodes the handle of a particular slot in the pool using metadata + */ +static inline unsigned long metadata_to_handle(struct zblock_block *block, + unsigned int block_type, unsigned int slot) +{ + return (unsigned long)(block) + (block_type << SLOT_BITS) + slot; +} + +/* Returns block, block type and slot in the pool corresponding to handle */ +static inline struct zblock_block *handle_to_metadata(unsigned long handle, + unsigned int *block_type, unsigned int *slot) +{ + *block_type = (handle & (PAGE_SIZE - 1)) >> SLOT_BITS; + *slot = handle & SLOT_MASK; + return (struct zblock_block *)(handle & PAGE_MASK); +} + + +/***************** + * API Functions + *****************/ +/** + * zblock_create_pool() - create a new zblock pool + * @gfp: gfp flags when allocating the zblock pool structure + * @ops: user-defined operations for the zblock pool + * + * Return: pointer to the new zblock pool or NULL if the metadata allocation + * failed. + */ +static struct zblock_pool *zblock_create_pool(gfp_t gfp, const struct zblock_ops *ops) +{ + struct zblock_pool *pool; + struct block_list *list; + int i, j, arr_sz; + + pool = kmalloc(sizeof(struct zblock_pool), gfp); + if (!pool) + return NULL; + + arr_sz = ARRAY_SIZE(block_desc); + if (block_desc[arr_sz - 1].slot_size < PAGE_SIZE) + return NULL; + + /* init each block list */ + for (i = 0; i < arr_sz; i++) { + list = &(pool->block_lists)[i]; + spin_lock_init(&list->lock); + INIT_LIST_HEAD(&list->head); + for (j = 0; j < BLOCK_CACHE_SIZE; j++) + list->block_cache[j] = NULL; + list->block_count = 0; + } + pool->ops = ops; + atomic_set(&pool->alloc_flag, 0); + return pool; +} + +/** + * zblock_destroy_pool() - destroys an existing zblock pool + * @pool: the zblock pool to be destroyed + * + */ +static void zblock_destroy_pool(struct zblock_pool *pool) +{ + kfree(pool); +} + + +/** + * zblock_alloc() - allocates a slot of appropriate size + * @pool: zblock pool from which to allocate + * @size: size in bytes of the desired allocation + * @gfp: gfp flags used if the pool needs to grow + * @handle: handle of the new allocation + * + * Return: 0 if success and handle is set, otherwise -EINVAL if the size or + * gfp arguments are invalid or -ENOMEM if the pool was unable to allocate + * a new slot. + */ +static int zblock_alloc(struct zblock_pool *pool, size_t size, gfp_t gfp, + unsigned long *handle) +{ + unsigned int block_type, slot; + struct zblock_block *block; + struct block_list *list; + + if (!size) + return -EINVAL; + + if (size > PAGE_SIZE) + return -ENOSPC; + + /* find basic block type with suitable slot size */ + for (block_type = 0; block_type < ARRAY_SIZE(block_desc); block_type++) { + if (size <= block_desc[block_type].slot_size) + break; + } + list = &(pool->block_lists[block_type]); + +check: + spin_lock(&list->lock); + /* check if there are free slots in cache */ + block = cache_find_block(list); + if (block) + goto found; + spin_unlock(&list->lock); + + /* not found block with free slots try to allocate new empty block */ + if (atomic_cmpxchg(&pool->alloc_flag, 0, 1)) + goto check; + block = alloc_block(pool, block_type, gfp & ~(__GFP_HIGHMEM | __GFP_MOVABLE)); + if (block) { + spin_lock(&list->lock); + goto found; + } + atomic_set(&pool->alloc_flag, 0); + return -ENOMEM; + +found: + spin_lock(&block->lock); + block->free_slots--; + spin_unlock(&list->lock); + /* find the first free slot in block */ + for (slot = 0; slot < block_desc[block_type].slots_per_block; slot++) { + if (block->slot_info[slot] == SLOT_FREE) + break; + } + block->slot_info[slot] = SLOT_OCCUPIED; + spin_unlock(&block->lock); + *handle = metadata_to_handle(block, block_type, slot); + atomic_set(&pool->alloc_flag, 0); + return 0; +} + +/** + * zblock_free() - frees the allocation associated with the given handle + * @pool: pool in which the allocation resided + * @handle: handle associated with the allocation returned by zblock_alloc() + * + */ +static void zblock_free(struct zblock_pool *pool, unsigned long handle) +{ + unsigned int slot, block_type; + struct zblock_block *block; + struct block_list *list; + int i; + + block = handle_to_metadata(handle, &block_type, &slot); + list = &(pool->block_lists[block_type]); + + if (block->under_reclaim) + return; + spin_lock(&list->lock); + i = is_in_cache(block, list); + block->free_slots++; + /* if all slots in block are empty delete whole block */ + if (block->free_slots == block_desc[block_type].slots_per_block) { + list_del(&block->block_node); + list->block_count--; + + /* if cached block to be deleted */ + if (i != -1) + list->block_cache[i] = NULL; + spin_unlock(&list->lock); + free_pages((unsigned long)block, block_desc[block_type].order); + return; + } + /* if block is not cached update cache */ + if (i == -1) + cache_insert_block(block, list); + + spin_lock(&block->lock); + spin_unlock(&list->lock); + block->slot_info[slot] = SLOT_FREE; + spin_unlock(&block->lock); +} + +/** + * zblock_reclaim_block() - evicts allocations from block and frees it + * @pool: pool from which a block will attempt to be evicted + * + * Returns: pages reclaimed count if block is successfully freed + * otherwise -EINVAL if there are no blocks to evict + */ +static int zblock_reclaim_block(struct zblock_pool *pool) +{ + struct zblock_block *block; + struct block_list *list; + unsigned long handle; + int ret, i, reclaimed, block_type, slot; + + /* start with list storing blocks with the worst compression and try + * to evict the first added (oldest) block in this list + */ + for (block_type = ARRAY_SIZE(block_desc) - 1; block_type >= 0; --block_type) { + list = &(pool->block_lists[block_type]); + spin_lock(&list->lock); + + /* find the oldest block in list */ + block = list_last_entry(&list->head, struct zblock_block, block_node); + + if (!block) { + spin_unlock(&list->lock); + continue; + } + i = is_in_cache(block, list); + /* skip iteration if this block is cached */ + if (i != -1) { + spin_unlock(&list->lock); + continue; + } + block->under_reclaim = true; + spin_unlock(&list->lock); + reclaimed = 0; + + /* try to evict all OCCUPIED and UNMAPPED slots in block */ + for (slot = 0; slot < block_desc[block_type].slots_per_block; ++slot) { + if (block->slot_info[slot] == SLOT_OCCUPIED || + block->slot_info[slot] == SLOT_UNMAPPED) { + handle = metadata_to_handle(block, block_type, slot); + ret = pool->ops->evict(pool, handle); + if (ret) + break; + + ++reclaimed; + spin_lock(&block->lock); + block->slot_info[slot] = SLOT_FREE; + spin_unlock(&block->lock); + block->free_slots++; + } + } + spin_lock(&list->lock); + /* some occupied slots remained - insert block */ + if (block->free_slots != block_desc[block_type].slots_per_block) { + block->under_reclaim = false; + cache_insert_block(block, list); + spin_unlock(&list->lock); + } else { + /* all slots are free - delete this block */ + list_del(&block->block_node); + list->block_count--; + spin_unlock(&list->lock); + free_pages((unsigned long)block, block_desc[block_type].order); + } + if (reclaimed != 0) + return reclaimed; + return -EAGAIN; + } + return -EINVAL; +} + + +/** + * zblock_map() - maps the allocation associated with the given handle + * @pool: pool in which the allocation resides + * @handle: handle associated with the allocation to be mapped + * + * + * Returns: a pointer to the mapped allocation + */ +static void *zblock_map(struct zblock_pool *pool, unsigned long handle) +{ + unsigned int block_type, slot; + struct zblock_block *block; + + block = handle_to_metadata(handle, &block_type, &slot); + spin_lock(&block->lock); + block->slot_info[slot] = SLOT_MAPPED; + spin_unlock(&block->lock); + return (void *)(block + 1) + slot * block_desc[block_type].slot_size; +} + +/** + * zblock_unmap() - unmaps the allocation associated with the given handle + * @pool: pool in which the allocation resides + * @handle: handle associated with the allocation to be unmapped + */ +static void zblock_unmap(struct zblock_pool *pool, unsigned long handle) +{ + unsigned int block_type, slot; + struct zblock_block *block; + + block = handle_to_metadata(handle, &block_type, &slot); + spin_lock(&block->lock); + block->slot_info[slot] = SLOT_UNMAPPED; + spin_unlock(&block->lock); +} + +/** + * zblock_get_pool_size() - gets the zblock pool size in bytes + * @pool: pool whose size is being queried + * + * Returns: size in bytes of the given pool. + */ +static u64 zblock_get_pool_size(struct zblock_pool *pool) +{ + u64 total_size; + int i; + + total_size = 0; + for (i = 0; i < ARRAY_SIZE(block_desc); i++) { + total_size += (pool->block_lists)[i].block_count + * (PAGE_SIZE << block_desc[i].order); + } + return total_size; +} + +/***************** + * zpool + ****************/ + +static int zblock_zpool_evict(struct zblock_pool *pool, unsigned long handle) +{ + if (pool->zpool && pool->zpool_ops && pool->zpool_ops->evict) + return pool->zpool_ops->evict(pool->zpool, handle); + else + return -ENOENT; +} + +static const struct zblock_ops zblock_zpool_ops = { + .evict = zblock_zpool_evict +}; + +static void *zblock_zpool_create(const char *name, gfp_t gfp, + const struct zpool_ops *zpool_ops, + struct zpool *zpool) +{ + struct zblock_pool *pool; + + pool = zblock_create_pool(gfp, &zblock_zpool_ops); + if (pool) { + pool->zpool = zpool; + pool->zpool_ops = zpool_ops; + } + return pool; +} + +static void zblock_zpool_destroy(void *pool) +{ + zblock_destroy_pool(pool); +} + +static int zblock_zpool_malloc(void *pool, size_t size, gfp_t gfp, + unsigned long *handle) +{ + return zblock_alloc(pool, size, gfp, handle); +} + +static void zblock_zpool_free(void *pool, unsigned long handle) +{ + zblock_free(pool, handle); +} + +static int zblock_zpool_shrink(void *pool, unsigned int pages, + unsigned int *reclaimed) +{ + unsigned int total = 0; + int ret = -EINVAL; + + while (total < pages) { + ret = zblock_reclaim_block(pool); + if (ret < 0) + break; + total += ret; + } + if (reclaimed) + *reclaimed = total; + + return ret; +} + +static void *zblock_zpool_map(void *pool, unsigned long handle, + enum zpool_mapmode mm) +{ + return zblock_map(pool, handle); +} + +static void zblock_zpool_unmap(void *pool, unsigned long handle) +{ + zblock_unmap(pool, handle); +} + +static u64 zblock_zpool_total_size(void *pool) +{ + return zblock_get_pool_size(pool); +} + +static struct zpool_driver zblock_zpool_driver = { + .type = "zblock", + .owner = THIS_MODULE, + .create = zblock_zpool_create, + .destroy = zblock_zpool_destroy, + .malloc = zblock_zpool_malloc, + .free = zblock_zpool_free, + .shrink = zblock_zpool_shrink, + .map = zblock_zpool_map, + .unmap = zblock_zpool_unmap, + .total_size = zblock_zpool_total_size, +}; + +MODULE_ALIAS("zpool-zblock"); + +static int __init init_zblock(void) +{ + pr_info("loaded\n"); + zpool_register_driver(&zblock_zpool_driver); + return 0; +} + +static void __exit exit_zblock(void) +{ + zpool_unregister_driver(&zblock_zpool_driver); + pr_info("unloaded\n"); +} + +module_init(init_zblock); +module_exit(exit_zblock); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Ananda Badmaeb "); +MODULE_DESCRIPTION("Block allocator for compressed pages"); diff --git a/scripts/Makefile.debug b/scripts/Makefile.debug index 8cf1cb22dd93..f831e836bd99 100644 --- a/scripts/Makefile.debug +++ b/scripts/Makefile.debug @@ -22,10 +22,14 @@ DEBUG_CFLAGS += -femit-struct-debug-baseonly endif endif -ifdef CONFIG_DEBUG_INFO_COMPRESSED +ifdef CONFIG_DEBUG_INFO_COMPRESSED_ZLIB DEBUG_CFLAGS += -gz=zlib KBUILD_AFLAGS += -gz=zlib KBUILD_LDFLAGS += --compress-debug-sections=zlib +else ifdef CONFIG_DEBUG_INFO_COMPRESSED_ZSTD +DEBUG_CFLAGS += -gz=zstd +KBUILD_AFLAGS += -gz=zstd +KBUILD_LDFLAGS += --compress-debug-sections=zstd endif KBUILD_CFLAGS += $(DEBUG_CFLAGS) diff --git a/scripts/Makefile.lib b/scripts/Makefile.lib index 3fb6a99e78c4..f62770a0a84f 100644 --- a/scripts/Makefile.lib +++ b/scripts/Makefile.lib @@ -504,14 +504,21 @@ quiet_cmd_xzmisc = XZMISC $@ # decompression is used, like initramfs decompression, zstd22 should likely not # be used because it would require zstd to allocate a 128 MB buffer. +ifdef CONFIG_ZSTD_COMP_VAL +zstd_comp_val := $(CONFIG_ZSTD_COMP_VAL) +ifeq ($(shell test $(zstd_comp_val) -gt 19; echo $$?),0) +zstd_comp_val += --ultra +endif +endif + quiet_cmd_zstd = ZSTD $@ - cmd_zstd = cat $(real-prereqs) | $(ZSTD) -19 > $@ + cmd_zstd = cat $(real-prereqs) | $(ZSTD) -T0 -19 > $@ quiet_cmd_zstd22 = ZSTD22 $@ - cmd_zstd22 = cat $(real-prereqs) | $(ZSTD) -22 --ultra > $@ + cmd_zstd22 = cat $(real-prereqs) | $(ZSTD) -T0 -22 --ultra > $@ quiet_cmd_zstd22_with_size = ZSTD22 $@ - cmd_zstd22_with_size = { cat $(real-prereqs) | $(ZSTD) -22 --ultra; $(size_append); } > $@ + cmd_zstd22_with_size = { cat $(real-prereqs) | $(ZSTD) -T0 -$(zstd_comp_val); $(size_append); } > $@ # ASM offsets # --------------------------------------------------------------------------- diff --git a/scripts/Makefile.modinst b/scripts/Makefile.modinst index a4c987c23750..132863cf3183 100644 --- a/scripts/Makefile.modinst +++ b/scripts/Makefile.modinst @@ -96,8 +96,13 @@ quiet_cmd_gzip = GZIP $@ cmd_gzip = $(KGZIP) -n -f $< quiet_cmd_xz = XZ $@ cmd_xz = $(XZ) --lzma2=dict=2MiB -f $< +ifdef CONFIG_MODULE_COMPRESS_ZSTD_ULTRA quiet_cmd_zstd = ZSTD $@ - cmd_zstd = $(ZSTD) -T0 --rm -f -q $< + cmd_zstd = $(ZSTD) -$(CONFIG_MODULE_COMPRESS_ZSTD_LEVEL_ULTRA) --ultra --zstd=wlog=21 -T0 --rm -f -q $< +else +quiet_cmd_zstd = ZSTD $@ + cmd_zstd = $(ZSTD) -$(CONFIG_MODULE_COMPRESS_ZSTD_LEVEL) --zstd=wlog=21 -T0 --rm -f -q $< +endif $(dst)/%.ko.gz: $(dst)/%.ko FORCE $(call cmd,gzip) -- 2.39.0.rc2 From 8d6a9d09ac08488d996a5d3a9ba1a3f18b60232f Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Mon, 5 Sep 2022 08:34:43 +0200 Subject: [PATCH 02/20] bbr2 Signed-off-by: Peter Jung --- Documentation/networking/ip-sysctl.rst | 58 + include/linux/tcp.h | 6 +- include/net/inet_connection_sock.h | 3 +- include/net/netns/ipv4.h | 5 + include/net/tcp.h | 58 +- include/uapi/linux/inet_diag.h | 33 + include/uapi/linux/snmp.h | 1 + net/ipv4/Kconfig | 22 + net/ipv4/Makefile | 3 +- net/ipv4/bpf_tcp_ca.c | 12 + net/ipv4/proc.c | 1 + net/ipv4/sysctl_net_ipv4.c | 43 + net/ipv4/tcp.c | 1 + net/ipv4/tcp_bbr.c | 38 +- net/ipv4/tcp_bbr2.c | 2692 ++++++++++++++++++++++++ net/ipv4/tcp_cong.c | 1 + net/ipv4/tcp_input.c | 27 +- net/ipv4/tcp_ipv4.c | 7 + net/ipv4/tcp_output.c | 26 +- net/ipv4/tcp_plb.c | 100 + net/ipv4/tcp_rate.c | 30 +- net/ipv4/tcp_timer.c | 1 + 22 files changed, 3133 insertions(+), 35 deletions(-) create mode 100644 net/ipv4/tcp_bbr2.c create mode 100644 net/ipv4/tcp_plb.c diff --git a/Documentation/networking/ip-sysctl.rst b/Documentation/networking/ip-sysctl.rst index a759872a2883..f2372f8f860b 100644 --- a/Documentation/networking/ip-sysctl.rst +++ b/Documentation/networking/ip-sysctl.rst @@ -1040,6 +1040,64 @@ tcp_challenge_ack_limit - INTEGER TCP stack implements per TCP socket limits anyway. Default: INT_MAX (unlimited) +tcp_plb_enabled - BOOLEAN + If set, TCP PLB (Protective Load Balancing) is enabled. PLB is + described in the following paper: + https://doi.org/10.1145/3544216.3544226. Based on PLB parameters, + upon sensing sustained congestion, TCP triggers a change in + flow label field for outgoing IPv6 packets. A change in flow label + field potentially changes the path of outgoing packets for switches + that use ECMP/WCMP for routing. + + Default: 0 + +tcp_plb_cong_thresh - INTEGER + Fraction of packets marked with congestion over a round (RTT) to + tag that round as congested. This is referred to as K in the PLB paper: + https://doi.org/10.1145/3544216.3544226. + + The 0-1 fraction range is mapped to 0-256 range to avoid floating + point operations. For example, 128 means that if at least 50% of + the packets in a round were marked as congested then the round + will be tagged as congested. + + Possible Values: 0 - 256 + + Default: 128 + +tcp_plb_idle_rehash_rounds - INTEGER + Number of consecutive congested rounds (RTT) seen after which + a rehash can be performed, given there are no packets in flight. + This is referred to as M in PLB paper: + https://doi.org/10.1145/3544216.3544226. + + Possible Values: 0 - 31 + + Default: 3 + +tcp_plb_rehash_rounds - INTEGER + Number of consecutive congested rounds (RTT) seen after which + a forced rehash can be performed. Be careful when setting this + parameter, as a small value increases the risk of retransmissions. + This is referred to as N in PLB paper: + https://doi.org/10.1145/3544216.3544226. + + Possible Values: 0 - 31 + + Default: 12 + +tcp_plb_suspend_rto_sec - INTEGER + Time, in seconds, to suspend PLB in event of an RTO. In order to avoid + having PLB repath onto a connectivity "black hole", after an RTO a TCP + connection suspends PLB repathing for a random duration between 1x and + 2x of this parameter. Randomness is added to avoid concurrent rehashing + of multiple TCP connections. This should be set corresponding to the + amount of time it takes to repair a failed link. + + Possible Values: 0 - 255 + + Default: 60 + UDP variables ============= diff --git a/include/linux/tcp.h b/include/linux/tcp.h index 4791fd801945..bbfc0a3537e9 100644 --- a/include/linux/tcp.h +++ b/include/linux/tcp.h @@ -255,7 +255,8 @@ struct tcp_sock { u8 compressed_ack; u8 dup_ack_counter:2, tlp_retrans:1, /* TLP is a retransmission */ - unused:5; + fast_ack_mode:2, /* which fast ack mode ? */ + unused:3; u32 chrono_start; /* Start time in jiffies of a TCP chrono */ u32 chrono_stat[3]; /* Time in jiffies for chrono_stat stats */ u8 chrono_type:2, /* current chronograph type */ @@ -443,6 +444,9 @@ struct tcp_sock { */ struct request_sock __rcu *fastopen_rsk; struct saved_syn *saved_syn; + +/* Rerouting information */ + u16 ecn_rehash; /* PLB triggered rehash attempts */ }; enum tsq_enum { diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h index ee88f0f1350f..e3075b3f1ece 100644 --- a/include/net/inet_connection_sock.h +++ b/include/net/inet_connection_sock.h @@ -132,7 +132,8 @@ struct inet_connection_sock { u32 icsk_probes_tstamp; u32 icsk_user_timeout; - u64 icsk_ca_priv[104 / sizeof(u64)]; +/* XXX inflated by temporary internal debugging info */ + u64 icsk_ca_priv[224 / sizeof(u64)]; #define ICSK_CA_PRIV_SIZE sizeof_field(struct inet_connection_sock, icsk_ca_priv) }; diff --git a/include/net/netns/ipv4.h b/include/net/netns/ipv4.h index 6320a76cefdc..2e39e07ed41f 100644 --- a/include/net/netns/ipv4.h +++ b/include/net/netns/ipv4.h @@ -181,6 +181,11 @@ struct netns_ipv4 { unsigned long tfo_active_disable_stamp; u32 tcp_challenge_timestamp; u32 tcp_challenge_count; + u8 sysctl_tcp_plb_enabled; + int sysctl_tcp_plb_cong_thresh; + u8 sysctl_tcp_plb_idle_rehash_rounds; + u8 sysctl_tcp_plb_rehash_rounds; + u8 sysctl_tcp_plb_suspend_rto_sec; int sysctl_udp_wmem_min; int sysctl_udp_rmem_min; diff --git a/include/net/tcp.h b/include/net/tcp.h index 95c1d51393ac..498b3d133ec0 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -372,6 +372,7 @@ static inline void tcp_dec_quickack_mode(struct sock *sk, #define TCP_ECN_QUEUE_CWR 2 #define TCP_ECN_DEMAND_CWR 4 #define TCP_ECN_SEEN 8 +#define TCP_ECN_ECT_PERMANENT 16 enum tcp_tw_status { TCP_TW_SUCCESS = 0, @@ -816,6 +817,11 @@ static inline u32 tcp_stamp_us_delta(u64 t1, u64 t0) return max_t(s64, t1 - t0, 0); } +static inline u32 tcp_stamp32_us_delta(u32 t1, u32 t0) +{ + return max_t(s32, t1 - t0, 0); +} + static inline u32 tcp_skb_timestamp(const struct sk_buff *skb) { return tcp_ns_to_ts(skb->skb_mstamp_ns); @@ -891,9 +897,14 @@ struct tcp_skb_cb { /* pkts S/ACKed so far upon tx of skb, incl retrans: */ __u32 delivered; /* start of send pipeline phase */ - u64 first_tx_mstamp; + u32 first_tx_mstamp; /* when we reached the "delivered" count */ - u64 delivered_mstamp; + u32 delivered_mstamp; +#define TCPCB_IN_FLIGHT_BITS 20 +#define TCPCB_IN_FLIGHT_MAX ((1U << TCPCB_IN_FLIGHT_BITS) - 1) + u32 in_flight:20, /* packets in flight at transmit */ + unused2:12; + u32 lost; /* packets lost so far upon tx of skb */ } tx; /* only used for outgoing skbs */ union { struct inet_skb_parm h4; @@ -1019,7 +1030,11 @@ enum tcp_ca_ack_event_flags { #define TCP_CONG_NON_RESTRICTED 0x1 /* Requires ECN/ECT set on all packets */ #define TCP_CONG_NEEDS_ECN 0x2 -#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | TCP_CONG_NEEDS_ECN) +/* Wants notification of CE events (CA_EVENT_ECN_IS_CE, CA_EVENT_ECN_NO_CE). */ +#define TCP_CONG_WANTS_CE_EVENTS 0x4 +#define TCP_CONG_MASK (TCP_CONG_NON_RESTRICTED | \ + TCP_CONG_NEEDS_ECN | \ + TCP_CONG_WANTS_CE_EVENTS) union tcp_cc_info; @@ -1039,8 +1054,11 @@ struct ack_sample { */ struct rate_sample { u64 prior_mstamp; /* starting timestamp for interval */ + u32 prior_lost; /* tp->lost at "prior_mstamp" */ u32 prior_delivered; /* tp->delivered at "prior_mstamp" */ u32 prior_delivered_ce;/* tp->delivered_ce at "prior_mstamp" */ + u32 tx_in_flight; /* packets in flight at starting timestamp */ + s32 lost; /* number of packets lost over interval */ s32 delivered; /* number of packets delivered over interval */ s32 delivered_ce; /* number of packets delivered w/ CE marks*/ long interval_us; /* time for tp->delivered to incr "delivered" */ @@ -1054,6 +1072,7 @@ struct rate_sample { bool is_app_limited; /* is sample from packet with bubble in pipe? */ bool is_retrans; /* is sample from retransmission? */ bool is_ack_delayed; /* is this (likely) a delayed ACK? */ + bool is_ece; /* did this ACK have ECN marked? */ }; struct tcp_congestion_ops { @@ -1077,8 +1096,11 @@ struct tcp_congestion_ops { /* hook for packet ack accounting (optional) */ void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample); - /* override sysctl_tcp_min_tso_segs */ - u32 (*min_tso_segs)(struct sock *sk); + /* pick target number of segments per TSO/GSO skb (optional): */ + u32 (*tso_segs)(struct sock *sk, unsigned int mss_now); + + /* react to a specific lost skb (optional) */ + void (*skb_marked_lost)(struct sock *sk, const struct sk_buff *skb); /* call when packets are delivered to update cwnd and pacing rate, * after all the ca_state processing. (optional) @@ -1141,6 +1163,14 @@ static inline char *tcp_ca_get_name_by_key(u32 key, char *buffer) } #endif +static inline bool tcp_ca_wants_ce_events(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + return icsk->icsk_ca_ops->flags & (TCP_CONG_NEEDS_ECN | + TCP_CONG_WANTS_CE_EVENTS); +} + static inline bool tcp_ca_needs_ecn(const struct sock *sk) { const struct inet_connection_sock *icsk = inet_csk(sk); @@ -1160,6 +1190,7 @@ static inline void tcp_ca_event(struct sock *sk, const enum tcp_ca_event event) void tcp_set_ca_state(struct sock *sk, const u8 ca_state); /* From tcp_rate.c */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb); void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, struct rate_sample *rs); @@ -2133,6 +2164,23 @@ extern void tcp_rack_advance(struct tcp_sock *tp, u8 sacked, u32 end_seq, extern void tcp_rack_reo_timeout(struct sock *sk); extern void tcp_rack_update_reo_wnd(struct sock *sk, struct rate_sample *rs); +/* tcp_plb.c */ + +#define TCP_PLB_SCALE 8 /* scaling factor for fractions in PLB (e.g. ce_ratio) */ + +/* State for PLB (Protective Load Balancing) for a single TCP connection. */ +struct tcp_plb_state { + u8 consec_cong_rounds:5, /* consecutive congested rounds */ + enabled:1, /* Check if PLB is enabled */ + unused:2; + u32 pause_until; /* jiffies32 when PLB can resume repathing */ +}; + +void tcp_plb_update_state(const struct sock *sk, struct tcp_plb_state *plb, + const int cong_ratio); +void tcp_plb_check_rehash(struct sock *sk, struct tcp_plb_state *plb); +void tcp_plb_update_state_upon_rto(struct sock *sk, struct tcp_plb_state *plb); + /* At how many usecs into the future should the RTO fire? */ static inline s64 tcp_rto_delta_us(const struct sock *sk) { diff --git a/include/uapi/linux/inet_diag.h b/include/uapi/linux/inet_diag.h index 50655de04c9b..0e24f11627d5 100644 --- a/include/uapi/linux/inet_diag.h +++ b/include/uapi/linux/inet_diag.h @@ -231,9 +231,42 @@ struct tcp_bbr_info { __u32 bbr_cwnd_gain; /* cwnd gain shifted left 8 bits */ }; +/* Phase as reported in netlink/ss stats. */ +enum tcp_bbr2_phase { + BBR2_PHASE_INVALID = 0, + BBR2_PHASE_STARTUP = 1, + BBR2_PHASE_DRAIN = 2, + BBR2_PHASE_PROBE_RTT = 3, + BBR2_PHASE_PROBE_BW_UP = 4, + BBR2_PHASE_PROBE_BW_DOWN = 5, + BBR2_PHASE_PROBE_BW_CRUISE = 6, + BBR2_PHASE_PROBE_BW_REFILL = 7 +}; + +struct tcp_bbr2_info { + /* u64 bw: bandwidth (app throughput) estimate in Byte per sec: */ + __u32 bbr_bw_lsb; /* lower 32 bits of bw */ + __u32 bbr_bw_msb; /* upper 32 bits of bw */ + __u32 bbr_min_rtt; /* min-filtered RTT in uSec */ + __u32 bbr_pacing_gain; /* pacing gain shifted left 8 bits */ + __u32 bbr_cwnd_gain; /* cwnd gain shifted left 8 bits */ + __u32 bbr_bw_hi_lsb; /* lower 32 bits of bw_hi */ + __u32 bbr_bw_hi_msb; /* upper 32 bits of bw_hi */ + __u32 bbr_bw_lo_lsb; /* lower 32 bits of bw_lo */ + __u32 bbr_bw_lo_msb; /* upper 32 bits of bw_lo */ + __u8 bbr_mode; /* current bbr_mode in state machine */ + __u8 bbr_phase; /* current state machine phase */ + __u8 unused1; /* alignment padding; not used yet */ + __u8 bbr_version; /* MUST be at this offset in struct */ + __u32 bbr_inflight_lo; /* lower/short-term data volume bound */ + __u32 bbr_inflight_hi; /* higher/long-term data volume bound */ + __u32 bbr_extra_acked; /* max excess packets ACKed in epoch */ +}; + union tcp_cc_info { struct tcpvegas_info vegas; struct tcp_dctcp_info dctcp; struct tcp_bbr_info bbr; + struct tcp_bbr2_info bbr2; }; #endif /* _UAPI_INET_DIAG_H_ */ diff --git a/include/uapi/linux/snmp.h b/include/uapi/linux/snmp.h index 4d7470036a8b..8ce035f1c874 100644 --- a/include/uapi/linux/snmp.h +++ b/include/uapi/linux/snmp.h @@ -292,6 +292,7 @@ enum LINUX_MIB_TCPDSACKIGNOREDDUBIOUS, /* TCPDSACKIgnoredDubious */ LINUX_MIB_TCPMIGRATEREQSUCCESS, /* TCPMigrateReqSuccess */ LINUX_MIB_TCPMIGRATEREQFAILURE, /* TCPMigrateReqFailure */ + LINUX_MIB_TCPECNREHASH, /* TCPECNRehash */ __LINUX_MIB_MAX }; diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 2dfb12230f08..b6bec331a82e 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -678,6 +678,24 @@ config TCP_CONG_BBR AQM schemes that do not provide a delay signal. It requires the fq ("Fair Queue") pacing packet scheduler. +config TCP_CONG_BBR2 + tristate "BBR2 TCP" + default n + help + + BBR2 TCP congestion control is a model-based congestion control + algorithm that aims to maximize network utilization, keep queues and + retransmit rates low, and to be able to coexist with Reno/CUBIC in + common scenarios. It builds an explicit model of the network path. It + tolerates a targeted degree of random packet loss and delay that are + unrelated to congestion. It can operate over LAN, WAN, cellular, wifi, + or cable modem links, and can use DCTCP-L4S-style ECN signals. It can + coexist with flows that use loss-based congestion control, and can + operate with shallow buffers, deep buffers, bufferbloat, policers, or + AQM schemes that do not provide a delay signal. It requires pacing, + using either TCP internal pacing or the fq ("Fair Queue") pacing packet + scheduler. + choice prompt "Default TCP congestion control" default DEFAULT_CUBIC @@ -715,6 +733,9 @@ choice config DEFAULT_BBR bool "BBR" if TCP_CONG_BBR=y + config DEFAULT_BBR2 + bool "BBR2" if TCP_CONG_BBR2=y + config DEFAULT_RENO bool "Reno" endchoice @@ -739,6 +760,7 @@ config DEFAULT_TCP_CONG default "dctcp" if DEFAULT_DCTCP default "cdg" if DEFAULT_CDG default "bbr" if DEFAULT_BBR + default "bbr2" if DEFAULT_BBR2 default "cubic" config TCP_MD5SIG diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index bbdd9c44f14e..e7a86a50838a 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -10,7 +10,7 @@ obj-y := route.o inetpeer.o protocol.o \ tcp.o tcp_input.o tcp_output.o tcp_timer.o tcp_ipv4.o \ tcp_minisocks.o tcp_cong.o tcp_metrics.o tcp_fastopen.o \ tcp_rate.o tcp_recovery.o tcp_ulp.o \ - tcp_offload.o datagram.o raw.o udp.o udplite.o \ + tcp_offload.o tcp_plb.o datagram.o raw.o udp.o udplite.o \ udp_offload.o arp.o icmp.o devinet.o af_inet.o igmp.o \ fib_frontend.o fib_semantics.o fib_trie.o fib_notifier.o \ inet_fragment.o ping.o ip_tunnel_core.o gre_offload.o \ @@ -46,6 +46,7 @@ obj-$(CONFIG_INET_TCP_DIAG) += tcp_diag.o obj-$(CONFIG_INET_UDP_DIAG) += udp_diag.o obj-$(CONFIG_INET_RAW_DIAG) += raw_diag.o obj-$(CONFIG_TCP_CONG_BBR) += tcp_bbr.o +obj-$(CONFIG_TCP_CONG_BBR2) += tcp_bbr2.o obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 85a9e500c42d..24fcbac984fe 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -14,6 +14,18 @@ /* "extern" is to avoid sparse warning. It is only used in bpf_struct_ops.c. */ extern struct bpf_struct_ops bpf_tcp_congestion_ops; +static u32 optional_ops[] = { + offsetof(struct tcp_congestion_ops, init), + offsetof(struct tcp_congestion_ops, release), + offsetof(struct tcp_congestion_ops, set_state), + offsetof(struct tcp_congestion_ops, cwnd_event), + offsetof(struct tcp_congestion_ops, in_ack_event), + offsetof(struct tcp_congestion_ops, pkts_acked), + offsetof(struct tcp_congestion_ops, tso_segs), + offsetof(struct tcp_congestion_ops, sndbuf_expand), + offsetof(struct tcp_congestion_ops, cong_control), +}; + static u32 unsupported_ops[] = { offsetof(struct tcp_congestion_ops, get_info), }; diff --git a/net/ipv4/proc.c b/net/ipv4/proc.c index 0088a4c64d77..e0a664b467e0 100644 --- a/net/ipv4/proc.c +++ b/net/ipv4/proc.c @@ -297,6 +297,7 @@ static const struct snmp_mib snmp4_net_list[] = { SNMP_MIB_ITEM("TCPDSACKIgnoredDubious", LINUX_MIB_TCPDSACKIGNOREDDUBIOUS), SNMP_MIB_ITEM("TCPMigrateReqSuccess", LINUX_MIB_TCPMIGRATEREQSUCCESS), SNMP_MIB_ITEM("TCPMigrateReqFailure", LINUX_MIB_TCPMIGRATEREQFAILURE), + SNMP_MIB_ITEM("TCPECNRehash", LINUX_MIB_TCPECNREHASH), SNMP_MIB_SENTINEL }; diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 5490c285668b..ed35a9485e71 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -39,6 +39,8 @@ static u32 u32_max_div_HZ = UINT_MAX / HZ; static int one_day_secs = 24 * 3600; static u32 fib_multipath_hash_fields_all_mask __maybe_unused = FIB_MULTIPATH_HASH_FIELD_ALL_MASK; +static int tcp_plb_max_rounds = 31; +static int tcp_plb_max_cong_thresh = 256; /* obsolete */ static int sysctl_tcp_low_latency __read_mostly; @@ -1346,6 +1348,47 @@ static struct ctl_table ipv4_net_table[] = { .extra1 = SYSCTL_ZERO, .extra2 = SYSCTL_TWO, }, + { + .procname = "tcp_plb_enabled", + .data = &init_net.ipv4.sysctl_tcp_plb_enabled, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_plb_cong_thresh", + .data = &init_net.ipv4.sysctl_tcp_plb_cong_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_plb_max_cong_thresh, + }, + { + .procname = "tcp_plb_idle_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_idle_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_suspend_rto_sec", + .data = &init_net.ipv4.sysctl_tcp_plb_suspend_rto_sec, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, { } }; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index cdd4f2f60f0c..7f591bdcae9e 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -3190,6 +3190,7 @@ int tcp_disconnect(struct sock *sk, int flags) tp->rx_opt.dsack = 0; tp->rx_opt.num_sacks = 0; tp->rcv_ooopack = 0; + tp->fast_ack_mode = 0; /* Clean up fastopen related fields */ diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index 54eec33c6e1c..bfbf158c71f4 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -294,26 +294,40 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) sk->sk_pacing_rate = rate; } -/* override sysctl_tcp_min_tso_segs */ static u32 bbr_min_tso_segs(struct sock *sk) { return sk->sk_pacing_rate < (bbr_min_tso_rate >> 3) ? 1 : 2; } +/* Return the number of segments BBR would like in a TSO/GSO skb, given + * a particular max gso size as a constraint. + */ +static u32 bbr_tso_segs_generic(struct sock *sk, unsigned int mss_now, + u32 gso_max_size) +{ + u32 segs; + u64 bytes; + + /* Budget a TSO/GSO burst size allowance based on bw (pacing_rate). */ + bytes = sk->sk_pacing_rate >> sk->sk_pacing_shift; + + bytes = min_t(u32, bytes, gso_max_size - 1 - MAX_TCP_HEADER); + segs = max_t(u32, div_u64(bytes, mss_now), bbr_min_tso_segs(sk)); + return segs; +} + +/* Custom tcp_tso_autosize() for BBR, used at transmit time to cap skb size. */ +static u32 bbr_tso_segs(struct sock *sk, unsigned int mss_now) +{ + return bbr_tso_segs_generic(sk, mss_now, sk->sk_gso_max_size); +} + +/* Like bbr_tso_segs(), using mss_cache, ignoring driver's sk_gso_max_size. */ static u32 bbr_tso_segs_goal(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); - u32 segs, bytes; - - /* Sort of tcp_tso_autosize() but ignoring - * driver provided sk_gso_max_size. - */ - bytes = min_t(unsigned long, - sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift), - GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); - segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); - return min(segs, 0x7FU); + return bbr_tso_segs_generic(sk, tp->mss_cache, GSO_LEGACY_MAX_SIZE); } /* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ @@ -1149,7 +1163,7 @@ static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { .undo_cwnd = bbr_undo_cwnd, .cwnd_event = bbr_cwnd_event, .ssthresh = bbr_ssthresh, - .min_tso_segs = bbr_min_tso_segs, + .tso_segs = bbr_tso_segs, .get_info = bbr_get_info, .set_state = bbr_set_state, }; diff --git a/net/ipv4/tcp_bbr2.c b/net/ipv4/tcp_bbr2.c new file mode 100644 index 000000000000..2e39f7a353be --- /dev/null +++ b/net/ipv4/tcp_bbr2.c @@ -0,0 +1,2692 @@ +/* BBR (Bottleneck Bandwidth and RTT) congestion control, v2 + * + * BBRv2 is a model-based congestion control algorithm that aims for low + * queues, low loss, and (bounded) Reno/CUBIC coexistence. To maintain a model + * of the network path, it uses measurements of bandwidth and RTT, as well as + * (if they occur) packet loss and/or DCTCP/L4S-style ECN signals. Note that + * although it can use ECN or loss signals explicitly, it does not require + * either; it can bound its in-flight data based on its estimate of the BDP. + * + * The model has both higher and lower bounds for the operating range: + * lo: bw_lo, inflight_lo: conservative short-term lower bound + * hi: bw_hi, inflight_hi: robust long-term upper bound + * The bandwidth-probing time scale is (a) extended dynamically based on + * estimated BDP to improve coexistence with Reno/CUBIC; (b) bounded by + * an interactive wall-clock time-scale to be more scalable and responsive + * than Reno and CUBIC. + * + * Here is a state transition diagram for BBR: + * + * | + * V + * +---> STARTUP ----+ + * | | | + * | V | + * | DRAIN ----+ + * | | | + * | V | + * +---> PROBE_BW ----+ + * | ^ | | + * | | | | + * | +----+ | + * | | + * +---- PROBE_RTT <--+ + * + * A BBR flow starts in STARTUP, and ramps up its sending rate quickly. + * When it estimates the pipe is full, it enters DRAIN to drain the queue. + * In steady state a BBR flow only uses PROBE_BW and PROBE_RTT. + * A long-lived BBR flow spends the vast majority of its time remaining + * (repeatedly) in PROBE_BW, fully probing and utilizing the pipe's bandwidth + * in a fair manner, with a small, bounded queue. *If* a flow has been + * continuously sending for the entire min_rtt window, and hasn't seen an RTT + * sample that matches or decreases its min_rtt estimate for 10 seconds, then + * it briefly enters PROBE_RTT to cut inflight to a minimum value to re-probe + * the path's two-way propagation delay (min_rtt). When exiting PROBE_RTT, if + * we estimated that we reached the full bw of the pipe then we enter PROBE_BW; + * otherwise we enter STARTUP to try to fill the pipe. + * + * BBR is described in detail in: + * "BBR: Congestion-Based Congestion Control", + * Neal Cardwell, Yuchung Cheng, C. Stephen Gunn, Soheil Hassas Yeganeh, + * Van Jacobson. ACM Queue, Vol. 14 No. 5, September-October 2016. + * + * There is a public e-mail list for discussing BBR development and testing: + * https://groups.google.com/forum/#!forum/bbr-dev + * + * NOTE: BBR might be used with the fq qdisc ("man tc-fq") with pacing enabled, + * otherwise TCP stack falls back to an internal pacing using one high + * resolution timer per TCP socket and may use more resources. + */ +#include +#include +#include +#include +#include + +#include "tcp_dctcp.h" + +/* Scale factor for rate in pkt/uSec unit to avoid truncation in bandwidth + * estimation. The rate unit ~= (1500 bytes / 1 usec / 2^24) ~= 715 bps. + * This handles bandwidths from 0.06pps (715bps) to 256Mpps (3Tbps) in a u32. + * Since the minimum window is >=4 packets, the lower bound isn't + * an issue. The upper bound isn't an issue with existing technologies. + */ +#define BW_SCALE 24 +#define BW_UNIT (1 << BW_SCALE) + +#define BBR_SCALE 8 /* scaling factor for fractions in BBR (e.g. gains) */ +#define BBR_UNIT (1 << BBR_SCALE) + +#define FLAG_DEBUG_VERBOSE 0x1 /* Verbose debugging messages */ +#define FLAG_DEBUG_LOOPBACK 0x2 /* Do NOT skip loopback addr */ + +#define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ + +/* BBR has the following modes for deciding how fast to send: */ +enum bbr_mode { + BBR_STARTUP, /* ramp up sending rate rapidly to fill pipe */ + BBR_DRAIN, /* drain any queue created during startup */ + BBR_PROBE_BW, /* discover, share bw: pace around estimated bw */ + BBR_PROBE_RTT, /* cut inflight to min to probe min_rtt */ +}; + +/* How does the incoming ACK stream relate to our bandwidth probing? */ +enum bbr_ack_phase { + BBR_ACKS_INIT, /* not probing; not getting probe feedback */ + BBR_ACKS_REFILLING, /* sending at est. bw to fill pipe */ + BBR_ACKS_PROBE_STARTING, /* inflight rising to probe bw */ + BBR_ACKS_PROBE_FEEDBACK, /* getting feedback from bw probing */ + BBR_ACKS_PROBE_STOPPING, /* stopped probing; still getting feedback */ +}; + +/* BBR congestion control block */ +struct bbr { + u32 min_rtt_us; /* min RTT in min_rtt_win_sec window */ + u32 min_rtt_stamp; /* timestamp of min_rtt_us */ + u32 probe_rtt_done_stamp; /* end time for BBR_PROBE_RTT mode */ + u32 probe_rtt_min_us; /* min RTT in bbr_probe_rtt_win_ms window */ + u32 probe_rtt_min_stamp; /* timestamp of probe_rtt_min_us*/ + u32 next_rtt_delivered; /* scb->tx.delivered at end of round */ + u32 prior_rcv_nxt; /* tp->rcv_nxt when CE state last changed */ + u64 cycle_mstamp; /* time of this cycle phase start */ + u32 mode:3, /* current bbr_mode in state machine */ + prev_ca_state:3, /* CA state on previous ACK */ + packet_conservation:1, /* use packet conservation? */ + round_start:1, /* start of packet-timed tx->ack round? */ + ce_state:1, /* If most recent data has CE bit set */ + bw_probe_up_rounds:5, /* cwnd-limited rounds in PROBE_UP */ + try_fast_path:1, /* can we take fast path? */ + unused2:11, + idle_restart:1, /* restarting after idle? */ + probe_rtt_round_done:1, /* a BBR_PROBE_RTT round at 4 pkts? */ + cycle_idx:3, /* current index in pacing_gain cycle array */ + has_seen_rtt:1; /* have we seen an RTT sample yet? */ + u32 pacing_gain:11, /* current gain for setting pacing rate */ + cwnd_gain:11, /* current gain for setting cwnd */ + full_bw_reached:1, /* reached full bw in Startup? */ + full_bw_cnt:2, /* number of rounds without large bw gains */ + init_cwnd:7; /* initial cwnd */ + u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ + u32 full_bw; /* recent bw, to estimate if pipe is full */ + + /* For tracking ACK aggregation: */ + u64 ack_epoch_mstamp; /* start of ACK sampling epoch */ + u16 extra_acked[2]; /* max excess data ACKed in epoch */ + u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ + extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ + extra_acked_win_idx:1, /* current index in extra_acked array */ + /* BBR v2 state: */ + unused1:2, + startup_ecn_rounds:2, /* consecutive hi ECN STARTUP rounds */ + loss_in_cycle:1, /* packet loss in this cycle? */ + ecn_in_cycle:1; /* ECN in this cycle? */ + u32 loss_round_delivered; /* scb->tx.delivered ending loss round */ + u32 undo_bw_lo; /* bw_lo before latest losses */ + u32 undo_inflight_lo; /* inflight_lo before latest losses */ + u32 undo_inflight_hi; /* inflight_hi before latest losses */ + u32 bw_latest; /* max delivered bw in last round trip */ + u32 bw_lo; /* lower bound on sending bandwidth */ + u32 bw_hi[2]; /* upper bound of sending bandwidth range*/ + u32 inflight_latest; /* max delivered data in last round trip */ + u32 inflight_lo; /* lower bound of inflight data range */ + u32 inflight_hi; /* upper bound of inflight data range */ + u32 bw_probe_up_cnt; /* packets delivered per inflight_hi incr */ + u32 bw_probe_up_acks; /* packets (S)ACKed since inflight_hi incr */ + u32 probe_wait_us; /* PROBE_DOWN until next clock-driven probe */ + u32 ecn_eligible:1, /* sender can use ECN (RTT, handshake)? */ + ecn_alpha:9, /* EWMA delivered_ce/delivered; 0..256 */ + bw_probe_samples:1, /* rate samples reflect bw probing? */ + prev_probe_too_high:1, /* did last PROBE_UP go too high? */ + stopped_risky_probe:1, /* last PROBE_UP stopped due to risk? */ + rounds_since_probe:8, /* packet-timed rounds since probed bw */ + loss_round_start:1, /* loss_round_delivered round trip? */ + loss_in_round:1, /* loss marked in this round trip? */ + ecn_in_round:1, /* ECN marked in this round trip? */ + ack_phase:3, /* bbr_ack_phase: meaning of ACKs */ + loss_events_in_round:4,/* losses in STARTUP round */ + initialized:1; /* has bbr_init() been called? */ + u32 alpha_last_delivered; /* tp->delivered at alpha update */ + u32 alpha_last_delivered_ce; /* tp->delivered_ce at alpha update */ + struct tcp_plb_state plb; + + /* Params configurable using setsockopt. Refer to correspoding + * module param for detailed description of params. + */ + struct bbr_params { + u32 high_gain:11, /* max allowed value: 2047 */ + drain_gain:10, /* max allowed value: 1023 */ + cwnd_gain:11; /* max allowed value: 2047 */ + u32 cwnd_min_target:4, /* max allowed value: 15 */ + min_rtt_win_sec:5, /* max allowed value: 31 */ + probe_rtt_mode_ms:9, /* max allowed value: 511 */ + full_bw_cnt:3, /* max allowed value: 7 */ + cwnd_tso_budget:1, /* allowed values: {0, 1} */ + unused3:6, + drain_to_target:1, /* boolean */ + precise_ece_ack:1, /* boolean */ + extra_acked_in_startup:1, /* allowed values: {0, 1} */ + fast_path:1; /* boolean */ + u32 full_bw_thresh:10, /* max allowed value: 1023 */ + startup_cwnd_gain:11, /* max allowed value: 2047 */ + bw_probe_pif_gain:9, /* max allowed value: 511 */ + usage_based_cwnd:1, /* boolean */ + unused2:1; + u16 probe_rtt_win_ms:14, /* max allowed value: 16383 */ + refill_add_inc:2; /* max allowed value: 3 */ + u16 extra_acked_gain:11, /* max allowed value: 2047 */ + extra_acked_win_rtts:5; /* max allowed value: 31*/ + u16 pacing_gain[CYCLE_LEN]; /* max allowed value: 1023 */ + /* Mostly BBR v2 parameters below here: */ + u32 ecn_alpha_gain:8, /* max allowed value: 255 */ + ecn_factor:8, /* max allowed value: 255 */ + ecn_thresh:8, /* max allowed value: 255 */ + beta:8; /* max allowed value: 255 */ + u32 ecn_max_rtt_us:19, /* max allowed value: 524287 */ + bw_probe_reno_gain:9, /* max allowed value: 511 */ + full_loss_cnt:4; /* max allowed value: 15 */ + u32 probe_rtt_cwnd_gain:8, /* max allowed value: 255 */ + inflight_headroom:8, /* max allowed value: 255 */ + loss_thresh:8, /* max allowed value: 255 */ + bw_probe_max_rounds:8; /* max allowed value: 255 */ + u32 bw_probe_rand_rounds:4, /* max allowed value: 15 */ + bw_probe_base_us:26, /* usecs: 0..2^26-1 (67 secs) */ + full_ecn_cnt:2; /* max allowed value: 3 */ + u32 bw_probe_rand_us:26, /* usecs: 0..2^26-1 (67 secs) */ + undo:1, /* boolean */ + tso_rtt_shift:4, /* max allowed value: 15 */ + unused5:1; + u32 ecn_reprobe_gain:9, /* max allowed value: 511 */ + unused1:14, + ecn_alpha_init:9; /* max allowed value: 256 */ + } params; + + struct { + u32 snd_isn; /* Initial sequence number */ + u32 rs_bw; /* last valid rate sample bw */ + u32 target_cwnd; /* target cwnd, based on BDP */ + u8 undo:1, /* Undo even happened but not yet logged */ + unused:7; + char event; /* single-letter event debug codes */ + u16 unused2; + } debug; +}; + +struct bbr_context { + u32 sample_bw; + u32 target_cwnd; + u32 log:1; +}; + +/* Window length of min_rtt filter (in sec). Max allowed value is 31 (0x1F) */ +static u32 bbr_min_rtt_win_sec = 10; +/* Minimum time (in ms) spent at bbr_cwnd_min_target in BBR_PROBE_RTT mode. + * Max allowed value is 511 (0x1FF). + */ +static u32 bbr_probe_rtt_mode_ms = 200; +/* Window length of probe_rtt_min_us filter (in ms), and consequently the + * typical interval between PROBE_RTT mode entries. + * Note that bbr_probe_rtt_win_ms must be <= bbr_min_rtt_win_sec * MSEC_PER_SEC + */ +static u32 bbr_probe_rtt_win_ms = 5000; +/* Skip TSO below the following bandwidth (bits/sec): */ +static int bbr_min_tso_rate = 1200000; + +/* Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting + * in bigger TSO bursts. By default we cut the RTT-based allowance in half + * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance + * is below 1500 bytes after 6 * ~500 usec = 3ms. + */ +static u32 bbr_tso_rtt_shift = 9; /* halve allowance per 2^9 usecs, 512us */ + +/* Select cwnd TSO budget approach: + * 0: padding + * 1: flooring + */ +static uint bbr_cwnd_tso_budget = 1; + +/* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. + * In order to help drive the network toward lower queues and low latency while + * maintaining high utilization, the average pacing rate aims to be slightly + * lower than the estimated bandwidth. This is an important aspect of the + * design. + */ +static const int bbr_pacing_margin_percent = 1; + +/* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain + * that will allow a smoothly increasing pacing rate that will double each RTT + * and send the same number of packets per RTT that an un-paced, slow-starting + * Reno or CUBIC flow would. Max allowed value is 2047 (0x7FF). + */ +static int bbr_high_gain = BBR_UNIT * 2885 / 1000 + 1; +/* The gain for deriving startup cwnd. Max allowed value is 2047 (0x7FF). */ +static int bbr_startup_cwnd_gain = BBR_UNIT * 2885 / 1000 + 1; +/* The pacing gain of 1/high_gain in BBR_DRAIN is calculated to typically drain + * the queue created in BBR_STARTUP in a single round. Max allowed value + * is 1023 (0x3FF). + */ +static int bbr_drain_gain = BBR_UNIT * 1000 / 2885; +/* The gain for deriving steady-state cwnd tolerates delayed/stretched ACKs. + * Max allowed value is 2047 (0x7FF). + */ +static int bbr_cwnd_gain = BBR_UNIT * 2; +/* The pacing_gain values for the PROBE_BW gain cycle, to discover/share bw. + * Max allowed value for each element is 1023 (0x3FF). + */ +enum bbr_pacing_gain_phase { + BBR_BW_PROBE_UP = 0, /* push up inflight to probe for bw/vol */ + BBR_BW_PROBE_DOWN = 1, /* drain excess inflight from the queue */ + BBR_BW_PROBE_CRUISE = 2, /* use pipe, w/ headroom in queue/pipe */ + BBR_BW_PROBE_REFILL = 3, /* v2: refill the pipe again to 100% */ +}; +static int bbr_pacing_gain[] = { + BBR_UNIT * 5 / 4, /* probe for more available bw */ + BBR_UNIT * 3 / 4, /* drain queue and/or yield bw to other flows */ + BBR_UNIT, BBR_UNIT, BBR_UNIT, /* cruise at 1.0*bw to utilize pipe, */ + BBR_UNIT, BBR_UNIT, BBR_UNIT /* without creating excess queue... */ +}; + +/* Try to keep at least this many packets in flight, if things go smoothly. For + * smooth functioning, a sliding window protocol ACKing every other packet + * needs at least 4 packets in flight. Max allowed value is 15 (0xF). + */ +static u32 bbr_cwnd_min_target = 4; + +/* Cwnd to BDP proportion in PROBE_RTT mode scaled by BBR_UNIT. Default: 50%. + * Use 0 to disable. Max allowed value is 255. + */ +static u32 bbr_probe_rtt_cwnd_gain = BBR_UNIT * 1 / 2; + +/* To estimate if BBR_STARTUP mode (i.e. high_gain) has filled pipe... */ +/* If bw has increased significantly (1.25x), there may be more bw available. + * Max allowed value is 1023 (0x3FF). + */ +static u32 bbr_full_bw_thresh = BBR_UNIT * 5 / 4; +/* But after 3 rounds w/o significant bw growth, estimate pipe is full. + * Max allowed value is 7 (0x7). + */ +static u32 bbr_full_bw_cnt = 3; + +static u32 bbr_flags; /* Debugging related stuff */ + +/* Whether to debug using printk. + */ +static bool bbr_debug_with_printk; + +/* Whether to debug using ftrace event tcp:tcp_bbr_event. + * Ignored when bbr_debug_with_printk is set. + */ +static bool bbr_debug_ftrace; + +/* Experiment: each cycle, try to hold sub-unity gain until inflight <= BDP. */ +static bool bbr_drain_to_target = true; /* default: enabled */ + +/* Experiment: Flags to control BBR with ECN behavior. + */ +static bool bbr_precise_ece_ack = true; /* default: enabled */ + +/* The max rwin scaling shift factor is 14 (RFC 1323), so the max sane rwin is + * (2^(16+14) B)/(1024 B/packet) = 1M packets. + */ +static u32 bbr_cwnd_warn_val = 1U << 20; + +static u16 bbr_debug_port_mask; + +/* BBR module parameters. These are module parameters only in Google prod. + * Upstream these are intentionally not module parameters. + */ +static int bbr_pacing_gain_size = CYCLE_LEN; + +/* Gain factor for adding extra_acked to target cwnd: */ +static int bbr_extra_acked_gain = 256; + +/* Window length of extra_acked window. Max allowed val is 31. */ +static u32 bbr_extra_acked_win_rtts = 5; + +/* Max allowed val for ack_epoch_acked, after which sampling epoch is reset */ +static u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; + +/* Time period for clamping cwnd increment due to ack aggregation */ +static u32 bbr_extra_acked_max_us = 100 * 1000; + +/* Use extra acked in startup ? + * 0: disabled + * 1: use latest extra_acked value from 1-2 rtt in startup + */ +static int bbr_extra_acked_in_startup = 1; /* default: enabled */ + +/* Experiment: don't grow cwnd beyond twice of what we just probed. */ +static bool bbr_usage_based_cwnd; /* default: disabled */ + +/* For lab testing, researchers can enable BBRv2 ECN support with this flag, + * when they know that any ECN marks that the connections experience will be + * DCTCP/L4S-style ECN marks, rather than RFC3168 ECN marks. + * TODO(ncardwell): Production use of the BBRv2 ECN functionality depends on + * negotiation or configuration that is outside the scope of the BBRv2 + * alpha release. + */ +static bool bbr_ecn_enable = false; + +module_param_named(min_tso_rate, bbr_min_tso_rate, int, 0644); +module_param_named(tso_rtt_shift, bbr_tso_rtt_shift, int, 0644); +module_param_named(high_gain, bbr_high_gain, int, 0644); +module_param_named(drain_gain, bbr_drain_gain, int, 0644); +module_param_named(startup_cwnd_gain, bbr_startup_cwnd_gain, int, 0644); +module_param_named(cwnd_gain, bbr_cwnd_gain, int, 0644); +module_param_array_named(pacing_gain, bbr_pacing_gain, int, + &bbr_pacing_gain_size, 0644); +module_param_named(cwnd_min_target, bbr_cwnd_min_target, uint, 0644); +module_param_named(probe_rtt_cwnd_gain, + bbr_probe_rtt_cwnd_gain, uint, 0664); +module_param_named(cwnd_warn_val, bbr_cwnd_warn_val, uint, 0664); +module_param_named(debug_port_mask, bbr_debug_port_mask, ushort, 0644); +module_param_named(flags, bbr_flags, uint, 0644); +module_param_named(debug_ftrace, bbr_debug_ftrace, bool, 0644); +module_param_named(debug_with_printk, bbr_debug_with_printk, bool, 0644); +module_param_named(min_rtt_win_sec, bbr_min_rtt_win_sec, uint, 0644); +module_param_named(probe_rtt_mode_ms, bbr_probe_rtt_mode_ms, uint, 0644); +module_param_named(probe_rtt_win_ms, bbr_probe_rtt_win_ms, uint, 0644); +module_param_named(full_bw_thresh, bbr_full_bw_thresh, uint, 0644); +module_param_named(full_bw_cnt, bbr_full_bw_cnt, uint, 0644); +module_param_named(cwnd_tso_bduget, bbr_cwnd_tso_budget, uint, 0664); +module_param_named(extra_acked_gain, bbr_extra_acked_gain, int, 0664); +module_param_named(extra_acked_win_rtts, + bbr_extra_acked_win_rtts, uint, 0664); +module_param_named(extra_acked_max_us, + bbr_extra_acked_max_us, uint, 0664); +module_param_named(ack_epoch_acked_reset_thresh, + bbr_ack_epoch_acked_reset_thresh, uint, 0664); +module_param_named(drain_to_target, bbr_drain_to_target, bool, 0664); +module_param_named(precise_ece_ack, bbr_precise_ece_ack, bool, 0664); +module_param_named(extra_acked_in_startup, + bbr_extra_acked_in_startup, int, 0664); +module_param_named(usage_based_cwnd, bbr_usage_based_cwnd, bool, 0664); +module_param_named(ecn_enable, bbr_ecn_enable, bool, 0664); + +static void bbr2_exit_probe_rtt(struct sock *sk); +static void bbr2_reset_congestion_signals(struct sock *sk); + +static void bbr_check_probe_rtt_done(struct sock *sk); + +/* Do we estimate that STARTUP filled the pipe? */ +static bool bbr_full_bw_reached(const struct sock *sk) +{ + const struct bbr *bbr = inet_csk_ca(sk); + + return bbr->full_bw_reached; +} + +/* Return the windowed max recent bandwidth sample, in pkts/uS << BW_SCALE. */ +static u32 bbr_max_bw(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return max(bbr->bw_hi[0], bbr->bw_hi[1]); +} + +/* Return the estimated bandwidth of the path, in pkts/uS << BW_SCALE. */ +static u32 bbr_bw(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return min(bbr_max_bw(sk), bbr->bw_lo); +} + +/* Return maximum extra acked in past k-2k round trips, + * where k = bbr_extra_acked_win_rtts. + */ +static u16 bbr_extra_acked(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return max(bbr->extra_acked[0], bbr->extra_acked[1]); +} + +/* Return rate in bytes per second, optionally with a gain. + * The order here is chosen carefully to avoid overflow of u64. This should + * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. + */ +static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain, + int margin) +{ + unsigned int mss = tcp_sk(sk)->mss_cache; + + rate *= mss; + rate *= gain; + rate >>= BBR_SCALE; + rate *= USEC_PER_SEC / 100 * (100 - margin); + rate >>= BW_SCALE; + rate = max(rate, 1ULL); + return rate; +} + +static u64 bbr_bw_bytes_per_sec(struct sock *sk, u64 rate) +{ + return bbr_rate_bytes_per_sec(sk, rate, BBR_UNIT, 0); +} + +static u64 bbr_rate_kbps(struct sock *sk, u64 rate) +{ + rate = bbr_bw_bytes_per_sec(sk, rate); + rate *= 8; + do_div(rate, 1000); + return rate; +} + +static u32 bbr_tso_segs_goal(struct sock *sk); +static void bbr_debug(struct sock *sk, u32 acked, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + static const char ca_states[] = { + [TCP_CA_Open] = 'O', + [TCP_CA_Disorder] = 'D', + [TCP_CA_CWR] = 'C', + [TCP_CA_Recovery] = 'R', + [TCP_CA_Loss] = 'L', + }; + static const char mode[] = { + 'G', /* Growing - BBR_STARTUP */ + 'D', /* Drain - BBR_DRAIN */ + 'W', /* Window - BBR_PROBE_BW */ + 'M', /* Min RTT - BBR_PROBE_RTT */ + }; + static const char ack_phase[] = { /* bbr_ack_phase strings */ + 'I', /* BBR_ACKS_INIT - 'Init' */ + 'R', /* BBR_ACKS_REFILLING - 'Refilling' */ + 'B', /* BBR_ACKS_PROBE_STARTING - 'Before' */ + 'F', /* BBR_ACKS_PROBE_FEEDBACK - 'Feedback' */ + 'A', /* BBR_ACKS_PROBE_STOPPING - 'After' */ + }; + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + const u32 una = tp->snd_una - bbr->debug.snd_isn; + const u32 fack = tcp_highest_sack_seq(tp); + const u16 dport = ntohs(inet_sk(sk)->inet_dport); + bool is_port_match = (bbr_debug_port_mask && + ((dport & bbr_debug_port_mask) == 0)); + char debugmsg[320]; + + if (sk->sk_state == TCP_SYN_SENT) + return; /* no bbr_init() yet if SYN retransmit -> CA_Loss */ + + if (!tp->snd_cwnd || tp->snd_cwnd > bbr_cwnd_warn_val) { + char addr[INET6_ADDRSTRLEN + 10] = { 0 }; + + if (sk->sk_family == AF_INET) + snprintf(addr, sizeof(addr), "%pI4:%u", + &inet_sk(sk)->inet_daddr, dport); + else if (sk->sk_family == AF_INET6) + snprintf(addr, sizeof(addr), "%pI6:%u", + &sk->sk_v6_daddr, dport); + + WARN_ONCE(1, + "BBR %s cwnd alert: %u " + "snd_una: %u ca: %d pacing_gain: %u cwnd_gain: %u " + "bw: %u rtt: %u min_rtt: %u " + "acked: %u tso_segs: %u " + "bw: %d %ld %d pif: %u\n", + addr, tp->snd_cwnd, + una, inet_csk(sk)->icsk_ca_state, + bbr->pacing_gain, bbr->cwnd_gain, + bbr_max_bw(sk), (tp->srtt_us >> 3), bbr->min_rtt_us, + acked, bbr_tso_segs_goal(sk), + rs->delivered, rs->interval_us, rs->is_retrans, + tcp_packets_in_flight(tp)); + } + + if (likely(!bbr_debug_with_printk && !bbr_debug_ftrace)) + return; + + if (!sock_flag(sk, SOCK_DBG) && !is_port_match) + return; + + if (!ctx->log && !tp->app_limited && !(bbr_flags & FLAG_DEBUG_VERBOSE)) + return; + + if (ipv4_is_loopback(inet_sk(sk)->inet_daddr) && + !(bbr_flags & FLAG_DEBUG_LOOPBACK)) + return; + + snprintf(debugmsg, sizeof(debugmsg) - 1, + "BBR %pI4:%-5u %5u,%03u:%-7u %c " + "%c %2u br %2u cr %2d rtt %5ld d %2d i %5ld mrtt %d %cbw %llu " + "bw %llu lb %llu ib %llu qb %llu " + "a %u if %2u %c %c dl %u l %u al %u # %u t %u %c %c " + "lr %d er %d ea %d bwl %lld il %d ih %d c %d " + "v %d %c %u %c %s\n", + &inet_sk(sk)->inet_daddr, dport, + una / 1000, una % 1000, fack - tp->snd_una, + ca_states[inet_csk(sk)->icsk_ca_state], + bbr->debug.undo ? '@' : mode[bbr->mode], + tp->snd_cwnd, + bbr_extra_acked(sk), /* br (legacy): extra_acked */ + rs->tx_in_flight, /* cr (legacy): tx_inflight */ + rs->rtt_us, + rs->delivered, + rs->interval_us, + bbr->min_rtt_us, + rs->is_app_limited ? '_' : 'l', + bbr_rate_kbps(sk, ctx->sample_bw), /* lbw: latest sample bw */ + bbr_rate_kbps(sk, bbr_max_bw(sk)), /* bw: max bw */ + 0ULL, /* lb: [obsolete] */ + 0ULL, /* ib: [obsolete] */ + div_u64((u64)sk->sk_pacing_rate * 8, 1000), + acked, + tcp_packets_in_flight(tp), + rs->is_ack_delayed ? 'd' : '.', + bbr->round_start ? '*' : '.', + tp->delivered, tp->lost, + tp->app_limited, + 0, /* #: [obsolete] */ + ctx->target_cwnd, + tp->reord_seen ? 'r' : '.', /* r: reordering seen? */ + ca_states[bbr->prev_ca_state], + (rs->lost + rs->delivered) > 0 ? + (1000 * rs->lost / + (rs->lost + rs->delivered)) : 0, /* lr: loss rate x1000 */ + (rs->delivered) > 0 ? + (1000 * rs->delivered_ce / + (rs->delivered)) : 0, /* er: ECN rate x1000 */ + 1000 * bbr->ecn_alpha >> BBR_SCALE, /* ea: ECN alpha x1000 */ + bbr->bw_lo == ~0U ? + -1 : (s64)bbr_rate_kbps(sk, bbr->bw_lo), /* bwl */ + bbr->inflight_lo, /* il */ + bbr->inflight_hi, /* ih */ + bbr->bw_probe_up_cnt, /* c */ + 2, /* v: version */ + bbr->debug.event, + bbr->cycle_idx, + ack_phase[bbr->ack_phase], + bbr->bw_probe_samples ? "Y" : "N"); + debugmsg[sizeof(debugmsg) - 1] = 0; + + /* printk takes a higher precedence. */ + if (bbr_debug_with_printk) + printk(KERN_DEBUG "%s", debugmsg); + + if (unlikely(bbr->debug.undo)) + bbr->debug.undo = 0; +} + +/* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ +static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) +{ + u64 rate = bw; + + rate = bbr_rate_bytes_per_sec(sk, rate, gain, + bbr_pacing_margin_percent); + rate = min_t(u64, rate, sk->sk_max_pacing_rate); + return rate; +} + +/* Initialize pacing rate to: high_gain * init_cwnd / RTT. */ +static void bbr_init_pacing_rate_from_rtt(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + u32 rtt_us; + + if (tp->srtt_us) { /* any RTT sample yet? */ + rtt_us = max(tp->srtt_us >> 3, 1U); + bbr->has_seen_rtt = 1; + } else { /* no RTT sample yet */ + rtt_us = USEC_PER_MSEC; /* use nominal default RTT */ + } + bw = (u64)tp->snd_cwnd * BW_UNIT; + do_div(bw, rtt_us); + sk->sk_pacing_rate = bbr_bw_to_pacing_rate(sk, bw, bbr->params.high_gain); +} + +/* Pace using current bw estimate and a gain factor. */ +static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + unsigned long rate = bbr_bw_to_pacing_rate(sk, bw, gain); + + if (unlikely(!bbr->has_seen_rtt && tp->srtt_us)) + bbr_init_pacing_rate_from_rtt(sk); + if (bbr_full_bw_reached(sk) || rate > sk->sk_pacing_rate) + sk->sk_pacing_rate = rate; +} + +static u32 bbr_min_tso_segs(struct sock *sk) +{ + return sk->sk_pacing_rate < (bbr_min_tso_rate >> 3) ? 1 : 2; +} + +/* Return the number of segments BBR would like in a TSO/GSO skb, given + * a particular max gso size as a constraint. + */ +static u32 bbr_tso_segs_generic(struct sock *sk, unsigned int mss_now, + u32 gso_max_size) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 segs, r; + u64 bytes; + + /* Budget a TSO/GSO burst size allowance based on bw (pacing_rate). */ + bytes = sk->sk_pacing_rate >> sk->sk_pacing_shift; + + /* Budget a TSO/GSO burst size allowance based on min_rtt. For every + * K = 2^tso_rtt_shift microseconds of min_rtt, halve the burst. + * The min_rtt-based burst allowance is: 64 KBytes / 2^(min_rtt/K) + */ + if (bbr->params.tso_rtt_shift) { + r = bbr->min_rtt_us >> bbr->params.tso_rtt_shift; + if (r < BITS_PER_TYPE(u32)) /* prevent undefined behavior */ + bytes += GSO_MAX_SIZE >> r; + } + + bytes = min_t(u32, bytes, gso_max_size - 1 - MAX_TCP_HEADER); + segs = max_t(u32, div_u64(bytes, mss_now), bbr_min_tso_segs(sk)); + return segs; +} + +/* Custom tcp_tso_autosize() for BBR, used at transmit time to cap skb size. */ +static u32 bbr_tso_segs(struct sock *sk, unsigned int mss_now) +{ + return bbr_tso_segs_generic(sk, mss_now, sk->sk_gso_max_size); +} + +/* Like bbr_tso_segs(), using mss_cache, ignoring driver's sk_gso_max_size. */ +static u32 bbr_tso_segs_goal(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + return bbr_tso_segs_generic(sk, tp->mss_cache, GSO_MAX_SIZE); +} + +/* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ +static void bbr_save_cwnd(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->prev_ca_state < TCP_CA_Recovery && bbr->mode != BBR_PROBE_RTT) + bbr->prior_cwnd = tp->snd_cwnd; /* this cwnd is good enough */ + else /* loss recovery or BBR_PROBE_RTT have temporarily cut cwnd */ + bbr->prior_cwnd = max(bbr->prior_cwnd, tp->snd_cwnd); +} + +static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (event == CA_EVENT_TX_START) { + tcp_plb_check_rehash(sk, &bbr->plb); + + if (!tp->app_limited) + return; + bbr->idle_restart = 1; + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + /* Avoid pointless buffer overflows: pace at est. bw if we don't + * need more speed (we're restarting from idle and app-limited). + */ + if (bbr->mode == BBR_PROBE_BW) + bbr_set_pacing_rate(sk, bbr_bw(sk), BBR_UNIT); + else if (bbr->mode == BBR_PROBE_RTT) + bbr_check_probe_rtt_done(sk); + } else if ((event == CA_EVENT_ECN_IS_CE || + event == CA_EVENT_ECN_NO_CE) && + bbr_ecn_enable && + bbr->params.precise_ece_ack) { + u32 state = bbr->ce_state; + dctcp_ece_ack_update(sk, event, &bbr->prior_rcv_nxt, &state); + bbr->ce_state = state; + if (tp->fast_ack_mode == 2 && event == CA_EVENT_ECN_IS_CE) + tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); + } +} + +/* Calculate bdp based on min RTT and the estimated bottleneck bandwidth: + * + * bdp = ceil(bw * min_rtt * gain) + * + * The key factor, gain, controls the amount of queue. While a small gain + * builds a smaller queue, it becomes more vulnerable to noise in RTT + * measurements (e.g., delayed ACKs or other ACK compression effects). This + * noise may cause BBR to under-estimate the rate. + */ +static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bdp; + u64 w; + + /* If we've never had a valid RTT sample, cap cwnd at the initial + * default. This should only happen when the connection is not using TCP + * timestamps and has retransmitted all of the SYN/SYNACK/data packets + * ACKed so far. In this case, an RTO can cut cwnd to 1, in which + * case we need to slow-start up toward something safe: initial cwnd. + */ + if (unlikely(bbr->min_rtt_us == ~0U)) /* no valid RTT samples yet? */ + return bbr->init_cwnd; /* be safe: cap at initial cwnd */ + + w = (u64)bw * bbr->min_rtt_us; + + /* Apply a gain to the given value, remove the BW_SCALE shift, and + * round the value up to avoid a negative feedback loop. + */ + bdp = (((w * gain) >> BBR_SCALE) + BW_UNIT - 1) / BW_UNIT; + + return bdp; +} + +/* To achieve full performance in high-speed paths, we budget enough cwnd to + * fit full-sized skbs in-flight on both end hosts to fully utilize the path: + * - one skb in sending host Qdisc, + * - one skb in sending host TSO/GSO engine + * - one skb being received by receiver host LRO/GRO/delayed-ACK engine + * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because + * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, + * which allows 2 outstanding 2-packet sequences, to try to keep pipe + * full even with ACK-every-other-packet delayed ACKs. + */ +static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 tso_segs_goal; + + tso_segs_goal = 3 * bbr_tso_segs_goal(sk); + + /* Allow enough full-sized skbs in flight to utilize end systems. */ + if (bbr->params.cwnd_tso_budget == 1) { + cwnd = max_t(u32, cwnd, tso_segs_goal); + cwnd = max_t(u32, cwnd, bbr->params.cwnd_min_target); + } else { + cwnd += tso_segs_goal; + cwnd = (cwnd + 1) & ~1U; + } + /* Ensure gain cycling gets inflight above BDP even for small BDPs. */ + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) + cwnd += 2; + + return cwnd; +} + +/* Find inflight based on min RTT and the estimated bottleneck bandwidth. */ +static u32 bbr_inflight(struct sock *sk, u32 bw, int gain) +{ + u32 inflight; + + inflight = bbr_bdp(sk, bw, gain); + inflight = bbr_quantization_budget(sk, inflight); + + return inflight; +} + +/* With pacing at lower layers, there's often less data "in the network" than + * "in flight". With TSQ and departure time pacing at lower layers (e.g. fq), + * we often have several skbs queued in the pacing layer with a pre-scheduled + * earliest departure time (EDT). BBR adapts its pacing rate based on the + * inflight level that it estimates has already been "baked in" by previous + * departure time decisions. We calculate a rough estimate of the number of our + * packets that might be in the network at the earliest departure time for the + * next skb scheduled: + * in_network_at_edt = inflight_at_edt - (EDT - now) * bw + * If we're increasing inflight, then we want to know if the transmit of the + * EDT skb will push inflight above the target, so inflight_at_edt includes + * bbr_tso_segs_goal() from the skb departing at EDT. If decreasing inflight, + * then estimate if inflight will sink too low just before the EDT transmit. + */ +static u32 bbr_packets_in_net_at_edt(struct sock *sk, u32 inflight_now) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 now_ns, edt_ns, interval_us; + u32 interval_delivered, inflight_at_edt; + + now_ns = tp->tcp_clock_cache; + edt_ns = max(tp->tcp_wstamp_ns, now_ns); + interval_us = div_u64(edt_ns - now_ns, NSEC_PER_USEC); + interval_delivered = (u64)bbr_bw(sk) * interval_us >> BW_SCALE; + inflight_at_edt = inflight_now; + if (bbr->pacing_gain > BBR_UNIT) /* increasing inflight */ + inflight_at_edt += bbr_tso_segs_goal(sk); /* include EDT skb */ + if (interval_delivered >= inflight_at_edt) + return 0; + return inflight_at_edt - interval_delivered; +} + +/* Find the cwnd increment based on estimate of ack aggregation */ +static u32 bbr_ack_aggregation_cwnd(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 max_aggr_cwnd, aggr_cwnd = 0; + + if (bbr->params.extra_acked_gain && + (bbr_full_bw_reached(sk) || bbr->params.extra_acked_in_startup)) { + max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) + / BW_UNIT; + aggr_cwnd = (bbr->params.extra_acked_gain * bbr_extra_acked(sk)) + >> BBR_SCALE; + aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); + } + + return aggr_cwnd; +} + +/* Returns the cwnd for PROBE_RTT mode. */ +static u32 bbr_probe_rtt_cwnd(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->params.probe_rtt_cwnd_gain == 0) + return bbr->params.cwnd_min_target; + return max_t(u32, bbr->params.cwnd_min_target, + bbr_bdp(sk, bbr_bw(sk), bbr->params.probe_rtt_cwnd_gain)); +} + +/* Slow-start up toward target cwnd (if bw estimate is growing, or packet loss + * has drawn us down below target), or snap down to target if we're above it. + */ +static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, + u32 acked, u32 bw, int gain, u32 cwnd, + struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 target_cwnd = 0, prev_cwnd = tp->snd_cwnd, max_probe; + + if (!acked) + goto done; /* no packet fully ACKed; just apply caps */ + + target_cwnd = bbr_bdp(sk, bw, gain); + + /* Increment the cwnd to account for excess ACKed data that seems + * due to aggregation (of data and/or ACKs) visible in the ACK stream. + */ + target_cwnd += bbr_ack_aggregation_cwnd(sk); + target_cwnd = bbr_quantization_budget(sk, target_cwnd); + + /* If we're below target cwnd, slow start cwnd toward target cwnd. */ + bbr->debug.target_cwnd = target_cwnd; + + /* Update cwnd and enable fast path if cwnd reaches target_cwnd. */ + bbr->try_fast_path = 0; + if (bbr_full_bw_reached(sk)) { /* only cut cwnd if we filled the pipe */ + cwnd += acked; + if (cwnd >= target_cwnd) { + cwnd = target_cwnd; + bbr->try_fast_path = 1; + } + } else if (cwnd < target_cwnd || cwnd < 2 * bbr->init_cwnd) { + cwnd += acked; + } else { + bbr->try_fast_path = 1; + } + + /* When growing cwnd, don't grow beyond twice what we just probed. */ + if (bbr->params.usage_based_cwnd) { + max_probe = max(2 * tp->max_packets_out, tp->snd_cwnd); + cwnd = min(cwnd, max_probe); + } + + cwnd = max_t(u32, cwnd, bbr->params.cwnd_min_target); +done: + tp->snd_cwnd = min(cwnd, tp->snd_cwnd_clamp); /* apply global cap */ + if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ + tp->snd_cwnd = min_t(u32, tp->snd_cwnd, bbr_probe_rtt_cwnd(sk)); + + ctx->target_cwnd = target_cwnd; + ctx->log = (tp->snd_cwnd != prev_cwnd); +} + +/* See if we have reached next round trip */ +static void bbr_update_round_start(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->round_start = 0; + + /* See if we've reached the next RTT */ + if (rs->interval_us > 0 && + !before(rs->prior_delivered, bbr->next_rtt_delivered)) { + bbr->next_rtt_delivered = tp->delivered; + bbr->round_start = 1; + } +} + +/* Calculate the bandwidth based on how fast packets are delivered */ +static void bbr_calculate_bw_sample(struct sock *sk, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u64 bw = 0; + + /* Divide delivered by the interval to find a (lower bound) bottleneck + * bandwidth sample. Delivered is in packets and interval_us in uS and + * ratio will be <<1 for most connections. So delivered is first scaled. + * Round up to allow growth at low rates, even with integer division. + */ + if (rs->interval_us > 0) { + if (WARN_ONCE(rs->delivered < 0, + "negative delivered: %d interval_us: %ld\n", + rs->delivered, rs->interval_us)) + return; + + bw = DIV_ROUND_UP_ULL((u64)rs->delivered * BW_UNIT, rs->interval_us); + } + + ctx->sample_bw = bw; + bbr->debug.rs_bw = bw; +} + +/* Estimates the windowed max degree of ack aggregation. + * This is used to provision extra in-flight data to keep sending during + * inter-ACK silences. + * + * Degree of ack aggregation is estimated as extra data acked beyond expected. + * + * max_extra_acked = "maximum recent excess data ACKed beyond max_bw * interval" + * cwnd += max_extra_acked + * + * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). + * Max filter is an approximate sliding window of 5-10 (packet timed) round + * trips for non-startup phase, and 1-2 round trips for startup. + */ +static void bbr_update_ack_aggregation(struct sock *sk, + const struct rate_sample *rs) +{ + u32 epoch_us, expected_acked, extra_acked; + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 extra_acked_win_rtts_thresh = bbr->params.extra_acked_win_rtts; + + if (!bbr->params.extra_acked_gain || rs->acked_sacked <= 0 || + rs->delivered < 0 || rs->interval_us <= 0) + return; + + if (bbr->round_start) { + bbr->extra_acked_win_rtts = min(0x1F, + bbr->extra_acked_win_rtts + 1); + if (bbr->params.extra_acked_in_startup && + !bbr_full_bw_reached(sk)) + extra_acked_win_rtts_thresh = 1; + if (bbr->extra_acked_win_rtts >= + extra_acked_win_rtts_thresh) { + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? + 0 : 1; + bbr->extra_acked[bbr->extra_acked_win_idx] = 0; + } + } + + /* Compute how many packets we expected to be delivered over epoch. */ + epoch_us = tcp_stamp_us_delta(tp->delivered_mstamp, + bbr->ack_epoch_mstamp); + expected_acked = ((u64)bbr_bw(sk) * epoch_us) / BW_UNIT; + + /* Reset the aggregation epoch if ACK rate is below expected rate or + * significantly large no. of ack received since epoch (potentially + * quite old epoch). + */ + if (bbr->ack_epoch_acked <= expected_acked || + (bbr->ack_epoch_acked + rs->acked_sacked >= + bbr_ack_epoch_acked_reset_thresh)) { + bbr->ack_epoch_acked = 0; + bbr->ack_epoch_mstamp = tp->delivered_mstamp; + expected_acked = 0; + } + + /* Compute excess data delivered, beyond what was expected. */ + bbr->ack_epoch_acked = min_t(u32, 0xFFFFF, + bbr->ack_epoch_acked + rs->acked_sacked); + extra_acked = bbr->ack_epoch_acked - expected_acked; + extra_acked = min(extra_acked, tp->snd_cwnd); + if (extra_acked > bbr->extra_acked[bbr->extra_acked_win_idx]) + bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; +} + +/* Estimate when the pipe is full, using the change in delivery rate: BBR + * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by + * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited + * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the + * higher rwin, 3: we get higher delivery rate samples. Or transient + * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar + * design goal, but uses delay and inter-ACK spacing instead of bandwidth. + */ +static void bbr_check_full_bw_reached(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bw_thresh; + + if (bbr_full_bw_reached(sk) || !bbr->round_start || rs->is_app_limited) + return; + + bw_thresh = (u64)bbr->full_bw * bbr->params.full_bw_thresh >> BBR_SCALE; + if (bbr_max_bw(sk) >= bw_thresh) { + bbr->full_bw = bbr_max_bw(sk); + bbr->full_bw_cnt = 0; + return; + } + ++bbr->full_bw_cnt; + bbr->full_bw_reached = bbr->full_bw_cnt >= bbr->params.full_bw_cnt; +} + +/* If pipe is probably full, drain the queue and then enter steady-state. */ +static bool bbr_check_drain(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { + bbr->mode = BBR_DRAIN; /* drain queue we created */ + tcp_sk(sk)->snd_ssthresh = + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + bbr2_reset_congestion_signals(sk); + } /* fall through to check if in-flight is already small: */ + if (bbr->mode == BBR_DRAIN && + bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) + return true; /* exiting DRAIN now */ + return false; +} + +static void bbr_check_probe_rtt_done(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (!(bbr->probe_rtt_done_stamp && + after(tcp_jiffies32, bbr->probe_rtt_done_stamp))) + return; + + bbr->probe_rtt_min_stamp = tcp_jiffies32; /* schedule next PROBE_RTT */ + tp->snd_cwnd = max(tp->snd_cwnd, bbr->prior_cwnd); + bbr2_exit_probe_rtt(sk); +} + +/* The goal of PROBE_RTT mode is to have BBR flows cooperatively and + * periodically drain the bottleneck queue, to converge to measure the true + * min_rtt (unloaded propagation delay). This allows the flows to keep queues + * small (reducing queuing delay and packet loss) and achieve fairness among + * BBR flows. + * + * The min_rtt filter window is 10 seconds. When the min_rtt estimate expires, + * we enter PROBE_RTT mode and cap the cwnd at bbr_cwnd_min_target=4 packets. + * After at least bbr_probe_rtt_mode_ms=200ms and at least one packet-timed + * round trip elapsed with that flight size <= 4, we leave PROBE_RTT mode and + * re-enter the previous mode. BBR uses 200ms to approximately bound the + * performance penalty of PROBE_RTT's cwnd capping to roughly 2% (200ms/10s). + * + * Note that flows need only pay 2% if they are busy sending over the last 10 + * seconds. Interactive applications (e.g., Web, RPCs, video chunks) often have + * natural silences or low-rate periods within 10 seconds where the rate is low + * enough for long enough to drain its queue in the bottleneck. We pick up + * these min RTT measurements opportunistically with our min_rtt filter. :-) + */ +static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + bool probe_rtt_expired, min_rtt_expired; + u32 expire; + + /* Track min RTT in probe_rtt_win_ms to time next PROBE_RTT state. */ + expire = bbr->probe_rtt_min_stamp + + msecs_to_jiffies(bbr->params.probe_rtt_win_ms); + probe_rtt_expired = after(tcp_jiffies32, expire); + if (rs->rtt_us >= 0 && + (rs->rtt_us <= bbr->probe_rtt_min_us || + (probe_rtt_expired && !rs->is_ack_delayed))) { + bbr->probe_rtt_min_us = rs->rtt_us; + bbr->probe_rtt_min_stamp = tcp_jiffies32; + } + /* Track min RTT seen in the min_rtt_win_sec filter window: */ + expire = bbr->min_rtt_stamp + bbr->params.min_rtt_win_sec * HZ; + min_rtt_expired = after(tcp_jiffies32, expire); + if (bbr->probe_rtt_min_us <= bbr->min_rtt_us || + min_rtt_expired) { + bbr->min_rtt_us = bbr->probe_rtt_min_us; + bbr->min_rtt_stamp = bbr->probe_rtt_min_stamp; + } + + if (bbr->params.probe_rtt_mode_ms > 0 && probe_rtt_expired && + !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { + bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ + bbr_save_cwnd(sk); /* note cwnd so we can restore it */ + bbr->probe_rtt_done_stamp = 0; + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; + } + + if (bbr->mode == BBR_PROBE_RTT) { + /* Ignore low rate samples during this mode. */ + tp->app_limited = + (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; + /* Maintain min packets in flight for max(200 ms, 1 round). */ + if (!bbr->probe_rtt_done_stamp && + tcp_packets_in_flight(tp) <= bbr_probe_rtt_cwnd(sk)) { + bbr->probe_rtt_done_stamp = tcp_jiffies32 + + msecs_to_jiffies(bbr->params.probe_rtt_mode_ms); + bbr->probe_rtt_round_done = 0; + bbr->next_rtt_delivered = tp->delivered; + } else if (bbr->probe_rtt_done_stamp) { + if (bbr->round_start) + bbr->probe_rtt_round_done = 1; + if (bbr->probe_rtt_round_done) + bbr_check_probe_rtt_done(sk); + } + } + /* Restart after idle ends only once we process a new S/ACK for data */ + if (rs->delivered > 0) + bbr->idle_restart = 0; +} + +static void bbr_update_gains(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + switch (bbr->mode) { + case BBR_STARTUP: + bbr->pacing_gain = bbr->params.high_gain; + bbr->cwnd_gain = bbr->params.startup_cwnd_gain; + break; + case BBR_DRAIN: + bbr->pacing_gain = bbr->params.drain_gain; /* slow, to drain */ + bbr->cwnd_gain = bbr->params.startup_cwnd_gain; /* keep cwnd */ + break; + case BBR_PROBE_BW: + bbr->pacing_gain = bbr->params.pacing_gain[bbr->cycle_idx]; + bbr->cwnd_gain = bbr->params.cwnd_gain; + break; + case BBR_PROBE_RTT: + bbr->pacing_gain = BBR_UNIT; + bbr->cwnd_gain = BBR_UNIT; + break; + default: + WARN_ONCE(1, "BBR bad mode: %u\n", bbr->mode); + break; + } +} + +static void bbr_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + int i; + + WARN_ON_ONCE(tp->snd_cwnd >= bbr_cwnd_warn_val); + + bbr->initialized = 1; + bbr->params.high_gain = min(0x7FF, bbr_high_gain); + bbr->params.drain_gain = min(0x3FF, bbr_drain_gain); + bbr->params.startup_cwnd_gain = min(0x7FF, bbr_startup_cwnd_gain); + bbr->params.cwnd_gain = min(0x7FF, bbr_cwnd_gain); + bbr->params.cwnd_tso_budget = min(0x1U, bbr_cwnd_tso_budget); + bbr->params.cwnd_min_target = min(0xFU, bbr_cwnd_min_target); + bbr->params.min_rtt_win_sec = min(0x1FU, bbr_min_rtt_win_sec); + bbr->params.probe_rtt_mode_ms = min(0x1FFU, bbr_probe_rtt_mode_ms); + bbr->params.full_bw_cnt = min(0x7U, bbr_full_bw_cnt); + bbr->params.full_bw_thresh = min(0x3FFU, bbr_full_bw_thresh); + bbr->params.extra_acked_gain = min(0x7FF, bbr_extra_acked_gain); + bbr->params.extra_acked_win_rtts = min(0x1FU, bbr_extra_acked_win_rtts); + bbr->params.drain_to_target = bbr_drain_to_target ? 1 : 0; + bbr->params.precise_ece_ack = bbr_precise_ece_ack ? 1 : 0; + bbr->params.extra_acked_in_startup = bbr_extra_acked_in_startup ? 1 : 0; + bbr->params.probe_rtt_cwnd_gain = min(0xFFU, bbr_probe_rtt_cwnd_gain); + bbr->params.probe_rtt_win_ms = + min(0x3FFFU, + min_t(u32, bbr_probe_rtt_win_ms, + bbr->params.min_rtt_win_sec * MSEC_PER_SEC)); + for (i = 0; i < CYCLE_LEN; i++) + bbr->params.pacing_gain[i] = min(0x3FF, bbr_pacing_gain[i]); + bbr->params.usage_based_cwnd = bbr_usage_based_cwnd ? 1 : 0; + bbr->params.tso_rtt_shift = min(0xFU, bbr_tso_rtt_shift); + + bbr->debug.snd_isn = tp->snd_una; + bbr->debug.target_cwnd = 0; + bbr->debug.undo = 0; + + bbr->init_cwnd = min(0x7FU, tp->snd_cwnd); + bbr->prior_cwnd = tp->prior_cwnd; + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + bbr->next_rtt_delivered = 0; + bbr->prev_ca_state = TCP_CA_Open; + bbr->packet_conservation = 0; + + bbr->probe_rtt_done_stamp = 0; + bbr->probe_rtt_round_done = 0; + bbr->probe_rtt_min_us = tcp_min_rtt(tp); + bbr->probe_rtt_min_stamp = tcp_jiffies32; + bbr->min_rtt_us = tcp_min_rtt(tp); + bbr->min_rtt_stamp = tcp_jiffies32; + + bbr->has_seen_rtt = 0; + bbr_init_pacing_rate_from_rtt(sk); + + bbr->round_start = 0; + bbr->idle_restart = 0; + bbr->full_bw_reached = 0; + bbr->full_bw = 0; + bbr->full_bw_cnt = 0; + bbr->cycle_mstamp = 0; + bbr->cycle_idx = 0; + bbr->mode = BBR_STARTUP; + bbr->debug.rs_bw = 0; + + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = 0; + bbr->extra_acked[0] = 0; + bbr->extra_acked[1] = 0; + + bbr->ce_state = 0; + bbr->prior_rcv_nxt = tp->rcv_nxt; + bbr->try_fast_path = 0; + + cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); +} + +static u32 bbr_sndbuf_expand(struct sock *sk) +{ + /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ + return 3; +} + +/* __________________________________________________________________________ + * + * Functions new to BBR v2 ("bbr") congestion control are below here. + * __________________________________________________________________________ + */ + +/* Incorporate a new bw sample into the current window of our max filter. */ +static void bbr2_take_bw_hi_sample(struct sock *sk, u32 bw) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->bw_hi[1] = max(bw, bbr->bw_hi[1]); +} + +/* Keep max of last 1-2 cycles. Each PROBE_BW cycle, flip filter window. */ +static void bbr2_advance_bw_hi_filter(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (!bbr->bw_hi[1]) + return; /* no samples in this window; remember old window */ + bbr->bw_hi[0] = bbr->bw_hi[1]; + bbr->bw_hi[1] = 0; +} + +/* How much do we want in flight? Our BDP, unless congestion cut cwnd. */ +static u32 bbr2_target_inflight(struct sock *sk) +{ + u32 bdp = bbr_inflight(sk, bbr_bw(sk), BBR_UNIT); + + return min(bdp, tcp_sk(sk)->snd_cwnd); +} + +static bool bbr2_is_probing_bandwidth(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return (bbr->mode == BBR_STARTUP) || + (bbr->mode == BBR_PROBE_BW && + (bbr->cycle_idx == BBR_BW_PROBE_REFILL || + bbr->cycle_idx == BBR_BW_PROBE_UP)); +} + +/* Has the given amount of time elapsed since we marked the phase start? */ +static bool bbr2_has_elapsed_in_phase(const struct sock *sk, u32 interval_us) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct bbr *bbr = inet_csk_ca(sk); + + return tcp_stamp_us_delta(tp->tcp_mstamp, + bbr->cycle_mstamp + interval_us) > 0; +} + +static void bbr2_handle_queue_too_high_in_startup(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->full_bw_reached = 1; + bbr->inflight_hi = bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); +} + +/* Exit STARTUP upon N consecutive rounds with ECN mark rate > ecn_thresh. */ +static void bbr2_check_ecn_too_high_in_startup(struct sock *sk, u32 ce_ratio) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk) || !bbr->ecn_eligible || + !bbr->params.full_ecn_cnt || !bbr->params.ecn_thresh) + return; + + if (ce_ratio >= bbr->params.ecn_thresh) + bbr->startup_ecn_rounds++; + else + bbr->startup_ecn_rounds = 0; + + if (bbr->startup_ecn_rounds >= bbr->params.full_ecn_cnt) { + bbr->debug.event = 'E'; /* ECN caused STARTUP exit */ + bbr2_handle_queue_too_high_in_startup(sk); + return; + } +} + +static int bbr2_update_ecn_alpha(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + s32 delivered, delivered_ce; + u64 alpha, ce_ratio; + u32 gain; + + if (bbr->params.ecn_factor == 0) + return -1; + + delivered = tp->delivered - bbr->alpha_last_delivered; + delivered_ce = tp->delivered_ce - bbr->alpha_last_delivered_ce; + + if (delivered == 0 || /* avoid divide by zero */ + WARN_ON_ONCE(delivered < 0 || delivered_ce < 0)) /* backwards? */ + return -1; + + /* See if we should use ECN sender logic for this connection. */ + if (!bbr->ecn_eligible && bbr_ecn_enable && + (bbr->min_rtt_us <= bbr->params.ecn_max_rtt_us || + !bbr->params.ecn_max_rtt_us)) + bbr->ecn_eligible = 1; + + ce_ratio = (u64)delivered_ce << BBR_SCALE; + do_div(ce_ratio, delivered); + gain = bbr->params.ecn_alpha_gain; + alpha = ((BBR_UNIT - gain) * bbr->ecn_alpha) >> BBR_SCALE; + alpha += (gain * ce_ratio) >> BBR_SCALE; + bbr->ecn_alpha = min_t(u32, alpha, BBR_UNIT); + + bbr->alpha_last_delivered = tp->delivered; + bbr->alpha_last_delivered_ce = tp->delivered_ce; + + bbr2_check_ecn_too_high_in_startup(sk, ce_ratio); + return (int)ce_ratio; +} + +/* Each round trip of BBR_BW_PROBE_UP, double volume of probing data. */ +static void bbr2_raise_inflight_hi_slope(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 growth_this_round, cnt; + + /* Calculate "slope": packets S/Acked per inflight_hi increment. */ + growth_this_round = 1 << bbr->bw_probe_up_rounds; + bbr->bw_probe_up_rounds = min(bbr->bw_probe_up_rounds + 1, 30); + cnt = tp->snd_cwnd / growth_this_round; + cnt = max(cnt, 1U); + bbr->bw_probe_up_cnt = cnt; + bbr->debug.event = 'G'; /* Grow inflight_hi slope */ +} + +/* In BBR_BW_PROBE_UP, not seeing high loss/ECN/queue, so raise inflight_hi. */ +static void bbr2_probe_inflight_hi_upward(struct sock *sk, + const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 delta; + + if (!tp->is_cwnd_limited || tp->snd_cwnd < bbr->inflight_hi) { + bbr->bw_probe_up_acks = 0; /* don't accmulate unused credits */ + return; /* not fully using inflight_hi, so don't grow it */ + } + + /* For each bw_probe_up_cnt packets ACKed, increase inflight_hi by 1. */ + bbr->bw_probe_up_acks += rs->acked_sacked; + if (bbr->bw_probe_up_acks >= bbr->bw_probe_up_cnt) { + delta = bbr->bw_probe_up_acks / bbr->bw_probe_up_cnt; + bbr->bw_probe_up_acks -= delta * bbr->bw_probe_up_cnt; + bbr->inflight_hi += delta; + bbr->debug.event = 'I'; /* Increment inflight_hi */ + } + + if (bbr->round_start) + bbr2_raise_inflight_hi_slope(sk); +} + +/* Does loss/ECN rate for this sample say inflight is "too high"? + * This is used by both the bbr_check_loss_too_high_in_startup() function, + * which can be used in either v1 or v2, and the PROBE_UP phase of v2, which + * uses it to notice when loss/ECN rates suggest inflight is too high. + */ +static bool bbr2_is_inflight_too_high(const struct sock *sk, + const struct rate_sample *rs) +{ + const struct bbr *bbr = inet_csk_ca(sk); + u32 loss_thresh, ecn_thresh; + + if (rs->lost > 0 && rs->tx_in_flight) { + loss_thresh = (u64)rs->tx_in_flight * bbr->params.loss_thresh >> + BBR_SCALE; + if (rs->lost > loss_thresh) + return true; + } + + if (rs->delivered_ce > 0 && rs->delivered > 0 && + bbr->ecn_eligible && bbr->params.ecn_thresh) { + ecn_thresh = (u64)rs->delivered * bbr->params.ecn_thresh >> + BBR_SCALE; + if (rs->delivered_ce >= ecn_thresh) + return true; + } + + return false; +} + +/* Calculate the tx_in_flight level that corresponded to excessive loss. + * We find "lost_prefix" segs of the skb where loss rate went too high, + * by solving for "lost_prefix" in the following equation: + * lost / inflight >= loss_thresh + * (lost_prev + lost_prefix) / (inflight_prev + lost_prefix) >= loss_thresh + * Then we take that equation, convert it to fixed point, and + * round up to the nearest packet. + */ +static u32 bbr2_inflight_hi_from_lost_skb(const struct sock *sk, + const struct rate_sample *rs, + const struct sk_buff *skb) +{ + const struct bbr *bbr = inet_csk_ca(sk); + u32 loss_thresh = bbr->params.loss_thresh; + u32 pcount, divisor, inflight_hi; + s32 inflight_prev, lost_prev; + u64 loss_budget, lost_prefix; + + pcount = tcp_skb_pcount(skb); + + /* How much data was in flight before this skb? */ + inflight_prev = rs->tx_in_flight - pcount; + if (WARN_ONCE(inflight_prev < 0, + "tx_in_flight: %u pcount: %u reneg: %u", + rs->tx_in_flight, pcount, tcp_sk(sk)->is_sack_reneg)) + return ~0U; + + /* How much inflight data was marked lost before this skb? */ + lost_prev = rs->lost - pcount; + if (WARN_ON_ONCE(lost_prev < 0)) + return ~0U; + + /* At what prefix of this lost skb did losss rate exceed loss_thresh? */ + loss_budget = (u64)inflight_prev * loss_thresh + BBR_UNIT - 1; + loss_budget >>= BBR_SCALE; + if (lost_prev >= loss_budget) { + lost_prefix = 0; /* previous losses crossed loss_thresh */ + } else { + lost_prefix = loss_budget - lost_prev; + lost_prefix <<= BBR_SCALE; + divisor = BBR_UNIT - loss_thresh; + if (WARN_ON_ONCE(!divisor)) /* loss_thresh is 8 bits */ + return ~0U; + do_div(lost_prefix, divisor); + } + + inflight_hi = inflight_prev + lost_prefix; + return inflight_hi; +} + +/* If loss/ECN rates during probing indicated we may have overfilled a + * buffer, return an operating point that tries to leave unutilized headroom in + * the path for other flows, for fairness convergence and lower RTTs and loss. + */ +static u32 bbr2_inflight_with_headroom(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 headroom, headroom_fraction; + + if (bbr->inflight_hi == ~0U) + return ~0U; + + headroom_fraction = bbr->params.inflight_headroom; + headroom = ((u64)bbr->inflight_hi * headroom_fraction) >> BBR_SCALE; + headroom = max(headroom, 1U); + return max_t(s32, bbr->inflight_hi - headroom, + bbr->params.cwnd_min_target); +} + +/* Bound cwnd to a sensible level, based on our current probing state + * machine phase and model of a good inflight level (inflight_lo, inflight_hi). + */ +static void bbr2_bound_cwnd_for_inflight_model(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 cap; + + /* tcp_rcv_synsent_state_process() currently calls tcp_ack() + * and thus cong_control() without first initializing us(!). + */ + if (!bbr->initialized) + return; + + cap = ~0U; + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx != BBR_BW_PROBE_CRUISE) { + /* Probe to see if more packets fit in the path. */ + cap = bbr->inflight_hi; + } else { + if (bbr->mode == BBR_PROBE_RTT || + (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_CRUISE)) + cap = bbr2_inflight_with_headroom(sk); + } + /* Adapt to any loss/ECN since our last bw probe. */ + cap = min(cap, bbr->inflight_lo); + + cap = max_t(u32, cap, bbr->params.cwnd_min_target); + tp->snd_cwnd = min(cap, tp->snd_cwnd); +} + +/* Estimate a short-term lower bound on the capacity available now, based + * on measurements of the current delivery process and recent history. When we + * are seeing loss/ECN at times when we are not probing bw, then conservatively + * move toward flow balance by multiplicatively cutting our short-term + * estimated safe rate and volume of data (bw_lo and inflight_lo). We use a + * multiplicative decrease in order to converge to a lower capacity in time + * logarithmic in the magnitude of the decrease. + * + * However, we do not cut our short-term estimates lower than the current rate + * and volume of delivered data from this round trip, since from the current + * delivery process we can estimate the measured capacity available now. + * + * Anything faster than that approach would knowingly risk high loss, which can + * cause low bw for Reno/CUBIC and high loss recovery latency for + * request/response flows using any congestion control. + */ +static void bbr2_adapt_lower_bounds(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 ecn_cut, ecn_inflight_lo, beta; + + /* We only use lower-bound estimates when not probing bw. + * When probing we need to push inflight higher to probe bw. + */ + if (bbr2_is_probing_bandwidth(sk)) + return; + + /* ECN response. */ + if (bbr->ecn_in_round && bbr->ecn_eligible && bbr->params.ecn_factor) { + /* Reduce inflight to (1 - alpha*ecn_factor). */ + ecn_cut = (BBR_UNIT - + ((bbr->ecn_alpha * bbr->params.ecn_factor) >> + BBR_SCALE)); + if (bbr->inflight_lo == ~0U) + bbr->inflight_lo = tp->snd_cwnd; + ecn_inflight_lo = (u64)bbr->inflight_lo * ecn_cut >> BBR_SCALE; + } else { + ecn_inflight_lo = ~0U; + } + + /* Loss response. */ + if (bbr->loss_in_round) { + /* Reduce bw and inflight to (1 - beta). */ + if (bbr->bw_lo == ~0U) + bbr->bw_lo = bbr_max_bw(sk); + if (bbr->inflight_lo == ~0U) + bbr->inflight_lo = tp->snd_cwnd; + beta = bbr->params.beta; + bbr->bw_lo = + max_t(u32, bbr->bw_latest, + (u64)bbr->bw_lo * + (BBR_UNIT - beta) >> BBR_SCALE); + bbr->inflight_lo = + max_t(u32, bbr->inflight_latest, + (u64)bbr->inflight_lo * + (BBR_UNIT - beta) >> BBR_SCALE); + } + + /* Adjust to the lower of the levels implied by loss or ECN. */ + bbr->inflight_lo = min(bbr->inflight_lo, ecn_inflight_lo); +} + +/* Reset any short-term lower-bound adaptation to congestion, so that we can + * push our inflight up. + */ +static void bbr2_reset_lower_bounds(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->bw_lo = ~0U; + bbr->inflight_lo = ~0U; +} + +/* After bw probing (STARTUP/PROBE_UP), reset signals before entering a state + * machine phase where we adapt our lower bound based on congestion signals. + */ +static void bbr2_reset_congestion_signals(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; + bbr->loss_in_cycle = 0; + bbr->ecn_in_cycle = 0; + bbr->bw_latest = 0; + bbr->inflight_latest = 0; +} + +/* Update (most of) our congestion signals: track the recent rate and volume of + * delivered data, presence of loss, and EWMA degree of ECN marking. + */ +static void bbr2_update_congestion_signals( + struct sock *sk, const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + + bbr->loss_round_start = 0; + if (rs->interval_us <= 0 || !rs->acked_sacked) + return; /* Not a valid observation */ + bw = ctx->sample_bw; + + if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) + bbr2_take_bw_hi_sample(sk, bw); + + bbr->loss_in_round |= (rs->losses > 0); + + /* Update rate and volume of delivered data from latest round trip: */ + bbr->bw_latest = max_t(u32, bbr->bw_latest, ctx->sample_bw); + bbr->inflight_latest = max_t(u32, bbr->inflight_latest, rs->delivered); + + if (before(rs->prior_delivered, bbr->loss_round_delivered)) + return; /* skip the per-round-trip updates */ + /* Now do per-round-trip updates. */ + bbr->loss_round_delivered = tp->delivered; /* mark round trip */ + bbr->loss_round_start = 1; + bbr2_adapt_lower_bounds(sk); + + /* Update windowed "latest" (single-round-trip) filters. */ + bbr->loss_in_round = 0; + bbr->ecn_in_round = 0; + bbr->bw_latest = ctx->sample_bw; + bbr->inflight_latest = rs->delivered; +} + +/* Bandwidth probing can cause loss. To help coexistence with loss-based + * congestion control we spread out our probing in a Reno-conscious way. Due to + * the shape of the Reno sawtooth, the time required between loss epochs for an + * idealized Reno flow is a number of round trips that is the BDP of that + * flow. We count packet-timed round trips directly, since measured RTT can + * vary widely, and Reno is driven by packet-timed round trips. + */ +static bool bbr2_is_reno_coexistence_probe_time(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 inflight, rounds, reno_gain, reno_rounds; + + /* Random loss can shave some small percentage off of our inflight + * in each round. To survive this, flows need robust periodic probes. + */ + rounds = bbr->params.bw_probe_max_rounds; + + reno_gain = bbr->params.bw_probe_reno_gain; + if (reno_gain) { + inflight = bbr2_target_inflight(sk); + reno_rounds = ((u64)inflight * reno_gain) >> BBR_SCALE; + rounds = min(rounds, reno_rounds); + } + return bbr->rounds_since_probe >= rounds; +} + +/* How long do we want to wait before probing for bandwidth (and risking + * loss)? We randomize the wait, for better mixing and fairness convergence. + * + * We bound the Reno-coexistence inter-bw-probe time to be 62-63 round trips. + * This is calculated to allow fairness with a 25Mbps, 30ms Reno flow, + * (eg 4K video to a broadband user): + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + * + * We bound the BBR-native inter-bw-probe wall clock time to be: + * (a) higher than 2 sec: to try to avoid causing loss for a long enough time + * to allow Reno at 30ms to get 4K video bw, the inter-bw-probe time must + * be at least: 25Mbps * .030sec / (1514bytes) * 0.030sec = 1.9secs + * (b) lower than 3 sec: to ensure flows can start probing in a reasonable + * amount of time to discover unutilized bw on human-scale interactive + * time-scales (e.g. perhaps traffic from a web page download that we + * were competing with is now complete). + */ +static void bbr2_pick_probe_wait(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Decide the random round-trip bound for wait until probe: */ + bbr->rounds_since_probe = + prandom_u32_max(bbr->params.bw_probe_rand_rounds); + /* Decide the random wall clock bound for wait until probe: */ + bbr->probe_wait_us = bbr->params.bw_probe_base_us + + prandom_u32_max(bbr->params.bw_probe_rand_us); +} + +static void bbr2_set_cycle_idx(struct sock *sk, int cycle_idx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->cycle_idx = cycle_idx; + /* New phase, so need to update cwnd and pacing rate. */ + bbr->try_fast_path = 0; +} + +/* Send at estimated bw to fill the pipe, but not queue. We need this phase + * before PROBE_UP, because as soon as we send faster than the available bw + * we will start building a queue, and if the buffer is shallow we can cause + * loss. If we do not fill the pipe before we cause this loss, our bw_hi and + * inflight_hi estimates will underestimate. + */ +static void bbr2_start_bw_probe_refill(struct sock *sk, u32 bw_probe_up_rounds) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr2_reset_lower_bounds(sk); + if (bbr->inflight_hi != ~0U) + bbr->inflight_hi += bbr->params.refill_add_inc; + bbr->bw_probe_up_rounds = bw_probe_up_rounds; + bbr->bw_probe_up_acks = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_REFILLING; + bbr->next_rtt_delivered = tp->delivered; + bbr2_set_cycle_idx(sk, BBR_BW_PROBE_REFILL); +} + +/* Now probe max deliverable data rate and volume. */ +static void bbr2_start_bw_probe_up(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->ack_phase = BBR_ACKS_PROBE_STARTING; + bbr->next_rtt_delivered = tp->delivered; + bbr->cycle_mstamp = tp->tcp_mstamp; + bbr2_set_cycle_idx(sk, BBR_BW_PROBE_UP); + bbr2_raise_inflight_hi_slope(sk); +} + +/* Start a new PROBE_BW probing cycle of some wall clock length. Pick a wall + * clock time at which to probe beyond an inflight that we think to be + * safe. This will knowingly risk packet loss, so we want to do this rarely, to + * keep packet loss rates low. Also start a round-trip counter, to probe faster + * if we estimate a Reno flow at our BDP would probe faster. + */ +static void bbr2_start_bw_probe_down(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr2_reset_congestion_signals(sk); + bbr->bw_probe_up_cnt = ~0U; /* not growing inflight_hi any more */ + bbr2_pick_probe_wait(sk); + bbr->cycle_mstamp = tp->tcp_mstamp; /* start wall clock */ + bbr->ack_phase = BBR_ACKS_PROBE_STOPPING; + bbr->next_rtt_delivered = tp->delivered; + bbr2_set_cycle_idx(sk, BBR_BW_PROBE_DOWN); +} + +/* Cruise: maintain what we estimate to be a neutral, conservative + * operating point, without attempting to probe up for bandwidth or down for + * RTT, and only reducing inflight in response to loss/ECN signals. + */ +static void bbr2_start_bw_probe_cruise(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->inflight_lo != ~0U) + bbr->inflight_lo = min(bbr->inflight_lo, bbr->inflight_hi); + + bbr2_set_cycle_idx(sk, BBR_BW_PROBE_CRUISE); +} + +/* Loss and/or ECN rate is too high while probing. + * Adapt (once per bw probe) by cutting inflight_hi and then restarting cycle. + */ +static void bbr2_handle_inflight_too_high(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + const u32 beta = bbr->params.beta; + + bbr->prev_probe_too_high = 1; + bbr->bw_probe_samples = 0; /* only react once per probe */ + bbr->debug.event = 'L'; /* Loss/ECN too high */ + /* If we are app-limited then we are not robustly + * probing the max volume of inflight data we think + * might be safe (analogous to how app-limited bw + * samples are not known to be robustly probing bw). + */ + if (!rs->is_app_limited) + bbr->inflight_hi = max_t(u32, rs->tx_in_flight, + (u64)bbr2_target_inflight(sk) * + (BBR_UNIT - beta) >> BBR_SCALE); + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr2_start_bw_probe_down(sk); +} + +/* If we're seeing bw and loss samples reflecting our bw probing, adapt + * using the signals we see. If loss or ECN mark rate gets too high, then adapt + * inflight_hi downward. If we're able to push inflight higher without such + * signals, push higher: adapt inflight_hi upward. + */ +static bool bbr2_adapt_upper_bounds(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Track when we'll see bw/loss samples resulting from our bw probes. */ + if (bbr->ack_phase == BBR_ACKS_PROBE_STARTING && bbr->round_start) + bbr->ack_phase = BBR_ACKS_PROBE_FEEDBACK; + if (bbr->ack_phase == BBR_ACKS_PROBE_STOPPING && bbr->round_start) { + /* End of samples from bw probing phase. */ + bbr->bw_probe_samples = 0; + bbr->ack_phase = BBR_ACKS_INIT; + /* At this point in the cycle, our current bw sample is also + * our best recent chance at finding the highest available bw + * for this flow. So now is the best time to forget the bw + * samples from the previous cycle, by advancing the window. + */ + if (bbr->mode == BBR_PROBE_BW && !rs->is_app_limited) + bbr2_advance_bw_hi_filter(sk); + /* If we had an inflight_hi, then probed and pushed inflight all + * the way up to hit that inflight_hi without seeing any + * high loss/ECN in all the resulting ACKs from that probing, + * then probe up again, this time letting inflight persist at + * inflight_hi for a round trip, then accelerating beyond. + */ + if (bbr->mode == BBR_PROBE_BW && + bbr->stopped_risky_probe && !bbr->prev_probe_too_high) { + bbr->debug.event = 'R'; /* reprobe */ + bbr2_start_bw_probe_refill(sk, 0); + return true; /* yes, decided state transition */ + } + } + + if (bbr2_is_inflight_too_high(sk, rs)) { + if (bbr->bw_probe_samples) /* sample is from bw probing? */ + bbr2_handle_inflight_too_high(sk, rs); + } else { + /* Loss/ECN rate is declared safe. Adjust upper bound upward. */ + if (bbr->inflight_hi == ~0U) /* no excess queue signals yet? */ + return false; + + /* To be resilient to random loss, we must raise inflight_hi + * if we observe in any phase that a higher level is safe. + */ + if (rs->tx_in_flight > bbr->inflight_hi) { + bbr->inflight_hi = rs->tx_in_flight; + bbr->debug.event = 'U'; /* raise up inflight_hi */ + } + + if (bbr->mode == BBR_PROBE_BW && + bbr->cycle_idx == BBR_BW_PROBE_UP) + bbr2_probe_inflight_hi_upward(sk, rs); + } + + return false; +} + +/* Check if it's time to probe for bandwidth now, and if so, kick it off. */ +static bool bbr2_check_time_to_probe_bw(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 n; + + /* If we seem to be at an operating point where we are not seeing loss + * but we are seeing ECN marks, then when the ECN marks cease we reprobe + * quickly (in case a burst of cross-traffic has ceased and freed up bw, + * or in case we are sharing with multiplicatively probing traffic). + */ + if (bbr->params.ecn_reprobe_gain && bbr->ecn_eligible && + bbr->ecn_in_cycle && !bbr->loss_in_cycle && + inet_csk(sk)->icsk_ca_state == TCP_CA_Open) { + bbr->debug.event = 'A'; /* *A*ll clear to probe *A*gain */ + /* Calculate n so that when bbr2_raise_inflight_hi_slope() + * computes growth_this_round as 2^n it will be roughly the + * desired volume of data (inflight_hi*ecn_reprobe_gain). + */ + n = ilog2((((u64)bbr->inflight_hi * + bbr->params.ecn_reprobe_gain) >> BBR_SCALE)); + bbr2_start_bw_probe_refill(sk, n); + return true; + } + + if (bbr2_has_elapsed_in_phase(sk, bbr->probe_wait_us) || + bbr2_is_reno_coexistence_probe_time(sk)) { + bbr2_start_bw_probe_refill(sk, 0); + return true; + } + return false; +} + +/* Is it time to transition from PROBE_DOWN to PROBE_CRUISE? */ +static bool bbr2_check_time_to_cruise(struct sock *sk, u32 inflight, u32 bw) +{ + struct bbr *bbr = inet_csk_ca(sk); + bool is_under_bdp, is_long_enough; + + /* Always need to pull inflight down to leave headroom in queue. */ + if (inflight > bbr2_inflight_with_headroom(sk)) + return false; + + is_under_bdp = inflight <= bbr_inflight(sk, bw, BBR_UNIT); + if (bbr->params.drain_to_target) + return is_under_bdp; + + is_long_enough = bbr2_has_elapsed_in_phase(sk, bbr->min_rtt_us); + return is_under_bdp || is_long_enough; +} + +/* PROBE_BW state machine: cruise, refill, probe for bw, or drain? */ +static void bbr2_update_cycle_phase(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + bool is_risky = false, is_queuing = false; + u32 inflight, bw; + + if (!bbr_full_bw_reached(sk)) + return; + + /* In DRAIN, PROBE_BW, or PROBE_RTT, adjust upper bounds. */ + if (bbr2_adapt_upper_bounds(sk, rs)) + return; /* already decided state transition */ + + if (bbr->mode != BBR_PROBE_BW) + return; + + inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); + bw = bbr_max_bw(sk); + + switch (bbr->cycle_idx) { + /* First we spend most of our time cruising with a pacing_gain of 1.0, + * which paces at the estimated bw, to try to fully use the pipe + * without building queue. If we encounter loss/ECN marks, we adapt + * by slowing down. + */ + case BBR_BW_PROBE_CRUISE: + if (bbr2_check_time_to_probe_bw(sk)) + return; /* already decided state transition */ + break; + + /* After cruising, when it's time to probe, we first "refill": we send + * at the estimated bw to fill the pipe, before probing higher and + * knowingly risking overflowing the bottleneck buffer (causing loss). + */ + case BBR_BW_PROBE_REFILL: + if (bbr->round_start) { + /* After one full round trip of sending in REFILL, we + * start to see bw samples reflecting our REFILL, which + * may be putting too much data in flight. + */ + bbr->bw_probe_samples = 1; + bbr2_start_bw_probe_up(sk); + } + break; + + /* After we refill the pipe, we probe by using a pacing_gain > 1.0, to + * probe for bw. If we have not seen loss/ECN, we try to raise inflight + * to at least pacing_gain*BDP; note that this may take more than + * min_rtt if min_rtt is small (e.g. on a LAN). + * + * We terminate PROBE_UP bandwidth probing upon any of the following: + * + * (1) We've pushed inflight up to hit the inflight_hi target set in the + * most recent previous bw probe phase. Thus we want to start + * draining the queue immediately because it's very likely the most + * recently sent packets will fill the queue and cause drops. + * (checked here) + * (2) We have probed for at least 1*min_rtt_us, and the + * estimated queue is high enough (inflight > 1.25 * estimated_bdp). + * (checked here) + * (3) Loss filter says loss rate is "too high". + * (checked in bbr_is_inflight_too_high()) + * (4) ECN filter says ECN mark rate is "too high". + * (checked in bbr_is_inflight_too_high()) + */ + case BBR_BW_PROBE_UP: + if (bbr->prev_probe_too_high && + inflight >= bbr->inflight_hi) { + bbr->stopped_risky_probe = 1; + is_risky = true; + bbr->debug.event = 'D'; /* D for danger */ + } else if (bbr2_has_elapsed_in_phase(sk, bbr->min_rtt_us) && + inflight >= + bbr_inflight(sk, bw, + bbr->params.bw_probe_pif_gain)) { + is_queuing = true; + bbr->debug.event = 'Q'; /* building Queue */ + } + if (is_risky || is_queuing) { + bbr->prev_probe_too_high = 0; /* no loss/ECN (yet) */ + bbr2_start_bw_probe_down(sk); /* restart w/ down */ + } + break; + + /* After probing in PROBE_UP, we have usually accumulated some data in + * the bottleneck buffer (if bw probing didn't find more bw). We next + * enter PROBE_DOWN to try to drain any excess data from the queue. To + * do this, we use a pacing_gain < 1.0. We hold this pacing gain until + * our inflight is less then that target cruising point, which is the + * minimum of (a) the amount needed to leave headroom, and (b) the + * estimated BDP. Once inflight falls to match the target, we estimate + * the queue is drained; persisting would underutilize the pipe. + */ + case BBR_BW_PROBE_DOWN: + if (bbr2_check_time_to_probe_bw(sk)) + return; /* already decided state transition */ + if (bbr2_check_time_to_cruise(sk, inflight, bw)) + bbr2_start_bw_probe_cruise(sk); + break; + + default: + WARN_ONCE(1, "BBR invalid cycle index %u\n", bbr->cycle_idx); + } +} + +/* Exiting PROBE_RTT, so return to bandwidth probing in STARTUP or PROBE_BW. */ +static void bbr2_exit_probe_rtt(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr2_reset_lower_bounds(sk); + if (bbr_full_bw_reached(sk)) { + bbr->mode = BBR_PROBE_BW; + /* Raising inflight after PROBE_RTT may cause loss, so reset + * the PROBE_BW clock and schedule the next bandwidth probe for + * a friendly and randomized future point in time. + */ + bbr2_start_bw_probe_down(sk); + /* Since we are exiting PROBE_RTT, we know inflight is + * below our estimated BDP, so it is reasonable to cruise. + */ + bbr2_start_bw_probe_cruise(sk); + } else { + bbr->mode = BBR_STARTUP; + } +} + +/* Exit STARTUP based on loss rate > 1% and loss gaps in round >= N. Wait until + * the end of the round in recovery to get a good estimate of how many packets + * have been lost, and how many we need to drain with a low pacing rate. + */ +static void bbr2_check_loss_too_high_in_startup(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_full_bw_reached(sk)) + return; + + /* For STARTUP exit, check the loss rate at the end of each round trip + * of Recovery episodes in STARTUP. We check the loss rate at the end + * of the round trip to filter out noisy/low loss and have a better + * sense of inflight (extent of loss), so we can drain more accurately. + */ + if (rs->losses && bbr->loss_events_in_round < 0xf) + bbr->loss_events_in_round++; /* update saturating counter */ + if (bbr->params.full_loss_cnt && bbr->loss_round_start && + inet_csk(sk)->icsk_ca_state == TCP_CA_Recovery && + bbr->loss_events_in_round >= bbr->params.full_loss_cnt && + bbr2_is_inflight_too_high(sk, rs)) { + bbr->debug.event = 'P'; /* Packet loss caused STARTUP exit */ + bbr2_handle_queue_too_high_in_startup(sk); + return; + } + if (bbr->loss_round_start) + bbr->loss_events_in_round = 0; +} + +/* If we are done draining, advance into steady state operation in PROBE_BW. */ +static void bbr2_check_drain(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr_check_drain(sk, rs, ctx)) { + bbr->mode = BBR_PROBE_BW; + bbr2_start_bw_probe_down(sk); + } +} + +static void bbr2_update_model(struct sock *sk, const struct rate_sample *rs, + struct bbr_context *ctx) +{ + bbr2_update_congestion_signals(sk, rs, ctx); + bbr_update_ack_aggregation(sk, rs); + bbr2_check_loss_too_high_in_startup(sk, rs); + bbr_check_full_bw_reached(sk, rs); + bbr2_check_drain(sk, rs, ctx); + bbr2_update_cycle_phase(sk, rs); + bbr_update_min_rtt(sk, rs); +} + +/* Fast path for app-limited case. + * + * On each ack, we execute bbr state machine, which primarily consists of: + * 1) update model based on new rate sample, and + * 2) update control based on updated model or state change. + * + * There are certain workload/scenarios, e.g. app-limited case, where + * either we can skip updating model or we can skip update of both model + * as well as control. This provides signifcant softirq cpu savings for + * processing incoming acks. + * + * In case of app-limited, if there is no congestion (loss/ecn) and + * if observed bw sample is less than current estimated bw, then we can + * skip some of the computation in bbr state processing: + * + * - if there is no rtt/mode/phase change: In this case, since all the + * parameters of the network model are constant, we can skip model + * as well control update. + * + * - else we can skip rest of the model update. But we still need to + * update the control to account for the new rtt/mode/phase. + * + * Returns whether we can take fast path or not. + */ +static bool bbr2_fast_path(struct sock *sk, bool *update_model, + const struct rate_sample *rs, struct bbr_context *ctx) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 prev_min_rtt_us, prev_mode; + + if (bbr->params.fast_path && bbr->try_fast_path && + rs->is_app_limited && ctx->sample_bw < bbr_max_bw(sk) && + !bbr->loss_in_round && !bbr->ecn_in_round) { + prev_mode = bbr->mode; + prev_min_rtt_us = bbr->min_rtt_us; + bbr2_check_drain(sk, rs, ctx); + bbr2_update_cycle_phase(sk, rs); + bbr_update_min_rtt(sk, rs); + + if (bbr->mode == prev_mode && + bbr->min_rtt_us == prev_min_rtt_us && + bbr->try_fast_path) + return true; + + /* Skip model update, but control still needs to be updated */ + *update_model = false; + } + return false; +} + +static void bbr2_main(struct sock *sk, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct bbr_context ctx = { 0 }; + bool update_model = true; + u32 bw; + int ce_ratio = -1; + + bbr->debug.event = '.'; /* init to default NOP (no event yet) */ + + bbr_update_round_start(sk, rs, &ctx); + if (bbr->round_start) { + bbr->rounds_since_probe = + min_t(s32, bbr->rounds_since_probe + 1, 0xFF); + ce_ratio = bbr2_update_ecn_alpha(sk); + tcp_plb_update_state(sk, &bbr->plb, ce_ratio); + tcp_plb_check_rehash(sk, &bbr->plb); + } + + bbr->ecn_in_round |= rs->is_ece; + bbr_calculate_bw_sample(sk, rs, &ctx); + + if (bbr2_fast_path(sk, &update_model, rs, &ctx)) + goto out; + + if (update_model) + bbr2_update_model(sk, rs, &ctx); + + bbr_update_gains(sk); + bw = bbr_bw(sk); + bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); + bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain, + tp->snd_cwnd, &ctx); + bbr2_bound_cwnd_for_inflight_model(sk); + +out: + bbr->prev_ca_state = inet_csk(sk)->icsk_ca_state; + bbr->loss_in_cycle |= rs->lost > 0; + bbr->ecn_in_cycle |= rs->delivered_ce > 0; + + bbr_debug(sk, rs->acked_sacked, rs, &ctx); +} + +/* Module parameters that are settable by TCP_CONGESTION_PARAMS are declared + * down here, so that the algorithm functions that use the parameters must use + * the per-socket parameters; if they accidentally use the global version + * then there will be a compile error. + * TODO(ncardwell): move all per-socket parameters down to this section. + */ + +/* On losses, scale down inflight and pacing rate by beta scaled by BBR_SCALE. + * No loss response when 0. Max allwed value is 255. + */ +static u32 bbr_beta = BBR_UNIT * 30 / 100; + +/* Gain factor for ECN mark ratio samples, scaled by BBR_SCALE. + * Max allowed value is 255. + */ +static u32 bbr_ecn_alpha_gain = BBR_UNIT * 1 / 16; /* 1/16 = 6.25% */ + +/* The initial value for the ecn_alpha state variable. Default and max + * BBR_UNIT (256), representing 1.0. This allows a flow to respond quickly + * to congestion if the bottleneck is congested when the flow starts up. + */ +static u32 bbr_ecn_alpha_init = BBR_UNIT; /* 1.0, to respond quickly */ + +/* On ECN, cut inflight_lo to (1 - ecn_factor * ecn_alpha) scaled by BBR_SCALE. + * No ECN based bounding when 0. Max allwed value is 255. + */ +static u32 bbr_ecn_factor = BBR_UNIT * 1 / 3; /* 1/3 = 33% */ + +/* Estimate bw probing has gone too far if CE ratio exceeds this threshold. + * Scaled by BBR_SCALE. Disabled when 0. Max allowed is 255. + */ +static u32 bbr_ecn_thresh = BBR_UNIT * 1 / 2; /* 1/2 = 50% */ + +/* Max RTT (in usec) at which to use sender-side ECN logic. + * Disabled when 0 (ECN allowed at any RTT). + * Max allowed for the parameter is 524287 (0x7ffff) us, ~524 ms. + */ +static u32 bbr_ecn_max_rtt_us = 5000; + +/* If non-zero, if in a cycle with no losses but some ECN marks, after ECN + * clears then use a multiplicative increase to quickly reprobe bw by + * starting inflight probing at the given multiple of inflight_hi. + * Default for this experimental knob is 0 (disabled). + * Planned value for experiments: BBR_UNIT * 1 / 2 = 128, representing 0.5. + */ +static u32 bbr_ecn_reprobe_gain; + +/* Estimate bw probing has gone too far if loss rate exceeds this level. */ +static u32 bbr_loss_thresh = BBR_UNIT * 2 / 100; /* 2% loss */ + +/* Exit STARTUP if number of loss marking events in a Recovery round is >= N, + * and loss rate is higher than bbr_loss_thresh. + * Disabled if 0. Max allowed value is 15 (0xF). + */ +static u32 bbr_full_loss_cnt = 8; + +/* Exit STARTUP if number of round trips with ECN mark rate above ecn_thresh + * meets this count. Max allowed value is 3. + */ +static u32 bbr_full_ecn_cnt = 2; + +/* Fraction of unutilized headroom to try to leave in path upon high loss. */ +static u32 bbr_inflight_headroom = BBR_UNIT * 15 / 100; + +/* Multiplier to get target inflight (as multiple of BDP) for PROBE_UP phase. + * Default is 1.25x, as in BBR v1. Max allowed is 511. + */ +static u32 bbr_bw_probe_pif_gain = BBR_UNIT * 5 / 4; + +/* Multiplier to get Reno-style probe epoch duration as: k * BDP round trips. + * If zero, disables this BBR v2 Reno-style BDP-scaled coexistence mechanism. + * Max allowed is 511. + */ +static u32 bbr_bw_probe_reno_gain = BBR_UNIT; + +/* Max number of packet-timed rounds to wait before probing for bandwidth. If + * we want to tolerate 1% random loss per round, and not have this cut our + * inflight too much, we must probe for bw periodically on roughly this scale. + * If low, limits Reno/CUBIC coexistence; if high, limits loss tolerance. + * We aim to be fair with Reno/CUBIC up to a BDP of at least: + * BDP = 25Mbps * .030sec /(1514bytes) = 61.9 packets + */ +static u32 bbr_bw_probe_max_rounds = 63; + +/* Max amount of randomness to inject in round counting for Reno-coexistence. + * Max value is 15. + */ +static u32 bbr_bw_probe_rand_rounds = 2; + +/* Use BBR-native probe time scale starting at this many usec. + * We aim to be fair with Reno/CUBIC up to an inter-loss time epoch of at least: + * BDP*RTT = 25Mbps * .030sec /(1514bytes) * 0.030sec = 1.9 secs + */ +static u32 bbr_bw_probe_base_us = 2 * USEC_PER_SEC; /* 2 secs */ + +/* Use BBR-native probes spread over this many usec: */ +static u32 bbr_bw_probe_rand_us = 1 * USEC_PER_SEC; /* 1 secs */ + +/* Undo the model changes made in loss recovery if recovery was spurious? */ +static bool bbr_undo = true; + +/* Use fast path if app-limited, no loss/ECN, and target cwnd was reached? */ +static bool bbr_fast_path = true; /* default: enabled */ + +/* Use fast ack mode ? */ +static int bbr_fast_ack_mode = 1; /* default: rwnd check off */ + +/* How much to additively increase inflight_hi when entering REFILL? */ +static u32 bbr_refill_add_inc; /* default: disabled */ + +module_param_named(beta, bbr_beta, uint, 0644); +module_param_named(ecn_alpha_gain, bbr_ecn_alpha_gain, uint, 0644); +module_param_named(ecn_alpha_init, bbr_ecn_alpha_init, uint, 0644); +module_param_named(ecn_factor, bbr_ecn_factor, uint, 0644); +module_param_named(ecn_thresh, bbr_ecn_thresh, uint, 0644); +module_param_named(ecn_max_rtt_us, bbr_ecn_max_rtt_us, uint, 0644); +module_param_named(ecn_reprobe_gain, bbr_ecn_reprobe_gain, uint, 0644); +module_param_named(loss_thresh, bbr_loss_thresh, uint, 0664); +module_param_named(full_loss_cnt, bbr_full_loss_cnt, uint, 0664); +module_param_named(full_ecn_cnt, bbr_full_ecn_cnt, uint, 0664); +module_param_named(inflight_headroom, bbr_inflight_headroom, uint, 0664); +module_param_named(bw_probe_pif_gain, bbr_bw_probe_pif_gain, uint, 0664); +module_param_named(bw_probe_reno_gain, bbr_bw_probe_reno_gain, uint, 0664); +module_param_named(bw_probe_max_rounds, bbr_bw_probe_max_rounds, uint, 0664); +module_param_named(bw_probe_rand_rounds, bbr_bw_probe_rand_rounds, uint, 0664); +module_param_named(bw_probe_base_us, bbr_bw_probe_base_us, uint, 0664); +module_param_named(bw_probe_rand_us, bbr_bw_probe_rand_us, uint, 0664); +module_param_named(undo, bbr_undo, bool, 0664); +module_param_named(fast_path, bbr_fast_path, bool, 0664); +module_param_named(fast_ack_mode, bbr_fast_ack_mode, uint, 0664); +module_param_named(refill_add_inc, bbr_refill_add_inc, uint, 0664); + +static void bbr2_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + const struct net *net = sock_net(sk); + + bbr_init(sk); /* run shared init code for v1 and v2 */ + + /* BBR v2 parameters: */ + bbr->params.beta = min_t(u32, 0xFFU, bbr_beta); + bbr->params.ecn_alpha_gain = min_t(u32, 0xFFU, bbr_ecn_alpha_gain); + bbr->params.ecn_alpha_init = min_t(u32, BBR_UNIT, bbr_ecn_alpha_init); + bbr->params.ecn_factor = min_t(u32, 0xFFU, bbr_ecn_factor); + bbr->params.ecn_thresh = min_t(u32, 0xFFU, bbr_ecn_thresh); + bbr->params.ecn_max_rtt_us = min_t(u32, 0x7ffffU, bbr_ecn_max_rtt_us); + bbr->params.ecn_reprobe_gain = min_t(u32, 0x1FF, bbr_ecn_reprobe_gain); + bbr->params.loss_thresh = min_t(u32, 0xFFU, bbr_loss_thresh); + bbr->params.full_loss_cnt = min_t(u32, 0xFU, bbr_full_loss_cnt); + bbr->params.full_ecn_cnt = min_t(u32, 0x3U, bbr_full_ecn_cnt); + bbr->params.inflight_headroom = + min_t(u32, 0xFFU, bbr_inflight_headroom); + bbr->params.bw_probe_pif_gain = + min_t(u32, 0x1FFU, bbr_bw_probe_pif_gain); + bbr->params.bw_probe_reno_gain = + min_t(u32, 0x1FFU, bbr_bw_probe_reno_gain); + bbr->params.bw_probe_max_rounds = + min_t(u32, 0xFFU, bbr_bw_probe_max_rounds); + bbr->params.bw_probe_rand_rounds = + min_t(u32, 0xFU, bbr_bw_probe_rand_rounds); + bbr->params.bw_probe_base_us = + min_t(u32, (1 << 26) - 1, bbr_bw_probe_base_us); + bbr->params.bw_probe_rand_us = + min_t(u32, (1 << 26) - 1, bbr_bw_probe_rand_us); + bbr->params.undo = bbr_undo; + bbr->params.fast_path = bbr_fast_path ? 1 : 0; + bbr->params.refill_add_inc = min_t(u32, 0x3U, bbr_refill_add_inc); + + /* BBR v2 state: */ + bbr->initialized = 1; + /* Start sampling ECN mark rate after first full flight is ACKed: */ + bbr->loss_round_delivered = tp->delivered + 1; + bbr->loss_round_start = 0; + bbr->undo_bw_lo = 0; + bbr->undo_inflight_lo = 0; + bbr->undo_inflight_hi = 0; + bbr->loss_events_in_round = 0; + bbr->startup_ecn_rounds = 0; + bbr2_reset_congestion_signals(sk); + bbr->bw_lo = ~0U; + bbr->bw_hi[0] = 0; + bbr->bw_hi[1] = 0; + bbr->inflight_lo = ~0U; + bbr->inflight_hi = ~0U; + bbr->bw_probe_up_cnt = ~0U; + bbr->bw_probe_up_acks = 0; + bbr->bw_probe_up_rounds = 0; + bbr->probe_wait_us = 0; + bbr->stopped_risky_probe = 0; + bbr->ack_phase = BBR_ACKS_INIT; + bbr->rounds_since_probe = 0; + bbr->bw_probe_samples = 0; + bbr->prev_probe_too_high = 0; + bbr->ecn_eligible = 0; + bbr->ecn_alpha = bbr->params.ecn_alpha_init; + bbr->alpha_last_delivered = 0; + bbr->alpha_last_delivered_ce = 0; + + bbr->plb.enabled = 0; + bbr->plb.consec_cong_rounds = 0; + bbr->plb.pause_until = 0; + if ((tp->ecn_flags & TCP_ECN_OK) && + net->ipv4.sysctl_tcp_plb_enabled) + bbr->plb.enabled = 1; + + tp->fast_ack_mode = min_t(u32, 0x2U, bbr_fast_ack_mode); + + if ((tp->ecn_flags & TCP_ECN_OK) && bbr_ecn_enable) + tp->ecn_flags |= TCP_ECN_ECT_PERMANENT; +} + +/* Core TCP stack informs us that the given skb was just marked lost. */ +static void bbr2_skb_marked_lost(struct sock *sk, const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + struct rate_sample rs; + + /* Capture "current" data over the full round trip of loss, + * to have a better chance to see the full capacity of the path. + */ + if (!bbr->loss_in_round) /* first loss in this round trip? */ + bbr->loss_round_delivered = tp->delivered; /* set round trip */ + bbr->loss_in_round = 1; + bbr->loss_in_cycle = 1; + + if (!bbr->bw_probe_samples) + return; /* not an skb sent while probing for bandwidth */ + if (unlikely(!scb->tx.delivered_mstamp)) + return; /* skb was SACKed, reneged, marked lost; ignore it */ + /* We are probing for bandwidth. Construct a rate sample that + * estimates what happened in the flight leading up to this lost skb, + * then see if the loss rate went too high, and if so at which packet. + */ + memset(&rs, 0, sizeof(rs)); + rs.tx_in_flight = scb->tx.in_flight; + rs.lost = tp->lost - scb->tx.lost; + rs.is_app_limited = scb->tx.is_app_limited; + if (bbr2_is_inflight_too_high(sk, &rs)) { + rs.tx_in_flight = bbr2_inflight_hi_from_lost_skb(sk, &rs, skb); + bbr2_handle_inflight_too_high(sk, &rs); + } +} + +/* Revert short-term model if current loss recovery event was spurious. */ +static u32 bbr2_undo_cwnd(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->debug.undo = 1; + bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ + bbr->full_bw_cnt = 0; + bbr->loss_in_round = 0; + + if (!bbr->params.undo) + return tp->snd_cwnd; + + /* Revert to cwnd and other state saved before loss episode. */ + bbr->bw_lo = max(bbr->bw_lo, bbr->undo_bw_lo); + bbr->inflight_lo = max(bbr->inflight_lo, bbr->undo_inflight_lo); + bbr->inflight_hi = max(bbr->inflight_hi, bbr->undo_inflight_hi); + return bbr->prior_cwnd; +} + +/* Entering loss recovery, so save state for when we undo recovery. */ +static u32 bbr2_ssthresh(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr_save_cwnd(sk); + /* For undo, save state that adapts based on loss signal. */ + bbr->undo_bw_lo = bbr->bw_lo; + bbr->undo_inflight_lo = bbr->inflight_lo; + bbr->undo_inflight_hi = bbr->inflight_hi; + return tcp_sk(sk)->snd_ssthresh; +} + +static enum tcp_bbr2_phase bbr2_get_phase(struct bbr *bbr) +{ + switch (bbr->mode) { + case BBR_STARTUP: + return BBR2_PHASE_STARTUP; + case BBR_DRAIN: + return BBR2_PHASE_DRAIN; + case BBR_PROBE_BW: + break; + case BBR_PROBE_RTT: + return BBR2_PHASE_PROBE_RTT; + default: + return BBR2_PHASE_INVALID; + } + switch (bbr->cycle_idx) { + case BBR_BW_PROBE_UP: + return BBR2_PHASE_PROBE_BW_UP; + case BBR_BW_PROBE_DOWN: + return BBR2_PHASE_PROBE_BW_DOWN; + case BBR_BW_PROBE_CRUISE: + return BBR2_PHASE_PROBE_BW_CRUISE; + case BBR_BW_PROBE_REFILL: + return BBR2_PHASE_PROBE_BW_REFILL; + default: + return BBR2_PHASE_INVALID; + } +} + +static size_t bbr2_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + if (ext & (1 << (INET_DIAG_BBRINFO - 1)) || + ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + struct bbr *bbr = inet_csk_ca(sk); + u64 bw = bbr_bw_bytes_per_sec(sk, bbr_bw(sk)); + u64 bw_hi = bbr_bw_bytes_per_sec(sk, bbr_max_bw(sk)); + u64 bw_lo = bbr->bw_lo == ~0U ? + ~0ULL : bbr_bw_bytes_per_sec(sk, bbr->bw_lo); + + memset(&info->bbr2, 0, sizeof(info->bbr2)); + info->bbr2.bbr_bw_lsb = (u32)bw; + info->bbr2.bbr_bw_msb = (u32)(bw >> 32); + info->bbr2.bbr_min_rtt = bbr->min_rtt_us; + info->bbr2.bbr_pacing_gain = bbr->pacing_gain; + info->bbr2.bbr_cwnd_gain = bbr->cwnd_gain; + info->bbr2.bbr_bw_hi_lsb = (u32)bw_hi; + info->bbr2.bbr_bw_hi_msb = (u32)(bw_hi >> 32); + info->bbr2.bbr_bw_lo_lsb = (u32)bw_lo; + info->bbr2.bbr_bw_lo_msb = (u32)(bw_lo >> 32); + info->bbr2.bbr_mode = bbr->mode; + info->bbr2.bbr_phase = (__u8)bbr2_get_phase(bbr); + info->bbr2.bbr_version = (__u8)2; + info->bbr2.bbr_inflight_lo = bbr->inflight_lo; + info->bbr2.bbr_inflight_hi = bbr->inflight_hi; + info->bbr2.bbr_extra_acked = bbr_extra_acked(sk); + *attr = INET_DIAG_BBRINFO; + return sizeof(info->bbr2); + } + return 0; +} + +static void bbr2_set_state(struct sock *sk, u8 new_state) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (new_state == TCP_CA_Loss) { + struct rate_sample rs = { .losses = 1 }; + struct bbr_context ctx = { 0 }; + + tcp_plb_update_state_upon_rto(sk, &bbr->plb); + bbr->prev_ca_state = TCP_CA_Loss; + bbr->full_bw = 0; + if (!bbr2_is_probing_bandwidth(sk) && bbr->inflight_lo == ~0U) { + /* bbr_adapt_lower_bounds() needs cwnd before + * we suffered an RTO, to update inflight_lo: + */ + bbr->inflight_lo = + max(tp->snd_cwnd, bbr->prior_cwnd); + } + bbr_debug(sk, 0, &rs, &ctx); + } else if (bbr->prev_ca_state == TCP_CA_Loss && + new_state != TCP_CA_Loss) { + tp->snd_cwnd = max(tp->snd_cwnd, bbr->prior_cwnd); + bbr->try_fast_path = 0; /* bound cwnd using latest model */ + } +} + +static struct tcp_congestion_ops tcp_bbr2_cong_ops __read_mostly = { + .flags = TCP_CONG_NON_RESTRICTED | TCP_CONG_WANTS_CE_EVENTS, + .name = "bbr2", + .owner = THIS_MODULE, + .init = bbr2_init, + .cong_control = bbr2_main, + .sndbuf_expand = bbr_sndbuf_expand, + .skb_marked_lost = bbr2_skb_marked_lost, + .undo_cwnd = bbr2_undo_cwnd, + .cwnd_event = bbr_cwnd_event, + .ssthresh = bbr2_ssthresh, + .tso_segs = bbr_tso_segs, + .get_info = bbr2_get_info, + .set_state = bbr2_set_state, +}; + +static int __init bbr_register(void) +{ + BUILD_BUG_ON(sizeof(struct bbr) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_bbr2_cong_ops); +} + +static void __exit bbr_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_bbr2_cong_ops); +} + +module_init(bbr_register); +module_exit(bbr_unregister); + +MODULE_AUTHOR("Van Jacobson "); +MODULE_AUTHOR("Neal Cardwell "); +MODULE_AUTHOR("Yuchung Cheng "); +MODULE_AUTHOR("Soheil Hassas Yeganeh "); +MODULE_AUTHOR("Priyaranjan Jha "); +MODULE_AUTHOR("Yousuk Seung "); +MODULE_AUTHOR("Kevin Yang "); +MODULE_AUTHOR("Arjun Roy "); + +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("TCP BBR (Bottleneck Bandwidth and RTT)"); diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c index d3cae40749e8..0f268f2ff2e9 100644 --- a/net/ipv4/tcp_cong.c +++ b/net/ipv4/tcp_cong.c @@ -189,6 +189,7 @@ void tcp_init_congestion_control(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); tcp_sk(sk)->prior_ssthresh = 0; + tcp_sk(sk)->fast_ack_mode = 0; if (icsk->icsk_ca_ops->init) icsk->icsk_ca_ops->init(sk); if (tcp_ca_needs_ecn(sk)) diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 0640453fce54..8a455eb0c552 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -349,7 +349,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tcp_enter_quickack_mode(sk, 2); break; case INET_ECN_CE: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { @@ -360,7 +360,7 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) tp->ecn_flags |= TCP_ECN_SEEN; break; default: - if (tcp_ca_needs_ecn(sk)) + if (tcp_ca_wants_ce_events(sk)) tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); tp->ecn_flags |= TCP_ECN_SEEN; break; @@ -1079,7 +1079,12 @@ static void tcp_verify_retransmit_hint(struct tcp_sock *tp, struct sk_buff *skb) */ static void tcp_notify_skb_loss_event(struct tcp_sock *tp, const struct sk_buff *skb) { + struct sock *sk = (struct sock *)tp; + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + tp->lost += tcp_skb_pcount(skb); + if (ca_ops->skb_marked_lost) + ca_ops->skb_marked_lost(sk, skb); } void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) @@ -1460,6 +1465,17 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount); tcp_skb_pcount_add(skb, -pcount); + /* Adjust tx.in_flight as pcount is shifted from skb to prev. */ + if (WARN_ONCE(TCP_SKB_CB(skb)->tx.in_flight < pcount, + "prev in_flight: %u skb in_flight: %u pcount: %u", + TCP_SKB_CB(prev)->tx.in_flight, + TCP_SKB_CB(skb)->tx.in_flight, + pcount)) + TCP_SKB_CB(skb)->tx.in_flight = 0; + else + TCP_SKB_CB(skb)->tx.in_flight -= pcount; + TCP_SKB_CB(prev)->tx.in_flight += pcount; + /* When we're adding to gso_segs == 1, gso_size will be zero, * in theory this shouldn't be necessary but as long as DSACK * code can come after this skb later on it's better to keep @@ -3812,6 +3828,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) prior_fack = tcp_is_sack(tp) ? tcp_highest_sack_seq(tp) : tp->snd_una; rs.prior_in_flight = tcp_packets_in_flight(tp); + tcp_rate_check_app_limited(sk); /* ts_recent update must be made after we are sure that the packet * is in window. @@ -3910,6 +3927,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) delivered = tcp_newly_delivered(sk, delivered, flag); lost = tp->lost - lost; /* freshly marked lost */ rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); + rs.is_ece = !!(flag & FLAG_ECE); tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); tcp_cong_control(sk, ack, delivered, flag, sack_state.rate); tcp_xmit_recovery(sk, rexmit); @@ -5509,13 +5527,14 @@ static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) /* More than one full frame received... */ if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && + (tp->fast_ack_mode == 1 || /* ... and right edge of window advances far enough. * (tcp_recvmsg() will send ACK otherwise). * If application uses SO_RCVLOWAT, we want send ack now if * we have not received enough bytes to satisfy the condition. */ - (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || - __tcp_select_window(sk) >= tp->rcv_wnd)) || + (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || + __tcp_select_window(sk) >= tp->rcv_wnd))) || /* We ACK each frame or... */ tcp_in_quickack_mode(sk) || /* Protocol state mandates a one-time immediate ACK */ diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index ef8013e2134f..1685356b7045 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -3169,6 +3169,13 @@ static int __net_init tcp_sk_init(struct net *net) net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 0; atomic_set(&net->ipv4.tfo_active_disable_times, 0); + /* Set default values for PLB */ + net->ipv4.sysctl_tcp_plb_enabled = 0; /* Disabled by default */ + net->ipv4.sysctl_tcp_plb_cong_thresh = 128; /* 50% congestion */ + net->ipv4.sysctl_tcp_plb_idle_rehash_rounds = 3; + net->ipv4.sysctl_tcp_plb_rehash_rounds = 12; + net->ipv4.sysctl_tcp_plb_suspend_rto_sec = 60; + /* Reno is always built in */ if (!net_eq(net, &init_net) && bpf_try_module_get(init_net.ipv4.tcp_congestion_control, diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index c69f4d966024..a9ceec2702b2 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -375,7 +375,8 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, th->cwr = 1; skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; } - } else if (!tcp_ca_needs_ecn(sk)) { + } else if (!(tp->ecn_flags & TCP_ECN_ECT_PERMANENT) && + !tcp_ca_needs_ecn(sk)) { /* ACK or retransmitted segment: clear ECT|CE */ INET_ECN_dontxmit(sk); } @@ -1533,7 +1534,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, { struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *buff; - int nsize, old_factor; + int nsize, old_factor, inflight_prev; long limit; int nlen; u8 flags; @@ -1610,6 +1611,15 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, if (diff) tcp_adjust_pcount(sk, skb, diff); + + /* Set buff tx.in_flight as if buff were sent by itself. */ + inflight_prev = TCP_SKB_CB(skb)->tx.in_flight - old_factor; + if (WARN_ONCE(inflight_prev < 0, + "inconsistent: tx.in_flight: %u old_factor: %d", + TCP_SKB_CB(skb)->tx.in_flight, old_factor)) + inflight_prev = 0; + TCP_SKB_CB(buff)->tx.in_flight = inflight_prev + + tcp_skb_pcount(buff); } /* Link BUFF into the send queue. */ @@ -1993,13 +2003,12 @@ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) { const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; - u32 min_tso, tso_segs; - - min_tso = ca_ops->min_tso_segs ? - ca_ops->min_tso_segs(sk) : - READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + u32 tso_segs; - tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); + tso_segs = ca_ops->tso_segs ? + ca_ops->tso_segs(sk, mss_now) : + tcp_tso_autosize(sk, mss_now, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs)); return min_t(u32, tso_segs, sk->sk_gso_max_segs); } @@ -2635,6 +2644,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, skb_set_delivery_time(skb, tp->tcp_wstamp_ns, true); list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); tcp_init_tso_segs(skb, mss_now); + tcp_set_tx_in_flight(sk, skb); goto repair; /* Skip network transmission */ } diff --git a/net/ipv4/tcp_plb.c b/net/ipv4/tcp_plb.c new file mode 100644 index 000000000000..71b02c0404ce --- /dev/null +++ b/net/ipv4/tcp_plb.c @@ -0,0 +1,100 @@ +/* Protective Load Balancing (PLB) + * + * PLB was designed to reduce link load imbalance across datacenter + * switches. PLB is a host-based optimization; it leverages congestion + * signals from the transport layer to randomly change the path of the + * connection experiencing sustained congestion. PLB prefers to repath + * after idle periods to minimize packet reordering. It repaths by + * changing the IPv6 Flow Label on the packets of a connection, which + * datacenter switches include as part of ECMP/WCMP hashing. + * + * PLB is described in detail in: + * + * Mubashir Adnan Qureshi, Yuchung Cheng, Qianwen Yin, Qiaobin Fu, + * Gautam Kumar, Masoud Moshref, Junhua Yan, Van Jacobson, + * David Wetherall,Abdul Kabbani: + * "PLB: Congestion Signals are Simple and Effective for + * Network Load Balancing" + * In ACM SIGCOMM 2022, Amsterdam Netherlands. + * + */ + +#include + +/* Called once per round-trip to update PLB state for a connection. */ +void tcp_plb_update_state(const struct sock *sk, struct tcp_plb_state *plb, + const int cong_ratio) +{ + struct net *net = sock_net(sk); + + if (!plb->enabled) + return; + + if (cong_ratio >= 0) { + if (cong_ratio < net->ipv4.sysctl_tcp_plb_cong_thresh) + plb->consec_cong_rounds = 0; + else if (plb->consec_cong_rounds < + net->ipv4.sysctl_tcp_plb_rehash_rounds) + plb->consec_cong_rounds++; + } +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state); + +/* Check whether recent congestion has been persistent enough to warrant + * a load balancing decision that switches the connection to another path. + */ +void tcp_plb_check_rehash(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + bool can_idle_rehash, can_force_rehash; + + if (!plb->enabled) + return; + + /* Note that tcp_jiffies32 can wrap, so we clear pause_until + * to 0 to indicate there is no recent RTO event that constrains + * PLB rehashing. + */ + if (plb->pause_until && + !before(tcp_jiffies32, plb->pause_until)) + plb->pause_until = 0; + + can_idle_rehash = net->ipv4.sysctl_tcp_plb_idle_rehash_rounds && + !tcp_sk(sk)->packets_out && + plb->consec_cong_rounds >= + net->ipv4.sysctl_tcp_plb_idle_rehash_rounds; + can_force_rehash = plb->consec_cong_rounds >= + net->ipv4.sysctl_tcp_plb_rehash_rounds; + + if (!plb->pause_until && (can_idle_rehash || can_force_rehash)) { + sk_rethink_txhash(sk); + plb->consec_cong_rounds = 0; + tcp_sk(sk)->ecn_rehash++; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPECNREHASH); + } +} +EXPORT_SYMBOL_GPL(tcp_plb_check_rehash); + +/* Upon RTO, disallow load balancing for a while, to avoid having load + * balancing decisions switch traffic to a black-holed path that was + * previously avoided with a sk_rethink_txhash() call at RTO time. + */ +void tcp_plb_update_state_upon_rto(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + u32 pause; + + if (!plb->enabled) + return; + + pause = net->ipv4.sysctl_tcp_plb_suspend_rto_sec * HZ; + pause += prandom_u32_max(pause); + plb->pause_until = tcp_jiffies32 + pause; + + /* Reset PLB state upon RTO, since an RTO causes a sk_rethink_txhash() call + * that may switch this connection to a path with completely different + * congestion characteristics. + */ + plb->consec_cong_rounds = 0; +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state_upon_rto); diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c index a8f6d9d06f2e..a8b4c9504570 100644 --- a/net/ipv4/tcp_rate.c +++ b/net/ipv4/tcp_rate.c @@ -34,6 +34,24 @@ * ready to send in the write queue. */ +void tcp_set_tx_in_flight(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 in_flight; + + /* Check, sanitize, and record packets in flight after skb was sent. */ + in_flight = tcp_packets_in_flight(tp) + tcp_skb_pcount(skb); + if (WARN_ONCE(in_flight > TCPCB_IN_FLIGHT_MAX, + "insane in_flight %u cc %s mss %u " + "cwnd %u pif %u %u %u %u\n", + in_flight, inet_csk(sk)->icsk_ca_ops->name, + tp->mss_cache, tp->snd_cwnd, + tp->packets_out, tp->retrans_out, + tp->sacked_out, tp->lost_out)) + in_flight = TCPCB_IN_FLIGHT_MAX; + TCP_SKB_CB(skb)->tx.in_flight = in_flight; +} + /* Snapshot the current delivery information in the skb, to generate * a rate sample later when the skb is (s)acked in tcp_rate_skb_delivered(). */ @@ -66,7 +84,9 @@ void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->tx.delivered_mstamp = tp->delivered_mstamp; TCP_SKB_CB(skb)->tx.delivered = tp->delivered; TCP_SKB_CB(skb)->tx.delivered_ce = tp->delivered_ce; + TCP_SKB_CB(skb)->tx.lost = tp->lost; TCP_SKB_CB(skb)->tx.is_app_limited = tp->app_limited ? 1 : 0; + tcp_set_tx_in_flight(sk, skb); } /* When an skb is sacked or acked, we fill in the rate sample with the (prior) @@ -91,18 +111,21 @@ void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, if (!rs->prior_delivered || tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, scb->end_seq, rs->last_end_seq)) { + rs->prior_lost = scb->tx.lost; rs->prior_delivered_ce = scb->tx.delivered_ce; rs->prior_delivered = scb->tx.delivered; rs->prior_mstamp = scb->tx.delivered_mstamp; rs->is_app_limited = scb->tx.is_app_limited; rs->is_retrans = scb->sacked & TCPCB_RETRANS; rs->last_end_seq = scb->end_seq; + rs->tx_in_flight = scb->tx.in_flight; /* Record send time of most recently ACKed packet: */ tp->first_tx_mstamp = tx_tstamp; /* Find the duration of the "send phase" of this window: */ - rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, - scb->tx.first_tx_mstamp); + rs->interval_us = tcp_stamp32_us_delta( + tp->first_tx_mstamp, + scb->tx.first_tx_mstamp); } /* Mark off the skb delivered once it's sacked to avoid being @@ -144,6 +167,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, return; } rs->delivered = tp->delivered - rs->prior_delivered; + rs->lost = tp->lost - rs->prior_lost; rs->delivered_ce = tp->delivered_ce - rs->prior_delivered_ce; /* delivered_ce occupies less than 32 bits in the skb control block */ @@ -155,7 +179,7 @@ void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, * longer phase. */ snd_us = rs->interval_us; /* send phase */ - ack_us = tcp_stamp_us_delta(tp->tcp_mstamp, + ack_us = tcp_stamp32_us_delta(tp->tcp_mstamp, rs->prior_mstamp); /* ack phase */ rs->interval_us = max(snd_us, ack_us); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index b4dfb82d6ecb..bb613fc8948a 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -605,6 +605,7 @@ void tcp_write_timer_handler(struct sock *sk) return; } + tcp_rate_check_app_limited(sk); tcp_mstamp_refresh(tcp_sk(sk)); event = icsk->icsk_pending; -- 2.39.0.rc2 From 7f653a9c345e2d778cce9cf0ecdd8edb0b61cd8b Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Wed, 5 Oct 2022 19:08:58 +0200 Subject: [PATCH 03/20] futex-winesync Signed-off-by: Peter Jung --- Documentation/admin-guide/devices.txt | 3 +- Documentation/userspace-api/index.rst | 1 + .../userspace-api/ioctl/ioctl-number.rst | 2 + Documentation/userspace-api/winesync.rst | 444 +++++ MAINTAINERS | 9 + drivers/misc/Kconfig | 11 + drivers/misc/Makefile | 1 + drivers/misc/winesync.c | 1212 ++++++++++++++ include/linux/miscdevice.h | 1 + include/uapi/linux/futex.h | 13 + include/uapi/linux/winesync.h | 71 + kernel/futex/syscalls.c | 75 +- tools/testing/selftests/Makefile | 1 + .../selftests/drivers/winesync/Makefile | 8 + .../testing/selftests/drivers/winesync/config | 1 + .../selftests/drivers/winesync/winesync.c | 1479 +++++++++++++++++ 16 files changed, 3330 insertions(+), 2 deletions(-) create mode 100644 Documentation/userspace-api/winesync.rst create mode 100644 drivers/misc/winesync.c create mode 100644 include/uapi/linux/winesync.h create mode 100644 tools/testing/selftests/drivers/winesync/Makefile create mode 100644 tools/testing/selftests/drivers/winesync/config create mode 100644 tools/testing/selftests/drivers/winesync/winesync.c diff --git a/Documentation/admin-guide/devices.txt b/Documentation/admin-guide/devices.txt index 9764d6edb189..a4696d3b4a5a 100644 --- a/Documentation/admin-guide/devices.txt +++ b/Documentation/admin-guide/devices.txt @@ -376,8 +376,9 @@ 240 = /dev/userio Serio driver testing device 241 = /dev/vhost-vsock Host kernel driver for virtio vsock 242 = /dev/rfkill Turning off radio transmissions (rfkill) + 243 = /dev/winesync Wine synchronization primitive device - 243-254 Reserved for local use + 244-254 Reserved for local use 255 Reserved for MISC_DYNAMIC_MINOR 11 char Raw keyboard device (Linux/SPARC only) diff --git a/Documentation/userspace-api/index.rst b/Documentation/userspace-api/index.rst index a61eac0c73f8..0bf697ddcb09 100644 --- a/Documentation/userspace-api/index.rst +++ b/Documentation/userspace-api/index.rst @@ -29,6 +29,7 @@ place where this information is gathered. sysfs-platform_profile vduse futex2 + winesync .. only:: subproject and html diff --git a/Documentation/userspace-api/ioctl/ioctl-number.rst b/Documentation/userspace-api/ioctl/ioctl-number.rst index 3b985b19f39d..3f313fd4338c 100644 --- a/Documentation/userspace-api/ioctl/ioctl-number.rst +++ b/Documentation/userspace-api/ioctl/ioctl-number.rst @@ -375,6 +375,8 @@ Code Seq# Include File Comments 0xF6 all LTTng Linux Trace Toolkit Next Generation +0xF7 00-0F uapi/linux/winesync.h Wine synchronization primitives + 0xF8 all arch/x86/include/uapi/asm/amd_hsmp.h AMD HSMP EPYC system management interface driver 0xFD all linux/dm-ioctl.h diff --git a/Documentation/userspace-api/winesync.rst b/Documentation/userspace-api/winesync.rst new file mode 100644 index 000000000000..d7055bf41820 --- /dev/null +++ b/Documentation/userspace-api/winesync.rst @@ -0,0 +1,444 @@ +===================================== +Wine synchronization primitive driver +===================================== + +This page documents the user-space API for the winesync driver. + +winesync is a support driver for emulation of NT synchronization +primitives by the Wine project or other NT emulators. It exists +because implementation in user-space, using existing tools, cannot +simultaneously satisfy performance, correctness, and security +constraints. It is implemented entirely in software, and does not +drive any hardware device. + +This interface is meant as a compatibility tool only, and should not +be used for general synchronization. Instead use generic, versatile +interfaces such as futex(2) and poll(2). + +Synchronization primitives +========================== + +The winesync driver exposes three types of synchronization primitives: +semaphores, mutexes, and events. + +A semaphore holds a single volatile 32-bit counter, and a static +32-bit integer denoting the maximum value. It is considered signaled +when the counter is nonzero. The counter is decremented by one when a +wait is satisfied. Both the initial and maximum count are established +when the semaphore is created. + +A mutex holds a volatile 32-bit recursion count, and a volatile 32-bit +identifier denoting its owner. A mutex is considered signaled when its +owner is zero (indicating that it is not owned). The recursion count +is incremented when a wait is satisfied, and ownership is set to the +given identifier. + +A mutex also holds an internal flag denoting whether its previous +owner has died; such a mutex is said to be inconsistent. Owner death +is not tracked automatically based on thread death, but rather must be +communicated using ``WINESYNC_IOC_KILL_OWNER``. An inconsistent mutex +is inherently considered unowned. + +Except for the "unowned" semantics of zero, the actual value of the +owner identifier is not interpreted by the winesync driver at all. The +intended use is to store a thread identifier; however, the winesync +driver does not actually validate that a calling thread provides +consistent or unique identifiers. + +An event holds a volatile boolean state denoting whether it is +signaled or not. There are two types of events, auto-reset and +manual-reset. An auto-reset event is designaled when a wait is +satisfied; a manual-reset event is not. The event type is specified +when the event is created. + +Unless specified otherwise, all operations on an object are atomic and +totally ordered with respect to other operations on the same object. + +Objects are represented by unsigned 32-bit integers. + +Char device +=========== + +The winesync driver creates a single char device /dev/winesync. Each +file description opened on the device represents a unique namespace. +That is, objects created on one open file description are shared +across all its individual descriptors, but are not shared with other +open() calls on the same device. The same file description may be +shared across multiple processes. + +ioctl reference +=============== + +All operations on the device are done through ioctls. There are three +structures used in ioctl calls:: + + struct winesync_sem_args { + __u32 sem; + __u32 count; + __u32 max; + }; + + struct winesync_mutex_args { + __u32 mutex; + __u32 owner; + __u32 count; + }; + + struct winesync_event_args { + __u32 event; + __u32 signaled; + __u32 manual; + }; + + struct winesync_wait_args { + __u64 timeout; + __u64 objs; + __u32 count; + __u32 owner; + __u32 index; + __u32 pad; + }; + +Depending on the ioctl, members of the structure may be used as input, +output, or not at all. All ioctls return 0 on success. + +The ioctls are as follows: + +.. c:macro:: WINESYNC_IOC_CREATE_SEM + + Create a semaphore object. Takes a pointer to struct + :c:type:`winesync_sem_args`, which is used as follows: + + .. list-table:: + + * - ``sem`` + - On output, contains the identifier of the created semaphore. + * - ``count`` + - Initial count of the semaphore. + * - ``max`` + - Maximum count of the semaphore. + + Fails with ``EINVAL`` if ``count`` is greater than ``max``. + +.. c:macro:: WINESYNC_IOC_CREATE_MUTEX + + Create a mutex object. Takes a pointer to struct + :c:type:`winesync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - On output, contains the identifier of the created mutex. + * - ``count`` + - Initial recursion count of the mutex. + * - ``owner`` + - Initial owner of the mutex. + + If ``owner`` is nonzero and ``count`` is zero, or if ``owner`` is + zero and ``count`` is nonzero, the function fails with ``EINVAL``. + +.. c:macro:: WINESYNC_IOC_CREATE_EVENT + + Create an event object. Takes a pointer to struct + :c:type:`winesync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - On output, contains the identifier of the created event. + * - ``signaled`` + - If nonzero, the event is initially signaled, otherwise + nonsignaled. + * - ``manual`` + - If nonzero, the event is a manual-reset event, otherwise + auto-reset. + +.. c:macro:: WINESYNC_IOC_DELETE + + Delete an object of any type. Takes an input-only pointer to a + 32-bit integer denoting the object to delete. + + Wait ioctls currently in progress are not interrupted, and behave as + if the object remains valid. + +.. c:macro:: WINESYNC_IOC_PUT_SEM + + Post to a semaphore object. Takes a pointer to struct + :c:type:`winesync_sem_args`, which is used as follows: + + .. list-table:: + + * - ``sem`` + - Semaphore object to post to. + * - ``count`` + - Count to add to the semaphore. On output, contains the + previous count of the semaphore. + * - ``max`` + - Not used. + + If adding ``count`` to the semaphore's current count would raise the + latter past the semaphore's maximum count, the ioctl fails with + ``EOVERFLOW`` and the semaphore is not affected. If raising the + semaphore's count causes it to become signaled, eligible threads + waiting on this semaphore will be woken and the semaphore's count + decremented appropriately. + +.. c:macro:: WINESYNC_IOC_PUT_MUTEX + + Release a mutex object. Takes a pointer to struct + :c:type:`winesync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - Mutex object to release. + * - ``owner`` + - Mutex owner identifier. + * - ``count`` + - On output, contains the previous recursion count. + + If ``owner`` is zero, the ioctl fails with ``EINVAL``. If ``owner`` + is not the current owner of the mutex, the ioctl fails with + ``EPERM``. + + The mutex's count will be decremented by one. If decrementing the + mutex's count causes it to become zero, the mutex is marked as + unowned and signaled, and eligible threads waiting on it will be + woken as appropriate. + +.. c:macro:: WINESYNC_IOC_SET_EVENT + + Signal an event object. Takes a pointer to struct + :c:type:`winesync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - Event object to set. + * - ``signaled`` + - On output, contains the previous state of the event. + * - ``manual`` + - Unused. + + Eligible threads will be woken, and auto-reset events will be + designaled appropriately. + +.. c:macro:: WINESYNC_IOC_RESET_EVENT + + Designal an event object. Takes a pointer to struct + :c:type:`winesync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - Event object to reset. + * - ``signaled`` + - On output, contains the previous state of the event. + * - ``manual`` + - Unused. + +.. c:macro:: WINESYNC_IOC_PULSE_EVENT + + Wake threads waiting on an event object without leaving it in a + signaled state. Takes a pointer to struct + :c:type:`winesync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - Event object to pulse. + * - ``signaled`` + - On output, contains the previous state of the event. + * - ``manual`` + - Unused. + + A pulse operation can be thought of as a set followed by a reset, + performed as a single atomic operation. If two threads are waiting + on an auto-reset event which is pulsed, only one will be woken. If + two threads are waiting a manual-reset event which is pulsed, both + will be woken. However, in both cases, the event will be unsignaled + afterwards, and a simultaneous read operation will always report the + event as unsignaled. + +.. c:macro:: WINESYNC_IOC_READ_SEM + + Read the current state of a semaphore object. Takes a pointer to + struct :c:type:`winesync_sem_args`, which is used as follows: + + .. list-table:: + + * - ``sem`` + - Semaphore object to read. + * - ``count`` + - On output, contains the current count of the semaphore. + * - ``max`` + - On output, contains the maximum count of the semaphore. + +.. c:macro:: WINESYNC_IOC_READ_MUTEX + + Read the current state of a mutex object. Takes a pointer to struct + :c:type:`winesync_mutex_args`, which is used as follows: + + .. list-table:: + + * - ``mutex`` + - Mutex object to read. + * - ``owner`` + - On output, contains the current owner of the mutex, or zero + if the mutex is not currently owned. + * - ``count`` + - On output, contains the current recursion count of the mutex. + + If the mutex is marked as inconsistent, the function fails with + ``EOWNERDEAD``. In this case, ``count`` and ``owner`` are set to + zero. + +.. c:macro:: WINESYNC_IOC_READ_EVENT + + Read the current state of an event object. Takes a pointer to struct + :c:type:`winesync_event_args`, which is used as follows: + + .. list-table:: + + * - ``event`` + - Event object. + * - ``signaled`` + - On output, contains the current state of the event. + * - ``manual`` + - On output, contains 1 if the event is a manual-reset event, + and 0 otherwise. + +.. c:macro:: WINESYNC_IOC_KILL_OWNER + + Mark any mutexes owned by the given owner as unowned and + inconsistent. Takes an input-only pointer to a 32-bit integer + denoting the owner. If the owner is zero, the ioctl fails with + ``EINVAL``. + + For each mutex currently owned by the given owner, eligible threads + waiting on said mutex will be woken as appropriate (and such waits + will fail with ``EOWNERDEAD``, as described below). + + The operation as a whole is not atomic; however, the modification of + each mutex is atomic and totally ordered with respect to other + operations on the same mutex. + +.. c:macro:: WINESYNC_IOC_WAIT_ANY + + Poll on any of a list of objects, atomically acquiring at most one. + Takes a pointer to struct :c:type:`winesync_wait_args`, which is + used as follows: + + .. list-table:: + + * - ``timeout`` + - Optional pointer to a 64-bit struct :c:type:`timespec` + (specified as an integer so that the structure has the same + size regardless of architecture). The timeout is specified in + absolute format, as measured against the MONOTONIC clock. If + the timeout is equal to or earlier than the current time, the + function returns immediately without sleeping. If ``timeout`` + is zero, i.e. NULL, the function will sleep until an object + is signaled, and will not fail with ``ETIMEDOUT``. + * - ``objs`` + - Pointer to an array of ``count`` 32-bit object identifiers + (specified as an integer so that the structure has the same + size regardless of architecture). If any identifier is + invalid, the function fails with ``EINVAL``. + * - ``count`` + - Number of object identifiers specified in the ``objs`` array. + * - ``owner`` + - Mutex owner identifier. If any object in ``objs`` is a mutex, + the ioctl will attempt to acquire that mutex on behalf of + ``owner``. If ``owner`` is zero, the ioctl fails with + ``EINVAL``. + * - ``index`` + - On success, contains the index (into ``objs``) of the object + which was signaled. If ``alert`` was signaled instead, + this contains ``count``. + * - ``alert`` + - Optional event object identifier. If nonzero, this specifies + an "alert" event object which, if signaled, will terminate + the wait. If nonzero, the identifier must point to a valid + event. + + This function attempts to acquire one of the given objects. If + unable to do so, it sleeps until an object becomes signaled, + subsequently acquiring it, or the timeout expires. In the latter + case the ioctl fails with ``ETIMEDOUT``. The function only acquires + one object, even if multiple objects are signaled. + + A semaphore is considered to be signaled if its count is nonzero, + and is acquired by decrementing its count by one. A mutex is + considered to be signaled if it is unowned or if its owner matches + the ``owner`` argument, and is acquired by incrementing its + recursion count by one and setting its owner to the ``owner`` + argument. An auto-reset event is acquired by designaling it; a + manual-reset event is not affected by acquisition. + + Acquisition is atomic and totally ordered with respect to other + operations on the same object. If two wait operations (with + different ``owner`` identifiers) are queued on the same mutex, only + one is signaled. If two wait operations are queued on the same + semaphore, and a value of one is posted to it, only one is signaled. + The order in which threads are signaled is not specified. + + If an inconsistent mutex is acquired, the ioctl fails with + ``EOWNERDEAD``. Although this is a failure return, the function may + otherwise be considered successful. The mutex is marked as owned by + the given owner (with a recursion count of 1) and as no longer + inconsistent, and ``index`` is still set to the index of the mutex. + + The ``alert`` argument is an "extra" event which can terminate the + wait, independently of all other objects. If members of ``objs`` and + ``alert`` are both simultaneously signaled, a member of ``objs`` + will always be given priority and acquired first. Aside from this, + for "any" waits, there is no difference between passing an event as + this parameter, and passing it as an additional object at the end of + the ``objs`` array. For "all" waits, there is an additional + difference, as described below. + + It is valid to pass the same object more than once, including by + passing the same event in the ``objs`` array and in ``alert``. If a + wakeup occurs due to that object being signaled, ``index`` is set to + the lowest index corresponding to that object. + + The function may fail with ``EINTR`` if a signal is received. + +.. c:macro:: WINESYNC_IOC_WAIT_ALL + + Poll on a list of objects, atomically acquiring all of them. Takes a + pointer to struct :c:type:`winesync_wait_args`, which is used + identically to ``WINESYNC_IOC_WAIT_ANY``, except that ``index`` is + always filled with zero on success if not woken via alert. + + This function attempts to simultaneously acquire all of the given + objects. If unable to do so, it sleeps until all objects become + simultaneously signaled, subsequently acquiring them, or the timeout + expires. In the latter case the ioctl fails with ``ETIMEDOUT`` and + no objects are modified. + + Objects may become signaled and subsequently designaled (through + acquisition by other threads) while this thread is sleeping. Only + once all objects are simultaneously signaled does the ioctl acquire + them and return. The entire acquisition is atomic and totally + ordered with respect to other operations on any of the given + objects. + + If an inconsistent mutex is acquired, the ioctl fails with + ``EOWNERDEAD``. Similarly to ``WINESYNC_IOC_WAIT_ANY``, all objects + are nevertheless marked as acquired. Note that if multiple mutex + objects are specified, there is no way to know which were marked as + inconsistent. + + As with "any" waits, the ``alert`` argument is an "extra" event + which can terminate the wait. Critically, however, an "all" wait + will succeed if all members in ``objs`` are signaled, *or* if + ``alert`` is signaled. In the latter case ``index`` will be set to + ``count``. As with "any" waits, if both conditions are filled, the + former takes priority, and objects in ``objs`` will be acquired. + + Unlike ``WINESYNC_IOC_WAIT_ANY``, it is not valid to pass the same + object more than once, nor is it valid to pass the same object in + ``objs`` and in ``alert`` If this is attempted, the function fails + with ``EINVAL``. diff --git a/MAINTAINERS b/MAINTAINERS index e3f4a32f28e4..c7461d56676c 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -21920,6 +21920,15 @@ M: David Härdeman S: Maintained F: drivers/media/rc/winbond-cir.c +WINESYNC SYNCHRONIZATION PRIMITIVE DRIVER +M: Zebediah Figura +L: wine-devel@winehq.org +S: Supported +F: Documentation/userspace-api/winesync.rst +F: drivers/misc/winesync.c +F: include/uapi/linux/winesync.h +F: tools/testing/selftests/drivers/winesync/ + WINSYSTEMS EBC-C384 WATCHDOG DRIVER M: William Breathitt Gray L: linux-watchdog@vger.kernel.org diff --git a/drivers/misc/Kconfig b/drivers/misc/Kconfig index 94e9fb4cdd76..bdf56cd530b7 100644 --- a/drivers/misc/Kconfig +++ b/drivers/misc/Kconfig @@ -483,6 +483,17 @@ config OPEN_DICE If unsure, say N. +config WINESYNC + tristate "Synchronization primitives for Wine" + help + This module provides kernel support for synchronization primitives + used by Wine. It is not a hardware driver. + + To compile this driver as a module, choose M here: the + module will be called winesync. + + If unsure, say N. + config VCPU_STALL_DETECTOR tristate "Guest vCPU stall detector" depends on OF && HAS_IOMEM diff --git a/drivers/misc/Makefile b/drivers/misc/Makefile index 2be8542616dd..d061fe45407b 100644 --- a/drivers/misc/Makefile +++ b/drivers/misc/Makefile @@ -58,6 +58,7 @@ obj-$(CONFIG_HABANA_AI) += habanalabs/ obj-$(CONFIG_UACCE) += uacce/ obj-$(CONFIG_XILINX_SDFEC) += xilinx_sdfec.o obj-$(CONFIG_HISI_HIKEY_USB) += hisi_hikey_usb.o +obj-$(CONFIG_WINESYNC) += winesync.o obj-$(CONFIG_HI6421V600_IRQ) += hi6421v600-irq.o obj-$(CONFIG_OPEN_DICE) += open-dice.o obj-$(CONFIG_VCPU_STALL_DETECTOR) += vcpu_stall_detector.o \ No newline at end of file diff --git a/drivers/misc/winesync.c b/drivers/misc/winesync.c new file mode 100644 index 000000000000..7a28f58dbbf2 --- /dev/null +++ b/drivers/misc/winesync.c @@ -0,0 +1,1212 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * winesync.c - Kernel driver for Wine synchronization primitives + * + * Copyright (C) 2021 Zebediah Figura + */ + +#include +#include +#include +#include +#include +#include + +#define WINESYNC_NAME "winesync" + +enum winesync_type { + WINESYNC_TYPE_SEM, + WINESYNC_TYPE_MUTEX, + WINESYNC_TYPE_EVENT, +}; + +struct winesync_obj { + struct rcu_head rhead; + struct kref refcount; + spinlock_t lock; + + /* + * any_waiters is protected by the object lock, but all_waiters is + * protected by the device wait_all_lock. + */ + struct list_head any_waiters; + struct list_head all_waiters; + + /* + * Hint describing how many tasks are queued on this object in a + * wait-all operation. + * + * Any time we do a wake, we may need to wake "all" waiters as well as + * "any" waiters. In order to atomically wake "all" waiters, we must + * lock all of the objects, and that means grabbing the wait_all_lock + * below (and, due to lock ordering rules, before locking this object). + * However, wait-all is a rare operation, and grabbing the wait-all + * lock for every wake would create unnecessary contention. Therefore we + * first check whether all_hint is zero, and, if it is, we skip trying + * to wake "all" waiters. + * + * This hint isn't protected by any lock. It might change during the + * course of a wake, but there's no meaningful race there; it's only a + * hint. + * + * Since wait requests must originate from user-space threads, we're + * limited here by PID_MAX_LIMIT, so there's no risk of saturation. + */ + atomic_t all_hint; + + enum winesync_type type; + + /* The following fields are protected by the object lock. */ + union { + struct { + __u32 count; + __u32 max; + } sem; + struct { + __u32 count; + __u32 owner; + bool ownerdead; + } mutex; + struct { + bool manual; + bool signaled; + } event; + } u; +}; + +struct winesync_q_entry { + struct list_head node; + struct winesync_q *q; + struct winesync_obj *obj; + __u32 index; +}; + +struct winesync_q { + struct task_struct *task; + __u32 owner; + + /* + * Protected via atomic_cmpxchg(). Only the thread that wins the + * compare-and-swap may actually change object states and wake this + * task. + */ + atomic_t signaled; + + bool all; + bool ownerdead; + __u32 count; + struct winesync_q_entry entries[]; +}; + +struct winesync_device { + /* + * Wait-all operations must atomically grab all objects, and be totally + * ordered with respect to each other and wait-any operations. If one + * thread is trying to acquire several objects, another thread cannot + * touch the object at the same time. + * + * We achieve this by grabbing multiple object locks at the same time. + * However, this creates a lock ordering problem. To solve that problem, + * wait_all_lock is taken first whenever multiple objects must be locked + * at the same time. + */ + spinlock_t wait_all_lock; + + struct xarray objects; +}; + +static struct winesync_obj *get_obj(struct winesync_device *dev, __u32 id) +{ + struct winesync_obj *obj; + + rcu_read_lock(); + obj = xa_load(&dev->objects, id); + if (obj && !kref_get_unless_zero(&obj->refcount)) + obj = NULL; + rcu_read_unlock(); + + return obj; +} + +static void destroy_obj(struct kref *ref) +{ + struct winesync_obj *obj = container_of(ref, struct winesync_obj, refcount); + + kfree_rcu(obj, rhead); +} + +static void put_obj(struct winesync_obj *obj) +{ + kref_put(&obj->refcount, destroy_obj); +} + +static struct winesync_obj *get_obj_typed(struct winesync_device *dev, __u32 id, + enum winesync_type type) +{ + struct winesync_obj *obj = get_obj(dev, id); + + if (obj && obj->type != type) { + put_obj(obj); + return NULL; + } + return obj; +} + +static int winesync_char_open(struct inode *inode, struct file *file) +{ + struct winesync_device *dev; + + dev = kzalloc(sizeof(*dev), GFP_KERNEL); + if (!dev) + return -ENOMEM; + + spin_lock_init(&dev->wait_all_lock); + + xa_init_flags(&dev->objects, XA_FLAGS_ALLOC); + + file->private_data = dev; + return nonseekable_open(inode, file); +} + +static int winesync_char_release(struct inode *inode, struct file *file) +{ + struct winesync_device *dev = file->private_data; + struct winesync_obj *obj; + unsigned long id; + + xa_for_each(&dev->objects, id, obj) + put_obj(obj); + + xa_destroy(&dev->objects); + + kfree(dev); + + return 0; +} + +static void init_obj(struct winesync_obj *obj) +{ + kref_init(&obj->refcount); + atomic_set(&obj->all_hint, 0); + spin_lock_init(&obj->lock); + INIT_LIST_HEAD(&obj->any_waiters); + INIT_LIST_HEAD(&obj->all_waiters); +} + +static bool is_signaled(struct winesync_obj *obj, __u32 owner) +{ + lockdep_assert_held(&obj->lock); + + switch (obj->type) { + case WINESYNC_TYPE_SEM: + return !!obj->u.sem.count; + case WINESYNC_TYPE_MUTEX: + if (obj->u.mutex.owner && obj->u.mutex.owner != owner) + return false; + return obj->u.mutex.count < UINT_MAX; + case WINESYNC_TYPE_EVENT: + return obj->u.event.signaled; + } + + WARN(1, "bad object type %#x\n", obj->type); + return false; +} + +/* + * "locked_obj" is an optional pointer to an object which is already locked and + * should not be locked again. This is necessary so that changing an object's + * state and waking it can be a single atomic operation. + */ +static void try_wake_all(struct winesync_device *dev, struct winesync_q *q, + struct winesync_obj *locked_obj) +{ + __u32 count = q->count; + bool can_wake = true; + __u32 i; + + lockdep_assert_held(&dev->wait_all_lock); + if (locked_obj) + lockdep_assert_held(&locked_obj->lock); + + for (i = 0; i < count; i++) { + if (q->entries[i].obj != locked_obj) + spin_lock(&q->entries[i].obj->lock); + } + + for (i = 0; i < count; i++) { + if (!is_signaled(q->entries[i].obj, q->owner)) { + can_wake = false; + break; + } + } + + if (can_wake && atomic_cmpxchg(&q->signaled, -1, 0) == -1) { + for (i = 0; i < count; i++) { + struct winesync_obj *obj = q->entries[i].obj; + + switch (obj->type) { + case WINESYNC_TYPE_SEM: + obj->u.sem.count--; + break; + case WINESYNC_TYPE_MUTEX: + if (obj->u.mutex.ownerdead) + q->ownerdead = true; + obj->u.mutex.ownerdead = false; + obj->u.mutex.count++; + obj->u.mutex.owner = q->owner; + break; + case WINESYNC_TYPE_EVENT: + if (!obj->u.event.manual) + obj->u.event.signaled = false; + break; + } + } + wake_up_process(q->task); + } + + for (i = 0; i < count; i++) { + if (q->entries[i].obj != locked_obj) + spin_unlock(&q->entries[i].obj->lock); + } +} + +static void try_wake_all_obj(struct winesync_device *dev, + struct winesync_obj *obj) +{ + struct winesync_q_entry *entry; + + lockdep_assert_held(&dev->wait_all_lock); + lockdep_assert_held(&obj->lock); + + list_for_each_entry(entry, &obj->all_waiters, node) + try_wake_all(dev, entry->q, obj); +} + +static void try_wake_any_sem(struct winesync_obj *sem) +{ + struct winesync_q_entry *entry; + + lockdep_assert_held(&sem->lock); + + list_for_each_entry(entry, &sem->any_waiters, node) { + struct winesync_q *q = entry->q; + + if (!sem->u.sem.count) + break; + + if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) { + sem->u.sem.count--; + wake_up_process(q->task); + } + } +} + +static void try_wake_any_mutex(struct winesync_obj *mutex) +{ + struct winesync_q_entry *entry; + + lockdep_assert_held(&mutex->lock); + + list_for_each_entry(entry, &mutex->any_waiters, node) { + struct winesync_q *q = entry->q; + + if (mutex->u.mutex.count == UINT_MAX) + break; + if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner) + continue; + + if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) { + if (mutex->u.mutex.ownerdead) + q->ownerdead = true; + mutex->u.mutex.ownerdead = false; + mutex->u.mutex.count++; + mutex->u.mutex.owner = q->owner; + wake_up_process(q->task); + } + } +} + +static void try_wake_any_event(struct winesync_obj *event) +{ + struct winesync_q_entry *entry; + + lockdep_assert_held(&event->lock); + + list_for_each_entry(entry, &event->any_waiters, node) { + struct winesync_q *q = entry->q; + + if (!event->u.event.signaled) + break; + + if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) { + if (!event->u.event.manual) + event->u.event.signaled = false; + wake_up_process(q->task); + } + } +} + +static int winesync_create_sem(struct winesync_device *dev, void __user *argp) +{ + struct winesync_sem_args __user *user_args = argp; + struct winesync_sem_args args; + struct winesync_obj *sem; + __u32 id; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + if (args.count > args.max) + return -EINVAL; + + sem = kzalloc(sizeof(*sem), GFP_KERNEL); + if (!sem) + return -ENOMEM; + + init_obj(sem); + sem->type = WINESYNC_TYPE_SEM; + sem->u.sem.count = args.count; + sem->u.sem.max = args.max; + + ret = xa_alloc(&dev->objects, &id, sem, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + kfree(sem); + return ret; + } + + return put_user(id, &user_args->sem); +} + +static int winesync_create_mutex(struct winesync_device *dev, void __user *argp) +{ + struct winesync_mutex_args __user *user_args = argp; + struct winesync_mutex_args args; + struct winesync_obj *mutex; + __u32 id; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + if (!args.owner != !args.count) + return -EINVAL; + + mutex = kzalloc(sizeof(*mutex), GFP_KERNEL); + if (!mutex) + return -ENOMEM; + + init_obj(mutex); + mutex->type = WINESYNC_TYPE_MUTEX; + mutex->u.mutex.count = args.count; + mutex->u.mutex.owner = args.owner; + + ret = xa_alloc(&dev->objects, &id, mutex, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + kfree(mutex); + return ret; + } + + return put_user(id, &user_args->mutex); +} + +static int winesync_create_event(struct winesync_device *dev, void __user *argp) +{ + struct winesync_event_args __user *user_args = argp; + struct winesync_event_args args; + struct winesync_obj *event; + __u32 id; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + event = kzalloc(sizeof(*event), GFP_KERNEL); + if (!event) + return -ENOMEM; + + init_obj(event); + event->type = WINESYNC_TYPE_EVENT; + event->u.event.manual = args.manual; + event->u.event.signaled = args.signaled; + + ret = xa_alloc(&dev->objects, &id, event, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + kfree(event); + return ret; + } + + return put_user(id, &user_args->event); +} + +static int winesync_delete(struct winesync_device *dev, void __user *argp) +{ + struct winesync_obj *obj; + __u32 id; + + if (get_user(id, (__u32 __user *)argp)) + return -EFAULT; + + obj = xa_erase(&dev->objects, id); + if (!obj) + return -EINVAL; + + put_obj(obj); + return 0; +} + +/* + * Actually change the semaphore state, returning -EOVERFLOW if it is made + * invalid. + */ +static int put_sem_state(struct winesync_obj *sem, __u32 count) +{ + lockdep_assert_held(&sem->lock); + + if (sem->u.sem.count + count < sem->u.sem.count || + sem->u.sem.count + count > sem->u.sem.max) + return -EOVERFLOW; + + sem->u.sem.count += count; + return 0; +} + +static int winesync_put_sem(struct winesync_device *dev, void __user *argp) +{ + struct winesync_sem_args __user *user_args = argp; + struct winesync_sem_args args; + struct winesync_obj *sem; + __u32 prev_count; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + sem = get_obj_typed(dev, args.sem, WINESYNC_TYPE_SEM); + if (!sem) + return -EINVAL; + + if (atomic_read(&sem->all_hint) > 0) { + spin_lock(&dev->wait_all_lock); + spin_lock(&sem->lock); + + prev_count = sem->u.sem.count; + ret = put_sem_state(sem, args.count); + if (!ret) { + try_wake_all_obj(dev, sem); + try_wake_any_sem(sem); + } + + spin_unlock(&sem->lock); + spin_unlock(&dev->wait_all_lock); + } else { + spin_lock(&sem->lock); + + prev_count = sem->u.sem.count; + ret = put_sem_state(sem, args.count); + if (!ret) + try_wake_any_sem(sem); + + spin_unlock(&sem->lock); + } + + put_obj(sem); + + if (!ret && put_user(prev_count, &user_args->count)) + ret = -EFAULT; + + return ret; +} + +/* + * Actually change the mutex state, returning -EPERM if not the owner. + */ +static int put_mutex_state(struct winesync_obj *mutex, + const struct winesync_mutex_args *args) +{ + lockdep_assert_held(&mutex->lock); + + if (mutex->u.mutex.owner != args->owner) + return -EPERM; + + if (!--mutex->u.mutex.count) + mutex->u.mutex.owner = 0; + return 0; +} + +static int winesync_put_mutex(struct winesync_device *dev, void __user *argp) +{ + struct winesync_mutex_args __user *user_args = argp; + struct winesync_mutex_args args; + struct winesync_obj *mutex; + __u32 prev_count; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + if (!args.owner) + return -EINVAL; + + mutex = get_obj_typed(dev, args.mutex, WINESYNC_TYPE_MUTEX); + if (!mutex) + return -EINVAL; + + if (atomic_read(&mutex->all_hint) > 0) { + spin_lock(&dev->wait_all_lock); + spin_lock(&mutex->lock); + + prev_count = mutex->u.mutex.count; + ret = put_mutex_state(mutex, &args); + if (!ret) { + try_wake_all_obj(dev, mutex); + try_wake_any_mutex(mutex); + } + + spin_unlock(&mutex->lock); + spin_unlock(&dev->wait_all_lock); + } else { + spin_lock(&mutex->lock); + + prev_count = mutex->u.mutex.count; + ret = put_mutex_state(mutex, &args); + if (!ret) + try_wake_any_mutex(mutex); + + spin_unlock(&mutex->lock); + } + + put_obj(mutex); + + if (!ret && put_user(prev_count, &user_args->count)) + ret = -EFAULT; + + return ret; +} + +static int winesync_read_sem(struct winesync_device *dev, void __user *argp) +{ + struct winesync_sem_args __user *user_args = argp; + struct winesync_sem_args args; + struct winesync_obj *sem; + __u32 id; + + if (get_user(id, &user_args->sem)) + return -EFAULT; + + sem = get_obj_typed(dev, id, WINESYNC_TYPE_SEM); + if (!sem) + return -EINVAL; + + args.sem = id; + spin_lock(&sem->lock); + args.count = sem->u.sem.count; + args.max = sem->u.sem.max; + spin_unlock(&sem->lock); + + put_obj(sem); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return 0; +} + +static int winesync_read_mutex(struct winesync_device *dev, void __user *argp) +{ + struct winesync_mutex_args __user *user_args = argp; + struct winesync_mutex_args args; + struct winesync_obj *mutex; + __u32 id; + int ret; + + if (get_user(id, &user_args->mutex)) + return -EFAULT; + + mutex = get_obj_typed(dev, id, WINESYNC_TYPE_MUTEX); + if (!mutex) + return -EINVAL; + + args.mutex = id; + spin_lock(&mutex->lock); + args.count = mutex->u.mutex.count; + args.owner = mutex->u.mutex.owner; + ret = mutex->u.mutex.ownerdead ? -EOWNERDEAD : 0; + spin_unlock(&mutex->lock); + + put_obj(mutex); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return ret; +} + +static int winesync_read_event(struct winesync_device *dev, void __user *argp) +{ + struct winesync_event_args __user *user_args = argp; + struct winesync_event_args args; + struct winesync_obj *event; + __u32 id; + + if (get_user(id, &user_args->event)) + return -EFAULT; + + event = get_obj_typed(dev, id, WINESYNC_TYPE_EVENT); + if (!event) + return -EINVAL; + + args.event = id; + spin_lock(&event->lock); + args.manual = event->u.event.manual; + args.signaled = event->u.event.signaled; + spin_unlock(&event->lock); + + put_obj(event); + + if (copy_to_user(user_args, &args, sizeof(args))) + return -EFAULT; + return 0; +} + +/* + * Actually change the mutex state to mark its owner as dead. + */ +static void put_mutex_ownerdead_state(struct winesync_obj *mutex) +{ + lockdep_assert_held(&mutex->lock); + + mutex->u.mutex.ownerdead = true; + mutex->u.mutex.owner = 0; + mutex->u.mutex.count = 0; +} + +static int winesync_kill_owner(struct winesync_device *dev, void __user *argp) +{ + struct winesync_obj *obj; + unsigned long id; + __u32 owner; + + if (get_user(owner, (__u32 __user *)argp)) + return -EFAULT; + if (!owner) + return -EINVAL; + + rcu_read_lock(); + + xa_for_each(&dev->objects, id, obj) { + if (!kref_get_unless_zero(&obj->refcount)) + continue; + + if (obj->type != WINESYNC_TYPE_MUTEX) { + put_obj(obj); + continue; + } + + if (atomic_read(&obj->all_hint) > 0) { + spin_lock(&dev->wait_all_lock); + spin_lock(&obj->lock); + + if (obj->u.mutex.owner == owner) { + put_mutex_ownerdead_state(obj); + try_wake_all_obj(dev, obj); + try_wake_any_mutex(obj); + } + + spin_unlock(&obj->lock); + spin_unlock(&dev->wait_all_lock); + } else { + spin_lock(&obj->lock); + + if (obj->u.mutex.owner == owner) { + put_mutex_ownerdead_state(obj); + try_wake_any_mutex(obj); + } + + spin_unlock(&obj->lock); + } + + put_obj(obj); + } + + rcu_read_unlock(); + + return 0; +} + +static int winesync_set_event(struct winesync_device *dev, void __user *argp, + bool pulse) +{ + struct winesync_event_args __user *user_args = argp; + struct winesync_event_args args; + struct winesync_obj *event; + bool prev_state; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + event = get_obj_typed(dev, args.event, WINESYNC_TYPE_EVENT); + if (!event) + return -EINVAL; + + if (atomic_read(&event->all_hint) > 0) { + spin_lock(&dev->wait_all_lock); + spin_lock(&event->lock); + + prev_state = event->u.event.signaled; + event->u.event.signaled = true; + try_wake_all_obj(dev, event); + try_wake_any_event(event); + if (pulse) + event->u.event.signaled = false; + + spin_unlock(&event->lock); + spin_unlock(&dev->wait_all_lock); + } else { + spin_lock(&event->lock); + + prev_state = event->u.event.signaled; + event->u.event.signaled = true; + try_wake_any_event(event); + if (pulse) + event->u.event.signaled = false; + + spin_unlock(&event->lock); + } + + put_obj(event); + + if (put_user(prev_state, &user_args->signaled)) + return -EFAULT; + + return 0; +} + +static int winesync_reset_event(struct winesync_device *dev, void __user *argp) +{ + struct winesync_event_args __user *user_args = argp; + struct winesync_event_args args; + struct winesync_obj *event; + bool prev_state; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + event = get_obj_typed(dev, args.event, WINESYNC_TYPE_EVENT); + if (!event) + return -EINVAL; + + spin_lock(&event->lock); + + prev_state = event->u.event.signaled; + event->u.event.signaled = false; + + spin_unlock(&event->lock); + + put_obj(event); + + if (put_user(prev_state, &user_args->signaled)) + return -EFAULT; + + return 0; +} + +static int winesync_schedule(const struct winesync_q *q, ktime_t *timeout) +{ + int ret = 0; + + do { + if (signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + + set_current_state(TASK_INTERRUPTIBLE); + if (atomic_read(&q->signaled) != -1) { + ret = 0; + break; + } + ret = schedule_hrtimeout(timeout, HRTIMER_MODE_ABS); + } while (ret < 0); + __set_current_state(TASK_RUNNING); + + return ret; +} + +/* + * Allocate and initialize the winesync_q structure, but do not queue us yet. + * Also, calculate the relative timeout. + */ +static int setup_wait(struct winesync_device *dev, + const struct winesync_wait_args *args, bool all, + ktime_t *ret_timeout, struct winesync_q **ret_q) +{ + const __u32 count = args->count; + struct winesync_q *q; + ktime_t timeout = 0; + __u32 total_count; + __u32 *ids; + __u32 i, j; + + if (!args->owner) + return -EINVAL; + + if (args->timeout) { + struct timespec64 to; + + if (get_timespec64(&to, u64_to_user_ptr(args->timeout))) + return -EFAULT; + if (!timespec64_valid(&to)) + return -EINVAL; + + timeout = timespec64_to_ns(&to); + } + + total_count = count; + if (args->alert) + total_count++; + + ids = kmalloc_array(total_count, sizeof(*ids), GFP_KERNEL); + if (!ids) + return -ENOMEM; + if (copy_from_user(ids, u64_to_user_ptr(args->objs), + array_size(count, sizeof(*ids)))) { + kfree(ids); + return -EFAULT; + } + if (args->alert) + ids[count] = args->alert; + + q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL); + if (!q) { + kfree(ids); + return -ENOMEM; + } + q->task = current; + q->owner = args->owner; + atomic_set(&q->signaled, -1); + q->all = all; + q->ownerdead = false; + q->count = count; + + for (i = 0; i < total_count; i++) { + struct winesync_q_entry *entry = &q->entries[i]; + struct winesync_obj *obj = get_obj(dev, ids[i]); + + if (!obj) + goto err; + + if (all) { + /* Check that the objects are all distinct. */ + for (j = 0; j < i; j++) { + if (obj == q->entries[j].obj) { + put_obj(obj); + goto err; + } + } + } + + entry->obj = obj; + entry->q = q; + entry->index = i; + } + + kfree(ids); + + *ret_q = q; + *ret_timeout = timeout; + return 0; + +err: + for (j = 0; j < i; j++) + put_obj(q->entries[j].obj); + kfree(ids); + kfree(q); + return -EINVAL; +} + +static void try_wake_any_obj(struct winesync_obj *obj) +{ + switch (obj->type) { + case WINESYNC_TYPE_SEM: + try_wake_any_sem(obj); + break; + case WINESYNC_TYPE_MUTEX: + try_wake_any_mutex(obj); + break; + case WINESYNC_TYPE_EVENT: + try_wake_any_event(obj); + break; + } +} + +static int winesync_wait_any(struct winesync_device *dev, void __user *argp) +{ + struct winesync_wait_args args; + struct winesync_q *q; + __u32 i, total_count; + ktime_t timeout; + int signaled; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + ret = setup_wait(dev, &args, false, &timeout, &q); + if (ret < 0) + return ret; + + total_count = args.count; + if (args.alert) + total_count++; + + /* queue ourselves */ + + for (i = 0; i < total_count; i++) { + struct winesync_q_entry *entry = &q->entries[i]; + struct winesync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_add_tail(&entry->node, &obj->any_waiters); + spin_unlock(&obj->lock); + } + + /* + * Check if we are already signaled. + * + * Note that the API requires that normal objects are checked before + * the alert event. Hence we queue the alert event last, and check + * objects in order. + */ + + for (i = 0; i < total_count; i++) { + struct winesync_obj *obj = q->entries[i].obj; + + if (atomic_read(&q->signaled) != -1) + break; + + spin_lock(&obj->lock); + try_wake_any_obj(obj); + spin_unlock(&obj->lock); + } + + /* sleep */ + + ret = winesync_schedule(q, args.timeout ? &timeout : NULL); + + /* and finally, unqueue */ + + for (i = 0; i < total_count; i++) { + struct winesync_q_entry *entry = &q->entries[i]; + struct winesync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_del(&entry->node); + spin_unlock(&obj->lock); + + put_obj(obj); + } + + signaled = atomic_read(&q->signaled); + if (signaled != -1) { + struct winesync_wait_args __user *user_args = argp; + + /* even if we caught a signal, we need to communicate success */ + ret = q->ownerdead ? -EOWNERDEAD : 0; + + if (put_user(signaled, &user_args->index)) + ret = -EFAULT; + } else if (!ret) { + ret = -ETIMEDOUT; + } + + kfree(q); + return ret; +} + +static int winesync_wait_all(struct winesync_device *dev, void __user *argp) +{ + struct winesync_wait_args args; + struct winesync_q *q; + ktime_t timeout; + int signaled; + __u32 i; + int ret; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + ret = setup_wait(dev, &args, true, &timeout, &q); + if (ret < 0) + return ret; + + /* queue ourselves */ + + spin_lock(&dev->wait_all_lock); + + for (i = 0; i < args.count; i++) { + struct winesync_q_entry *entry = &q->entries[i]; + struct winesync_obj *obj = entry->obj; + + atomic_inc(&obj->all_hint); + + /* + * obj->all_waiters is protected by dev->wait_all_lock rather + * than obj->lock, so there is no need to acquire it here. + */ + list_add_tail(&entry->node, &obj->all_waiters); + } + if (args.alert) { + struct winesync_q_entry *entry = &q->entries[args.count]; + struct winesync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_add_tail(&entry->node, &obj->any_waiters); + spin_unlock(&obj->lock); + } + + /* check if we are already signaled */ + + try_wake_all(dev, q, NULL); + + spin_unlock(&dev->wait_all_lock); + + /* + * Check if the alert event is signaled, making sure to do so only + * after checking if the other objects are signaled. + */ + + if (args.alert) { + struct winesync_obj *obj = q->entries[args.count].obj; + + if (atomic_read(&q->signaled) == -1) { + spin_lock(&obj->lock); + try_wake_any_obj(obj); + spin_unlock(&obj->lock); + } + } + + /* sleep */ + + ret = winesync_schedule(q, args.timeout ? &timeout : NULL); + + /* and finally, unqueue */ + + spin_lock(&dev->wait_all_lock); + + for (i = 0; i < args.count; i++) { + struct winesync_q_entry *entry = &q->entries[i]; + struct winesync_obj *obj = entry->obj; + + /* + * obj->all_waiters is protected by dev->wait_all_lock rather + * than obj->lock, so there is no need to acquire it here. + */ + list_del(&entry->node); + + atomic_dec(&obj->all_hint); + + put_obj(obj); + } + if (args.alert) { + struct winesync_q_entry *entry = &q->entries[args.count]; + struct winesync_obj *obj = entry->obj; + + spin_lock(&obj->lock); + list_del(&entry->node); + spin_unlock(&obj->lock); + + put_obj(obj); + } + + spin_unlock(&dev->wait_all_lock); + + signaled = atomic_read(&q->signaled); + if (signaled != -1) { + struct winesync_wait_args __user *user_args = argp; + + /* even if we caught a signal, we need to communicate success */ + ret = q->ownerdead ? -EOWNERDEAD : 0; + + if (put_user(signaled, &user_args->index)) + ret = -EFAULT; + } else if (!ret) { + ret = -ETIMEDOUT; + } + + kfree(q); + return ret; +} + +static long winesync_char_ioctl(struct file *file, unsigned int cmd, + unsigned long parm) +{ + struct winesync_device *dev = file->private_data; + void __user *argp = (void __user *)parm; + + switch (cmd) { + case WINESYNC_IOC_CREATE_EVENT: + return winesync_create_event(dev, argp); + case WINESYNC_IOC_CREATE_MUTEX: + return winesync_create_mutex(dev, argp); + case WINESYNC_IOC_CREATE_SEM: + return winesync_create_sem(dev, argp); + case WINESYNC_IOC_DELETE: + return winesync_delete(dev, argp); + case WINESYNC_IOC_KILL_OWNER: + return winesync_kill_owner(dev, argp); + case WINESYNC_IOC_PULSE_EVENT: + return winesync_set_event(dev, argp, true); + case WINESYNC_IOC_PUT_MUTEX: + return winesync_put_mutex(dev, argp); + case WINESYNC_IOC_PUT_SEM: + return winesync_put_sem(dev, argp); + case WINESYNC_IOC_READ_EVENT: + return winesync_read_event(dev, argp); + case WINESYNC_IOC_READ_MUTEX: + return winesync_read_mutex(dev, argp); + case WINESYNC_IOC_READ_SEM: + return winesync_read_sem(dev, argp); + case WINESYNC_IOC_RESET_EVENT: + return winesync_reset_event(dev, argp); + case WINESYNC_IOC_SET_EVENT: + return winesync_set_event(dev, argp, false); + case WINESYNC_IOC_WAIT_ALL: + return winesync_wait_all(dev, argp); + case WINESYNC_IOC_WAIT_ANY: + return winesync_wait_any(dev, argp); + default: + return -ENOSYS; + } +} + +static const struct file_operations winesync_fops = { + .owner = THIS_MODULE, + .open = winesync_char_open, + .release = winesync_char_release, + .unlocked_ioctl = winesync_char_ioctl, + .compat_ioctl = winesync_char_ioctl, + .llseek = no_llseek, +}; + +static struct miscdevice winesync_misc = { + .minor = WINESYNC_MINOR, + .name = WINESYNC_NAME, + .fops = &winesync_fops, +}; + +static int __init winesync_init(void) +{ + return misc_register(&winesync_misc); +} + +static void __exit winesync_exit(void) +{ + misc_deregister(&winesync_misc); +} + +module_init(winesync_init); +module_exit(winesync_exit); + +MODULE_AUTHOR("Zebediah Figura"); +MODULE_DESCRIPTION("Kernel driver for Wine synchronization primitives"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("devname:" WINESYNC_NAME); +MODULE_ALIAS_MISCDEV(WINESYNC_MINOR); diff --git a/include/linux/miscdevice.h b/include/linux/miscdevice.h index c0fea6ca5076..36fc5d5315a4 100644 --- a/include/linux/miscdevice.h +++ b/include/linux/miscdevice.h @@ -71,6 +71,7 @@ #define USERIO_MINOR 240 #define VHOST_VSOCK_MINOR 241 #define RFKILL_MINOR 242 +#define WINESYNC_MINOR 243 #define MISC_DYNAMIC_MINOR 255 struct device; diff --git a/include/uapi/linux/futex.h b/include/uapi/linux/futex.h index 71a5df8d2689..d375ab21cbf8 100644 --- a/include/uapi/linux/futex.h +++ b/include/uapi/linux/futex.h @@ -22,6 +22,7 @@ #define FUTEX_WAIT_REQUEUE_PI 11 #define FUTEX_CMP_REQUEUE_PI 12 #define FUTEX_LOCK_PI2 13 +#define FUTEX_WAIT_MULTIPLE 31 #define FUTEX_PRIVATE_FLAG 128 #define FUTEX_CLOCK_REALTIME 256 @@ -68,6 +69,18 @@ struct futex_waitv { __u32 __reserved; }; +/** + * struct futex_wait_block - Block of futexes to be waited for + * @uaddr: User address of the futex + * @val: Futex value expected by userspace + * @bitset: Bitset for the optional bitmasked wakeup + */ +struct futex_wait_block { + __u32 __user *uaddr; + __u32 val; + __u32 bitset; +}; + /* * Support for robust futexes: the kernel cleans up held futexes at * thread exit time. diff --git a/include/uapi/linux/winesync.h b/include/uapi/linux/winesync.h new file mode 100644 index 000000000000..5b4e369f7469 --- /dev/null +++ b/include/uapi/linux/winesync.h @@ -0,0 +1,71 @@ +/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */ +/* + * Kernel support for Wine synchronization primitives + * + * Copyright (C) 2021 Zebediah Figura + */ + +#ifndef __LINUX_WINESYNC_H +#define __LINUX_WINESYNC_H + +#include + +struct winesync_sem_args { + __u32 sem; + __u32 count; + __u32 max; +}; + +struct winesync_mutex_args { + __u32 mutex; + __u32 owner; + __u32 count; +}; + +struct winesync_event_args { + __u32 event; + __u32 manual; + __u32 signaled; +}; + +struct winesync_wait_args { + __u64 timeout; + __u64 objs; + __u32 count; + __u32 owner; + __u32 index; + __u32 alert; +}; + +#define WINESYNC_IOC_BASE 0xf7 + +#define WINESYNC_IOC_CREATE_SEM _IOWR(WINESYNC_IOC_BASE, 0, \ + struct winesync_sem_args) +#define WINESYNC_IOC_DELETE _IOW (WINESYNC_IOC_BASE, 1, __u32) +#define WINESYNC_IOC_PUT_SEM _IOWR(WINESYNC_IOC_BASE, 2, \ + struct winesync_sem_args) +#define WINESYNC_IOC_WAIT_ANY _IOWR(WINESYNC_IOC_BASE, 3, \ + struct winesync_wait_args) +#define WINESYNC_IOC_WAIT_ALL _IOWR(WINESYNC_IOC_BASE, 4, \ + struct winesync_wait_args) +#define WINESYNC_IOC_CREATE_MUTEX _IOWR(WINESYNC_IOC_BASE, 5, \ + struct winesync_mutex_args) +#define WINESYNC_IOC_PUT_MUTEX _IOWR(WINESYNC_IOC_BASE, 6, \ + struct winesync_mutex_args) +#define WINESYNC_IOC_KILL_OWNER _IOW (WINESYNC_IOC_BASE, 7, __u32) +#define WINESYNC_IOC_READ_SEM _IOWR(WINESYNC_IOC_BASE, 8, \ + struct winesync_sem_args) +#define WINESYNC_IOC_READ_MUTEX _IOWR(WINESYNC_IOC_BASE, 9, \ + struct winesync_mutex_args) +#define WINESYNC_IOC_CREATE_EVENT _IOWR(WINESYNC_IOC_BASE, 10, \ + struct winesync_event_args) +#define WINESYNC_IOC_SET_EVENT _IOWR(WINESYNC_IOC_BASE, 11, \ + struct winesync_event_args) +#define WINESYNC_IOC_RESET_EVENT _IOWR(WINESYNC_IOC_BASE, 12, \ + struct winesync_event_args) +#define WINESYNC_IOC_PULSE_EVENT _IOWR(WINESYNC_IOC_BASE, 13, \ + struct winesync_event_args) +#define WINESYNC_IOC_READ_EVENT _IOWR(WINESYNC_IOC_BASE, 14, \ + struct winesync_event_args) + +#endif diff --git a/kernel/futex/syscalls.c b/kernel/futex/syscalls.c index 086a22d1adb7..c6f5f1e84e09 100644 --- a/kernel/futex/syscalls.c +++ b/kernel/futex/syscalls.c @@ -142,6 +142,7 @@ static __always_inline bool futex_cmd_has_timeout(u32 cmd) case FUTEX_LOCK_PI2: case FUTEX_WAIT_BITSET: case FUTEX_WAIT_REQUEUE_PI: + case FUTEX_WAIT_MULTIPLE: return true; } return false; @@ -154,13 +155,79 @@ futex_init_timeout(u32 cmd, u32 op, struct timespec64 *ts, ktime_t *t) return -EINVAL; *t = timespec64_to_ktime(*ts); - if (cmd == FUTEX_WAIT) + if (cmd == FUTEX_WAIT || cmd == FUTEX_WAIT_MULTIPLE) *t = ktime_add_safe(ktime_get(), *t); else if (cmd != FUTEX_LOCK_PI && !(op & FUTEX_CLOCK_REALTIME)) *t = timens_ktime_to_host(CLOCK_MONOTONIC, *t); return 0; } +/** + * futex_read_wait_block - Read an array of futex_wait_block from userspace + * @uaddr: Userspace address of the block + * @count: Number of blocks to be read + * + * This function creates and allocate an array of futex_q (we zero it to + * initialize the fields) and then, for each futex_wait_block element from + * userspace, fill a futex_q element with proper values. + */ +inline struct futex_vector *futex_read_wait_block(u32 __user *uaddr, u32 count) +{ + unsigned int i; + struct futex_vector *futexv; + struct futex_wait_block fwb; + struct futex_wait_block __user *entry = + (struct futex_wait_block __user *)uaddr; + + if (!count || count > FUTEX_WAITV_MAX) + return ERR_PTR(-EINVAL); + + futexv = kcalloc(count, sizeof(*futexv), GFP_KERNEL); + if (!futexv) + return ERR_PTR(-ENOMEM); + + for (i = 0; i < count; i++) { + if (copy_from_user(&fwb, &entry[i], sizeof(fwb))) { + kfree(futexv); + return ERR_PTR(-EFAULT); + } + + futexv[i].w.flags = FUTEX_32; + futexv[i].w.val = fwb.val; + futexv[i].w.uaddr = (uintptr_t) (fwb.uaddr); + futexv[i].q = futex_q_init; + } + + return futexv; +} + +int futex_wait_multiple(struct futex_vector *vs, unsigned int count, + struct hrtimer_sleeper *to); + +int futex_opcode_31(ktime_t *abs_time, u32 __user *uaddr, int count) +{ + int ret; + struct futex_vector *vs; + struct hrtimer_sleeper *to = NULL, timeout; + + to = futex_setup_timer(abs_time, &timeout, 0, 0); + + vs = futex_read_wait_block(uaddr, count); + + if (IS_ERR(vs)) + return PTR_ERR(vs); + + ret = futex_wait_multiple(vs, count, abs_time ? to : NULL); + kfree(vs); + + if (to) { + hrtimer_cancel(&to->timer); + destroy_hrtimer_on_stack(&to->timer); + } + + return ret; +} + SYSCALL_DEFINE6(futex, u32 __user *, uaddr, int, op, u32, val, const struct __kernel_timespec __user *, utime, u32 __user *, uaddr2, u32, val3) @@ -180,6 +247,9 @@ SYSCALL_DEFINE6(futex, u32 __user *, uaddr, int, op, u32, val, tp = &t; } + if (cmd == FUTEX_WAIT_MULTIPLE) + return futex_opcode_31(tp, uaddr, val); + return do_futex(uaddr, op, val, tp, uaddr2, (unsigned long)utime, val3); } @@ -370,6 +440,9 @@ SYSCALL_DEFINE6(futex_time32, u32 __user *, uaddr, int, op, u32, val, tp = &t; } + if (cmd == FUTEX_WAIT_MULTIPLE) + return futex_opcode_31(tp, uaddr, val); + return do_futex(uaddr, op, val, tp, uaddr2, (unsigned long)utime, val3); } #endif /* CONFIG_COMPAT_32BIT_TIME */ diff --git a/tools/testing/selftests/Makefile b/tools/testing/selftests/Makefile index 1fc89b8ef433..c7d3d9f5e0a0 100644 --- a/tools/testing/selftests/Makefile +++ b/tools/testing/selftests/Makefile @@ -14,6 +14,7 @@ TARGETS += drivers/dma-buf TARGETS += drivers/s390x/uvdevice TARGETS += drivers/net/bonding TARGETS += drivers/net/team +TARGETS += drivers/winesync TARGETS += efivarfs TARGETS += exec TARGETS += filesystems diff --git a/tools/testing/selftests/drivers/winesync/Makefile b/tools/testing/selftests/drivers/winesync/Makefile new file mode 100644 index 000000000000..43b39fdeea10 --- /dev/null +++ b/tools/testing/selftests/drivers/winesync/Makefile @@ -0,0 +1,8 @@ +# SPDX-LICENSE-IDENTIFIER: GPL-2.0-only +TEST_GEN_PROGS := winesync + +top_srcdir =../../../../.. +CFLAGS += -I$(top_srcdir)/usr/include +LDLIBS += -lpthread + +include ../../lib.mk diff --git a/tools/testing/selftests/drivers/winesync/config b/tools/testing/selftests/drivers/winesync/config new file mode 100644 index 000000000000..60539c826d06 --- /dev/null +++ b/tools/testing/selftests/drivers/winesync/config @@ -0,0 +1 @@ +CONFIG_WINESYNC=y diff --git a/tools/testing/selftests/drivers/winesync/winesync.c b/tools/testing/selftests/drivers/winesync/winesync.c new file mode 100644 index 000000000000..169e922484b0 --- /dev/null +++ b/tools/testing/selftests/drivers/winesync/winesync.c @@ -0,0 +1,1479 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Various unit tests for the "winesync" synchronization primitive driver. + * + * Copyright (C) 2021 Zebediah Figura + */ + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include "../../kselftest_harness.h" + +static int read_sem_state(int fd, __u32 sem, __u32 *count, __u32 *max) +{ + struct winesync_sem_args args; + int ret; + + args.sem = sem; + args.count = 0xdeadbeef; + args.max = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_SEM, &args); + *count = args.count; + *max = args.max; + return ret; +} + +#define check_sem_state(fd, sem, count, max) \ + ({ \ + __u32 __count, __max; \ + int ret = read_sem_state((fd), (sem), &__count, &__max); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((count), __count); \ + EXPECT_EQ((max), __max); \ + }) + +static int put_sem(int fd, __u32 sem, __u32 *count) +{ + struct winesync_sem_args args; + int ret; + + args.sem = sem; + args.count = *count; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &args); + *count = args.count; + return ret; +} + +static int read_mutex_state(int fd, __u32 mutex, __u32 *count, __u32 *owner) +{ + struct winesync_mutex_args args; + int ret; + + args.mutex = mutex; + args.count = 0xdeadbeef; + args.owner = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &args); + *count = args.count; + *owner = args.owner; + return ret; +} + +#define check_mutex_state(fd, mutex, count, owner) \ + ({ \ + __u32 __count, __owner; \ + int ret = read_mutex_state((fd), (mutex), &__count, &__owner); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((count), __count); \ + EXPECT_EQ((owner), __owner); \ + }) + +static int put_mutex(int fd, __u32 mutex, __u32 owner, __u32 *count) +{ + struct winesync_mutex_args args; + int ret; + + args.mutex = mutex; + args.owner = owner; + args.count = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_PUT_MUTEX, &args); + *count = args.count; + return ret; +} + +static int read_event_state(int fd, __u32 event, __u32 *signaled, __u32 *manual) +{ + struct winesync_event_args args; + int ret; + + args.event = event; + args.signaled = 0xdeadbeef; + args.manual = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_EVENT, &args); + *signaled = args.signaled; + *manual = args.manual; + return ret; +} + +#define check_event_state(fd, event, signaled, manual) \ + ({ \ + __u32 __signaled, __manual; \ + int ret = read_event_state((fd), (event), \ + &__signaled, &__manual); \ + EXPECT_EQ(0, ret); \ + EXPECT_EQ((signaled), __signaled); \ + EXPECT_EQ((manual), __manual); \ + }) + +static int wait_objs(int fd, unsigned long request, __u32 count, + const __u32 *objs, __u32 owner, __u32 alert, __u32 *index) +{ + struct winesync_wait_args args = {0}; + struct timespec timeout; + int ret; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + args.timeout = (uintptr_t)&timeout; + args.count = count; + args.objs = (uintptr_t)objs; + args.owner = owner; + args.index = 0xdeadbeef; + args.alert = alert; + ret = ioctl(fd, request, &args); + *index = args.index; + return ret; +} + +static int wait_any(int fd, __u32 count, const __u32 *objs, + __u32 owner, __u32 *index) +{ + return wait_objs(fd, WINESYNC_IOC_WAIT_ANY, + count, objs, owner, 0, index); +} + +static int wait_all(int fd, __u32 count, const __u32 *objs, + __u32 owner, __u32 *index) +{ + return wait_objs(fd, WINESYNC_IOC_WAIT_ALL, + count, objs, owner, 0, index); +} + +static int wait_any_alert(int fd, __u32 count, const __u32 *objs, + __u32 owner, __u32 alert, __u32 *index) +{ + return wait_objs(fd, WINESYNC_IOC_WAIT_ANY, + count, objs, owner, alert, index); +} + +static int wait_all_alert(int fd, __u32 count, const __u32 *objs, + __u32 owner, __u32 alert, __u32 *index) +{ + return wait_objs(fd, WINESYNC_IOC_WAIT_ALL, + count, objs, owner, alert, index); +} + +TEST(semaphore_state) +{ + struct winesync_sem_args sem_args; + struct timespec timeout; + __u32 sem, count, index; + int fd, ret; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 3; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + sem_args.count = 2; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + check_sem_state(fd, sem, 2, 2); + + count = 0; + ret = put_sem(fd, sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_sem_state(fd, sem, 2, 2); + + count = 1; + ret = put_sem(fd, sem, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOVERFLOW, errno); + check_sem_state(fd, sem, 2, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem, 1, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem, 0, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + count = 3; + ret = put_sem(fd, sem, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOVERFLOW, errno); + check_sem_state(fd, sem, 0, 2); + + count = 2; + ret = put_sem(fd, sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(fd, sem, 2, 2); + + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + ret = wait_any(fd, 1, &sem, 123, &index); + EXPECT_EQ(0, ret); + + count = 1; + ret = put_sem(fd, sem, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, count); + check_sem_state(fd, sem, 1, 2); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(mutex_state) +{ + struct winesync_mutex_args mutex_args; + __u32 mutex, owner, count, index; + struct timespec timeout; + int fd, ret; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + mutex_args.owner = 123; + mutex_args.count = 0; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + mutex_args.owner = 0; + mutex_args.count = 2; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + mutex_args.owner = 123; + mutex_args.count = 2; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + mutex = mutex_args.mutex; + check_mutex_state(fd, mutex, 2, 123); + + ret = put_mutex(fd, mutex, 0, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = put_mutex(fd, mutex, 456, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EPERM, errno); + check_mutex_state(fd, mutex, 2, 123); + + ret = put_mutex(fd, mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_mutex_state(fd, mutex, 1, 123); + + ret = put_mutex(fd, mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, count); + check_mutex_state(fd, mutex, 0, 0); + + ret = put_mutex(fd, mutex, 123, &count); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EPERM, errno); + + ret = wait_any(fd, 1, &mutex, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(fd, mutex, 1, 456); + + ret = wait_any(fd, 1, &mutex, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(fd, mutex, 2, 456); + + ret = put_mutex(fd, mutex, 456, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + check_mutex_state(fd, mutex, 1, 456); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + owner = 0; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + owner = 123; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(0, ret); + check_mutex_state(fd, mutex, 1, 456); + + owner = 456; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(0, ret); + + mutex_args.count = 0xdeadbeef; + mutex_args.owner = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + mutex_args.count = 0xdeadbeef; + mutex_args.owner = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, index); + check_mutex_state(fd, mutex, 1, 123); + + owner = 123; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(0, ret); + + mutex_args.count = 0xdeadbeef; + mutex_args.owner = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, mutex_args.count); + EXPECT_EQ(0, mutex_args.owner); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(0, index); + check_mutex_state(fd, mutex, 1, 123); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex); + EXPECT_EQ(0, ret); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + mutex = mutex_args.mutex; + check_mutex_state(fd, mutex, 0, 0); + + ret = wait_any(fd, 1, &mutex, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_mutex_state(fd, mutex, 1, 123); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(manual_event_state) +{ + struct winesync_event_args event_args; + __u32 index; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + event_args.manual = 1; + event_args.signaled = 0; + event_args.event = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, event_args.event); + check_event_state(fd, event_args.event, 0, 1); + + event_args.signaled = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 1, 1); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + check_event_state(fd, event_args.event, 1, 1); + + ret = wait_any(fd, 1, &event_args.event, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_event_state(fd, event_args.event, 1, 1); + + event_args.signaled = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + check_event_state(fd, event_args.event, 0, 1); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 1); + + ret = wait_any(fd, 1, &event_args.event, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + check_event_state(fd, event_args.event, 0, 1); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 1); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(auto_event_state) +{ + struct winesync_event_args event_args; + __u32 index; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + event_args.manual = 0; + event_args.signaled = 1; + event_args.event = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, event_args.event); + + check_event_state(fd, event_args.event, 1, 0); + + event_args.signaled = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + check_event_state(fd, event_args.event, 1, 0); + + ret = wait_any(fd, 1, &event_args.event, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_event_state(fd, event_args.event, 0, 0); + + event_args.signaled = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 0); + + ret = wait_any(fd, 1, &event_args.event, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + check_event_state(fd, event_args.event, 0, 0); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 0); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(test_wait_any) +{ + struct winesync_mutex_args mutex_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + __u32 objs[2], owner, index; + struct timespec timeout; + int fd, ret; + + clock_gettime(CLOCK_MONOTONIC, &timeout); + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 1, 3); + check_mutex_state(fd, mutex_args.mutex, 0, 0); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 0, 0); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 123); + + sem_args.count = 1; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 123); + + ret = wait_any(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 2, 123); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + owner = 123; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(0, ret); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + EXPECT_EQ(1, index); + + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + + /* test waiting on the same object twice */ + sem_args.count = 2; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + + objs[0] = objs[1] = sem_args.sem; + ret = wait_any(fd, 2, objs, 456, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, wait_args.index); + check_sem_state(fd, sem_args.sem, 1, 3); + + ret = wait_any(fd, 0, NULL, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem_args.sem); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(test_wait_all) +{ + struct winesync_event_args event_args = {0}; + struct winesync_mutex_args mutex_args = {0}; + struct winesync_sem_args sem_args = {0}; + __u32 objs[2], owner, index; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 0; + mutex_args.count = 0; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 1, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 123); + + ret = wait_all(fd, 2, objs, 456, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + check_sem_state(fd, sem_args.sem, 1, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 123); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 2, 123); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + check_sem_state(fd, sem_args.sem, 0, 3); + check_mutex_state(fd, mutex_args.mutex, 2, 123); + + sem_args.count = 3; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 2, 3); + check_mutex_state(fd, mutex_args.mutex, 3, 123); + + owner = 123; + ret = ioctl(fd, WINESYNC_IOC_KILL_OWNER, &owner); + EXPECT_EQ(0, ret); + + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EOWNERDEAD, errno); + check_sem_state(fd, sem_args.sem, 1, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 123); + + objs[0] = sem_args.sem; + objs[1] = event_args.event; + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + check_sem_state(fd, sem_args.sem, 0, 3); + check_event_state(fd, event_args.event, 1, 1); + + /* test waiting on the same object twice */ + objs[0] = objs[1] = sem_args.sem; + ret = wait_all(fd, 2, objs, 123, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem_args.sem); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(invalid_objects) +{ + struct winesync_event_args event_args = {0}; + struct winesync_mutex_args mutex_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + __u32 objs[2] = {0}; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_PUT_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + wait_args.objs = (uintptr_t)objs; + wait_args.count = 1; + ret = ioctl(fd, WINESYNC_IOC_WAIT_ANY, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + ret = ioctl(fd, WINESYNC_IOC_WAIT_ALL, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &objs[0]); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + sem_args.max = 1; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + + mutex_args.mutex = sem_args.sem; + ret = ioctl(fd, WINESYNC_IOC_PUT_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_MUTEX, &mutex_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + event_args.event = sem_args.sem; + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_EVENT, &event_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + objs[0] = sem_args.sem; + objs[1] = sem_args.sem + 1; + wait_args.count = 2; + ret = ioctl(fd, WINESYNC_IOC_WAIT_ANY, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + ret = ioctl(fd, WINESYNC_IOC_WAIT_ALL, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + objs[0] = sem_args.sem + 1; + objs[1] = sem_args.sem; + ret = ioctl(fd, WINESYNC_IOC_WAIT_ANY, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + ret = ioctl(fd, WINESYNC_IOC_WAIT_ALL, &wait_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem_args.sem); + EXPECT_EQ(0, ret); + + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + + sem_args.sem = mutex_args.mutex; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_READ_SEM, &sem_args); + EXPECT_EQ(-1, ret); + EXPECT_EQ(EINVAL, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + + close(fd); +} + +struct wake_args +{ + int fd; + __u32 obj; +}; + +struct wait_args +{ + int fd; + unsigned long request; + struct winesync_wait_args *args; + int ret; + int err; +}; + +static void *wait_thread(void *arg) +{ + struct wait_args *args = arg; + + args->ret = ioctl(args->fd, args->request, args->args); + args->err = errno; + return NULL; +} + +static void get_abs_timeout(struct timespec *timeout, clockid_t clock, + unsigned int ms) +{ + clock_gettime(clock, timeout); + timeout->tv_nsec += ms * 1000000; + timeout->tv_sec += (timeout->tv_nsec / 1000000000); + timeout->tv_nsec %= 1000000000; +} + +static int wait_for_thread(pthread_t thread, unsigned int ms) +{ + struct timespec timeout; + get_abs_timeout(&timeout, CLOCK_REALTIME, ms); + return pthread_timedjoin_np(thread, NULL, &timeout); +} + +TEST(wake_any) +{ + struct winesync_event_args event_args = {0}; + struct winesync_mutex_args mutex_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + struct wait_args thread_args; + __u32 objs[2], count, index; + struct timespec timeout; + pthread_t thread; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 123; + mutex_args.count = 1; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + + /* test waking the semaphore */ + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + wait_args.timeout = (uintptr_t)&timeout; + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 456; + wait_args.index = 0xdeadbeef; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = WINESYNC_IOC_WAIT_ANY; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + sem_args.count = 1; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + check_sem_state(fd, sem_args.sem, 0, 3); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(0, wait_args.index); + + /* test waking the mutex */ + + /* first grab it again for owner 123 */ + ret = wait_any(fd, 1, &mutex_args.mutex, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + wait_args.owner = 456; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = put_mutex(fd, mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + ret = put_mutex(fd, mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, mutex_args.count); + check_mutex_state(fd, mutex_args.mutex, 1, 456); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + /* test waking events */ + + event_args.manual = false; + event_args.signaled = false; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[1] = event_args.event; + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + event_args.manual = true; + event_args.signaled = false; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + objs[1] = event_args.event; + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 1, 1); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, event_args.signaled); + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_PULSE_EVENT, &event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, event_args.signaled); + check_event_state(fd, event_args.event, 0, 1); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(1, wait_args.index); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + /* delete an object while it's being waited on */ + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 200); + wait_args.owner = 123; + objs[1] = mutex_args.mutex; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem_args.sem); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 200); + EXPECT_EQ(0, ret); + EXPECT_EQ(-1, thread_args.ret); + EXPECT_EQ(ETIMEDOUT, thread_args.err); + + close(fd); +} + +TEST(wake_all) +{ + struct winesync_event_args manual_event_args = {0}; + struct winesync_event_args auto_event_args = {0}; + struct winesync_mutex_args mutex_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + struct wait_args thread_args; + __u32 objs[4], count, index; + struct timespec timeout; + pthread_t thread; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 3; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + + mutex_args.owner = 123; + mutex_args.count = 1; + mutex_args.mutex = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_MUTEX, &mutex_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, mutex_args.mutex); + + manual_event_args.manual = true; + manual_event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &manual_event_args); + EXPECT_EQ(0, ret); + + auto_event_args.manual = false; + auto_event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &auto_event_args); + EXPECT_EQ(0, ret); + + objs[0] = sem_args.sem; + objs[1] = mutex_args.mutex; + objs[2] = manual_event_args.event; + objs[3] = auto_event_args.event; + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + wait_args.timeout = (uintptr_t)&timeout; + wait_args.objs = (uintptr_t)objs; + wait_args.count = 4; + wait_args.owner = 456; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = WINESYNC_IOC_WAIT_ALL; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + sem_args.count = 1; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + check_sem_state(fd, sem_args.sem, 1, 3); + + ret = wait_any(fd, 1, &sem_args.sem, 123, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = put_mutex(fd, mutex_args.mutex, 123, &count); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, count); + + ret = pthread_tryjoin_np(thread, NULL); + EXPECT_EQ(EBUSY, ret); + + check_mutex_state(fd, mutex_args.mutex, 0, 0); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &manual_event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, manual_event_args.signaled); + + sem_args.count = 2; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, sem_args.count); + check_sem_state(fd, sem_args.sem, 2, 3); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &auto_event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, auto_event_args.signaled); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &manual_event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, manual_event_args.signaled); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &auto_event_args); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, auto_event_args.signaled); + + check_sem_state(fd, sem_args.sem, 1, 3); + check_mutex_state(fd, mutex_args.mutex, 1, 456); + check_event_state(fd, manual_event_args.event, 1, 1); + check_event_state(fd, auto_event_args.event, 0, 0); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + + /* delete an object while it's being waited on */ + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 200); + wait_args.owner = 123; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &sem_args.sem); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &mutex_args.mutex); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &manual_event_args.event); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &auto_event_args.event); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 200); + EXPECT_EQ(0, ret); + EXPECT_EQ(-1, thread_args.ret); + EXPECT_EQ(ETIMEDOUT, thread_args.err); + + close(fd); +} + +TEST(alert_any) +{ + struct winesync_event_args event_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + struct wait_args thread_args; + struct timespec timeout; + __u32 objs[2], index; + pthread_t thread; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 0; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[0] = sem_args.sem; + + sem_args.count = 1; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[1] = sem_args.sem; + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 0, NULL, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(1, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + /* test wakeup via alert */ + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + wait_args.timeout = (uintptr_t)&timeout; + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 123; + wait_args.index = 0xdeadbeef; + wait_args.alert = event_args.event; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = WINESYNC_IOC_WAIT_ANY; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(2, wait_args.index); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + /* test with an auto-reset event */ + + event_args.manual = false; + event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + sem_args.sem = objs[0]; + sem_args.count = 1; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + ret = wait_any_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &objs[0]); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &objs[1]); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST(alert_all) +{ + struct winesync_event_args event_args = {0}; + struct winesync_wait_args wait_args = {0}; + struct winesync_sem_args sem_args = {0}; + struct wait_args thread_args; + struct timespec timeout; + __u32 objs[2], index; + pthread_t thread; + int fd, ret; + + fd = open("/dev/winesync", O_CLOEXEC | O_RDONLY); + ASSERT_LE(0, fd); + + sem_args.count = 2; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[0] = sem_args.sem; + + sem_args.count = 1; + sem_args.max = 2; + sem_args.sem = 0xdeadbeef; + ret = ioctl(fd, WINESYNC_IOC_CREATE_SEM, &sem_args); + EXPECT_EQ(0, ret); + EXPECT_NE(0xdeadbeef, sem_args.sem); + objs[1] = sem_args.sem; + + event_args.manual = true; + event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + /* test wakeup via alert */ + + ret = ioctl(fd, WINESYNC_IOC_RESET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + get_abs_timeout(&timeout, CLOCK_MONOTONIC, 1000); + wait_args.timeout = (uintptr_t)&timeout; + wait_args.objs = (uintptr_t)objs; + wait_args.count = 2; + wait_args.owner = 123; + wait_args.index = 0xdeadbeef; + wait_args.alert = event_args.event; + thread_args.fd = fd; + thread_args.args = &wait_args; + thread_args.request = WINESYNC_IOC_WAIT_ALL; + ret = pthread_create(&thread, NULL, wait_thread, &thread_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(ETIMEDOUT, ret); + + ret = ioctl(fd, WINESYNC_IOC_SET_EVENT, &event_args); + EXPECT_EQ(0, ret); + + ret = wait_for_thread(thread, 100); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, thread_args.ret); + EXPECT_EQ(2, wait_args.index); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + /* test with an auto-reset event */ + + event_args.manual = false; + event_args.signaled = true; + ret = ioctl(fd, WINESYNC_IOC_CREATE_EVENT, &event_args); + EXPECT_EQ(0, ret); + + sem_args.sem = objs[1]; + sem_args.count = 2; + ret = ioctl(fd, WINESYNC_IOC_PUT_SEM, &sem_args); + EXPECT_EQ(0, ret); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(0, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, index); + + ret = wait_all_alert(fd, 2, objs, 123, event_args.event, &index); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &event_args.event); + EXPECT_EQ(0, ret); + + ret = ioctl(fd, WINESYNC_IOC_DELETE, &objs[0]); + EXPECT_EQ(0, ret); + ret = ioctl(fd, WINESYNC_IOC_DELETE, &objs[1]); + EXPECT_EQ(0, ret); + + close(fd); +} + +TEST_HARNESS_MAIN -- 2.39.0.rc2 From f4227dcccd7b662ef6d686d05a57766634798d83 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Fri, 5 Aug 2022 19:33:47 +0200 Subject: [PATCH 04/20] Introducing-OpenVPN-Data-Channel-Offload Signed-off-by: Peter Jung --- MAINTAINERS | 8 + drivers/net/Kconfig | 19 + drivers/net/Makefile | 1 + drivers/net/ovpn-dco/Makefile | 21 + drivers/net/ovpn-dco/addr.h | 41 + drivers/net/ovpn-dco/bind.c | 62 ++ drivers/net/ovpn-dco/bind.h | 67 ++ drivers/net/ovpn-dco/crypto.c | 154 ++++ drivers/net/ovpn-dco/crypto.h | 144 ++++ drivers/net/ovpn-dco/crypto_aead.c | 367 +++++++++ drivers/net/ovpn-dco/crypto_aead.h | 27 + drivers/net/ovpn-dco/main.c | 271 +++++++ drivers/net/ovpn-dco/main.h | 32 + drivers/net/ovpn-dco/netlink.c | 1143 ++++++++++++++++++++++++++++ drivers/net/ovpn-dco/netlink.h | 22 + drivers/net/ovpn-dco/ovpn.c | 600 +++++++++++++++ drivers/net/ovpn-dco/ovpn.h | 43 ++ drivers/net/ovpn-dco/ovpnstruct.h | 59 ++ drivers/net/ovpn-dco/peer.c | 906 ++++++++++++++++++++++ drivers/net/ovpn-dco/peer.h | 168 ++++ drivers/net/ovpn-dco/pktid.c | 127 ++++ drivers/net/ovpn-dco/pktid.h | 116 +++ drivers/net/ovpn-dco/proto.h | 101 +++ drivers/net/ovpn-dco/rcu.h | 21 + drivers/net/ovpn-dco/skb.h | 54 ++ drivers/net/ovpn-dco/sock.c | 134 ++++ drivers/net/ovpn-dco/sock.h | 54 ++ drivers/net/ovpn-dco/stats.c | 20 + drivers/net/ovpn-dco/stats.h | 67 ++ drivers/net/ovpn-dco/tcp.c | 326 ++++++++ drivers/net/ovpn-dco/tcp.h | 38 + drivers/net/ovpn-dco/udp.c | 343 +++++++++ drivers/net/ovpn-dco/udp.h | 25 + include/net/netlink.h | 1 + include/uapi/linux/ovpn_dco.h | 265 +++++++ include/uapi/linux/udp.h | 1 + 36 files changed, 5848 insertions(+) create mode 100644 drivers/net/ovpn-dco/Makefile create mode 100644 drivers/net/ovpn-dco/addr.h create mode 100644 drivers/net/ovpn-dco/bind.c create mode 100644 drivers/net/ovpn-dco/bind.h create mode 100644 drivers/net/ovpn-dco/crypto.c create mode 100644 drivers/net/ovpn-dco/crypto.h create mode 100644 drivers/net/ovpn-dco/crypto_aead.c create mode 100644 drivers/net/ovpn-dco/crypto_aead.h create mode 100644 drivers/net/ovpn-dco/main.c create mode 100644 drivers/net/ovpn-dco/main.h create mode 100644 drivers/net/ovpn-dco/netlink.c create mode 100644 drivers/net/ovpn-dco/netlink.h create mode 100644 drivers/net/ovpn-dco/ovpn.c create mode 100644 drivers/net/ovpn-dco/ovpn.h create mode 100644 drivers/net/ovpn-dco/ovpnstruct.h create mode 100644 drivers/net/ovpn-dco/peer.c create mode 100644 drivers/net/ovpn-dco/peer.h create mode 100644 drivers/net/ovpn-dco/pktid.c create mode 100644 drivers/net/ovpn-dco/pktid.h create mode 100644 drivers/net/ovpn-dco/proto.h create mode 100644 drivers/net/ovpn-dco/rcu.h create mode 100644 drivers/net/ovpn-dco/skb.h create mode 100644 drivers/net/ovpn-dco/sock.c create mode 100644 drivers/net/ovpn-dco/sock.h create mode 100644 drivers/net/ovpn-dco/stats.c create mode 100644 drivers/net/ovpn-dco/stats.h create mode 100644 drivers/net/ovpn-dco/tcp.c create mode 100644 drivers/net/ovpn-dco/tcp.h create mode 100644 drivers/net/ovpn-dco/udp.c create mode 100644 drivers/net/ovpn-dco/udp.h create mode 100644 include/uapi/linux/ovpn_dco.h diff --git a/MAINTAINERS b/MAINTAINERS index c7461d56676c..603920f452d4 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -15319,6 +15319,14 @@ T: git git://git.kernel.org/pub/scm/linux/kernel/git/mszeredi/vfs.git F: Documentation/filesystems/overlayfs.rst F: fs/overlayfs/ +OVPN-DCO NETWORK DRIVER +M: Antonio Quartulli +L: openvpn-devel@lists.sourceforge.net (moderated for non-subscribers) +L: netdev@vger.kernel.org +S: Maintained +F: drivers/net/ovpn-dco/ +F: include/uapi/linux/ovpn_dco.h + P54 WIRELESS DRIVER M: Christian Lamparter L: linux-wireless@vger.kernel.org diff --git a/drivers/net/Kconfig b/drivers/net/Kconfig index 94c889802566..349866bd4448 100644 --- a/drivers/net/Kconfig +++ b/drivers/net/Kconfig @@ -116,6 +116,25 @@ config WIREGUARD_DEBUG Say N here unless you know what you're doing. +config OVPN_DCO + tristate "OpenVPN data channel offload" + depends on NET && INET + select NET_UDP_TUNNEL + select DST_CACHE + select CRYPTO + select CRYPTO_AES + select CRYPTO_GCM + select CRYPTO_CHACHA20POLY1305 + help + This module enhances the performance of an OpenVPN connection by + allowing the user to offload the data channel processing to + kernelspace. + Connection handshake, parameters negotiation and other non-data + related mechanisms are still performed in userspace. + + The OpenVPN userspace software at version 2.6 or higher is required + to use this functionality. + config EQUALIZER tristate "EQL (serial line load balancing) support" help diff --git a/drivers/net/Makefile b/drivers/net/Makefile index 3f1192d3c52d..8ed151e8d233 100644 --- a/drivers/net/Makefile +++ b/drivers/net/Makefile @@ -11,6 +11,7 @@ obj-$(CONFIG_IPVLAN) += ipvlan/ obj-$(CONFIG_IPVTAP) += ipvlan/ obj-$(CONFIG_DUMMY) += dummy.o obj-$(CONFIG_WIREGUARD) += wireguard/ +obj-$(CONFIG_OVPN_DCO) += ovpn-dco/ obj-$(CONFIG_EQUALIZER) += eql.o obj-$(CONFIG_IFB) += ifb.o obj-$(CONFIG_MACSEC) += macsec.o diff --git a/drivers/net/ovpn-dco/Makefile b/drivers/net/ovpn-dco/Makefile new file mode 100644 index 000000000000..7efefe8f13a9 --- /dev/null +++ b/drivers/net/ovpn-dco/Makefile @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# ovpn-dco -- OpenVPN data channel offload +# +# Copyright (C) 2020-2022 OpenVPN, Inc. +# +# Author: Antonio Quartulli + +obj-$(CONFIG_OVPN_DCO) += ovpn-dco.o +ovpn-dco-y += main.o +ovpn-dco-y += bind.o +ovpn-dco-y += crypto.o +ovpn-dco-y += ovpn.o +ovpn-dco-y += peer.o +ovpn-dco-y += sock.o +ovpn-dco-y += stats.o +ovpn-dco-y += netlink.o +ovpn-dco-y += crypto_aead.o +ovpn-dco-y += pktid.o +ovpn-dco-y += tcp.o +ovpn-dco-y += udp.o diff --git a/drivers/net/ovpn-dco/addr.h b/drivers/net/ovpn-dco/addr.h new file mode 100644 index 000000000000..3d6ad0fc15af --- /dev/null +++ b/drivers/net/ovpn-dco/addr.h @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNADDR_H_ +#define _NET_OVPN_DCO_OVPNADDR_H_ + +#include "crypto.h" + +#include +#include +#include +#include + +/* our basic transport layer address */ +struct ovpn_sockaddr { + union { + struct sockaddr_in in4; + struct sockaddr_in6 in6; + }; +}; + +/* Translate skb->protocol value to AF_INET or AF_INET6 */ +static inline unsigned short skb_protocol_to_family(const struct sk_buff *skb) +{ + switch (skb->protocol) { + case htons(ETH_P_IP): + return AF_INET; + case htons(ETH_P_IPV6): + return AF_INET6; + default: + return 0; + } +} + +#endif /* _NET_OVPN_DCO_OVPNADDR_H_ */ diff --git a/drivers/net/ovpn-dco/bind.c b/drivers/net/ovpn-dco/bind.c new file mode 100644 index 000000000000..107697ea983e --- /dev/null +++ b/drivers/net/ovpn-dco/bind.c @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2012-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "ovpn.h" +#include "bind.h" +#include "peer.h" + +#include +#include +#include +#include + +/* Given a remote sockaddr, compute the skb hash + * and get a dst_entry so we can send packets to the remote. + * Called from process context or softirq (must be indicated with + * process_context bool). + */ +struct ovpn_bind *ovpn_bind_from_sockaddr(const struct sockaddr_storage *ss) +{ + struct ovpn_bind *bind; + size_t sa_len; + + if (ss->ss_family == AF_INET) + sa_len = sizeof(struct sockaddr_in); + else if (ss->ss_family == AF_INET6) + sa_len = sizeof(struct sockaddr_in6); + else + return ERR_PTR(-EAFNOSUPPORT); + + bind = kzalloc(sizeof(*bind), GFP_ATOMIC); + if (unlikely(!bind)) + return ERR_PTR(-ENOMEM); + + memcpy(&bind->sa, ss, sa_len); + + return bind; +} + +static void ovpn_bind_release_rcu(struct rcu_head *head) +{ + struct ovpn_bind *bind = container_of(head, struct ovpn_bind, rcu); + + kfree(bind); +} + +void ovpn_bind_reset(struct ovpn_peer *peer, struct ovpn_bind *new) +{ + struct ovpn_bind *old; + + spin_lock_bh(&peer->lock); + old = rcu_replace_pointer(peer->bind, new, true); + spin_unlock_bh(&peer->lock); + + if (old) + call_rcu(&old->rcu, ovpn_bind_release_rcu); +} diff --git a/drivers/net/ovpn-dco/bind.h b/drivers/net/ovpn-dco/bind.h new file mode 100644 index 000000000000..a562e471acae --- /dev/null +++ b/drivers/net/ovpn-dco/bind.h @@ -0,0 +1,67 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OVPN -- OpenVPN protocol accelerator for Linux + * Copyright (C) 2012-2022 OpenVPN, Inc. + * All rights reserved. + * Author: James Yonan + */ + +#ifndef _NET_OVPN_DCO_OVPNBIND_H_ +#define _NET_OVPN_DCO_OVPNBIND_H_ + +#include "addr.h" +#include "rcu.h" + +#include +#include +#include + +struct ovpn_peer; + +struct ovpn_bind { + struct ovpn_sockaddr sa; /* remote sockaddr */ + + union { + struct in_addr ipv4; + struct in6_addr ipv6; + } local; + + struct rcu_head rcu; +}; + +static inline bool ovpn_bind_skb_src_match(const struct ovpn_bind *bind, struct sk_buff *skb) +{ + const unsigned short family = skb_protocol_to_family(skb); + const struct ovpn_sockaddr *sa = &bind->sa; + + if (unlikely(!bind)) + return false; + + if (unlikely(sa->in4.sin_family != family)) + return false; + + switch (family) { + case AF_INET: + if (unlikely(sa->in4.sin_addr.s_addr != ip_hdr(skb)->saddr)) + return false; + + if (unlikely(sa->in4.sin_port != udp_hdr(skb)->source)) + return false; + break; + case AF_INET6: + if (unlikely(!ipv6_addr_equal(&sa->in6.sin6_addr, &ipv6_hdr(skb)->saddr))) + return false; + + if (unlikely(sa->in6.sin6_port != udp_hdr(skb)->source)) + return false; + break; + default: + return false; + } + + return true; +} + +struct ovpn_bind *ovpn_bind_from_sockaddr(const struct sockaddr_storage *sa); +void ovpn_bind_reset(struct ovpn_peer *peer, struct ovpn_bind *bind); + +#endif /* _NET_OVPN_DCO_OVPNBIND_H_ */ diff --git a/drivers/net/ovpn-dco/crypto.c b/drivers/net/ovpn-dco/crypto.c new file mode 100644 index 000000000000..fcc3a351ba9d --- /dev/null +++ b/drivers/net/ovpn-dco/crypto.c @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "main.h" +#include "crypto_aead.h" +#include "crypto.h" + +#include + +static void ovpn_ks_destroy_rcu(struct rcu_head *head) +{ + struct ovpn_crypto_key_slot *ks; + + ks = container_of(head, struct ovpn_crypto_key_slot, rcu); + ovpn_aead_crypto_key_slot_destroy(ks); +} + +void ovpn_crypto_key_slot_release(struct kref *kref) +{ + struct ovpn_crypto_key_slot *ks; + + ks = container_of(kref, struct ovpn_crypto_key_slot, refcount); + call_rcu(&ks->rcu, ovpn_ks_destroy_rcu); +} + +/* can only be invoked when all peer references have been dropped (i.e. RCU + * release routine) + */ +void ovpn_crypto_state_release(struct ovpn_crypto_state *cs) +{ + struct ovpn_crypto_key_slot *ks; + + ks = rcu_access_pointer(cs->primary); + if (ks) { + RCU_INIT_POINTER(cs->primary, NULL); + ovpn_crypto_key_slot_put(ks); + } + + ks = rcu_access_pointer(cs->secondary); + if (ks) { + RCU_INIT_POINTER(cs->secondary, NULL); + ovpn_crypto_key_slot_put(ks); + } + + mutex_destroy(&cs->mutex); +} + +/* removes the primary key from the crypto context */ +void ovpn_crypto_kill_primary(struct ovpn_crypto_state *cs) +{ + struct ovpn_crypto_key_slot *ks; + + mutex_lock(&cs->mutex); + ks = rcu_replace_pointer(cs->primary, NULL, lockdep_is_held(&cs->mutex)); + ovpn_crypto_key_slot_put(ks); + mutex_unlock(&cs->mutex); +} + +/* Reset the ovpn_crypto_state object in a way that is atomic + * to RCU readers. + */ +int ovpn_crypto_state_reset(struct ovpn_crypto_state *cs, + const struct ovpn_peer_key_reset *pkr) + __must_hold(cs->mutex) +{ + struct ovpn_crypto_key_slot *old = NULL; + struct ovpn_crypto_key_slot *new; + + lockdep_assert_held(&cs->mutex); + + new = ovpn_aead_crypto_key_slot_new(&pkr->key); + if (IS_ERR(new)) + return PTR_ERR(new); + + switch (pkr->slot) { + case OVPN_KEY_SLOT_PRIMARY: + old = rcu_replace_pointer(cs->primary, new, + lockdep_is_held(&cs->mutex)); + break; + case OVPN_KEY_SLOT_SECONDARY: + old = rcu_replace_pointer(cs->secondary, new, + lockdep_is_held(&cs->mutex)); + break; + default: + goto free_key; + } + + if (old) + ovpn_crypto_key_slot_put(old); + + return 0; +free_key: + ovpn_crypto_key_slot_put(new); + return -EINVAL; +} + +void ovpn_crypto_key_slot_delete(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot) +{ + struct ovpn_crypto_key_slot *ks = NULL; + + mutex_lock(&cs->mutex); + switch (slot) { + case OVPN_KEY_SLOT_PRIMARY: + ks = rcu_replace_pointer(cs->primary, NULL, + lockdep_is_held(&cs->mutex)); + break; + case OVPN_KEY_SLOT_SECONDARY: + ks = rcu_replace_pointer(cs->secondary, NULL, + lockdep_is_held(&cs->mutex)); + break; + default: + pr_warn("Invalid slot to release: %u\n", slot); + break; + } + mutex_unlock(&cs->mutex); + + if (!ks) { + pr_debug("Key slot already released: %u\n", slot); + return; + } + pr_debug("deleting key slot %u, key_id=%u\n", slot, ks->key_id); + + ovpn_crypto_key_slot_put(ks); +} + +/* this swap is not atomic, but there will be a very short time frame where the + * old_secondary key won't be available. This should not be a big deal as most + * likely both peers are already using the new primary at this point. + */ +void ovpn_crypto_key_slots_swap(struct ovpn_crypto_state *cs) +{ + const struct ovpn_crypto_key_slot *old_primary, *old_secondary; + + mutex_lock(&cs->mutex); + + old_secondary = rcu_dereference_protected(cs->secondary, + lockdep_is_held(&cs->mutex)); + old_primary = rcu_replace_pointer(cs->primary, old_secondary, + lockdep_is_held(&cs->mutex)); + rcu_assign_pointer(cs->secondary, old_primary); + + pr_debug("key swapped: %u <-> %u\n", + old_primary ? old_primary->key_id : 0, + old_secondary ? old_secondary->key_id : 0); + + mutex_unlock(&cs->mutex); +} diff --git a/drivers/net/ovpn-dco/crypto.h b/drivers/net/ovpn-dco/crypto.h new file mode 100644 index 000000000000..79f580e54a63 --- /dev/null +++ b/drivers/net/ovpn-dco/crypto.h @@ -0,0 +1,144 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNCRYPTO_H_ +#define _NET_OVPN_DCO_OVPNCRYPTO_H_ + +#include "main.h" +#include "pktid.h" + +#include +#include + +struct ovpn_peer; +struct ovpn_crypto_key_slot; + +/* info needed for both encrypt and decrypt directions */ +struct ovpn_key_direction { + const u8 *cipher_key; + size_t cipher_key_size; + const u8 *nonce_tail; /* only needed for GCM modes */ + size_t nonce_tail_size; /* only needed for GCM modes */ +}; + +/* all info for a particular symmetric key (primary or secondary) */ +struct ovpn_key_config { + enum ovpn_cipher_alg cipher_alg; + u8 key_id; + struct ovpn_key_direction encrypt; + struct ovpn_key_direction decrypt; +}; + +/* used to pass settings from netlink to the crypto engine */ +struct ovpn_peer_key_reset { + enum ovpn_key_slot slot; + struct ovpn_key_config key; +}; + +struct ovpn_crypto_key_slot { + u8 key_id; + + struct crypto_aead *encrypt; + struct crypto_aead *decrypt; + struct ovpn_nonce_tail nonce_tail_xmit; + struct ovpn_nonce_tail nonce_tail_recv; + + struct ovpn_pktid_recv pid_recv ____cacheline_aligned_in_smp; + struct ovpn_pktid_xmit pid_xmit ____cacheline_aligned_in_smp; + struct kref refcount; + struct rcu_head rcu; +}; + +struct ovpn_crypto_state { + struct ovpn_crypto_key_slot __rcu *primary; + struct ovpn_crypto_key_slot __rcu *secondary; + + /* protects primary and secondary slots */ + struct mutex mutex; +}; + +static inline bool ovpn_crypto_key_slot_hold(struct ovpn_crypto_key_slot *ks) +{ + return kref_get_unless_zero(&ks->refcount); +} + +static inline void ovpn_crypto_state_init(struct ovpn_crypto_state *cs) +{ + RCU_INIT_POINTER(cs->primary, NULL); + RCU_INIT_POINTER(cs->secondary, NULL); + mutex_init(&cs->mutex); +} + +static inline struct ovpn_crypto_key_slot * +ovpn_crypto_key_id_to_slot(const struct ovpn_crypto_state *cs, u8 key_id) +{ + struct ovpn_crypto_key_slot *ks; + + if (unlikely(!cs)) + return NULL; + + rcu_read_lock(); + ks = rcu_dereference(cs->primary); + if (ks && ks->key_id == key_id) { + if (unlikely(!ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + goto out; + } + + ks = rcu_dereference(cs->secondary); + if (ks && ks->key_id == key_id) { + if (unlikely(!ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + goto out; + } + + /* when both key slots are occupied but no matching key ID is found, ks has to be reset to + * NULL to avoid carrying a stale pointer + */ + ks = NULL; +out: + rcu_read_unlock(); + + return ks; +} + +static inline struct ovpn_crypto_key_slot * +ovpn_crypto_key_slot_primary(const struct ovpn_crypto_state *cs) +{ + struct ovpn_crypto_key_slot *ks; + + rcu_read_lock(); + ks = rcu_dereference(cs->primary); + if (unlikely(ks && !ovpn_crypto_key_slot_hold(ks))) + ks = NULL; + rcu_read_unlock(); + + return ks; +} + +void ovpn_crypto_key_slot_release(struct kref *kref); + +static inline void ovpn_crypto_key_slot_put(struct ovpn_crypto_key_slot *ks) +{ + kref_put(&ks->refcount, ovpn_crypto_key_slot_release); +} + +int ovpn_crypto_state_reset(struct ovpn_crypto_state *cs, + const struct ovpn_peer_key_reset *pkr); + +void ovpn_crypto_key_slot_delete(struct ovpn_crypto_state *cs, + enum ovpn_key_slot slot); + +void ovpn_crypto_state_release(struct ovpn_crypto_state *cs); + +void ovpn_crypto_key_slots_swap(struct ovpn_crypto_state *cs); + +void ovpn_crypto_kill_primary(struct ovpn_crypto_state *cs); + +#endif /* _NET_OVPN_DCO_OVPNCRYPTO_H_ */ diff --git a/drivers/net/ovpn-dco/crypto_aead.c b/drivers/net/ovpn-dco/crypto_aead.c new file mode 100644 index 000000000000..c21bff90d748 --- /dev/null +++ b/drivers/net/ovpn-dco/crypto_aead.c @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "crypto_aead.h" +#include "crypto.h" +#include "pktid.h" +#include "proto.h" +#include "skb.h" + +#include +#include +#include + +#define AUTH_TAG_SIZE 16 + +static int ovpn_aead_encap_overhead(const struct ovpn_crypto_key_slot *ks) +{ + return OVPN_OP_SIZE_V2 + /* OP header size */ + 4 + /* Packet ID */ + crypto_aead_authsize(ks->encrypt); /* Auth Tag */ +} + +int ovpn_aead_encrypt(struct ovpn_crypto_key_slot *ks, struct sk_buff *skb, u32 peer_id) +{ + const unsigned int tag_size = crypto_aead_authsize(ks->encrypt); + const unsigned int head_size = ovpn_aead_encap_overhead(ks); + struct scatterlist sg[MAX_SKB_FRAGS + 2]; + DECLARE_CRYPTO_WAIT(wait); + struct aead_request *req; + struct sk_buff *trailer; + u8 iv[NONCE_SIZE]; + int nfrags, ret; + u32 pktid, op; + + /* Sample AEAD header format: + * 48000001 00000005 7e7046bd 444a7e28 cc6387b1 64a4d6c1 380275a... + * [ OP32 ] [seq # ] [ auth tag ] [ payload ... ] + * [4-byte + * IV head] + */ + + /* check that there's enough headroom in the skb for packet + * encapsulation, after adding network header and encryption overhead + */ + if (unlikely(skb_cow_head(skb, OVPN_HEAD_ROOM + head_size))) + return -ENOBUFS; + + /* get number of skb frags and ensure that packet data is writable */ + nfrags = skb_cow_data(skb, 0, &trailer); + if (unlikely(nfrags < 0)) + return nfrags; + + if (unlikely(nfrags + 2 > ARRAY_SIZE(sg))) + return -ENOSPC; + + req = aead_request_alloc(ks->encrypt, GFP_KERNEL); + if (unlikely(!req)) + return -ENOMEM; + + /* sg table: + * 0: op, wire nonce (AD, len=OVPN_OP_SIZE_V2+NONCE_WIRE_SIZE), + * 1, 2, 3, ..., n: payload, + * n+1: auth_tag (len=tag_size) + */ + sg_init_table(sg, nfrags + 2); + + /* build scatterlist to encrypt packet payload */ + ret = skb_to_sgvec_nomark(skb, sg + 1, 0, skb->len); + if (unlikely(nfrags != ret)) { + ret = -EINVAL; + goto free_req; + } + + /* append auth_tag onto scatterlist */ + __skb_push(skb, tag_size); + sg_set_buf(sg + nfrags + 1, skb->data, tag_size); + + /* obtain packet ID, which is used both as a first + * 4 bytes of nonce and last 4 bytes of associated data. + */ + ret = ovpn_pktid_xmit_next(&ks->pid_xmit, &pktid); + if (unlikely(ret < 0)) + goto free_req; + + /* concat 4 bytes packet id and 8 bytes nonce tail into 12 bytes nonce */ + ovpn_pktid_aead_write(pktid, &ks->nonce_tail_xmit, iv); + + /* make space for packet id and push it to the front */ + __skb_push(skb, NONCE_WIRE_SIZE); + memcpy(skb->data, iv, NONCE_WIRE_SIZE); + + /* add packet op as head of additional data */ + op = ovpn_opcode_compose(OVPN_DATA_V2, ks->key_id, peer_id); + __skb_push(skb, OVPN_OP_SIZE_V2); + BUILD_BUG_ON(sizeof(op) != OVPN_OP_SIZE_V2); + *((__force __be32 *)skb->data) = htonl(op); + + /* AEAD Additional data */ + sg_set_buf(sg, skb->data, OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE); + + /* setup async crypto operation */ + aead_request_set_tfm(req, ks->encrypt); + aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG | + CRYPTO_TFM_REQ_MAY_SLEEP, + crypto_req_done, &wait); + aead_request_set_crypt(req, sg, sg, skb->len - head_size, iv); + aead_request_set_ad(req, OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE); + + /* encrypt it */ + ret = crypto_wait_req(crypto_aead_encrypt(req), &wait); + if (ret < 0) + net_err_ratelimited("%s: encrypt failed: %d\n", __func__, ret); + +free_req: + aead_request_free(req); + return ret; +} + +int ovpn_aead_decrypt(struct ovpn_crypto_key_slot *ks, struct sk_buff *skb) +{ + const unsigned int tag_size = crypto_aead_authsize(ks->decrypt); + struct scatterlist sg[MAX_SKB_FRAGS + 2]; + int ret, payload_len, nfrags; + u8 *sg_data, iv[NONCE_SIZE]; + unsigned int payload_offset; + DECLARE_CRYPTO_WAIT(wait); + struct aead_request *req; + struct sk_buff *trailer; + unsigned int sg_len; + __be32 *pid; + + payload_offset = OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE + tag_size; + payload_len = skb->len - payload_offset; + + /* sanity check on packet size, payload size must be >= 0 */ + if (unlikely(payload_len < 0)) + return -EINVAL; + + /* Prepare the skb data buffer to be accessed up until the auth tag. + * This is required because this area is directly mapped into the sg list. + */ + if (unlikely(!pskb_may_pull(skb, payload_offset))) + return -ENODATA; + + /* get number of skb frags and ensure that packet data is writable */ + nfrags = skb_cow_data(skb, 0, &trailer); + if (unlikely(nfrags < 0)) + return nfrags; + + if (unlikely(nfrags + 2 > ARRAY_SIZE(sg))) + return -ENOSPC; + + req = aead_request_alloc(ks->decrypt, GFP_KERNEL); + if (unlikely(!req)) + return -ENOMEM; + + /* sg table: + * 0: op, wire nonce (AD, len=OVPN_OP_SIZE_V2+NONCE_WIRE_SIZE), + * 1, 2, 3, ..., n: payload, + * n+1: auth_tag (len=tag_size) + */ + sg_init_table(sg, nfrags + 2); + + /* packet op is head of additional data */ + sg_data = skb->data; + sg_len = OVPN_OP_SIZE_V2 + NONCE_WIRE_SIZE; + sg_set_buf(sg, sg_data, sg_len); + + /* build scatterlist to decrypt packet payload */ + ret = skb_to_sgvec_nomark(skb, sg + 1, payload_offset, payload_len); + if (unlikely(nfrags != ret)) { + ret = -EINVAL; + goto free_req; + } + + /* append auth_tag onto scatterlist */ + sg_set_buf(sg + nfrags + 1, skb->data + sg_len, tag_size); + + /* copy nonce into IV buffer */ + memcpy(iv, skb->data + OVPN_OP_SIZE_V2, NONCE_WIRE_SIZE); + memcpy(iv + NONCE_WIRE_SIZE, ks->nonce_tail_recv.u8, + sizeof(struct ovpn_nonce_tail)); + + /* setup async crypto operation */ + aead_request_set_tfm(req, ks->decrypt); + aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG | + CRYPTO_TFM_REQ_MAY_SLEEP, + crypto_req_done, &wait); + aead_request_set_crypt(req, sg, sg, payload_len + tag_size, iv); + + aead_request_set_ad(req, NONCE_WIRE_SIZE + OVPN_OP_SIZE_V2); + + /* decrypt it */ + ret = crypto_wait_req(crypto_aead_decrypt(req), &wait); + if (ret < 0) { + net_err_ratelimited("%s: decrypt failed: %d\n", __func__, ret); + goto free_req; + } + + /* PID sits after the op */ + pid = (__force __be32 *)(skb->data + OVPN_OP_SIZE_V2); + ret = ovpn_pktid_recv(&ks->pid_recv, ntohl(*pid), 0); + if (unlikely(ret < 0)) + goto free_req; + + /* point to encapsulated IP packet */ + __skb_pull(skb, payload_offset); + +free_req: + aead_request_free(req); + return ret; +} + +/* Initialize a struct crypto_aead object */ +struct crypto_aead *ovpn_aead_init(const char *title, const char *alg_name, + const unsigned char *key, unsigned int keylen) +{ + struct crypto_aead *aead; + int ret; + + aead = crypto_alloc_aead(alg_name, 0, 0); + if (IS_ERR(aead)) { + ret = PTR_ERR(aead); + pr_err("%s crypto_alloc_aead failed, err=%d\n", title, ret); + aead = NULL; + goto error; + } + + ret = crypto_aead_setkey(aead, key, keylen); + if (ret) { + pr_err("%s crypto_aead_setkey size=%u failed, err=%d\n", title, keylen, ret); + goto error; + } + + ret = crypto_aead_setauthsize(aead, AUTH_TAG_SIZE); + if (ret) { + pr_err("%s crypto_aead_setauthsize failed, err=%d\n", title, ret); + goto error; + } + + /* basic AEAD assumption */ + if (crypto_aead_ivsize(aead) != NONCE_SIZE) { + pr_err("%s IV size must be %d\n", title, NONCE_SIZE); + ret = -EINVAL; + goto error; + } + + pr_debug("********* Cipher %s (%s)\n", alg_name, title); + pr_debug("*** IV size=%u\n", crypto_aead_ivsize(aead)); + pr_debug("*** req size=%u\n", crypto_aead_reqsize(aead)); + pr_debug("*** block size=%u\n", crypto_aead_blocksize(aead)); + pr_debug("*** auth size=%u\n", crypto_aead_authsize(aead)); + pr_debug("*** alignmask=0x%x\n", crypto_aead_alignmask(aead)); + + return aead; + +error: + crypto_free_aead(aead); + return ERR_PTR(ret); +} + +void ovpn_aead_crypto_key_slot_destroy(struct ovpn_crypto_key_slot *ks) +{ + if (!ks) + return; + + crypto_free_aead(ks->encrypt); + crypto_free_aead(ks->decrypt); + kfree(ks); +} + +static struct ovpn_crypto_key_slot * +ovpn_aead_crypto_key_slot_init(enum ovpn_cipher_alg alg, + const unsigned char *encrypt_key, + unsigned int encrypt_keylen, + const unsigned char *decrypt_key, + unsigned int decrypt_keylen, + const unsigned char *encrypt_nonce_tail, + unsigned int encrypt_nonce_tail_len, + const unsigned char *decrypt_nonce_tail, + unsigned int decrypt_nonce_tail_len, + u16 key_id) +{ + struct ovpn_crypto_key_slot *ks = NULL; + const char *alg_name; + int ret; + + /* validate crypto alg */ + switch (alg) { + case OVPN_CIPHER_ALG_AES_GCM: + alg_name = "gcm(aes)"; + break; + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + alg_name = "rfc7539(chacha20,poly1305)"; + break; + default: + return ERR_PTR(-EOPNOTSUPP); + } + + /* build the key slot */ + ks = kmalloc(sizeof(*ks), GFP_KERNEL); + if (!ks) + return ERR_PTR(-ENOMEM); + + ks->encrypt = NULL; + ks->decrypt = NULL; + kref_init(&ks->refcount); + ks->key_id = key_id; + + ks->encrypt = ovpn_aead_init("encrypt", alg_name, encrypt_key, + encrypt_keylen); + if (IS_ERR(ks->encrypt)) { + ret = PTR_ERR(ks->encrypt); + ks->encrypt = NULL; + goto destroy_ks; + } + + ks->decrypt = ovpn_aead_init("decrypt", alg_name, decrypt_key, + decrypt_keylen); + if (IS_ERR(ks->decrypt)) { + ret = PTR_ERR(ks->decrypt); + ks->decrypt = NULL; + goto destroy_ks; + } + + if (sizeof(struct ovpn_nonce_tail) != encrypt_nonce_tail_len || + sizeof(struct ovpn_nonce_tail) != decrypt_nonce_tail_len) { + ret = -EINVAL; + goto destroy_ks; + } + + memcpy(ks->nonce_tail_xmit.u8, encrypt_nonce_tail, + sizeof(struct ovpn_nonce_tail)); + memcpy(ks->nonce_tail_recv.u8, decrypt_nonce_tail, + sizeof(struct ovpn_nonce_tail)); + + /* init packet ID generation/validation */ + ovpn_pktid_xmit_init(&ks->pid_xmit); + ovpn_pktid_recv_init(&ks->pid_recv); + + return ks; + +destroy_ks: + ovpn_aead_crypto_key_slot_destroy(ks); + return ERR_PTR(ret); +} + +struct ovpn_crypto_key_slot * +ovpn_aead_crypto_key_slot_new(const struct ovpn_key_config *kc) +{ + return ovpn_aead_crypto_key_slot_init(kc->cipher_alg, + kc->encrypt.cipher_key, + kc->encrypt.cipher_key_size, + kc->decrypt.cipher_key, + kc->decrypt.cipher_key_size, + kc->encrypt.nonce_tail, + kc->encrypt.nonce_tail_size, + kc->decrypt.nonce_tail, + kc->decrypt.nonce_tail_size, + kc->key_id); +} diff --git a/drivers/net/ovpn-dco/crypto_aead.h b/drivers/net/ovpn-dco/crypto_aead.h new file mode 100644 index 000000000000..1e3054e7d5a4 --- /dev/null +++ b/drivers/net/ovpn-dco/crypto_aead.h @@ -0,0 +1,27 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNAEAD_H_ +#define _NET_OVPN_DCO_OVPNAEAD_H_ + +#include "crypto.h" + +#include +#include + +struct crypto_aead *ovpn_aead_init(const char *title, const char *alg_name, + const unsigned char *key, unsigned int keylen); + +int ovpn_aead_encrypt(struct ovpn_crypto_key_slot *ks, struct sk_buff *skb, u32 peer_id); +int ovpn_aead_decrypt(struct ovpn_crypto_key_slot *ks, struct sk_buff *skb); + +struct ovpn_crypto_key_slot *ovpn_aead_crypto_key_slot_new(const struct ovpn_key_config *kc); +void ovpn_aead_crypto_key_slot_destroy(struct ovpn_crypto_key_slot *ks); + +#endif /* _NET_OVPN_DCO_OVPNAEAD_H_ */ diff --git a/drivers/net/ovpn-dco/main.c b/drivers/net/ovpn-dco/main.c new file mode 100644 index 000000000000..4eb90ea7a500 --- /dev/null +++ b/drivers/net/ovpn-dco/main.c @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#include "main.h" + +#include "ovpn.h" +#include "ovpnstruct.h" +#include "netlink.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +/* Driver info */ +#define DRV_NAME "ovpn-dco" +#define DRV_VERSION OVPN_DCO_VERSION +#define DRV_DESCRIPTION "OpenVPN data channel offload (ovpn-dco)" +#define DRV_COPYRIGHT "(C) 2020-2022 OpenVPN, Inc." + +static void ovpn_struct_free(struct net_device *net) +{ + struct ovpn_struct *ovpn = netdev_priv(net); + + security_tun_dev_free_security(ovpn->security); + free_percpu(net->tstats); + flush_workqueue(ovpn->crypto_wq); + flush_workqueue(ovpn->events_wq); + destroy_workqueue(ovpn->crypto_wq); + destroy_workqueue(ovpn->events_wq); + rcu_barrier(); +} + +/* Net device open */ +static int ovpn_net_open(struct net_device *dev) +{ + struct in_device *dev_v4 = __in_dev_get_rtnl(dev); + + if (dev_v4) { + /* disable redirects as Linux gets confused by ovpn-dco handling same-LAN routing */ + IN_DEV_CONF_SET(dev_v4, SEND_REDIRECTS, false); + IPV4_DEVCONF_ALL(dev_net(dev), SEND_REDIRECTS) = false; + } + + netif_tx_start_all_queues(dev); + return 0; +} + +/* Net device stop -- called prior to device unload */ +static int ovpn_net_stop(struct net_device *dev) +{ + netif_tx_stop_all_queues(dev); + return 0; +} + +/******************************************* + * ovpn ethtool ops + *******************************************/ + +static int ovpn_get_link_ksettings(struct net_device *dev, + struct ethtool_link_ksettings *cmd) +{ + ethtool_convert_legacy_u32_to_link_mode(cmd->link_modes.supported, 0); + ethtool_convert_legacy_u32_to_link_mode(cmd->link_modes.advertising, 0); + cmd->base.speed = SPEED_1000; + cmd->base.duplex = DUPLEX_FULL; + cmd->base.port = PORT_TP; + cmd->base.phy_address = 0; + cmd->base.transceiver = XCVR_INTERNAL; + cmd->base.autoneg = AUTONEG_DISABLE; + + return 0; +} + +static void ovpn_get_drvinfo(struct net_device *dev, + struct ethtool_drvinfo *info) +{ + strscpy(info->driver, DRV_NAME, sizeof(info->driver)); + strscpy(info->version, DRV_VERSION, sizeof(info->version)); + strscpy(info->bus_info, "ovpn", sizeof(info->bus_info)); +} + +bool ovpn_dev_is_valid(const struct net_device *dev) +{ + return dev->netdev_ops->ndo_start_xmit == ovpn_net_xmit; +} + +/******************************************* + * ovpn exported methods + *******************************************/ + +static const struct net_device_ops ovpn_netdev_ops = { + .ndo_open = ovpn_net_open, + .ndo_stop = ovpn_net_stop, + .ndo_start_xmit = ovpn_net_xmit, + .ndo_get_stats64 = dev_get_tstats64, +}; + +static const struct ethtool_ops ovpn_ethtool_ops = { + .get_link_ksettings = ovpn_get_link_ksettings, + .get_drvinfo = ovpn_get_drvinfo, + .get_link = ethtool_op_get_link, + .get_ts_info = ethtool_op_get_ts_info, +}; + +static void ovpn_setup(struct net_device *dev) +{ + /* compute the overhead considering AEAD encryption */ + const int overhead = sizeof(u32) + NONCE_WIRE_SIZE + 16 + sizeof(struct udphdr) + + max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); + + netdev_features_t feat = NETIF_F_SG | NETIF_F_LLTX | + NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_GSO | + NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA; + + dev->ethtool_ops = &ovpn_ethtool_ops; + dev->needs_free_netdev = true; + + dev->netdev_ops = &ovpn_netdev_ops; + + dev->priv_destructor = ovpn_struct_free; + + /* Point-to-Point TUN Device */ + dev->hard_header_len = 0; + dev->addr_len = 0; + dev->mtu = ETH_DATA_LEN - overhead; + dev->min_mtu = IPV4_MIN_MTU; + dev->max_mtu = IP_MAX_MTU - overhead; + + /* Zero header length */ + dev->type = ARPHRD_NONE; + dev->flags = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST; + + dev->features |= feat; + dev->hw_features |= feat; + dev->hw_enc_features |= feat; + + dev->needed_headroom = OVPN_HEAD_ROOM; + dev->needed_tailroom = OVPN_MAX_PADDING; +} + +static const struct nla_policy ovpn_policy[IFLA_OVPN_MAX + 1] = { + [IFLA_OVPN_MODE] = NLA_POLICY_RANGE(NLA_U8, __OVPN_MODE_FIRST, + __OVPN_MODE_AFTER_LAST - 1), +}; + +static int ovpn_newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], struct netlink_ext_ack *extack) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + int ret; + + ret = security_tun_dev_create(); + if (ret < 0) + return ret; + + ret = ovpn_struct_init(dev); + if (ret < 0) + return ret; + + ovpn->mode = OVPN_MODE_P2P; + if (data && data[IFLA_OVPN_MODE]) { + ovpn->mode = nla_get_u8(data[IFLA_OVPN_MODE]); + netdev_dbg(dev, "%s: setting device (%s) mode: %u\n", __func__, dev->name, + ovpn->mode); + } + + return register_netdevice(dev); +} + +static void ovpn_dellink(struct net_device *dev, struct list_head *head) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + + switch (ovpn->mode) { + case OVPN_MODE_P2P: + ovpn_peer_release_p2p(ovpn); + break; + default: + ovpn_peers_free(ovpn); + break; + } + + unregister_netdevice_queue(dev, head); +} + +/** + * ovpn_num_queues - define number of queues to allocate per device + * + * The value returned by this function is used to decide how many RX and TX + * queues to allocate when creating the netdev object + * + * Return the number of queues to allocate + */ +static unsigned int ovpn_num_queues(void) +{ + return num_online_cpus(); +} + +static struct rtnl_link_ops ovpn_link_ops __read_mostly = { + .kind = DRV_NAME, + .priv_size = sizeof(struct ovpn_struct), + .setup = ovpn_setup, + .policy = ovpn_policy, + .maxtype = IFLA_OVPN_MAX, + .newlink = ovpn_newlink, + .dellink = ovpn_dellink, + .get_num_tx_queues = ovpn_num_queues, + .get_num_rx_queues = ovpn_num_queues, +}; + +static int __init ovpn_init(void) +{ + int err = 0; + + pr_info("%s %s -- %s\n", DRV_DESCRIPTION, DRV_VERSION, DRV_COPYRIGHT); + + /* init RTNL link ops */ + err = rtnl_link_register(&ovpn_link_ops); + if (err) { + pr_err("ovpn: can't register RTNL link ops\n"); + goto err; + } + + err = ovpn_netlink_register(); + if (err) { + pr_err("ovpn: can't register netlink family\n"); + goto err_rtnl_unregister; + } + + return 0; + +err_rtnl_unregister: + rtnl_link_unregister(&ovpn_link_ops); +err: + pr_err("ovpn: initialization failed, error status=%d\n", err); + return err; +} + +static __exit void ovpn_cleanup(void) +{ + rtnl_link_unregister(&ovpn_link_ops); + ovpn_netlink_unregister(); + rcu_barrier(); /* because we use call_rcu */ +} + +module_init(ovpn_init); +module_exit(ovpn_cleanup); + +MODULE_DESCRIPTION(DRV_DESCRIPTION); +MODULE_AUTHOR(DRV_COPYRIGHT); +MODULE_LICENSE("GPL"); +MODULE_VERSION(DRV_VERSION); +MODULE_ALIAS_RTNL_LINK(DRV_NAME); +MODULE_ALIAS_GENL_FAMILY(OVPN_NL_NAME); diff --git a/drivers/net/ovpn-dco/main.h b/drivers/net/ovpn-dco/main.h new file mode 100644 index 000000000000..c4ef200b30f4 --- /dev/null +++ b/drivers/net/ovpn-dco/main.h @@ -0,0 +1,32 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_MAIN_H_ +#define _NET_OVPN_DCO_MAIN_H_ + +#include +#include +#include +#include + +#define OVPN_DCO_VERSION "2.0.0" + +struct net_device; +bool ovpn_dev_is_valid(const struct net_device *dev); + +#define SKB_HEADER_LEN \ + (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + \ + sizeof(struct udphdr) + NET_SKB_PAD) + +#define OVPN_HEAD_ROOM ALIGN(16 + SKB_HEADER_LEN, 4) +#define OVPN_MAX_PADDING 16 +#define OVPN_QUEUE_LEN 1024 +#define OVPN_MAX_TUN_QUEUE_LEN 0x10000 + +#endif /* _NET_OVPN_DCO_OVPN_DCO_H_ */ diff --git a/drivers/net/ovpn-dco/netlink.c b/drivers/net/ovpn-dco/netlink.c new file mode 100644 index 000000000000..ee5c943e7db4 --- /dev/null +++ b/drivers/net/ovpn-dco/netlink.c @@ -0,0 +1,1143 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include "main.h" +#include "ovpn.h" +#include "peer.h" +#include "proto.h" +#include "netlink.h" +#include "ovpnstruct.h" +#include "udp.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** The ovpn-dco netlink family */ +static struct genl_family ovpn_netlink_family; + +enum ovpn_netlink_multicast_groups { + OVPN_MCGRP_PEERS, +}; + +static const struct genl_multicast_group ovpn_netlink_mcgrps[] = { + [OVPN_MCGRP_PEERS] = { .name = OVPN_NL_MULTICAST_GROUP_PEERS }, +}; + +/** Key direction policy. Can be used for configuring an encryption and a decryption key */ +static const struct nla_policy ovpn_netlink_policy_key_dir[OVPN_KEY_DIR_ATTR_MAX + 1] = { + [OVPN_KEY_DIR_ATTR_CIPHER_KEY] = NLA_POLICY_MAX_LEN(U8_MAX), + [OVPN_KEY_DIR_ATTR_NONCE_TAIL] = NLA_POLICY_EXACT_LEN(NONCE_TAIL_SIZE), +}; + +/** CMD_NEW_KEY policy */ +static const struct nla_policy ovpn_netlink_policy_new_key[OVPN_NEW_KEY_ATTR_MAX + 1] = { + [OVPN_NEW_KEY_ATTR_PEER_ID] = { .type = NLA_U32 }, + [OVPN_NEW_KEY_ATTR_KEY_SLOT] = NLA_POLICY_RANGE(NLA_U8, __OVPN_KEY_SLOT_FIRST, + __OVPN_KEY_SLOT_AFTER_LAST - 1), + [OVPN_NEW_KEY_ATTR_KEY_ID] = { .type = NLA_U8 }, + [OVPN_NEW_KEY_ATTR_CIPHER_ALG] = { .type = NLA_U16 }, + [OVPN_NEW_KEY_ATTR_ENCRYPT_KEY] = NLA_POLICY_NESTED(ovpn_netlink_policy_key_dir), + [OVPN_NEW_KEY_ATTR_DECRYPT_KEY] = NLA_POLICY_NESTED(ovpn_netlink_policy_key_dir), +}; + +/** CMD_DEL_KEY policy */ +static const struct nla_policy ovpn_netlink_policy_del_key[OVPN_DEL_KEY_ATTR_MAX + 1] = { + [OVPN_DEL_KEY_ATTR_PEER_ID] = { .type = NLA_U32 }, + [OVPN_DEL_KEY_ATTR_KEY_SLOT] = NLA_POLICY_RANGE(NLA_U8, __OVPN_KEY_SLOT_FIRST, + __OVPN_KEY_SLOT_AFTER_LAST - 1), +}; + +/** CMD_SWAP_KEYS policy */ +static const struct nla_policy ovpn_netlink_policy_swap_keys[OVPN_SWAP_KEYS_ATTR_MAX + 1] = { + [OVPN_SWAP_KEYS_ATTR_PEER_ID] = { .type = NLA_U32 }, +}; + +/** CMD_NEW_PEER policy */ +static const struct nla_policy ovpn_netlink_policy_new_peer[OVPN_NEW_PEER_ATTR_MAX + 1] = { + [OVPN_NEW_PEER_ATTR_PEER_ID] = { .type = NLA_U32 }, + [OVPN_NEW_PEER_ATTR_SOCKADDR_REMOTE] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)), + [OVPN_NEW_PEER_ATTR_SOCKET] = { .type = NLA_U32 }, + [OVPN_NEW_PEER_ATTR_IPV4] = { .type = NLA_U32 }, + [OVPN_NEW_PEER_ATTR_IPV6] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [OVPN_NEW_PEER_ATTR_LOCAL_IP] = NLA_POLICY_MAX_LEN(sizeof(struct in6_addr)), +}; + +/** CMD_SET_PEER policy */ +static const struct nla_policy ovpn_netlink_policy_set_peer[OVPN_SET_PEER_ATTR_MAX + 1] = { + [OVPN_SET_PEER_ATTR_PEER_ID] = { .type = NLA_U32 }, + [OVPN_SET_PEER_ATTR_KEEPALIVE_INTERVAL] = { .type = NLA_U32 }, + [OVPN_SET_PEER_ATTR_KEEPALIVE_TIMEOUT] = { .type = NLA_U32 }, +}; + +/** CMD_DEL_PEER policy */ +static const struct nla_policy ovpn_netlink_policy_del_peer[OVPN_DEL_PEER_ATTR_MAX + 1] = { + [OVPN_DEL_PEER_ATTR_REASON] = NLA_POLICY_RANGE(NLA_U8, __OVPN_DEL_PEER_REASON_FIRST, + __OVPN_DEL_PEER_REASON_AFTER_LAST - 1), + [OVPN_DEL_PEER_ATTR_PEER_ID] = { .type = NLA_U32 }, +}; + +/** CMD_GET_PEER policy */ +static const struct nla_policy ovpn_netlink_policy_get_peer[OVPN_GET_PEER_ATTR_MAX + 1] = { + [OVPN_GET_PEER_ATTR_PEER_ID] = { .type = NLA_U32 }, +}; + +/** CMD_PACKET polocy */ +static const struct nla_policy ovpn_netlink_policy_packet[OVPN_PACKET_ATTR_MAX + 1] = { + [OVPN_PACKET_ATTR_PEER_ID] = { .type = NLA_U32 }, + [OVPN_PACKET_ATTR_PACKET] = NLA_POLICY_MAX_LEN(U16_MAX), +}; + +/** Generic message container policy */ +static const struct nla_policy ovpn_netlink_policy[OVPN_ATTR_MAX + 1] = { + [OVPN_ATTR_IFINDEX] = { .type = NLA_U32 }, + [OVPN_ATTR_NEW_PEER] = NLA_POLICY_NESTED(ovpn_netlink_policy_new_peer), + [OVPN_ATTR_SET_PEER] = NLA_POLICY_NESTED(ovpn_netlink_policy_set_peer), + [OVPN_ATTR_DEL_PEER] = NLA_POLICY_NESTED(ovpn_netlink_policy_del_peer), + [OVPN_ATTR_GET_PEER] = NLA_POLICY_NESTED(ovpn_netlink_policy_get_peer), + [OVPN_ATTR_NEW_KEY] = NLA_POLICY_NESTED(ovpn_netlink_policy_new_key), + [OVPN_ATTR_SWAP_KEYS] = NLA_POLICY_NESTED(ovpn_netlink_policy_swap_keys), + [OVPN_ATTR_DEL_KEY] = NLA_POLICY_NESTED(ovpn_netlink_policy_del_key), + [OVPN_ATTR_PACKET] = NLA_POLICY_NESTED(ovpn_netlink_policy_packet), +}; + +static struct net_device * +ovpn_get_dev_from_attrs(struct net *net, struct nlattr **attrs) +{ + struct net_device *dev; + int ifindex; + + if (!attrs[OVPN_ATTR_IFINDEX]) + return ERR_PTR(-EINVAL); + + ifindex = nla_get_u32(attrs[OVPN_ATTR_IFINDEX]); + + dev = dev_get_by_index(net, ifindex); + if (!dev) + return ERR_PTR(-ENODEV); + + if (!ovpn_dev_is_valid(dev)) + goto err_put_dev; + + return dev; + +err_put_dev: + dev_put(dev); + + return ERR_PTR(-EINVAL); +} + +/** + * ovpn_pre_doit() - Prepare ovpn genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int ovpn_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct net_device *dev; + + dev = ovpn_get_dev_from_attrs(net, info->attrs); + if (IS_ERR(dev)) + return PTR_ERR(dev); + + info->user_ptr[0] = netdev_priv(dev); + + return 0; +} + +/** + * ovpn_post_doit() - complete ovpn genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + */ +static void ovpn_post_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct ovpn_struct *ovpn; + + ovpn = info->user_ptr[0]; + dev_put(ovpn->dev); +} + +static int ovpn_netlink_get_key_dir(struct genl_info *info, struct nlattr *key, + enum ovpn_cipher_alg cipher, + struct ovpn_key_direction *dir) +{ + struct nlattr *attr, *attrs[OVPN_KEY_DIR_ATTR_MAX + 1]; + int ret; + + ret = nla_parse_nested(attrs, OVPN_KEY_DIR_ATTR_MAX, key, NULL, info->extack); + if (ret) + return ret; + + switch (cipher) { + case OVPN_CIPHER_ALG_AES_GCM: + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + attr = attrs[OVPN_KEY_DIR_ATTR_CIPHER_KEY]; + if (!attr) + return -EINVAL; + + dir->cipher_key = nla_data(attr); + dir->cipher_key_size = nla_len(attr); + + attr = attrs[OVPN_KEY_DIR_ATTR_NONCE_TAIL]; + /* These algorithms require a 96bit nonce, + * Construct it by combining 4-bytes packet id and + * 8-bytes nonce-tail from userspace + */ + if (!attr) + return -EINVAL; + + dir->nonce_tail = nla_data(attr); + dir->nonce_tail_size = nla_len(attr); + break; + default: + return -EINVAL; + } + + return 0; +} + +static int ovpn_netlink_new_key(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_NEW_KEY_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer_key_reset pkr; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (!info->attrs[OVPN_ATTR_NEW_KEY]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_NEW_KEY_ATTR_MAX, info->attrs[OVPN_ATTR_NEW_KEY], + NULL, info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_NEW_KEY_ATTR_PEER_ID] || + !attrs[OVPN_NEW_KEY_ATTR_KEY_SLOT] || + !attrs[OVPN_NEW_KEY_ATTR_KEY_ID] || + !attrs[OVPN_NEW_KEY_ATTR_CIPHER_ALG] || + !attrs[OVPN_NEW_KEY_ATTR_ENCRYPT_KEY] || + !attrs[OVPN_NEW_KEY_ATTR_DECRYPT_KEY]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_NEW_KEY_ATTR_PEER_ID]); + pkr.slot = nla_get_u8(attrs[OVPN_NEW_KEY_ATTR_KEY_SLOT]); + pkr.key.key_id = nla_get_u16(attrs[OVPN_NEW_KEY_ATTR_KEY_ID]); + + pkr.key.cipher_alg = nla_get_u16(attrs[OVPN_NEW_KEY_ATTR_CIPHER_ALG]); + + ret = ovpn_netlink_get_key_dir(info, attrs[OVPN_NEW_KEY_ATTR_ENCRYPT_KEY], + pkr.key.cipher_alg, &pkr.key.encrypt); + if (ret < 0) + return ret; + + ret = ovpn_netlink_get_key_dir(info, attrs[OVPN_NEW_KEY_ATTR_DECRYPT_KEY], + pkr.key.cipher_alg, &pkr.key.decrypt); + if (ret < 0) + return ret; + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) { + netdev_dbg(ovpn->dev, "%s: no peer with id %u to set key for\n", __func__, peer_id); + return -ENOENT; + } + + mutex_lock(&peer->crypto.mutex); + ret = ovpn_crypto_state_reset(&peer->crypto, &pkr); + if (ret < 0) { + netdev_dbg(ovpn->dev, "%s: cannot install new key for peer %u\n", __func__, + peer_id); + goto unlock; + } + + netdev_dbg(ovpn->dev, "%s: new key installed (id=%u) for peer %u\n", __func__, + pkr.key.key_id, peer_id); +unlock: + mutex_unlock(&peer->crypto.mutex); + ovpn_peer_put(peer); + return ret; +} + +static int ovpn_netlink_del_key(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_DEL_KEY_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + enum ovpn_key_slot slot; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (!info->attrs[OVPN_ATTR_DEL_KEY]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_DEL_KEY_ATTR_MAX, info->attrs[OVPN_ATTR_DEL_KEY], NULL, + info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_DEL_KEY_ATTR_PEER_ID] || !attrs[OVPN_DEL_KEY_ATTR_KEY_SLOT]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_DEL_KEY_ATTR_PEER_ID]); + slot = nla_get_u8(attrs[OVPN_DEL_KEY_ATTR_KEY_SLOT]); + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + ovpn_crypto_key_slot_delete(&peer->crypto, slot); + ovpn_peer_put(peer); + + return 0; +} + +static int ovpn_netlink_swap_keys(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_SWAP_KEYS_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (!info->attrs[OVPN_ATTR_SWAP_KEYS]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_SWAP_KEYS_ATTR_MAX, info->attrs[OVPN_ATTR_SWAP_KEYS], + NULL, info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_SWAP_KEYS_ATTR_PEER_ID]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_SWAP_KEYS_ATTR_PEER_ID]); + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + ovpn_crypto_key_slots_swap(&peer->crypto); + ovpn_peer_put(peer); + + return 0; +} + +static int ovpn_netlink_new_peer(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_NEW_PEER_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct sockaddr_storage *ss = NULL; + struct sockaddr_in mapped; + struct sockaddr_in6 *in6; + struct ovpn_peer *peer; + size_t sa_len, ip_len; + struct socket *sock; + u8 *local_ip = NULL; + u32 sockfd, id; + int ret; + + if (!info->attrs[OVPN_ATTR_NEW_PEER]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_NEW_PEER_ATTR_MAX, info->attrs[OVPN_ATTR_NEW_PEER], NULL, + info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_NEW_PEER_ATTR_PEER_ID] || !attrs[OVPN_NEW_PEER_ATTR_SOCKET] || + (!attrs[OVPN_NEW_PEER_ATTR_IPV4] && !attrs[OVPN_NEW_PEER_ATTR_IPV6])) { + netdev_dbg(ovpn->dev, "%s: basic attributes missing\n", __func__); + return -EINVAL; + } + + /* lookup the fd in the kernel table and extract the socket object */ + sockfd = nla_get_u32(attrs[OVPN_NEW_PEER_ATTR_SOCKET]); + /* sockfd_lookup() increases sock's refcounter */ + sock = sockfd_lookup(sockfd, &ret); + if (!sock) { + netdev_dbg(ovpn->dev, "%s: cannot lookup peer socket (fd=%u): %d\n", __func__, + sockfd, ret); + return -ENOTSOCK; + } + + /* Only when using UDP as transport protocol the remote endpoint must be configured + * so that ovpn-dco knows where to send packets to. + * + * In case of TCP, the socket is connected to the peer and ovpn-dco will just send bytes + * over it, without the need to specify a destination. + */ + if (sock->sk->sk_protocol == IPPROTO_UDP) { + ret = -EINVAL; + + if (!attrs[OVPN_NEW_PEER_ATTR_SOCKADDR_REMOTE]) { + netdev_err(ovpn->dev, "%s: cannot add UDP peer with no remote endpoint\n", + __func__); + goto sockfd_release; + } + + ss = nla_data(attrs[OVPN_NEW_PEER_ATTR_SOCKADDR_REMOTE]); + sa_len = nla_len(attrs[OVPN_NEW_PEER_ATTR_SOCKADDR_REMOTE]); + switch (sa_len) { + case sizeof(struct sockaddr_in): + if (ss->ss_family == AF_INET) + /* valid sockaddr */ + break; + + netdev_err(ovpn->dev, "%s: remote sockaddr_in has invalid family\n", + __func__); + goto sockfd_release; + case sizeof(struct sockaddr_in6): + if (ss->ss_family == AF_INET6) + /* valid sockaddr */ + break; + + netdev_err(ovpn->dev, "%s: remote sockaddr_in6 has invalid family\n", + __func__); + goto sockfd_release; + default: + netdev_err(ovpn->dev, "%s: invalid size for sockaddr\n", __func__); + goto sockfd_release; + } + + if (ss->ss_family == AF_INET6) { + in6 = (struct sockaddr_in6 *)ss; + + if (ipv6_addr_type(&in6->sin6_addr) & IPV6_ADDR_MAPPED) { + mapped.sin_family = AF_INET; + mapped.sin_addr.s_addr = in6->sin6_addr.s6_addr32[3]; + mapped.sin_port = in6->sin6_port; + ss = (struct sockaddr_storage *)&mapped; + } + } + + /* When using UDP we may be talking over socket bound to 0.0.0.0/::. + * In this case, if the host has multiple IPs, we need to make sure + * that outgoing traffic has as source IP the same address that the + * peer is using to reach us. + * + * Since early control packets were all forwarded to userspace, we + * need the latter to tell us what IP has to be used. + */ + if (attrs[OVPN_NEW_PEER_ATTR_LOCAL_IP]) { + ip_len = nla_len(attrs[OVPN_NEW_PEER_ATTR_LOCAL_IP]); + local_ip = nla_data(attrs[OVPN_NEW_PEER_ATTR_LOCAL_IP]); + + if (ip_len == sizeof(struct in_addr)) { + if (ss->ss_family != AF_INET) { + netdev_dbg(ovpn->dev, + "%s: the specified local IP is IPv4, but the peer endpoint is not\n", + __func__); + goto sockfd_release; + } + } else if (ip_len == sizeof(struct in6_addr)) { + bool is_mapped = ipv6_addr_type((struct in6_addr *)local_ip) & + IPV6_ADDR_MAPPED; + + if (ss->ss_family != AF_INET6 && !is_mapped) { + netdev_dbg(ovpn->dev, + "%s: the specified local IP is IPv6, but the peer endpoint is not\n", + __func__); + goto sockfd_release; + } + + if (is_mapped) + /* this is an IPv6-mapped IPv4 address, therefore extract + * the actual v4 address from the last 4 bytes + */ + local_ip += 12; + } else { + netdev_dbg(ovpn->dev, + "%s: invalid length %zu for local IP\n", __func__, + ip_len); + goto sockfd_release; + } + } + + /* sanity checks passed */ + ret = 0; + } + + id = nla_get_u32(attrs[OVPN_NEW_PEER_ATTR_PEER_ID]); + peer = ovpn_peer_new(ovpn, ss, sock, id, local_ip); + if (IS_ERR(peer)) { + netdev_err(ovpn->dev, "%s: cannot create new peer object for peer %u %pIScp\n", + __func__, id, ss); + ret = PTR_ERR(peer); + goto sockfd_release; + } + + if (attrs[OVPN_NEW_PEER_ATTR_IPV4]) { + if (nla_len(attrs[OVPN_NEW_PEER_ATTR_IPV4]) != sizeof(struct in_addr)) { + ret = -EINVAL; + goto peer_release; + } + + peer->vpn_addrs.ipv4.s_addr = nla_get_be32(attrs[OVPN_NEW_PEER_ATTR_IPV4]); + } + + if (attrs[OVPN_NEW_PEER_ATTR_IPV6]) { + if (nla_len(attrs[OVPN_NEW_PEER_ATTR_IPV6]) != sizeof(struct in6_addr)) { + ret = -EINVAL; + goto peer_release; + } + + memcpy(&peer->vpn_addrs.ipv6, nla_data(attrs[OVPN_NEW_PEER_ATTR_IPV6]), + sizeof(struct in6_addr)); + } + + netdev_dbg(ovpn->dev, + "%s: adding peer with endpoint=%pIScp/%s id=%u VPN-IPv4=%pI4 VPN-IPv6=%pI6c\n", + __func__, ss, sock->sk->sk_prot_creator->name, peer->id, + &peer->vpn_addrs.ipv4.s_addr, &peer->vpn_addrs.ipv6); + + ret = ovpn_peer_add(ovpn, peer); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot add new peer (id=%u) to hashtable: %d\n", + __func__, peer->id, ret); + goto peer_release; + } + + return 0; + +peer_release: + /* release right away because peer is not really used in any context */ + ovpn_peer_release(peer); + return ret; + +sockfd_release: + sockfd_put(sock); + return ret; +} + +static int ovpn_netlink_set_peer(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_SET_PEER_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + u32 peer_id, interv, timeout; + bool keepalive_set = false; + struct ovpn_peer *peer; + int ret; + + if (!info->attrs[OVPN_ATTR_SET_PEER]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_SET_PEER_ATTR_MAX, info->attrs[OVPN_ATTR_SET_PEER], NULL, + info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_SET_PEER_ATTR_PEER_ID]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_SET_PEER_ATTR_PEER_ID]); + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + /* when setting the keepalive, both parameters have to be configured */ + if (attrs[OVPN_SET_PEER_ATTR_KEEPALIVE_INTERVAL] && + attrs[OVPN_SET_PEER_ATTR_KEEPALIVE_TIMEOUT]) { + keepalive_set = true; + interv = nla_get_u32(attrs[OVPN_SET_PEER_ATTR_KEEPALIVE_INTERVAL]); + timeout = nla_get_u32(attrs[OVPN_SET_PEER_ATTR_KEEPALIVE_TIMEOUT]); + } + + if (keepalive_set) + ovpn_peer_keepalive_set(peer, interv, timeout); + + ovpn_peer_put(peer); + return 0; +} + +static int ovpn_netlink_send_peer(struct sk_buff *skb, const struct ovpn_peer *peer, u32 portid, + u32 seq, int flags) +{ + const struct ovpn_bind *bind; + struct nlattr *attr; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ovpn_netlink_family, flags, OVPN_CMD_GET_PEER); + if (!hdr) { + netdev_dbg(peer->ovpn->dev, "%s: cannot create message header\n", __func__); + return -EMSGSIZE; + } + + attr = nla_nest_start(skb, OVPN_ATTR_GET_PEER); + if (!attr) { + netdev_dbg(peer->ovpn->dev, "%s: cannot create submessage\n", __func__); + goto err; + } + + if (nla_put_u32(skb, OVPN_GET_PEER_RESP_ATTR_PEER_ID, peer->id)) + goto err; + + if (peer->vpn_addrs.ipv4.s_addr != htonl(INADDR_ANY)) + if (nla_put(skb, OVPN_GET_PEER_RESP_ATTR_IPV4, sizeof(peer->vpn_addrs.ipv4), + &peer->vpn_addrs.ipv4)) + goto err; + + if (memcmp(&peer->vpn_addrs.ipv6, &in6addr_any, sizeof(peer->vpn_addrs.ipv6))) + if (nla_put(skb, OVPN_GET_PEER_RESP_ATTR_IPV6, sizeof(peer->vpn_addrs.ipv6), + &peer->vpn_addrs.ipv6)) + goto err; + + if (nla_put_u32(skb, OVPN_GET_PEER_RESP_ATTR_KEEPALIVE_INTERVAL, + peer->keepalive_interval) || + nla_put_u32(skb, OVPN_GET_PEER_RESP_ATTR_KEEPALIVE_TIMEOUT, + peer->keepalive_timeout)) + goto err; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (bind) { + if (bind->sa.in4.sin_family == AF_INET) { + if (nla_put(skb, OVPN_GET_PEER_RESP_ATTR_SOCKADDR_REMOTE, + sizeof(bind->sa.in4), &bind->sa.in4) || + nla_put(skb, OVPN_GET_PEER_RESP_ATTR_LOCAL_IP, + sizeof(bind->local.ipv4), &bind->local.ipv4)) + goto err_unlock; + } else if (bind->sa.in4.sin_family == AF_INET6) { + if (nla_put(skb, OVPN_GET_PEER_RESP_ATTR_SOCKADDR_REMOTE, + sizeof(bind->sa.in6), &bind->sa.in6) || + nla_put(skb, OVPN_GET_PEER_RESP_ATTR_LOCAL_IP, + sizeof(bind->local.ipv6), &bind->local.ipv6)) + goto err_unlock; + } + } + rcu_read_unlock(); + + if (nla_put_net16(skb, OVPN_GET_PEER_RESP_ATTR_LOCAL_PORT, + inet_sk(peer->sock->sock->sk)->inet_sport) || + /* RX stats */ + nla_put_u64_64bit(skb, OVPN_GET_PEER_RESP_ATTR_RX_BYTES, + atomic64_read(&peer->stats.rx.bytes), + OVPN_GET_PEER_RESP_ATTR_UNSPEC) || + nla_put_u32(skb, OVPN_GET_PEER_RESP_ATTR_RX_PACKETS, + atomic_read(&peer->stats.rx.packets)) || + /* TX stats */ + nla_put_u64_64bit(skb, OVPN_GET_PEER_RESP_ATTR_TX_BYTES, + atomic64_read(&peer->stats.tx.bytes), + OVPN_GET_PEER_RESP_ATTR_UNSPEC) || + nla_put_u32(skb, OVPN_GET_PEER_RESP_ATTR_TX_PACKETS, + atomic_read(&peer->stats.tx.packets))) + goto err; + + nla_nest_end(skb, attr); + genlmsg_end(skb, hdr); + + return 0; +err_unlock: + rcu_read_unlock(); +err: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ovpn_netlink_get_peer(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_SET_PEER_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + struct sk_buff *msg; + u32 peer_id; + int ret; + + if (!info->attrs[OVPN_ATTR_GET_PEER]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_GET_PEER_ATTR_MAX, info->attrs[OVPN_ATTR_GET_PEER], NULL, + info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_GET_PEER_ATTR_PEER_ID]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_GET_PEER_ATTR_PEER_ID]); + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = ovpn_netlink_send_peer(msg, peer, info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + goto err; + } + + ret = genlmsg_reply(msg, info); +err: + ovpn_peer_put(peer); + return ret; +} + +static int ovpn_netlink_dump_done(struct netlink_callback *cb) +{ + struct ovpn_struct *ovpn = (struct ovpn_struct *)cb->args[0]; + + dev_put(ovpn->dev); + return 0; +} + +static int ovpn_netlink_dump_prepare(struct netlink_callback *cb) +{ + struct net *netns = sock_net(cb->skb->sk); + struct nlattr **attrbuf; + struct net_device *dev; + int ret; + + attrbuf = kcalloc(OVPN_ATTR_MAX + 1, sizeof(*attrbuf), GFP_KERNEL); + if (!attrbuf) + return -ENOMEM; + + ret = nlmsg_parse_deprecated(cb->nlh, GENL_HDRLEN, attrbuf, OVPN_ATTR_MAX, + ovpn_netlink_policy, NULL); + if (ret < 0) + goto err; + + dev = ovpn_get_dev_from_attrs(netns, attrbuf); + if (IS_ERR(dev)) { + ret = PTR_ERR(dev); + goto err; + } + + cb->args[0] = (long)netdev_priv(dev); + ret = 0; +err: + kfree(attrbuf); + return ret; +} + +static int ovpn_netlink_dump_peers(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ovpn_struct *ovpn = (struct ovpn_struct *)cb->args[0]; + int ret, bkt, last_idx = cb->args[1], dumped = 0; + struct ovpn_peer *peer; + + if (!ovpn) { + ret = ovpn_netlink_dump_prepare(cb); + if (ret < 0) { + netdev_dbg(ovpn->dev, "%s: cannot prepare for dump: %d\n", __func__, ret); + return ret; + } + + ovpn = (struct ovpn_struct *)cb->args[0]; + } + + rcu_read_lock(); + hash_for_each_rcu(ovpn->peers.by_id, bkt, peer, hash_entry_id) { + /* skip already dumped peers that were dumped by previous invocations */ + if (last_idx > 0) { + last_idx--; + continue; + } + + if (ovpn_netlink_send_peer(skb, peer, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI) < 0) + break; + + /* count peers being dumped during this invocation */ + dumped++; + } + rcu_read_unlock(); + + /* sum up peers dumped in this message, so that at the next invocation + * we can continue from where we left + */ + cb->args[1] += dumped; + + return skb->len; +} + +static int ovpn_netlink_del_peer(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_SET_PEER_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + struct ovpn_peer *peer; + u32 peer_id; + int ret; + + if (!info->attrs[OVPN_ATTR_DEL_PEER]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_DEL_PEER_ATTR_MAX, info->attrs[OVPN_ATTR_DEL_PEER], NULL, + info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_DEL_PEER_ATTR_PEER_ID]) + return -EINVAL; + + peer_id = nla_get_u32(attrs[OVPN_DEL_PEER_ATTR_PEER_ID]); + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) + return -ENOENT; + + netdev_dbg(ovpn->dev, "%s: peer id=%u\n", __func__, peer->id); + ret = ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_USERSPACE); + ovpn_peer_put(peer); + + return ret; +} + +static int ovpn_netlink_register_packet(struct sk_buff *skb, + struct genl_info *info) +{ + struct ovpn_struct *ovpn = info->user_ptr[0]; + + /* only one registered process per interface is allowed for now */ + if (ovpn->registered_nl_portid_set) { + netdev_dbg(ovpn->dev, "%s: userspace listener already registered\n", __func__); + return -EBUSY; + } + + netdev_dbg(ovpn->dev, "%s: registering userspace at %u\n", __func__, info->snd_portid); + + ovpn->registered_nl_portid = info->snd_portid; + ovpn->registered_nl_portid_set = true; + + return 0; +} + +static int ovpn_netlink_packet(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[OVPN_PACKET_ATTR_MAX + 1]; + struct ovpn_struct *ovpn = info->user_ptr[0]; + const u8 *packet; + u32 peer_id; + size_t len; + u8 opcode; + int ret; + + if (!info->attrs[OVPN_ATTR_PACKET]) + return -EINVAL; + + ret = nla_parse_nested(attrs, OVPN_PACKET_ATTR_MAX, info->attrs[OVPN_ATTR_PACKET], + NULL, info->extack); + if (ret) + return ret; + + if (!attrs[OVPN_PACKET_ATTR_PACKET] || !attrs[OVPN_PACKET_ATTR_PEER_ID]) { + netdev_dbg(ovpn->dev, "received netlink packet with no payload\n"); + return -EINVAL; + } + + peer_id = nla_get_u32(attrs[OVPN_PACKET_ATTR_PEER_ID]); + + len = nla_len(attrs[OVPN_PACKET_ATTR_PACKET]); + + if (len < 4 || len > ovpn->dev->mtu) { + netdev_dbg(ovpn->dev, "%s: invalid packet size %zu (min is 4, max is MTU: %u)\n", + __func__, len, ovpn->dev->mtu); + return -EINVAL; + } + + packet = nla_data(attrs[OVPN_PACKET_ATTR_PACKET]); + opcode = ovpn_opcode_from_byte(packet[0]); + + /* reject data packets from userspace as they could lead to IV reuse */ + if (opcode == OVPN_DATA_V1 || opcode == OVPN_DATA_V2) { + netdev_dbg(ovpn->dev, "%s: rejecting data packet from userspace (opcode=%u)\n", + __func__, opcode); + return -EINVAL; + } + + netdev_dbg(ovpn->dev, "%s: sending userspace packet to peer %u...\n", __func__, peer_id); + + return ovpn_send_data(ovpn, peer_id, packet, len); +} + +static const struct genl_ops ovpn_netlink_ops[] = { + { + .cmd = OVPN_CMD_NEW_PEER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_new_peer, + }, + { + .cmd = OVPN_CMD_SET_PEER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_set_peer, + }, + { + .cmd = OVPN_CMD_DEL_PEER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_del_peer, + }, + { + .cmd = OVPN_CMD_GET_PEER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_get_peer, + .dumpit = ovpn_netlink_dump_peers, + .done = ovpn_netlink_dump_done, + }, + { + .cmd = OVPN_CMD_NEW_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_new_key, + }, + { + .cmd = OVPN_CMD_DEL_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_del_key, + }, + { + .cmd = OVPN_CMD_SWAP_KEYS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_swap_keys, + }, + { + .cmd = OVPN_CMD_REGISTER_PACKET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_register_packet, + }, + { + .cmd = OVPN_CMD_PACKET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ovpn_netlink_packet, + }, +}; + +static struct genl_family ovpn_netlink_family __ro_after_init = { + .hdrsize = 0, + .name = OVPN_NL_NAME, + .version = 1, + .maxattr = OVPN_ATTR_MAX, + .policy = ovpn_netlink_policy, + .netnsok = true, + .pre_doit = ovpn_pre_doit, + .post_doit = ovpn_post_doit, + .module = THIS_MODULE, + .ops = ovpn_netlink_ops, + .n_ops = ARRAY_SIZE(ovpn_netlink_ops), + .mcgrps = ovpn_netlink_mcgrps, + .n_mcgrps = ARRAY_SIZE(ovpn_netlink_mcgrps), +}; + +int ovpn_netlink_notify_del_peer(struct ovpn_peer *peer) +{ + struct sk_buff *msg; + struct nlattr *attr; + void *hdr; + int ret; + + netdev_info(peer->ovpn->dev, "%s: deleting peer with id %u, reason %d\n", + peer->ovpn->dev->name, peer->id, peer->delete_reason); + + msg = nlmsg_new(100, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &ovpn_netlink_family, 0, + OVPN_CMD_DEL_PEER); + if (!hdr) { + ret = -ENOBUFS; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_ATTR_IFINDEX, peer->ovpn->dev->ifindex)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + attr = nla_nest_start(msg, OVPN_ATTR_DEL_PEER); + if (!attr) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + if (nla_put_u8(msg, OVPN_DEL_PEER_ATTR_REASON, peer->delete_reason)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_DEL_PEER_ATTR_PEER_ID, peer->id)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + nla_nest_end(msg, attr); + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&ovpn_netlink_family, dev_net(peer->ovpn->dev), + msg, 0, OVPN_MCGRP_PEERS, GFP_KERNEL); + + return 0; + +err_free_msg: + nlmsg_free(msg); + return ret; +} + +int ovpn_netlink_send_packet(struct ovpn_struct *ovpn, const struct ovpn_peer *peer, + const u8 *buf, size_t len) +{ + struct nlattr *attr; + struct sk_buff *msg; + void *hdr; + int ret; + + if (!ovpn->registered_nl_portid_set) { + net_warn_ratelimited("%s: no userspace listener\n", __func__); + return 0; + } + + netdev_dbg(ovpn->dev, "%s: sending packet to userspace, len: %zd\n", __func__, len); + + msg = nlmsg_new(100 + len, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &ovpn_netlink_family, 0, + OVPN_CMD_PACKET); + if (!hdr) { + ret = -ENOBUFS; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_ATTR_IFINDEX, ovpn->dev->ifindex)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + attr = nla_nest_start(msg, OVPN_ATTR_PACKET); + if (!attr) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + if (nla_put(msg, OVPN_PACKET_ATTR_PACKET, len, buf)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + if (nla_put_u32(msg, OVPN_PACKET_ATTR_PEER_ID, peer->id)) { + ret = -EMSGSIZE; + goto err_free_msg; + } + + nla_nest_end(msg, attr); + + genlmsg_end(msg, hdr); + + return genlmsg_unicast(dev_net(ovpn->dev), msg, + ovpn->registered_nl_portid); + +err_free_msg: + nlmsg_free(msg); + return ret; +} + +static int ovpn_netlink_notify(struct notifier_block *nb, unsigned long state, + void *_notify) +{ + struct netlink_notify *notify = _notify; + struct ovpn_struct *ovpn; + struct net_device *dev; + struct net *netns; + bool found = false; + + if (state != NETLINK_URELEASE || notify->protocol != NETLINK_GENERIC) + return NOTIFY_DONE; + + rcu_read_lock(); + for_each_net_rcu(netns) { + for_each_netdev_rcu(netns, dev) { + if (!ovpn_dev_is_valid(dev)) + continue; + + ovpn = netdev_priv(dev); + if (notify->portid != ovpn->registered_nl_portid) + continue; + + found = true; + netdev_dbg(ovpn->dev, "%s: deregistering userspace listener\n", __func__); + ovpn->registered_nl_portid_set = false; + break; + } + } + rcu_read_unlock(); + + /* if no interface matched our purposes, pass the notification along */ + if (!found) + return NOTIFY_DONE; + + return NOTIFY_OK; +} + +static struct notifier_block ovpn_netlink_notifier = { + .notifier_call = ovpn_netlink_notify, +}; + +int ovpn_netlink_init(struct ovpn_struct *ovpn) +{ + ovpn->registered_nl_portid_set = false; + + return 0; +} + +/** + * ovpn_netlink_register() - register the ovpn genl netlink family + */ +int __init ovpn_netlink_register(void) +{ + int ret; + + ret = genl_register_family(&ovpn_netlink_family); + if (ret) + return ret; + + ret = netlink_register_notifier(&ovpn_netlink_notifier); + if (ret) + goto err; + + return 0; +err: + genl_unregister_family(&ovpn_netlink_family); + return ret; +} + +/** + * ovpn_netlink_unregister() - unregister the ovpn genl netlink family + */ +void __exit ovpn_netlink_unregister(void) +{ + netlink_unregister_notifier(&ovpn_netlink_notifier); + genl_unregister_family(&ovpn_netlink_family); +} diff --git a/drivers/net/ovpn-dco/netlink.h b/drivers/net/ovpn-dco/netlink.h new file mode 100644 index 000000000000..843daf052c03 --- /dev/null +++ b/drivers/net/ovpn-dco/netlink.h @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_NETLINK_H_ +#define _NET_OVPN_DCO_NETLINK_H_ + +struct ovpn_struct; +struct ovpn_peer; + +int ovpn_netlink_init(struct ovpn_struct *ovpn); +int ovpn_netlink_register(void); +void ovpn_netlink_unregister(void); +int ovpn_netlink_send_packet(struct ovpn_struct *ovpn, const struct ovpn_peer *peer, + const u8 *buf, size_t len); +int ovpn_netlink_notify_del_peer(struct ovpn_peer *peer); + +#endif /* _NET_OVPN_DCO_NETLINK_H_ */ diff --git a/drivers/net/ovpn-dco/ovpn.c b/drivers/net/ovpn-dco/ovpn.c new file mode 100644 index 000000000000..66c019174f5e --- /dev/null +++ b/drivers/net/ovpn-dco/ovpn.c @@ -0,0 +1,600 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "main.h" +#include "bind.h" +#include "netlink.h" +#include "ovpn.h" +#include "sock.h" +#include "peer.h" +#include "stats.h" +#include "proto.h" +#include "crypto.h" +#include "crypto_aead.h" +#include "skb.h" +#include "tcp.h" +#include "udp.h" + +#include +#include + +static const unsigned char ovpn_keepalive_message[] = { + 0x2a, 0x18, 0x7b, 0xf3, 0x64, 0x1e, 0xb4, 0xcb, + 0x07, 0xed, 0x2d, 0x0a, 0x98, 0x1f, 0xc7, 0x48 +}; + +static const unsigned char ovpn_explicit_exit_notify_message[] = { + 0x28, 0x7f, 0x34, 0x6b, 0xd4, 0xef, 0x7a, 0x81, + 0x2d, 0x56, 0xb8, 0xd3, 0xaf, 0xc5, 0x45, 0x9c, + 6 // OCC_EXIT +}; + +/* Is keepalive message? + * Assumes that single byte at skb->data is defined. + */ +static bool ovpn_is_keepalive(struct sk_buff *skb) +{ + if (*skb->data != OVPN_KEEPALIVE_FIRST_BYTE) + return false; + + if (!pskb_may_pull(skb, sizeof(ovpn_keepalive_message))) + return false; + + return !memcmp(skb->data, ovpn_keepalive_message, + sizeof(ovpn_keepalive_message)); +} + +int ovpn_struct_init(struct net_device *dev) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + int err; + + memset(ovpn, 0, sizeof(*ovpn)); + + ovpn->dev = dev; + + err = ovpn_netlink_init(ovpn); + if (err < 0) + return err; + + spin_lock_init(&ovpn->lock); + spin_lock_init(&ovpn->peers.lock); + + ovpn->crypto_wq = alloc_workqueue("ovpn-crypto-wq-%s", + WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, + dev->name); + if (!ovpn->crypto_wq) + return -ENOMEM; + + ovpn->events_wq = alloc_workqueue("ovpn-event-wq-%s", WQ_MEM_RECLAIM, 0, dev->name); + if (!ovpn->events_wq) + return -ENOMEM; + + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + err = security_tun_dev_alloc_security(&ovpn->security); + if (err < 0) + return err; + + /* kernel -> userspace tun queue length */ + ovpn->max_tun_queue_len = OVPN_MAX_TUN_QUEUE_LEN; + + return 0; +} + +/* Called after decrypt to write IP packet to tun netdev. + * This method is expected to manage/free skb. + */ +static void tun_netdev_write(struct ovpn_peer *peer, struct sk_buff *skb) +{ + /* packet integrity was verified on the VPN layer - no need to perform + * any additional check along the stack + */ + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->csum_level = ~0; + + /* skb hash for transport packet no longer valid after decapsulation */ + skb_clear_hash(skb); + + /* post-decrypt scrub -- prepare to inject encapsulated packet onto tun + * interface, based on __skb_tunnel_rx() in dst.h + */ + skb->dev = peer->ovpn->dev; + skb_set_queue_mapping(skb, 0); + skb_scrub_packet(skb, true); + + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_probe_transport_header(skb); + skb_reset_inner_headers(skb); + + /* update per-cpu RX stats with the stored size of encrypted packet */ + + /* we are in softirq context - hence no locking nor disable preemption needed */ + dev_sw_netstats_rx_add(peer->ovpn->dev, OVPN_SKB_CB(skb)->rx_stats_size); + + /* cause packet to be "received" by tun interface */ + napi_gro_receive(&peer->napi, skb); +} + +int ovpn_napi_poll(struct napi_struct *napi, int budget) +{ + struct ovpn_peer *peer = container_of(napi, struct ovpn_peer, napi); + struct sk_buff *skb; + int work_done = 0; + + if (unlikely(budget <= 0)) + return 0; + /* this function should schedule at most 'budget' number of + * packets for delivery to the tun interface. + * If in the queue we have more packets than what allowed by the + * budget, the next polling will take care of those + */ + while ((work_done < budget) && + (skb = ptr_ring_consume_bh(&peer->netif_rx_ring))) { + tun_netdev_write(peer, skb); + work_done++; + } + + if (work_done < budget) + napi_complete_done(napi, work_done); + + return work_done; +} + +static int ovpn_transport_to_userspace(struct ovpn_struct *ovpn, const struct ovpn_peer *peer, + struct sk_buff *skb) +{ + int ret; + + ret = skb_linearize(skb); + if (ret < 0) + return ret; + + ret = ovpn_netlink_send_packet(ovpn, peer, skb->data, skb->len); + if (ret < 0) + return ret; + + consume_skb(skb); + return 0; +} + +/* Entry point for processing an incoming packet (in skb form) + * + * Enqueue the packet and schedule RX consumer. + * Reference to peer is dropped only in case of success. + * + * Return 0 if the packet was handled (and consumed) + * Return <0 in case of error (return value is error code) + */ +int ovpn_recv(struct ovpn_struct *ovpn, struct ovpn_peer *peer, struct sk_buff *skb) +{ + int ret; + + /* At this point we know the packet is from a configured peer. + * DATA_V2 packets are handled in kernel space, the rest goes to user space. + * + * Packets are sent to userspace via netlink API in order to be consistenbt across + * UDP and TCP. + */ + if (unlikely(ovpn_opcode_from_skb(skb, 0) != OVPN_DATA_V2)) { + ret = ovpn_transport_to_userspace(ovpn, peer, skb); + if (ret < 0) + return ret; + + ovpn_peer_put(peer); + return 0; + } + + ret = ptr_ring_produce_bh(&peer->rx_ring, skb); + if (unlikely(ret < 0)) + return -ENOSPC; + + if (!queue_work(ovpn->crypto_wq, &peer->decrypt_work)) + ovpn_peer_put(peer); + + return 0; +} + +static int ovpn_decrypt_one(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct ovpn_peer *allowed_peer = NULL; + struct ovpn_crypto_key_slot *ks; + unsigned int rx_stats_size; + __be16 proto; + int ret = -1; + u8 key_id; + + /* save original packet size for stats accounting */ + OVPN_SKB_CB(skb)->rx_stats_size = skb->len; + + /* get the key slot matching the key Id in the received packet */ + key_id = ovpn_key_id_from_skb(skb); + ks = ovpn_crypto_key_id_to_slot(&peer->crypto, key_id); + if (unlikely(!ks)) { + net_info_ratelimited("%s: no available key for peer %u, key-id: %u\n", __func__, + peer->id, key_id); + goto drop; + } + + /* decrypt */ + ret = ovpn_aead_decrypt(ks, skb); + + ovpn_crypto_key_slot_put(ks); + + if (unlikely(ret < 0)) { + net_err_ratelimited("%s: error during decryption for peer %u, key-id %u: %d\n", + __func__, peer->id, key_id, ret); + goto drop; + } + + /* note event of authenticated packet received for keepalive */ + ovpn_peer_keepalive_recv_reset(peer); + + /* update source and destination endpoint for this peer */ + if (peer->sock->sock->sk->sk_protocol == IPPROTO_UDP) + ovpn_peer_update_local_endpoint(peer, skb); + + /* increment RX stats */ + rx_stats_size = OVPN_SKB_CB(skb)->rx_stats_size; + ovpn_peer_stats_increment_rx(&peer->stats, rx_stats_size); + + /* check if this is a valid datapacket that has to be delivered to the + * tun interface + */ + skb_reset_network_header(skb); + proto = ovpn_ip_check_protocol(skb); + if (unlikely(!proto)) { + /* check if null packet */ + if (unlikely(!pskb_may_pull(skb, 1))) { + ret = -EINVAL; + goto drop; + } + + /* check if special OpenVPN message */ + if (ovpn_is_keepalive(skb)) { + netdev_dbg(peer->ovpn->dev, "%s: ping received from peer with id %u\n", + __func__, peer->id); + /* not an error */ + consume_skb(skb); + /* inform the caller that NAPI should not be scheduled + * for this packet + */ + return -1; + } + + ret = -EPROTONOSUPPORT; + goto drop; + } + skb->protocol = proto; + + /* perform Reverse Path Filtering (RPF) */ + allowed_peer = ovpn_peer_lookup_vpn_addr(peer->ovpn, skb, true); + if (unlikely(allowed_peer != peer)) { + ret = -EPERM; + goto drop; + } + + ret = ptr_ring_produce_bh(&peer->netif_rx_ring, skb); +drop: + if (likely(allowed_peer)) + ovpn_peer_put(allowed_peer); + + if (unlikely(ret < 0)) + kfree_skb(skb); + + return ret; +} + +/* pick packet from RX queue, decrypt and forward it to the tun device */ +void ovpn_decrypt_work(struct work_struct *work) +{ + struct ovpn_peer *peer; + struct sk_buff *skb; + + peer = container_of(work, struct ovpn_peer, decrypt_work); + while ((skb = ptr_ring_consume_bh(&peer->rx_ring))) { + if (likely(ovpn_decrypt_one(peer, skb) == 0)) { + /* if a packet has been enqueued for NAPI, signal + * availability to the networking stack + */ + local_bh_disable(); + napi_schedule(&peer->napi); + local_bh_enable(); + } + + /* give a chance to be rescheduled if needed */ + cond_resched(); + } + ovpn_peer_put(peer); +} + +static bool ovpn_encrypt_one(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct ovpn_crypto_key_slot *ks; + bool success = false; + int ret; + + /* get primary key to be used for encrypting data */ + ks = ovpn_crypto_key_slot_primary(&peer->crypto); + if (unlikely(!ks)) { + net_info_ratelimited("%s: error while retrieving primary key slot\n", __func__); + return false; + } + + if (unlikely(skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_help(skb))) { + net_err_ratelimited("%s: cannot compute checksum for outgoing packet\n", __func__); + goto err; + } + + ovpn_peer_stats_increment_tx(&peer->stats, skb->len); + + /* encrypt */ + ret = ovpn_aead_encrypt(ks, skb, peer->id); + if (unlikely(ret < 0)) { + /* if we ran out of IVs we must kill the key as it can't be used anymore */ + if (ret == -ERANGE) { + netdev_warn(peer->ovpn->dev, + "%s: killing primary key as we ran out of IVs\n", __func__); + ovpn_crypto_kill_primary(&peer->crypto); + goto err; + } + net_err_ratelimited("%s: error during encryption for peer %u, key-id %u: %d\n", + __func__, peer->id, ks->key_id, ret); + goto err; + } + + success = true; +err: + ovpn_crypto_key_slot_put(ks); + return success; +} + +/* Process packets in TX queue in a transport-specific way. + * + * UDP transport - encrypt and send across the tunnel. + * TCP transport - encrypt and put into TCP TX queue. + */ +void ovpn_encrypt_work(struct work_struct *work) +{ + struct sk_buff *skb, *curr, *next; + struct ovpn_peer *peer; + + peer = container_of(work, struct ovpn_peer, encrypt_work); + while ((skb = ptr_ring_consume_bh(&peer->tx_ring))) { + /* this might be a GSO-segmented skb list: process each skb + * independently + */ + skb_list_walk_safe(skb, curr, next) { + /* if one segment fails encryption, we drop the entire + * packet, because it does not really make sense to send + * only part of it at this point + */ + if (unlikely(!ovpn_encrypt_one(peer, curr))) { + kfree_skb_list(skb); + skb = NULL; + break; + } + } + + /* successful encryption */ + if (skb) { + skb_list_walk_safe(skb, curr, next) { + skb_mark_not_on_list(curr); + + switch (peer->sock->sock->sk->sk_protocol) { + case IPPROTO_UDP: + ovpn_udp_send_skb(peer->ovpn, peer, curr); + break; + case IPPROTO_TCP: + ovpn_tcp_send_skb(peer, curr); + break; + default: + /* no transport configured yet */ + consume_skb(skb); + break; + } + } + + /* note event of authenticated packet xmit for keepalive */ + ovpn_peer_keepalive_xmit_reset(peer); + } + + /* give a chance to be rescheduled if needed */ + cond_resched(); + } + ovpn_peer_put(peer); +} + +/* Put skb into TX queue and schedule a consumer */ +static void ovpn_queue_skb(struct ovpn_struct *ovpn, struct sk_buff *skb, struct ovpn_peer *peer) +{ + int ret; + + if (likely(!peer)) + peer = ovpn_peer_lookup_vpn_addr(ovpn, skb, false); + if (unlikely(!peer)) { + net_dbg_ratelimited("%s: no peer to send data to\n", ovpn->dev->name); + goto drop; + } + + ret = ptr_ring_produce_bh(&peer->tx_ring, skb); + if (unlikely(ret < 0)) { + net_err_ratelimited("%s: cannot queue packet to TX ring\n", __func__); + goto drop; + } + + if (!queue_work(ovpn->crypto_wq, &peer->encrypt_work)) + ovpn_peer_put(peer); + + return; +drop: + if (peer) + ovpn_peer_put(peer); + kfree_skb_list(skb); +} + +/* Net device start xmit + */ +netdev_tx_t ovpn_net_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ovpn_struct *ovpn = netdev_priv(dev); + struct sk_buff *segments, *tmp, *curr, *next; + struct sk_buff_head skb_list; + __be16 proto; + int ret; + + /* reset netfilter state */ + nf_reset_ct(skb); + + /* verify IP header size in network packet */ + proto = ovpn_ip_check_protocol(skb); + if (unlikely(!proto || skb->protocol != proto)) { + net_dbg_ratelimited("%s: dropping malformed payload packet\n", + dev->name); + goto drop; + } + + if (skb_is_gso(skb)) { + segments = skb_gso_segment(skb, 0); + if (IS_ERR(segments)) { + ret = PTR_ERR(segments); + net_dbg_ratelimited("%s: cannot segment packet: %d\n", dev->name, ret); + goto drop; + } + + consume_skb(skb); + skb = segments; + } + + /* from this moment on, "skb" might be a list */ + + __skb_queue_head_init(&skb_list); + skb_list_walk_safe(skb, curr, next) { + skb_mark_not_on_list(curr); + + tmp = skb_share_check(curr, GFP_ATOMIC); + if (unlikely(!tmp)) { + kfree_skb_list(next); + net_dbg_ratelimited("%s: skb_share_check failed\n", dev->name); + goto drop_list; + } + + __skb_queue_tail(&skb_list, tmp); + } + skb_list.prev->next = NULL; + + ovpn_queue_skb(ovpn, skb_list.next, NULL); + + return NETDEV_TX_OK; + +drop_list: + skb_queue_walk_safe(&skb_list, curr, next) + kfree_skb(curr); +drop: + skb_tx_error(skb); + kfree_skb_list(skb); + return NET_XMIT_DROP; +} + +/* Encrypt and transmit a special message to peer, such as keepalive + * or explicit-exit-notify. Called from softirq context. + * Assumes that caller holds a reference to peer. + */ +static void ovpn_xmit_special(struct ovpn_peer *peer, const void *data, + const unsigned int len) +{ + struct ovpn_struct *ovpn; + struct sk_buff *skb; + + ovpn = peer->ovpn; + if (unlikely(!ovpn)) + return; + + skb = alloc_skb(256 + len, GFP_ATOMIC); + if (unlikely(!skb)) + return; + + skb_reserve(skb, 128); + skb->priority = TC_PRIO_BESTEFFORT; + memcpy(__skb_put(skb, len), data, len); + + /* increase reference counter when passing peer to sending queue */ + if (!ovpn_peer_hold(peer)) { + netdev_dbg(ovpn->dev, "%s: cannot hold peer reference for sending special packet\n", + __func__); + kfree_skb(skb); + return; + } + + ovpn_queue_skb(ovpn, skb, peer); +} + +void ovpn_keepalive_xmit(struct ovpn_peer *peer) +{ + ovpn_xmit_special(peer, ovpn_keepalive_message, + sizeof(ovpn_keepalive_message)); +} + +/* Transmit explicit exit notification. + * Called from process context. + */ +void ovpn_explicit_exit_notify_xmit(struct ovpn_peer *peer) +{ + ovpn_xmit_special(peer, ovpn_explicit_exit_notify_message, + sizeof(ovpn_explicit_exit_notify_message)); +} + +/* Copy buffer into skb and send it across the tunnel. + * + * For UDP transport: just sent the skb to peer + * For TCP transport: put skb into TX queue + */ +int ovpn_send_data(struct ovpn_struct *ovpn, u32 peer_id, const u8 *data, size_t len) +{ + u16 skb_len = SKB_HEADER_LEN + len; + struct ovpn_peer *peer; + struct sk_buff *skb; + bool tcp = false; + int ret = 0; + + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (unlikely(!peer)) { + netdev_dbg(ovpn->dev, "no peer to send data to\n"); + return -EHOSTUNREACH; + } + + if (peer->sock->sock->sk->sk_protocol == IPPROTO_TCP) { + skb_len += sizeof(u16); + tcp = true; + } + + skb = alloc_skb(skb_len, GFP_ATOMIC); + if (unlikely(!skb)) { + ret = -ENOMEM; + goto out; + } + + skb_reserve(skb, SKB_HEADER_LEN); + skb_put_data(skb, data, len); + + /* prepend TCP packet with size, as required by OpenVPN protocol */ + if (tcp) { + *(__be16 *)__skb_push(skb, sizeof(u16)) = htons(len); + ovpn_queue_tcp_skb(peer, skb); + } else { + ovpn_udp_send_skb(ovpn, peer, skb); + } +out: + ovpn_peer_put(peer); + return ret; +} diff --git a/drivers/net/ovpn-dco/ovpn.h b/drivers/net/ovpn-dco/ovpn.h new file mode 100644 index 000000000000..9364fd5dd309 --- /dev/null +++ b/drivers/net/ovpn-dco/ovpn.h @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPN_H_ +#define _NET_OVPN_DCO_OVPN_H_ + +#include "main.h" +#include "peer.h" +#include "sock.h" +#include "ovpnstruct.h" + +#include +#include +#include + +struct ovpn_struct; +struct net_device; + +int ovpn_struct_init(struct net_device *dev); + +u16 ovpn_select_queue(struct net_device *dev, struct sk_buff *skb, + struct net_device *sb_dev); + +void ovpn_keepalive_xmit(struct ovpn_peer *peer); +void ovpn_explicit_exit_notify_xmit(struct ovpn_peer *peer); + +netdev_tx_t ovpn_net_xmit(struct sk_buff *skb, struct net_device *dev); + +int ovpn_recv(struct ovpn_struct *ovpn, struct ovpn_peer *peer, struct sk_buff *skb); + +void ovpn_encrypt_work(struct work_struct *work); +void ovpn_decrypt_work(struct work_struct *work); +int ovpn_napi_poll(struct napi_struct *napi, int budget); + +int ovpn_send_data(struct ovpn_struct *ovpn, u32 peer_id, const u8 *data, size_t len); + +#endif /* _NET_OVPN_DCO_OVPN_H_ */ diff --git a/drivers/net/ovpn-dco/ovpnstruct.h b/drivers/net/ovpn-dco/ovpnstruct.h new file mode 100644 index 000000000000..f9bc559609cd --- /dev/null +++ b/drivers/net/ovpn-dco/ovpnstruct.h @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNSTRUCT_H_ +#define _NET_OVPN_DCO_OVPNSTRUCT_H_ + +#include "peer.h" + +#include +#include +#include + +/* Our state per ovpn interface */ +struct ovpn_struct { + /* read-mostly objects in this section */ + struct net_device *dev; + + /* device operation mode (i.e. P2P, MP) */ + enum ovpn_mode mode; + + /* protect writing to the ovpn_struct object */ + spinlock_t lock; + + /* workqueue used to schedule crypto work that may sleep */ + struct workqueue_struct *crypto_wq; + /* workqueue used to schedule generic event that may sleep or that need + * to be performed out of softirq context + */ + struct workqueue_struct *events_wq; + + /* list of known peers */ + struct { + DECLARE_HASHTABLE(by_id, 12); + DECLARE_HASHTABLE(by_transp_addr, 12); + DECLARE_HASHTABLE(by_vpn_addr, 12); + /* protects write access to any of the hashtables above */ + spinlock_t lock; + } peers; + + /* for p2p mode */ + struct ovpn_peer __rcu *peer; + + unsigned int max_tun_queue_len; + + netdev_features_t set_features; + + void *security; + + u32 registered_nl_portid; + bool registered_nl_portid_set; +}; + +#endif /* _NET_OVPN_DCO_OVPNSTRUCT_H_ */ diff --git a/drivers/net/ovpn-dco/peer.c b/drivers/net/ovpn-dco/peer.c new file mode 100644 index 000000000000..87d3f1b34c4d --- /dev/null +++ b/drivers/net/ovpn-dco/peer.c @@ -0,0 +1,906 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "ovpn.h" +#include "bind.h" +#include "crypto.h" +#include "peer.h" +#include "netlink.h" +#include "tcp.h" + +#include +#include +#include +#include +#include + +static void ovpn_peer_ping(struct timer_list *t) +{ + struct ovpn_peer *peer = from_timer(peer, t, keepalive_xmit); + + netdev_dbg(peer->ovpn->dev, "%s: sending ping to peer %u\n", __func__, peer->id); + ovpn_keepalive_xmit(peer); +} + +static void ovpn_peer_expire(struct timer_list *t) +{ + struct ovpn_peer *peer = from_timer(peer, t, keepalive_recv); + + netdev_dbg(peer->ovpn->dev, "%s: peer %u expired\n", __func__, peer->id); + ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_EXPIRED); +} + +/* Construct a new peer */ +static struct ovpn_peer *ovpn_peer_create(struct ovpn_struct *ovpn, u32 id) +{ + struct ovpn_peer *peer; + int ret; + + /* alloc and init peer object */ + peer = kzalloc(sizeof(*peer), GFP_KERNEL); + if (!peer) + return ERR_PTR(-ENOMEM); + + peer->id = id; + peer->halt = false; + peer->ovpn = ovpn; + + peer->vpn_addrs.ipv4.s_addr = htonl(INADDR_ANY); + peer->vpn_addrs.ipv6 = in6addr_any; + + RCU_INIT_POINTER(peer->bind, NULL); + ovpn_crypto_state_init(&peer->crypto); + spin_lock_init(&peer->lock); + kref_init(&peer->refcount); + ovpn_peer_stats_init(&peer->stats); + + INIT_WORK(&peer->encrypt_work, ovpn_encrypt_work); + INIT_WORK(&peer->decrypt_work, ovpn_decrypt_work); + + ret = dst_cache_init(&peer->dst_cache, GFP_KERNEL); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot initialize dst cache\n", __func__); + goto err; + } + + ret = ptr_ring_init(&peer->tx_ring, OVPN_QUEUE_LEN, GFP_KERNEL); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot allocate TX ring\n", __func__); + goto err_dst_cache; + } + + ret = ptr_ring_init(&peer->rx_ring, OVPN_QUEUE_LEN, GFP_KERNEL); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot allocate RX ring\n", __func__); + goto err_tx_ring; + } + + ret = ptr_ring_init(&peer->netif_rx_ring, OVPN_QUEUE_LEN, GFP_KERNEL); + if (ret < 0) { + netdev_err(ovpn->dev, "%s: cannot allocate NETIF RX ring\n", __func__); + goto err_rx_ring; + } + + /* configure and start NAPI */ + netif_tx_napi_add(ovpn->dev, &peer->napi, ovpn_napi_poll, + NAPI_POLL_WEIGHT); + napi_enable(&peer->napi); + + dev_hold(ovpn->dev); + + timer_setup(&peer->keepalive_xmit, ovpn_peer_ping, 0); + timer_setup(&peer->keepalive_recv, ovpn_peer_expire, 0); + + return peer; +err_rx_ring: + ptr_ring_cleanup(&peer->rx_ring, NULL); +err_tx_ring: + ptr_ring_cleanup(&peer->tx_ring, NULL); +err_dst_cache: + dst_cache_destroy(&peer->dst_cache); +err: + kfree(peer); + return ERR_PTR(ret); +} + +/* Reset the ovpn_sockaddr associated with a peer */ +static int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer, const struct sockaddr_storage *ss, + const u8 *local_ip) +{ + struct ovpn_bind *bind; + size_t ip_len; + + /* create new ovpn_bind object */ + bind = ovpn_bind_from_sockaddr(ss); + if (IS_ERR(bind)) + return PTR_ERR(bind); + + if (local_ip) { + if (ss->ss_family == AF_INET) { + ip_len = sizeof(struct in_addr); + } else if (ss->ss_family == AF_INET6) { + ip_len = sizeof(struct in6_addr); + } else { + netdev_dbg(peer->ovpn->dev, "%s: invalid family for remote endpoint\n", + __func__); + kfree(bind); + return -EINVAL; + } + + memcpy(&bind->local, local_ip, ip_len); + } + + /* set binding */ + ovpn_bind_reset(peer, bind); + + return 0; +} + +void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct sockaddr_storage ss; + const u8 *local_ip = NULL; + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa; + struct ovpn_bind *bind; + sa_family_t family; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) + goto unlock; + + if (likely(ovpn_bind_skb_src_match(bind, skb))) + goto unlock; + + family = skb_protocol_to_family(skb); + + if (bind->sa.in4.sin_family == family) + local_ip = (u8 *)&bind->local; + + switch (family) { + case AF_INET: + sa = (struct sockaddr_in *)&ss; + sa->sin_family = AF_INET; + sa->sin_addr.s_addr = ip_hdr(skb)->saddr; + sa->sin_port = udp_hdr(skb)->source; + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)&ss; + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = ipv6_hdr(skb)->saddr; + sa6->sin6_port = udp_hdr(skb)->source; + sa6->sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif); + break; + default: + goto unlock; + } + + netdev_dbg(peer->ovpn->dev, "%s: peer %d floated to %pIScp", __func__, peer->id, &ss); + ovpn_peer_reset_sockaddr(peer, (struct sockaddr_storage *)&ss, local_ip); +unlock: + rcu_read_unlock(); +} + +static void ovpn_peer_timer_delete_all(struct ovpn_peer *peer) +{ + del_timer_sync(&peer->keepalive_xmit); + del_timer_sync(&peer->keepalive_recv); +} + +static void ovpn_peer_free(struct ovpn_peer *peer) +{ + ovpn_bind_reset(peer, NULL); + ovpn_peer_timer_delete_all(peer); + + WARN_ON(!__ptr_ring_empty(&peer->tx_ring)); + ptr_ring_cleanup(&peer->tx_ring, NULL); + WARN_ON(!__ptr_ring_empty(&peer->rx_ring)); + ptr_ring_cleanup(&peer->rx_ring, NULL); + WARN_ON(!__ptr_ring_empty(&peer->netif_rx_ring)); + ptr_ring_cleanup(&peer->netif_rx_ring, NULL); + + dst_cache_destroy(&peer->dst_cache); + + dev_put(peer->ovpn->dev); + + kfree(peer); +} + +static void ovpn_peer_release_rcu(struct rcu_head *head) +{ + struct ovpn_peer *peer = container_of(head, struct ovpn_peer, rcu); + + ovpn_crypto_state_release(&peer->crypto); + ovpn_peer_free(peer); +} + +void ovpn_peer_release(struct ovpn_peer *peer) +{ + napi_disable(&peer->napi); + netif_napi_del(&peer->napi); + + if (peer->sock) + ovpn_socket_put(peer->sock); + + call_rcu(&peer->rcu, ovpn_peer_release_rcu); +} + +static void ovpn_peer_delete_work(struct work_struct *work) +{ + struct ovpn_peer *peer = container_of(work, struct ovpn_peer, + delete_work); + ovpn_peer_release(peer); + ovpn_netlink_notify_del_peer(peer); +} + +/* Use with kref_put calls, when releasing refcount + * on ovpn_peer objects. This method should only + * be called from process context with config_mutex held. + */ +void ovpn_peer_release_kref(struct kref *kref) +{ + struct ovpn_peer *peer = container_of(kref, struct ovpn_peer, refcount); + + INIT_WORK(&peer->delete_work, ovpn_peer_delete_work); + queue_work(peer->ovpn->events_wq, &peer->delete_work); +} + +struct ovpn_peer *ovpn_peer_new(struct ovpn_struct *ovpn, const struct sockaddr_storage *sa, + struct socket *sock, u32 id, uint8_t *local_ip) +{ + struct ovpn_peer *peer; + int ret; + + /* create new peer */ + peer = ovpn_peer_create(ovpn, id); + if (IS_ERR(peer)) + return peer; + + if (sock->sk->sk_protocol == IPPROTO_UDP) { + /* a UDP peer must have a remote endpoint */ + if (!sa) { + ovpn_peer_release(peer); + return ERR_PTR(-EINVAL); + } + + /* set peer sockaddr */ + ret = ovpn_peer_reset_sockaddr(peer, sa, local_ip); + if (ret < 0) { + ovpn_peer_release(peer); + return ERR_PTR(ret); + } + } + + peer->sock = ovpn_socket_new(sock, peer); + if (IS_ERR(peer->sock)) { + peer->sock = NULL; + ovpn_peer_release(peer); + return ERR_PTR(-ENOTSOCK); + } + + /* schedule initial TCP RX work only after having assigned peer->sock */ + if (peer->sock->sock->sk->sk_protocol == IPPROTO_TCP) + queue_work(peer->ovpn->events_wq, &peer->tcp.rx_work); + + return peer; +} + +/* Configure keepalive parameters */ +void ovpn_peer_keepalive_set(struct ovpn_peer *peer, u32 interval, u32 timeout) +{ + u32 delta; + + netdev_dbg(peer->ovpn->dev, + "%s: scheduling keepalive for peer %u: interval=%u timeout=%u\n", __func__, + peer->id, interval, timeout); + + peer->keepalive_interval = interval; + if (interval > 0) { + delta = msecs_to_jiffies(interval * MSEC_PER_SEC); + mod_timer(&peer->keepalive_xmit, jiffies + delta); + } else { + del_timer(&peer->keepalive_xmit); + } + + peer->keepalive_timeout = timeout; + if (timeout) { + delta = msecs_to_jiffies(timeout * MSEC_PER_SEC); + mod_timer(&peer->keepalive_recv, jiffies + delta); + } else { + del_timer(&peer->keepalive_recv); + } +} + +#define ovpn_peer_index(_tbl, _key, _key_len) \ + (jhash(_key, _key_len, 0) % HASH_SIZE(_tbl)) \ + +static struct ovpn_peer *ovpn_peer_lookup_vpn_addr4(struct hlist_head *head, __be32 *addr) +{ + struct ovpn_peer *tmp, *peer = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, hash_entry_addr4) { + if (*addr != tmp->vpn_addrs.ipv4.s_addr) + continue; + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +static struct ovpn_peer *ovpn_peer_lookup_vpn_addr6(struct hlist_head *head, struct in6_addr *addr) +{ + struct ovpn_peer *tmp, *peer = NULL; + int i; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, hash_entry_addr6) { + for (i = 0; i < 4; i++) { + if (addr->s6_addr32[i] != tmp->vpn_addrs.ipv6.s6_addr32[i]) + continue; + } + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +/** + * ovpn_nexthop4() - looks up the IP of the nexthop for the given destination + * + * Looks up in the IPv4 system routing table the IO of the nexthop to be used + * to reach the destination passed as argument. IF no nexthop can be found, the + * destination itself is returned as it probably has to be used as nexthop. + * + * @ovpn: the private data representing the current VPN session + * @dst: the destination to be looked up + * + * Return the IP of the next hop if found or the dst itself otherwise + */ +static __be32 ovpn_nexthop4(struct ovpn_struct *ovpn, __be32 dst) +{ + struct rtable *rt; + struct flowi4 fl = { + .daddr = dst + }; + + rt = ip_route_output_flow(dev_net(ovpn->dev), &fl, NULL); + if (IS_ERR(rt)) { + net_dbg_ratelimited("%s: no route to host %pI4\n", __func__, &dst); + /* if we end up here this packet is probably going to be thrown away later */ + return dst; + } + + if (!rt->rt_uses_gateway) + goto out; + + dst = rt->rt_gw4; +out: + ip_rt_put(rt); + return dst; +} + +/** + * ovpn_nexthop6() - looks up the IPv6 of the nexthop for the given destination + * + * Looks up in the IPv6 system routing table the IO of the nexthop to be used + * to reach the destination passed as argument. IF no nexthop can be found, the + * destination itself is returned as it probably has to be used as nexthop. + * + * @ovpn: the private data representing the current VPN session + * @dst: the destination to be looked up + * + * Return the IP of the next hop if found or the dst itself otherwise + */ +static struct in6_addr ovpn_nexthop6(struct ovpn_struct *ovpn, struct in6_addr dst) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct rt6_info *rt; + struct flowi6 fl = { + .daddr = dst, + }; + + rt = (struct rt6_info *)ipv6_stub->ipv6_dst_lookup_flow(dev_net(ovpn->dev), NULL, &fl, + NULL); + if (IS_ERR(rt)) { + net_dbg_ratelimited("%s: no route to host %pI6\n", __func__, &dst); + /* if we end up here this packet is probably going to be thrown away later */ + return dst; + } + + if (!(rt->rt6i_flags & RTF_GATEWAY)) + goto out; + + dst = rt->rt6i_gateway; +out: + dst_release((struct dst_entry *)rt); +#endif + return dst; +} + +/** + * ovpn_peer_lookup_vpn_addr() - Lookup peer to send skb to + * + * This function takes a tunnel packet and looks up the peer to send it to + * after encapsulation. The skb is expected to be the in-tunnel packet, without + * any OpenVPN related header. + * + * Assume that the IP header is accessible in the skb data. + * + * @ovpn: the private data representing the current VPN session + * @skb: the skb to extract the destination address from + * + * Return the peer if found or NULL otherwise. + */ +struct ovpn_peer *ovpn_peer_lookup_vpn_addr(struct ovpn_struct *ovpn, struct sk_buff *skb, + bool use_src) +{ + struct ovpn_peer *tmp, *peer = NULL; + struct hlist_head *head; + struct rt6_info *rt6i = NULL; + struct rtable *rt = NULL; + sa_family_t sa_fam; + struct in6_addr addr6; + __be32 addr4; + u32 index; + + /* in P2P mode, no matter the destination, packets are always sent to the single peer + * listening on the other side + */ + if (ovpn->mode == OVPN_MODE_P2P) { + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (likely(tmp && ovpn_peer_hold(tmp))) + peer = tmp; + rcu_read_unlock(); + return peer; + } + + sa_fam = skb_protocol_to_family(skb); + + switch (sa_fam) { + case AF_INET: + if (use_src) + addr4 = ip_hdr(skb)->saddr; + else + addr4 = ip_hdr(skb)->daddr; + addr4 = ovpn_nexthop4(ovpn, addr4); + + index = ovpn_peer_index(ovpn->peers.by_vpn_addr, &addr4, sizeof(addr4)); + head = &ovpn->peers.by_vpn_addr[index]; + + peer = ovpn_peer_lookup_vpn_addr4(head, &addr4); + break; + case AF_INET6: + if (use_src) + addr6 = ipv6_hdr(skb)->saddr; + else + addr6 = ipv6_hdr(skb)->daddr; + addr6 = ovpn_nexthop6(ovpn, addr6); + + index = ovpn_peer_index(ovpn->peers.by_vpn_addr, &addr6, sizeof(addr6)); + head = &ovpn->peers.by_vpn_addr[index]; + + peer = ovpn_peer_lookup_vpn_addr6(head, &addr6); + break; + } + + if (rt) + ip_rt_put(rt); + if (rt6i) + dst_release((struct dst_entry *)rt6i); + + return peer; +} + +static bool ovpn_peer_transp_match(struct ovpn_peer *peer, struct sockaddr_storage *ss) +{ + struct ovpn_bind *bind = rcu_dereference(peer->bind); + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + + if (unlikely(!bind)) + return false; + + if (ss->ss_family != bind->sa.in4.sin_family) + return false; + + switch (ss->ss_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)ss; + if (sa4->sin_addr.s_addr != bind->sa.in4.sin_addr.s_addr) + return false; + if (sa4->sin_port != bind->sa.in4.sin_port) + return false; + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)ss; + if (memcmp(&sa6->sin6_addr, &bind->sa.in6.sin6_addr, sizeof(struct in6_addr))) + return false; + if (sa6->sin6_port != bind->sa.in6.sin6_port) + return false; + break; + default: + return false; + } + + return true; +} + +static bool ovpn_peer_skb_to_sockaddr(struct sk_buff *skb, struct sockaddr_storage *ss) +{ + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + + ss->ss_family = skb_protocol_to_family(skb); + switch (ss->ss_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)ss; + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = ip_hdr(skb)->saddr; + sa4->sin_port = udp_hdr(skb)->source; + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)ss; + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = ipv6_hdr(skb)->saddr; + sa6->sin6_port = udp_hdr(skb)->source; + break; + default: + return false; + } + + return true; +} + +static struct ovpn_peer *ovpn_peer_lookup_transp_addr_p2p(struct ovpn_struct *ovpn, + struct sockaddr_storage *ss) +{ + struct ovpn_peer *tmp, *peer = NULL; + + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (likely(tmp && ovpn_peer_transp_match(tmp, ss) && ovpn_peer_hold(tmp))) + peer = tmp; + rcu_read_unlock(); + + return peer; +} + +struct ovpn_peer *ovpn_peer_lookup_transp_addr(struct ovpn_struct *ovpn, struct sk_buff *skb) +{ + struct ovpn_peer *peer = NULL, *tmp; + struct sockaddr_storage ss = { 0 }; + struct hlist_head *head; + size_t sa_len; + bool found; + u32 index; + + if (unlikely(!ovpn_peer_skb_to_sockaddr(skb, &ss))) + return NULL; + + if (ovpn->mode == OVPN_MODE_P2P) + return ovpn_peer_lookup_transp_addr_p2p(ovpn, &ss); + + switch (ss.ss_family) { + case AF_INET: + sa_len = sizeof(struct sockaddr_in); + break; + case AF_INET6: + sa_len = sizeof(struct sockaddr_in6); + break; + default: + return NULL; + } + + index = ovpn_peer_index(ovpn->peers.by_transp_addr, &ss, sa_len); + head = &ovpn->peers.by_transp_addr[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, hash_entry_transp_addr) { + found = ovpn_peer_transp_match(tmp, &ss); + if (!found) + continue; + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +static struct ovpn_peer *ovpn_peer_lookup_id_p2p(struct ovpn_struct *ovpn, u32 peer_id) +{ + struct ovpn_peer *tmp, *peer = NULL; + + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (likely(tmp && tmp->id == peer_id && ovpn_peer_hold(tmp))) + peer = tmp; + rcu_read_unlock(); + + return peer; +} + +struct ovpn_peer *ovpn_peer_lookup_id(struct ovpn_struct *ovpn, u32 peer_id) +{ + struct ovpn_peer *tmp, *peer = NULL; + struct hlist_head *head; + u32 index; + + if (ovpn->mode == OVPN_MODE_P2P) + return ovpn_peer_lookup_id_p2p(ovpn, peer_id); + + index = ovpn_peer_index(ovpn->peers.by_id, &peer_id, sizeof(peer_id)); + head = &ovpn->peers.by_id[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, hash_entry_id) { + if (tmp->id != peer_id) + continue; + + if (!ovpn_peer_hold(tmp)) + continue; + + peer = tmp; + break; + } + rcu_read_unlock(); + + return peer; +} + +void ovpn_peer_update_local_endpoint(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct ovpn_bind *bind; + + rcu_read_lock(); + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) + goto unlock; + + switch (skb_protocol_to_family(skb)) { + case AF_INET: + if (unlikely(bind->local.ipv4.s_addr != ip_hdr(skb)->daddr)) { + netdev_dbg(peer->ovpn->dev, + "%s: learning local IPv4 for peer %d (%pI4 -> %pI4)\n", __func__, + peer->id, &bind->local.ipv4.s_addr, &ip_hdr(skb)->daddr); + bind->local.ipv4.s_addr = ip_hdr(skb)->daddr; + } + break; + case AF_INET6: + if (unlikely(memcmp(&bind->local.ipv6, &ipv6_hdr(skb)->daddr, + sizeof(bind->local.ipv6)))) { + netdev_dbg(peer->ovpn->dev, + "%s: learning local IPv6 for peer %d (%pI6c -> %pI6c\n", + __func__, peer->id, &bind->local.ipv6, &ipv6_hdr(skb)->daddr); + bind->local.ipv6 = ipv6_hdr(skb)->daddr; + } + break; + default: + break; + } +unlock: + rcu_read_unlock(); +} + +static int ovpn_peer_add_mp(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + struct sockaddr_storage sa = { 0 }; + struct sockaddr_in6 *sa6; + struct sockaddr_in *sa4; + struct ovpn_bind *bind; + struct ovpn_peer *tmp; + size_t salen; + int ret = 0; + u32 index; + + spin_lock_bh(&ovpn->peers.lock); + /* do not add duplicates */ + tmp = ovpn_peer_lookup_id(ovpn, peer->id); + if (tmp) { + ovpn_peer_put(tmp); + ret = -EEXIST; + goto unlock; + } + + hlist_del_init_rcu(&peer->hash_entry_transp_addr); + bind = rcu_dereference_protected(peer->bind, true); + /* peers connected via UDP have bind == NULL */ + if (bind) { + switch (bind->sa.in4.sin_family) { + case AF_INET: + sa4 = (struct sockaddr_in *)&sa; + + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = bind->sa.in4.sin_addr.s_addr; + sa4->sin_port = bind->sa.in4.sin_port; + salen = sizeof(*sa4); + break; + case AF_INET6: + sa6 = (struct sockaddr_in6 *)&sa; + + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = bind->sa.in6.sin6_addr; + sa6->sin6_port = bind->sa.in6.sin6_port; + salen = sizeof(*sa6); + break; + default: + ret = -EPROTONOSUPPORT; + goto unlock; + } + + index = ovpn_peer_index(ovpn->peers.by_transp_addr, &sa, salen); + hlist_add_head_rcu(&peer->hash_entry_transp_addr, + &ovpn->peers.by_transp_addr[index]); + } + + index = ovpn_peer_index(ovpn->peers.by_id, &peer->id, sizeof(peer->id)); + hlist_add_head_rcu(&peer->hash_entry_id, &ovpn->peers.by_id[index]); + + if (peer->vpn_addrs.ipv4.s_addr != htonl(INADDR_ANY)) { + index = ovpn_peer_index(ovpn->peers.by_vpn_addr, &peer->vpn_addrs.ipv4, + sizeof(peer->vpn_addrs.ipv4)); + hlist_add_head_rcu(&peer->hash_entry_addr4, &ovpn->peers.by_vpn_addr[index]); + } + + hlist_del_init_rcu(&peer->hash_entry_addr6); + if (memcmp(&peer->vpn_addrs.ipv6, &in6addr_any, sizeof(peer->vpn_addrs.ipv6))) { + index = ovpn_peer_index(ovpn->peers.by_vpn_addr, &peer->vpn_addrs.ipv6, + sizeof(peer->vpn_addrs.ipv6)); + hlist_add_head_rcu(&peer->hash_entry_addr6, &ovpn->peers.by_vpn_addr[index]); + } + +unlock: + spin_unlock_bh(&ovpn->peers.lock); + + return ret; +} + +static int ovpn_peer_add_p2p(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + struct ovpn_peer *tmp; + + spin_lock_bh(&ovpn->lock); + /* in p2p mode it is possible to have a single peer only, therefore the + * old one is released and substituted by the new one + */ + tmp = rcu_dereference(ovpn->peer); + if (tmp) { + tmp->delete_reason = OVPN_DEL_PEER_REASON_TEARDOWN; + ovpn_peer_put(tmp); + } + + rcu_assign_pointer(ovpn->peer, peer); + spin_unlock_bh(&ovpn->lock); + + return 0; +} + +/* assume refcounter was increased by caller */ +int ovpn_peer_add(struct ovpn_struct *ovpn, struct ovpn_peer *peer) +{ + switch (ovpn->mode) { + case OVPN_MODE_MP: + return ovpn_peer_add_mp(ovpn, peer); + case OVPN_MODE_P2P: + return ovpn_peer_add_p2p(ovpn, peer); + default: + return -EOPNOTSUPP; + } +} + +static void ovpn_peer_unhash(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason) +{ + hlist_del_init_rcu(&peer->hash_entry_id); + hlist_del_init_rcu(&peer->hash_entry_addr4); + hlist_del_init_rcu(&peer->hash_entry_addr6); + hlist_del_init_rcu(&peer->hash_entry_transp_addr); + + ovpn_peer_put(peer); + peer->delete_reason = reason; +} + +static int ovpn_peer_del_mp(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason) +{ + struct ovpn_peer *tmp; + int ret = 0; + + spin_lock_bh(&peer->ovpn->peers.lock); + tmp = ovpn_peer_lookup_id(peer->ovpn, peer->id); + if (tmp != peer) { + ret = -ENOENT; + goto unlock; + } + ovpn_peer_unhash(peer, reason); + +unlock: + spin_unlock_bh(&peer->ovpn->peers.lock); + + if (tmp) + ovpn_peer_put(tmp); + + return ret; +} + +static int ovpn_peer_del_p2p(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason) +{ + struct ovpn_peer *tmp; + int ret = -ENOENT; + + spin_lock_bh(&peer->ovpn->lock); + tmp = rcu_dereference(peer->ovpn->peer); + if (tmp != peer) + goto unlock; + + ovpn_peer_put(tmp); + tmp->delete_reason = reason; + RCU_INIT_POINTER(peer->ovpn->peer, NULL); + ret = 0; + +unlock: + spin_unlock_bh(&peer->ovpn->lock); + + return ret; +} + +void ovpn_peer_release_p2p(struct ovpn_struct *ovpn) +{ + struct ovpn_peer *tmp; + + rcu_read_lock(); + tmp = rcu_dereference(ovpn->peer); + if (!tmp) + goto unlock; + + ovpn_peer_del_p2p(tmp, OVPN_DEL_PEER_REASON_TEARDOWN); +unlock: + rcu_read_unlock(); +} + +int ovpn_peer_del(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason) +{ + switch (peer->ovpn->mode) { + case OVPN_MODE_MP: + return ovpn_peer_del_mp(peer, reason); + case OVPN_MODE_P2P: + return ovpn_peer_del_p2p(peer, reason); + default: + return -EOPNOTSUPP; + } +} + +void ovpn_peers_free(struct ovpn_struct *ovpn) +{ + struct hlist_node *tmp; + struct ovpn_peer *peer; + int bkt; + + spin_lock_bh(&ovpn->peers.lock); + hash_for_each_safe(ovpn->peers.by_id, bkt, tmp, peer, hash_entry_id) + ovpn_peer_unhash(peer, OVPN_DEL_PEER_REASON_TEARDOWN); + spin_unlock_bh(&ovpn->peers.lock); +} diff --git a/drivers/net/ovpn-dco/peer.h b/drivers/net/ovpn-dco/peer.h new file mode 100644 index 000000000000..b759c2c9da48 --- /dev/null +++ b/drivers/net/ovpn-dco/peer.h @@ -0,0 +1,168 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNPEER_H_ +#define _NET_OVPN_DCO_OVPNPEER_H_ + +#include "addr.h" +#include "bind.h" +#include "sock.h" +#include "stats.h" + +#include +#include +#include + +struct ovpn_peer { + struct ovpn_struct *ovpn; + + u32 id; + + struct { + struct in_addr ipv4; + struct in6_addr ipv6; + } vpn_addrs; + + struct hlist_node hash_entry_id; + struct hlist_node hash_entry_addr4; + struct hlist_node hash_entry_addr6; + struct hlist_node hash_entry_transp_addr; + + /* work objects to handle encryption/decryption of packets. + * these works are queued on the ovpn->crypt_wq workqueue. + */ + struct work_struct encrypt_work; + struct work_struct decrypt_work; + + struct ptr_ring tx_ring; + struct ptr_ring rx_ring; + struct ptr_ring netif_rx_ring; + + struct napi_struct napi; + + struct ovpn_socket *sock; + + /* state of the TCP reading. Needed to keep track of how much of a single packet has already + * been read from the stream and how much is missing + */ + struct { + struct ptr_ring tx_ring; + struct work_struct tx_work; + struct work_struct rx_work; + + u8 raw_len[sizeof(u16)]; + struct sk_buff *skb; + u16 offset; + u16 data_len; + struct { + void (*sk_state_change)(struct sock *sk); + void (*sk_data_ready)(struct sock *sk); + void (*sk_write_space)(struct sock *sk); + } sk_cb; + } tcp; + + struct dst_cache dst_cache; + + /* our crypto state */ + struct ovpn_crypto_state crypto; + + /* our binding to peer, protected by spinlock */ + struct ovpn_bind __rcu *bind; + + /* timer used to send periodic ping messages to the other peer, if no + * other data was sent within the past keepalive_interval seconds + */ + struct timer_list keepalive_xmit; + /* keepalive interval in seconds */ + unsigned long keepalive_interval; + + /* timer used to mark a peer as expired when no data is received for + * keepalive_timeout seconds + */ + struct timer_list keepalive_recv; + /* keepalive timeout in seconds */ + unsigned long keepalive_timeout; + + /* true if ovpn_peer_mark_delete was called */ + bool halt; + + /* per-peer rx/tx stats */ + struct ovpn_peer_stats stats; + + /* why peer was deleted - keepalive timeout, module removed etc */ + enum ovpn_del_peer_reason delete_reason; + + /* protects binding to peer (bind) and timers + * (keepalive_xmit, keepalive_expire) + */ + spinlock_t lock; + + /* needed because crypto methods can go async */ + struct kref refcount; + + /* needed to free a peer in an RCU safe way */ + struct rcu_head rcu; + + /* needed to notify userspace about deletion */ + struct work_struct delete_work; +}; + +void ovpn_peer_release_kref(struct kref *kref); +void ovpn_peer_release(struct ovpn_peer *peer); + +static inline bool ovpn_peer_hold(struct ovpn_peer *peer) +{ + return kref_get_unless_zero(&peer->refcount); +} + +static inline void ovpn_peer_put(struct ovpn_peer *peer) +{ + kref_put(&peer->refcount, ovpn_peer_release_kref); +} + +static inline void ovpn_peer_keepalive_recv_reset(struct ovpn_peer *peer) +{ + u32 delta = msecs_to_jiffies(peer->keepalive_timeout * MSEC_PER_SEC); + + if (unlikely(!delta)) + return; + + mod_timer(&peer->keepalive_recv, jiffies + delta); +} + +static inline void ovpn_peer_keepalive_xmit_reset(struct ovpn_peer *peer) +{ + u32 delta = msecs_to_jiffies(peer->keepalive_interval * MSEC_PER_SEC); + + if (unlikely(!delta)) + return; + + mod_timer(&peer->keepalive_xmit, jiffies + delta); +} + +struct ovpn_peer *ovpn_peer_new(struct ovpn_struct *ovpn, const struct sockaddr_storage *sa, + struct socket *sock, u32 id, uint8_t *local_ip); + +void ovpn_peer_keepalive_set(struct ovpn_peer *peer, u32 interval, u32 timeout); + +int ovpn_peer_add(struct ovpn_struct *ovpn, struct ovpn_peer *peer); +int ovpn_peer_del(struct ovpn_peer *peer, enum ovpn_del_peer_reason reason); +struct ovpn_peer *ovpn_peer_find(struct ovpn_struct *ovpn, u32 peer_id); +void ovpn_peer_release_p2p(struct ovpn_struct *ovpn); +void ovpn_peers_free(struct ovpn_struct *ovpn); + +struct ovpn_peer *ovpn_peer_lookup_transp_addr(struct ovpn_struct *ovpn, struct sk_buff *skb); +struct ovpn_peer *ovpn_peer_lookup_vpn_addr(struct ovpn_struct *ovpn, struct sk_buff *skb, + bool use_src); +struct ovpn_peer *ovpn_peer_lookup_id(struct ovpn_struct *ovpn, u32 peer_id); + +void ovpn_peer_update_local_endpoint(struct ovpn_peer *peer, struct sk_buff *skb); +void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb); + +#endif /* _NET_OVPN_DCO_OVPNPEER_H_ */ diff --git a/drivers/net/ovpn-dco/pktid.c b/drivers/net/ovpn-dco/pktid.c new file mode 100644 index 000000000000..fcde8fba5156 --- /dev/null +++ b/drivers/net/ovpn-dco/pktid.c @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#include "pktid.h" + +#include +#include + +void ovpn_pktid_xmit_init(struct ovpn_pktid_xmit *pid) +{ + atomic64_set(&pid->seq_num, 1); + pid->tcp_linear = NULL; +} + +void ovpn_pktid_recv_init(struct ovpn_pktid_recv *pr) +{ + memset(pr, 0, sizeof(*pr)); + spin_lock_init(&pr->lock); +} + +/* Packet replay detection. + * Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1. + */ +int ovpn_pktid_recv(struct ovpn_pktid_recv *pr, u32 pkt_id, u32 pkt_time) +{ + const unsigned long now = jiffies; + int ret; + + spin_lock(&pr->lock); + + /* expire backtracks at or below pr->id after PKTID_RECV_EXPIRE time */ + if (unlikely(time_after_eq(now, pr->expire))) + pr->id_floor = pr->id; + + /* ID must not be zero */ + if (unlikely(pkt_id == 0)) { + ret = -EINVAL; + goto out; + } + + /* time changed? */ + if (unlikely(pkt_time != pr->time)) { + if (pkt_time > pr->time) { + /* time moved forward, accept */ + pr->base = 0; + pr->extent = 0; + pr->id = 0; + pr->time = pkt_time; + pr->id_floor = 0; + } else { + /* time moved backward, reject */ + ret = -ETIME; + goto out; + } + } + + if (likely(pkt_id == pr->id + 1)) { + /* well-formed ID sequence (incremented by 1) */ + pr->base = REPLAY_INDEX(pr->base, -1); + pr->history[pr->base / 8] |= (1 << (pr->base % 8)); + if (pr->extent < REPLAY_WINDOW_SIZE) + ++pr->extent; + pr->id = pkt_id; + } else if (pkt_id > pr->id) { + /* ID jumped forward by more than one */ + const unsigned int delta = pkt_id - pr->id; + + if (delta < REPLAY_WINDOW_SIZE) { + unsigned int i; + + pr->base = REPLAY_INDEX(pr->base, -delta); + pr->history[pr->base / 8] |= (1 << (pr->base % 8)); + pr->extent += delta; + if (pr->extent > REPLAY_WINDOW_SIZE) + pr->extent = REPLAY_WINDOW_SIZE; + for (i = 1; i < delta; ++i) { + unsigned int newb = REPLAY_INDEX(pr->base, i); + + pr->history[newb / 8] &= ~BIT(newb % 8); + } + } else { + pr->base = 0; + pr->extent = REPLAY_WINDOW_SIZE; + memset(pr->history, 0, sizeof(pr->history)); + pr->history[0] = 1; + } + pr->id = pkt_id; + } else { + /* ID backtrack */ + const unsigned int delta = pr->id - pkt_id; + + if (delta > pr->max_backtrack) + pr->max_backtrack = delta; + if (delta < pr->extent) { + if (pkt_id > pr->id_floor) { + const unsigned int ri = REPLAY_INDEX(pr->base, + delta); + u8 *p = &pr->history[ri / 8]; + const u8 mask = (1 << (ri % 8)); + + if (*p & mask) { + ret = -EINVAL; + goto out; + } + *p |= mask; + } else { + ret = -EINVAL; + goto out; + } + } else { + ret = -EINVAL; + goto out; + } + } + + pr->expire = now + PKTID_RECV_EXPIRE; + ret = 0; +out: + spin_unlock(&pr->lock); + return ret; +} diff --git a/drivers/net/ovpn-dco/pktid.h b/drivers/net/ovpn-dco/pktid.h new file mode 100644 index 000000000000..2447bb37ba55 --- /dev/null +++ b/drivers/net/ovpn-dco/pktid.h @@ -0,0 +1,116 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_DCO_OVPNPKTID_H_ +#define _NET_OVPN_DCO_OVPNPKTID_H_ + +#include "main.h" + +/* When the OpenVPN protocol is run in AEAD mode, use + * the OpenVPN packet ID as the AEAD nonce: + * + * 00000005 521c3b01 4308c041 + * [seq # ] [ nonce_tail ] + * [ 12-byte full IV ] -> NONCE_SIZE + * [4-bytes -> NONCE_WIRE_SIZE + * on wire] + */ + +/* OpenVPN nonce size */ +#define NONCE_SIZE 12 +/* amount of bytes of the nonce received from user space */ +#define NONCE_TAIL_SIZE 8 + +/* OpenVPN nonce size reduced by 8-byte nonce tail -- this is the + * size of the AEAD Associated Data (AD) sent over the wire + * and is normally the head of the IV + */ +#define NONCE_WIRE_SIZE (NONCE_SIZE - sizeof(struct ovpn_nonce_tail)) + +/* If no packets received for this length of time, set a backtrack floor + * at highest received packet ID thus far. + */ +#define PKTID_RECV_EXPIRE (30 * HZ) + +/* Last 8 bytes of AEAD nonce + * Provided by userspace and usually derived from + * key material generated during TLS handshake + */ +struct ovpn_nonce_tail { + u8 u8[NONCE_TAIL_SIZE]; +}; + +/* Packet-ID state for transmitter */ +struct ovpn_pktid_xmit { + atomic64_t seq_num; + struct ovpn_tcp_linear *tcp_linear; +}; + +/* replay window sizing in bytes = 2^REPLAY_WINDOW_ORDER */ +#define REPLAY_WINDOW_ORDER 8 + +#define REPLAY_WINDOW_BYTES BIT(REPLAY_WINDOW_ORDER) +#define REPLAY_WINDOW_SIZE (REPLAY_WINDOW_BYTES * 8) +#define REPLAY_INDEX(base, i) (((base) + (i)) & (REPLAY_WINDOW_SIZE - 1)) + +/* Packet-ID state for receiver. + * Other than lock member, can be zeroed to initialize. + */ +struct ovpn_pktid_recv { + /* "sliding window" bitmask of recent packet IDs received */ + u8 history[REPLAY_WINDOW_BYTES]; + /* bit position of deque base in history */ + unsigned int base; + /* extent (in bits) of deque in history */ + unsigned int extent; + /* expiration of history in jiffies */ + unsigned long expire; + /* highest sequence number received */ + u32 id; + /* highest time stamp received */ + u32 time; + /* we will only accept backtrack IDs > id_floor */ + u32 id_floor; + unsigned int max_backtrack; + /* protects entire pktd ID state */ + spinlock_t lock; +}; + +/* Get the next packet ID for xmit */ +static inline int ovpn_pktid_xmit_next(struct ovpn_pktid_xmit *pid, u32 *pktid) +{ + const s64 seq_num = atomic64_fetch_add_unless(&pid->seq_num, 1, + 0x100000000LL); + /* when the 32bit space is over, we return an error because the packet ID is used to create + * the cipher IV and we do not want to re-use the same value more than once + */ + if (unlikely(seq_num == 0x100000000LL)) + return -ERANGE; + + *pktid = (u32)seq_num; + + return 0; +} + +/* Write 12-byte AEAD IV to dest */ +static inline void ovpn_pktid_aead_write(const u32 pktid, + const struct ovpn_nonce_tail *nt, + unsigned char *dest) +{ + *(__force __be32 *)(dest) = htonl(pktid); + BUILD_BUG_ON(4 + sizeof(struct ovpn_nonce_tail) != NONCE_SIZE); + memcpy(dest + 4, nt->u8, sizeof(struct ovpn_nonce_tail)); +} + +void ovpn_pktid_xmit_init(struct ovpn_pktid_xmit *pid); +void ovpn_pktid_recv_init(struct ovpn_pktid_recv *pr); + +int ovpn_pktid_recv(struct ovpn_pktid_recv *pr, u32 pkt_id, u32 pkt_time); + +#endif /* _NET_OVPN_DCO_OVPNPKTID_H_ */ diff --git a/drivers/net/ovpn-dco/proto.h b/drivers/net/ovpn-dco/proto.h new file mode 100644 index 000000000000..875529021c1b --- /dev/null +++ b/drivers/net/ovpn-dco/proto.h @@ -0,0 +1,101 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_DCO_OVPNPROTO_H_ +#define _NET_OVPN_DCO_OVPNPROTO_H_ + +#include "main.h" + +#include + +/* Methods for operating on the initial command + * byte of the OpenVPN protocol. + */ + +/* packet opcode (high 5 bits) and key-id (low 3 bits) are combined in + * one byte + */ +#define OVPN_KEY_ID_MASK 0x07 +#define OVPN_OPCODE_SHIFT 3 +#define OVPN_OPCODE_MASK 0x1F +/* upper bounds on opcode and key ID */ +#define OVPN_KEY_ID_MAX (OVPN_KEY_ID_MASK + 1) +#define OVPN_OPCODE_MAX (OVPN_OPCODE_MASK + 1) +/* packet opcodes of interest to us */ +#define OVPN_DATA_V1 6 /* data channel V1 packet */ +#define OVPN_DATA_V2 9 /* data channel V2 packet */ +/* size of initial packet opcode */ +#define OVPN_OP_SIZE_V1 1 +#define OVPN_OP_SIZE_V2 4 +#define OVPN_PEER_ID_MASK 0x00FFFFFF +#define OVPN_PEER_ID_UNDEF 0x00FFFFFF +/* first byte of keepalive message */ +#define OVPN_KEEPALIVE_FIRST_BYTE 0x2a +/* first byte of exit message */ +#define OVPN_EXPLICIT_EXIT_NOTIFY_FIRST_BYTE 0x28 + +/** + * Extract the OP code from the specified byte + * + * Return the OP code + */ +static inline u8 ovpn_opcode_from_byte(u8 byte) +{ + return byte >> OVPN_OPCODE_SHIFT; +} + +/** + * Extract the OP code from the skb head. + * + * Note: this function assumes that the skb head was pulled enough + * to access the first byte. + * + * Return the OP code + */ +static inline u8 ovpn_opcode_from_skb(const struct sk_buff *skb, u16 offset) +{ + return ovpn_opcode_from_byte(*(skb->data + offset)); +} + +/** + * Extract the key ID from the skb head. + * + * Note: this function assumes that the skb head was pulled enough + * to access the first byte. + * + * Return the key ID + */ + +static inline u8 ovpn_key_id_from_skb(const struct sk_buff *skb) +{ + return *skb->data & OVPN_KEY_ID_MASK; +} + +/** + * Extract the peer ID from the skb head. + * + * Note: this function assumes that the skb head was pulled enough + * to access the first 4 bytes. + * + * Return the peer ID. + */ + +static inline u32 ovpn_peer_id_from_skb(const struct sk_buff *skb, u16 offset) +{ + return ntohl(*(__be32 *)(skb->data + offset)) & OVPN_PEER_ID_MASK; +} + +static inline u32 ovpn_opcode_compose(u8 opcode, u8 key_id, u32 peer_id) +{ + const u8 op = (opcode << OVPN_OPCODE_SHIFT) | (key_id & OVPN_KEY_ID_MASK); + + return (op << 24) | (peer_id & OVPN_PEER_ID_MASK); +} + +#endif /* _NET_OVPN_DCO_OVPNPROTO_H_ */ diff --git a/drivers/net/ovpn-dco/rcu.h b/drivers/net/ovpn-dco/rcu.h new file mode 100644 index 000000000000..02a50f49ba2e --- /dev/null +++ b/drivers/net/ovpn-dco/rcu.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_OVPNRCU_H_ +#define _NET_OVPN_DCO_OVPNRCU_H_ + +static inline void ovpn_rcu_lockdep_assert_held(void) +{ +#ifdef CONFIG_PROVE_RCU + RCU_LOCKDEP_WARN(!rcu_read_lock_held(), + "ovpn-dco RCU read lock not held"); +#endif +} + +#endif /* _NET_OVPN_DCO_OVPNRCU_H_ */ diff --git a/drivers/net/ovpn-dco/skb.h b/drivers/net/ovpn-dco/skb.h new file mode 100644 index 000000000000..d38dc2da01df --- /dev/null +++ b/drivers/net/ovpn-dco/skb.h @@ -0,0 +1,54 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + * James Yonan + */ + +#ifndef _NET_OVPN_DCO_SKB_H_ +#define _NET_OVPN_DCO_SKB_H_ + +#include +#include +#include +#include +#include +#include + +#define OVPN_SKB_CB(skb) ((struct ovpn_skb_cb *)&((skb)->cb)) + +struct ovpn_skb_cb { + /* original recv packet size for stats accounting */ + unsigned int rx_stats_size; + + union { + struct in_addr ipv4; + struct in6_addr ipv6; + } local; + sa_family_t sa_fam; +}; + +/* Return IP protocol version from skb header. + * Return 0 if protocol is not IPv4/IPv6 or cannot be read. + */ +static inline __be16 ovpn_ip_check_protocol(struct sk_buff *skb) +{ + __be16 proto = 0; + + /* skb could be non-linear, + * make sure IP header is in non-fragmented part + */ + if (!pskb_network_may_pull(skb, sizeof(struct iphdr))) + return 0; + + if (ip_hdr(skb)->version == 4) + proto = htons(ETH_P_IP); + else if (ip_hdr(skb)->version == 6) + proto = htons(ETH_P_IPV6); + + return proto; +} + +#endif /* _NET_OVPN_DCO_SKB_H_ */ diff --git a/drivers/net/ovpn-dco/sock.c b/drivers/net/ovpn-dco/sock.c new file mode 100644 index 000000000000..e92a4a9b952e --- /dev/null +++ b/drivers/net/ovpn-dco/sock.c @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "main.h" +#include "ovpn.h" +#include "peer.h" +#include "sock.h" +#include "rcu.h" +#include "tcp.h" +#include "udp.h" + +#include +#include + +/* Finalize release of socket, called after RCU grace period */ +static void ovpn_socket_detach(struct socket *sock) +{ + if (!sock) + return; + + if (sock->sk->sk_protocol == IPPROTO_UDP) + ovpn_udp_socket_detach(sock); + else if (sock->sk->sk_protocol == IPPROTO_TCP) + ovpn_tcp_socket_detach(sock); + + sockfd_put(sock); +} + +void ovpn_socket_release_kref(struct kref *kref) +{ + struct ovpn_socket *sock = container_of(kref, struct ovpn_socket, refcount); + + ovpn_socket_detach(sock->sock); + kfree_rcu(sock, rcu); +} + +static bool ovpn_socket_hold(struct ovpn_socket *sock) +{ + return kref_get_unless_zero(&sock->refcount); +} + +static struct ovpn_socket *ovpn_socket_get(struct socket *sock) +{ + struct ovpn_socket *ovpn_sock; + + rcu_read_lock(); + ovpn_sock = rcu_dereference_sk_user_data(sock->sk); + if (!ovpn_socket_hold(ovpn_sock)) { + pr_warn("%s: found ovpn_socket with ref = 0\n", __func__); + ovpn_sock = NULL; + } + rcu_read_unlock(); + + return ovpn_sock; +} + +/* Finalize release of socket, called after RCU grace period */ +static int ovpn_socket_attach(struct socket *sock, struct ovpn_peer *peer) +{ + int ret = -EOPNOTSUPP; + + if (!sock || !peer) + return -EINVAL; + + if (sock->sk->sk_protocol == IPPROTO_UDP) + ret = ovpn_udp_socket_attach(sock, peer->ovpn); + else if (sock->sk->sk_protocol == IPPROTO_TCP) + ret = ovpn_tcp_socket_attach(sock, peer); + + return ret; +} + +struct ovpn_struct *ovpn_from_udp_sock(struct sock *sk) +{ + struct ovpn_socket *ovpn_sock; + + ovpn_rcu_lockdep_assert_held(); + + if (unlikely(READ_ONCE(udp_sk(sk)->encap_type) != UDP_ENCAP_OVPNINUDP)) + return NULL; + + ovpn_sock = rcu_dereference_sk_user_data(sk); + if (unlikely(!ovpn_sock)) + return NULL; + + /* make sure that sk matches our stored transport socket */ + if (unlikely(!ovpn_sock->sock || sk != ovpn_sock->sock->sk)) + return NULL; + + return ovpn_sock->ovpn; +} + +struct ovpn_socket *ovpn_socket_new(struct socket *sock, struct ovpn_peer *peer) +{ + struct ovpn_socket *ovpn_sock; + int ret; + + ret = ovpn_socket_attach(sock, peer); + if (ret < 0 && ret != -EALREADY) + return ERR_PTR(ret); + + /* if this socket is already owned by this interface, just increase the refcounter */ + if (ret == -EALREADY) { + /* caller is expected to increase the sock refcounter before passing it to this + * function. For this reason we drop it if not needed, like when this socket is + * already owned. + */ + ovpn_sock = ovpn_socket_get(sock); + sockfd_put(sock); + return ovpn_sock; + } + + ovpn_sock = kzalloc(sizeof(*ovpn_sock), GFP_KERNEL); + if (!ovpn_sock) + return ERR_PTR(-ENOMEM); + + ovpn_sock->ovpn = peer->ovpn; + ovpn_sock->sock = sock; + kref_init(&ovpn_sock->refcount); + + /* TCP sockets are per-peer, therefore they are linked to their unique peer */ + if (sock->sk->sk_protocol == IPPROTO_TCP) + ovpn_sock->peer = peer; + + rcu_assign_sk_user_data(sock->sk, ovpn_sock); + + return ovpn_sock; +} diff --git a/drivers/net/ovpn-dco/sock.h b/drivers/net/ovpn-dco/sock.h new file mode 100644 index 000000000000..9e79c1b5fe04 --- /dev/null +++ b/drivers/net/ovpn-dco/sock.h @@ -0,0 +1,54 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_SOCK_H_ +#define _NET_OVPN_DCO_SOCK_H_ + +#include +#include +#include + +#include "peer.h" + +struct ovpn_struct; + +/** + * struct ovpn_socket - a kernel socket referenced in the ovpn-dco code + */ +struct ovpn_socket { + union { + /** @ovpn: the VPN session object owning this socket (UDP only) */ + struct ovpn_struct *ovpn; + + /** @peer: the unique peer transmitting over this socket (TCP only) */ + struct ovpn_peer *peer; + }; + + /** @sock: the kernel socket */ + struct socket *sock; + + /** @refcount: amount of contexts currently referencing this object */ + struct kref refcount; + + /** @rcu: member used to schedule RCU destructor callback */ + struct rcu_head rcu; +}; + +struct ovpn_struct *ovpn_from_udp_sock(struct sock *sk); + +void ovpn_socket_release_kref(struct kref *kref); + +static inline void ovpn_socket_put(struct ovpn_socket *sock) +{ + kref_put(&sock->refcount, ovpn_socket_release_kref); +} + +struct ovpn_socket *ovpn_socket_new(struct socket *sock, struct ovpn_peer *peer); + +#endif /* _NET_OVPN_DCO_SOCK_H_ */ diff --git a/drivers/net/ovpn-dco/stats.c b/drivers/net/ovpn-dco/stats.c new file mode 100644 index 000000000000..ee000b2a2177 --- /dev/null +++ b/drivers/net/ovpn-dco/stats.c @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#include "main.h" +#include "stats.h" + +void ovpn_peer_stats_init(struct ovpn_peer_stats *ps) +{ + atomic64_set(&ps->rx.bytes, 0); + atomic_set(&ps->rx.packets, 0); + + atomic64_set(&ps->tx.bytes, 0); + atomic_set(&ps->tx.packets, 0); +} diff --git a/drivers/net/ovpn-dco/stats.h b/drivers/net/ovpn-dco/stats.h new file mode 100644 index 000000000000..3aa6bdc049c6 --- /dev/null +++ b/drivers/net/ovpn-dco/stats.h @@ -0,0 +1,67 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2020-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + * Lev Stipakov + */ + +#ifndef _NET_OVPN_DCO_OVPNSTATS_H_ +#define _NET_OVPN_DCO_OVPNSTATS_H_ + +#include +#include + +struct ovpn_struct; + +/* per-peer stats, measured on transport layer */ + +/* one stat */ +struct ovpn_peer_stat { + atomic64_t bytes; + atomic_t packets; +}; + +/* rx and tx stats, enabled by notify_per != 0 or period != 0 */ +struct ovpn_peer_stats { + struct ovpn_peer_stat rx; + struct ovpn_peer_stat tx; +}; + +/* struct for OVPN_ERR_STATS */ + +struct ovpn_err_stat { + unsigned int category; + int errcode; + u64 count; +}; + +struct ovpn_err_stats { + /* total stats, returned by kovpn */ + unsigned int total_stats; + /* number of stats dimensioned below */ + unsigned int n_stats; + struct ovpn_err_stat stats[]; +}; + +void ovpn_peer_stats_init(struct ovpn_peer_stats *ps); + +static inline void ovpn_peer_stats_increment(struct ovpn_peer_stat *stat, const unsigned int n) +{ + atomic64_add(n, &stat->bytes); + atomic_inc(&stat->packets); +} + +static inline void ovpn_peer_stats_increment_rx(struct ovpn_peer_stats *stats, const unsigned int n) +{ + ovpn_peer_stats_increment(&stats->rx, n); +} + +static inline void ovpn_peer_stats_increment_tx(struct ovpn_peer_stats *stats, const unsigned int n) +{ + ovpn_peer_stats_increment(&stats->tx, n); +} + +#endif /* _NET_OVPN_DCO_OVPNSTATS_H_ */ diff --git a/drivers/net/ovpn-dco/tcp.c b/drivers/net/ovpn-dco/tcp.c new file mode 100644 index 000000000000..7e6690fee6e7 --- /dev/null +++ b/drivers/net/ovpn-dco/tcp.c @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include "main.h" +#include "ovpnstruct.h" +#include "ovpn.h" +#include "peer.h" +#include "skb.h" +#include "tcp.h" + +#include +#include +#include + +static void ovpn_tcp_state_change(struct sock *sk) +{ +} + +static void ovpn_tcp_data_ready(struct sock *sk) +{ + struct ovpn_socket *sock; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + rcu_read_unlock(); + + if (!sock || !sock->peer) + return; + + queue_work(sock->peer->ovpn->events_wq, &sock->peer->tcp.rx_work); +} + +static void ovpn_tcp_write_space(struct sock *sk) +{ + struct ovpn_socket *sock; + + rcu_read_lock(); + sock = rcu_dereference_sk_user_data(sk); + rcu_read_unlock(); + + if (!sock || !sock->peer) + return; + + queue_work(sock->peer->ovpn->events_wq, &sock->peer->tcp.tx_work); +} + +static void ovpn_destroy_skb(void *skb) +{ + consume_skb(skb); +} + +void ovpn_tcp_socket_detach(struct socket *sock) +{ + struct ovpn_socket *ovpn_sock; + struct ovpn_peer *peer; + + if (!sock) + return; + + rcu_read_lock(); + ovpn_sock = rcu_dereference_sk_user_data(sock->sk); + rcu_read_unlock(); + + if (!ovpn_sock->peer) + return; + + peer = ovpn_sock->peer; + + /* restore CBs that were saved in ovpn_sock_set_tcp_cb() */ + write_lock_bh(&sock->sk->sk_callback_lock); + sock->sk->sk_state_change = peer->tcp.sk_cb.sk_state_change; + sock->sk->sk_data_ready = peer->tcp.sk_cb.sk_data_ready; + sock->sk->sk_write_space = peer->tcp.sk_cb.sk_write_space; + rcu_assign_sk_user_data(sock->sk, NULL); + write_unlock_bh(&sock->sk->sk_callback_lock); + + /* cancel any ongoing work. Done after removing the CBs so that these workers cannot be + * re-armed + */ + cancel_work_sync(&peer->tcp.tx_work); + cancel_work_sync(&peer->tcp.rx_work); + + ptr_ring_cleanup(&peer->tcp.tx_ring, ovpn_destroy_skb); +} + +/* Try to send one skb (or part of it) over the TCP stream. + * + * Return 0 on success or a negative error code otherwise. + * + * Note that the skb is modified by putting away the data being sent, therefore + * the caller should check if skb->len is zero to understand if the full skb was + * sent or not. + */ +static int ovpn_tcp_send_one(struct ovpn_peer *peer, struct sk_buff *skb) +{ + struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL }; + struct kvec iv = { 0 }; + int ret; + + if (skb_linearize(skb) < 0) { + net_err_ratelimited("%s: can't linearize packet\n", __func__); + return -ENOMEM; + } + + /* initialize iv structure now as skb_linearize() may have changed skb->data */ + iv.iov_base = skb->data; + iv.iov_len = skb->len; + + ret = kernel_sendmsg(peer->sock->sock, &msg, &iv, 1, iv.iov_len); + if (ret > 0) { + __skb_pull(skb, ret); + + /* since we update per-cpu stats in process context, + * we need to disable softirqs + */ + local_bh_disable(); + dev_sw_netstats_tx_add(peer->ovpn->dev, 1, ret); + local_bh_enable(); + + return 0; + } + + return ret; +} + +/* Process packets in TCP TX queue */ +static void ovpn_tcp_tx_work(struct work_struct *work) +{ + struct ovpn_peer *peer; + struct sk_buff *skb; + int ret; + + peer = container_of(work, struct ovpn_peer, tcp.tx_work); + while ((skb = __ptr_ring_peek(&peer->tcp.tx_ring))) { + ret = ovpn_tcp_send_one(peer, skb); + if (ret < 0 && ret != -EAGAIN) { + net_warn_ratelimited("%s: cannot send TCP packet to peer %u: %d\n", __func__, + peer->id, ret); + /* in case of TCP error stop sending loop and delete peer */ + ovpn_peer_del(peer, OVPN_DEL_PEER_REASON_TRANSPORT_ERROR); + break; + } else if (!skb->len) { + /* skb was entirely consumed and can now be removed from the ring */ + __ptr_ring_discard_one(&peer->tcp.tx_ring); + consume_skb(skb); + } + + /* give a chance to be rescheduled if needed */ + cond_resched(); + } +} + +static int ovpn_tcp_rx_one(struct ovpn_peer *peer) +{ + struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL }; + struct ovpn_skb_cb *cb; + int status, ret; + + /* no skb allocated means that we have to read (or finish reading) the 2 bytes prefix + * containing the actual packet size. + */ + if (!peer->tcp.skb) { + struct kvec iv = { + .iov_base = peer->tcp.raw_len + peer->tcp.offset, + .iov_len = sizeof(u16) - peer->tcp.offset, + }; + + ret = kernel_recvmsg(peer->sock->sock, &msg, &iv, 1, iv.iov_len, msg.msg_flags); + if (ret <= 0) + return ret; + + peer->tcp.offset += ret; + /* the entire packet size was read, prepare skb for reading data */ + if (peer->tcp.offset == sizeof(u16)) { + u16 len = ntohs(*(__be16 *)peer->tcp.raw_len); + /* invalid packet length: this is a fatal TCP error */ + if (!len) { + netdev_err(peer->ovpn->dev, "%s: received invalid packet length\n", + __func__); + return -EINVAL; + } + + peer->tcp.skb = netdev_alloc_skb_ip_align(peer->ovpn->dev, len); + peer->tcp.offset = 0; + peer->tcp.data_len = len; + } + } else { + struct kvec iv = { + .iov_base = peer->tcp.skb->data + peer->tcp.offset, + .iov_len = peer->tcp.data_len - peer->tcp.offset, + }; + + ret = kernel_recvmsg(peer->sock->sock, &msg, &iv, 1, iv.iov_len, msg.msg_flags); + if (ret <= 0) + return ret; + + peer->tcp.offset += ret; + /* full packet received, send it up for processing */ + if (peer->tcp.offset == peer->tcp.data_len) { + /* update the skb data structure with the amount of data written by + * kernel_recvmsg() + */ + skb_put(peer->tcp.skb, peer->tcp.data_len); + + /* do not perform IP caching for TCP connections */ + cb = OVPN_SKB_CB(peer->tcp.skb); + cb->sa_fam = AF_UNSPEC; + + /* hold reference to peer as requird by ovpn_recv() */ + ovpn_peer_hold(peer); + status = ovpn_recv(peer->ovpn, peer, peer->tcp.skb); + /* skb not consumed - free it now */ + if (unlikely(status < 0)) + kfree_skb(peer->tcp.skb); + + peer->tcp.skb = NULL; + peer->tcp.offset = 0; + peer->tcp.data_len = 0; + } + } + + return ret; +} + +static void ovpn_tcp_rx_work(struct work_struct *work) +{ + struct ovpn_peer *peer = container_of(work, struct ovpn_peer, tcp.rx_work); + int ret; + + while (true) { + /* give a chance to be rescheduled if needed */ + cond_resched(); + + ret = ovpn_tcp_rx_one(peer); + if (ret <= 0) + break; + } + + if (ret < 0 && ret != -EAGAIN) + netdev_err(peer->ovpn->dev, "%s: TCP socket error: %d\n", __func__, ret); +} + +/* Put packet into TCP TX queue and schedule a consumer */ +void ovpn_queue_tcp_skb(struct ovpn_peer *peer, struct sk_buff *skb) +{ + int ret; + + ret = ptr_ring_produce_bh(&peer->tcp.tx_ring, skb); + if (ret < 0) { + kfree_skb_list(skb); + return; + } + + queue_work(peer->ovpn->events_wq, &peer->tcp.tx_work); +} + +/* Set TCP encapsulation callbacks */ +int ovpn_tcp_socket_attach(struct socket *sock, struct ovpn_peer *peer) +{ + void *old_data; + int ret; + + INIT_WORK(&peer->tcp.tx_work, ovpn_tcp_tx_work); + INIT_WORK(&peer->tcp.rx_work, ovpn_tcp_rx_work); + + ret = ptr_ring_init(&peer->tcp.tx_ring, OVPN_QUEUE_LEN, GFP_KERNEL); + if (ret < 0) { + netdev_err(peer->ovpn->dev, "cannot allocate TCP TX ring\n"); + return ret; + } + + peer->tcp.skb = NULL; + peer->tcp.offset = 0; + peer->tcp.data_len = 0; + + write_lock_bh(&sock->sk->sk_callback_lock); + + /* make sure no pre-existing encapsulation handler exists */ + rcu_read_lock(); + old_data = rcu_dereference_sk_user_data(sock->sk); + rcu_read_unlock(); + if (old_data) { + netdev_err(peer->ovpn->dev, "provided socket already taken by other user\n"); + ret = -EBUSY; + goto err; + } + + /* sanity check */ + if (sock->sk->sk_protocol != IPPROTO_TCP) { + netdev_err(peer->ovpn->dev, "expected TCP socket\n"); + ret = -EINVAL; + goto err; + } + + /* only a fully connected socket are expected. Connection should be handled in userspace */ + if (sock->sk->sk_state != TCP_ESTABLISHED) { + netdev_err(peer->ovpn->dev, "unexpected state for TCP socket: %d\n", + sock->sk->sk_state); + ret = -EINVAL; + goto err; + } + + /* save current CBs so that they can be restored upon socket release */ + peer->tcp.sk_cb.sk_state_change = sock->sk->sk_state_change; + peer->tcp.sk_cb.sk_data_ready = sock->sk->sk_data_ready; + peer->tcp.sk_cb.sk_write_space = sock->sk->sk_write_space; + + /* assign our static CBs */ + sock->sk->sk_state_change = ovpn_tcp_state_change; + sock->sk->sk_data_ready = ovpn_tcp_data_ready; + sock->sk->sk_write_space = ovpn_tcp_write_space; + + write_unlock_bh(&sock->sk->sk_callback_lock); + + return 0; +err: + write_unlock_bh(&sock->sk->sk_callback_lock); + ptr_ring_cleanup(&peer->tcp.tx_ring, NULL); + + return ret; +} diff --git a/drivers/net/ovpn-dco/tcp.h b/drivers/net/ovpn-dco/tcp.h new file mode 100644 index 000000000000..d243a8e1c34e --- /dev/null +++ b/drivers/net/ovpn-dco/tcp.h @@ -0,0 +1,38 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_TCP_H_ +#define _NET_OVPN_DCO_TCP_H_ + +#include "peer.h" + +#include +#include +#include +#include + +void ovpn_queue_tcp_skb(struct ovpn_peer *peer, struct sk_buff *skb); + +int ovpn_tcp_socket_attach(struct socket *sock, struct ovpn_peer *peer); +void ovpn_tcp_socket_detach(struct socket *sock); + +/* Prepare skb and enqueue it for sending to peer. + * + * Preparation consist in prepending the skb payload with its size. + * Required by the OpenVPN protocol in order to extract packets from + * the TCP stream on the receiver side. + */ +static inline void ovpn_tcp_send_skb(struct ovpn_peer *peer, struct sk_buff *skb) +{ + u16 len = skb->len; + + *(__be16 *)__skb_push(skb, sizeof(u16)) = htons(len); + ovpn_queue_tcp_skb(peer, skb); +} + +#endif /* _NET_OVPN_DCO_TCP_H_ */ diff --git a/drivers/net/ovpn-dco/udp.c b/drivers/net/ovpn-dco/udp.c new file mode 100644 index 000000000000..afa236d1f15c --- /dev/null +++ b/drivers/net/ovpn-dco/udp.c @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0 +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#include "main.h" +#include "bind.h" +#include "ovpn.h" +#include "ovpnstruct.h" +#include "peer.h" +#include "proto.h" +#include "skb.h" +#include "udp.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * ovpn_udp_encap_recv() - Start processing a received UDP packet. + * If the first byte of the payload is DATA_V2, the packet is further processed, + * otherwise it is forwarded to the UDP stack for delivery to user space. + * + * @sk: the socket the packet was received on + * @skb: the sk_buff containing the actual packet + * + * Return codes: + * 0 : we consumed or dropped packet + * >0 : skb should be passed up to userspace as UDP (packet not consumed) + * <0 : skb should be resubmitted as proto -N (packet not consumed) + */ +static int ovpn_udp_encap_recv(struct sock *sk, struct sk_buff *skb) +{ + struct ovpn_peer *peer = NULL; + struct ovpn_struct *ovpn; + u32 peer_id; + u8 opcode; + int ret; + + ovpn = ovpn_from_udp_sock(sk); + if (unlikely(!ovpn)) { + net_err_ratelimited("%s: cannot obtain ovpn object from UDP socket\n", __func__); + goto drop; + } + + /* Make sure the first 4 bytes of the skb data buffer after the UDP header are accessible. + * They are required to fetch the OP code, the key ID and the peer ID. + */ + if (unlikely(!pskb_may_pull(skb, sizeof(struct udphdr) + 4))) { + net_dbg_ratelimited("%s: packet too small\n", __func__); + goto drop; + } + + opcode = ovpn_opcode_from_skb(skb, sizeof(struct udphdr)); + if (likely(opcode == OVPN_DATA_V2)) { + peer_id = ovpn_peer_id_from_skb(skb, sizeof(struct udphdr)); + /* some OpenVPN server implementations send data packets with the peer-id set to + * undef. In this case we skip the peer lookup by peer-id and we try with the + * transport address + */ + if (peer_id != OVPN_PEER_ID_UNDEF) { + peer = ovpn_peer_lookup_id(ovpn, peer_id); + if (!peer) { + net_err_ratelimited("%s: received data from unknown peer (id: %d)\n", + __func__, peer_id); + goto drop; + } + + /* check if this peer changed it's IP address and update state */ + ovpn_peer_float(peer, skb); + } + } + + if (!peer) { + /* might be a control packet or a data packet with undef peer-id */ + peer = ovpn_peer_lookup_transp_addr(ovpn, skb); + if (unlikely(!peer)) { + if (opcode != OVPN_DATA_V2) { + netdev_dbg(ovpn->dev, + "%s: control packet from unknown peer, sending to userspace", + __func__); + return 1; + } + + netdev_dbg(ovpn->dev, + "%s: received data with undef peer-id from unknown source\n", + __func__); + goto drop; + } + } + + /* pop off outer UDP header */ + __skb_pull(skb, sizeof(struct udphdr)); + + ret = ovpn_recv(ovpn, peer, skb); + if (unlikely(ret < 0)) { + net_err_ratelimited("%s: cannot handle incoming packet: %d\n", __func__, ret); + goto drop; + } + + /* should this be a non DATA_V2 packet, ret will be >0 and this will instruct the UDP + * stack to continue processing this packet as usual (i.e. deliver to user space) + */ + return ret; + +drop: + if (peer) + ovpn_peer_put(peer); + kfree_skb(skb); + return 0; +} + +static int ovpn_udp4_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + struct rtable *rt; + struct flowi4 fl = { + .saddr = bind->local.ipv4.s_addr, + .daddr = bind->sa.in4.sin_addr.s_addr, + .fl4_sport = inet_sk(sk)->inet_sport, + .fl4_dport = bind->sa.in4.sin_port, + .flowi4_proto = sk->sk_protocol, + .flowi4_mark = sk->sk_mark, + }; + int ret; + + local_bh_disable(); + rt = dst_cache_get_ip4(cache, &fl.saddr); + if (rt) + goto transmit; + + if (unlikely(!inet_confirm_addr(sock_net(sk), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { + /* we may end up here when the cached address is not usable anymore. + * In this case we reset address/cache and perform a new look up + */ + fl.saddr = 0; + bind->local.ipv4.s_addr = 0; + dst_cache_reset(cache); + } + + rt = ip_route_output_flow(sock_net(sk), &fl, sk); + if (IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) { + fl.saddr = 0; + bind->local.ipv4.s_addr = 0; + dst_cache_reset(cache); + + rt = ip_route_output_flow(sock_net(sk), &fl, sk); + } + + if (IS_ERR(rt)) { + ret = PTR_ERR(rt); + net_dbg_ratelimited("%s: no route to host %pISpc: %d\n", ovpn->dev->name, + &bind->sa.in4, ret); + goto err; + } + dst_cache_set_ip4(cache, &rt->dst, fl.saddr); + +transmit: + udp_tunnel_xmit_skb(rt, sk, skb, fl.saddr, fl.daddr, 0, + ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, + fl.fl4_dport, false, sk->sk_no_check_tx); + ret = 0; +err: + local_bh_enable(); + return ret; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int ovpn_udp6_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *dst; + int ret; + + struct flowi6 fl = { + .saddr = bind->local.ipv6, + .daddr = bind->sa.in6.sin6_addr, + .fl6_sport = inet_sk(sk)->inet_sport, + .fl6_dport = bind->sa.in6.sin6_port, + .flowi6_proto = sk->sk_protocol, + .flowi6_mark = sk->sk_mark, + .flowi6_oif = bind->sa.in6.sin6_scope_id, + }; + + local_bh_disable(); + dst = dst_cache_get_ip6(cache, &fl.saddr); + if (dst) + goto transmit; + + if (unlikely(!ipv6_chk_addr(sock_net(sk), &fl.saddr, NULL, 0))) { + /* we may end up here when the cached address is not usable anymore. + * In this case we reset address/cache and perform a new look up + */ + fl.saddr = in6addr_any; + bind->local.ipv6 = in6addr_any; + dst_cache_reset(cache); + } + + dst = ipv6_stub->ipv6_dst_lookup_flow(sock_net(sk), sk, &fl, NULL); + if (IS_ERR(dst)) { + ret = PTR_ERR(dst); + net_dbg_ratelimited("%s: no route to host %pISpc: %d\n", ovpn->dev->name, + &bind->sa.in6, ret); + goto err; + } + dst_cache_set_ip6(cache, dst, &fl.saddr); + +transmit: + udp_tunnel6_xmit_skb(dst, sk, skb, skb->dev, &fl.saddr, &fl.daddr, 0, + ip6_dst_hoplimit(dst), 0, fl.fl6_sport, + fl.fl6_dport, udp_get_no_check6_tx(sk)); + ret = 0; +err: + local_bh_enable(); + return ret; +} +#endif + +/* Transmit skb utilizing kernel-provided UDP tunneling framework. + * + * rcu_read_lock should be held on entry. + * On return, the skb is consumed. + */ +static int ovpn_udp_output(struct ovpn_struct *ovpn, struct ovpn_bind *bind, + struct dst_cache *cache, struct sock *sk, + struct sk_buff *skb) +{ + int ret; + + ovpn_rcu_lockdep_assert_held(); + + /* set sk to null if skb is already orphaned */ + if (!skb->destructor) + skb->sk = NULL; + + switch (bind->sa.in4.sin_family) { + case AF_INET: + ret = ovpn_udp4_output(ovpn, bind, cache, sk, skb); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ret = ovpn_udp6_output(ovpn, bind, cache, sk, skb); + break; +#endif + default: + ret = -EAFNOSUPPORT; + break; + } + + return ret; +} + +void ovpn_udp_send_skb(struct ovpn_struct *ovpn, struct ovpn_peer *peer, + struct sk_buff *skb) +{ + struct ovpn_bind *bind; + struct socket *sock; + int ret = -1; + + skb->dev = ovpn->dev; + /* no checksum performed at this layer */ + skb->ip_summed = CHECKSUM_NONE; + + /* get socket info */ + sock = peer->sock->sock; + if (unlikely(!sock)) { + net_dbg_ratelimited("%s: no sock for remote peer\n", __func__); + goto out; + } + + rcu_read_lock(); + /* get binding */ + bind = rcu_dereference(peer->bind); + if (unlikely(!bind)) { + net_dbg_ratelimited("%s: no bind for remote peer\n", __func__); + goto out_unlock; + } + + /* crypto layer -> transport (UDP) */ + ret = ovpn_udp_output(ovpn, bind, &peer->dst_cache, sock->sk, skb); + +out_unlock: + rcu_read_unlock(); +out: + if (ret < 0) + kfree_skb(skb); +} + +/* Set UDP encapsulation callbacks */ +int ovpn_udp_socket_attach(struct socket *sock, struct ovpn_struct *ovpn) +{ + struct udp_tunnel_sock_cfg cfg = { + .sk_user_data = ovpn, + .encap_type = UDP_ENCAP_OVPNINUDP, + .encap_rcv = ovpn_udp_encap_recv, + }; + struct ovpn_socket *old_data; + + /* sanity check */ + if (sock->sk->sk_protocol != IPPROTO_UDP) { + netdev_err(ovpn->dev, "%s: expected UDP socket\n", __func__); + return -EINVAL; + } + + /* make sure no pre-existing encapsulation handler exists */ + rcu_read_lock(); + old_data = rcu_dereference_sk_user_data(sock->sk); + rcu_read_unlock(); + if (old_data) { + if (old_data->ovpn == ovpn) { + netdev_dbg(ovpn->dev, + "%s: provided socket already owned by this interface\n", + __func__); + return -EALREADY; + } + + netdev_err(ovpn->dev, "%s: provided socket already taken by other user\n", + __func__); + return -EBUSY; + } + + setup_udp_tunnel_sock(sock_net(sock->sk), sock, &cfg); + + return 0; +} + +/* Detach socket from encapsulation handler and/or other callbacks */ +void ovpn_udp_socket_detach(struct socket *sock) +{ + struct udp_tunnel_sock_cfg cfg = { }; + + setup_udp_tunnel_sock(sock_net(sock->sk), sock, &cfg); +} diff --git a/drivers/net/ovpn-dco/udp.h b/drivers/net/ovpn-dco/udp.h new file mode 100644 index 000000000000..be94fb74669b --- /dev/null +++ b/drivers/net/ovpn-dco/udp.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: Antonio Quartulli + */ + +#ifndef _NET_OVPN_DCO_UDP_H_ +#define _NET_OVPN_DCO_UDP_H_ + +#include "peer.h" +#include "ovpnstruct.h" + +#include +#include +#include +#include + +int ovpn_udp_socket_attach(struct socket *sock, struct ovpn_struct *ovpn); +void ovpn_udp_socket_detach(struct socket *sock); +void ovpn_udp_send_skb(struct ovpn_struct *ovpn, struct ovpn_peer *peer, + struct sk_buff *skb); + +#endif /* _NET_OVPN_DCO_UDP_H_ */ diff --git a/include/net/netlink.h b/include/net/netlink.h index 7a2a9d3144ba..335f44871529 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -441,6 +441,7 @@ struct nla_policy { .max = _len \ } #define NLA_POLICY_MIN_LEN(_len) NLA_POLICY_MIN(NLA_BINARY, _len) +#define NLA_POLICY_MAX_LEN(_len) NLA_POLICY_MAX(NLA_BINARY, _len) /** * struct nl_info - netlink source information diff --git a/include/uapi/linux/ovpn_dco.h b/include/uapi/linux/ovpn_dco.h new file mode 100644 index 000000000000..6afee8b3fedd --- /dev/null +++ b/include/uapi/linux/ovpn_dco.h @@ -0,0 +1,265 @@ +/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */ +/* + * OpenVPN data channel accelerator + * + * Copyright (C) 2019-2022 OpenVPN, Inc. + * + * Author: James Yonan + * Antonio Quartulli + */ + +#ifndef _UAPI_LINUX_OVPN_DCO_H_ +#define _UAPI_LINUX_OVPN_DCO_H_ + +#define OVPN_NL_NAME "ovpn-dco" + +#define OVPN_NL_MULTICAST_GROUP_PEERS "peers" + +/** + * enum ovpn_nl_commands - supported netlink commands + */ +enum ovpn_nl_commands { + /** + * @OVPN_CMD_UNSPEC: unspecified command to catch errors + */ + OVPN_CMD_UNSPEC = 0, + + /** + * @OVPN_CMD_NEW_PEER: Configure peer with its crypto keys + */ + OVPN_CMD_NEW_PEER, + + /** + * @OVPN_CMD_SET_PEER: Tweak parameters for an existing peer + */ + OVPN_CMD_SET_PEER, + + /** + * @OVPN_CMD_DEL_PEER: Remove peer from internal table + */ + OVPN_CMD_DEL_PEER, + + OVPN_CMD_NEW_KEY, + + OVPN_CMD_SWAP_KEYS, + + OVPN_CMD_DEL_KEY, + + /** + * @OVPN_CMD_REGISTER_PACKET: Register for specific packet types to be + * forwarded to userspace + */ + OVPN_CMD_REGISTER_PACKET, + + /** + * @OVPN_CMD_PACKET: Send a packet from userspace to kernelspace. Also + * used to send to userspace packets for which a process had registered + * with OVPN_CMD_REGISTER_PACKET + */ + OVPN_CMD_PACKET, + + /** + * @OVPN_CMD_GET_PEER: Retrieve the status of a peer or all peers + */ + OVPN_CMD_GET_PEER, +}; + +enum ovpn_cipher_alg { + /** + * @OVPN_CIPHER_ALG_NONE: No encryption - reserved for debugging only + */ + OVPN_CIPHER_ALG_NONE = 0, + /** + * @OVPN_CIPHER_ALG_AES_GCM: AES-GCM AEAD cipher with any allowed key size + */ + OVPN_CIPHER_ALG_AES_GCM, + /** + * @OVPN_CIPHER_ALG_CHACHA20_POLY1305: ChaCha20Poly1305 AEAD cipher + */ + OVPN_CIPHER_ALG_CHACHA20_POLY1305, +}; + +enum ovpn_del_peer_reason { + __OVPN_DEL_PEER_REASON_FIRST, + OVPN_DEL_PEER_REASON_TEARDOWN = __OVPN_DEL_PEER_REASON_FIRST, + OVPN_DEL_PEER_REASON_USERSPACE, + OVPN_DEL_PEER_REASON_EXPIRED, + OVPN_DEL_PEER_REASON_TRANSPORT_ERROR, + __OVPN_DEL_PEER_REASON_AFTER_LAST +}; + +enum ovpn_key_slot { + __OVPN_KEY_SLOT_FIRST, + OVPN_KEY_SLOT_PRIMARY = __OVPN_KEY_SLOT_FIRST, + OVPN_KEY_SLOT_SECONDARY, + __OVPN_KEY_SLOT_AFTER_LAST, +}; + +enum ovpn_netlink_attrs { + OVPN_ATTR_UNSPEC = 0, + OVPN_ATTR_IFINDEX, + OVPN_ATTR_NEW_PEER, + OVPN_ATTR_SET_PEER, + OVPN_ATTR_DEL_PEER, + OVPN_ATTR_NEW_KEY, + OVPN_ATTR_SWAP_KEYS, + OVPN_ATTR_DEL_KEY, + OVPN_ATTR_PACKET, + OVPN_ATTR_GET_PEER, + + __OVPN_ATTR_AFTER_LAST, + OVPN_ATTR_MAX = __OVPN_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_key_dir_attrs { + OVPN_KEY_DIR_ATTR_UNSPEC = 0, + OVPN_KEY_DIR_ATTR_CIPHER_KEY, + OVPN_KEY_DIR_ATTR_NONCE_TAIL, + + __OVPN_KEY_DIR_ATTR_AFTER_LAST, + OVPN_KEY_DIR_ATTR_MAX = __OVPN_KEY_DIR_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_new_key_attrs { + OVPN_NEW_KEY_ATTR_UNSPEC = 0, + OVPN_NEW_KEY_ATTR_PEER_ID, + OVPN_NEW_KEY_ATTR_KEY_SLOT, + OVPN_NEW_KEY_ATTR_KEY_ID, + OVPN_NEW_KEY_ATTR_CIPHER_ALG, + OVPN_NEW_KEY_ATTR_ENCRYPT_KEY, + OVPN_NEW_KEY_ATTR_DECRYPT_KEY, + + __OVPN_NEW_KEY_ATTR_AFTER_LAST, + OVPN_NEW_KEY_ATTR_MAX = __OVPN_NEW_KEY_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_del_key_attrs { + OVPN_DEL_KEY_ATTR_UNSPEC = 0, + OVPN_DEL_KEY_ATTR_PEER_ID, + OVPN_DEL_KEY_ATTR_KEY_SLOT, + + __OVPN_DEL_KEY_ATTR_AFTER_LAST, + OVPN_DEL_KEY_ATTR_MAX = __OVPN_DEL_KEY_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_swap_keys_attrs { + OVPN_SWAP_KEYS_ATTR_UNSPEC = 0, + OVPN_SWAP_KEYS_ATTR_PEER_ID, + + __OVPN_SWAP_KEYS_ATTR_AFTER_LAST, + OVPN_SWAP_KEYS_ATTR_MAX = __OVPN_SWAP_KEYS_ATTR_AFTER_LAST - 1, + +}; + +enum ovpn_netlink_new_peer_attrs { + OVPN_NEW_PEER_ATTR_UNSPEC = 0, + OVPN_NEW_PEER_ATTR_PEER_ID, + OVPN_NEW_PEER_ATTR_SOCKADDR_REMOTE, + OVPN_NEW_PEER_ATTR_SOCKET, + OVPN_NEW_PEER_ATTR_IPV4, + OVPN_NEW_PEER_ATTR_IPV6, + OVPN_NEW_PEER_ATTR_LOCAL_IP, + + __OVPN_NEW_PEER_ATTR_AFTER_LAST, + OVPN_NEW_PEER_ATTR_MAX = __OVPN_NEW_PEER_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_set_peer_attrs { + OVPN_SET_PEER_ATTR_UNSPEC = 0, + OVPN_SET_PEER_ATTR_PEER_ID, + OVPN_SET_PEER_ATTR_KEEPALIVE_INTERVAL, + OVPN_SET_PEER_ATTR_KEEPALIVE_TIMEOUT, + + __OVPN_SET_PEER_ATTR_AFTER_LAST, + OVPN_SET_PEER_ATTR_MAX = __OVPN_SET_PEER_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_del_peer_attrs { + OVPN_DEL_PEER_ATTR_UNSPEC = 0, + OVPN_DEL_PEER_ATTR_REASON, + OVPN_DEL_PEER_ATTR_PEER_ID, + + __OVPN_DEL_PEER_ATTR_AFTER_LAST, + OVPN_DEL_PEER_ATTR_MAX = __OVPN_DEL_PEER_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_get_peer_attrs { + OVPN_GET_PEER_ATTR_UNSPEC = 0, + OVPN_GET_PEER_ATTR_PEER_ID, + + __OVPN_GET_PEER_ATTR_AFTER_LAST, + OVPN_GET_PEER_ATTR_MAX = __OVPN_GET_PEER_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_get_peer_response_attrs { + OVPN_GET_PEER_RESP_ATTR_UNSPEC = 0, + OVPN_GET_PEER_RESP_ATTR_PEER_ID, + OVPN_GET_PEER_RESP_ATTR_SOCKADDR_REMOTE, + OVPN_GET_PEER_RESP_ATTR_IPV4, + OVPN_GET_PEER_RESP_ATTR_IPV6, + OVPN_GET_PEER_RESP_ATTR_LOCAL_IP, + OVPN_GET_PEER_RESP_ATTR_LOCAL_PORT, + OVPN_GET_PEER_RESP_ATTR_KEEPALIVE_INTERVAL, + OVPN_GET_PEER_RESP_ATTR_KEEPALIVE_TIMEOUT, + OVPN_GET_PEER_RESP_ATTR_RX_BYTES, + OVPN_GET_PEER_RESP_ATTR_TX_BYTES, + OVPN_GET_PEER_RESP_ATTR_RX_PACKETS, + OVPN_GET_PEER_RESP_ATTR_TX_PACKETS, + + __OVPN_GET_PEER_RESP_ATTR_AFTER_LAST, + OVPN_GET_PEER_RESP_ATTR_MAX = __OVPN_GET_PEER_RESP_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_peer_stats_attrs { + OVPN_PEER_STATS_ATTR_UNSPEC = 0, + OVPN_PEER_STATS_BYTES, + OVPN_PEER_STATS_PACKETS, + + __OVPN_PEER_STATS_ATTR_AFTER_LAST, + OVPN_PEER_STATS_ATTR_MAX = __OVPN_PEER_STATS_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_peer_attrs { + OVPN_PEER_ATTR_UNSPEC = 0, + OVPN_PEER_ATTR_PEER_ID, + OVPN_PEER_ATTR_SOCKADDR_REMOTE, + OVPN_PEER_ATTR_IPV4, + OVPN_PEER_ATTR_IPV6, + OVPN_PEER_ATTR_LOCAL_IP, + OVPN_PEER_ATTR_KEEPALIVE_INTERVAL, + OVPN_PEER_ATTR_KEEPALIVE_TIMEOUT, + OVPN_PEER_ATTR_ENCRYPT_KEY, + OVPN_PEER_ATTR_DECRYPT_KEY, + OVPN_PEER_ATTR_RX_STATS, + OVPN_PEER_ATTR_TX_STATS, + + __OVPN_PEER_ATTR_AFTER_LAST, + OVPN_PEER_ATTR_MAX = __OVPN_PEER_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_netlink_packet_attrs { + OVPN_PACKET_ATTR_UNSPEC = 0, + OVPN_PACKET_ATTR_PACKET, + OVPN_PACKET_ATTR_PEER_ID, + + __OVPN_PACKET_ATTR_AFTER_LAST, + OVPN_PACKET_ATTR_MAX = __OVPN_PACKET_ATTR_AFTER_LAST - 1, +}; + +enum ovpn_ifla_attrs { + IFLA_OVPN_UNSPEC = 0, + IFLA_OVPN_MODE, + + __IFLA_OVPN_AFTER_LAST, + IFLA_OVPN_MAX = __IFLA_OVPN_AFTER_LAST - 1, +}; + +enum ovpn_mode { + __OVPN_MODE_FIRST = 0, + OVPN_MODE_P2P = __OVPN_MODE_FIRST, + OVPN_MODE_MP, + + __OVPN_MODE_AFTER_LAST, +}; + +#endif /* _UAPI_LINUX_OVPN_DCO_H_ */ diff --git a/include/uapi/linux/udp.h b/include/uapi/linux/udp.h index 4828794efcf8..8008c762e6b8 100644 --- a/include/uapi/linux/udp.h +++ b/include/uapi/linux/udp.h @@ -43,5 +43,6 @@ struct udphdr { #define UDP_ENCAP_GTP1U 5 /* 3GPP TS 29.060 */ #define UDP_ENCAP_RXRPC 6 #define TCP_ENCAP_ESPINTCP 7 /* Yikes, this is really xfrm encap types. */ +#define UDP_ENCAP_OVPNINUDP 8 /* OpenVPN over UDP connection */ #endif /* _UAPI_LINUX_UDP_H */ -- 2.39.0.rc2 From 17aabe9600bb9fe36db28463c3c8b1ee42f06b52 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Thu, 8 Dec 2022 13:22:12 +0100 Subject: [PATCH 05/20] mm-Introduce the maple tree with MG-LRU Signed-off-by: Peter Jung --- Documentation/admin-guide/mm/index.rst | 1 + Documentation/admin-guide/mm/multigen_lru.rst | 162 + Documentation/core-api/index.rst | 1 + Documentation/core-api/maple_tree.rst | 217 + Documentation/mm/index.rst | 1 + Documentation/mm/multigen_lru.rst | 159 + MAINTAINERS | 12 + arch/Kconfig | 8 + arch/arm64/include/asm/pgtable.h | 15 +- arch/arm64/kernel/elfcore.c | 16 +- arch/arm64/kernel/vdso.c | 3 +- arch/parisc/kernel/cache.c | 9 +- arch/powerpc/kernel/vdso.c | 6 +- arch/powerpc/mm/book3s32/tlb.c | 11 +- arch/powerpc/mm/book3s64/subpage_prot.c | 13 +- arch/riscv/kernel/vdso.c | 3 +- arch/s390/kernel/vdso.c | 3 +- arch/s390/mm/gmap.c | 6 +- arch/um/kernel/tlb.c | 14 +- arch/x86/Kconfig | 1 + arch/x86/entry/vdso/vma.c | 9 +- arch/x86/include/asm/pgtable.h | 9 +- arch/x86/kernel/tboot.c | 2 +- arch/x86/mm/pgtable.c | 5 +- arch/xtensa/kernel/syscall.c | 18 +- drivers/firmware/efi/efi.c | 2 +- drivers/gpu/drm/i915/gem/i915_gem_userptr.c | 14 +- drivers/misc/cxl/fault.c | 45 +- drivers/tee/optee/call.c | 18 +- drivers/xen/privcmd.c | 2 +- fs/coredump.c | 34 +- fs/exec.c | 14 +- fs/fuse/dev.c | 3 +- fs/proc/base.c | 5 +- fs/proc/internal.h | 2 +- fs/proc/task_mmu.c | 74 +- fs/proc/task_nommu.c | 45 +- fs/userfaultfd.c | 65 +- include/linux/cgroup.h | 15 +- include/linux/maple_tree.h | 692 + include/linux/memcontrol.h | 36 + include/linux/mm.h | 83 +- include/linux/mm_inline.h | 231 +- include/linux/mm_types.h | 119 +- include/linux/mm_types_task.h | 12 - include/linux/mmzone.h | 216 + include/linux/nodemask.h | 1 + include/linux/page-flags-layout.h | 16 +- include/linux/page-flags.h | 4 +- include/linux/pgtable.h | 17 +- include/linux/sched.h | 5 +- include/linux/swap.h | 4 + include/linux/userfaultfd_k.h | 7 +- include/linux/vm_event_item.h | 4 - include/linux/vmacache.h | 28 - include/linux/vmstat.h | 6 - include/trace/events/maple_tree.h | 123 + include/trace/events/mmap.h | 73 + init/main.c | 2 + ipc/shm.c | 21 +- kernel/acct.c | 11 +- kernel/bounds.c | 7 + kernel/bpf/task_iter.c | 10 +- kernel/cgroup/cgroup-internal.h | 1 - kernel/debug/debug_core.c | 12 - kernel/events/core.c | 3 +- kernel/events/uprobes.c | 9 +- kernel/exit.c | 1 + kernel/fork.c | 71 +- kernel/sched/core.c | 1 + kernel/sched/fair.c | 10 +- lib/Kconfig.debug | 21 +- lib/Makefile | 3 +- lib/maple_tree.c | 7158 +++ lib/test_maple_tree.c | 2767 ++ mm/Kconfig | 26 + mm/Makefile | 2 +- mm/damon/vaddr-test.h | 36 +- mm/damon/vaddr.c | 53 +- mm/debug.c | 14 +- mm/gup.c | 7 +- mm/huge_memory.c | 7 +- mm/init-mm.c | 4 +- mm/internal.h | 9 +- mm/khugepaged.c | 11 +- mm/ksm.c | 18 +- mm/madvise.c | 2 +- mm/memcontrol.c | 34 +- mm/memory.c | 72 +- mm/mempolicy.c | 55 +- mm/mlock.c | 35 +- mm/mm_init.c | 6 +- mm/mmap.c | 2162 +- mm/mmzone.c | 2 + mm/mprotect.c | 8 +- mm/mremap.c | 22 +- mm/msync.c | 2 +- mm/nommu.c | 260 +- mm/oom_kill.c | 3 +- mm/pagewalk.c | 2 +- mm/rmap.c | 6 + mm/swap.c | 54 +- mm/swapfile.c | 4 +- mm/util.c | 32 - mm/vmacache.c | 117 - mm/vmscan.c | 3231 +- mm/vmstat.c | 4 - mm/workingset.c | 110 +- tools/include/linux/slab.h | 4 + tools/testing/radix-tree/.gitignore | 3 + tools/testing/radix-tree/Makefile | 28 +- tools/testing/radix-tree/generated/autoconf.h | 1 + tools/testing/radix-tree/linux.c | 164 +- tools/testing/radix-tree/linux/kernel.h | 1 + tools/testing/radix-tree/linux/lockdep.h | 2 + tools/testing/radix-tree/linux/maple_tree.h | 7 + tools/testing/radix-tree/maple.c | 35830 ++++++++++++++++ .../radix-tree/trace/events/maple_tree.h | 5 + 118 files changed, 53118 insertions(+), 2164 deletions(-) create mode 100644 Documentation/admin-guide/mm/multigen_lru.rst create mode 100644 Documentation/core-api/maple_tree.rst create mode 100644 Documentation/mm/multigen_lru.rst create mode 100644 include/linux/maple_tree.h delete mode 100644 include/linux/vmacache.h create mode 100644 include/trace/events/maple_tree.h create mode 100644 lib/maple_tree.c create mode 100644 lib/test_maple_tree.c delete mode 100644 mm/vmacache.c create mode 100644 tools/testing/radix-tree/linux/maple_tree.h create mode 100644 tools/testing/radix-tree/maple.c create mode 100644 tools/testing/radix-tree/trace/events/maple_tree.h diff --git a/Documentation/admin-guide/mm/index.rst b/Documentation/admin-guide/mm/index.rst index 1bd11118dfb1..d1064e0ba34a 100644 --- a/Documentation/admin-guide/mm/index.rst +++ b/Documentation/admin-guide/mm/index.rst @@ -32,6 +32,7 @@ the Linux memory management. idle_page_tracking ksm memory-hotplug + multigen_lru nommu-mmap numa_memory_policy numaperf diff --git a/Documentation/admin-guide/mm/multigen_lru.rst b/Documentation/admin-guide/mm/multigen_lru.rst new file mode 100644 index 000000000000..33e068830497 --- /dev/null +++ b/Documentation/admin-guide/mm/multigen_lru.rst @@ -0,0 +1,162 @@ +.. SPDX-License-Identifier: GPL-2.0 + +============= +Multi-Gen LRU +============= +The multi-gen LRU is an alternative LRU implementation that optimizes +page reclaim and improves performance under memory pressure. Page +reclaim decides the kernel's caching policy and ability to overcommit +memory. It directly impacts the kswapd CPU usage and RAM efficiency. + +Quick start +=========== +Build the kernel with the following configurations. + +* ``CONFIG_LRU_GEN=y`` +* ``CONFIG_LRU_GEN_ENABLED=y`` + +All set! + +Runtime options +=============== +``/sys/kernel/mm/lru_gen/`` contains stable ABIs described in the +following subsections. + +Kill switch +----------- +``enabled`` accepts different values to enable or disable the +following components. Its default value depends on +``CONFIG_LRU_GEN_ENABLED``. All the components should be enabled +unless some of them have unforeseen side effects. Writing to +``enabled`` has no effect when a component is not supported by the +hardware, and valid values will be accepted even when the main switch +is off. + +====== =============================================================== +Values Components +====== =============================================================== +0x0001 The main switch for the multi-gen LRU. +0x0002 Clearing the accessed bit in leaf page table entries in large + batches, when MMU sets it (e.g., on x86). This behavior can + theoretically worsen lock contention (mmap_lock). If it is + disabled, the multi-gen LRU will suffer a minor performance + degradation for workloads that contiguously map hot pages, + whose accessed bits can be otherwise cleared by fewer larger + batches. +0x0004 Clearing the accessed bit in non-leaf page table entries as + well, when MMU sets it (e.g., on x86). This behavior was not + verified on x86 varieties other than Intel and AMD. If it is + disabled, the multi-gen LRU will suffer a negligible + performance degradation. +[yYnN] Apply to all the components above. +====== =============================================================== + +E.g., +:: + + echo y >/sys/kernel/mm/lru_gen/enabled + cat /sys/kernel/mm/lru_gen/enabled + 0x0007 + echo 5 >/sys/kernel/mm/lru_gen/enabled + cat /sys/kernel/mm/lru_gen/enabled + 0x0005 + +Thrashing prevention +-------------------- +Personal computers are more sensitive to thrashing because it can +cause janks (lags when rendering UI) and negatively impact user +experience. The multi-gen LRU offers thrashing prevention to the +majority of laptop and desktop users who do not have ``oomd``. + +Users can write ``N`` to ``min_ttl_ms`` to prevent the working set of +``N`` milliseconds from getting evicted. The OOM killer is triggered +if this working set cannot be kept in memory. In other words, this +option works as an adjustable pressure relief valve, and when open, it +terminates applications that are hopefully not being used. + +Based on the average human detectable lag (~100ms), ``N=1000`` usually +eliminates intolerable janks due to thrashing. Larger values like +``N=3000`` make janks less noticeable at the risk of premature OOM +kills. + +The default value ``0`` means disabled. + +Experimental features +===================== +``/sys/kernel/debug/lru_gen`` accepts commands described in the +following subsections. Multiple command lines are supported, so does +concatenation with delimiters ``,`` and ``;``. + +``/sys/kernel/debug/lru_gen_full`` provides additional stats for +debugging. ``CONFIG_LRU_GEN_STATS=y`` keeps historical stats from +evicted generations in this file. + +Working set estimation +---------------------- +Working set estimation measures how much memory an application needs +in a given time interval, and it is usually done with little impact on +the performance of the application. E.g., data centers want to +optimize job scheduling (bin packing) to improve memory utilizations. +When a new job comes in, the job scheduler needs to find out whether +each server it manages can allocate a certain amount of memory for +this new job before it can pick a candidate. To do so, the job +scheduler needs to estimate the working sets of the existing jobs. + +When it is read, ``lru_gen`` returns a histogram of numbers of pages +accessed over different time intervals for each memcg and node. +``MAX_NR_GENS`` decides the number of bins for each histogram. The +histograms are noncumulative. +:: + + memcg memcg_id memcg_path + node node_id + min_gen_nr age_in_ms nr_anon_pages nr_file_pages + ... + max_gen_nr age_in_ms nr_anon_pages nr_file_pages + +Each bin contains an estimated number of pages that have been accessed +within ``age_in_ms``. E.g., ``min_gen_nr`` contains the coldest pages +and ``max_gen_nr`` contains the hottest pages, since ``age_in_ms`` of +the former is the largest and that of the latter is the smallest. + +Users can write the following command to ``lru_gen`` to create a new +generation ``max_gen_nr+1``: + + ``+ memcg_id node_id max_gen_nr [can_swap [force_scan]]`` + +``can_swap`` defaults to the swap setting and, if it is set to ``1``, +it forces the scan of anon pages when swap is off, and vice versa. +``force_scan`` defaults to ``1`` and, if it is set to ``0``, it +employs heuristics to reduce the overhead, which is likely to reduce +the coverage as well. + +A typical use case is that a job scheduler runs this command at a +certain time interval to create new generations, and it ranks the +servers it manages based on the sizes of their cold pages defined by +this time interval. + +Proactive reclaim +----------------- +Proactive reclaim induces page reclaim when there is no memory +pressure. It usually targets cold pages only. E.g., when a new job +comes in, the job scheduler wants to proactively reclaim cold pages on +the server it selected, to improve the chance of successfully landing +this new job. + +Users can write the following command to ``lru_gen`` to evict +generations less than or equal to ``min_gen_nr``. + + ``- memcg_id node_id min_gen_nr [swappiness [nr_to_reclaim]]`` + +``min_gen_nr`` should be less than ``max_gen_nr-1``, since +``max_gen_nr`` and ``max_gen_nr-1`` are not fully aged (equivalent to +the active list) and therefore cannot be evicted. ``swappiness`` +overrides the default value in ``/proc/sys/vm/swappiness``. +``nr_to_reclaim`` limits the number of pages to evict. + +A typical use case is that a job scheduler runs this command before it +tries to land a new job on a server. If it fails to materialize enough +cold pages because of the overestimation, it retries on the next +server according to the ranking result obtained from the working set +estimation step. This less forceful approach limits the impacts on the +existing jobs. diff --git a/Documentation/core-api/index.rst b/Documentation/core-api/index.rst index dc95df462eea..1da6a4fac664 100644 --- a/Documentation/core-api/index.rst +++ b/Documentation/core-api/index.rst @@ -36,6 +36,7 @@ Library functionality that is used throughout the kernel. kref assoc_array xarray + maple_tree idr circular-buffers rbtree diff --git a/Documentation/core-api/maple_tree.rst b/Documentation/core-api/maple_tree.rst new file mode 100644 index 000000000000..45defcf15da7 --- /dev/null +++ b/Documentation/core-api/maple_tree.rst @@ -0,0 +1,217 @@ +.. SPDX-License-Identifier: GPL-2.0+ + + +========== +Maple Tree +========== + +:Author: Liam R. Howlett + +Overview +======== + +The Maple Tree is a B-Tree data type which is optimized for storing +non-overlapping ranges, including ranges of size 1. The tree was designed to +be simple to use and does not require a user written search method. It +supports iterating over a range of entries and going to the previous or next +entry in a cache-efficient manner. The tree can also be put into an RCU-safe +mode of operation which allows reading and writing concurrently. Writers must +synchronize on a lock, which can be the default spinlock, or the user can set +the lock to an external lock of a different type. + +The Maple Tree maintains a small memory footprint and was designed to use +modern processor cache efficiently. The majority of the users will be able to +use the normal API. An :ref:`maple-tree-advanced-api` exists for more complex +scenarios. The most important usage of the Maple Tree is the tracking of the +virtual memory areas. + +The Maple Tree can store values between ``0`` and ``ULONG_MAX``. The Maple +Tree reserves values with the bottom two bits set to '10' which are below 4096 +(ie 2, 6, 10 .. 4094) for internal use. If the entries may use reserved +entries then the users can convert the entries using xa_mk_value() and convert +them back by calling xa_to_value(). If the user needs to use a reserved +value, then the user can convert the value when using the +:ref:`maple-tree-advanced-api`, but are blocked by the normal API. + +The Maple Tree can also be configured to support searching for a gap of a given +size (or larger). + +Pre-allocating of nodes is also supported using the +:ref:`maple-tree-advanced-api`. This is useful for users who must guarantee a +successful store operation within a given +code segment when allocating cannot be done. Allocations of nodes are +relatively small at around 256 bytes. + +.. _maple-tree-normal-api: + +Normal API +========== + +Start by initialising a maple tree, either with DEFINE_MTREE() for statically +allocated maple trees or mt_init() for dynamically allocated ones. A +freshly-initialised maple tree contains a ``NULL`` pointer for the range ``0`` +- ``ULONG_MAX``. There are currently two types of maple trees supported: the +allocation tree and the regular tree. The regular tree has a higher branching +factor for internal nodes. The allocation tree has a lower branching factor +but allows the user to search for a gap of a given size or larger from either +``0`` upwards or ``ULONG_MAX`` down. An allocation tree can be used by +passing in the ``MT_FLAGS_ALLOC_RANGE`` flag when initialising the tree. + +You can then set entries using mtree_store() or mtree_store_range(). +mtree_store() will overwrite any entry with the new entry and return 0 on +success or an error code otherwise. mtree_store_range() works in the same way +but takes a range. mtree_load() is used to retrieve the entry stored at a +given index. You can use mtree_erase() to erase an entire range by only +knowing one value within that range, or mtree_store() call with an entry of +NULL may be used to partially erase a range or many ranges at once. + +If you want to only store a new entry to a range (or index) if that range is +currently ``NULL``, you can use mtree_insert_range() or mtree_insert() which +return -EEXIST if the range is not empty. + +You can search for an entry from an index upwards by using mt_find(). + +You can walk each entry within a range by calling mt_for_each(). You must +provide a temporary variable to store a cursor. If you want to walk each +element of the tree then ``0`` and ``ULONG_MAX`` may be used as the range. If +the caller is going to hold the lock for the duration of the walk then it is +worth looking at the mas_for_each() API in the :ref:`maple-tree-advanced-api` +section. + +Sometimes it is necessary to ensure the next call to store to a maple tree does +not allocate memory, please see :ref:`maple-tree-advanced-api` for this use case. + +Finally, you can remove all entries from a maple tree by calling +mtree_destroy(). If the maple tree entries are pointers, you may wish to free +the entries first. + +Allocating Nodes +---------------- + +The allocations are handled by the internal tree code. See +:ref:`maple-tree-advanced-alloc` for other options. + +Locking +------- + +You do not have to worry about locking. See :ref:`maple-tree-advanced-locks` +for other options. + +The Maple Tree uses RCU and an internal spinlock to synchronise access: + +Takes RCU read lock: + * mtree_load() + * mt_find() + * mt_for_each() + * mt_next() + * mt_prev() + +Takes ma_lock internally: + * mtree_store() + * mtree_store_range() + * mtree_insert() + * mtree_insert_range() + * mtree_erase() + * mtree_destroy() + * mt_set_in_rcu() + * mt_clear_in_rcu() + +If you want to take advantage of the internal lock to protect the data +structures that you are storing in the Maple Tree, you can call mtree_lock() +before calling mtree_load(), then take a reference count on the object you +have found before calling mtree_unlock(). This will prevent stores from +removing the object from the tree between looking up the object and +incrementing the refcount. You can also use RCU to avoid dereferencing +freed memory, but an explanation of that is beyond the scope of this +document. + +.. _maple-tree-advanced-api: + +Advanced API +============ + +The advanced API offers more flexibility and better performance at the +cost of an interface which can be harder to use and has fewer safeguards. +You must take care of your own locking while using the advanced API. +You can use the ma_lock, RCU or an external lock for protection. +You can mix advanced and normal operations on the same array, as long +as the locking is compatible. The :ref:`maple-tree-normal-api` is implemented +in terms of the advanced API. + +The advanced API is based around the ma_state, this is where the 'mas' +prefix originates. The ma_state struct keeps track of tree operations to make +life easier for both internal and external tree users. + +Initialising the maple tree is the same as in the :ref:`maple-tree-normal-api`. +Please see above. + +The maple state keeps track of the range start and end in mas->index and +mas->last, respectively. + +mas_walk() will walk the tree to the location of mas->index and set the +mas->index and mas->last according to the range for the entry. + +You can set entries using mas_store(). mas_store() will overwrite any entry +with the new entry and return the first existing entry that is overwritten. +The range is passed in as members of the maple state: index and last. + +You can use mas_erase() to erase an entire range by setting index and +last of the maple state to the desired range to erase. This will erase +the first range that is found in that range, set the maple state index +and last as the range that was erased and return the entry that existed +at that location. + +You can walk each entry within a range by using mas_for_each(). If you want +to walk each element of the tree then ``0`` and ``ULONG_MAX`` may be used as +the range. If the lock needs to be periodically dropped, see the locking +section mas_pause(). + +Using a maple state allows mas_next() and mas_prev() to function as if the +tree was a linked list. With such a high branching factor the amortized +performance penalty is outweighed by cache optimization. mas_next() will +return the next entry which occurs after the entry at index. mas_prev() +will return the previous entry which occurs before the entry at index. + +mas_find() will find the first entry which exists at or above index on +the first call, and the next entry from every subsequent calls. + +mas_find_rev() will find the fist entry which exists at or below the last on +the first call, and the previous entry from every subsequent calls. + +If the user needs to yield the lock during an operation, then the maple state +must be paused using mas_pause(). + +There are a few extra interfaces provided when using an allocation tree. +If you wish to search for a gap within a range, then mas_empty_area() +or mas_empty_area_rev() can be used. mas_empty_area() searches for a gap +starting at the lowest index given up to the maximum of the range. +mas_empty_area_rev() searches for a gap starting at the highest index given +and continues downward to the lower bound of the range. + +.. _maple-tree-advanced-alloc: + +Advanced Allocating Nodes +------------------------- + +Allocations are usually handled internally to the tree, however if allocations +need to occur before a write occurs then calling mas_expected_entries() will +allocate the worst-case number of needed nodes to insert the provided number of +ranges. This also causes the tree to enter mass insertion mode. Once +insertions are complete calling mas_destroy() on the maple state will free the +unused allocations. + +.. _maple-tree-advanced-locks: + +Advanced Locking +---------------- + +The maple tree uses a spinlock by default, but external locks can be used for +tree updates as well. To use an external lock, the tree must be initialized +with the ``MT_FLAGS_LOCK_EXTERN flag``, this is usually done with the +MTREE_INIT_EXT() #define, which takes an external lock as an argument. + +Functions and structures +======================== + +.. kernel-doc:: include/linux/maple_tree.h +.. kernel-doc:: lib/maple_tree.c diff --git a/Documentation/mm/index.rst b/Documentation/mm/index.rst index 25561b95780f..3e538e098eb6 100644 --- a/Documentation/mm/index.rst +++ b/Documentation/mm/index.rst @@ -51,6 +51,7 @@ above structured documentation, or deleted if it has served its purpose. ksm memory-model mmu_notifier + multigen_lru numa overcommit-accounting page_migration diff --git a/Documentation/mm/multigen_lru.rst b/Documentation/mm/multigen_lru.rst new file mode 100644 index 000000000000..d7062c6a8946 --- /dev/null +++ b/Documentation/mm/multigen_lru.rst @@ -0,0 +1,159 @@ +.. SPDX-License-Identifier: GPL-2.0 + +============= +Multi-Gen LRU +============= +The multi-gen LRU is an alternative LRU implementation that optimizes +page reclaim and improves performance under memory pressure. Page +reclaim decides the kernel's caching policy and ability to overcommit +memory. It directly impacts the kswapd CPU usage and RAM efficiency. + +Design overview +=============== +Objectives +---------- +The design objectives are: + +* Good representation of access recency +* Try to profit from spatial locality +* Fast paths to make obvious choices +* Simple self-correcting heuristics + +The representation of access recency is at the core of all LRU +implementations. In the multi-gen LRU, each generation represents a +group of pages with similar access recency. Generations establish a +(time-based) common frame of reference and therefore help make better +choices, e.g., between different memcgs on a computer or different +computers in a data center (for job scheduling). + +Exploiting spatial locality improves efficiency when gathering the +accessed bit. A rmap walk targets a single page and does not try to +profit from discovering a young PTE. A page table walk can sweep all +the young PTEs in an address space, but the address space can be too +sparse to make a profit. The key is to optimize both methods and use +them in combination. + +Fast paths reduce code complexity and runtime overhead. Unmapped pages +do not require TLB flushes; clean pages do not require writeback. +These facts are only helpful when other conditions, e.g., access +recency, are similar. With generations as a common frame of reference, +additional factors stand out. But obvious choices might not be good +choices; thus self-correction is necessary. + +The benefits of simple self-correcting heuristics are self-evident. +Again, with generations as a common frame of reference, this becomes +attainable. Specifically, pages in the same generation can be +categorized based on additional factors, and a feedback loop can +statistically compare the refault percentages across those categories +and infer which of them are better choices. + +Assumptions +----------- +The protection of hot pages and the selection of cold pages are based +on page access channels and patterns. There are two access channels: + +* Accesses through page tables +* Accesses through file descriptors + +The protection of the former channel is by design stronger because: + +1. The uncertainty in determining the access patterns of the former + channel is higher due to the approximation of the accessed bit. +2. The cost of evicting the former channel is higher due to the TLB + flushes required and the likelihood of encountering the dirty bit. +3. The penalty of underprotecting the former channel is higher because + applications usually do not prepare themselves for major page + faults like they do for blocked I/O. E.g., GUI applications + commonly use dedicated I/O threads to avoid blocking rendering + threads. + +There are also two access patterns: + +* Accesses exhibiting temporal locality +* Accesses not exhibiting temporal locality + +For the reasons listed above, the former channel is assumed to follow +the former pattern unless ``VM_SEQ_READ`` or ``VM_RAND_READ`` is +present, and the latter channel is assumed to follow the latter +pattern unless outlying refaults have been observed. + +Workflow overview +================= +Evictable pages are divided into multiple generations for each +``lruvec``. The youngest generation number is stored in +``lrugen->max_seq`` for both anon and file types as they are aged on +an equal footing. The oldest generation numbers are stored in +``lrugen->min_seq[]`` separately for anon and file types as clean file +pages can be evicted regardless of swap constraints. These three +variables are monotonically increasing. + +Generation numbers are truncated into ``order_base_2(MAX_NR_GENS+1)`` +bits in order to fit into the gen counter in ``folio->flags``. Each +truncated generation number is an index to ``lrugen->lists[]``. The +sliding window technique is used to track at least ``MIN_NR_GENS`` and +at most ``MAX_NR_GENS`` generations. The gen counter stores a value +within ``[1, MAX_NR_GENS]`` while a page is on one of +``lrugen->lists[]``; otherwise it stores zero. + +Each generation is divided into multiple tiers. A page accessed ``N`` +times through file descriptors is in tier ``order_base_2(N)``. Unlike +generations, tiers do not have dedicated ``lrugen->lists[]``. In +contrast to moving across generations, which requires the LRU lock, +moving across tiers only involves atomic operations on +``folio->flags`` and therefore has a negligible cost. A feedback loop +modeled after the PID controller monitors refaults over all the tiers +from anon and file types and decides which tiers from which types to +evict or protect. + +There are two conceptually independent procedures: the aging and the +eviction. They form a closed-loop system, i.e., the page reclaim. + +Aging +----- +The aging produces young generations. Given an ``lruvec``, it +increments ``max_seq`` when ``max_seq-min_seq+1`` approaches +``MIN_NR_GENS``. The aging promotes hot pages to the youngest +generation when it finds them accessed through page tables; the +demotion of cold pages happens consequently when it increments +``max_seq``. The aging uses page table walks and rmap walks to find +young PTEs. For the former, it iterates ``lruvec_memcg()->mm_list`` +and calls ``walk_page_range()`` with each ``mm_struct`` on this list +to scan PTEs, and after each iteration, it increments ``max_seq``. For +the latter, when the eviction walks the rmap and finds a young PTE, +the aging scans the adjacent PTEs. For both, on finding a young PTE, +the aging clears the accessed bit and updates the gen counter of the +page mapped by this PTE to ``(max_seq%MAX_NR_GENS)+1``. + +Eviction +-------- +The eviction consumes old generations. Given an ``lruvec``, it +increments ``min_seq`` when ``lrugen->lists[]`` indexed by +``min_seq%MAX_NR_GENS`` becomes empty. To select a type and a tier to +evict from, it first compares ``min_seq[]`` to select the older type. +If both types are equally old, it selects the one whose first tier has +a lower refault percentage. The first tier contains single-use +unmapped clean pages, which are the best bet. The eviction sorts a +page according to its gen counter if the aging has found this page +accessed through page tables and updated its gen counter. It also +moves a page to the next generation, i.e., ``min_seq+1``, if this page +was accessed multiple times through file descriptors and the feedback +loop has detected outlying refaults from the tier this page is in. To +this end, the feedback loop uses the first tier as the baseline, for +the reason stated earlier. + +Summary +------- +The multi-gen LRU can be disassembled into the following parts: + +* Generations +* Rmap walks +* Page table walks +* Bloom filters +* PID controller + +The aging and the eviction form a producer-consumer model; +specifically, the latter drives the former by the sliding window over +generations. Within the aging, rmap walks drive page table walks by +inserting hot densely populated page tables to the Bloom filters. +Within the eviction, the PID controller uses refaults as the feedback +to select types to evict and tiers to protect. diff --git a/MAINTAINERS b/MAINTAINERS index 603920f452d4..ac600cd66d22 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -12093,6 +12093,18 @@ L: linux-man@vger.kernel.org S: Maintained W: http://www.kernel.org/doc/man-pages +MAPLE TREE +M: Liam R. Howlett +L: linux-mm@kvack.org +S: Supported +F: Documentation/core-api/maple_tree.rst +F: include/linux/maple_tree.h +F: include/trace/events/maple_tree.h +F: lib/maple_tree.c +F: lib/test_maple_tree.c +F: tools/testing/radix-tree/linux/maple_tree.h +F: tools/testing/radix-tree/maple.c + MARDUK (CREATOR CI40) DEVICE TREE SUPPORT M: Rahul Bedarkar L: linux-mips@vger.kernel.org diff --git a/arch/Kconfig b/arch/Kconfig index 8b311e400ec1..bf19a84fffa2 100644 --- a/arch/Kconfig +++ b/arch/Kconfig @@ -1418,6 +1418,14 @@ config DYNAMIC_SIGFRAME config HAVE_ARCH_NODE_DEV_GROUP bool +config ARCH_HAS_NONLEAF_PMD_YOUNG + bool + help + Architectures that select this option are capable of setting the + accessed bit in non-leaf PMD entries when using them as part of linear + address translations. Page table walkers that clear the accessed bit + may use this capability to reduce their search space. + source "kernel/gcov/Kconfig" source "scripts/gcc-plugins/Kconfig" diff --git a/arch/arm64/include/asm/pgtable.h b/arch/arm64/include/asm/pgtable.h index d78e69293d12..edf6625ce965 100644 --- a/arch/arm64/include/asm/pgtable.h +++ b/arch/arm64/include/asm/pgtable.h @@ -1082,24 +1082,13 @@ static inline void update_mmu_cache(struct vm_area_struct *vma, * page after fork() + CoW for pfn mappings. We don't always have a * hardware-managed access flag on arm64. */ -static inline bool arch_faults_on_old_pte(void) -{ - /* The register read below requires a stable CPU to make any sense */ - cant_migrate(); - - return !cpu_has_hw_af(); -} -#define arch_faults_on_old_pte arch_faults_on_old_pte +#define arch_has_hw_pte_young cpu_has_hw_af /* * Experimentally, it's cheap to set the access flag in hardware and we * benefit from prefaulting mappings as 'old' to start with. */ -static inline bool arch_wants_old_prefaulted_pte(void) -{ - return !arch_faults_on_old_pte(); -} -#define arch_wants_old_prefaulted_pte arch_wants_old_prefaulted_pte +#define arch_wants_old_prefaulted_pte cpu_has_hw_af static inline bool pud_sect_supported(void) { diff --git a/arch/arm64/kernel/elfcore.c b/arch/arm64/kernel/elfcore.c index 98d67444a5b6..27ef7ad3ffd2 100644 --- a/arch/arm64/kernel/elfcore.c +++ b/arch/arm64/kernel/elfcore.c @@ -8,9 +8,9 @@ #include #include -#define for_each_mte_vma(tsk, vma) \ +#define for_each_mte_vma(vmi, vma) \ if (system_supports_mte()) \ - for (vma = tsk->mm->mmap; vma; vma = vma->vm_next) \ + for_each_vma(vmi, vma) \ if (vma->vm_flags & VM_MTE) static unsigned long mte_vma_tag_dump_size(struct vm_area_struct *vma) @@ -81,8 +81,9 @@ Elf_Half elf_core_extra_phdrs(void) { struct vm_area_struct *vma; int vma_count = 0; + VMA_ITERATOR(vmi, current->mm, 0); - for_each_mte_vma(current, vma) + for_each_mte_vma(vmi, vma) vma_count++; return vma_count; @@ -91,8 +92,9 @@ Elf_Half elf_core_extra_phdrs(void) int elf_core_write_extra_phdrs(struct coredump_params *cprm, loff_t offset) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, current->mm, 0); - for_each_mte_vma(current, vma) { + for_each_mte_vma(vmi, vma) { struct elf_phdr phdr; phdr.p_type = PT_AARCH64_MEMTAG_MTE; @@ -116,8 +118,9 @@ size_t elf_core_extra_data_size(void) { struct vm_area_struct *vma; size_t data_size = 0; + VMA_ITERATOR(vmi, current->mm, 0); - for_each_mte_vma(current, vma) + for_each_mte_vma(vmi, vma) data_size += mte_vma_tag_dump_size(vma); return data_size; @@ -126,8 +129,9 @@ size_t elf_core_extra_data_size(void) int elf_core_write_extra_data(struct coredump_params *cprm) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, current->mm, 0); - for_each_mte_vma(current, vma) { + for_each_mte_vma(vmi, vma) { if (vma->vm_flags & VM_DONTDUMP) continue; diff --git a/arch/arm64/kernel/vdso.c b/arch/arm64/kernel/vdso.c index a61fc4f989b3..a8388af62b99 100644 --- a/arch/arm64/kernel/vdso.c +++ b/arch/arm64/kernel/vdso.c @@ -136,10 +136,11 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns) { struct mm_struct *mm = task->mm; struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long size = vma->vm_end - vma->vm_start; if (vma_is_special_mapping(vma, vdso_info[VDSO_ABI_AA64].dm)) diff --git a/arch/parisc/kernel/cache.c b/arch/parisc/kernel/cache.c index 3feb7694e0ca..1d3b8bc8a623 100644 --- a/arch/parisc/kernel/cache.c +++ b/arch/parisc/kernel/cache.c @@ -657,15 +657,20 @@ static inline unsigned long mm_total_size(struct mm_struct *mm) { struct vm_area_struct *vma; unsigned long usize = 0; + VMA_ITERATOR(vmi, mm, 0); - for (vma = mm->mmap; vma && usize < parisc_cache_flush_threshold; vma = vma->vm_next) + for_each_vma(vmi, vma) { + if (usize >= parisc_cache_flush_threshold) + break; usize += vma->vm_end - vma->vm_start; + } return usize; } void flush_cache_mm(struct mm_struct *mm) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); /* * Flushing the whole cache on each cpu takes forever on @@ -685,7 +690,7 @@ void flush_cache_mm(struct mm_struct *mm) } /* Flush mm */ - for (vma = mm->mmap; vma; vma = vma->vm_next) + for_each_vma(vmi, vma) flush_cache_pages(vma, vma->vm_start, vma->vm_end); } diff --git a/arch/powerpc/kernel/vdso.c b/arch/powerpc/kernel/vdso.c index 0da287544054..94a8fa5017c3 100644 --- a/arch/powerpc/kernel/vdso.c +++ b/arch/powerpc/kernel/vdso.c @@ -113,18 +113,18 @@ struct vdso_data *arch_get_vdso_data(void *vvar_page) int vdso_join_timens(struct task_struct *task, struct time_namespace *ns) { struct mm_struct *mm = task->mm; + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; mmap_read_lock(mm); - - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long size = vma->vm_end - vma->vm_start; if (vma_is_special_mapping(vma, &vvar_spec)) zap_page_range(vma, vma->vm_start, size); } - mmap_read_unlock(mm); + return 0; } diff --git a/arch/powerpc/mm/book3s32/tlb.c b/arch/powerpc/mm/book3s32/tlb.c index 19f0ef950d77..9ad6b56bfec9 100644 --- a/arch/powerpc/mm/book3s32/tlb.c +++ b/arch/powerpc/mm/book3s32/tlb.c @@ -81,14 +81,15 @@ EXPORT_SYMBOL(hash__flush_range); void hash__flush_tlb_mm(struct mm_struct *mm) { struct vm_area_struct *mp; + VMA_ITERATOR(vmi, mm, 0); /* - * It is safe to go down the mm's list of vmas when called - * from dup_mmap, holding mmap_lock. It would also be safe from - * unmap_region or exit_mmap, but not from vmtruncate on SMP - - * but it seems dup_mmap is the only SMP case which gets here. + * It is safe to iterate the vmas when called from dup_mmap, + * holding mmap_lock. It would also be safe from unmap_region + * or exit_mmap, but not from vmtruncate on SMP - but it seems + * dup_mmap is the only SMP case which gets here. */ - for (mp = mm->mmap; mp != NULL; mp = mp->vm_next) + for_each_vma(vmi, mp) hash__flush_range(mp->vm_mm, mp->vm_start, mp->vm_end); } EXPORT_SYMBOL(hash__flush_tlb_mm); diff --git a/arch/powerpc/mm/book3s64/subpage_prot.c b/arch/powerpc/mm/book3s64/subpage_prot.c index 60c6ea16a972..d73b3b4176e8 100644 --- a/arch/powerpc/mm/book3s64/subpage_prot.c +++ b/arch/powerpc/mm/book3s64/subpage_prot.c @@ -149,24 +149,15 @@ static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr, unsigned long len) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, addr); /* * We don't try too hard, we just mark all the vma in that range * VM_NOHUGEPAGE and split them. */ - vma = find_vma(mm, addr); - /* - * If the range is in unmapped range, just return - */ - if (vma && ((addr + len) <= vma->vm_start)) - return; - - while (vma) { - if (vma->vm_start >= (addr + len)) - break; + for_each_vma_range(vmi, vma, addr + len) { vma->vm_flags |= VM_NOHUGEPAGE; walk_page_vma(vma, &subpage_walk_ops, NULL); - vma = vma->vm_next; } } #else diff --git a/arch/riscv/kernel/vdso.c b/arch/riscv/kernel/vdso.c index 4abc9aebdfae..123d05255fcf 100644 --- a/arch/riscv/kernel/vdso.c +++ b/arch/riscv/kernel/vdso.c @@ -119,10 +119,11 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns) { struct mm_struct *mm = task->mm; struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long size = vma->vm_end - vma->vm_start; if (vma_is_special_mapping(vma, vdso_info.dm)) diff --git a/arch/s390/kernel/vdso.c b/arch/s390/kernel/vdso.c index 5075cde77b29..535099f2736d 100644 --- a/arch/s390/kernel/vdso.c +++ b/arch/s390/kernel/vdso.c @@ -69,10 +69,11 @@ static struct page *find_timens_vvar_page(struct vm_area_struct *vma) int vdso_join_timens(struct task_struct *task, struct time_namespace *ns) { struct mm_struct *mm = task->mm; + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long size = vma->vm_end - vma->vm_start; if (!vma_is_special_mapping(vma, &vvar_mapping)) diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c index 62758cb5872f..02d15c8dc92e 100644 --- a/arch/s390/mm/gmap.c +++ b/arch/s390/mm/gmap.c @@ -2515,8 +2515,9 @@ static const struct mm_walk_ops thp_split_walk_ops = { static inline void thp_split_mm(struct mm_struct *mm) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); - for (vma = mm->mmap; vma != NULL; vma = vma->vm_next) { + for_each_vma(vmi, vma) { vma->vm_flags &= ~VM_HUGEPAGE; vma->vm_flags |= VM_NOHUGEPAGE; walk_page_vma(vma, &thp_split_walk_ops, NULL); @@ -2584,8 +2585,9 @@ int gmap_mark_unmergeable(void) struct mm_struct *mm = current->mm; struct vm_area_struct *vma; int ret; + VMA_ITERATOR(vmi, mm, 0); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { ret = ksm_madvise(vma, vma->vm_start, vma->vm_end, MADV_UNMERGEABLE, &vma->vm_flags); if (ret) diff --git a/arch/um/kernel/tlb.c b/arch/um/kernel/tlb.c index bc38f79ca3a3..ad449173a1a1 100644 --- a/arch/um/kernel/tlb.c +++ b/arch/um/kernel/tlb.c @@ -584,21 +584,19 @@ void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start, void flush_tlb_mm(struct mm_struct *mm) { - struct vm_area_struct *vma = mm->mmap; + struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); - while (vma != NULL) { + for_each_vma(vmi, vma) fix_range(mm, vma->vm_start, vma->vm_end, 0); - vma = vma->vm_next; - } } void force_flush_all(void) { struct mm_struct *mm = current->mm; - struct vm_area_struct *vma = mm->mmap; + struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); - while (vma != NULL) { + for_each_vma(vmi, vma) fix_range(mm, vma->vm_start, vma->vm_end, 1); - vma = vma->vm_next; - } } diff --git a/arch/x86/Kconfig b/arch/x86/Kconfig index 4728d3f5d5c4..23f622578950 100644 --- a/arch/x86/Kconfig +++ b/arch/x86/Kconfig @@ -85,6 +85,7 @@ config X86 select ARCH_HAS_PMEM_API if X86_64 select ARCH_HAS_PTE_DEVMAP if X86_64 select ARCH_HAS_PTE_SPECIAL + select ARCH_HAS_NONLEAF_PMD_YOUNG if PGTABLE_LEVELS > 2 select ARCH_HAS_UACCESS_FLUSHCACHE if X86_64 select ARCH_HAS_COPY_MC if X86_64 select ARCH_HAS_SET_MEMORY diff --git a/arch/x86/entry/vdso/vma.c b/arch/x86/entry/vdso/vma.c index 1000d457c332..6292b960037b 100644 --- a/arch/x86/entry/vdso/vma.c +++ b/arch/x86/entry/vdso/vma.c @@ -127,17 +127,17 @@ int vdso_join_timens(struct task_struct *task, struct time_namespace *ns) { struct mm_struct *mm = task->mm; struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); mmap_read_lock(mm); - - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long size = vma->vm_end - vma->vm_start; if (vma_is_special_mapping(vma, &vvar_mapping)) zap_page_range(vma, vma->vm_start, size); } - mmap_read_unlock(mm); + return 0; } #else @@ -354,6 +354,7 @@ int map_vdso_once(const struct vdso_image *image, unsigned long addr) { struct mm_struct *mm = current->mm; struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); mmap_write_lock(mm); /* @@ -363,7 +364,7 @@ int map_vdso_once(const struct vdso_image *image, unsigned long addr) * We could search vma near context.vdso, but it's a slowpath, * so let's explicitly check all VMAs to be completely sure. */ - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (vma_is_special_mapping(vma, &vdso_mapping) || vma_is_special_mapping(vma, &vvar_mapping)) { mmap_write_unlock(mm); diff --git a/arch/x86/include/asm/pgtable.h b/arch/x86/include/asm/pgtable.h index 44e2d6f1dbaa..5059799bebe3 100644 --- a/arch/x86/include/asm/pgtable.h +++ b/arch/x86/include/asm/pgtable.h @@ -815,7 +815,8 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd) static inline int pmd_bad(pmd_t pmd) { - return (pmd_flags(pmd) & ~_PAGE_USER) != _KERNPG_TABLE; + return (pmd_flags(pmd) & ~(_PAGE_USER | _PAGE_ACCESSED)) != + (_KERNPG_TABLE & ~_PAGE_ACCESSED); } static inline unsigned long pages_to_mb(unsigned long npg) @@ -1431,10 +1432,10 @@ static inline bool arch_has_pfn_modify_check(void) return boot_cpu_has_bug(X86_BUG_L1TF); } -#define arch_faults_on_old_pte arch_faults_on_old_pte -static inline bool arch_faults_on_old_pte(void) +#define arch_has_hw_pte_young arch_has_hw_pte_young +static inline bool arch_has_hw_pte_young(void) { - return false; + return true; } #ifdef CONFIG_PAGE_TABLE_CHECK diff --git a/arch/x86/kernel/tboot.c b/arch/x86/kernel/tboot.c index 3bacd935f840..4c1bcb6053fc 100644 --- a/arch/x86/kernel/tboot.c +++ b/arch/x86/kernel/tboot.c @@ -95,7 +95,7 @@ void __init tboot_probe(void) static pgd_t *tboot_pg_dir; static struct mm_struct tboot_mm = { - .mm_rb = RB_ROOT, + .mm_mt = MTREE_INIT_EXT(mm_mt, MM_MT_FLAGS, tboot_mm.mmap_lock), .pgd = swapper_pg_dir, .mm_users = ATOMIC_INIT(2), .mm_count = ATOMIC_INIT(1), diff --git a/arch/x86/mm/pgtable.c b/arch/x86/mm/pgtable.c index a932d7712d85..8525f2876fb4 100644 --- a/arch/x86/mm/pgtable.c +++ b/arch/x86/mm/pgtable.c @@ -550,7 +550,7 @@ int ptep_test_and_clear_young(struct vm_area_struct *vma, return ret; } -#ifdef CONFIG_TRANSPARENT_HUGEPAGE +#if defined(CONFIG_TRANSPARENT_HUGEPAGE) || defined(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG) int pmdp_test_and_clear_young(struct vm_area_struct *vma, unsigned long addr, pmd_t *pmdp) { @@ -562,6 +562,9 @@ int pmdp_test_and_clear_young(struct vm_area_struct *vma, return ret; } +#endif + +#ifdef CONFIG_TRANSPARENT_HUGEPAGE int pudp_test_and_clear_young(struct vm_area_struct *vma, unsigned long addr, pud_t *pudp) { diff --git a/arch/xtensa/kernel/syscall.c b/arch/xtensa/kernel/syscall.c index 201356faa7e6..b3c2450d6f23 100644 --- a/arch/xtensa/kernel/syscall.c +++ b/arch/xtensa/kernel/syscall.c @@ -58,6 +58,7 @@ unsigned long arch_get_unmapped_area(struct file *filp, unsigned long addr, unsigned long len, unsigned long pgoff, unsigned long flags) { struct vm_area_struct *vmm; + struct vma_iterator vmi; if (flags & MAP_FIXED) { /* We do not accept a shared mapping if it would violate @@ -79,15 +80,20 @@ unsigned long arch_get_unmapped_area(struct file *filp, unsigned long addr, else addr = PAGE_ALIGN(addr); - for (vmm = find_vma(current->mm, addr); ; vmm = vmm->vm_next) { - /* At this point: (!vmm || addr < vmm->vm_end). */ - if (TASK_SIZE - len < addr) - return -ENOMEM; - if (!vmm || addr + len <= vm_start_gap(vmm)) - return addr; + vma_iter_init(&vmi, current->mm, addr); + for_each_vma(vmi, vmm) { + /* At this point: (addr < vmm->vm_end). */ + if (addr + len <= vm_start_gap(vmm)) + break; + addr = vmm->vm_end; if (flags & MAP_SHARED) addr = COLOUR_ALIGN(addr, pgoff); } + + if (TASK_SIZE - len < addr) + return -ENOMEM; + + return addr; } #endif diff --git a/drivers/firmware/efi/efi.c b/drivers/firmware/efi/efi.c index a06decee51e0..1ffc27c5be36 100644 --- a/drivers/firmware/efi/efi.c +++ b/drivers/firmware/efi/efi.c @@ -57,7 +57,7 @@ static unsigned long __initdata mem_reserve = EFI_INVALID_TABLE_ADDR; static unsigned long __initdata rt_prop = EFI_INVALID_TABLE_ADDR; struct mm_struct efi_mm = { - .mm_rb = RB_ROOT, + .mm_mt = MTREE_INIT_EXT(mm_mt, MM_MT_FLAGS, efi_mm.mmap_lock), .mm_users = ATOMIC_INIT(2), .mm_count = ATOMIC_INIT(1), .write_protect_seq = SEQCNT_ZERO(efi_mm.write_protect_seq), diff --git a/drivers/gpu/drm/i915/gem/i915_gem_userptr.c b/drivers/gpu/drm/i915/gem/i915_gem_userptr.c index e4515d6acd43..f34e01a7fefb 100644 --- a/drivers/gpu/drm/i915/gem/i915_gem_userptr.c +++ b/drivers/gpu/drm/i915/gem/i915_gem_userptr.c @@ -426,12 +426,11 @@ static const struct drm_i915_gem_object_ops i915_gem_userptr_ops = { static int probe_range(struct mm_struct *mm, unsigned long addr, unsigned long len) { - const unsigned long end = addr + len; + VMA_ITERATOR(vmi, mm, addr); struct vm_area_struct *vma; - int ret = -EFAULT; mmap_read_lock(mm); - for (vma = find_vma(mm, addr); vma; vma = vma->vm_next) { + for_each_vma_range(vmi, vma, addr + len) { /* Check for holes, note that we also update the addr below */ if (vma->vm_start > addr) break; @@ -439,16 +438,13 @@ probe_range(struct mm_struct *mm, unsigned long addr, unsigned long len) if (vma->vm_flags & (VM_PFNMAP | VM_MIXEDMAP)) break; - if (vma->vm_end >= end) { - ret = 0; - break; - } - addr = vma->vm_end; } mmap_read_unlock(mm); - return ret; + if (vma) + return -EFAULT; + return 0; } /* diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c index 60c829113299..2c64f55cf01f 100644 --- a/drivers/misc/cxl/fault.c +++ b/drivers/misc/cxl/fault.c @@ -280,22 +280,6 @@ void cxl_handle_fault(struct work_struct *fault_work) mmput(mm); } -static void cxl_prefault_one(struct cxl_context *ctx, u64 ea) -{ - struct mm_struct *mm; - - mm = get_mem_context(ctx); - if (mm == NULL) { - pr_devel("cxl_prefault_one unable to get mm %i\n", - pid_nr(ctx->pid)); - return; - } - - cxl_fault_segment(ctx, mm, ea); - - mmput(mm); -} - static u64 next_segment(u64 ea, u64 vsid) { if (vsid & SLB_VSID_B_1T) @@ -306,23 +290,16 @@ static u64 next_segment(u64 ea, u64 vsid) return ea + 1; } -static void cxl_prefault_vma(struct cxl_context *ctx) +static void cxl_prefault_vma(struct cxl_context *ctx, struct mm_struct *mm) { u64 ea, last_esid = 0; struct copro_slb slb; + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; int rc; - struct mm_struct *mm; - - mm = get_mem_context(ctx); - if (mm == NULL) { - pr_devel("cxl_prefault_vm unable to get mm %i\n", - pid_nr(ctx->pid)); - return; - } mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { for (ea = vma->vm_start; ea < vma->vm_end; ea = next_segment(ea, slb.vsid)) { rc = copro_calculate_slb(mm, ea, &slb); @@ -337,20 +314,28 @@ static void cxl_prefault_vma(struct cxl_context *ctx) } } mmap_read_unlock(mm); - - mmput(mm); } void cxl_prefault(struct cxl_context *ctx, u64 wed) { + struct mm_struct *mm = get_mem_context(ctx); + + if (mm == NULL) { + pr_devel("cxl_prefault unable to get mm %i\n", + pid_nr(ctx->pid)); + return; + } + switch (ctx->afu->prefault_mode) { case CXL_PREFAULT_WED: - cxl_prefault_one(ctx, wed); + cxl_fault_segment(ctx, mm, wed); break; case CXL_PREFAULT_ALL: - cxl_prefault_vma(ctx); + cxl_prefault_vma(ctx, mm); break; default: break; } + + mmput(mm); } diff --git a/drivers/tee/optee/call.c b/drivers/tee/optee/call.c index 28f87cd8b3ed..290b1bb0e9cd 100644 --- a/drivers/tee/optee/call.c +++ b/drivers/tee/optee/call.c @@ -492,15 +492,18 @@ static bool is_normal_memory(pgprot_t p) #endif } -static int __check_mem_type(struct vm_area_struct *vma, unsigned long end) +static int __check_mem_type(struct mm_struct *mm, unsigned long start, + unsigned long end) { - while (vma && is_normal_memory(vma->vm_page_prot)) { - if (vma->vm_end >= end) - return 0; - vma = vma->vm_next; + struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, start); + + for_each_vma_range(vmi, vma, end) { + if (!is_normal_memory(vma->vm_page_prot)) + return -EINVAL; } - return -EINVAL; + return 0; } int optee_check_mem_type(unsigned long start, size_t num_pages) @@ -516,8 +519,7 @@ int optee_check_mem_type(unsigned long start, size_t num_pages) return 0; mmap_read_lock(mm); - rc = __check_mem_type(find_vma(mm, start), - start + num_pages * PAGE_SIZE); + rc = __check_mem_type(mm, start, start + num_pages * PAGE_SIZE); mmap_read_unlock(mm); return rc; diff --git a/drivers/xen/privcmd.c b/drivers/xen/privcmd.c index e88e8f6f0a33..fae50a24630b 100644 --- a/drivers/xen/privcmd.c +++ b/drivers/xen/privcmd.c @@ -282,7 +282,7 @@ static long privcmd_ioctl_mmap(struct file *file, void __user *udata) struct page, lru); struct privcmd_mmap_entry *msg = page_address(page); - vma = find_vma(mm, msg->va); + vma = vma_lookup(mm, msg->va); rc = -EINVAL; if (!vma || (msg->va != vma->vm_start) || vma->vm_private_data) diff --git a/fs/coredump.c b/fs/coredump.c index 3538f3a63965..e69f87211839 100644 --- a/fs/coredump.c +++ b/fs/coredump.c @@ -1101,30 +1101,20 @@ static unsigned long vma_dump_size(struct vm_area_struct *vma, return vma->vm_end - vma->vm_start; } -static struct vm_area_struct *first_vma(struct task_struct *tsk, - struct vm_area_struct *gate_vma) -{ - struct vm_area_struct *ret = tsk->mm->mmap; - - if (ret) - return ret; - return gate_vma; -} - /* * Helper function for iterating across a vma list. It ensures that the caller * will visit `gate_vma' prior to terminating the search. */ -static struct vm_area_struct *next_vma(struct vm_area_struct *this_vma, +static struct vm_area_struct *coredump_next_vma(struct ma_state *mas, + struct vm_area_struct *vma, struct vm_area_struct *gate_vma) { - struct vm_area_struct *ret; - - ret = this_vma->vm_next; - if (ret) - return ret; - if (this_vma == gate_vma) + if (gate_vma && (vma == gate_vma)) return NULL; + + vma = mas_next(mas, ULONG_MAX); + if (vma) + return vma; return gate_vma; } @@ -1148,9 +1138,10 @@ static void free_vma_snapshot(struct coredump_params *cprm) */ static bool dump_vma_snapshot(struct coredump_params *cprm) { - struct vm_area_struct *vma, *gate_vma; + struct vm_area_struct *gate_vma, *vma = NULL; struct mm_struct *mm = current->mm; - int i; + MA_STATE(mas, &mm->mm_mt, 0, 0); + int i = 0; /* * Once the stack expansion code is fixed to not change VMA bounds @@ -1170,8 +1161,7 @@ static bool dump_vma_snapshot(struct coredump_params *cprm) return false; } - for (i = 0, vma = first_vma(current, gate_vma); vma != NULL; - vma = next_vma(vma, gate_vma), i++) { + while ((vma = coredump_next_vma(&mas, vma, gate_vma)) != NULL) { struct core_vma_metadata *m = cprm->vma_meta + i; m->start = vma->vm_start; @@ -1179,10 +1169,10 @@ static bool dump_vma_snapshot(struct coredump_params *cprm) m->flags = vma->vm_flags; m->dump_size = vma_dump_size(vma, cprm->mm_flags); m->pgoff = vma->vm_pgoff; - m->file = vma->vm_file; if (m->file) get_file(m->file); + i++; } mmap_write_unlock(mm); diff --git a/fs/exec.c b/fs/exec.c index 6f3d6fce178f..47aa02fda159 100644 --- a/fs/exec.c +++ b/fs/exec.c @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -683,6 +682,8 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift) unsigned long length = old_end - old_start; unsigned long new_start = old_start - shift; unsigned long new_end = old_end - shift; + VMA_ITERATOR(vmi, mm, new_start); + struct vm_area_struct *next; struct mmu_gather tlb; BUG_ON(new_start > new_end); @@ -691,7 +692,7 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift) * ensure there are no vmas between where we want to go * and where we are */ - if (vma != find_vma(mm, new_start)) + if (vma != vma_next(&vmi)) return -EFAULT; /* @@ -710,12 +711,13 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift) lru_add_drain(); tlb_gather_mmu(&tlb, mm); + next = vma_next(&vmi); if (new_end > old_start) { /* * when the old and new regions overlap clear from new_end. */ free_pgd_range(&tlb, new_end, old_end, new_end, - vma->vm_next ? vma->vm_next->vm_start : USER_PGTABLES_CEILING); + next ? next->vm_start : USER_PGTABLES_CEILING); } else { /* * otherwise, clean from old_start; this is done to not touch @@ -724,7 +726,7 @@ static int shift_arg_pages(struct vm_area_struct *vma, unsigned long shift) * for the others its just a little faster. */ free_pgd_range(&tlb, old_start, old_end, new_end, - vma->vm_next ? vma->vm_next->vm_start : USER_PGTABLES_CEILING); + next ? next->vm_start : USER_PGTABLES_CEILING); } tlb_finish_mmu(&tlb); @@ -1023,9 +1025,9 @@ static int exec_mmap(struct mm_struct *mm) activate_mm(active_mm, mm); if (IS_ENABLED(CONFIG_ARCH_WANT_IRQS_OFF_ACTIVATE_MM)) local_irq_enable(); - tsk->mm->vmacache_seqnum = 0; - vmacache_flush(tsk); + lru_gen_add_mm(mm); task_unlock(tsk); + lru_gen_use_mm(mm); if (old_mm) { mmap_read_unlock(old_mm); BUG_ON(active_mm != old_mm); diff --git a/fs/fuse/dev.c b/fs/fuse/dev.c index 51897427a534..b4a6e0a1b945 100644 --- a/fs/fuse/dev.c +++ b/fs/fuse/dev.c @@ -776,7 +776,8 @@ static int fuse_check_page(struct page *page) 1 << PG_active | 1 << PG_workingset | 1 << PG_reclaim | - 1 << PG_waiters))) { + 1 << PG_waiters | + LRU_GEN_MASK | LRU_REFS_MASK))) { dump_page(page, "fuse: trying to steal weird page"); return 1; } diff --git a/fs/proc/base.c b/fs/proc/base.c index 93f7e3d971e4..12885a75913f 100644 --- a/fs/proc/base.c +++ b/fs/proc/base.c @@ -2350,6 +2350,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) GENRADIX(struct map_files_info) fa; struct map_files_info *p; int ret; + struct vma_iterator vmi; genradix_init(&fa); @@ -2388,7 +2389,9 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) * routine might require mmap_lock taken in might_fault(). */ - for (vma = mm->mmap, pos = 2; vma; vma = vma->vm_next) { + pos = 2; + vma_iter_init(&vmi, mm, 0); + for_each_vma(vmi, vma) { if (!vma->vm_file) continue; if (++pos <= ctx->pos) diff --git a/fs/proc/internal.h b/fs/proc/internal.h index 06a80f78433d..f03000764ce5 100644 --- a/fs/proc/internal.h +++ b/fs/proc/internal.h @@ -285,7 +285,7 @@ struct proc_maps_private { struct task_struct *task; struct mm_struct *mm; #ifdef CONFIG_MMU - struct vm_area_struct *tail_vma; + struct vma_iterator iter; #endif #ifdef CONFIG_NUMA struct mempolicy *task_mempolicy; diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c index 1e7bbc0873a4..72a02b563e64 100644 --- a/fs/proc/task_mmu.c +++ b/fs/proc/task_mmu.c @@ -1,6 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 #include -#include #include #include #include @@ -124,12 +123,26 @@ static void release_task_mempolicy(struct proc_maps_private *priv) } #endif +static struct vm_area_struct *proc_get_vma(struct proc_maps_private *priv, + loff_t *ppos) +{ + struct vm_area_struct *vma = vma_next(&priv->iter); + + if (vma) { + *ppos = vma->vm_start; + } else { + *ppos = -2UL; + vma = get_gate_vma(priv->mm); + } + + return vma; +} + static void *m_start(struct seq_file *m, loff_t *ppos) { struct proc_maps_private *priv = m->private; unsigned long last_addr = *ppos; struct mm_struct *mm; - struct vm_area_struct *vma; /* See m_next(). Zero at the start or after lseek. */ if (last_addr == -1UL) @@ -153,31 +166,21 @@ static void *m_start(struct seq_file *m, loff_t *ppos) return ERR_PTR(-EINTR); } + vma_iter_init(&priv->iter, mm, last_addr); hold_task_mempolicy(priv); - priv->tail_vma = get_gate_vma(mm); - - vma = find_vma(mm, last_addr); - if (vma) - return vma; + if (last_addr == -2UL) + return get_gate_vma(mm); - return priv->tail_vma; + return proc_get_vma(priv, ppos); } static void *m_next(struct seq_file *m, void *v, loff_t *ppos) { - struct proc_maps_private *priv = m->private; - struct vm_area_struct *next, *vma = v; - - if (vma == priv->tail_vma) - next = NULL; - else if (vma->vm_next) - next = vma->vm_next; - else - next = priv->tail_vma; - - *ppos = next ? next->vm_start : -1UL; - - return next; + if (*ppos == -2UL) { + *ppos = -1UL; + return NULL; + } + return proc_get_vma(m->private, ppos); } static void m_stop(struct seq_file *m, void *v) @@ -877,16 +880,16 @@ static int show_smaps_rollup(struct seq_file *m, void *v) { struct proc_maps_private *priv = m->private; struct mem_size_stats mss; - struct mm_struct *mm; + struct mm_struct *mm = priv->mm; struct vm_area_struct *vma; - unsigned long last_vma_end = 0; + unsigned long vma_start = 0, last_vma_end = 0; int ret = 0; + MA_STATE(mas, &mm->mm_mt, 0, 0); priv->task = get_proc_task(priv->inode); if (!priv->task) return -ESRCH; - mm = priv->mm; if (!mm || !mmget_not_zero(mm)) { ret = -ESRCH; goto out_put_task; @@ -899,8 +902,13 @@ static int show_smaps_rollup(struct seq_file *m, void *v) goto out_put_mm; hold_task_mempolicy(priv); + vma = mas_find(&mas, 0); + + if (unlikely(!vma)) + goto empty_set; - for (vma = priv->mm->mmap; vma;) { + vma_start = vma->vm_start; + do { smap_gather_stats(vma, &mss, 0); last_vma_end = vma->vm_end; @@ -909,6 +917,7 @@ static int show_smaps_rollup(struct seq_file *m, void *v) * access it for write request. */ if (mmap_lock_is_contended(mm)) { + mas_pause(&mas); mmap_read_unlock(mm); ret = mmap_read_lock_killable(mm); if (ret) { @@ -952,7 +961,7 @@ static int show_smaps_rollup(struct seq_file *m, void *v) * contains last_vma_end. * Iterate VMA' from last_vma_end. */ - vma = find_vma(mm, last_vma_end - 1); + vma = mas_find(&mas, ULONG_MAX); /* Case 3 above */ if (!vma) break; @@ -966,11 +975,10 @@ static int show_smaps_rollup(struct seq_file *m, void *v) smap_gather_stats(vma, &mss, last_vma_end); } /* Case 2 above */ - vma = vma->vm_next; - } + } while ((vma = mas_find(&mas, ULONG_MAX)) != NULL); - show_vma_header_prefix(m, priv->mm->mmap ? priv->mm->mmap->vm_start : 0, - last_vma_end, 0, 0, 0, 0); +empty_set: + show_vma_header_prefix(m, vma_start, last_vma_end, 0, 0, 0, 0); seq_pad(m, ' '); seq_puts(m, "[rollup]\n"); @@ -1263,6 +1271,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, return -ESRCH; mm = get_task_mm(task); if (mm) { + MA_STATE(mas, &mm->mm_mt, 0, 0); struct mmu_notifier_range range; struct clear_refs_private cp = { .type = type, @@ -1282,7 +1291,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, } if (type == CLEAR_REFS_SOFT_DIRTY) { - for (vma = mm->mmap; vma; vma = vma->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { if (!(vma->vm_flags & VM_SOFTDIRTY)) continue; vma->vm_flags &= ~VM_SOFTDIRTY; @@ -1294,8 +1303,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, 0, NULL, mm, 0, -1UL); mmu_notifier_invalidate_range_start(&range); } - walk_page_range(mm, 0, mm->highest_vm_end, &clear_refs_walk_ops, - &cp); + walk_page_range(mm, 0, -1, &clear_refs_walk_ops, &cp); if (type == CLEAR_REFS_SOFT_DIRTY) { mmu_notifier_invalidate_range_end(&range); flush_tlb_mm(mm); diff --git a/fs/proc/task_nommu.c b/fs/proc/task_nommu.c index a6d21fc0033c..2fd06f52b6a4 100644 --- a/fs/proc/task_nommu.c +++ b/fs/proc/task_nommu.c @@ -20,15 +20,13 @@ */ void task_mem(struct seq_file *m, struct mm_struct *mm) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; struct vm_region *region; - struct rb_node *p; unsigned long bytes = 0, sbytes = 0, slack = 0, size; - - mmap_read_lock(mm); - for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { - vma = rb_entry(p, struct vm_area_struct, vm_rb); + mmap_read_lock(mm); + for_each_vma(vmi, vma) { bytes += kobjsize(vma); region = vma->vm_region; @@ -82,15 +80,13 @@ void task_mem(struct seq_file *m, struct mm_struct *mm) unsigned long task_vsize(struct mm_struct *mm) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; - struct rb_node *p; unsigned long vsize = 0; mmap_read_lock(mm); - for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { - vma = rb_entry(p, struct vm_area_struct, vm_rb); + for_each_vma(vmi, vma) vsize += vma->vm_end - vma->vm_start; - } mmap_read_unlock(mm); return vsize; } @@ -99,14 +95,13 @@ unsigned long task_statm(struct mm_struct *mm, unsigned long *shared, unsigned long *text, unsigned long *data, unsigned long *resident) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; struct vm_region *region; - struct rb_node *p; unsigned long size = kobjsize(mm); mmap_read_lock(mm); - for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { - vma = rb_entry(p, struct vm_area_struct, vm_rb); + for_each_vma(vmi, vma) { size += kobjsize(vma); region = vma->vm_region; if (region) { @@ -190,17 +185,19 @@ static int nommu_vma_show(struct seq_file *m, struct vm_area_struct *vma) */ static int show_map(struct seq_file *m, void *_p) { - struct rb_node *p = _p; - - return nommu_vma_show(m, rb_entry(p, struct vm_area_struct, vm_rb)); + return nommu_vma_show(m, _p); } static void *m_start(struct seq_file *m, loff_t *pos) { struct proc_maps_private *priv = m->private; struct mm_struct *mm; - struct rb_node *p; - loff_t n = *pos; + struct vm_area_struct *vma; + unsigned long addr = *pos; + + /* See m_next(). Zero at the start or after lseek. */ + if (addr == -1UL) + return NULL; /* pin the task and mm whilst we play with them */ priv->task = get_proc_task(priv->inode); @@ -216,10 +213,10 @@ static void *m_start(struct seq_file *m, loff_t *pos) return ERR_PTR(-EINTR); } - /* start from the Nth VMA */ - for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) - if (n-- == 0) - return p; + /* start the next element from addr */ + vma = find_vma(mm, addr); + if (vma) + return vma; mmap_read_unlock(mm); mmput(mm); @@ -242,10 +239,10 @@ static void m_stop(struct seq_file *m, void *_vml) static void *m_next(struct seq_file *m, void *_p, loff_t *pos) { - struct rb_node *p = _p; + struct vm_area_struct *vma = _p; - (*pos)++; - return p ? rb_next(p) : NULL; + *pos = vma->vm_end; + return find_vma(vma->vm_mm, vma->vm_end); } static const struct seq_operations proc_pid_maps_ops = { diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 0c1d33c4f74c..2e125af0d7a2 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -615,14 +615,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, if (release_new_ctx) { struct vm_area_struct *vma; struct mm_struct *mm = release_new_ctx->mm; + VMA_ITERATOR(vmi, mm, 0); /* the various vma->vm_userfaultfd_ctx still points to it */ mmap_write_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) + for_each_vma(vmi, vma) { if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~__VM_UFFD_FLAGS; } + } mmap_write_unlock(mm); userfaultfd_ctx_put(release_new_ctx); @@ -803,11 +805,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, return false; } -int userfaultfd_unmap_prep(struct vm_area_struct *vma, - unsigned long start, unsigned long end, - struct list_head *unmaps) +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, + unsigned long end, struct list_head *unmaps) { - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { + VMA_ITERATOR(vmi, mm, start); + struct vm_area_struct *vma; + + for_each_vma_range(vmi, vma, end) { struct userfaultfd_unmap_ctx *unmap_ctx; struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; @@ -857,6 +861,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) /* len == 0 means wake all */ struct userfaultfd_wake_range range = { .len = 0, }; unsigned long new_flags; + MA_STATE(mas, &mm->mm_mt, 0, 0); WRITE_ONCE(ctx->released, true); @@ -873,7 +878,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) */ mmap_write_lock(mm); prev = NULL; - for (vma = mm->mmap; vma; vma = vma->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { cond_resched(); BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ !!(vma->vm_flags & __VM_UFFD_FLAGS)); @@ -887,10 +892,13 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_file, vma->vm_pgoff, vma_policy(vma), NULL_VM_UFFD_CTX, anon_vma_name(vma)); - if (prev) + if (prev) { + mas_pause(&mas); vma = prev; - else + } else { prev = vma; + } + vma->vm_flags = new_flags; vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } @@ -1272,6 +1280,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, bool found; bool basic_ioctls; unsigned long start, end, vma_end; + MA_STATE(mas, &mm->mm_mt, 0, 0); user_uffdio_register = (struct uffdio_register __user *) arg; @@ -1314,7 +1323,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); - vma = find_vma_prev(mm, start, &prev); + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); if (!vma) goto out_unlock; @@ -1339,7 +1349,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, */ found = false; basic_ioctls = false; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1399,8 +1409,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, } BUG_ON(!found); - if (vma->vm_start < start) - prev = vma; + mas_set(&mas, start); + prev = mas_prev(&mas, 0); + if (prev != vma) + mas_next(&mas, ULONG_MAX); ret = 0; do { @@ -1430,6 +1442,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, ((struct vm_userfaultfd_ctx){ ctx }), anon_vma_name(vma)); if (prev) { + /* vma_merge() invalidated the mas */ + mas_pause(&mas); vma = prev; goto next; } @@ -1437,11 +1451,15 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, ret = split_vma(mm, vma, start, 1); if (ret) break; + /* split_vma() invalidated the mas */ + mas_pause(&mas); } if (vma->vm_end > end) { ret = split_vma(mm, vma, end, 0); if (ret) break; + /* split_vma() invalidated the mas */ + mas_pause(&mas); } next: /* @@ -1458,8 +1476,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; - } while (vma && vma->vm_start < end); + vma = mas_next(&mas, end - 1); + } while (vma); out_unlock: mmap_write_unlock(mm); mmput(mm); @@ -1503,6 +1521,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, bool found; unsigned long start, end, vma_end; const void __user *buf = (void __user *)arg; + MA_STATE(mas, &mm->mm_mt, 0, 0); ret = -EFAULT; if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) @@ -1521,7 +1540,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); - vma = find_vma_prev(mm, start, &prev); + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); if (!vma) goto out_unlock; @@ -1546,7 +1566,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, */ found = false; ret = -EINVAL; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1566,8 +1586,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, } BUG_ON(!found); - if (vma->vm_start < start) - prev = vma; + mas_set(&mas, start); + prev = mas_prev(&mas, 0); + if (prev != vma) + mas_next(&mas, ULONG_MAX); ret = 0; do { @@ -1612,17 +1634,20 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, NULL_VM_UFFD_CTX, anon_vma_name(vma)); if (prev) { vma = prev; + mas_pause(&mas); goto next; } if (vma->vm_start < start) { ret = split_vma(mm, vma, start, 1); if (ret) break; + mas_pause(&mas); } if (vma->vm_end > end) { ret = split_vma(mm, vma, end, 0); if (ret) break; + mas_pause(&mas); } next: /* @@ -1636,8 +1661,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; - } while (vma && vma->vm_start < end); + vma = mas_next(&mas, end - 1); + } while (vma); out_unlock: mmap_write_unlock(mm); mmput(mm); diff --git a/include/linux/cgroup.h b/include/linux/cgroup.h index ac5d0515680e..9179463c3c9f 100644 --- a/include/linux/cgroup.h +++ b/include/linux/cgroup.h @@ -432,6 +432,18 @@ static inline void cgroup_put(struct cgroup *cgrp) css_put(&cgrp->self); } +extern struct mutex cgroup_mutex; + +static inline void cgroup_lock(void) +{ + mutex_lock(&cgroup_mutex); +} + +static inline void cgroup_unlock(void) +{ + mutex_unlock(&cgroup_mutex); +} + /** * task_css_set_check - obtain a task's css_set with extra access conditions * @task: the task to obtain css_set for @@ -446,7 +458,6 @@ static inline void cgroup_put(struct cgroup *cgrp) * as locks used during the cgroup_subsys::attach() methods. */ #ifdef CONFIG_PROVE_RCU -extern struct mutex cgroup_mutex; extern spinlock_t css_set_lock; #define task_css_set_check(task, __c) \ rcu_dereference_check((task)->cgroups, \ @@ -708,6 +719,8 @@ struct cgroup; static inline u64 cgroup_id(const struct cgroup *cgrp) { return 1; } static inline void css_get(struct cgroup_subsys_state *css) {} static inline void css_put(struct cgroup_subsys_state *css) {} +static inline void cgroup_lock(void) {} +static inline void cgroup_unlock(void) {} static inline int cgroup_attach_task_all(struct task_struct *from, struct task_struct *t) { return 0; } static inline int cgroupstats_build(struct cgroupstats *stats, diff --git a/include/linux/maple_tree.h b/include/linux/maple_tree.h new file mode 100644 index 000000000000..e594db58a0f1 --- /dev/null +++ b/include/linux/maple_tree.h @@ -0,0 +1,692 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ +#ifndef _LINUX_MAPLE_TREE_H +#define _LINUX_MAPLE_TREE_H +/* + * Maple Tree - An RCU-safe adaptive tree for storing ranges + * Copyright (c) 2018-2022 Oracle + * Authors: Liam R. Howlett + * Matthew Wilcox + */ + +#include +#include +#include +/* #define CONFIG_MAPLE_RCU_DISABLED */ +/* #define CONFIG_DEBUG_MAPLE_TREE_VERBOSE */ + +/* + * Allocated nodes are mutable until they have been inserted into the tree, + * at which time they cannot change their type until they have been removed + * from the tree and an RCU grace period has passed. + * + * Removed nodes have their ->parent set to point to themselves. RCU readers + * check ->parent before relying on the value that they loaded from the + * slots array. This lets us reuse the slots array for the RCU head. + * + * Nodes in the tree point to their parent unless bit 0 is set. + */ +#if defined(CONFIG_64BIT) || defined(BUILD_VDSO32_64) +/* 64bit sizes */ +#define MAPLE_NODE_SLOTS 31 /* 256 bytes including ->parent */ +#define MAPLE_RANGE64_SLOTS 16 /* 256 bytes */ +#define MAPLE_ARANGE64_SLOTS 10 /* 240 bytes */ +#define MAPLE_ARANGE64_META_MAX 15 /* Out of range for metadata */ +#define MAPLE_ALLOC_SLOTS (MAPLE_NODE_SLOTS - 1) +#else +/* 32bit sizes */ +#define MAPLE_NODE_SLOTS 63 /* 256 bytes including ->parent */ +#define MAPLE_RANGE64_SLOTS 32 /* 256 bytes */ +#define MAPLE_ARANGE64_SLOTS 21 /* 240 bytes */ +#define MAPLE_ARANGE64_META_MAX 31 /* Out of range for metadata */ +#define MAPLE_ALLOC_SLOTS (MAPLE_NODE_SLOTS - 2) +#endif /* defined(CONFIG_64BIT) || defined(BUILD_VDSO32_64) */ + +#define MAPLE_NODE_MASK 255UL + +/* + * The node->parent of the root node has bit 0 set and the rest of the pointer + * is a pointer to the tree itself. No more bits are available in this pointer + * (on m68k, the data structure may only be 2-byte aligned). + * + * Internal non-root nodes can only have maple_range_* nodes as parents. The + * parent pointer is 256B aligned like all other tree nodes. When storing a 32 + * or 64 bit values, the offset can fit into 4 bits. The 16 bit values need an + * extra bit to store the offset. This extra bit comes from a reuse of the last + * bit in the node type. This is possible by using bit 1 to indicate if bit 2 + * is part of the type or the slot. + * + * Once the type is decided, the decision of an allocation range type or a range + * type is done by examining the immutable tree flag for the MAPLE_ALLOC_RANGE + * flag. + * + * Node types: + * 0x??1 = Root + * 0x?00 = 16 bit nodes + * 0x010 = 32 bit nodes + * 0x110 = 64 bit nodes + * + * Slot size and location in the parent pointer: + * type : slot location + * 0x??1 : Root + * 0x?00 : 16 bit values, type in 0-1, slot in 2-6 + * 0x010 : 32 bit values, type in 0-2, slot in 3-6 + * 0x110 : 64 bit values, type in 0-2, slot in 3-6 + */ + +/* + * This metadata is used to optimize the gap updating code and in reverse + * searching for gaps or any other code that needs to find the end of the data. + */ +struct maple_metadata { + unsigned char end; + unsigned char gap; +}; + +/* + * Leaf nodes do not store pointers to nodes, they store user data. Users may + * store almost any bit pattern. As noted above, the optimisation of storing an + * entry at 0 in the root pointer cannot be done for data which have the bottom + * two bits set to '10'. We also reserve values with the bottom two bits set to + * '10' which are below 4096 (ie 2, 6, 10 .. 4094) for internal use. Some APIs + * return errnos as a negative errno shifted right by two bits and the bottom + * two bits set to '10', and while choosing to store these values in the array + * is not an error, it may lead to confusion if you're testing for an error with + * mas_is_err(). + * + * Non-leaf nodes store the type of the node pointed to (enum maple_type in bits + * 3-6), bit 2 is reserved. That leaves bits 0-1 unused for now. + * + * In regular B-Tree terms, pivots are called keys. The term pivot is used to + * indicate that the tree is specifying ranges, Pivots may appear in the + * subtree with an entry attached to the value whereas keys are unique to a + * specific position of a B-tree. Pivot values are inclusive of the slot with + * the same index. + */ + +struct maple_range_64 { + struct maple_pnode *parent; + unsigned long pivot[MAPLE_RANGE64_SLOTS - 1]; + union { + void __rcu *slot[MAPLE_RANGE64_SLOTS]; + struct { + void __rcu *pad[MAPLE_RANGE64_SLOTS - 1]; + struct maple_metadata meta; + }; + }; +}; + +/* + * At tree creation time, the user can specify that they're willing to trade off + * storing fewer entries in a tree in return for storing more information in + * each node. + * + * The maple tree supports recording the largest range of NULL entries available + * in this node, also called gaps. This optimises the tree for allocating a + * range. + */ +struct maple_arange_64 { + struct maple_pnode *parent; + unsigned long pivot[MAPLE_ARANGE64_SLOTS - 1]; + void __rcu *slot[MAPLE_ARANGE64_SLOTS]; + unsigned long gap[MAPLE_ARANGE64_SLOTS]; + struct maple_metadata meta; +}; + +struct maple_alloc { + unsigned long total; + unsigned char node_count; + unsigned int request_count; + struct maple_alloc *slot[MAPLE_ALLOC_SLOTS]; +}; + +struct maple_topiary { + struct maple_pnode *parent; + struct maple_enode *next; /* Overlaps the pivot */ +}; + +enum maple_type { + maple_dense, + maple_leaf_64, + maple_range_64, + maple_arange_64, +}; + + +/** + * DOC: Maple tree flags + * + * * MT_FLAGS_ALLOC_RANGE - Track gaps in this tree + * * MT_FLAGS_USE_RCU - Operate in RCU mode + * * MT_FLAGS_HEIGHT_OFFSET - The position of the tree height in the flags + * * MT_FLAGS_HEIGHT_MASK - The mask for the maple tree height value + * * MT_FLAGS_LOCK_MASK - How the mt_lock is used + * * MT_FLAGS_LOCK_IRQ - Acquired irq-safe + * * MT_FLAGS_LOCK_BH - Acquired bh-safe + * * MT_FLAGS_LOCK_EXTERN - mt_lock is not used + * + * MAPLE_HEIGHT_MAX The largest height that can be stored + */ +#define MT_FLAGS_ALLOC_RANGE 0x01 +#define MT_FLAGS_USE_RCU 0x02 +#define MT_FLAGS_HEIGHT_OFFSET 0x02 +#define MT_FLAGS_HEIGHT_MASK 0x7C +#define MT_FLAGS_LOCK_MASK 0x300 +#define MT_FLAGS_LOCK_IRQ 0x100 +#define MT_FLAGS_LOCK_BH 0x200 +#define MT_FLAGS_LOCK_EXTERN 0x300 + +#define MAPLE_HEIGHT_MAX 31 + + +#define MAPLE_NODE_TYPE_MASK 0x0F +#define MAPLE_NODE_TYPE_SHIFT 0x03 + +#define MAPLE_RESERVED_RANGE 4096 + +#ifdef CONFIG_LOCKDEP +typedef struct lockdep_map *lockdep_map_p; +#define mt_lock_is_held(mt) lock_is_held(mt->ma_external_lock) +#define mt_set_external_lock(mt, lock) \ + (mt)->ma_external_lock = &(lock)->dep_map +#else +typedef struct { /* nothing */ } lockdep_map_p; +#define mt_lock_is_held(mt) 1 +#define mt_set_external_lock(mt, lock) do { } while (0) +#endif + +/* + * If the tree contains a single entry at index 0, it is usually stored in + * tree->ma_root. To optimise for the page cache, an entry which ends in '00', + * '01' or '11' is stored in the root, but an entry which ends in '10' will be + * stored in a node. Bits 3-6 are used to store enum maple_type. + * + * The flags are used both to store some immutable information about this tree + * (set at tree creation time) and dynamic information set under the spinlock. + * + * Another use of flags are to indicate global states of the tree. This is the + * case with the MAPLE_USE_RCU flag, which indicates the tree is currently in + * RCU mode. This mode was added to allow the tree to reuse nodes instead of + * re-allocating and RCU freeing nodes when there is a single user. + */ +struct maple_tree { + union { + spinlock_t ma_lock; + lockdep_map_p ma_external_lock; + }; + void __rcu *ma_root; + unsigned int ma_flags; +}; + +/** + * MTREE_INIT() - Initialize a maple tree + * @name: The maple tree name + * @__flags: The maple tree flags + * + */ +#define MTREE_INIT(name, __flags) { \ + .ma_lock = __SPIN_LOCK_UNLOCKED((name).ma_lock), \ + .ma_flags = __flags, \ + .ma_root = NULL, \ +} + +/** + * MTREE_INIT_EXT() - Initialize a maple tree with an external lock. + * @name: The tree name + * @__flags: The maple tree flags + * @__lock: The external lock + */ +#ifdef CONFIG_LOCKDEP +#define MTREE_INIT_EXT(name, __flags, __lock) { \ + .ma_external_lock = &(__lock).dep_map, \ + .ma_flags = (__flags), \ + .ma_root = NULL, \ +} +#else +#define MTREE_INIT_EXT(name, __flags, __lock) MTREE_INIT(name, __flags) +#endif + +#define DEFINE_MTREE(name) \ + struct maple_tree name = MTREE_INIT(name, 0) + +#define mtree_lock(mt) spin_lock((&(mt)->ma_lock)) +#define mtree_unlock(mt) spin_unlock((&(mt)->ma_lock)) + +/* + * The Maple Tree squeezes various bits in at various points which aren't + * necessarily obvious. Usually, this is done by observing that pointers are + * N-byte aligned and thus the bottom log_2(N) bits are available for use. We + * don't use the high bits of pointers to store additional information because + * we don't know what bits are unused on any given architecture. + * + * Nodes are 256 bytes in size and are also aligned to 256 bytes, giving us 8 + * low bits for our own purposes. Nodes are currently of 4 types: + * 1. Single pointer (Range is 0-0) + * 2. Non-leaf Allocation Range nodes + * 3. Non-leaf Range nodes + * 4. Leaf Range nodes All nodes consist of a number of node slots, + * pivots, and a parent pointer. + */ + +struct maple_node { + union { + struct { + struct maple_pnode *parent; + void __rcu *slot[MAPLE_NODE_SLOTS]; + }; + struct { + void *pad; + struct rcu_head rcu; + struct maple_enode *piv_parent; + unsigned char parent_slot; + enum maple_type type; + unsigned char slot_len; + unsigned int ma_flags; + }; + struct maple_range_64 mr64; + struct maple_arange_64 ma64; + struct maple_alloc alloc; + }; +}; + +/* + * More complicated stores can cause two nodes to become one or three and + * potentially alter the height of the tree. Either half of the tree may need + * to be rebalanced against the other. The ma_topiary struct is used to track + * which nodes have been 'cut' from the tree so that the change can be done + * safely at a later date. This is done to support RCU. + */ +struct ma_topiary { + struct maple_enode *head; + struct maple_enode *tail; + struct maple_tree *mtree; +}; + +void *mtree_load(struct maple_tree *mt, unsigned long index); + +int mtree_insert(struct maple_tree *mt, unsigned long index, + void *entry, gfp_t gfp); +int mtree_insert_range(struct maple_tree *mt, unsigned long first, + unsigned long last, void *entry, gfp_t gfp); +int mtree_alloc_range(struct maple_tree *mt, unsigned long *startp, + void *entry, unsigned long size, unsigned long min, + unsigned long max, gfp_t gfp); +int mtree_alloc_rrange(struct maple_tree *mt, unsigned long *startp, + void *entry, unsigned long size, unsigned long min, + unsigned long max, gfp_t gfp); + +int mtree_store_range(struct maple_tree *mt, unsigned long first, + unsigned long last, void *entry, gfp_t gfp); +int mtree_store(struct maple_tree *mt, unsigned long index, + void *entry, gfp_t gfp); +void *mtree_erase(struct maple_tree *mt, unsigned long index); + +void mtree_destroy(struct maple_tree *mt); +void __mt_destroy(struct maple_tree *mt); + +/** + * mtree_empty() - Determine if a tree has any present entries. + * @mt: Maple Tree. + * + * Context: Any context. + * Return: %true if the tree contains only NULL pointers. + */ +static inline bool mtree_empty(const struct maple_tree *mt) +{ + return mt->ma_root == NULL; +} + +/* Advanced API */ + +/* + * The maple state is defined in the struct ma_state and is used to keep track + * of information during operations, and even between operations when using the + * advanced API. + * + * If state->node has bit 0 set then it references a tree location which is not + * a node (eg the root). If bit 1 is set, the rest of the bits are a negative + * errno. Bit 2 (the 'unallocated slots' bit) is clear. Bits 3-6 indicate the + * node type. + * + * state->alloc either has a request number of nodes or an allocated node. If + * stat->alloc has a requested number of nodes, the first bit will be set (0x1) + * and the remaining bits are the value. If state->alloc is a node, then the + * node will be of type maple_alloc. maple_alloc has MAPLE_NODE_SLOTS - 1 for + * storing more allocated nodes, a total number of nodes allocated, and the + * node_count in this node. node_count is the number of allocated nodes in this + * node. The scaling beyond MAPLE_NODE_SLOTS - 1 is handled by storing further + * nodes into state->alloc->slot[0]'s node. Nodes are taken from state->alloc + * by removing a node from the state->alloc node until state->alloc->node_count + * is 1, when state->alloc is returned and the state->alloc->slot[0] is promoted + * to state->alloc. Nodes are pushed onto state->alloc by putting the current + * state->alloc into the pushed node's slot[0]. + * + * The state also contains the implied min/max of the state->node, the depth of + * this search, and the offset. The implied min/max are either from the parent + * node or are 0-oo for the root node. The depth is incremented or decremented + * every time a node is walked down or up. The offset is the slot/pivot of + * interest in the node - either for reading or writing. + * + * When returning a value the maple state index and last respectively contain + * the start and end of the range for the entry. Ranges are inclusive in the + * Maple Tree. + */ +struct ma_state { + struct maple_tree *tree; /* The tree we're operating in */ + unsigned long index; /* The index we're operating on - range start */ + unsigned long last; /* The last index we're operating on - range end */ + struct maple_enode *node; /* The node containing this entry */ + unsigned long min; /* The minimum index of this node - implied pivot min */ + unsigned long max; /* The maximum index of this node - implied pivot max */ + struct maple_alloc *alloc; /* Allocated nodes for this operation */ + unsigned char depth; /* depth of tree descent during write */ + unsigned char offset; + unsigned char mas_flags; +}; + +struct ma_wr_state { + struct ma_state *mas; + struct maple_node *node; /* Decoded mas->node */ + unsigned long r_min; /* range min */ + unsigned long r_max; /* range max */ + enum maple_type type; /* mas->node type */ + unsigned char offset_end; /* The offset where the write ends */ + unsigned char node_end; /* mas->node end */ + unsigned long *pivots; /* mas->node->pivots pointer */ + unsigned long end_piv; /* The pivot at the offset end */ + void __rcu **slots; /* mas->node->slots pointer */ + void *entry; /* The entry to write */ + void *content; /* The existing entry that is being overwritten */ +}; + +#define mas_lock(mas) spin_lock(&((mas)->tree->ma_lock)) +#define mas_unlock(mas) spin_unlock(&((mas)->tree->ma_lock)) + + +/* + * Special values for ma_state.node. + * MAS_START means we have not searched the tree. + * MAS_ROOT means we have searched the tree and the entry we found lives in + * the root of the tree (ie it has index 0, length 1 and is the only entry in + * the tree). + * MAS_NONE means we have searched the tree and there is no node in the + * tree for this entry. For example, we searched for index 1 in an empty + * tree. Or we have a tree which points to a full leaf node and we + * searched for an entry which is larger than can be contained in that + * leaf node. + * MA_ERROR represents an errno. After dropping the lock and attempting + * to resolve the error, the walk would have to be restarted from the + * top of the tree as the tree may have been modified. + */ +#define MAS_START ((struct maple_enode *)1UL) +#define MAS_ROOT ((struct maple_enode *)5UL) +#define MAS_NONE ((struct maple_enode *)9UL) +#define MAS_PAUSE ((struct maple_enode *)17UL) +#define MA_ERROR(err) \ + ((struct maple_enode *)(((unsigned long)err << 2) | 2UL)) + +#define MA_STATE(name, mt, first, end) \ + struct ma_state name = { \ + .tree = mt, \ + .index = first, \ + .last = end, \ + .node = MAS_START, \ + .min = 0, \ + .max = ULONG_MAX, \ + .alloc = NULL, \ + } + +#define MA_WR_STATE(name, ma_state, wr_entry) \ + struct ma_wr_state name = { \ + .mas = ma_state, \ + .content = NULL, \ + .entry = wr_entry, \ + } + +#define MA_TOPIARY(name, tree) \ + struct ma_topiary name = { \ + .head = NULL, \ + .tail = NULL, \ + .mtree = tree, \ + } + +void *mas_walk(struct ma_state *mas); +void *mas_store(struct ma_state *mas, void *entry); +void *mas_erase(struct ma_state *mas); +int mas_store_gfp(struct ma_state *mas, void *entry, gfp_t gfp); +void mas_store_prealloc(struct ma_state *mas, void *entry); +void *mas_find(struct ma_state *mas, unsigned long max); +void *mas_find_rev(struct ma_state *mas, unsigned long min); +int mas_preallocate(struct ma_state *mas, void *entry, gfp_t gfp); +bool mas_is_err(struct ma_state *mas); + +bool mas_nomem(struct ma_state *mas, gfp_t gfp); +void mas_pause(struct ma_state *mas); +void maple_tree_init(void); +void mas_destroy(struct ma_state *mas); +int mas_expected_entries(struct ma_state *mas, unsigned long nr_entries); + +void *mas_prev(struct ma_state *mas, unsigned long min); +void *mas_next(struct ma_state *mas, unsigned long max); + +int mas_empty_area(struct ma_state *mas, unsigned long min, unsigned long max, + unsigned long size); + +/* Checks if a mas has not found anything */ +static inline bool mas_is_none(struct ma_state *mas) +{ + return mas->node == MAS_NONE; +} + +/* Checks if a mas has been paused */ +static inline bool mas_is_paused(struct ma_state *mas) +{ + return mas->node == MAS_PAUSE; +} + +void mas_dup_tree(struct ma_state *oldmas, struct ma_state *mas); +void mas_dup_store(struct ma_state *mas, void *entry); + +/* + * This finds an empty area from the highest address to the lowest. + * AKA "Topdown" version, + */ +int mas_empty_area_rev(struct ma_state *mas, unsigned long min, + unsigned long max, unsigned long size); +/** + * mas_reset() - Reset a Maple Tree operation state. + * @mas: Maple Tree operation state. + * + * Resets the error or walk state of the @mas so future walks of the + * array will start from the root. Use this if you have dropped the + * lock and want to reuse the ma_state. + * + * Context: Any context. + */ +static inline void mas_reset(struct ma_state *mas) +{ + mas->node = MAS_START; +} + +/** + * mas_for_each() - Iterate over a range of the maple tree. + * @__mas: Maple Tree operation state (maple_state) + * @__entry: Entry retrieved from the tree + * @__max: maximum index to retrieve from the tree + * + * When returned, mas->index and mas->last will hold the entire range for the + * entry. + * + * Note: may return the zero entry. + * + */ +#define mas_for_each(__mas, __entry, __max) \ + while (((__entry) = mas_find((__mas), (__max))) != NULL) + + +/** + * mas_set_range() - Set up Maple Tree operation state for a different index. + * @mas: Maple Tree operation state. + * @start: New start of range in the Maple Tree. + * @last: New end of range in the Maple Tree. + * + * Move the operation state to refer to a different range. This will + * have the effect of starting a walk from the top; see mas_next() + * to move to an adjacent index. + */ +static inline +void mas_set_range(struct ma_state *mas, unsigned long start, unsigned long last) +{ + mas->index = start; + mas->last = last; + mas->node = MAS_START; +} + +/** + * mas_set() - Set up Maple Tree operation state for a different index. + * @mas: Maple Tree operation state. + * @index: New index into the Maple Tree. + * + * Move the operation state to refer to a different index. This will + * have the effect of starting a walk from the top; see mas_next() + * to move to an adjacent index. + */ +static inline void mas_set(struct ma_state *mas, unsigned long index) +{ + + mas_set_range(mas, index, index); +} + +static inline bool mt_external_lock(const struct maple_tree *mt) +{ + return (mt->ma_flags & MT_FLAGS_LOCK_MASK) == MT_FLAGS_LOCK_EXTERN; +} + +/** + * mt_init_flags() - Initialise an empty maple tree with flags. + * @mt: Maple Tree + * @flags: maple tree flags. + * + * If you need to initialise a Maple Tree with special flags (eg, an + * allocation tree), use this function. + * + * Context: Any context. + */ +static inline void mt_init_flags(struct maple_tree *mt, unsigned int flags) +{ + mt->ma_flags = flags; + if (!mt_external_lock(mt)) + spin_lock_init(&mt->ma_lock); + rcu_assign_pointer(mt->ma_root, NULL); +} + +/** + * mt_init() - Initialise an empty maple tree. + * @mt: Maple Tree + * + * An empty Maple Tree. + * + * Context: Any context. + */ +static inline void mt_init(struct maple_tree *mt) +{ + mt_init_flags(mt, 0); +} + +static inline bool mt_in_rcu(struct maple_tree *mt) +{ +#ifdef CONFIG_MAPLE_RCU_DISABLED + return false; +#endif + return mt->ma_flags & MT_FLAGS_USE_RCU; +} + +/** + * mt_clear_in_rcu() - Switch the tree to non-RCU mode. + * @mt: The Maple Tree + */ +static inline void mt_clear_in_rcu(struct maple_tree *mt) +{ + if (!mt_in_rcu(mt)) + return; + + if (mt_external_lock(mt)) { + BUG_ON(!mt_lock_is_held(mt)); + mt->ma_flags &= ~MT_FLAGS_USE_RCU; + } else { + mtree_lock(mt); + mt->ma_flags &= ~MT_FLAGS_USE_RCU; + mtree_unlock(mt); + } +} + +/** + * mt_set_in_rcu() - Switch the tree to RCU safe mode. + * @mt: The Maple Tree + */ +static inline void mt_set_in_rcu(struct maple_tree *mt) +{ + if (mt_in_rcu(mt)) + return; + + if (mt_external_lock(mt)) { + BUG_ON(!mt_lock_is_held(mt)); + mt->ma_flags |= MT_FLAGS_USE_RCU; + } else { + mtree_lock(mt); + mt->ma_flags |= MT_FLAGS_USE_RCU; + mtree_unlock(mt); + } +} + +static inline unsigned int mt_height(const struct maple_tree *mt) + +{ + return (mt->ma_flags & MT_FLAGS_HEIGHT_MASK) >> MT_FLAGS_HEIGHT_OFFSET; +} + +void *mt_find(struct maple_tree *mt, unsigned long *index, unsigned long max); +void *mt_find_after(struct maple_tree *mt, unsigned long *index, + unsigned long max); +void *mt_prev(struct maple_tree *mt, unsigned long index, unsigned long min); +void *mt_next(struct maple_tree *mt, unsigned long index, unsigned long max); + +/** + * mt_for_each - Iterate over each entry starting at index until max. + * @__tree: The Maple Tree + * @__entry: The current entry + * @__index: The index to update to track the location in the tree + * @__max: The maximum limit for @index + * + * Note: Will not return the zero entry. + */ +#define mt_for_each(__tree, __entry, __index, __max) \ + for (__entry = mt_find(__tree, &(__index), __max); \ + __entry; __entry = mt_find_after(__tree, &(__index), __max)) + + +#ifdef CONFIG_DEBUG_MAPLE_TREE +extern atomic_t maple_tree_tests_run; +extern atomic_t maple_tree_tests_passed; + +void mt_dump(const struct maple_tree *mt); +void mt_validate(struct maple_tree *mt); +void mt_cache_shrink(void); +#define MT_BUG_ON(__tree, __x) do { \ + atomic_inc(&maple_tree_tests_run); \ + if (__x) { \ + pr_info("BUG at %s:%d (%u)\n", \ + __func__, __LINE__, __x); \ + mt_dump(__tree); \ + pr_info("Pass: %u Run:%u\n", \ + atomic_read(&maple_tree_tests_passed), \ + atomic_read(&maple_tree_tests_run)); \ + dump_stack(); \ + } else { \ + atomic_inc(&maple_tree_tests_passed); \ + } \ +} while (0) +#else +#define MT_BUG_ON(__tree, __x) BUG_ON(__x) +#endif /* CONFIG_DEBUG_MAPLE_TREE */ + +#endif /*_LINUX_MAPLE_TREE_H */ diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h index 567f12323f55..877cbcbc6ed9 100644 --- a/include/linux/memcontrol.h +++ b/include/linux/memcontrol.h @@ -350,6 +350,11 @@ struct mem_cgroup { struct deferred_split deferred_split_queue; #endif +#ifdef CONFIG_LRU_GEN + /* per-memcg mm_struct list */ + struct lru_gen_mm_list mm_list; +#endif + struct mem_cgroup_per_node *nodeinfo[]; }; @@ -444,6 +449,7 @@ static inline struct obj_cgroup *__folio_objcg(struct folio *folio) * - LRU isolation * - lock_page_memcg() * - exclusive reference + * - mem_cgroup_trylock_pages() * * For a kmem folio a caller should hold an rcu read lock to protect memcg * associated with a kmem folio from being released. @@ -505,6 +511,7 @@ static inline struct mem_cgroup *folio_memcg_rcu(struct folio *folio) * - LRU isolation * - lock_page_memcg() * - exclusive reference + * - mem_cgroup_trylock_pages() * * For a kmem page a caller should hold an rcu read lock to protect memcg * associated with a kmem page from being released. @@ -959,6 +966,23 @@ void unlock_page_memcg(struct page *page); void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val); +/* try to stablize folio_memcg() for all the pages in a memcg */ +static inline bool mem_cgroup_trylock_pages(struct mem_cgroup *memcg) +{ + rcu_read_lock(); + + if (mem_cgroup_disabled() || !atomic_read(&memcg->moving_account)) + return true; + + rcu_read_unlock(); + return false; +} + +static inline void mem_cgroup_unlock_pages(void) +{ + rcu_read_unlock(); +} + /* idx can be of type enum memcg_stat_item or node_stat_item */ static inline void mod_memcg_state(struct mem_cgroup *memcg, int idx, int val) @@ -1433,6 +1457,18 @@ static inline void folio_memcg_unlock(struct folio *folio) { } +static inline bool mem_cgroup_trylock_pages(struct mem_cgroup *memcg) +{ + /* to match folio_memcg_rcu() */ + rcu_read_lock(); + return true; +} + +static inline void mem_cgroup_unlock_pages(void) +{ + rcu_read_unlock(); +} + static inline void mem_cgroup_handle_over_high(void) { } diff --git a/include/linux/mm.h b/include/linux/mm.h index 21f8b27bd9fd..9ac0e02e2238 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -661,6 +661,38 @@ static inline bool vma_is_accessible(struct vm_area_struct *vma) return vma->vm_flags & VM_ACCESS_FLAGS; } +static inline +struct vm_area_struct *vma_find(struct vma_iterator *vmi, unsigned long max) +{ + return mas_find(&vmi->mas, max); +} + +static inline struct vm_area_struct *vma_next(struct vma_iterator *vmi) +{ + /* + * Uses vma_find() to get the first VMA when the iterator starts. + * Calling mas_next() could skip the first entry. + */ + return vma_find(vmi, ULONG_MAX); +} + +static inline struct vm_area_struct *vma_prev(struct vma_iterator *vmi) +{ + return mas_prev(&vmi->mas, 0); +} + +static inline unsigned long vma_iter_addr(struct vma_iterator *vmi) +{ + return vmi->mas.index; +} + +#define for_each_vma(__vmi, __vma) \ + while (((__vma) = vma_next(&(__vmi))) != NULL) + +/* The MM code likes to work with exclusive end addresses */ +#define for_each_vma_range(__vmi, __vma, __end) \ + while (((__vma) = vma_find(&(__vmi), (__end) - 1)) != NULL) + #ifdef CONFIG_SHMEM /* * The vma_is_shmem is not inline because it is used only by slow @@ -1465,6 +1497,11 @@ static inline unsigned long folio_pfn(struct folio *folio) return page_to_pfn(&folio->page); } +static inline struct folio *pfn_folio(unsigned long pfn) +{ + return page_folio(pfn_to_page(pfn)); +} + static inline atomic_t *folio_pincount_ptr(struct folio *folio) { return &folio_page(folio, 1)->compound_pincount; @@ -1795,8 +1832,9 @@ void zap_vma_ptes(struct vm_area_struct *vma, unsigned long address, unsigned long size); void zap_page_range(struct vm_area_struct *vma, unsigned long address, unsigned long size); -void unmap_vmas(struct mmu_gather *tlb, struct vm_area_struct *start_vma, - unsigned long start, unsigned long end); +void unmap_vmas(struct mmu_gather *tlb, struct maple_tree *mt, + struct vm_area_struct *start_vma, unsigned long start, + unsigned long end); struct mmu_notifier_range; @@ -2593,14 +2631,15 @@ extern int __split_vma(struct mm_struct *, struct vm_area_struct *, extern int split_vma(struct mm_struct *, struct vm_area_struct *, unsigned long addr, int new_below); extern int insert_vm_struct(struct mm_struct *, struct vm_area_struct *); -extern void __vma_link_rb(struct mm_struct *, struct vm_area_struct *, - struct rb_node **, struct rb_node *); extern void unlink_file_vma(struct vm_area_struct *); extern struct vm_area_struct *copy_vma(struct vm_area_struct **, unsigned long addr, unsigned long len, pgoff_t pgoff, bool *need_rmap_locks); extern void exit_mmap(struct mm_struct *); +void vma_mas_store(struct vm_area_struct *vma, struct ma_state *mas); +void vma_mas_remove(struct vm_area_struct *vma, struct ma_state *mas); + static inline int check_data_rlimit(unsigned long rlim, unsigned long new, unsigned long start, @@ -2648,8 +2687,9 @@ extern unsigned long mmap_region(struct file *file, unsigned long addr, extern unsigned long do_mmap(struct file *file, unsigned long addr, unsigned long len, unsigned long prot, unsigned long flags, unsigned long pgoff, unsigned long *populate, struct list_head *uf); -extern int __do_munmap(struct mm_struct *, unsigned long, size_t, - struct list_head *uf, bool downgrade); +extern int do_mas_munmap(struct ma_state *mas, struct mm_struct *mm, + unsigned long start, size_t len, struct list_head *uf, + bool downgrade); extern int do_munmap(struct mm_struct *, unsigned long, size_t, struct list_head *uf); extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int behavior); @@ -2716,26 +2756,12 @@ extern struct vm_area_struct * find_vma(struct mm_struct * mm, unsigned long add extern struct vm_area_struct * find_vma_prev(struct mm_struct * mm, unsigned long addr, struct vm_area_struct **pprev); -/** - * find_vma_intersection() - Look up the first VMA which intersects the interval - * @mm: The process address space. - * @start_addr: The inclusive start user address. - * @end_addr: The exclusive end user address. - * - * Returns: The first VMA within the provided range, %NULL otherwise. Assumes - * start_addr < end_addr. +/* + * Look up the first VMA which intersects the interval [start_addr, end_addr) + * NULL if none. Assume start_addr < end_addr. */ -static inline struct vm_area_struct *find_vma_intersection(struct mm_struct *mm, - unsigned long start_addr, - unsigned long end_addr) -{ - struct vm_area_struct *vma = find_vma(mm, start_addr); - - if (vma && end_addr <= vma->vm_start) - vma = NULL; - return vma; -} + unsigned long start_addr, unsigned long end_addr); /** * vma_lookup() - Find a VMA at a specific address @@ -2747,12 +2773,7 @@ struct vm_area_struct *find_vma_intersection(struct mm_struct *mm, static inline struct vm_area_struct *vma_lookup(struct mm_struct *mm, unsigned long addr) { - struct vm_area_struct *vma = find_vma(mm, addr); - - if (vma && addr < vma->vm_start) - vma = NULL; - - return vma; + return mtree_load(&mm->mm_mt, addr); } static inline unsigned long vm_start_gap(struct vm_area_struct *vma) @@ -2788,7 +2809,7 @@ static inline unsigned long vma_pages(struct vm_area_struct *vma) static inline struct vm_area_struct *find_exact_vma(struct mm_struct *mm, unsigned long vm_start, unsigned long vm_end) { - struct vm_area_struct *vma = find_vma(mm, vm_start); + struct vm_area_struct *vma = vma_lookup(mm, vm_start); if (vma && (vma->vm_start != vm_start || vma->vm_end != vm_end)) vma = NULL; diff --git a/include/linux/mm_inline.h b/include/linux/mm_inline.h index 7b25b53c474a..4949eda9a9a2 100644 --- a/include/linux/mm_inline.h +++ b/include/linux/mm_inline.h @@ -34,15 +34,25 @@ static inline int page_is_file_lru(struct page *page) return folio_is_file_lru(page_folio(page)); } -static __always_inline void update_lru_size(struct lruvec *lruvec, +static __always_inline void __update_lru_size(struct lruvec *lruvec, enum lru_list lru, enum zone_type zid, long nr_pages) { struct pglist_data *pgdat = lruvec_pgdat(lruvec); + lockdep_assert_held(&lruvec->lru_lock); + WARN_ON_ONCE(nr_pages != (int)nr_pages); + __mod_lruvec_state(lruvec, NR_LRU_BASE + lru, nr_pages); __mod_zone_page_state(&pgdat->node_zones[zid], NR_ZONE_LRU_BASE + lru, nr_pages); +} + +static __always_inline void update_lru_size(struct lruvec *lruvec, + enum lru_list lru, enum zone_type zid, + long nr_pages) +{ + __update_lru_size(lruvec, lru, zid, nr_pages); #ifdef CONFIG_MEMCG mem_cgroup_update_lru_size(lruvec, lru, zid, nr_pages); #endif @@ -94,11 +104,224 @@ static __always_inline enum lru_list folio_lru_list(struct folio *folio) return lru; } +#ifdef CONFIG_LRU_GEN + +#ifdef CONFIG_LRU_GEN_ENABLED +static inline bool lru_gen_enabled(void) +{ + DECLARE_STATIC_KEY_TRUE(lru_gen_caps[NR_LRU_GEN_CAPS]); + + return static_branch_likely(&lru_gen_caps[LRU_GEN_CORE]); +} +#else +static inline bool lru_gen_enabled(void) +{ + DECLARE_STATIC_KEY_FALSE(lru_gen_caps[NR_LRU_GEN_CAPS]); + + return static_branch_unlikely(&lru_gen_caps[LRU_GEN_CORE]); +} +#endif + +static inline bool lru_gen_in_fault(void) +{ + return current->in_lru_fault; +} + +static inline int lru_gen_from_seq(unsigned long seq) +{ + return seq % MAX_NR_GENS; +} + +static inline int lru_hist_from_seq(unsigned long seq) +{ + return seq % NR_HIST_GENS; +} + +static inline int lru_tier_from_refs(int refs) +{ + VM_WARN_ON_ONCE(refs > BIT(LRU_REFS_WIDTH)); + + /* see the comment in folio_lru_refs() */ + return order_base_2(refs + 1); +} + +static inline int folio_lru_refs(struct folio *folio) +{ + unsigned long flags = READ_ONCE(folio->flags); + bool workingset = flags & BIT(PG_workingset); + + /* + * Return the number of accesses beyond PG_referenced, i.e., N-1 if the + * total number of accesses is N>1, since N=0,1 both map to the first + * tier. lru_tier_from_refs() will account for this off-by-one. Also see + * the comment on MAX_NR_TIERS. + */ + return ((flags & LRU_REFS_MASK) >> LRU_REFS_PGOFF) + workingset; +} + +static inline int folio_lru_gen(struct folio *folio) +{ + unsigned long flags = READ_ONCE(folio->flags); + + return ((flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1; +} + +static inline bool lru_gen_is_active(struct lruvec *lruvec, int gen) +{ + unsigned long max_seq = lruvec->lrugen.max_seq; + + VM_WARN_ON_ONCE(gen >= MAX_NR_GENS); + + /* see the comment on MIN_NR_GENS */ + return gen == lru_gen_from_seq(max_seq) || gen == lru_gen_from_seq(max_seq - 1); +} + +static inline void lru_gen_update_size(struct lruvec *lruvec, struct folio *folio, + int old_gen, int new_gen) +{ + int type = folio_is_file_lru(folio); + int zone = folio_zonenum(folio); + int delta = folio_nr_pages(folio); + enum lru_list lru = type * LRU_INACTIVE_FILE; + struct lru_gen_struct *lrugen = &lruvec->lrugen; + + VM_WARN_ON_ONCE(old_gen != -1 && old_gen >= MAX_NR_GENS); + VM_WARN_ON_ONCE(new_gen != -1 && new_gen >= MAX_NR_GENS); + VM_WARN_ON_ONCE(old_gen == -1 && new_gen == -1); + + if (old_gen >= 0) + WRITE_ONCE(lrugen->nr_pages[old_gen][type][zone], + lrugen->nr_pages[old_gen][type][zone] - delta); + if (new_gen >= 0) + WRITE_ONCE(lrugen->nr_pages[new_gen][type][zone], + lrugen->nr_pages[new_gen][type][zone] + delta); + + /* addition */ + if (old_gen < 0) { + if (lru_gen_is_active(lruvec, new_gen)) + lru += LRU_ACTIVE; + __update_lru_size(lruvec, lru, zone, delta); + return; + } + + /* deletion */ + if (new_gen < 0) { + if (lru_gen_is_active(lruvec, old_gen)) + lru += LRU_ACTIVE; + __update_lru_size(lruvec, lru, zone, -delta); + return; + } + + /* promotion */ + if (!lru_gen_is_active(lruvec, old_gen) && lru_gen_is_active(lruvec, new_gen)) { + __update_lru_size(lruvec, lru, zone, -delta); + __update_lru_size(lruvec, lru + LRU_ACTIVE, zone, delta); + } + + /* demotion requires isolation, e.g., lru_deactivate_fn() */ + VM_WARN_ON_ONCE(lru_gen_is_active(lruvec, old_gen) && !lru_gen_is_active(lruvec, new_gen)); +} + +static inline bool lru_gen_add_folio(struct lruvec *lruvec, struct folio *folio, bool reclaiming) +{ + unsigned long seq; + unsigned long flags; + int gen = folio_lru_gen(folio); + int type = folio_is_file_lru(folio); + int zone = folio_zonenum(folio); + struct lru_gen_struct *lrugen = &lruvec->lrugen; + + VM_WARN_ON_ONCE_FOLIO(gen != -1, folio); + + if (folio_test_unevictable(folio) || !lrugen->enabled) + return false; + /* + * There are three common cases for this page: + * 1. If it's hot, e.g., freshly faulted in or previously hot and + * migrated, add it to the youngest generation. + * 2. If it's cold but can't be evicted immediately, i.e., an anon page + * not in swapcache or a dirty page pending writeback, add it to the + * second oldest generation. + * 3. Everything else (clean, cold) is added to the oldest generation. + */ + if (folio_test_active(folio)) + seq = lrugen->max_seq; + else if ((type == LRU_GEN_ANON && !folio_test_swapcache(folio)) || + (folio_test_reclaim(folio) && + (folio_test_dirty(folio) || folio_test_writeback(folio)))) + seq = lrugen->min_seq[type] + 1; + else + seq = lrugen->min_seq[type]; + + gen = lru_gen_from_seq(seq); + flags = (gen + 1UL) << LRU_GEN_PGOFF; + /* see the comment on MIN_NR_GENS about PG_active */ + set_mask_bits(&folio->flags, LRU_GEN_MASK | BIT(PG_active), flags); + + lru_gen_update_size(lruvec, folio, -1, gen); + /* for folio_rotate_reclaimable() */ + if (reclaiming) + list_add_tail(&folio->lru, &lrugen->lists[gen][type][zone]); + else + list_add(&folio->lru, &lrugen->lists[gen][type][zone]); + + return true; +} + +static inline bool lru_gen_del_folio(struct lruvec *lruvec, struct folio *folio, bool reclaiming) +{ + unsigned long flags; + int gen = folio_lru_gen(folio); + + if (gen < 0) + return false; + + VM_WARN_ON_ONCE_FOLIO(folio_test_active(folio), folio); + VM_WARN_ON_ONCE_FOLIO(folio_test_unevictable(folio), folio); + + /* for folio_migrate_flags() */ + flags = !reclaiming && lru_gen_is_active(lruvec, gen) ? BIT(PG_active) : 0; + flags = set_mask_bits(&folio->flags, LRU_GEN_MASK, flags); + gen = ((flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1; + + lru_gen_update_size(lruvec, folio, gen, -1); + list_del(&folio->lru); + + return true; +} + +#else /* !CONFIG_LRU_GEN */ + +static inline bool lru_gen_enabled(void) +{ + return false; +} + +static inline bool lru_gen_in_fault(void) +{ + return false; +} + +static inline bool lru_gen_add_folio(struct lruvec *lruvec, struct folio *folio, bool reclaiming) +{ + return false; +} + +static inline bool lru_gen_del_folio(struct lruvec *lruvec, struct folio *folio, bool reclaiming) +{ + return false; +} + +#endif /* CONFIG_LRU_GEN */ + static __always_inline void lruvec_add_folio(struct lruvec *lruvec, struct folio *folio) { enum lru_list lru = folio_lru_list(folio); + if (lru_gen_add_folio(lruvec, folio, false)) + return; + update_lru_size(lruvec, lru, folio_zonenum(folio), folio_nr_pages(folio)); if (lru != LRU_UNEVICTABLE) @@ -116,6 +339,9 @@ void lruvec_add_folio_tail(struct lruvec *lruvec, struct folio *folio) { enum lru_list lru = folio_lru_list(folio); + if (lru_gen_add_folio(lruvec, folio, true)) + return; + update_lru_size(lruvec, lru, folio_zonenum(folio), folio_nr_pages(folio)); /* This is not expected to be used on LRU_UNEVICTABLE */ @@ -133,6 +359,9 @@ void lruvec_del_folio(struct lruvec *lruvec, struct folio *folio) { enum lru_list lru = folio_lru_list(folio); + if (lru_gen_del_folio(lruvec, folio, false)) + return; + if (lru != LRU_UNEVICTABLE) list_del(&folio->lru); update_lru_size(lruvec, lru, folio_zonenum(folio), diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index cf97f3884fda..5e32211cb5a9 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -407,21 +408,6 @@ struct vm_area_struct { unsigned long vm_end; /* The first byte after our end address within vm_mm. */ - /* linked list of VM areas per task, sorted by address */ - struct vm_area_struct *vm_next, *vm_prev; - - struct rb_node vm_rb; - - /* - * Largest free memory gap in bytes to the left of this VMA. - * Either between this VMA and vma->vm_prev, or between one of the - * VMAs below us in the VMA rbtree and its ->vm_prev. This helps - * get_unmapped_area find a free area of the right size. - */ - unsigned long rb_subtree_gap; - - /* Second cache line starts here. */ - struct mm_struct *vm_mm; /* The address space we belong to. */ /* @@ -485,9 +471,7 @@ struct vm_area_struct { struct kioctx_table; struct mm_struct { struct { - struct vm_area_struct *mmap; /* list of VMAs */ - struct rb_root mm_rb; - u64 vmacache_seqnum; /* per-thread vmacache */ + struct maple_tree mm_mt; #ifdef CONFIG_MMU unsigned long (*get_unmapped_area) (struct file *filp, unsigned long addr, unsigned long len, @@ -501,7 +485,6 @@ struct mm_struct { unsigned long mmap_compat_legacy_base; #endif unsigned long task_size; /* size of task vm space */ - unsigned long highest_vm_end; /* highest vma end address */ pgd_t * pgd; #ifdef CONFIG_MEMBARRIER @@ -672,6 +655,22 @@ struct mm_struct { */ unsigned long ksm_merging_pages; #endif +#ifdef CONFIG_LRU_GEN + struct { + /* this mm_struct is on lru_gen_mm_list */ + struct list_head list; + /* + * Set when switching to this mm_struct, as a hint of + * whether it has been used since the last time per-node + * page table walkers cleared the corresponding bits. + */ + unsigned long bitmap; +#ifdef CONFIG_MEMCG + /* points to the memcg of "owner" above */ + struct mem_cgroup *memcg; +#endif + } lru_gen; +#endif /* CONFIG_LRU_GEN */ } __randomize_layout; /* @@ -681,6 +680,7 @@ struct mm_struct { unsigned long cpu_bitmap[]; }; +#define MM_MT_FLAGS (MT_FLAGS_ALLOC_RANGE | MT_FLAGS_LOCK_EXTERN) extern struct mm_struct init_mm; /* Pointer magic because the dynamic array size confuses some compilers. */ @@ -698,6 +698,87 @@ static inline cpumask_t *mm_cpumask(struct mm_struct *mm) return (struct cpumask *)&mm->cpu_bitmap; } +#ifdef CONFIG_LRU_GEN + +struct lru_gen_mm_list { + /* mm_struct list for page table walkers */ + struct list_head fifo; + /* protects the list above */ + spinlock_t lock; +}; + +void lru_gen_add_mm(struct mm_struct *mm); +void lru_gen_del_mm(struct mm_struct *mm); +#ifdef CONFIG_MEMCG +void lru_gen_migrate_mm(struct mm_struct *mm); +#endif + +static inline void lru_gen_init_mm(struct mm_struct *mm) +{ + INIT_LIST_HEAD(&mm->lru_gen.list); + mm->lru_gen.bitmap = 0; +#ifdef CONFIG_MEMCG + mm->lru_gen.memcg = NULL; +#endif +} + +static inline void lru_gen_use_mm(struct mm_struct *mm) +{ + /* + * When the bitmap is set, page reclaim knows this mm_struct has been + * used since the last time it cleared the bitmap. So it might be worth + * walking the page tables of this mm_struct to clear the accessed bit. + */ + WRITE_ONCE(mm->lru_gen.bitmap, -1); +} + +#else /* !CONFIG_LRU_GEN */ + +static inline void lru_gen_add_mm(struct mm_struct *mm) +{ +} + +static inline void lru_gen_del_mm(struct mm_struct *mm) +{ +} + +#ifdef CONFIG_MEMCG +static inline void lru_gen_migrate_mm(struct mm_struct *mm) +{ +} +#endif + +static inline void lru_gen_init_mm(struct mm_struct *mm) +{ +} + +static inline void lru_gen_use_mm(struct mm_struct *mm) +{ +} + +#endif /* CONFIG_LRU_GEN */ + +struct vma_iterator { + struct ma_state mas; +}; + +#define VMA_ITERATOR(name, __mm, __addr) \ + struct vma_iterator name = { \ + .mas = { \ + .tree = &(__mm)->mm_mt, \ + .index = __addr, \ + .node = MAS_START, \ + }, \ + } + +static inline void vma_iter_init(struct vma_iterator *vmi, + struct mm_struct *mm, unsigned long addr) +{ + vmi->mas.tree = &mm->mm_mt; + vmi->mas.index = addr; + vmi->mas.node = MAS_START; +} + struct mmu_gather; extern void tlb_gather_mmu(struct mmu_gather *tlb, struct mm_struct *mm); extern void tlb_gather_mmu_fullmm(struct mmu_gather *tlb, struct mm_struct *mm); diff --git a/include/linux/mm_types_task.h b/include/linux/mm_types_task.h index c1bc6731125c..0bb4b6da9993 100644 --- a/include/linux/mm_types_task.h +++ b/include/linux/mm_types_task.h @@ -24,18 +24,6 @@ IS_ENABLED(CONFIG_ARCH_ENABLE_SPLIT_PMD_PTLOCK)) #define ALLOC_SPLIT_PTLOCKS (SPINLOCK_SIZE > BITS_PER_LONG/8) -/* - * The per task VMA cache array: - */ -#define VMACACHE_BITS 2 -#define VMACACHE_SIZE (1U << VMACACHE_BITS) -#define VMACACHE_MASK (VMACACHE_SIZE - 1) - -struct vmacache { - u64 seqnum; - struct vm_area_struct *vmas[VMACACHE_SIZE]; -}; - /* * When updating this, please also update struct resident_page_types[] in * kernel/fork.c diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h index e24b40c52468..1543001feba9 100644 --- a/include/linux/mmzone.h +++ b/include/linux/mmzone.h @@ -306,6 +306,8 @@ static inline bool is_active_lru(enum lru_list lru) return (lru == LRU_ACTIVE_ANON || lru == LRU_ACTIVE_FILE); } +#define WORKINGSET_ANON 0 +#define WORKINGSET_FILE 1 #define ANON_AND_FILE 2 enum lruvec_flags { @@ -314,6 +316,207 @@ enum lruvec_flags { */ }; +#endif /* !__GENERATING_BOUNDS_H */ + +/* + * Evictable pages are divided into multiple generations. The youngest and the + * oldest generation numbers, max_seq and min_seq, are monotonically increasing. + * They form a sliding window of a variable size [MIN_NR_GENS, MAX_NR_GENS]. An + * offset within MAX_NR_GENS, i.e., gen, indexes the LRU list of the + * corresponding generation. The gen counter in folio->flags stores gen+1 while + * a page is on one of lrugen->lists[]. Otherwise it stores 0. + * + * A page is added to the youngest generation on faulting. The aging needs to + * check the accessed bit at least twice before handing this page over to the + * eviction. The first check takes care of the accessed bit set on the initial + * fault; the second check makes sure this page hasn't been used since then. + * This process, AKA second chance, requires a minimum of two generations, + * hence MIN_NR_GENS. And to maintain ABI compatibility with the active/inactive + * LRU, e.g., /proc/vmstat, these two generations are considered active; the + * rest of generations, if they exist, are considered inactive. See + * lru_gen_is_active(). + * + * PG_active is always cleared while a page is on one of lrugen->lists[] so that + * the aging needs not to worry about it. And it's set again when a page + * considered active is isolated for non-reclaiming purposes, e.g., migration. + * See lru_gen_add_folio() and lru_gen_del_folio(). + * + * MAX_NR_GENS is set to 4 so that the multi-gen LRU can support twice the + * number of categories of the active/inactive LRU when keeping track of + * accesses through page tables. This requires order_base_2(MAX_NR_GENS+1) bits + * in folio->flags. + */ +#define MIN_NR_GENS 2U +#define MAX_NR_GENS 4U + +/* + * Each generation is divided into multiple tiers. A page accessed N times + * through file descriptors is in tier order_base_2(N). A page in the first tier + * (N=0,1) is marked by PG_referenced unless it was faulted in through page + * tables or read ahead. A page in any other tier (N>1) is marked by + * PG_referenced and PG_workingset. This implies a minimum of two tiers is + * supported without using additional bits in folio->flags. + * + * In contrast to moving across generations which requires the LRU lock, moving + * across tiers only involves atomic operations on folio->flags and therefore + * has a negligible cost in the buffered access path. In the eviction path, + * comparisons of refaulted/(evicted+protected) from the first tier and the + * rest infer whether pages accessed multiple times through file descriptors + * are statistically hot and thus worth protecting. + * + * MAX_NR_TIERS is set to 4 so that the multi-gen LRU can support twice the + * number of categories of the active/inactive LRU when keeping track of + * accesses through file descriptors. This uses MAX_NR_TIERS-2 spare bits in + * folio->flags. + */ +#define MAX_NR_TIERS 4U + +#ifndef __GENERATING_BOUNDS_H + +struct lruvec; +struct page_vma_mapped_walk; + +#define LRU_GEN_MASK ((BIT(LRU_GEN_WIDTH) - 1) << LRU_GEN_PGOFF) +#define LRU_REFS_MASK ((BIT(LRU_REFS_WIDTH) - 1) << LRU_REFS_PGOFF) + +#ifdef CONFIG_LRU_GEN + +enum { + LRU_GEN_ANON, + LRU_GEN_FILE, +}; + +enum { + LRU_GEN_CORE, + LRU_GEN_MM_WALK, + LRU_GEN_NONLEAF_YOUNG, + NR_LRU_GEN_CAPS +}; + +#define MIN_LRU_BATCH BITS_PER_LONG +#define MAX_LRU_BATCH (MIN_LRU_BATCH * 64) + +/* whether to keep historical stats from evicted generations */ +#ifdef CONFIG_LRU_GEN_STATS +#define NR_HIST_GENS MAX_NR_GENS +#else +#define NR_HIST_GENS 1U +#endif + +/* + * The youngest generation number is stored in max_seq for both anon and file + * types as they are aged on an equal footing. The oldest generation numbers are + * stored in min_seq[] separately for anon and file types as clean file pages + * can be evicted regardless of swap constraints. + * + * Normally anon and file min_seq are in sync. But if swapping is constrained, + * e.g., out of swap space, file min_seq is allowed to advance and leave anon + * min_seq behind. + * + * The number of pages in each generation is eventually consistent and therefore + * can be transiently negative when reset_batch_size() is pending. + */ +struct lru_gen_struct { + /* the aging increments the youngest generation number */ + unsigned long max_seq; + /* the eviction increments the oldest generation numbers */ + unsigned long min_seq[ANON_AND_FILE]; + /* the birth time of each generation in jiffies */ + unsigned long timestamps[MAX_NR_GENS]; + /* the multi-gen LRU lists, lazily sorted on eviction */ + struct list_head lists[MAX_NR_GENS][ANON_AND_FILE][MAX_NR_ZONES]; + /* the multi-gen LRU sizes, eventually consistent */ + long nr_pages[MAX_NR_GENS][ANON_AND_FILE][MAX_NR_ZONES]; + /* the exponential moving average of refaulted */ + unsigned long avg_refaulted[ANON_AND_FILE][MAX_NR_TIERS]; + /* the exponential moving average of evicted+protected */ + unsigned long avg_total[ANON_AND_FILE][MAX_NR_TIERS]; + /* the first tier doesn't need protection, hence the minus one */ + unsigned long protected[NR_HIST_GENS][ANON_AND_FILE][MAX_NR_TIERS - 1]; + /* can be modified without holding the LRU lock */ + atomic_long_t evicted[NR_HIST_GENS][ANON_AND_FILE][MAX_NR_TIERS]; + atomic_long_t refaulted[NR_HIST_GENS][ANON_AND_FILE][MAX_NR_TIERS]; + /* whether the multi-gen LRU is enabled */ + bool enabled; +}; + +enum { + MM_LEAF_TOTAL, /* total leaf entries */ + MM_LEAF_OLD, /* old leaf entries */ + MM_LEAF_YOUNG, /* young leaf entries */ + MM_NONLEAF_TOTAL, /* total non-leaf entries */ + MM_NONLEAF_FOUND, /* non-leaf entries found in Bloom filters */ + MM_NONLEAF_ADDED, /* non-leaf entries added to Bloom filters */ + NR_MM_STATS +}; + +/* double-buffering Bloom filters */ +#define NR_BLOOM_FILTERS 2 + +struct lru_gen_mm_state { + /* set to max_seq after each iteration */ + unsigned long seq; + /* where the current iteration continues (inclusive) */ + struct list_head *head; + /* where the last iteration ended (exclusive) */ + struct list_head *tail; + /* to wait for the last page table walker to finish */ + struct wait_queue_head wait; + /* Bloom filters flip after each iteration */ + unsigned long *filters[NR_BLOOM_FILTERS]; + /* the mm stats for debugging */ + unsigned long stats[NR_HIST_GENS][NR_MM_STATS]; + /* the number of concurrent page table walkers */ + int nr_walkers; +}; + +struct lru_gen_mm_walk { + /* the lruvec under reclaim */ + struct lruvec *lruvec; + /* unstable max_seq from lru_gen_struct */ + unsigned long max_seq; + /* the next address within an mm to scan */ + unsigned long next_addr; + /* to batch promoted pages */ + int nr_pages[MAX_NR_GENS][ANON_AND_FILE][MAX_NR_ZONES]; + /* to batch the mm stats */ + int mm_stats[NR_MM_STATS]; + /* total batched items */ + int batched; + bool can_swap; + bool force_scan; +}; + +void lru_gen_init_lruvec(struct lruvec *lruvec); +void lru_gen_look_around(struct page_vma_mapped_walk *pvmw); + +#ifdef CONFIG_MEMCG +void lru_gen_init_memcg(struct mem_cgroup *memcg); +void lru_gen_exit_memcg(struct mem_cgroup *memcg); +#endif + +#else /* !CONFIG_LRU_GEN */ + +static inline void lru_gen_init_lruvec(struct lruvec *lruvec) +{ +} + +static inline void lru_gen_look_around(struct page_vma_mapped_walk *pvmw) +{ +} + +#ifdef CONFIG_MEMCG +static inline void lru_gen_init_memcg(struct mem_cgroup *memcg) +{ +} + +static inline void lru_gen_exit_memcg(struct mem_cgroup *memcg) +{ +} +#endif + +#endif /* CONFIG_LRU_GEN */ + struct lruvec { struct list_head lists[NR_LRU_LISTS]; /* per lruvec lru_lock for memcg */ @@ -331,6 +534,12 @@ struct lruvec { unsigned long refaults[ANON_AND_FILE]; /* Various lruvec state flags (enum lruvec_flags) */ unsigned long flags; +#ifdef CONFIG_LRU_GEN + /* evictable pages divided into generations */ + struct lru_gen_struct lrugen; + /* to concurrently iterate lru_gen_mm_list */ + struct lru_gen_mm_state mm_state; +#endif #ifdef CONFIG_MEMCG struct pglist_data *pgdat; #endif @@ -746,6 +955,8 @@ static inline bool zone_is_empty(struct zone *zone) #define ZONES_PGOFF (NODES_PGOFF - ZONES_WIDTH) #define LAST_CPUPID_PGOFF (ZONES_PGOFF - LAST_CPUPID_WIDTH) #define KASAN_TAG_PGOFF (LAST_CPUPID_PGOFF - KASAN_TAG_WIDTH) +#define LRU_GEN_PGOFF (KASAN_TAG_PGOFF - LRU_GEN_WIDTH) +#define LRU_REFS_PGOFF (LRU_GEN_PGOFF - LRU_REFS_WIDTH) /* * Define the bit shifts to access each section. For non-existent @@ -1007,6 +1218,11 @@ typedef struct pglist_data { unsigned long flags; +#ifdef CONFIG_LRU_GEN + /* kswap mm walk data */ + struct lru_gen_mm_walk mm_walk; +#endif + ZONE_PADDING(_pad2_) /* Per-node vmstats */ diff --git a/include/linux/nodemask.h b/include/linux/nodemask.h index 4b71a96190a8..3a0eec9f2faa 100644 --- a/include/linux/nodemask.h +++ b/include/linux/nodemask.h @@ -493,6 +493,7 @@ static inline int num_node_state(enum node_states state) #define first_online_node 0 #define first_memory_node 0 #define next_online_node(nid) (MAX_NUMNODES) +#define next_memory_node(nid) (MAX_NUMNODES) #define nr_node_ids 1U #define nr_online_nodes 1U diff --git a/include/linux/page-flags-layout.h b/include/linux/page-flags-layout.h index ef1e3e736e14..7d79818dc065 100644 --- a/include/linux/page-flags-layout.h +++ b/include/linux/page-flags-layout.h @@ -55,7 +55,8 @@ #define SECTIONS_WIDTH 0 #endif -#if ZONES_WIDTH + SECTIONS_WIDTH + NODES_SHIFT <= BITS_PER_LONG - NR_PAGEFLAGS +#if ZONES_WIDTH + LRU_GEN_WIDTH + SECTIONS_WIDTH + NODES_SHIFT \ + <= BITS_PER_LONG - NR_PAGEFLAGS #define NODES_WIDTH NODES_SHIFT #elif defined(CONFIG_SPARSEMEM_VMEMMAP) #error "Vmemmap: No space for nodes field in page flags" @@ -89,8 +90,8 @@ #define LAST_CPUPID_SHIFT 0 #endif -#if ZONES_WIDTH + SECTIONS_WIDTH + NODES_WIDTH + KASAN_TAG_WIDTH + LAST_CPUPID_SHIFT \ - <= BITS_PER_LONG - NR_PAGEFLAGS +#if ZONES_WIDTH + LRU_GEN_WIDTH + SECTIONS_WIDTH + NODES_WIDTH + \ + KASAN_TAG_WIDTH + LAST_CPUPID_SHIFT <= BITS_PER_LONG - NR_PAGEFLAGS #define LAST_CPUPID_WIDTH LAST_CPUPID_SHIFT #else #define LAST_CPUPID_WIDTH 0 @@ -100,10 +101,15 @@ #define LAST_CPUPID_NOT_IN_PAGE_FLAGS #endif -#if ZONES_WIDTH + SECTIONS_WIDTH + NODES_WIDTH + KASAN_TAG_WIDTH + LAST_CPUPID_WIDTH \ - > BITS_PER_LONG - NR_PAGEFLAGS +#if ZONES_WIDTH + LRU_GEN_WIDTH + SECTIONS_WIDTH + NODES_WIDTH + \ + KASAN_TAG_WIDTH + LAST_CPUPID_WIDTH > BITS_PER_LONG - NR_PAGEFLAGS #error "Not enough bits in page flags" #endif +/* see the comment on MAX_NR_TIERS */ +#define LRU_REFS_WIDTH min(__LRU_REFS_WIDTH, BITS_PER_LONG - NR_PAGEFLAGS - \ + ZONES_WIDTH - LRU_GEN_WIDTH - SECTIONS_WIDTH - \ + NODES_WIDTH - KASAN_TAG_WIDTH - LAST_CPUPID_WIDTH) + #endif #endif /* _LINUX_PAGE_FLAGS_LAYOUT */ diff --git a/include/linux/page-flags.h b/include/linux/page-flags.h index 465ff35a8c00..0b0ae5084e60 100644 --- a/include/linux/page-flags.h +++ b/include/linux/page-flags.h @@ -1058,7 +1058,7 @@ static __always_inline void __ClearPageAnonExclusive(struct page *page) 1UL << PG_private | 1UL << PG_private_2 | \ 1UL << PG_writeback | 1UL << PG_reserved | \ 1UL << PG_slab | 1UL << PG_active | \ - 1UL << PG_unevictable | __PG_MLOCKED) + 1UL << PG_unevictable | __PG_MLOCKED | LRU_GEN_MASK) /* * Flags checked when a page is prepped for return by the page allocator. @@ -1069,7 +1069,7 @@ static __always_inline void __ClearPageAnonExclusive(struct page *page) * alloc-free cycle to prevent from reusing the page. */ #define PAGE_FLAGS_CHECK_AT_PREP \ - (PAGEFLAGS_MASK & ~__PG_HWPOISON) + ((PAGEFLAGS_MASK & ~__PG_HWPOISON) | LRU_GEN_MASK | LRU_REFS_MASK) #define PAGE_FLAGS_PRIVATE \ (1UL << PG_private | 1UL << PG_private_2) diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h index 014ee8f0fbaa..d9095251bffd 100644 --- a/include/linux/pgtable.h +++ b/include/linux/pgtable.h @@ -213,7 +213,7 @@ static inline int ptep_test_and_clear_young(struct vm_area_struct *vma, #endif #ifndef __HAVE_ARCH_PMDP_TEST_AND_CLEAR_YOUNG -#ifdef CONFIG_TRANSPARENT_HUGEPAGE +#if defined(CONFIG_TRANSPARENT_HUGEPAGE) || defined(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG) static inline int pmdp_test_and_clear_young(struct vm_area_struct *vma, unsigned long address, pmd_t *pmdp) @@ -234,7 +234,7 @@ static inline int pmdp_test_and_clear_young(struct vm_area_struct *vma, BUILD_BUG(); return 0; } -#endif /* CONFIG_TRANSPARENT_HUGEPAGE */ +#endif /* CONFIG_TRANSPARENT_HUGEPAGE || CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG */ #endif #ifndef __HAVE_ARCH_PTEP_CLEAR_YOUNG_FLUSH @@ -260,6 +260,19 @@ static inline int pmdp_clear_flush_young(struct vm_area_struct *vma, #endif /* CONFIG_TRANSPARENT_HUGEPAGE */ #endif +#ifndef arch_has_hw_pte_young +/* + * Return whether the accessed bit is supported on the local CPU. + * + * This stub assumes accessing through an old PTE triggers a page fault. + * Architectures that automatically set the access bit should overwrite it. + */ +static inline bool arch_has_hw_pte_young(void) +{ + return false; +} +#endif + #ifndef __HAVE_ARCH_PTEP_GET_AND_CLEAR static inline pte_t ptep_get_and_clear(struct mm_struct *mm, unsigned long address, diff --git a/include/linux/sched.h b/include/linux/sched.h index 8d82d6d32670..d929238024bd 100644 --- a/include/linux/sched.h +++ b/include/linux/sched.h @@ -861,7 +861,6 @@ struct task_struct { struct mm_struct *active_mm; /* Per-thread vma caching: */ - struct vmacache vmacache; #ifdef SPLIT_RSS_COUNTING struct task_rss_stat rss_stat; @@ -914,6 +913,10 @@ struct task_struct { #ifdef CONFIG_MEMCG unsigned in_user_fault:1; #endif +#ifdef CONFIG_LRU_GEN + /* whether the LRU algorithm may apply to this access */ + unsigned in_lru_fault:1; +#endif #ifdef CONFIG_COMPAT_BRK unsigned brk_randomized:1; #endif diff --git a/include/linux/swap.h b/include/linux/swap.h index 43150b9bbc5c..6308150b234a 100644 --- a/include/linux/swap.h +++ b/include/linux/swap.h @@ -162,6 +162,10 @@ union swap_header { */ struct reclaim_state { unsigned long reclaimed_slab; +#ifdef CONFIG_LRU_GEN + /* per-thread mm walk data */ + struct lru_gen_mm_walk *mm_walk; +#endif }; #ifdef __KERNEL__ diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index 31d86b8c0634..9df0b9a762cc 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -175,9 +175,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, unsigned long start, unsigned long end); -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, - unsigned long start, unsigned long end, - struct list_head *uf); +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, + unsigned long end, struct list_head *uf); extern void userfaultfd_unmap_complete(struct mm_struct *mm, struct list_head *uf); @@ -258,7 +257,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, return true; } -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, unsigned long end, struct list_head *uf) { diff --git a/include/linux/vm_event_item.h b/include/linux/vm_event_item.h index f3fc36cd2276..3518dba1e02f 100644 --- a/include/linux/vm_event_item.h +++ b/include/linux/vm_event_item.h @@ -129,10 +129,6 @@ enum vm_event_item { PGPGIN, PGPGOUT, PSWPIN, PSWPOUT, NR_TLB_LOCAL_FLUSH_ALL, NR_TLB_LOCAL_FLUSH_ONE, #endif /* CONFIG_DEBUG_TLBFLUSH */ -#ifdef CONFIG_DEBUG_VM_VMACACHE - VMACACHE_FIND_CALLS, - VMACACHE_FIND_HITS, -#endif #ifdef CONFIG_SWAP SWAP_RA, SWAP_RA_HIT, diff --git a/include/linux/vmacache.h b/include/linux/vmacache.h deleted file mode 100644 index 6fce268a4588..000000000000 --- a/include/linux/vmacache.h +++ /dev/null @@ -1,28 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 */ -#ifndef __LINUX_VMACACHE_H -#define __LINUX_VMACACHE_H - -#include -#include - -static inline void vmacache_flush(struct task_struct *tsk) -{ - memset(tsk->vmacache.vmas, 0, sizeof(tsk->vmacache.vmas)); -} - -extern void vmacache_update(unsigned long addr, struct vm_area_struct *newvma); -extern struct vm_area_struct *vmacache_find(struct mm_struct *mm, - unsigned long addr); - -#ifndef CONFIG_MMU -extern struct vm_area_struct *vmacache_find_exact(struct mm_struct *mm, - unsigned long start, - unsigned long end); -#endif - -static inline void vmacache_invalidate(struct mm_struct *mm) -{ - mm->vmacache_seqnum++; -} - -#endif /* __LINUX_VMACACHE_H */ diff --git a/include/linux/vmstat.h b/include/linux/vmstat.h index bfe38869498d..19cf5b6892ce 100644 --- a/include/linux/vmstat.h +++ b/include/linux/vmstat.h @@ -125,12 +125,6 @@ static inline void vm_events_fold_cpu(int cpu) #define count_vm_tlb_events(x, y) do { (void)(y); } while (0) #endif -#ifdef CONFIG_DEBUG_VM_VMACACHE -#define count_vm_vmacache_event(x) count_vm_event(x) -#else -#define count_vm_vmacache_event(x) do {} while (0) -#endif - #define __count_zid_vm_events(item, zid, delta) \ __count_vm_events(item##_NORMAL - ZONE_NORMAL + zid, delta) diff --git a/include/trace/events/maple_tree.h b/include/trace/events/maple_tree.h new file mode 100644 index 000000000000..2be403bdc2bd --- /dev/null +++ b/include/trace/events/maple_tree.h @@ -0,0 +1,123 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#undef TRACE_SYSTEM +#define TRACE_SYSTEM maple_tree + +#if !defined(_TRACE_MM_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_MM_H + + +#include + +struct ma_state; + +TRACE_EVENT(ma_op, + + TP_PROTO(const char *fn, struct ma_state *mas), + + TP_ARGS(fn, mas), + + TP_STRUCT__entry( + __field(const char *, fn) + __field(unsigned long, min) + __field(unsigned long, max) + __field(unsigned long, index) + __field(unsigned long, last) + __field(void *, node) + ), + + TP_fast_assign( + __entry->fn = fn; + __entry->min = mas->min; + __entry->max = mas->max; + __entry->index = mas->index; + __entry->last = mas->last; + __entry->node = mas->node; + ), + + TP_printk("%s\tNode: %p (%lu %lu) range: %lu-%lu", + __entry->fn, + (void *) __entry->node, + (unsigned long) __entry->min, + (unsigned long) __entry->max, + (unsigned long) __entry->index, + (unsigned long) __entry->last + ) +) +TRACE_EVENT(ma_read, + + TP_PROTO(const char *fn, struct ma_state *mas), + + TP_ARGS(fn, mas), + + TP_STRUCT__entry( + __field(const char *, fn) + __field(unsigned long, min) + __field(unsigned long, max) + __field(unsigned long, index) + __field(unsigned long, last) + __field(void *, node) + ), + + TP_fast_assign( + __entry->fn = fn; + __entry->min = mas->min; + __entry->max = mas->max; + __entry->index = mas->index; + __entry->last = mas->last; + __entry->node = mas->node; + ), + + TP_printk("%s\tNode: %p (%lu %lu) range: %lu-%lu", + __entry->fn, + (void *) __entry->node, + (unsigned long) __entry->min, + (unsigned long) __entry->max, + (unsigned long) __entry->index, + (unsigned long) __entry->last + ) +) + +TRACE_EVENT(ma_write, + + TP_PROTO(const char *fn, struct ma_state *mas, unsigned long piv, + void *val), + + TP_ARGS(fn, mas, piv, val), + + TP_STRUCT__entry( + __field(const char *, fn) + __field(unsigned long, min) + __field(unsigned long, max) + __field(unsigned long, index) + __field(unsigned long, last) + __field(unsigned long, piv) + __field(void *, val) + __field(void *, node) + ), + + TP_fast_assign( + __entry->fn = fn; + __entry->min = mas->min; + __entry->max = mas->max; + __entry->index = mas->index; + __entry->last = mas->last; + __entry->piv = piv; + __entry->val = val; + __entry->node = mas->node; + ), + + TP_printk("%s\tNode %p (%lu %lu) range:%lu-%lu piv (%lu) val %p", + __entry->fn, + (void *) __entry->node, + (unsigned long) __entry->min, + (unsigned long) __entry->max, + (unsigned long) __entry->index, + (unsigned long) __entry->last, + (unsigned long) __entry->piv, + (void *) __entry->val + ) +) +#endif /* _TRACE_MM_H */ + +/* This part must be outside protection */ +#include diff --git a/include/trace/events/mmap.h b/include/trace/events/mmap.h index 4661f7ba07c0..216de5f03621 100644 --- a/include/trace/events/mmap.h +++ b/include/trace/events/mmap.h @@ -42,6 +42,79 @@ TRACE_EVENT(vm_unmapped_area, __entry->low_limit, __entry->high_limit, __entry->align_mask, __entry->align_offset) ); + +TRACE_EVENT(vma_mas_szero, + TP_PROTO(struct maple_tree *mt, unsigned long start, + unsigned long end), + + TP_ARGS(mt, start, end), + + TP_STRUCT__entry( + __field(struct maple_tree *, mt) + __field(unsigned long, start) + __field(unsigned long, end) + ), + + TP_fast_assign( + __entry->mt = mt; + __entry->start = start; + __entry->end = end; + ), + + TP_printk("mt_mod %p, (NULL), SNULL, %lu, %lu,", + __entry->mt, + (unsigned long) __entry->start, + (unsigned long) __entry->end + ) +); + +TRACE_EVENT(vma_store, + TP_PROTO(struct maple_tree *mt, struct vm_area_struct *vma), + + TP_ARGS(mt, vma), + + TP_STRUCT__entry( + __field(struct maple_tree *, mt) + __field(struct vm_area_struct *, vma) + __field(unsigned long, vm_start) + __field(unsigned long, vm_end) + ), + + TP_fast_assign( + __entry->mt = mt; + __entry->vma = vma; + __entry->vm_start = vma->vm_start; + __entry->vm_end = vma->vm_end - 1; + ), + + TP_printk("mt_mod %p, (%p), STORE, %lu, %lu,", + __entry->mt, __entry->vma, + (unsigned long) __entry->vm_start, + (unsigned long) __entry->vm_end + ) +); + + +TRACE_EVENT(exit_mmap, + TP_PROTO(struct mm_struct *mm), + + TP_ARGS(mm), + + TP_STRUCT__entry( + __field(struct mm_struct *, mm) + __field(struct maple_tree *, mt) + ), + + TP_fast_assign( + __entry->mm = mm; + __entry->mt = &mm->mm_mt; + ), + + TP_printk("mt_mod %p, DESTROY\n", + __entry->mt + ) +); + #endif /* This part must be outside protection */ diff --git a/init/main.c b/init/main.c index 1fe7942f5d4a..df800fc61b2a 100644 --- a/init/main.c +++ b/init/main.c @@ -117,6 +117,7 @@ static int kernel_init(void *); extern void init_IRQ(void); extern void radix_tree_init(void); +extern void maple_tree_init(void); /* * Debug helper: via this flag we know that we are in 'early bootup code' @@ -1002,6 +1003,7 @@ asmlinkage __visible void __init __no_sanitize_address start_kernel(void) "Interrupts were enabled *very* early, fixing it\n")) local_irq_disable(); radix_tree_init(); + maple_tree_init(); /* * Set up housekeeping before setting up workqueues to allow the unbound diff --git a/ipc/shm.c b/ipc/shm.c index b3048ebd5c31..7d86f058fb86 100644 --- a/ipc/shm.c +++ b/ipc/shm.c @@ -1721,7 +1721,7 @@ long ksys_shmdt(char __user *shmaddr) #ifdef CONFIG_MMU loff_t size = 0; struct file *file; - struct vm_area_struct *next; + VMA_ITERATOR(vmi, mm, addr); #endif if (addr & ~PAGE_MASK) @@ -1751,12 +1751,9 @@ long ksys_shmdt(char __user *shmaddr) * match the usual checks anyway. So assume all vma's are * above the starting address given. */ - vma = find_vma(mm, addr); #ifdef CONFIG_MMU - while (vma) { - next = vma->vm_next; - + for_each_vma(vmi, vma) { /* * Check if the starting address would match, i.e. it's * a fragment created by mprotect() and/or munmap(), or it @@ -1774,6 +1771,7 @@ long ksys_shmdt(char __user *shmaddr) file = vma->vm_file; size = i_size_read(file_inode(vma->vm_file)); do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL); + mas_pause(&vmi.mas); /* * We discovered the size of the shm segment, so * break out of here and fall through to the next @@ -1781,10 +1779,9 @@ long ksys_shmdt(char __user *shmaddr) * searching for matching vma's. */ retval = 0; - vma = next; + vma = vma_next(&vmi); break; } - vma = next; } /* @@ -1794,17 +1791,19 @@ long ksys_shmdt(char __user *shmaddr) */ size = PAGE_ALIGN(size); while (vma && (loff_t)(vma->vm_end - addr) <= size) { - next = vma->vm_next; - /* finding a matching vma now does not alter retval */ if ((vma->vm_ops == &shm_vm_ops) && ((vma->vm_start - addr)/PAGE_SIZE == vma->vm_pgoff) && - (vma->vm_file == file)) + (vma->vm_file == file)) { do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL); - vma = next; + mas_pause(&vmi.mas); + } + + vma = vma_next(&vmi); } #else /* CONFIG_MMU */ + vma = vma_lookup(mm, addr); /* under NOMMU conditions, the exact address to be destroyed must be * given */ diff --git a/kernel/acct.c b/kernel/acct.c index 13706356ec54..62200d799b9b 100644 --- a/kernel/acct.c +++ b/kernel/acct.c @@ -555,15 +555,14 @@ void acct_collect(long exitcode, int group_dead) unsigned long vsize = 0; if (group_dead && current->mm) { + struct mm_struct *mm = current->mm; + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; - mmap_read_lock(current->mm); - vma = current->mm->mmap; - while (vma) { + mmap_read_lock(mm); + for_each_vma(vmi, vma) vsize += vma->vm_end - vma->vm_start; - vma = vma->vm_next; - } - mmap_read_unlock(current->mm); + mmap_read_unlock(mm); } spin_lock_irq(¤t->sighand->siglock); diff --git a/kernel/bounds.c b/kernel/bounds.c index 9795d75b09b2..b529182e8b04 100644 --- a/kernel/bounds.c +++ b/kernel/bounds.c @@ -22,6 +22,13 @@ int main(void) DEFINE(NR_CPUS_BITS, ilog2(CONFIG_NR_CPUS)); #endif DEFINE(SPINLOCK_SIZE, sizeof(spinlock_t)); +#ifdef CONFIG_LRU_GEN + DEFINE(LRU_GEN_WIDTH, order_base_2(MAX_NR_GENS + 1)); + DEFINE(__LRU_REFS_WIDTH, MAX_NR_TIERS - 2); +#else + DEFINE(LRU_GEN_WIDTH, 0); + DEFINE(__LRU_REFS_WIDTH, 0); +#endif /* End of constants */ return 0; diff --git a/kernel/bpf/task_iter.c b/kernel/bpf/task_iter.c index 8c921799def4..1c8debd42dc9 100644 --- a/kernel/bpf/task_iter.c +++ b/kernel/bpf/task_iter.c @@ -299,8 +299,8 @@ struct bpf_iter_seq_task_vma_info { }; enum bpf_task_vma_iter_find_op { - task_vma_iter_first_vma, /* use mm->mmap */ - task_vma_iter_next_vma, /* use curr_vma->vm_next */ + task_vma_iter_first_vma, /* use find_vma() with addr 0 */ + task_vma_iter_next_vma, /* use vma_next() with curr_vma */ task_vma_iter_find_vma, /* use find_vma() to find next vma */ }; @@ -400,10 +400,10 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info) switch (op) { case task_vma_iter_first_vma: - curr_vma = curr_task->mm->mmap; + curr_vma = find_vma(curr_task->mm, 0); break; case task_vma_iter_next_vma: - curr_vma = curr_vma->vm_next; + curr_vma = find_vma(curr_task->mm, curr_vma->vm_end); break; case task_vma_iter_find_vma: /* We dropped mmap_lock so it is necessary to use find_vma @@ -417,7 +417,7 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info) if (curr_vma && curr_vma->vm_start == info->prev_vm_start && curr_vma->vm_end == info->prev_vm_end) - curr_vma = curr_vma->vm_next; + curr_vma = find_vma(curr_task->mm, curr_vma->vm_end); break; } if (!curr_vma) { diff --git a/kernel/cgroup/cgroup-internal.h b/kernel/cgroup/cgroup-internal.h index 36b740cb3d59..63dc3e82be4f 100644 --- a/kernel/cgroup/cgroup-internal.h +++ b/kernel/cgroup/cgroup-internal.h @@ -164,7 +164,6 @@ struct cgroup_mgctx { #define DEFINE_CGROUP_MGCTX(name) \ struct cgroup_mgctx name = CGROUP_MGCTX_INIT(name) -extern struct mutex cgroup_mutex; extern spinlock_t css_set_lock; extern struct cgroup_subsys *cgroup_subsys[]; extern struct list_head cgroup_roots; diff --git a/kernel/debug/debug_core.c b/kernel/debug/debug_core.c index 7beceb447211..d5e9ccde3ab8 100644 --- a/kernel/debug/debug_core.c +++ b/kernel/debug/debug_core.c @@ -50,7 +50,6 @@ #include #include #include -#include #include #include #include @@ -283,17 +282,6 @@ static void kgdb_flush_swbreak_addr(unsigned long addr) if (!CACHE_FLUSH_IS_SAFE) return; - if (current->mm) { - int i; - - for (i = 0; i < VMACACHE_SIZE; i++) { - if (!current->vmacache.vmas[i]) - continue; - flush_cache_range(current->vmacache.vmas[i], - addr, addr + BREAK_INSTR_SIZE); - } - } - /* Force flush instruction cache if it was outside the mm */ flush_icache_range(addr, addr + BREAK_INSTR_SIZE); } diff --git a/kernel/events/core.c b/kernel/events/core.c index 8dcbefd90b7f..e1f2182d00eb 100644 --- a/kernel/events/core.c +++ b/kernel/events/core.c @@ -10329,8 +10329,9 @@ static void perf_addr_filter_apply(struct perf_addr_filter *filter, struct perf_addr_filter_range *fr) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (!vma->vm_file) continue; diff --git a/kernel/events/uprobes.c b/kernel/events/uprobes.c index 2eaa327f8158..401bc2d24ce0 100644 --- a/kernel/events/uprobes.c +++ b/kernel/events/uprobes.c @@ -349,9 +349,10 @@ static bool valid_ref_ctr_vma(struct uprobe *uprobe, static struct vm_area_struct * find_ref_ctr_vma(struct uprobe *uprobe, struct mm_struct *mm) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *tmp; - for (tmp = mm->mmap; tmp; tmp = tmp->vm_next) + for_each_vma(vmi, tmp) if (valid_ref_ctr_vma(uprobe, tmp)) return tmp; @@ -1231,11 +1232,12 @@ int uprobe_apply(struct inode *inode, loff_t offset, static int unapply_uprobe(struct uprobe *uprobe, struct mm_struct *mm) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; int err = 0; mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { unsigned long vaddr; loff_t offset; @@ -1983,9 +1985,10 @@ bool uprobe_deny_signal(void) static void mmf_recalc_uprobes(struct mm_struct *mm) { + VMA_ITERATOR(vmi, mm, 0); struct vm_area_struct *vma; - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (!valid_vma(vma, false)) continue; /* diff --git a/kernel/exit.c b/kernel/exit.c index 84021b24f79e..98a33bd7c25c 100644 --- a/kernel/exit.c +++ b/kernel/exit.c @@ -466,6 +466,7 @@ void mm_update_next_owner(struct mm_struct *mm) goto retry; } WRITE_ONCE(mm->owner, c); + lru_gen_migrate_mm(mm); task_unlock(c); put_task_struct(c); } diff --git a/kernel/fork.c b/kernel/fork.c index 704fe6bc9cb4..1add76edca04 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -43,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -479,7 +478,6 @@ struct vm_area_struct *vm_area_dup(struct vm_area_struct *orig) */ *new = data_race(*orig); INIT_LIST_HEAD(&new->anon_vma_chain); - new->vm_next = new->vm_prev = NULL; dup_anon_vma_name(orig, new); } return new; @@ -584,11 +582,12 @@ static void dup_mm_exe_file(struct mm_struct *mm, struct mm_struct *oldmm) static __latent_entropy int dup_mmap(struct mm_struct *mm, struct mm_struct *oldmm) { - struct vm_area_struct *mpnt, *tmp, *prev, **pprev; - struct rb_node **rb_link, *rb_parent; + struct vm_area_struct *mpnt, *tmp; int retval; - unsigned long charge; + unsigned long charge = 0; LIST_HEAD(uf); + MA_STATE(old_mas, &oldmm->mm_mt, 0, 0); + MA_STATE(mas, &mm->mm_mt, 0, 0); uprobe_start_dup_mmap(); if (mmap_write_lock_killable(oldmm)) { @@ -610,16 +609,16 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, mm->exec_vm = oldmm->exec_vm; mm->stack_vm = oldmm->stack_vm; - rb_link = &mm->mm_rb.rb_node; - rb_parent = NULL; - pprev = &mm->mmap; retval = ksm_fork(mm, oldmm); if (retval) goto out; khugepaged_fork(mm, oldmm); - prev = NULL; - for (mpnt = oldmm->mmap; mpnt; mpnt = mpnt->vm_next) { + retval = mas_expected_entries(&mas, oldmm->map_count); + if (retval) + goto out; + + mas_for_each(&old_mas, mpnt, ULONG_MAX) { struct file *file; if (mpnt->vm_flags & VM_DONTCOPY) { @@ -633,7 +632,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, */ if (fatal_signal_pending(current)) { retval = -EINTR; - goto out; + goto loop_out; } if (mpnt->vm_flags & VM_ACCOUNT) { unsigned long len = vma_pages(mpnt); @@ -686,17 +685,12 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, if (is_vm_hugetlb_page(tmp)) reset_vma_resv_huge_pages(tmp); - /* - * Link in the new vma and copy the page table entries. - */ - *pprev = tmp; - pprev = &tmp->vm_next; - tmp->vm_prev = prev; - prev = tmp; - - __vma_link_rb(mm, tmp, rb_link, rb_parent); - rb_link = &tmp->vm_rb.rb_right; - rb_parent = &tmp->vm_rb; + /* Link the vma into the MT */ + mas.index = tmp->vm_start; + mas.last = tmp->vm_end - 1; + mas_store(&mas, tmp); + if (mas_is_err(&mas)) + goto fail_nomem_mas_store; mm->map_count++; if (!(tmp->vm_flags & VM_WIPEONFORK)) @@ -706,10 +700,12 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, tmp->vm_ops->open(tmp); if (retval) - goto out; + goto loop_out; } /* a new mm has just been created */ retval = arch_dup_mmap(oldmm, mm); +loop_out: + mas_destroy(&mas); out: mmap_write_unlock(mm); flush_tlb_mm(oldmm); @@ -718,6 +714,9 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, fail_uprobe_end: uprobe_end_dup_mmap(); return retval; + +fail_nomem_mas_store: + unlink_anon_vmas(tmp); fail_nomem_anon_vma_fork: mpol_put(vma_policy(tmp)); fail_nomem_policy: @@ -725,7 +724,7 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, fail_nomem: retval = -ENOMEM; vm_unacct_memory(charge); - goto out; + goto loop_out; } static inline int mm_alloc_pgd(struct mm_struct *mm) @@ -1113,9 +1112,8 @@ static void mm_init_uprobes_state(struct mm_struct *mm) static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, struct user_namespace *user_ns) { - mm->mmap = NULL; - mm->mm_rb = RB_ROOT; - mm->vmacache_seqnum = 0; + mt_init_flags(&mm->mm_mt, MM_MT_FLAGS); + mt_set_external_lock(&mm->mm_mt, &mm->mmap_lock); atomic_set(&mm->mm_users, 1); atomic_set(&mm->mm_count, 1); seqcount_init(&mm->write_protect_seq); @@ -1156,6 +1154,7 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, goto fail_nocontext; mm->user_ns = get_user_ns(user_ns); + lru_gen_init_mm(mm); return mm; fail_nocontext: @@ -1198,6 +1197,7 @@ static inline void __mmput(struct mm_struct *mm) } if (mm->binfmt) module_put(mm->binfmt->module); + lru_gen_del_mm(mm); mmdrop(mm); } @@ -1289,13 +1289,16 @@ int replace_mm_exe_file(struct mm_struct *mm, struct file *new_exe_file) /* Forbid mm->exe_file change if old file still mapped. */ old_exe_file = get_mm_exe_file(mm); if (old_exe_file) { + VMA_ITERATOR(vmi, mm, 0); mmap_read_lock(mm); - for (vma = mm->mmap; vma && !ret; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (!vma->vm_file) continue; if (path_equal(&vma->vm_file->f_path, - &old_exe_file->f_path)) + &old_exe_file->f_path)) { ret = -EBUSY; + break; + } } mmap_read_unlock(mm); fput(old_exe_file); @@ -1571,9 +1574,6 @@ static int copy_mm(unsigned long clone_flags, struct task_struct *tsk) if (!oldmm) return 0; - /* initialize the new vmacache entries */ - vmacache_flush(tsk); - if (clone_flags & CLONE_VM) { mmget(oldmm); mm = oldmm; @@ -2700,6 +2700,13 @@ pid_t kernel_clone(struct kernel_clone_args *args) get_task_struct(p); } + if (IS_ENABLED(CONFIG_LRU_GEN) && !(clone_flags & CLONE_VM)) { + /* lock the task to synchronize with memcg migration */ + task_lock(p); + lru_gen_add_mm(p->mm); + task_unlock(p); + } + wake_up_new_task(p); /* forking complete and child started to run, tell ptracer */ diff --git a/kernel/sched/core.c b/kernel/sched/core.c index 6a4417178679..3403b9ac2ea4 100644 --- a/kernel/sched/core.c +++ b/kernel/sched/core.c @@ -5170,6 +5170,7 @@ context_switch(struct rq *rq, struct task_struct *prev, * finish_task_switch()'s mmdrop(). */ switch_mm_irqs_off(prev->active_mm, next->mm, next); + lru_gen_use_mm(next->mm); if (!prev->mm) { // from kernel /* will mmdrop() in finish_task_switch(). */ diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c index c4fd77e7e8f3..67b8e93ea742 100644 --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -2783,6 +2783,7 @@ static void task_numa_work(struct callback_head *work) struct task_struct *p = current; struct mm_struct *mm = p->mm; u64 runtime = p->se.sum_exec_runtime; + MA_STATE(mas, &mm->mm_mt, 0, 0); struct vm_area_struct *vma; unsigned long start, end; unsigned long nr_pte_updates = 0; @@ -2839,13 +2840,16 @@ static void task_numa_work(struct callback_head *work) if (!mmap_read_trylock(mm)) return; - vma = find_vma(mm, start); + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); if (!vma) { reset_ptenuma_scan(p); start = 0; - vma = mm->mmap; + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); } - for (; vma; vma = vma->vm_next) { + + for (; vma; vma = mas_find(&mas, ULONG_MAX)) { if (!vma_migratable(vma) || !vma_policy_mof(vma) || is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_MIXEDMAP)) { continue; diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 2c579d965266..9e30af03352a 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -846,13 +846,12 @@ config DEBUG_VM If unsure, say N. -config DEBUG_VM_VMACACHE - bool "Debug VMA caching" +config DEBUG_VM_MAPLE_TREE + bool "Debug VM maple trees" depends on DEBUG_VM + select DEBUG_MAPLE_TREE help - Enable this to turn on VMA caching debug information. Doing so - can cause significant overhead, so only enable it in non-production - environments. + Enable VM maple tree debugging information and extra validations. If unsure, say N. @@ -1669,6 +1668,14 @@ config BUG_ON_DATA_CORRUPTION If unsure, say N. +config DEBUG_MAPLE_TREE + bool "Debug maple trees" + depends on DEBUG_KERNEL + help + Enable maple tree debugging information and extra validations. + + If unsure, say N. + endmenu config DEBUG_CREDENTIALS @@ -2262,6 +2269,10 @@ config TEST_UUID config TEST_XARRAY tristate "Test the XArray code at runtime" +config TEST_MAPLE_TREE + select DEBUG_MAPLE_TREE + tristate "Test the Maple Tree code at runtime" + config TEST_RHASHTABLE tristate "Perform selftest on resizable hash table" help diff --git a/lib/Makefile b/lib/Makefile index ffabc30a27d4..b4a76a657e53 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -29,7 +29,7 @@ endif lib-y := ctype.o string.o vsprintf.o cmdline.o \ rbtree.o radix-tree.o timerqueue.o xarray.o \ - idr.o extable.o irq_regs.o argv_split.o \ + maple_tree.o idr.o extable.o irq_regs.o argv_split.o \ flex_proportions.o ratelimit.o show_mem.o \ is_single_threaded.o plist.o decompress.o kobject_uevent.o \ earlycpio.o seq_buf.o siphash.o dec_and_lock.o \ @@ -89,6 +89,7 @@ obj-$(CONFIG_TEST_BITMAP) += test_bitmap.o obj-$(CONFIG_TEST_STRSCPY) += test_strscpy.o obj-$(CONFIG_TEST_UUID) += test_uuid.o obj-$(CONFIG_TEST_XARRAY) += test_xarray.o +obj-$(CONFIG_TEST_MAPLE_TREE) += test_maple_tree.o obj-$(CONFIG_TEST_PARMAN) += test_parman.o obj-$(CONFIG_TEST_KMOD) += test_kmod.o obj-$(CONFIG_TEST_DEBUG_VIRTUAL) += test_debug_virtual.o diff --git a/lib/maple_tree.c b/lib/maple_tree.c new file mode 100644 index 000000000000..5e5f3ce44a6c --- /dev/null +++ b/lib/maple_tree.c @@ -0,0 +1,7158 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Maple Tree implementation + * Copyright (c) 2018-2022 Oracle Corporation + * Authors: Liam R. Howlett + * Matthew Wilcox + */ + +/* + * DOC: Interesting implementation details of the Maple Tree + * + * Each node type has a number of slots for entries and a number of slots for + * pivots. In the case of dense nodes, the pivots are implied by the position + * and are simply the slot index + the minimum of the node. + * + * In regular B-Tree terms, pivots are called keys. The term pivot is used to + * indicate that the tree is specifying ranges, Pivots may appear in the + * subtree with an entry attached to the value where as keys are unique to a + * specific position of a B-tree. Pivot values are inclusive of the slot with + * the same index. + * + * + * The following illustrates the layout of a range64 nodes slots and pivots. + * + * + * Slots -> | 0 | 1 | 2 | ... | 12 | 13 | 14 | 15 | + * ┬ ┬ ┬ ┬ ┬ ┬ ┬ ┬ ┬ + * │ │ │ │ │ │ │ │ └─ Implied maximum + * │ │ │ │ │ │ │ └─ Pivot 14 + * │ │ │ │ │ │ └─ Pivot 13 + * │ │ │ │ │ └─ Pivot 12 + * │ │ │ │ └─ Pivot 11 + * │ │ │ └─ Pivot 2 + * │ │ └─ Pivot 1 + * │ └─ Pivot 0 + * └─ Implied minimum + * + * Slot contents: + * Internal (non-leaf) nodes contain pointers to other nodes. + * Leaf nodes contain entries. + * + * The location of interest is often referred to as an offset. All offsets have + * a slot, but the last offset has an implied pivot from the node above (or + * UINT_MAX for the root node. + * + * Ranges complicate certain write activities. When modifying any of + * the B-tree variants, it is known that one entry will either be added or + * deleted. When modifying the Maple Tree, one store operation may overwrite + * the entire data set, or one half of the tree, or the middle half of the tree. + * + */ + + +#include +#include +#include +#include +#include +#include +#include + +#define CREATE_TRACE_POINTS +#include + +#define MA_ROOT_PARENT 1 + +/* + * Maple state flags + * * MA_STATE_BULK - Bulk insert mode + * * MA_STATE_REBALANCE - Indicate a rebalance during bulk insert + * * MA_STATE_PREALLOC - Preallocated nodes, WARN_ON allocation + */ +#define MA_STATE_BULK 1 +#define MA_STATE_REBALANCE 2 +#define MA_STATE_PREALLOC 4 + +#define ma_parent_ptr(x) ((struct maple_pnode *)(x)) +#define ma_mnode_ptr(x) ((struct maple_node *)(x)) +#define ma_enode_ptr(x) ((struct maple_enode *)(x)) +static struct kmem_cache *maple_node_cache; + +#ifdef CONFIG_DEBUG_MAPLE_TREE +static const unsigned long mt_max[] = { + [maple_dense] = MAPLE_NODE_SLOTS, + [maple_leaf_64] = ULONG_MAX, + [maple_range_64] = ULONG_MAX, + [maple_arange_64] = ULONG_MAX, +}; +#define mt_node_max(x) mt_max[mte_node_type(x)] +#endif + +static const unsigned char mt_slots[] = { + [maple_dense] = MAPLE_NODE_SLOTS, + [maple_leaf_64] = MAPLE_RANGE64_SLOTS, + [maple_range_64] = MAPLE_RANGE64_SLOTS, + [maple_arange_64] = MAPLE_ARANGE64_SLOTS, +}; +#define mt_slot_count(x) mt_slots[mte_node_type(x)] + +static const unsigned char mt_pivots[] = { + [maple_dense] = 0, + [maple_leaf_64] = MAPLE_RANGE64_SLOTS - 1, + [maple_range_64] = MAPLE_RANGE64_SLOTS - 1, + [maple_arange_64] = MAPLE_ARANGE64_SLOTS - 1, +}; +#define mt_pivot_count(x) mt_pivots[mte_node_type(x)] + +static const unsigned char mt_min_slots[] = { + [maple_dense] = MAPLE_NODE_SLOTS / 2, + [maple_leaf_64] = (MAPLE_RANGE64_SLOTS / 2) - 2, + [maple_range_64] = (MAPLE_RANGE64_SLOTS / 2) - 2, + [maple_arange_64] = (MAPLE_ARANGE64_SLOTS / 2) - 1, +}; +#define mt_min_slot_count(x) mt_min_slots[mte_node_type(x)] + +#define MAPLE_BIG_NODE_SLOTS (MAPLE_RANGE64_SLOTS * 2 + 2) +#define MAPLE_BIG_NODE_GAPS (MAPLE_ARANGE64_SLOTS * 2 + 1) + +struct maple_big_node { + struct maple_pnode *parent; + unsigned long pivot[MAPLE_BIG_NODE_SLOTS - 1]; + union { + struct maple_enode *slot[MAPLE_BIG_NODE_SLOTS]; + struct { + unsigned long padding[MAPLE_BIG_NODE_GAPS]; + unsigned long gap[MAPLE_BIG_NODE_GAPS]; + }; + }; + unsigned char b_end; + enum maple_type type; +}; + +/* + * The maple_subtree_state is used to build a tree to replace a segment of an + * existing tree in a more atomic way. Any walkers of the older tree will hit a + * dead node and restart on updates. + */ +struct maple_subtree_state { + struct ma_state *orig_l; /* Original left side of subtree */ + struct ma_state *orig_r; /* Original right side of subtree */ + struct ma_state *l; /* New left side of subtree */ + struct ma_state *m; /* New middle of subtree (rare) */ + struct ma_state *r; /* New right side of subtree */ + struct ma_topiary *free; /* nodes to be freed */ + struct ma_topiary *destroy; /* Nodes to be destroyed (walked and freed) */ + struct maple_big_node *bn; +}; + +/* Functions */ +static inline struct maple_node *mt_alloc_one(gfp_t gfp) +{ + return kmem_cache_alloc(maple_node_cache, gfp | __GFP_ZERO); +} + +static inline int mt_alloc_bulk(gfp_t gfp, size_t size, void **nodes) +{ + return kmem_cache_alloc_bulk(maple_node_cache, gfp | __GFP_ZERO, size, + nodes); +} + +static inline void mt_free_bulk(size_t size, void __rcu **nodes) +{ + kmem_cache_free_bulk(maple_node_cache, size, (void **)nodes); +} + +static void mt_free_rcu(struct rcu_head *head) +{ + struct maple_node *node = container_of(head, struct maple_node, rcu); + + kmem_cache_free(maple_node_cache, node); +} + +/* + * ma_free_rcu() - Use rcu callback to free a maple node + * @node: The node to free + * + * The maple tree uses the parent pointer to indicate this node is no longer in + * use and will be freed. + */ +static void ma_free_rcu(struct maple_node *node) +{ + node->parent = ma_parent_ptr(node); + call_rcu(&node->rcu, mt_free_rcu); +} + + +static void mas_set_height(struct ma_state *mas) +{ + unsigned int new_flags = mas->tree->ma_flags; + + new_flags &= ~MT_FLAGS_HEIGHT_MASK; + BUG_ON(mas->depth > MAPLE_HEIGHT_MAX); + new_flags |= mas->depth << MT_FLAGS_HEIGHT_OFFSET; + mas->tree->ma_flags = new_flags; +} + +static unsigned int mas_mt_height(struct ma_state *mas) +{ + return mt_height(mas->tree); +} + +static inline enum maple_type mte_node_type(const struct maple_enode *entry) +{ + return ((unsigned long)entry >> MAPLE_NODE_TYPE_SHIFT) & + MAPLE_NODE_TYPE_MASK; +} + +static inline bool ma_is_dense(const enum maple_type type) +{ + return type < maple_leaf_64; +} + +static inline bool ma_is_leaf(const enum maple_type type) +{ + return type < maple_range_64; +} + +static inline bool mte_is_leaf(const struct maple_enode *entry) +{ + return ma_is_leaf(mte_node_type(entry)); +} + +/* + * We also reserve values with the bottom two bits set to '10' which are + * below 4096 + */ +static inline bool mt_is_reserved(const void *entry) +{ + return ((unsigned long)entry < MAPLE_RESERVED_RANGE) && + xa_is_internal(entry); +} + +static inline void mas_set_err(struct ma_state *mas, long err) +{ + mas->node = MA_ERROR(err); +} + +static inline bool mas_is_ptr(struct ma_state *mas) +{ + return mas->node == MAS_ROOT; +} + +static inline bool mas_is_start(struct ma_state *mas) +{ + return mas->node == MAS_START; +} + +bool mas_is_err(struct ma_state *mas) +{ + return xa_is_err(mas->node); +} + +static inline bool mas_searchable(struct ma_state *mas) +{ + if (mas_is_none(mas)) + return false; + + if (mas_is_ptr(mas)) + return false; + + return true; +} + +static inline struct maple_node *mte_to_node(const struct maple_enode *entry) +{ + return (struct maple_node *)((unsigned long)entry & ~MAPLE_NODE_MASK); +} + +/* + * mte_to_mat() - Convert a maple encoded node to a maple topiary node. + * @entry: The maple encoded node + * + * Return: a maple topiary pointer + */ +static inline struct maple_topiary *mte_to_mat(const struct maple_enode *entry) +{ + return (struct maple_topiary *) + ((unsigned long)entry & ~MAPLE_NODE_MASK); +} + +/* + * mas_mn() - Get the maple state node. + * @mas: The maple state + * + * Return: the maple node (not encoded - bare pointer). + */ +static inline struct maple_node *mas_mn(const struct ma_state *mas) +{ + return mte_to_node(mas->node); +} + +/* + * mte_set_node_dead() - Set a maple encoded node as dead. + * @mn: The maple encoded node. + */ +static inline void mte_set_node_dead(struct maple_enode *mn) +{ + mte_to_node(mn)->parent = ma_parent_ptr(mte_to_node(mn)); + smp_wmb(); /* Needed for RCU */ +} + +/* Bit 1 indicates the root is a node */ +#define MAPLE_ROOT_NODE 0x02 +/* maple_type stored bit 3-6 */ +#define MAPLE_ENODE_TYPE_SHIFT 0x03 +/* Bit 2 means a NULL somewhere below */ +#define MAPLE_ENODE_NULL 0x04 + +static inline struct maple_enode *mt_mk_node(const struct maple_node *node, + enum maple_type type) +{ + return (void *)((unsigned long)node | + (type << MAPLE_ENODE_TYPE_SHIFT) | MAPLE_ENODE_NULL); +} + +static inline void *mte_mk_root(const struct maple_enode *node) +{ + return (void *)((unsigned long)node | MAPLE_ROOT_NODE); +} + +static inline void *mte_safe_root(const struct maple_enode *node) +{ + return (void *)((unsigned long)node & ~MAPLE_ROOT_NODE); +} + +static inline void *mte_set_full(const struct maple_enode *node) +{ + return (void *)((unsigned long)node & ~MAPLE_ENODE_NULL); +} + +static inline void *mte_clear_full(const struct maple_enode *node) +{ + return (void *)((unsigned long)node | MAPLE_ENODE_NULL); +} + +static inline bool mte_has_null(const struct maple_enode *node) +{ + return (unsigned long)node & MAPLE_ENODE_NULL; +} + +static inline bool ma_is_root(struct maple_node *node) +{ + return ((unsigned long)node->parent & MA_ROOT_PARENT); +} + +static inline bool mte_is_root(const struct maple_enode *node) +{ + return ma_is_root(mte_to_node(node)); +} + +static inline bool mas_is_root_limits(const struct ma_state *mas) +{ + return !mas->min && mas->max == ULONG_MAX; +} + +static inline bool mt_is_alloc(struct maple_tree *mt) +{ + return (mt->ma_flags & MT_FLAGS_ALLOC_RANGE); +} + +/* + * The Parent Pointer + * Excluding root, the parent pointer is 256B aligned like all other tree nodes. + * When storing a 32 or 64 bit values, the offset can fit into 5 bits. The 16 + * bit values need an extra bit to store the offset. This extra bit comes from + * a reuse of the last bit in the node type. This is possible by using bit 1 to + * indicate if bit 2 is part of the type or the slot. + * + * Note types: + * 0x??1 = Root + * 0x?00 = 16 bit nodes + * 0x010 = 32 bit nodes + * 0x110 = 64 bit nodes + * + * Slot size and alignment + * 0b??1 : Root + * 0b?00 : 16 bit values, type in 0-1, slot in 2-7 + * 0b010 : 32 bit values, type in 0-2, slot in 3-7 + * 0b110 : 64 bit values, type in 0-2, slot in 3-7 + */ + +#define MAPLE_PARENT_ROOT 0x01 + +#define MAPLE_PARENT_SLOT_SHIFT 0x03 +#define MAPLE_PARENT_SLOT_MASK 0xF8 + +#define MAPLE_PARENT_16B_SLOT_SHIFT 0x02 +#define MAPLE_PARENT_16B_SLOT_MASK 0xFC + +#define MAPLE_PARENT_RANGE64 0x06 +#define MAPLE_PARENT_RANGE32 0x04 +#define MAPLE_PARENT_NOT_RANGE16 0x02 + +/* + * mte_parent_shift() - Get the parent shift for the slot storage. + * @parent: The parent pointer cast as an unsigned long + * Return: The shift into that pointer to the star to of the slot + */ +static inline unsigned long mte_parent_shift(unsigned long parent) +{ + /* Note bit 1 == 0 means 16B */ + if (likely(parent & MAPLE_PARENT_NOT_RANGE16)) + return MAPLE_PARENT_SLOT_SHIFT; + + return MAPLE_PARENT_16B_SLOT_SHIFT; +} + +/* + * mte_parent_slot_mask() - Get the slot mask for the parent. + * @parent: The parent pointer cast as an unsigned long. + * Return: The slot mask for that parent. + */ +static inline unsigned long mte_parent_slot_mask(unsigned long parent) +{ + /* Note bit 1 == 0 means 16B */ + if (likely(parent & MAPLE_PARENT_NOT_RANGE16)) + return MAPLE_PARENT_SLOT_MASK; + + return MAPLE_PARENT_16B_SLOT_MASK; +} + +/* + * mas_parent_enum() - Return the maple_type of the parent from the stored + * parent type. + * @mas: The maple state + * @node: The maple_enode to extract the parent's enum + * Return: The node->parent maple_type + */ +static inline +enum maple_type mte_parent_enum(struct maple_enode *p_enode, + struct maple_tree *mt) +{ + unsigned long p_type; + + p_type = (unsigned long)p_enode; + if (p_type & MAPLE_PARENT_ROOT) + return 0; /* Validated in the caller. */ + + p_type &= MAPLE_NODE_MASK; + p_type = p_type & ~mte_parent_slot_mask(p_type); + + switch (p_type) { + case MAPLE_PARENT_RANGE64: /* or MAPLE_PARENT_ARANGE64 */ + if (mt_is_alloc(mt)) + return maple_arange_64; + return maple_range_64; + } + + return 0; +} + +static inline +enum maple_type mas_parent_enum(struct ma_state *mas, struct maple_enode *enode) +{ + return mte_parent_enum(ma_enode_ptr(mte_to_node(enode)->parent), mas->tree); +} + +/* + * mte_set_parent() - Set the parent node and encode the slot + * @enode: The encoded maple node. + * @parent: The encoded maple node that is the parent of @enode. + * @slot: The slot that @enode resides in @parent. + * + * Slot number is encoded in the enode->parent bit 3-6 or 2-6, depending on the + * parent type. + */ +static inline +void mte_set_parent(struct maple_enode *enode, const struct maple_enode *parent, + unsigned char slot) +{ + unsigned long val = (unsigned long) parent; + unsigned long shift; + unsigned long type; + enum maple_type p_type = mte_node_type(parent); + + BUG_ON(p_type == maple_dense); + BUG_ON(p_type == maple_leaf_64); + + switch (p_type) { + case maple_range_64: + case maple_arange_64: + shift = MAPLE_PARENT_SLOT_SHIFT; + type = MAPLE_PARENT_RANGE64; + break; + default: + case maple_dense: + case maple_leaf_64: + shift = type = 0; + break; + } + + val &= ~MAPLE_NODE_MASK; /* Clear all node metadata in parent */ + val |= (slot << shift) | type; + mte_to_node(enode)->parent = ma_parent_ptr(val); +} + +/* + * mte_parent_slot() - get the parent slot of @enode. + * @enode: The encoded maple node. + * + * Return: The slot in the parent node where @enode resides. + */ +static inline unsigned int mte_parent_slot(const struct maple_enode *enode) +{ + unsigned long val = (unsigned long) mte_to_node(enode)->parent; + + /* Root. */ + if (val & 1) + return 0; + + /* + * Okay to use MAPLE_PARENT_16B_SLOT_MASK as the last bit will be lost + * by shift if the parent shift is MAPLE_PARENT_SLOT_SHIFT + */ + return (val & MAPLE_PARENT_16B_SLOT_MASK) >> mte_parent_shift(val); +} + +/* + * mte_parent() - Get the parent of @node. + * @node: The encoded maple node. + * + * Return: The parent maple node. + */ +static inline struct maple_node *mte_parent(const struct maple_enode *enode) +{ + return (void *)((unsigned long) + (mte_to_node(enode)->parent) & ~MAPLE_NODE_MASK); +} + +/* + * ma_dead_node() - check if the @enode is dead. + * @enode: The encoded maple node + * + * Return: true if dead, false otherwise. + */ +static inline bool ma_dead_node(const struct maple_node *node) +{ + struct maple_node *parent = (void *)((unsigned long) + node->parent & ~MAPLE_NODE_MASK); + + return (parent == node); +} +/* + * mte_dead_node() - check if the @enode is dead. + * @enode: The encoded maple node + * + * Return: true if dead, false otherwise. + */ +static inline bool mte_dead_node(const struct maple_enode *enode) +{ + struct maple_node *parent, *node; + + node = mte_to_node(enode); + parent = mte_parent(enode); + return (parent == node); +} + +/* + * mas_allocated() - Get the number of nodes allocated in a maple state. + * @mas: The maple state + * + * The ma_state alloc member is overloaded to hold a pointer to the first + * allocated node or to the number of requested nodes to allocate. If bit 0 is + * set, then the alloc contains the number of requested nodes. If there is an + * allocated node, then the total allocated nodes is in that node. + * + * Return: The total number of nodes allocated + */ +static inline unsigned long mas_allocated(const struct ma_state *mas) +{ + if (!mas->alloc || ((unsigned long)mas->alloc & 0x1)) + return 0; + + return mas->alloc->total; +} + +/* + * mas_set_alloc_req() - Set the requested number of allocations. + * @mas: the maple state + * @count: the number of allocations. + * + * The requested number of allocations is either in the first allocated node, + * located in @mas->alloc->request_count, or directly in @mas->alloc if there is + * no allocated node. Set the request either in the node or do the necessary + * encoding to store in @mas->alloc directly. + */ +static inline void mas_set_alloc_req(struct ma_state *mas, unsigned long count) +{ + if (!mas->alloc || ((unsigned long)mas->alloc & 0x1)) { + if (!count) + mas->alloc = NULL; + else + mas->alloc = (struct maple_alloc *)(((count) << 1U) | 1U); + return; + } + + mas->alloc->request_count = count; +} + +/* + * mas_alloc_req() - get the requested number of allocations. + * @mas: The maple state + * + * The alloc count is either stored directly in @mas, or in + * @mas->alloc->request_count if there is at least one node allocated. Decode + * the request count if it's stored directly in @mas->alloc. + * + * Return: The allocation request count. + */ +static inline unsigned int mas_alloc_req(const struct ma_state *mas) +{ + if ((unsigned long)mas->alloc & 0x1) + return (unsigned long)(mas->alloc) >> 1; + else if (mas->alloc) + return mas->alloc->request_count; + return 0; +} + +/* + * ma_pivots() - Get a pointer to the maple node pivots. + * @node - the maple node + * @type - the node type + * + * Return: A pointer to the maple node pivots + */ +static inline unsigned long *ma_pivots(struct maple_node *node, + enum maple_type type) +{ + switch (type) { + case maple_arange_64: + return node->ma64.pivot; + case maple_range_64: + case maple_leaf_64: + return node->mr64.pivot; + case maple_dense: + return NULL; + } + return NULL; +} + +/* + * ma_gaps() - Get a pointer to the maple node gaps. + * @node - the maple node + * @type - the node type + * + * Return: A pointer to the maple node gaps + */ +static inline unsigned long *ma_gaps(struct maple_node *node, + enum maple_type type) +{ + switch (type) { + case maple_arange_64: + return node->ma64.gap; + case maple_range_64: + case maple_leaf_64: + case maple_dense: + return NULL; + } + return NULL; +} + +/* + * mte_pivot() - Get the pivot at @piv of the maple encoded node. + * @mn: The maple encoded node. + * @piv: The pivot. + * + * Return: the pivot at @piv of @mn. + */ +static inline unsigned long mte_pivot(const struct maple_enode *mn, + unsigned char piv) +{ + struct maple_node *node = mte_to_node(mn); + enum maple_type type = mte_node_type(mn); + + if (piv >= mt_pivots[type]) { + WARN_ON(1); + return 0; + } + switch (type) { + case maple_arange_64: + return node->ma64.pivot[piv]; + case maple_range_64: + case maple_leaf_64: + return node->mr64.pivot[piv]; + case maple_dense: + return 0; + } + return 0; +} + +/* + * mas_safe_pivot() - get the pivot at @piv or mas->max. + * @mas: The maple state + * @pivots: The pointer to the maple node pivots + * @piv: The pivot to fetch + * @type: The maple node type + * + * Return: The pivot at @piv within the limit of the @pivots array, @mas->max + * otherwise. + */ +static inline unsigned long +mas_safe_pivot(const struct ma_state *mas, unsigned long *pivots, + unsigned char piv, enum maple_type type) +{ + if (piv >= mt_pivots[type]) + return mas->max; + + return pivots[piv]; +} + +/* + * mas_safe_min() - Return the minimum for a given offset. + * @mas: The maple state + * @pivots: The pointer to the maple node pivots + * @offset: The offset into the pivot array + * + * Return: The minimum range value that is contained in @offset. + */ +static inline unsigned long +mas_safe_min(struct ma_state *mas, unsigned long *pivots, unsigned char offset) +{ + if (likely(offset)) + return pivots[offset - 1] + 1; + + return mas->min; +} + +/* + * mas_logical_pivot() - Get the logical pivot of a given offset. + * @mas: The maple state + * @pivots: The pointer to the maple node pivots + * @offset: The offset into the pivot array + * @type: The maple node type + * + * When there is no value at a pivot (beyond the end of the data), then the + * pivot is actually @mas->max. + * + * Return: the logical pivot of a given @offset. + */ +static inline unsigned long +mas_logical_pivot(struct ma_state *mas, unsigned long *pivots, + unsigned char offset, enum maple_type type) +{ + unsigned long lpiv = mas_safe_pivot(mas, pivots, offset, type); + + if (likely(lpiv)) + return lpiv; + + if (likely(offset)) + return mas->max; + + return lpiv; +} + +/* + * mte_set_pivot() - Set a pivot to a value in an encoded maple node. + * @mn: The encoded maple node + * @piv: The pivot offset + * @val: The value of the pivot + */ +static inline void mte_set_pivot(struct maple_enode *mn, unsigned char piv, + unsigned long val) +{ + struct maple_node *node = mte_to_node(mn); + enum maple_type type = mte_node_type(mn); + + BUG_ON(piv >= mt_pivots[type]); + switch (type) { + default: + case maple_range_64: + case maple_leaf_64: + node->mr64.pivot[piv] = val; + break; + case maple_arange_64: + node->ma64.pivot[piv] = val; + break; + case maple_dense: + break; + } + +} + +/* + * ma_slots() - Get a pointer to the maple node slots. + * @mn: The maple node + * @mt: The maple node type + * + * Return: A pointer to the maple node slots + */ +static inline void __rcu **ma_slots(struct maple_node *mn, enum maple_type mt) +{ + switch (mt) { + default: + case maple_arange_64: + return mn->ma64.slot; + case maple_range_64: + case maple_leaf_64: + return mn->mr64.slot; + case maple_dense: + return mn->slot; + } +} + +static inline bool mt_locked(const struct maple_tree *mt) +{ + return mt_external_lock(mt) ? mt_lock_is_held(mt) : + lockdep_is_held(&mt->ma_lock); +} + +static inline void *mt_slot(const struct maple_tree *mt, + void __rcu **slots, unsigned char offset) +{ + return rcu_dereference_check(slots[offset], mt_locked(mt)); +} + +/* + * mas_slot_locked() - Get the slot value when holding the maple tree lock. + * @mas: The maple state + * @slots: The pointer to the slots + * @offset: The offset into the slots array to fetch + * + * Return: The entry stored in @slots at the @offset. + */ +static inline void *mas_slot_locked(struct ma_state *mas, void __rcu **slots, + unsigned char offset) +{ + return rcu_dereference_protected(slots[offset], mt_locked(mas->tree)); +} + +/* + * mas_slot() - Get the slot value when not holding the maple tree lock. + * @mas: The maple state + * @slots: The pointer to the slots + * @offset: The offset into the slots array to fetch + * + * Return: The entry stored in @slots at the @offset + */ +static inline void *mas_slot(struct ma_state *mas, void __rcu **slots, + unsigned char offset) +{ + return mt_slot(mas->tree, slots, offset); +} + +/* + * mas_root() - Get the maple tree root. + * @mas: The maple state. + * + * Return: The pointer to the root of the tree + */ +static inline void *mas_root(struct ma_state *mas) +{ + return rcu_dereference_check(mas->tree->ma_root, mt_locked(mas->tree)); +} + +static inline void *mt_root_locked(struct maple_tree *mt) +{ + return rcu_dereference_protected(mt->ma_root, mt_locked(mt)); +} + +/* + * mas_root_locked() - Get the maple tree root when holding the maple tree lock. + * @mas: The maple state. + * + * Return: The pointer to the root of the tree + */ +static inline void *mas_root_locked(struct ma_state *mas) +{ + return mt_root_locked(mas->tree); +} + +static inline struct maple_metadata *ma_meta(struct maple_node *mn, + enum maple_type mt) +{ + switch (mt) { + case maple_arange_64: + return &mn->ma64.meta; + default: + return &mn->mr64.meta; + } +} + +/* + * ma_set_meta() - Set the metadata information of a node. + * @mn: The maple node + * @mt: The maple node type + * @offset: The offset of the highest sub-gap in this node. + * @end: The end of the data in this node. + */ +static inline void ma_set_meta(struct maple_node *mn, enum maple_type mt, + unsigned char offset, unsigned char end) +{ + struct maple_metadata *meta = ma_meta(mn, mt); + + meta->gap = offset; + meta->end = end; +} + +/* + * ma_meta_end() - Get the data end of a node from the metadata + * @mn: The maple node + * @mt: The maple node type + */ +static inline unsigned char ma_meta_end(struct maple_node *mn, + enum maple_type mt) +{ + struct maple_metadata *meta = ma_meta(mn, mt); + + return meta->end; +} + +/* + * ma_meta_gap() - Get the largest gap location of a node from the metadata + * @mn: The maple node + * @mt: The maple node type + */ +static inline unsigned char ma_meta_gap(struct maple_node *mn, + enum maple_type mt) +{ + BUG_ON(mt != maple_arange_64); + + return mn->ma64.meta.gap; +} + +/* + * ma_set_meta_gap() - Set the largest gap location in a nodes metadata + * @mn: The maple node + * @mn: The maple node type + * @offset: The location of the largest gap. + */ +static inline void ma_set_meta_gap(struct maple_node *mn, enum maple_type mt, + unsigned char offset) +{ + + struct maple_metadata *meta = ma_meta(mn, mt); + + meta->gap = offset; +} + +/* + * mat_add() - Add a @dead_enode to the ma_topiary of a list of dead nodes. + * @mat - the ma_topiary, a linked list of dead nodes. + * @dead_enode - the node to be marked as dead and added to the tail of the list + * + * Add the @dead_enode to the linked list in @mat. + */ +static inline void mat_add(struct ma_topiary *mat, + struct maple_enode *dead_enode) +{ + mte_set_node_dead(dead_enode); + mte_to_mat(dead_enode)->next = NULL; + if (!mat->tail) { + mat->tail = mat->head = dead_enode; + return; + } + + mte_to_mat(mat->tail)->next = dead_enode; + mat->tail = dead_enode; +} + +static void mte_destroy_walk(struct maple_enode *, struct maple_tree *); +static inline void mas_free(struct ma_state *mas, struct maple_enode *used); + +/* + * mas_mat_free() - Free all nodes in a dead list. + * @mas - the maple state + * @mat - the ma_topiary linked list of dead nodes to free. + * + * Free walk a dead list. + */ +static void mas_mat_free(struct ma_state *mas, struct ma_topiary *mat) +{ + struct maple_enode *next; + + while (mat->head) { + next = mte_to_mat(mat->head)->next; + mas_free(mas, mat->head); + mat->head = next; + } +} + +/* + * mas_mat_destroy() - Free all nodes and subtrees in a dead list. + * @mas - the maple state + * @mat - the ma_topiary linked list of dead nodes to free. + * + * Destroy walk a dead list. + */ +static void mas_mat_destroy(struct ma_state *mas, struct ma_topiary *mat) +{ + struct maple_enode *next; + + while (mat->head) { + next = mte_to_mat(mat->head)->next; + mte_destroy_walk(mat->head, mat->mtree); + mat->head = next; + } +} +/* + * mas_descend() - Descend into the slot stored in the ma_state. + * @mas - the maple state. + * + * Note: Not RCU safe, only use in write side or debug code. + */ +static inline void mas_descend(struct ma_state *mas) +{ + enum maple_type type; + unsigned long *pivots; + struct maple_node *node; + void __rcu **slots; + + node = mas_mn(mas); + type = mte_node_type(mas->node); + pivots = ma_pivots(node, type); + slots = ma_slots(node, type); + + if (mas->offset) + mas->min = pivots[mas->offset - 1] + 1; + mas->max = mas_safe_pivot(mas, pivots, mas->offset, type); + mas->node = mas_slot(mas, slots, mas->offset); +} + +/* + * mte_set_gap() - Set a maple node gap. + * @mn: The encoded maple node + * @gap: The offset of the gap to set + * @val: The gap value + */ +static inline void mte_set_gap(const struct maple_enode *mn, + unsigned char gap, unsigned long val) +{ + switch (mte_node_type(mn)) { + default: + break; + case maple_arange_64: + mte_to_node(mn)->ma64.gap[gap] = val; + break; + } +} + +/* + * mas_ascend() - Walk up a level of the tree. + * @mas: The maple state + * + * Sets the @mas->max and @mas->min to the correct values when walking up. This + * may cause several levels of walking up to find the correct min and max. + * May find a dead node which will cause a premature return. + * Return: 1 on dead node, 0 otherwise + */ +static int mas_ascend(struct ma_state *mas) +{ + struct maple_enode *p_enode; /* parent enode. */ + struct maple_enode *a_enode; /* ancestor enode. */ + struct maple_node *a_node; /* ancestor node. */ + struct maple_node *p_node; /* parent node. */ + unsigned char a_slot; + enum maple_type a_type; + unsigned long min, max; + unsigned long *pivots; + unsigned char offset; + bool set_max = false, set_min = false; + + a_node = mas_mn(mas); + if (ma_is_root(a_node)) { + mas->offset = 0; + return 0; + } + + p_node = mte_parent(mas->node); + if (unlikely(a_node == p_node)) + return 1; + a_type = mas_parent_enum(mas, mas->node); + offset = mte_parent_slot(mas->node); + a_enode = mt_mk_node(p_node, a_type); + + /* Check to make sure all parent information is still accurate */ + if (p_node != mte_parent(mas->node)) + return 1; + + mas->node = a_enode; + mas->offset = offset; + + if (mte_is_root(a_enode)) { + mas->max = ULONG_MAX; + mas->min = 0; + return 0; + } + + min = 0; + max = ULONG_MAX; + do { + p_enode = a_enode; + a_type = mas_parent_enum(mas, p_enode); + a_node = mte_parent(p_enode); + a_slot = mte_parent_slot(p_enode); + pivots = ma_pivots(a_node, a_type); + a_enode = mt_mk_node(a_node, a_type); + + if (!set_min && a_slot) { + set_min = true; + min = pivots[a_slot - 1] + 1; + } + + if (!set_max && a_slot < mt_pivots[a_type]) { + set_max = true; + max = pivots[a_slot]; + } + + if (unlikely(ma_dead_node(a_node))) + return 1; + + if (unlikely(ma_is_root(a_node))) + break; + + } while (!set_min || !set_max); + + mas->max = max; + mas->min = min; + return 0; +} + +/* + * mas_pop_node() - Get a previously allocated maple node from the maple state. + * @mas: The maple state + * + * Return: A pointer to a maple node. + */ +static inline struct maple_node *mas_pop_node(struct ma_state *mas) +{ + struct maple_alloc *ret, *node = mas->alloc; + unsigned long total = mas_allocated(mas); + + /* nothing or a request pending. */ + if (unlikely(!total)) + return NULL; + + if (total == 1) { + /* single allocation in this ma_state */ + mas->alloc = NULL; + ret = node; + goto single_node; + } + + if (!node->node_count) { + /* Single allocation in this node. */ + mas->alloc = node->slot[0]; + node->slot[0] = NULL; + mas->alloc->total = node->total - 1; + ret = node; + goto new_head; + } + + node->total--; + ret = node->slot[node->node_count]; + node->slot[node->node_count--] = NULL; + +single_node: +new_head: + ret->total = 0; + ret->node_count = 0; + if (ret->request_count) { + mas_set_alloc_req(mas, ret->request_count + 1); + ret->request_count = 0; + } + return (struct maple_node *)ret; +} + +/* + * mas_push_node() - Push a node back on the maple state allocation. + * @mas: The maple state + * @used: The used maple node + * + * Stores the maple node back into @mas->alloc for reuse. Updates allocated and + * requested node count as necessary. + */ +static inline void mas_push_node(struct ma_state *mas, struct maple_node *used) +{ + struct maple_alloc *reuse = (struct maple_alloc *)used; + struct maple_alloc *head = mas->alloc; + unsigned long count; + unsigned int requested = mas_alloc_req(mas); + + memset(reuse, 0, sizeof(*reuse)); + count = mas_allocated(mas); + + if (count && (head->node_count < MAPLE_ALLOC_SLOTS - 1)) { + if (head->slot[0]) + head->node_count++; + head->slot[head->node_count] = reuse; + head->total++; + goto done; + } + + reuse->total = 1; + if ((head) && !((unsigned long)head & 0x1)) { + head->request_count = 0; + reuse->slot[0] = head; + reuse->total += head->total; + } + + mas->alloc = reuse; +done: + if (requested > 1) + mas_set_alloc_req(mas, requested - 1); +} + +/* + * mas_alloc_nodes() - Allocate nodes into a maple state + * @mas: The maple state + * @gfp: The GFP Flags + */ +static inline void mas_alloc_nodes(struct ma_state *mas, gfp_t gfp) +{ + struct maple_alloc *node; + unsigned long allocated = mas_allocated(mas); + unsigned long success = allocated; + unsigned int requested = mas_alloc_req(mas); + unsigned int count; + void **slots = NULL; + unsigned int max_req = 0; + + if (!requested) + return; + + mas_set_alloc_req(mas, 0); + if (mas->mas_flags & MA_STATE_PREALLOC) { + if (allocated) + return; + WARN_ON(!allocated); + } + + if (!allocated || mas->alloc->node_count == MAPLE_ALLOC_SLOTS - 1) { + node = (struct maple_alloc *)mt_alloc_one(gfp); + if (!node) + goto nomem_one; + + if (allocated) + node->slot[0] = mas->alloc; + + success++; + mas->alloc = node; + requested--; + } + + node = mas->alloc; + while (requested) { + max_req = MAPLE_ALLOC_SLOTS; + if (node->slot[0]) { + unsigned int offset = node->node_count + 1; + + slots = (void **)&node->slot[offset]; + max_req -= offset; + } else { + slots = (void **)&node->slot; + } + + max_req = min(requested, max_req); + count = mt_alloc_bulk(gfp, max_req, slots); + if (!count) + goto nomem_bulk; + + node->node_count += count; + /* zero indexed. */ + if (slots == (void **)&node->slot) + node->node_count--; + + success += count; + node = node->slot[0]; + requested -= count; + } + mas->alloc->total = success; + return; + +nomem_bulk: + /* Clean up potential freed allocations on bulk failure */ + memset(slots, 0, max_req * sizeof(unsigned long)); +nomem_one: + mas_set_alloc_req(mas, requested); + if (mas->alloc && !(((unsigned long)mas->alloc & 0x1))) + mas->alloc->total = success; + mas_set_err(mas, -ENOMEM); + return; + +} + +/* + * mas_free() - Free an encoded maple node + * @mas: The maple state + * @used: The encoded maple node to free. + * + * Uses rcu free if necessary, pushes @used back on the maple state allocations + * otherwise. + */ +static inline void mas_free(struct ma_state *mas, struct maple_enode *used) +{ + struct maple_node *tmp = mte_to_node(used); + + if (mt_in_rcu(mas->tree)) + ma_free_rcu(tmp); + else + mas_push_node(mas, tmp); +} + +/* + * mas_node_count() - Check if enough nodes are allocated and request more if + * there is not enough nodes. + * @mas: The maple state + * @count: The number of nodes needed + * @gfp: the gfp flags + */ +static void mas_node_count_gfp(struct ma_state *mas, int count, gfp_t gfp) +{ + unsigned long allocated = mas_allocated(mas); + + if (allocated < count) { + mas_set_alloc_req(mas, count - allocated); + mas_alloc_nodes(mas, gfp); + } +} + +/* + * mas_node_count() - Check if enough nodes are allocated and request more if + * there is not enough nodes. + * @mas: The maple state + * @count: The number of nodes needed + * + * Note: Uses GFP_NOWAIT | __GFP_NOWARN for gfp flags. + */ +static void mas_node_count(struct ma_state *mas, int count) +{ + return mas_node_count_gfp(mas, count, GFP_NOWAIT | __GFP_NOWARN); +} + +/* + * mas_start() - Sets up maple state for operations. + * @mas: The maple state. + * + * If mas->node == MAS_START, then set the min, max, depth, and offset to + * defaults. + * + * Return: + * - If mas->node is an error or not MAS_START, return NULL. + * - If it's an empty tree: NULL & mas->node == MAS_NONE + * - If it's a single entry: The entry & mas->node == MAS_ROOT + * - If it's a tree: NULL & mas->node == safe root node. + */ +static inline struct maple_enode *mas_start(struct ma_state *mas) +{ + if (likely(mas_is_start(mas))) { + struct maple_enode *root; + + mas->node = MAS_NONE; + mas->min = 0; + mas->max = ULONG_MAX; + mas->depth = 0; + mas->offset = 0; + + root = mas_root(mas); + /* Tree with nodes */ + if (likely(xa_is_node(root))) { + mas->depth = 1; + mas->node = mte_safe_root(root); + return NULL; + } + + /* empty tree */ + if (unlikely(!root)) { + mas->offset = MAPLE_NODE_SLOTS; + return NULL; + } + + /* Single entry tree */ + mas->node = MAS_ROOT; + mas->offset = MAPLE_NODE_SLOTS; + + /* Single entry tree. */ + if (mas->index > 0) + return NULL; + + return root; + } + + return NULL; +} + +/* + * ma_data_end() - Find the end of the data in a node. + * @node: The maple node + * @type: The maple node type + * @pivots: The array of pivots in the node + * @max: The maximum value in the node + * + * Uses metadata to find the end of the data when possible. + * Return: The zero indexed last slot with data (may be null). + */ +static inline unsigned char ma_data_end(struct maple_node *node, + enum maple_type type, + unsigned long *pivots, + unsigned long max) +{ + unsigned char offset; + + if (type == maple_arange_64) + return ma_meta_end(node, type); + + offset = mt_pivots[type] - 1; + if (likely(!pivots[offset])) + return ma_meta_end(node, type); + + if (likely(pivots[offset] == max)) + return offset; + + return mt_pivots[type]; +} + +/* + * mas_data_end() - Find the end of the data (slot). + * @mas: the maple state + * + * This method is optimized to check the metadata of a node if the node type + * supports data end metadata. + * + * Return: The zero indexed last slot with data (may be null). + */ +static inline unsigned char mas_data_end(struct ma_state *mas) +{ + enum maple_type type; + struct maple_node *node; + unsigned char offset; + unsigned long *pivots; + + type = mte_node_type(mas->node); + node = mas_mn(mas); + if (type == maple_arange_64) + return ma_meta_end(node, type); + + pivots = ma_pivots(node, type); + offset = mt_pivots[type] - 1; + if (likely(!pivots[offset])) + return ma_meta_end(node, type); + + if (likely(pivots[offset] == mas->max)) + return offset; + + return mt_pivots[type]; +} + +/* + * mas_leaf_max_gap() - Returns the largest gap in a leaf node + * @mas - the maple state + * + * Return: The maximum gap in the leaf. + */ +static unsigned long mas_leaf_max_gap(struct ma_state *mas) +{ + enum maple_type mt; + unsigned long pstart, gap, max_gap; + struct maple_node *mn; + unsigned long *pivots; + void __rcu **slots; + unsigned char i; + unsigned char max_piv; + + mt = mte_node_type(mas->node); + mn = mas_mn(mas); + slots = ma_slots(mn, mt); + max_gap = 0; + if (unlikely(ma_is_dense(mt))) { + gap = 0; + for (i = 0; i < mt_slots[mt]; i++) { + if (slots[i]) { + if (gap > max_gap) + max_gap = gap; + gap = 0; + } else { + gap++; + } + } + if (gap > max_gap) + max_gap = gap; + return max_gap; + } + + /* + * Check the first implied pivot optimizes the loop below and slot 1 may + * be skipped if there is a gap in slot 0. + */ + pivots = ma_pivots(mn, mt); + if (likely(!slots[0])) { + max_gap = pivots[0] - mas->min + 1; + i = 2; + } else { + i = 1; + } + + /* reduce max_piv as the special case is checked before the loop */ + max_piv = ma_data_end(mn, mt, pivots, mas->max) - 1; + /* + * Check end implied pivot which can only be a gap on the right most + * node. + */ + if (unlikely(mas->max == ULONG_MAX) && !slots[max_piv + 1]) { + gap = ULONG_MAX - pivots[max_piv]; + if (gap > max_gap) + max_gap = gap; + } + + for (; i <= max_piv; i++) { + /* data == no gap. */ + if (likely(slots[i])) + continue; + + pstart = pivots[i - 1]; + gap = pivots[i] - pstart; + if (gap > max_gap) + max_gap = gap; + + /* There cannot be two gaps in a row. */ + i++; + } + return max_gap; +} + +/* + * ma_max_gap() - Get the maximum gap in a maple node (non-leaf) + * @node: The maple node + * @gaps: The pointer to the gaps + * @mt: The maple node type + * @*off: Pointer to store the offset location of the gap. + * + * Uses the metadata data end to scan backwards across set gaps. + * + * Return: The maximum gap value + */ +static inline unsigned long +ma_max_gap(struct maple_node *node, unsigned long *gaps, enum maple_type mt, + unsigned char *off) +{ + unsigned char offset, i; + unsigned long max_gap = 0; + + i = offset = ma_meta_end(node, mt); + do { + if (gaps[i] > max_gap) { + max_gap = gaps[i]; + offset = i; + } + } while (i--); + + *off = offset; + return max_gap; +} + +/* + * mas_max_gap() - find the largest gap in a non-leaf node and set the slot. + * @mas: The maple state. + * + * If the metadata gap is set to MAPLE_ARANGE64_META_MAX, there is no gap. + * + * Return: The gap value. + */ +static inline unsigned long mas_max_gap(struct ma_state *mas) +{ + unsigned long *gaps; + unsigned char offset; + enum maple_type mt; + struct maple_node *node; + + mt = mte_node_type(mas->node); + if (ma_is_leaf(mt)) + return mas_leaf_max_gap(mas); + + node = mas_mn(mas); + offset = ma_meta_gap(node, mt); + if (offset == MAPLE_ARANGE64_META_MAX) + return 0; + + gaps = ma_gaps(node, mt); + return gaps[offset]; +} + +/* + * mas_parent_gap() - Set the parent gap and any gaps above, as needed + * @mas: The maple state + * @offset: The gap offset in the parent to set + * @new: The new gap value. + * + * Set the parent gap then continue to set the gap upwards, using the metadata + * of the parent to see if it is necessary to check the node above. + */ +static inline void mas_parent_gap(struct ma_state *mas, unsigned char offset, + unsigned long new) +{ + unsigned long meta_gap = 0; + struct maple_node *pnode; + struct maple_enode *penode; + unsigned long *pgaps; + unsigned char meta_offset; + enum maple_type pmt; + + pnode = mte_parent(mas->node); + pmt = mas_parent_enum(mas, mas->node); + penode = mt_mk_node(pnode, pmt); + pgaps = ma_gaps(pnode, pmt); + +ascend: + meta_offset = ma_meta_gap(pnode, pmt); + if (meta_offset == MAPLE_ARANGE64_META_MAX) + meta_gap = 0; + else + meta_gap = pgaps[meta_offset]; + + pgaps[offset] = new; + + if (meta_gap == new) + return; + + if (offset != meta_offset) { + if (meta_gap > new) + return; + + ma_set_meta_gap(pnode, pmt, offset); + } else if (new < meta_gap) { + meta_offset = 15; + new = ma_max_gap(pnode, pgaps, pmt, &meta_offset); + ma_set_meta_gap(pnode, pmt, meta_offset); + } + + if (ma_is_root(pnode)) + return; + + /* Go to the parent node. */ + pnode = mte_parent(penode); + pmt = mas_parent_enum(mas, penode); + pgaps = ma_gaps(pnode, pmt); + offset = mte_parent_slot(penode); + penode = mt_mk_node(pnode, pmt); + goto ascend; +} + +/* + * mas_update_gap() - Update a nodes gaps and propagate up if necessary. + * @mas - the maple state. + */ +static inline void mas_update_gap(struct ma_state *mas) +{ + unsigned char pslot; + unsigned long p_gap; + unsigned long max_gap; + + if (!mt_is_alloc(mas->tree)) + return; + + if (mte_is_root(mas->node)) + return; + + max_gap = mas_max_gap(mas); + + pslot = mte_parent_slot(mas->node); + p_gap = ma_gaps(mte_parent(mas->node), + mas_parent_enum(mas, mas->node))[pslot]; + + if (p_gap != max_gap) + mas_parent_gap(mas, pslot, max_gap); +} + +/* + * mas_adopt_children() - Set the parent pointer of all nodes in @parent to + * @parent with the slot encoded. + * @mas - the maple state (for the tree) + * @parent - the maple encoded node containing the children. + */ +static inline void mas_adopt_children(struct ma_state *mas, + struct maple_enode *parent) +{ + enum maple_type type = mte_node_type(parent); + struct maple_node *node = mas_mn(mas); + void __rcu **slots = ma_slots(node, type); + unsigned long *pivots = ma_pivots(node, type); + struct maple_enode *child; + unsigned char offset; + + offset = ma_data_end(node, type, pivots, mas->max); + do { + child = mas_slot_locked(mas, slots, offset); + mte_set_parent(child, parent, offset); + } while (offset--); +} + +/* + * mas_replace() - Replace a maple node in the tree with mas->node. Uses the + * parent encoding to locate the maple node in the tree. + * @mas - the ma_state to use for operations. + * @advanced - boolean to adopt the child nodes and free the old node (false) or + * leave the node (true) and handle the adoption and free elsewhere. + */ +static inline void mas_replace(struct ma_state *mas, bool advanced) + __must_hold(mas->tree->lock) +{ + struct maple_node *mn = mas_mn(mas); + struct maple_enode *old_enode; + unsigned char offset = 0; + void __rcu **slots = NULL; + + if (ma_is_root(mn)) { + old_enode = mas_root_locked(mas); + } else { + offset = mte_parent_slot(mas->node); + slots = ma_slots(mte_parent(mas->node), + mas_parent_enum(mas, mas->node)); + old_enode = mas_slot_locked(mas, slots, offset); + } + + if (!advanced && !mte_is_leaf(mas->node)) + mas_adopt_children(mas, mas->node); + + if (mte_is_root(mas->node)) { + mn->parent = ma_parent_ptr( + ((unsigned long)mas->tree | MA_ROOT_PARENT)); + rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node)); + mas_set_height(mas); + } else { + rcu_assign_pointer(slots[offset], mas->node); + } + + if (!advanced) + mas_free(mas, old_enode); +} + +/* + * mas_new_child() - Find the new child of a node. + * @mas: the maple state + * @child: the maple state to store the child. + */ +static inline bool mas_new_child(struct ma_state *mas, struct ma_state *child) + __must_hold(mas->tree->lock) +{ + enum maple_type mt; + unsigned char offset; + unsigned char end; + unsigned long *pivots; + struct maple_enode *entry; + struct maple_node *node; + void __rcu **slots; + + mt = mte_node_type(mas->node); + node = mas_mn(mas); + slots = ma_slots(node, mt); + pivots = ma_pivots(node, mt); + end = ma_data_end(node, mt, pivots, mas->max); + for (offset = mas->offset; offset <= end; offset++) { + entry = mas_slot_locked(mas, slots, offset); + if (mte_parent(entry) == node) { + *child = *mas; + mas->offset = offset + 1; + child->offset = offset; + mas_descend(child); + child->offset = 0; + return true; + } + } + return false; +} + +/* + * mab_shift_right() - Shift the data in mab right. Note, does not clean out the + * old data or set b_node->b_end. + * @b_node: the maple_big_node + * @shift: the shift count + */ +static inline void mab_shift_right(struct maple_big_node *b_node, + unsigned char shift) +{ + unsigned long size = b_node->b_end * sizeof(unsigned long); + + memmove(b_node->pivot + shift, b_node->pivot, size); + memmove(b_node->slot + shift, b_node->slot, size); + if (b_node->type == maple_arange_64) + memmove(b_node->gap + shift, b_node->gap, size); +} + +/* + * mab_middle_node() - Check if a middle node is needed (unlikely) + * @b_node: the maple_big_node that contains the data. + * @size: the amount of data in the b_node + * @split: the potential split location + * @slot_count: the size that can be stored in a single node being considered. + * + * Return: true if a middle node is required. + */ +static inline bool mab_middle_node(struct maple_big_node *b_node, int split, + unsigned char slot_count) +{ + unsigned char size = b_node->b_end; + + if (size >= 2 * slot_count) + return true; + + if (!b_node->slot[split] && (size >= 2 * slot_count - 1)) + return true; + + return false; +} + +/* + * mab_no_null_split() - ensure the split doesn't fall on a NULL + * @b_node: the maple_big_node with the data + * @split: the suggested split location + * @slot_count: the number of slots in the node being considered. + * + * Return: the split location. + */ +static inline int mab_no_null_split(struct maple_big_node *b_node, + unsigned char split, unsigned char slot_count) +{ + if (!b_node->slot[split]) { + /* + * If the split is less than the max slot && the right side will + * still be sufficient, then increment the split on NULL. + */ + if ((split < slot_count - 1) && + (b_node->b_end - split) > (mt_min_slots[b_node->type])) + split++; + else + split--; + } + return split; +} + +/* + * mab_calc_split() - Calculate the split location and if there needs to be two + * splits. + * @bn: The maple_big_node with the data + * @mid_split: The second split, if required. 0 otherwise. + * + * Return: The first split location. The middle split is set in @mid_split. + */ +static inline int mab_calc_split(struct ma_state *mas, + struct maple_big_node *bn, unsigned char *mid_split, unsigned long min) +{ + unsigned char b_end = bn->b_end; + int split = b_end / 2; /* Assume equal split. */ + unsigned char slot_min, slot_count = mt_slots[bn->type]; + + /* + * To support gap tracking, all NULL entries are kept together and a node cannot + * end on a NULL entry, with the exception of the left-most leaf. The + * limitation means that the split of a node must be checked for this condition + * and be able to put more data in one direction or the other. + */ + if (unlikely((mas->mas_flags & MA_STATE_BULK))) { + *mid_split = 0; + split = b_end - mt_min_slots[bn->type]; + + if (!ma_is_leaf(bn->type)) + return split; + + mas->mas_flags |= MA_STATE_REBALANCE; + if (!bn->slot[split]) + split--; + return split; + } + + /* + * Although extremely rare, it is possible to enter what is known as the 3-way + * split scenario. The 3-way split comes about by means of a store of a range + * that overwrites the end and beginning of two full nodes. The result is a set + * of entries that cannot be stored in 2 nodes. Sometimes, these two nodes can + * also be located in different parent nodes which are also full. This can + * carry upwards all the way to the root in the worst case. + */ + if (unlikely(mab_middle_node(bn, split, slot_count))) { + split = b_end / 3; + *mid_split = split * 2; + } else { + slot_min = mt_min_slots[bn->type]; + + *mid_split = 0; + /* + * Avoid having a range less than the slot count unless it + * causes one node to be deficient. + * NOTE: mt_min_slots is 1 based, b_end and split are zero. + */ + while (((bn->pivot[split] - min) < slot_count - 1) && + (split < slot_count - 1) && (b_end - split > slot_min)) + split++; + } + + /* Avoid ending a node on a NULL entry */ + split = mab_no_null_split(bn, split, slot_count); + if (!(*mid_split)) + return split; + + *mid_split = mab_no_null_split(bn, *mid_split, slot_count); + + return split; +} + +/* + * mas_mab_cp() - Copy data from a maple state inclusively to a maple_big_node + * and set @b_node->b_end to the next free slot. + * @mas: The maple state + * @mas_start: The starting slot to copy + * @mas_end: The end slot to copy (inclusively) + * @b_node: The maple_big_node to place the data + * @mab_start: The starting location in maple_big_node to store the data. + */ +static inline void mas_mab_cp(struct ma_state *mas, unsigned char mas_start, + unsigned char mas_end, struct maple_big_node *b_node, + unsigned char mab_start) +{ + enum maple_type mt; + struct maple_node *node; + void __rcu **slots; + unsigned long *pivots, *gaps; + int i = mas_start, j = mab_start; + unsigned char piv_end; + + node = mas_mn(mas); + mt = mte_node_type(mas->node); + pivots = ma_pivots(node, mt); + if (!i) { + b_node->pivot[j] = pivots[i++]; + if (unlikely(i > mas_end)) + goto complete; + j++; + } + + piv_end = min(mas_end, mt_pivots[mt]); + for (; i < piv_end; i++, j++) { + b_node->pivot[j] = pivots[i]; + if (unlikely(!b_node->pivot[j])) + break; + + if (unlikely(mas->max == b_node->pivot[j])) + goto complete; + } + + if (likely(i <= mas_end)) + b_node->pivot[j] = mas_safe_pivot(mas, pivots, i, mt); + +complete: + b_node->b_end = ++j; + j -= mab_start; + slots = ma_slots(node, mt); + memcpy(b_node->slot + mab_start, slots + mas_start, sizeof(void *) * j); + if (!ma_is_leaf(mt) && mt_is_alloc(mas->tree)) { + gaps = ma_gaps(node, mt); + memcpy(b_node->gap + mab_start, gaps + mas_start, + sizeof(unsigned long) * j); + } +} + +/* + * mas_leaf_set_meta() - Set the metadata of a leaf if possible. + * @mas: The maple state + * @node: The maple node + * @pivots: pointer to the maple node pivots + * @mt: The maple type + * @end: The assumed end + * + * Note, end may be incremented within this function but not modified at the + * source. This is fine since the metadata is the last thing to be stored in a + * node during a write. + */ +static inline void mas_leaf_set_meta(struct ma_state *mas, + struct maple_node *node, unsigned long *pivots, + enum maple_type mt, unsigned char end) +{ + /* There is no room for metadata already */ + if (mt_pivots[mt] <= end) + return; + + if (pivots[end] && pivots[end] < mas->max) + end++; + + if (end < mt_slots[mt] - 1) + ma_set_meta(node, mt, 0, end); +} + +/* + * mab_mas_cp() - Copy data from maple_big_node to a maple encoded node. + * @b_node: the maple_big_node that has the data + * @mab_start: the start location in @b_node. + * @mab_end: The end location in @b_node (inclusively) + * @mas: The maple state with the maple encoded node. + */ +static inline void mab_mas_cp(struct maple_big_node *b_node, + unsigned char mab_start, unsigned char mab_end, + struct ma_state *mas, bool new_max) +{ + int i, j = 0; + enum maple_type mt = mte_node_type(mas->node); + struct maple_node *node = mte_to_node(mas->node); + void __rcu **slots = ma_slots(node, mt); + unsigned long *pivots = ma_pivots(node, mt); + unsigned long *gaps = NULL; + unsigned char end; + + if (mab_end - mab_start > mt_pivots[mt]) + mab_end--; + + if (!pivots[mt_pivots[mt] - 1]) + slots[mt_pivots[mt]] = NULL; + + i = mab_start; + do { + pivots[j++] = b_node->pivot[i++]; + } while (i <= mab_end && likely(b_node->pivot[i])); + + memcpy(slots, b_node->slot + mab_start, + sizeof(void *) * (i - mab_start)); + + if (new_max) + mas->max = b_node->pivot[i - 1]; + + end = j - 1; + if (likely(!ma_is_leaf(mt) && mt_is_alloc(mas->tree))) { + unsigned long max_gap = 0; + unsigned char offset = 15; + + gaps = ma_gaps(node, mt); + do { + gaps[--j] = b_node->gap[--i]; + if (gaps[j] > max_gap) { + offset = j; + max_gap = gaps[j]; + } + } while (j); + + ma_set_meta(node, mt, offset, end); + } else { + mas_leaf_set_meta(mas, node, pivots, mt, end); + } +} + +/* + * mas_descend_adopt() - Descend through a sub-tree and adopt children. + * @mas: the maple state with the maple encoded node of the sub-tree. + * + * Descend through a sub-tree and adopt children who do not have the correct + * parents set. Follow the parents which have the correct parents as they are + * the new entries which need to be followed to find other incorrectly set + * parents. + */ +static inline void mas_descend_adopt(struct ma_state *mas) +{ + struct ma_state list[3], next[3]; + int i, n; + + /* + * At each level there may be up to 3 correct parent pointers which indicates + * the new nodes which need to be walked to find any new nodes at a lower level. + */ + + for (i = 0; i < 3; i++) { + list[i] = *mas; + list[i].offset = 0; + next[i].offset = 0; + } + next[0] = *mas; + + while (!mte_is_leaf(list[0].node)) { + n = 0; + for (i = 0; i < 3; i++) { + if (mas_is_none(&list[i])) + continue; + + if (i && list[i-1].node == list[i].node) + continue; + + while ((n < 3) && (mas_new_child(&list[i], &next[n]))) + n++; + + mas_adopt_children(&list[i], list[i].node); + } + + while (n < 3) + next[n++].node = MAS_NONE; + + /* descend by setting the list to the children */ + for (i = 0; i < 3; i++) + list[i] = next[i]; + } +} + +/* + * mas_bulk_rebalance() - Rebalance the end of a tree after a bulk insert. + * @mas: The maple state + * @end: The maple node end + * @mt: The maple node type + */ +static inline void mas_bulk_rebalance(struct ma_state *mas, unsigned char end, + enum maple_type mt) +{ + if (!(mas->mas_flags & MA_STATE_BULK)) + return; + + if (mte_is_root(mas->node)) + return; + + if (end > mt_min_slots[mt]) { + mas->mas_flags &= ~MA_STATE_REBALANCE; + return; + } +} + +/* + * mas_store_b_node() - Store an @entry into the b_node while also copying the + * data from a maple encoded node. + * @wr_mas: the maple write state + * @b_node: the maple_big_node to fill with data + * @offset_end: the offset to end copying + * + * Return: The actual end of the data stored in @b_node + */ +static inline void mas_store_b_node(struct ma_wr_state *wr_mas, + struct maple_big_node *b_node, unsigned char offset_end) +{ + unsigned char slot; + unsigned char b_end; + /* Possible underflow of piv will wrap back to 0 before use. */ + unsigned long piv; + struct ma_state *mas = wr_mas->mas; + + b_node->type = wr_mas->type; + b_end = 0; + slot = mas->offset; + if (slot) { + /* Copy start data up to insert. */ + mas_mab_cp(mas, 0, slot - 1, b_node, 0); + b_end = b_node->b_end; + piv = b_node->pivot[b_end - 1]; + } else + piv = mas->min - 1; + + if (piv + 1 < mas->index) { + /* Handle range starting after old range */ + b_node->slot[b_end] = wr_mas->content; + if (!wr_mas->content) + b_node->gap[b_end] = mas->index - 1 - piv; + b_node->pivot[b_end++] = mas->index - 1; + } + + /* Store the new entry. */ + mas->offset = b_end; + b_node->slot[b_end] = wr_mas->entry; + b_node->pivot[b_end] = mas->last; + + /* Appended. */ + if (mas->last >= mas->max) + goto b_end; + + /* Handle new range ending before old range ends */ + piv = mas_logical_pivot(mas, wr_mas->pivots, offset_end, wr_mas->type); + if (piv > mas->last) { + if (piv == ULONG_MAX) + mas_bulk_rebalance(mas, b_node->b_end, wr_mas->type); + + if (offset_end != slot) + wr_mas->content = mas_slot_locked(mas, wr_mas->slots, + offset_end); + + b_node->slot[++b_end] = wr_mas->content; + if (!wr_mas->content) + b_node->gap[b_end] = piv - mas->last + 1; + b_node->pivot[b_end] = piv; + } + + slot = offset_end + 1; + if (slot > wr_mas->node_end) + goto b_end; + + /* Copy end data to the end of the node. */ + mas_mab_cp(mas, slot, wr_mas->node_end + 1, b_node, ++b_end); + b_node->b_end--; + return; + +b_end: + b_node->b_end = b_end; +} + +/* + * mas_prev_sibling() - Find the previous node with the same parent. + * @mas: the maple state + * + * Return: True if there is a previous sibling, false otherwise. + */ +static inline bool mas_prev_sibling(struct ma_state *mas) +{ + unsigned int p_slot = mte_parent_slot(mas->node); + + if (mte_is_root(mas->node)) + return false; + + if (!p_slot) + return false; + + mas_ascend(mas); + mas->offset = p_slot - 1; + mas_descend(mas); + return true; +} + +/* + * mas_next_sibling() - Find the next node with the same parent. + * @mas: the maple state + * + * Return: true if there is a next sibling, false otherwise. + */ +static inline bool mas_next_sibling(struct ma_state *mas) +{ + MA_STATE(parent, mas->tree, mas->index, mas->last); + + if (mte_is_root(mas->node)) + return false; + + parent = *mas; + mas_ascend(&parent); + parent.offset = mte_parent_slot(mas->node) + 1; + if (parent.offset > mas_data_end(&parent)) + return false; + + *mas = parent; + mas_descend(mas); + return true; +} + +/* + * mte_node_or_node() - Return the encoded node or MAS_NONE. + * @enode: The encoded maple node. + * + * Shorthand to avoid setting %NULLs in the tree or maple_subtree_state. + * + * Return: @enode or MAS_NONE + */ +static inline struct maple_enode *mte_node_or_none(struct maple_enode *enode) +{ + if (enode) + return enode; + + return ma_enode_ptr(MAS_NONE); +} + +/* + * mas_wr_node_walk() - Find the correct offset for the index in the @mas. + * @wr_mas: The maple write state + * + * Uses mas_slot_locked() and does not need to worry about dead nodes. + */ +static inline void mas_wr_node_walk(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + unsigned char count; + unsigned char offset; + unsigned long index, min, max; + + if (unlikely(ma_is_dense(wr_mas->type))) { + wr_mas->r_max = wr_mas->r_min = mas->index; + mas->offset = mas->index = mas->min; + return; + } + + wr_mas->node = mas_mn(wr_mas->mas); + wr_mas->pivots = ma_pivots(wr_mas->node, wr_mas->type); + count = wr_mas->node_end = ma_data_end(wr_mas->node, wr_mas->type, + wr_mas->pivots, mas->max); + offset = mas->offset; + min = mas_safe_min(mas, wr_mas->pivots, offset); + if (unlikely(offset == count)) + goto max; + + max = wr_mas->pivots[offset]; + index = mas->index; + if (unlikely(index <= max)) + goto done; + + if (unlikely(!max && offset)) + goto max; + + min = max + 1; + while (++offset < count) { + max = wr_mas->pivots[offset]; + if (index <= max) + goto done; + else if (unlikely(!max)) + break; + + min = max + 1; + } + +max: + max = mas->max; +done: + wr_mas->r_max = max; + wr_mas->r_min = min; + wr_mas->offset_end = mas->offset = offset; +} + +/* + * mas_topiary_range() - Add a range of slots to the topiary. + * @mas: The maple state + * @destroy: The topiary to add the slots (usually destroy) + * @start: The starting slot inclusively + * @end: The end slot inclusively + */ +static inline void mas_topiary_range(struct ma_state *mas, + struct ma_topiary *destroy, unsigned char start, unsigned char end) +{ + void __rcu **slots; + unsigned char offset; + + MT_BUG_ON(mas->tree, mte_is_leaf(mas->node)); + slots = ma_slots(mas_mn(mas), mte_node_type(mas->node)); + for (offset = start; offset <= end; offset++) { + struct maple_enode *enode = mas_slot_locked(mas, slots, offset); + + if (mte_dead_node(enode)) + continue; + + mat_add(destroy, enode); + } +} + +/* + * mast_topiary() - Add the portions of the tree to the removal list; either to + * be freed or discarded (destroy walk). + * @mast: The maple_subtree_state. + */ +static inline void mast_topiary(struct maple_subtree_state *mast) +{ + MA_WR_STATE(wr_mas, mast->orig_l, NULL); + unsigned char r_start, r_end; + unsigned char l_start, l_end; + void __rcu **l_slots, **r_slots; + + wr_mas.type = mte_node_type(mast->orig_l->node); + mast->orig_l->index = mast->orig_l->last; + mas_wr_node_walk(&wr_mas); + l_start = mast->orig_l->offset + 1; + l_end = mas_data_end(mast->orig_l); + r_start = 0; + r_end = mast->orig_r->offset; + + if (r_end) + r_end--; + + l_slots = ma_slots(mas_mn(mast->orig_l), + mte_node_type(mast->orig_l->node)); + + r_slots = ma_slots(mas_mn(mast->orig_r), + mte_node_type(mast->orig_r->node)); + + if ((l_start < l_end) && + mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_start))) { + l_start++; + } + + if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_end))) { + if (r_end) + r_end--; + } + + if ((l_start > r_end) && (mast->orig_l->node == mast->orig_r->node)) + return; + + /* At the node where left and right sides meet, add the parts between */ + if (mast->orig_l->node == mast->orig_r->node) { + return mas_topiary_range(mast->orig_l, mast->destroy, + l_start, r_end); + } + + /* mast->orig_r is different and consumed. */ + if (mte_is_leaf(mast->orig_r->node)) + return; + + if (mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_end))) + l_end--; + + + if (l_start <= l_end) + mas_topiary_range(mast->orig_l, mast->destroy, l_start, l_end); + + if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_start))) + r_start++; + + if (r_start <= r_end) + mas_topiary_range(mast->orig_r, mast->destroy, 0, r_end); +} + +/* + * mast_rebalance_next() - Rebalance against the next node + * @mast: The maple subtree state + * @old_r: The encoded maple node to the right (next node). + */ +static inline void mast_rebalance_next(struct maple_subtree_state *mast) +{ + unsigned char b_end = mast->bn->b_end; + + mas_mab_cp(mast->orig_r, 0, mt_slot_count(mast->orig_r->node), + mast->bn, b_end); + mast->orig_r->last = mast->orig_r->max; +} + +/* + * mast_rebalance_prev() - Rebalance against the previous node + * @mast: The maple subtree state + * @old_l: The encoded maple node to the left (previous node) + */ +static inline void mast_rebalance_prev(struct maple_subtree_state *mast) +{ + unsigned char end = mas_data_end(mast->orig_l) + 1; + unsigned char b_end = mast->bn->b_end; + + mab_shift_right(mast->bn, end); + mas_mab_cp(mast->orig_l, 0, end - 1, mast->bn, 0); + mast->l->min = mast->orig_l->min; + mast->orig_l->index = mast->orig_l->min; + mast->bn->b_end = end + b_end; + mast->l->offset += end; +} + +/* + * mast_spanning_rebalance() - Rebalance nodes with nearest neighbour favouring + * the node to the right. Checking the nodes to the right then the left at each + * level upwards until root is reached. Free and destroy as needed. + * Data is copied into the @mast->bn. + * @mast: The maple_subtree_state. + */ +static inline +bool mast_spanning_rebalance(struct maple_subtree_state *mast) +{ + struct ma_state r_tmp = *mast->orig_r; + struct ma_state l_tmp = *mast->orig_l; + struct maple_enode *ancestor = NULL; + unsigned char start, end; + unsigned char depth = 0; + + r_tmp = *mast->orig_r; + l_tmp = *mast->orig_l; + do { + mas_ascend(mast->orig_r); + mas_ascend(mast->orig_l); + depth++; + if (!ancestor && + (mast->orig_r->node == mast->orig_l->node)) { + ancestor = mast->orig_r->node; + end = mast->orig_r->offset - 1; + start = mast->orig_l->offset + 1; + } + + if (mast->orig_r->offset < mas_data_end(mast->orig_r)) { + if (!ancestor) { + ancestor = mast->orig_r->node; + start = 0; + } + + mast->orig_r->offset++; + do { + mas_descend(mast->orig_r); + mast->orig_r->offset = 0; + depth--; + } while (depth); + + mast_rebalance_next(mast); + do { + unsigned char l_off = 0; + struct maple_enode *child = r_tmp.node; + + mas_ascend(&r_tmp); + if (ancestor == r_tmp.node) + l_off = start; + + if (r_tmp.offset) + r_tmp.offset--; + + if (l_off < r_tmp.offset) + mas_topiary_range(&r_tmp, mast->destroy, + l_off, r_tmp.offset); + + if (l_tmp.node != child) + mat_add(mast->free, child); + + } while (r_tmp.node != ancestor); + + *mast->orig_l = l_tmp; + return true; + + } else if (mast->orig_l->offset != 0) { + if (!ancestor) { + ancestor = mast->orig_l->node; + end = mas_data_end(mast->orig_l); + } + + mast->orig_l->offset--; + do { + mas_descend(mast->orig_l); + mast->orig_l->offset = + mas_data_end(mast->orig_l); + depth--; + } while (depth); + + mast_rebalance_prev(mast); + do { + unsigned char r_off; + struct maple_enode *child = l_tmp.node; + + mas_ascend(&l_tmp); + if (ancestor == l_tmp.node) + r_off = end; + else + r_off = mas_data_end(&l_tmp); + + if (l_tmp.offset < r_off) + l_tmp.offset++; + + if (l_tmp.offset < r_off) + mas_topiary_range(&l_tmp, mast->destroy, + l_tmp.offset, r_off); + + if (r_tmp.node != child) + mat_add(mast->free, child); + + } while (l_tmp.node != ancestor); + + *mast->orig_r = r_tmp; + return true; + } + } while (!mte_is_root(mast->orig_r->node)); + + *mast->orig_r = r_tmp; + *mast->orig_l = l_tmp; + return false; +} + +/* + * mast_ascend_free() - Add current original maple state nodes to the free list + * and ascend. + * @mast: the maple subtree state. + * + * Ascend the original left and right sides and add the previous nodes to the + * free list. Set the slots to point to the correct location in the new nodes. + */ +static inline void +mast_ascend_free(struct maple_subtree_state *mast) +{ + MA_WR_STATE(wr_mas, mast->orig_r, NULL); + struct maple_enode *left = mast->orig_l->node; + struct maple_enode *right = mast->orig_r->node; + + mas_ascend(mast->orig_l); + mas_ascend(mast->orig_r); + mat_add(mast->free, left); + + if (left != right) + mat_add(mast->free, right); + + mast->orig_r->offset = 0; + mast->orig_r->index = mast->r->max; + /* last should be larger than or equal to index */ + if (mast->orig_r->last < mast->orig_r->index) + mast->orig_r->last = mast->orig_r->index; + /* + * The node may not contain the value so set slot to ensure all + * of the nodes contents are freed or destroyed. + */ + wr_mas.type = mte_node_type(mast->orig_r->node); + mas_wr_node_walk(&wr_mas); + /* Set up the left side of things */ + mast->orig_l->offset = 0; + mast->orig_l->index = mast->l->min; + wr_mas.mas = mast->orig_l; + wr_mas.type = mte_node_type(mast->orig_l->node); + mas_wr_node_walk(&wr_mas); + + mast->bn->type = wr_mas.type; +} + +/* + * mas_new_ma_node() - Create and return a new maple node. Helper function. + * @mas: the maple state with the allocations. + * @b_node: the maple_big_node with the type encoding. + * + * Use the node type from the maple_big_node to allocate a new node from the + * ma_state. This function exists mainly for code readability. + * + * Return: A new maple encoded node + */ +static inline struct maple_enode +*mas_new_ma_node(struct ma_state *mas, struct maple_big_node *b_node) +{ + return mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)), b_node->type); +} + +/* + * mas_mab_to_node() - Set up right and middle nodes + * + * @mas: the maple state that contains the allocations. + * @b_node: the node which contains the data. + * @left: The pointer which will have the left node + * @right: The pointer which may have the right node + * @middle: the pointer which may have the middle node (rare) + * @mid_split: the split location for the middle node + * + * Return: the split of left. + */ +static inline unsigned char mas_mab_to_node(struct ma_state *mas, + struct maple_big_node *b_node, struct maple_enode **left, + struct maple_enode **right, struct maple_enode **middle, + unsigned char *mid_split, unsigned long min) +{ + unsigned char split = 0; + unsigned char slot_count = mt_slots[b_node->type]; + + *left = mas_new_ma_node(mas, b_node); + *right = NULL; + *middle = NULL; + *mid_split = 0; + + if (b_node->b_end < slot_count) { + split = b_node->b_end; + } else { + split = mab_calc_split(mas, b_node, mid_split, min); + *right = mas_new_ma_node(mas, b_node); + } + + if (*mid_split) + *middle = mas_new_ma_node(mas, b_node); + + return split; + +} + +/* + * mab_set_b_end() - Add entry to b_node at b_node->b_end and increment the end + * pointer. + * @b_node - the big node to add the entry + * @mas - the maple state to get the pivot (mas->max) + * @entry - the entry to add, if NULL nothing happens. + */ +static inline void mab_set_b_end(struct maple_big_node *b_node, + struct ma_state *mas, + void *entry) +{ + if (!entry) + return; + + b_node->slot[b_node->b_end] = entry; + if (mt_is_alloc(mas->tree)) + b_node->gap[b_node->b_end] = mas_max_gap(mas); + b_node->pivot[b_node->b_end++] = mas->max; +} + +/* + * mas_set_split_parent() - combine_then_separate helper function. Sets the parent + * of @mas->node to either @left or @right, depending on @slot and @split + * + * @mas - the maple state with the node that needs a parent + * @left - possible parent 1 + * @right - possible parent 2 + * @slot - the slot the mas->node was placed + * @split - the split location between @left and @right + */ +static inline void mas_set_split_parent(struct ma_state *mas, + struct maple_enode *left, + struct maple_enode *right, + unsigned char *slot, unsigned char split) +{ + if (mas_is_none(mas)) + return; + + if ((*slot) <= split) + mte_set_parent(mas->node, left, *slot); + else if (right) + mte_set_parent(mas->node, right, (*slot) - split - 1); + + (*slot)++; +} + +/* + * mte_mid_split_check() - Check if the next node passes the mid-split + * @**l: Pointer to left encoded maple node. + * @**m: Pointer to middle encoded maple node. + * @**r: Pointer to right encoded maple node. + * @slot: The offset + * @*split: The split location. + * @mid_split: The middle split. + */ +static inline void mte_mid_split_check(struct maple_enode **l, + struct maple_enode **r, + struct maple_enode *right, + unsigned char slot, + unsigned char *split, + unsigned char mid_split) +{ + if (*r == right) + return; + + if (slot < mid_split) + return; + + *l = *r; + *r = right; + *split = mid_split; +} + +/* + * mast_set_split_parents() - Helper function to set three nodes parents. Slot + * is taken from @mast->l. + * @mast - the maple subtree state + * @left - the left node + * @right - the right node + * @split - the split location. + */ +static inline void mast_set_split_parents(struct maple_subtree_state *mast, + struct maple_enode *left, + struct maple_enode *middle, + struct maple_enode *right, + unsigned char split, + unsigned char mid_split) +{ + unsigned char slot; + struct maple_enode *l = left; + struct maple_enode *r = right; + + if (mas_is_none(mast->l)) + return; + + if (middle) + r = middle; + + slot = mast->l->offset; + + mte_mid_split_check(&l, &r, right, slot, &split, mid_split); + mas_set_split_parent(mast->l, l, r, &slot, split); + + mte_mid_split_check(&l, &r, right, slot, &split, mid_split); + mas_set_split_parent(mast->m, l, r, &slot, split); + + mte_mid_split_check(&l, &r, right, slot, &split, mid_split); + mas_set_split_parent(mast->r, l, r, &slot, split); +} + +/* + * mas_wmb_replace() - Write memory barrier and replace + * @mas: The maple state + * @free: the maple topiary list of nodes to free + * @destroy: The maple topiary list of nodes to destroy (walk and free) + * + * Updates gap as necessary. + */ +static inline void mas_wmb_replace(struct ma_state *mas, + struct ma_topiary *free, + struct ma_topiary *destroy) +{ + /* All nodes must see old data as dead prior to replacing that data */ + smp_wmb(); /* Needed for RCU */ + + /* Insert the new data in the tree */ + mas_replace(mas, true); + + if (!mte_is_leaf(mas->node)) + mas_descend_adopt(mas); + + mas_mat_free(mas, free); + + if (destroy) + mas_mat_destroy(mas, destroy); + + if (mte_is_leaf(mas->node)) + return; + + mas_update_gap(mas); +} + +/* + * mast_new_root() - Set a new tree root during subtree creation + * @mast: The maple subtree state + * @mas: The maple state + */ +static inline void mast_new_root(struct maple_subtree_state *mast, + struct ma_state *mas) +{ + mas_mn(mast->l)->parent = + ma_parent_ptr(((unsigned long)mas->tree | MA_ROOT_PARENT)); + if (!mte_dead_node(mast->orig_l->node) && + !mte_is_root(mast->orig_l->node)) { + do { + mast_ascend_free(mast); + mast_topiary(mast); + } while (!mte_is_root(mast->orig_l->node)); + } + if ((mast->orig_l->node != mas->node) && + (mast->l->depth > mas_mt_height(mas))) { + mat_add(mast->free, mas->node); + } +} + +/* + * mast_cp_to_nodes() - Copy data out to nodes. + * @mast: The maple subtree state + * @left: The left encoded maple node + * @middle: The middle encoded maple node + * @right: The right encoded maple node + * @split: The location to split between left and (middle ? middle : right) + * @mid_split: The location to split between middle and right. + */ +static inline void mast_cp_to_nodes(struct maple_subtree_state *mast, + struct maple_enode *left, struct maple_enode *middle, + struct maple_enode *right, unsigned char split, unsigned char mid_split) +{ + bool new_lmax = true; + + mast->l->node = mte_node_or_none(left); + mast->m->node = mte_node_or_none(middle); + mast->r->node = mte_node_or_none(right); + + mast->l->min = mast->orig_l->min; + if (split == mast->bn->b_end) { + mast->l->max = mast->orig_r->max; + new_lmax = false; + } + + mab_mas_cp(mast->bn, 0, split, mast->l, new_lmax); + + if (middle) { + mab_mas_cp(mast->bn, 1 + split, mid_split, mast->m, true); + mast->m->min = mast->bn->pivot[split] + 1; + split = mid_split; + } + + mast->r->max = mast->orig_r->max; + if (right) { + mab_mas_cp(mast->bn, 1 + split, mast->bn->b_end, mast->r, false); + mast->r->min = mast->bn->pivot[split] + 1; + } +} + +/* + * mast_combine_cp_left - Copy in the original left side of the tree into the + * combined data set in the maple subtree state big node. + * @mast: The maple subtree state + */ +static inline void mast_combine_cp_left(struct maple_subtree_state *mast) +{ + unsigned char l_slot = mast->orig_l->offset; + + if (!l_slot) + return; + + mas_mab_cp(mast->orig_l, 0, l_slot - 1, mast->bn, 0); +} + +/* + * mast_combine_cp_right: Copy in the original right side of the tree into the + * combined data set in the maple subtree state big node. + * @mast: The maple subtree state + */ +static inline void mast_combine_cp_right(struct maple_subtree_state *mast) +{ + if (mast->bn->pivot[mast->bn->b_end - 1] >= mast->orig_r->max) + return; + + mas_mab_cp(mast->orig_r, mast->orig_r->offset + 1, + mt_slot_count(mast->orig_r->node), mast->bn, + mast->bn->b_end); + mast->orig_r->last = mast->orig_r->max; +} + +/* + * mast_sufficient: Check if the maple subtree state has enough data in the big + * node to create at least one sufficient node + * @mast: the maple subtree state + */ +static inline bool mast_sufficient(struct maple_subtree_state *mast) +{ + if (mast->bn->b_end > mt_min_slot_count(mast->orig_l->node)) + return true; + + return false; +} + +/* + * mast_overflow: Check if there is too much data in the subtree state for a + * single node. + * @mast: The maple subtree state + */ +static inline bool mast_overflow(struct maple_subtree_state *mast) +{ + if (mast->bn->b_end >= mt_slot_count(mast->orig_l->node)) + return true; + + return false; +} + +static inline void *mtree_range_walk(struct ma_state *mas) +{ + unsigned long *pivots; + unsigned char offset; + struct maple_node *node; + struct maple_enode *next, *last; + enum maple_type type; + void __rcu **slots; + unsigned char end; + unsigned long max, min; + unsigned long prev_max, prev_min; + + next = mas->node; + min = mas->min; + max = mas->max; + do { + offset = 0; + last = next; + node = mte_to_node(next); + type = mte_node_type(next); + pivots = ma_pivots(node, type); + end = ma_data_end(node, type, pivots, max); + if (unlikely(ma_dead_node(node))) + goto dead_node; + + if (pivots[offset] >= mas->index) { + prev_max = max; + prev_min = min; + max = pivots[offset]; + goto next; + } + + do { + offset++; + } while ((offset < end) && (pivots[offset] < mas->index)); + + prev_min = min; + min = pivots[offset - 1] + 1; + prev_max = max; + if (likely(offset < end && pivots[offset])) + max = pivots[offset]; + +next: + slots = ma_slots(node, type); + next = mt_slot(mas->tree, slots, offset); + if (unlikely(ma_dead_node(node))) + goto dead_node; + } while (!ma_is_leaf(type)); + + mas->offset = offset; + mas->index = min; + mas->last = max; + mas->min = prev_min; + mas->max = prev_max; + mas->node = last; + return (void *) next; + +dead_node: + mas_reset(mas); + return NULL; +} + +/* + * mas_spanning_rebalance() - Rebalance across two nodes which may not be peers. + * @mas: The starting maple state + * @mast: The maple_subtree_state, keeps track of 4 maple states. + * @count: The estimated count of iterations needed. + * + * Follow the tree upwards from @l_mas and @r_mas for @count, or until the root + * is hit. First @b_node is split into two entries which are inserted into the + * next iteration of the loop. @b_node is returned populated with the final + * iteration. @mas is used to obtain allocations. orig_l_mas keeps track of the + * nodes that will remain active by using orig_l_mas->index and orig_l_mas->last + * to account of what has been copied into the new sub-tree. The update of + * orig_l_mas->last is used in mas_consume to find the slots that will need to + * be either freed or destroyed. orig_l_mas->depth keeps track of the height of + * the new sub-tree in case the sub-tree becomes the full tree. + * + * Return: the number of elements in b_node during the last loop. + */ +static int mas_spanning_rebalance(struct ma_state *mas, + struct maple_subtree_state *mast, unsigned char count) +{ + unsigned char split, mid_split; + unsigned char slot = 0; + struct maple_enode *left = NULL, *middle = NULL, *right = NULL; + + MA_STATE(l_mas, mas->tree, mas->index, mas->index); + MA_STATE(r_mas, mas->tree, mas->index, mas->last); + MA_STATE(m_mas, mas->tree, mas->index, mas->index); + MA_TOPIARY(free, mas->tree); + MA_TOPIARY(destroy, mas->tree); + + /* + * The tree needs to be rebalanced and leaves need to be kept at the same level. + * Rebalancing is done by use of the ``struct maple_topiary``. + */ + mast->l = &l_mas; + mast->m = &m_mas; + mast->r = &r_mas; + mast->free = &free; + mast->destroy = &destroy; + l_mas.node = r_mas.node = m_mas.node = MAS_NONE; + if (!(mast->orig_l->min && mast->orig_r->max == ULONG_MAX) && + unlikely(mast->bn->b_end <= mt_min_slots[mast->bn->type])) + mast_spanning_rebalance(mast); + + mast->orig_l->depth = 0; + + /* + * Each level of the tree is examined and balanced, pushing data to the left or + * right, or rebalancing against left or right nodes is employed to avoid + * rippling up the tree to limit the amount of churn. Once a new sub-section of + * the tree is created, there may be a mix of new and old nodes. The old nodes + * will have the incorrect parent pointers and currently be in two trees: the + * original tree and the partially new tree. To remedy the parent pointers in + * the old tree, the new data is swapped into the active tree and a walk down + * the tree is performed and the parent pointers are updated. + * See mas_descend_adopt() for more information.. + */ + while (count--) { + mast->bn->b_end--; + mast->bn->type = mte_node_type(mast->orig_l->node); + split = mas_mab_to_node(mas, mast->bn, &left, &right, &middle, + &mid_split, mast->orig_l->min); + mast_set_split_parents(mast, left, middle, right, split, + mid_split); + mast_cp_to_nodes(mast, left, middle, right, split, mid_split); + + /* + * Copy data from next level in the tree to mast->bn from next + * iteration + */ + memset(mast->bn, 0, sizeof(struct maple_big_node)); + mast->bn->type = mte_node_type(left); + mast->orig_l->depth++; + + /* Root already stored in l->node. */ + if (mas_is_root_limits(mast->l)) + goto new_root; + + mast_ascend_free(mast); + mast_combine_cp_left(mast); + l_mas.offset = mast->bn->b_end; + mab_set_b_end(mast->bn, &l_mas, left); + mab_set_b_end(mast->bn, &m_mas, middle); + mab_set_b_end(mast->bn, &r_mas, right); + + /* Copy anything necessary out of the right node. */ + mast_combine_cp_right(mast); + mast_topiary(mast); + mast->orig_l->last = mast->orig_l->max; + + if (mast_sufficient(mast)) + continue; + + if (mast_overflow(mast)) + continue; + + /* May be a new root stored in mast->bn */ + if (mas_is_root_limits(mast->orig_l)) + break; + + mast_spanning_rebalance(mast); + + /* rebalancing from other nodes may require another loop. */ + if (!count) + count++; + } + + l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)), + mte_node_type(mast->orig_l->node)); + mast->orig_l->depth++; + mab_mas_cp(mast->bn, 0, mt_slots[mast->bn->type] - 1, &l_mas, true); + mte_set_parent(left, l_mas.node, slot); + if (middle) + mte_set_parent(middle, l_mas.node, ++slot); + + if (right) + mte_set_parent(right, l_mas.node, ++slot); + + if (mas_is_root_limits(mast->l)) { +new_root: + mast_new_root(mast, mas); + } else { + mas_mn(&l_mas)->parent = mas_mn(mast->orig_l)->parent; + } + + if (!mte_dead_node(mast->orig_l->node)) + mat_add(&free, mast->orig_l->node); + + mas->depth = mast->orig_l->depth; + *mast->orig_l = l_mas; + mte_set_node_dead(mas->node); + + /* Set up mas for insertion. */ + mast->orig_l->depth = mas->depth; + mast->orig_l->alloc = mas->alloc; + *mas = *mast->orig_l; + mas_wmb_replace(mas, &free, &destroy); + mtree_range_walk(mas); + return mast->bn->b_end; +} + +/* + * mas_rebalance() - Rebalance a given node. + * @mas: The maple state + * @b_node: The big maple node. + * + * Rebalance two nodes into a single node or two new nodes that are sufficient. + * Continue upwards until tree is sufficient. + * + * Return: the number of elements in b_node during the last loop. + */ +static inline int mas_rebalance(struct ma_state *mas, + struct maple_big_node *b_node) +{ + char empty_count = mas_mt_height(mas); + struct maple_subtree_state mast; + unsigned char shift, b_end = ++b_node->b_end; + + MA_STATE(l_mas, mas->tree, mas->index, mas->last); + MA_STATE(r_mas, mas->tree, mas->index, mas->last); + + trace_ma_op(__func__, mas); + + /* + * Rebalancing occurs if a node is insufficient. Data is rebalanced + * against the node to the right if it exists, otherwise the node to the + * left of this node is rebalanced against this node. If rebalancing + * causes just one node to be produced instead of two, then the parent + * is also examined and rebalanced if it is insufficient. Every level + * tries to combine the data in the same way. If one node contains the + * entire range of the tree, then that node is used as a new root node. + */ + mas_node_count(mas, 1 + empty_count * 3); + if (mas_is_err(mas)) + return 0; + + mast.orig_l = &l_mas; + mast.orig_r = &r_mas; + mast.bn = b_node; + mast.bn->type = mte_node_type(mas->node); + + l_mas = r_mas = *mas; + + if (mas_next_sibling(&r_mas)) { + mas_mab_cp(&r_mas, 0, mt_slot_count(r_mas.node), b_node, b_end); + r_mas.last = r_mas.index = r_mas.max; + } else { + mas_prev_sibling(&l_mas); + shift = mas_data_end(&l_mas) + 1; + mab_shift_right(b_node, shift); + mas->offset += shift; + mas_mab_cp(&l_mas, 0, shift - 1, b_node, 0); + b_node->b_end = shift + b_end; + l_mas.index = l_mas.last = l_mas.min; + } + + return mas_spanning_rebalance(mas, &mast, empty_count); +} + +/* + * mas_destroy_rebalance() - Rebalance left-most node while destroying the maple + * state. + * @mas: The maple state + * @end: The end of the left-most node. + * + * During a mass-insert event (such as forking), it may be necessary to + * rebalance the left-most node when it is not sufficient. + */ +static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end) +{ + enum maple_type mt = mte_node_type(mas->node); + struct maple_node reuse, *newnode, *parent, *new_left, *left, *node; + struct maple_enode *eparent; + unsigned char offset, tmp, split = mt_slots[mt] / 2; + void __rcu **l_slots, **slots; + unsigned long *l_pivs, *pivs, gap; + bool in_rcu = mt_in_rcu(mas->tree); + + MA_STATE(l_mas, mas->tree, mas->index, mas->last); + + l_mas = *mas; + mas_prev_sibling(&l_mas); + + /* set up node. */ + if (in_rcu) { + /* Allocate for both left and right as well as parent. */ + mas_node_count(mas, 3); + if (mas_is_err(mas)) + return; + + newnode = mas_pop_node(mas); + } else { + newnode = &reuse; + } + + node = mas_mn(mas); + newnode->parent = node->parent; + slots = ma_slots(newnode, mt); + pivs = ma_pivots(newnode, mt); + left = mas_mn(&l_mas); + l_slots = ma_slots(left, mt); + l_pivs = ma_pivots(left, mt); + if (!l_slots[split]) + split++; + tmp = mas_data_end(&l_mas) - split; + + memcpy(slots, l_slots + split + 1, sizeof(void *) * tmp); + memcpy(pivs, l_pivs + split + 1, sizeof(unsigned long) * tmp); + pivs[tmp] = l_mas.max; + memcpy(slots + tmp, ma_slots(node, mt), sizeof(void *) * end); + memcpy(pivs + tmp, ma_pivots(node, mt), sizeof(unsigned long) * end); + + l_mas.max = l_pivs[split]; + mas->min = l_mas.max + 1; + eparent = mt_mk_node(mte_parent(l_mas.node), + mas_parent_enum(&l_mas, l_mas.node)); + tmp += end; + if (!in_rcu) { + unsigned char max_p = mt_pivots[mt]; + unsigned char max_s = mt_slots[mt]; + + if (tmp < max_p) + memset(pivs + tmp, 0, + sizeof(unsigned long *) * (max_p - tmp)); + + if (tmp < mt_slots[mt]) + memset(slots + tmp, 0, sizeof(void *) * (max_s - tmp)); + + memcpy(node, newnode, sizeof(struct maple_node)); + ma_set_meta(node, mt, 0, tmp - 1); + mte_set_pivot(eparent, mte_parent_slot(l_mas.node), + l_pivs[split]); + + /* Remove data from l_pivs. */ + tmp = split + 1; + memset(l_pivs + tmp, 0, sizeof(unsigned long) * (max_p - tmp)); + memset(l_slots + tmp, 0, sizeof(void *) * (max_s - tmp)); + ma_set_meta(left, mt, 0, split); + + goto done; + } + + /* RCU requires replacing both l_mas, mas, and parent. */ + mas->node = mt_mk_node(newnode, mt); + ma_set_meta(newnode, mt, 0, tmp); + + new_left = mas_pop_node(mas); + new_left->parent = left->parent; + mt = mte_node_type(l_mas.node); + slots = ma_slots(new_left, mt); + pivs = ma_pivots(new_left, mt); + memcpy(slots, l_slots, sizeof(void *) * split); + memcpy(pivs, l_pivs, sizeof(unsigned long) * split); + ma_set_meta(new_left, mt, 0, split); + l_mas.node = mt_mk_node(new_left, mt); + + /* replace parent. */ + offset = mte_parent_slot(mas->node); + mt = mas_parent_enum(&l_mas, l_mas.node); + parent = mas_pop_node(mas); + slots = ma_slots(parent, mt); + pivs = ma_pivots(parent, mt); + memcpy(parent, mte_to_node(eparent), sizeof(struct maple_node)); + rcu_assign_pointer(slots[offset], mas->node); + rcu_assign_pointer(slots[offset - 1], l_mas.node); + pivs[offset - 1] = l_mas.max; + eparent = mt_mk_node(parent, mt); +done: + gap = mas_leaf_max_gap(mas); + mte_set_gap(eparent, mte_parent_slot(mas->node), gap); + gap = mas_leaf_max_gap(&l_mas); + mte_set_gap(eparent, mte_parent_slot(l_mas.node), gap); + mas_ascend(mas); + + if (in_rcu) + mas_replace(mas, false); + + mas_update_gap(mas); +} + +/* + * mas_split_final_node() - Split the final node in a subtree operation. + * @mast: the maple subtree state + * @mas: The maple state + * @height: The height of the tree in case it's a new root. + */ +static inline bool mas_split_final_node(struct maple_subtree_state *mast, + struct ma_state *mas, int height) +{ + struct maple_enode *ancestor; + + if (mte_is_root(mas->node)) { + if (mt_is_alloc(mas->tree)) + mast->bn->type = maple_arange_64; + else + mast->bn->type = maple_range_64; + mas->depth = height; + } + /* + * Only a single node is used here, could be root. + * The Big_node data should just fit in a single node. + */ + ancestor = mas_new_ma_node(mas, mast->bn); + mte_set_parent(mast->l->node, ancestor, mast->l->offset); + mte_set_parent(mast->r->node, ancestor, mast->r->offset); + mte_to_node(ancestor)->parent = mas_mn(mas)->parent; + + mast->l->node = ancestor; + mab_mas_cp(mast->bn, 0, mt_slots[mast->bn->type] - 1, mast->l, true); + mas->offset = mast->bn->b_end - 1; + return true; +} + +/* + * mast_fill_bnode() - Copy data into the big node in the subtree state + * @mast: The maple subtree state + * @mas: the maple state + * @skip: The number of entries to skip for new nodes insertion. + */ +static inline void mast_fill_bnode(struct maple_subtree_state *mast, + struct ma_state *mas, + unsigned char skip) +{ + bool cp = true; + struct maple_enode *old = mas->node; + unsigned char split; + + memset(mast->bn->gap, 0, sizeof(unsigned long) * ARRAY_SIZE(mast->bn->gap)); + memset(mast->bn->slot, 0, sizeof(unsigned long) * ARRAY_SIZE(mast->bn->slot)); + memset(mast->bn->pivot, 0, sizeof(unsigned long) * ARRAY_SIZE(mast->bn->pivot)); + mast->bn->b_end = 0; + + if (mte_is_root(mas->node)) { + cp = false; + } else { + mas_ascend(mas); + mat_add(mast->free, old); + mas->offset = mte_parent_slot(mas->node); + } + + if (cp && mast->l->offset) + mas_mab_cp(mas, 0, mast->l->offset - 1, mast->bn, 0); + + split = mast->bn->b_end; + mab_set_b_end(mast->bn, mast->l, mast->l->node); + mast->r->offset = mast->bn->b_end; + mab_set_b_end(mast->bn, mast->r, mast->r->node); + if (mast->bn->pivot[mast->bn->b_end - 1] == mas->max) + cp = false; + + if (cp) + mas_mab_cp(mas, split + skip, mt_slot_count(mas->node) - 1, + mast->bn, mast->bn->b_end); + + mast->bn->b_end--; + mast->bn->type = mte_node_type(mas->node); +} + +/* + * mast_split_data() - Split the data in the subtree state big node into regular + * nodes. + * @mast: The maple subtree state + * @mas: The maple state + * @split: The location to split the big node + */ +static inline void mast_split_data(struct maple_subtree_state *mast, + struct ma_state *mas, unsigned char split) +{ + unsigned char p_slot; + + mab_mas_cp(mast->bn, 0, split, mast->l, true); + mte_set_pivot(mast->r->node, 0, mast->r->max); + mab_mas_cp(mast->bn, split + 1, mast->bn->b_end, mast->r, false); + mast->l->offset = mte_parent_slot(mas->node); + mast->l->max = mast->bn->pivot[split]; + mast->r->min = mast->l->max + 1; + if (mte_is_leaf(mas->node)) + return; + + p_slot = mast->orig_l->offset; + mas_set_split_parent(mast->orig_l, mast->l->node, mast->r->node, + &p_slot, split); + mas_set_split_parent(mast->orig_r, mast->l->node, mast->r->node, + &p_slot, split); +} + +/* + * mas_push_data() - Instead of splitting a node, it is beneficial to push the + * data to the right or left node if there is room. + * @mas: The maple state + * @height: The current height of the maple state + * @mast: The maple subtree state + * @left: Push left or not. + * + * Keeping the height of the tree low means faster lookups. + * + * Return: True if pushed, false otherwise. + */ +static inline bool mas_push_data(struct ma_state *mas, int height, + struct maple_subtree_state *mast, bool left) +{ + unsigned char slot_total = mast->bn->b_end; + unsigned char end, space, split; + + MA_STATE(tmp_mas, mas->tree, mas->index, mas->last); + tmp_mas = *mas; + tmp_mas.depth = mast->l->depth; + + if (left && !mas_prev_sibling(&tmp_mas)) + return false; + else if (!left && !mas_next_sibling(&tmp_mas)) + return false; + + end = mas_data_end(&tmp_mas); + slot_total += end; + space = 2 * mt_slot_count(mas->node) - 2; + /* -2 instead of -1 to ensure there isn't a triple split */ + if (ma_is_leaf(mast->bn->type)) + space--; + + if (mas->max == ULONG_MAX) + space--; + + if (slot_total >= space) + return false; + + /* Get the data; Fill mast->bn */ + mast->bn->b_end++; + if (left) { + mab_shift_right(mast->bn, end + 1); + mas_mab_cp(&tmp_mas, 0, end, mast->bn, 0); + mast->bn->b_end = slot_total + 1; + } else { + mas_mab_cp(&tmp_mas, 0, end, mast->bn, mast->bn->b_end); + } + + /* Configure mast for splitting of mast->bn */ + split = mt_slots[mast->bn->type] - 2; + if (left) { + /* Switch mas to prev node */ + mat_add(mast->free, mas->node); + *mas = tmp_mas; + /* Start using mast->l for the left side. */ + tmp_mas.node = mast->l->node; + *mast->l = tmp_mas; + } else { + mat_add(mast->free, tmp_mas.node); + tmp_mas.node = mast->r->node; + *mast->r = tmp_mas; + split = slot_total - split; + } + split = mab_no_null_split(mast->bn, split, mt_slots[mast->bn->type]); + /* Update parent slot for split calculation. */ + if (left) + mast->orig_l->offset += end + 1; + + mast_split_data(mast, mas, split); + mast_fill_bnode(mast, mas, 2); + mas_split_final_node(mast, mas, height + 1); + return true; +} + +/* + * mas_split() - Split data that is too big for one node into two. + * @mas: The maple state + * @b_node: The maple big node + * Return: 1 on success, 0 on failure. + */ +static int mas_split(struct ma_state *mas, struct maple_big_node *b_node) +{ + + struct maple_subtree_state mast; + int height = 0; + unsigned char mid_split, split = 0; + + /* + * Splitting is handled differently from any other B-tree; the Maple + * Tree splits upwards. Splitting up means that the split operation + * occurs when the walk of the tree hits the leaves and not on the way + * down. The reason for splitting up is that it is impossible to know + * how much space will be needed until the leaf is (or leaves are) + * reached. Since overwriting data is allowed and a range could + * overwrite more than one range or result in changing one entry into 3 + * entries, it is impossible to know if a split is required until the + * data is examined. + * + * Splitting is a balancing act between keeping allocations to a minimum + * and avoiding a 'jitter' event where a tree is expanded to make room + * for an entry followed by a contraction when the entry is removed. To + * accomplish the balance, there are empty slots remaining in both left + * and right nodes after a split. + */ + MA_STATE(l_mas, mas->tree, mas->index, mas->last); + MA_STATE(r_mas, mas->tree, mas->index, mas->last); + MA_STATE(prev_l_mas, mas->tree, mas->index, mas->last); + MA_STATE(prev_r_mas, mas->tree, mas->index, mas->last); + MA_TOPIARY(mat, mas->tree); + + trace_ma_op(__func__, mas); + mas->depth = mas_mt_height(mas); + /* Allocation failures will happen early. */ + mas_node_count(mas, 1 + mas->depth * 2); + if (mas_is_err(mas)) + return 0; + + mast.l = &l_mas; + mast.r = &r_mas; + mast.orig_l = &prev_l_mas; + mast.orig_r = &prev_r_mas; + mast.free = &mat; + mast.bn = b_node; + + while (height++ <= mas->depth) { + if (mt_slots[b_node->type] > b_node->b_end) { + mas_split_final_node(&mast, mas, height); + break; + } + + l_mas = r_mas = *mas; + l_mas.node = mas_new_ma_node(mas, b_node); + r_mas.node = mas_new_ma_node(mas, b_node); + /* + * Another way that 'jitter' is avoided is to terminate a split up early if the + * left or right node has space to spare. This is referred to as "pushing left" + * or "pushing right" and is similar to the B* tree, except the nodes left or + * right can rarely be reused due to RCU, but the ripple upwards is halted which + * is a significant savings. + */ + /* Try to push left. */ + if (mas_push_data(mas, height, &mast, true)) + break; + + /* Try to push right. */ + if (mas_push_data(mas, height, &mast, false)) + break; + + split = mab_calc_split(mas, b_node, &mid_split, prev_l_mas.min); + mast_split_data(&mast, mas, split); + /* + * Usually correct, mab_mas_cp in the above call overwrites + * r->max. + */ + mast.r->max = mas->max; + mast_fill_bnode(&mast, mas, 1); + prev_l_mas = *mast.l; + prev_r_mas = *mast.r; + } + + /* Set the original node as dead */ + mat_add(mast.free, mas->node); + mas->node = l_mas.node; + mas_wmb_replace(mas, mast.free, NULL); + mtree_range_walk(mas); + return 1; +} + +/* + * mas_reuse_node() - Reuse the node to store the data. + * @wr_mas: The maple write state + * @bn: The maple big node + * @end: The end of the data. + * + * Will always return false in RCU mode. + * + * Return: True if node was reused, false otherwise. + */ +static inline bool mas_reuse_node(struct ma_wr_state *wr_mas, + struct maple_big_node *bn, unsigned char end) +{ + /* Need to be rcu safe. */ + if (mt_in_rcu(wr_mas->mas->tree)) + return false; + + if (end > bn->b_end) { + int clear = mt_slots[wr_mas->type] - bn->b_end; + + memset(wr_mas->slots + bn->b_end, 0, sizeof(void *) * clear--); + memset(wr_mas->pivots + bn->b_end, 0, sizeof(void *) * clear); + } + mab_mas_cp(bn, 0, bn->b_end, wr_mas->mas, false); + return true; +} + +/* + * mas_commit_b_node() - Commit the big node into the tree. + * @wr_mas: The maple write state + * @b_node: The maple big node + * @end: The end of the data. + */ +static inline int mas_commit_b_node(struct ma_wr_state *wr_mas, + struct maple_big_node *b_node, unsigned char end) +{ + struct maple_node *node; + unsigned char b_end = b_node->b_end; + enum maple_type b_type = b_node->type; + + if ((b_end < mt_min_slots[b_type]) && + (!mte_is_root(wr_mas->mas->node)) && + (mas_mt_height(wr_mas->mas) > 1)) + return mas_rebalance(wr_mas->mas, b_node); + + if (b_end >= mt_slots[b_type]) + return mas_split(wr_mas->mas, b_node); + + if (mas_reuse_node(wr_mas, b_node, end)) + goto reuse_node; + + mas_node_count(wr_mas->mas, 1); + if (mas_is_err(wr_mas->mas)) + return 0; + + node = mas_pop_node(wr_mas->mas); + node->parent = mas_mn(wr_mas->mas)->parent; + wr_mas->mas->node = mt_mk_node(node, b_type); + mab_mas_cp(b_node, 0, b_end, wr_mas->mas, false); + mas_replace(wr_mas->mas, false); +reuse_node: + mas_update_gap(wr_mas->mas); + return 1; +} + +/* + * mas_root_expand() - Expand a root to a node + * @mas: The maple state + * @entry: The entry to store into the tree + */ +static inline int mas_root_expand(struct ma_state *mas, void *entry) +{ + void *contents = mas_root_locked(mas); + enum maple_type type = maple_leaf_64; + struct maple_node *node; + void __rcu **slots; + unsigned long *pivots; + int slot = 0; + + mas_node_count(mas, 1); + if (unlikely(mas_is_err(mas))) + return 0; + + node = mas_pop_node(mas); + pivots = ma_pivots(node, type); + slots = ma_slots(node, type); + node->parent = ma_parent_ptr( + ((unsigned long)mas->tree | MA_ROOT_PARENT)); + mas->node = mt_mk_node(node, type); + + if (mas->index) { + if (contents) { + rcu_assign_pointer(slots[slot], contents); + if (likely(mas->index > 1)) + slot++; + } + pivots[slot++] = mas->index - 1; + } + + rcu_assign_pointer(slots[slot], entry); + mas->offset = slot; + pivots[slot] = mas->last; + if (mas->last != ULONG_MAX) + slot++; + mas->depth = 1; + mas_set_height(mas); + + /* swap the new root into the tree */ + rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node)); + ma_set_meta(node, maple_leaf_64, 0, slot); + return slot; +} + +static inline void mas_store_root(struct ma_state *mas, void *entry) +{ + if (likely((mas->last != 0) || (mas->index != 0))) + mas_root_expand(mas, entry); + else if (((unsigned long) (entry) & 3) == 2) + mas_root_expand(mas, entry); + else { + rcu_assign_pointer(mas->tree->ma_root, entry); + mas->node = MAS_START; + } +} + +/* + * mas_is_span_wr() - Check if the write needs to be treated as a write that + * spans the node. + * @mas: The maple state + * @piv: The pivot value being written + * @type: The maple node type + * @entry: The data to write + * + * Spanning writes are writes that start in one node and end in another OR if + * the write of a %NULL will cause the node to end with a %NULL. + * + * Return: True if this is a spanning write, false otherwise. + */ +static bool mas_is_span_wr(struct ma_wr_state *wr_mas) +{ + unsigned long max; + unsigned long last = wr_mas->mas->last; + unsigned long piv = wr_mas->r_max; + enum maple_type type = wr_mas->type; + void *entry = wr_mas->entry; + + /* Contained in this pivot */ + if (piv > last) + return false; + + max = wr_mas->mas->max; + if (unlikely(ma_is_leaf(type))) { + /* Fits in the node, but may span slots. */ + if (last < max) + return false; + + /* Writes to the end of the node but not null. */ + if ((last == max) && entry) + return false; + + /* + * Writing ULONG_MAX is not a spanning write regardless of the + * value being written as long as the range fits in the node. + */ + if ((last == ULONG_MAX) && (last == max)) + return false; + } else if (piv == last) { + if (entry) + return false; + + /* Detect spanning store wr walk */ + if (last == ULONG_MAX) + return false; + } + + trace_ma_write(__func__, wr_mas->mas, piv, entry); + + return true; +} + +static inline void mas_wr_walk_descend(struct ma_wr_state *wr_mas) +{ + wr_mas->type = mte_node_type(wr_mas->mas->node); + mas_wr_node_walk(wr_mas); + wr_mas->slots = ma_slots(wr_mas->node, wr_mas->type); +} + +static inline void mas_wr_walk_traverse(struct ma_wr_state *wr_mas) +{ + wr_mas->mas->max = wr_mas->r_max; + wr_mas->mas->min = wr_mas->r_min; + wr_mas->mas->node = wr_mas->content; + wr_mas->mas->offset = 0; + wr_mas->mas->depth++; +} +/* + * mas_wr_walk() - Walk the tree for a write. + * @wr_mas: The maple write state + * + * Uses mas_slot_locked() and does not need to worry about dead nodes. + * + * Return: True if it's contained in a node, false on spanning write. + */ +static bool mas_wr_walk(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + + while (true) { + mas_wr_walk_descend(wr_mas); + if (unlikely(mas_is_span_wr(wr_mas))) + return false; + + wr_mas->content = mas_slot_locked(mas, wr_mas->slots, + mas->offset); + if (ma_is_leaf(wr_mas->type)) + return true; + + mas_wr_walk_traverse(wr_mas); + } + + return true; +} + +static bool mas_wr_walk_index(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + + while (true) { + mas_wr_walk_descend(wr_mas); + wr_mas->content = mas_slot_locked(mas, wr_mas->slots, + mas->offset); + if (ma_is_leaf(wr_mas->type)) + return true; + mas_wr_walk_traverse(wr_mas); + + } + return true; +} +/* + * mas_extend_spanning_null() - Extend a store of a %NULL to include surrounding %NULLs. + * @l_wr_mas: The left maple write state + * @r_wr_mas: The right maple write state + */ +static inline void mas_extend_spanning_null(struct ma_wr_state *l_wr_mas, + struct ma_wr_state *r_wr_mas) +{ + struct ma_state *r_mas = r_wr_mas->mas; + struct ma_state *l_mas = l_wr_mas->mas; + unsigned char l_slot; + + l_slot = l_mas->offset; + if (!l_wr_mas->content) + l_mas->index = l_wr_mas->r_min; + + if ((l_mas->index == l_wr_mas->r_min) && + (l_slot && + !mas_slot_locked(l_mas, l_wr_mas->slots, l_slot - 1))) { + if (l_slot > 1) + l_mas->index = l_wr_mas->pivots[l_slot - 2] + 1; + else + l_mas->index = l_mas->min; + + l_mas->offset = l_slot - 1; + } + + if (!r_wr_mas->content) { + if (r_mas->last < r_wr_mas->r_max) + r_mas->last = r_wr_mas->r_max; + r_mas->offset++; + } else if ((r_mas->last == r_wr_mas->r_max) && + (r_mas->last < r_mas->max) && + !mas_slot_locked(r_mas, r_wr_mas->slots, r_mas->offset + 1)) { + r_mas->last = mas_safe_pivot(r_mas, r_wr_mas->pivots, + r_wr_mas->type, r_mas->offset + 1); + r_mas->offset++; + } +} + +static inline void *mas_state_walk(struct ma_state *mas) +{ + void *entry; + + entry = mas_start(mas); + if (mas_is_none(mas)) + return NULL; + + if (mas_is_ptr(mas)) + return entry; + + return mtree_range_walk(mas); +} + +/* + * mtree_lookup_walk() - Internal quick lookup that does not keep maple state up + * to date. + * + * @mas: The maple state. + * + * Note: Leaves mas in undesirable state. + * Return: The entry for @mas->index or %NULL on dead node. + */ +static inline void *mtree_lookup_walk(struct ma_state *mas) +{ + unsigned long *pivots; + unsigned char offset; + struct maple_node *node; + struct maple_enode *next; + enum maple_type type; + void __rcu **slots; + unsigned char end; + unsigned long max; + + next = mas->node; + max = ULONG_MAX; + do { + offset = 0; + node = mte_to_node(next); + type = mte_node_type(next); + pivots = ma_pivots(node, type); + end = ma_data_end(node, type, pivots, max); + if (unlikely(ma_dead_node(node))) + goto dead_node; + + if (pivots[offset] >= mas->index) + goto next; + + do { + offset++; + } while ((offset < end) && (pivots[offset] < mas->index)); + + if (likely(offset > end)) + max = pivots[offset]; + +next: + slots = ma_slots(node, type); + next = mt_slot(mas->tree, slots, offset); + if (unlikely(ma_dead_node(node))) + goto dead_node; + } while (!ma_is_leaf(type)); + + return (void *) next; + +dead_node: + mas_reset(mas); + return NULL; +} + +/* + * mas_new_root() - Create a new root node that only contains the entry passed + * in. + * @mas: The maple state + * @entry: The entry to store. + * + * Only valid when the index == 0 and the last == ULONG_MAX + * + * Return 0 on error, 1 on success. + */ +static inline int mas_new_root(struct ma_state *mas, void *entry) +{ + struct maple_enode *root = mas_root_locked(mas); + enum maple_type type = maple_leaf_64; + struct maple_node *node; + void __rcu **slots; + unsigned long *pivots; + + if (!entry && !mas->index && mas->last == ULONG_MAX) { + mas->depth = 0; + mas_set_height(mas); + rcu_assign_pointer(mas->tree->ma_root, entry); + mas->node = MAS_START; + goto done; + } + + mas_node_count(mas, 1); + if (mas_is_err(mas)) + return 0; + + node = mas_pop_node(mas); + pivots = ma_pivots(node, type); + slots = ma_slots(node, type); + node->parent = ma_parent_ptr( + ((unsigned long)mas->tree | MA_ROOT_PARENT)); + mas->node = mt_mk_node(node, type); + rcu_assign_pointer(slots[0], entry); + pivots[0] = mas->last; + mas->depth = 1; + mas_set_height(mas); + rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node)); + +done: + if (xa_is_node(root)) + mte_destroy_walk(root, mas->tree); + + return 1; +} +/* + * mas_wr_spanning_store() - Create a subtree with the store operation completed + * and new nodes where necessary, then place the sub-tree in the actual tree. + * Note that mas is expected to point to the node which caused the store to + * span. + * @wr_mas: The maple write state + * + * Return: 0 on error, positive on success. + */ +static inline int mas_wr_spanning_store(struct ma_wr_state *wr_mas) +{ + struct maple_subtree_state mast; + struct maple_big_node b_node; + struct ma_state *mas; + unsigned char height; + + /* Left and Right side of spanning store */ + MA_STATE(l_mas, NULL, 0, 0); + MA_STATE(r_mas, NULL, 0, 0); + + MA_WR_STATE(r_wr_mas, &r_mas, wr_mas->entry); + MA_WR_STATE(l_wr_mas, &l_mas, wr_mas->entry); + + /* + * A store operation that spans multiple nodes is called a spanning + * store and is handled early in the store call stack by the function + * mas_is_span_wr(). When a spanning store is identified, the maple + * state is duplicated. The first maple state walks the left tree path + * to ``index``, the duplicate walks the right tree path to ``last``. + * The data in the two nodes are combined into a single node, two nodes, + * or possibly three nodes (see the 3-way split above). A ``NULL`` + * written to the last entry of a node is considered a spanning store as + * a rebalance is required for the operation to complete and an overflow + * of data may happen. + */ + mas = wr_mas->mas; + trace_ma_op(__func__, mas); + + if (unlikely(!mas->index && mas->last == ULONG_MAX)) + return mas_new_root(mas, wr_mas->entry); + /* + * Node rebalancing may occur due to this store, so there may be three new + * entries per level plus a new root. + */ + height = mas_mt_height(mas); + mas_node_count(mas, 1 + height * 3); + if (mas_is_err(mas)) + return 0; + + /* + * Set up right side. Need to get to the next offset after the spanning + * store to ensure it's not NULL and to combine both the next node and + * the node with the start together. + */ + r_mas = *mas; + /* Avoid overflow, walk to next slot in the tree. */ + if (r_mas.last + 1) + r_mas.last++; + + r_mas.index = r_mas.last; + mas_wr_walk_index(&r_wr_mas); + r_mas.last = r_mas.index = mas->last; + + /* Set up left side. */ + l_mas = *mas; + mas_wr_walk_index(&l_wr_mas); + + if (!wr_mas->entry) { + mas_extend_spanning_null(&l_wr_mas, &r_wr_mas); + mas->offset = l_mas.offset; + mas->index = l_mas.index; + mas->last = l_mas.last = r_mas.last; + } + + /* expanding NULLs may make this cover the entire range */ + if (!l_mas.index && r_mas.last == ULONG_MAX) { + mas_set_range(mas, 0, ULONG_MAX); + return mas_new_root(mas, wr_mas->entry); + } + + memset(&b_node, 0, sizeof(struct maple_big_node)); + /* Copy l_mas and store the value in b_node. */ + mas_store_b_node(&l_wr_mas, &b_node, l_wr_mas.node_end); + /* Copy r_mas into b_node. */ + if (r_mas.offset <= r_wr_mas.node_end) + mas_mab_cp(&r_mas, r_mas.offset, r_wr_mas.node_end, + &b_node, b_node.b_end + 1); + else + b_node.b_end++; + + /* Stop spanning searches by searching for just index. */ + l_mas.index = l_mas.last = mas->index; + + mast.bn = &b_node; + mast.orig_l = &l_mas; + mast.orig_r = &r_mas; + /* Combine l_mas and r_mas and split them up evenly again. */ + return mas_spanning_rebalance(mas, &mast, height + 1); +} + +/* + * mas_wr_node_store() - Attempt to store the value in a node + * @wr_mas: The maple write state + * + * Attempts to reuse the node, but may allocate. + * + * Return: True if stored, false otherwise + */ +static inline bool mas_wr_node_store(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + void __rcu **dst_slots; + unsigned long *dst_pivots; + unsigned char dst_offset; + unsigned char new_end = wr_mas->node_end; + unsigned char offset; + unsigned char node_slots = mt_slots[wr_mas->type]; + struct maple_node reuse, *newnode; + unsigned char copy_size, max_piv = mt_pivots[wr_mas->type]; + bool in_rcu = mt_in_rcu(mas->tree); + + offset = mas->offset; + if (mas->last == wr_mas->r_max) { + /* runs right to the end of the node */ + if (mas->last == mas->max) + new_end = offset; + /* don't copy this offset */ + wr_mas->offset_end++; + } else if (mas->last < wr_mas->r_max) { + /* new range ends in this range */ + if (unlikely(wr_mas->r_max == ULONG_MAX)) + mas_bulk_rebalance(mas, wr_mas->node_end, wr_mas->type); + + new_end++; + } else { + if (wr_mas->end_piv == mas->last) + wr_mas->offset_end++; + + new_end -= wr_mas->offset_end - offset - 1; + } + + /* new range starts within a range */ + if (wr_mas->r_min < mas->index) + new_end++; + + /* Not enough room */ + if (new_end >= node_slots) + return false; + + /* Not enough data. */ + if (!mte_is_root(mas->node) && (new_end <= mt_min_slots[wr_mas->type]) && + !(mas->mas_flags & MA_STATE_BULK)) + return false; + + /* set up node. */ + if (in_rcu) { + mas_node_count(mas, 1); + if (mas_is_err(mas)) + return false; + + newnode = mas_pop_node(mas); + } else { + memset(&reuse, 0, sizeof(struct maple_node)); + newnode = &reuse; + } + + newnode->parent = mas_mn(mas)->parent; + dst_pivots = ma_pivots(newnode, wr_mas->type); + dst_slots = ma_slots(newnode, wr_mas->type); + /* Copy from start to insert point */ + memcpy(dst_pivots, wr_mas->pivots, sizeof(unsigned long) * (offset + 1)); + memcpy(dst_slots, wr_mas->slots, sizeof(void *) * (offset + 1)); + dst_offset = offset; + + /* Handle insert of new range starting after old range */ + if (wr_mas->r_min < mas->index) { + mas->offset++; + rcu_assign_pointer(dst_slots[dst_offset], wr_mas->content); + dst_pivots[dst_offset++] = mas->index - 1; + } + + /* Store the new entry and range end. */ + if (dst_offset < max_piv) + dst_pivots[dst_offset] = mas->last; + mas->offset = dst_offset; + rcu_assign_pointer(dst_slots[dst_offset], wr_mas->entry); + + /* + * this range wrote to the end of the node or it overwrote the rest of + * the data + */ + if (wr_mas->offset_end > wr_mas->node_end || mas->last >= mas->max) { + new_end = dst_offset; + goto done; + } + + dst_offset++; + /* Copy to the end of node if necessary. */ + copy_size = wr_mas->node_end - wr_mas->offset_end + 1; + memcpy(dst_slots + dst_offset, wr_mas->slots + wr_mas->offset_end, + sizeof(void *) * copy_size); + if (dst_offset < max_piv) { + if (copy_size > max_piv - dst_offset) + copy_size = max_piv - dst_offset; + + memcpy(dst_pivots + dst_offset, + wr_mas->pivots + wr_mas->offset_end, + sizeof(unsigned long) * copy_size); + } + + if ((wr_mas->node_end == node_slots - 1) && (new_end < node_slots - 1)) + dst_pivots[new_end] = mas->max; + +done: + mas_leaf_set_meta(mas, newnode, dst_pivots, maple_leaf_64, new_end); + if (in_rcu) { + mas->node = mt_mk_node(newnode, wr_mas->type); + mas_replace(mas, false); + } else { + memcpy(wr_mas->node, newnode, sizeof(struct maple_node)); + } + trace_ma_write(__func__, mas, 0, wr_mas->entry); + mas_update_gap(mas); + return true; +} + +/* + * mas_wr_slot_store: Attempt to store a value in a slot. + * @wr_mas: the maple write state + * + * Return: True if stored, false otherwise + */ +static inline bool mas_wr_slot_store(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + unsigned long lmax; /* Logical max. */ + unsigned char offset = mas->offset; + + if ((wr_mas->r_max > mas->last) && ((wr_mas->r_min != mas->index) || + (offset != wr_mas->node_end))) + return false; + + if (offset == wr_mas->node_end - 1) + lmax = mas->max; + else + lmax = wr_mas->pivots[offset + 1]; + + /* going to overwrite too many slots. */ + if (lmax < mas->last) + return false; + + if (wr_mas->r_min == mas->index) { + /* overwriting two or more ranges with one. */ + if (lmax == mas->last) + return false; + + /* Overwriting all of offset and a portion of offset + 1. */ + rcu_assign_pointer(wr_mas->slots[offset], wr_mas->entry); + wr_mas->pivots[offset] = mas->last; + goto done; + } + + /* Doesn't end on the next range end. */ + if (lmax != mas->last) + return false; + + /* Overwriting a portion of offset and all of offset + 1 */ + if ((offset + 1 < mt_pivots[wr_mas->type]) && + (wr_mas->entry || wr_mas->pivots[offset + 1])) + wr_mas->pivots[offset + 1] = mas->last; + + rcu_assign_pointer(wr_mas->slots[offset + 1], wr_mas->entry); + wr_mas->pivots[offset] = mas->index - 1; + mas->offset++; /* Keep mas accurate. */ + +done: + trace_ma_write(__func__, mas, 0, wr_mas->entry); + mas_update_gap(mas); + return true; +} + +static inline void mas_wr_end_piv(struct ma_wr_state *wr_mas) +{ + while ((wr_mas->mas->last > wr_mas->end_piv) && + (wr_mas->offset_end < wr_mas->node_end)) + wr_mas->end_piv = wr_mas->pivots[++wr_mas->offset_end]; + + if (wr_mas->mas->last > wr_mas->end_piv) + wr_mas->end_piv = wr_mas->mas->max; +} + +static inline void mas_wr_extend_null(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + + if (mas->last < wr_mas->end_piv && !wr_mas->slots[wr_mas->offset_end]) + mas->last = wr_mas->end_piv; + + /* Check next slot(s) if we are overwriting the end */ + if ((mas->last == wr_mas->end_piv) && + (wr_mas->node_end != wr_mas->offset_end) && + !wr_mas->slots[wr_mas->offset_end + 1]) { + wr_mas->offset_end++; + if (wr_mas->offset_end == wr_mas->node_end) + mas->last = mas->max; + else + mas->last = wr_mas->pivots[wr_mas->offset_end]; + wr_mas->end_piv = mas->last; + } + + if (!wr_mas->content) { + /* If this one is null, the next and prev are not */ + mas->index = wr_mas->r_min; + } else { + /* Check prev slot if we are overwriting the start */ + if (mas->index == wr_mas->r_min && mas->offset && + !wr_mas->slots[mas->offset - 1]) { + mas->offset--; + wr_mas->r_min = mas->index = + mas_safe_min(mas, wr_mas->pivots, mas->offset); + wr_mas->r_max = wr_mas->pivots[mas->offset]; + } + } +} + +static inline bool mas_wr_append(struct ma_wr_state *wr_mas) +{ + unsigned char end = wr_mas->node_end; + unsigned char new_end = end + 1; + struct ma_state *mas = wr_mas->mas; + unsigned char node_pivots = mt_pivots[wr_mas->type]; + + if ((mas->index != wr_mas->r_min) && (mas->last == wr_mas->r_max)) { + if (new_end < node_pivots) + wr_mas->pivots[new_end] = wr_mas->pivots[end]; + + if (new_end < node_pivots) + ma_set_meta(wr_mas->node, maple_leaf_64, 0, new_end); + + rcu_assign_pointer(wr_mas->slots[new_end], wr_mas->entry); + mas->offset = new_end; + wr_mas->pivots[end] = mas->index - 1; + + return true; + } + + if ((mas->index == wr_mas->r_min) && (mas->last < wr_mas->r_max)) { + if (new_end < node_pivots) + wr_mas->pivots[new_end] = wr_mas->pivots[end]; + + rcu_assign_pointer(wr_mas->slots[new_end], wr_mas->content); + if (new_end < node_pivots) + ma_set_meta(wr_mas->node, maple_leaf_64, 0, new_end); + + wr_mas->pivots[end] = mas->last; + rcu_assign_pointer(wr_mas->slots[end], wr_mas->entry); + return true; + } + + return false; +} + +/* + * mas_wr_bnode() - Slow path for a modification. + * @wr_mas: The write maple state + * + * This is where split, rebalance end up. + */ +static void mas_wr_bnode(struct ma_wr_state *wr_mas) +{ + struct maple_big_node b_node; + + trace_ma_write(__func__, wr_mas->mas, 0, wr_mas->entry); + memset(&b_node, 0, sizeof(struct maple_big_node)); + mas_store_b_node(wr_mas, &b_node, wr_mas->offset_end); + mas_commit_b_node(wr_mas, &b_node, wr_mas->node_end); +} + +static inline void mas_wr_modify(struct ma_wr_state *wr_mas) +{ + unsigned char node_slots; + unsigned char node_size; + struct ma_state *mas = wr_mas->mas; + + /* Direct replacement */ + if (wr_mas->r_min == mas->index && wr_mas->r_max == mas->last) { + rcu_assign_pointer(wr_mas->slots[mas->offset], wr_mas->entry); + if (!!wr_mas->entry ^ !!wr_mas->content) + mas_update_gap(mas); + return; + } + + /* Attempt to append */ + node_slots = mt_slots[wr_mas->type]; + node_size = wr_mas->node_end - wr_mas->offset_end + mas->offset + 2; + if (mas->max == ULONG_MAX) + node_size++; + + /* slot and node store will not fit, go to the slow path */ + if (unlikely(node_size >= node_slots)) + goto slow_path; + + if (wr_mas->entry && (wr_mas->node_end < node_slots - 1) && + (mas->offset == wr_mas->node_end) && mas_wr_append(wr_mas)) { + if (!wr_mas->content || !wr_mas->entry) + mas_update_gap(mas); + return; + } + + if ((wr_mas->offset_end - mas->offset <= 1) && mas_wr_slot_store(wr_mas)) + return; + else if (mas_wr_node_store(wr_mas)) + return; + + if (mas_is_err(mas)) + return; + +slow_path: + mas_wr_bnode(wr_mas); +} + +/* + * mas_wr_store_entry() - Internal call to store a value + * @mas: The maple state + * @entry: The entry to store. + * + * Return: The contents that was stored at the index. + */ +static inline void *mas_wr_store_entry(struct ma_wr_state *wr_mas) +{ + struct ma_state *mas = wr_mas->mas; + + wr_mas->content = mas_start(mas); + if (mas_is_none(mas) || mas_is_ptr(mas)) { + mas_store_root(mas, wr_mas->entry); + return wr_mas->content; + } + + if (unlikely(!mas_wr_walk(wr_mas))) { + mas_wr_spanning_store(wr_mas); + return wr_mas->content; + } + + /* At this point, we are at the leaf node that needs to be altered. */ + wr_mas->end_piv = wr_mas->r_max; + mas_wr_end_piv(wr_mas); + + if (!wr_mas->entry) + mas_wr_extend_null(wr_mas); + + /* New root for a single pointer */ + if (unlikely(!mas->index && mas->last == ULONG_MAX)) { + mas_new_root(mas, wr_mas->entry); + return wr_mas->content; + } + + mas_wr_modify(wr_mas); + return wr_mas->content; +} + +/** + * mas_insert() - Internal call to insert a value + * @mas: The maple state + * @entry: The entry to store + * + * Return: %NULL or the contents that already exists at the requested index + * otherwise. The maple state needs to be checked for error conditions. + */ +static inline void *mas_insert(struct ma_state *mas, void *entry) +{ + MA_WR_STATE(wr_mas, mas, entry); + + /* + * Inserting a new range inserts either 0, 1, or 2 pivots within the + * tree. If the insert fits exactly into an existing gap with a value + * of NULL, then the slot only needs to be written with the new value. + * If the range being inserted is adjacent to another range, then only a + * single pivot needs to be inserted (as well as writing the entry). If + * the new range is within a gap but does not touch any other ranges, + * then two pivots need to be inserted: the start - 1, and the end. As + * usual, the entry must be written. Most operations require a new node + * to be allocated and replace an existing node to ensure RCU safety, + * when in RCU mode. The exception to requiring a newly allocated node + * is when inserting at the end of a node (appending). When done + * carefully, appending can reuse the node in place. + */ + wr_mas.content = mas_start(mas); + if (wr_mas.content) + goto exists; + + if (mas_is_none(mas) || mas_is_ptr(mas)) { + mas_store_root(mas, entry); + return NULL; + } + + /* spanning writes always overwrite something */ + if (!mas_wr_walk(&wr_mas)) + goto exists; + + /* At this point, we are at the leaf node that needs to be altered. */ + wr_mas.offset_end = mas->offset; + wr_mas.end_piv = wr_mas.r_max; + + if (wr_mas.content || (mas->last > wr_mas.r_max)) + goto exists; + + if (!entry) + return NULL; + + mas_wr_modify(&wr_mas); + return wr_mas.content; + +exists: + mas_set_err(mas, -EEXIST); + return wr_mas.content; + +} + +/* + * mas_prev_node() - Find the prev non-null entry at the same level in the + * tree. The prev value will be mas->node[mas->offset] or MAS_NONE. + * @mas: The maple state + * @min: The lower limit to search + * + * The prev node value will be mas->node[mas->offset] or MAS_NONE. + * Return: 1 if the node is dead, 0 otherwise. + */ +static inline int mas_prev_node(struct ma_state *mas, unsigned long min) +{ + enum maple_type mt; + int offset, level; + void __rcu **slots; + struct maple_node *node; + struct maple_enode *enode; + unsigned long *pivots; + + if (mas_is_none(mas)) + return 0; + + level = 0; + do { + node = mas_mn(mas); + if (ma_is_root(node)) + goto no_entry; + + /* Walk up. */ + if (unlikely(mas_ascend(mas))) + return 1; + offset = mas->offset; + level++; + } while (!offset); + + offset--; + mt = mte_node_type(mas->node); + node = mas_mn(mas); + slots = ma_slots(node, mt); + pivots = ma_pivots(node, mt); + mas->max = pivots[offset]; + if (offset) + mas->min = pivots[offset - 1] + 1; + if (unlikely(ma_dead_node(node))) + return 1; + + if (mas->max < min) + goto no_entry_min; + + while (level > 1) { + level--; + enode = mas_slot(mas, slots, offset); + if (unlikely(ma_dead_node(node))) + return 1; + + mas->node = enode; + mt = mte_node_type(mas->node); + node = mas_mn(mas); + slots = ma_slots(node, mt); + pivots = ma_pivots(node, mt); + offset = ma_data_end(node, mt, pivots, mas->max); + if (offset) + mas->min = pivots[offset - 1] + 1; + + if (offset < mt_pivots[mt]) + mas->max = pivots[offset]; + + if (mas->max < min) + goto no_entry; + } + + mas->node = mas_slot(mas, slots, offset); + if (unlikely(ma_dead_node(node))) + return 1; + + mas->offset = mas_data_end(mas); + if (unlikely(mte_dead_node(mas->node))) + return 1; + + return 0; + +no_entry_min: + mas->offset = offset; + if (offset) + mas->min = pivots[offset - 1] + 1; +no_entry: + if (unlikely(ma_dead_node(node))) + return 1; + + mas->node = MAS_NONE; + return 0; +} + +/* + * mas_next_node() - Get the next node at the same level in the tree. + * @mas: The maple state + * @max: The maximum pivot value to check. + * + * The next value will be mas->node[mas->offset] or MAS_NONE. + * Return: 1 on dead node, 0 otherwise. + */ +static inline int mas_next_node(struct ma_state *mas, struct maple_node *node, + unsigned long max) +{ + unsigned long min, pivot; + unsigned long *pivots; + struct maple_enode *enode; + int level = 0; + unsigned char offset; + enum maple_type mt; + void __rcu **slots; + + if (mas->max >= max) + goto no_entry; + + level = 0; + do { + if (ma_is_root(node)) + goto no_entry; + + min = mas->max + 1; + if (min > max) + goto no_entry; + + if (unlikely(mas_ascend(mas))) + return 1; + + offset = mas->offset; + level++; + node = mas_mn(mas); + mt = mte_node_type(mas->node); + pivots = ma_pivots(node, mt); + } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max))); + + slots = ma_slots(node, mt); + pivot = mas_safe_pivot(mas, pivots, ++offset, mt); + while (unlikely(level > 1)) { + /* Descend, if necessary */ + enode = mas_slot(mas, slots, offset); + if (unlikely(ma_dead_node(node))) + return 1; + + mas->node = enode; + level--; + node = mas_mn(mas); + mt = mte_node_type(mas->node); + slots = ma_slots(node, mt); + pivots = ma_pivots(node, mt); + offset = 0; + pivot = pivots[0]; + } + + enode = mas_slot(mas, slots, offset); + if (unlikely(ma_dead_node(node))) + return 1; + + mas->node = enode; + mas->min = min; + mas->max = pivot; + return 0; + +no_entry: + if (unlikely(ma_dead_node(node))) + return 1; + + mas->node = MAS_NONE; + return 0; +} + +/* + * mas_next_nentry() - Get the next node entry + * @mas: The maple state + * @max: The maximum value to check + * @*range_start: Pointer to store the start of the range. + * + * Sets @mas->offset to the offset of the next node entry, @mas->last to the + * pivot of the entry. + * + * Return: The next entry, %NULL otherwise + */ +static inline void *mas_next_nentry(struct ma_state *mas, + struct maple_node *node, unsigned long max, enum maple_type type) +{ + unsigned char count; + unsigned long pivot; + unsigned long *pivots; + void __rcu **slots; + void *entry; + + if (mas->last == mas->max) { + mas->index = mas->max; + return NULL; + } + + pivots = ma_pivots(node, type); + slots = ma_slots(node, type); + mas->index = mas_safe_min(mas, pivots, mas->offset); + if (ma_dead_node(node)) + return NULL; + + if (mas->index > max) + return NULL; + + count = ma_data_end(node, type, pivots, mas->max); + if (mas->offset > count) + return NULL; + + while (mas->offset < count) { + pivot = pivots[mas->offset]; + entry = mas_slot(mas, slots, mas->offset); + if (ma_dead_node(node)) + return NULL; + + if (entry) + goto found; + + if (pivot >= max) + return NULL; + + mas->index = pivot + 1; + mas->offset++; + } + + if (mas->index > mas->max) { + mas->index = mas->last; + return NULL; + } + + pivot = mas_safe_pivot(mas, pivots, mas->offset, type); + entry = mas_slot(mas, slots, mas->offset); + if (ma_dead_node(node)) + return NULL; + + if (!pivot) + return NULL; + + if (!entry) + return NULL; + +found: + mas->last = pivot; + return entry; +} + +static inline void mas_rewalk(struct ma_state *mas, unsigned long index) +{ + +retry: + mas_set(mas, index); + mas_state_walk(mas); + if (mas_is_start(mas)) + goto retry; + + return; + +} + +/* + * mas_next_entry() - Internal function to get the next entry. + * @mas: The maple state + * @limit: The maximum range start. + * + * Set the @mas->node to the next entry and the range_start to + * the beginning value for the entry. Does not check beyond @limit. + * Sets @mas->index and @mas->last to the limit if it is hit. + * Restarts on dead nodes. + * + * Return: the next entry or %NULL. + */ +static inline void *mas_next_entry(struct ma_state *mas, unsigned long limit) +{ + void *entry = NULL; + struct maple_enode *prev_node; + struct maple_node *node; + unsigned char offset; + unsigned long last; + enum maple_type mt; + + last = mas->last; +retry: + offset = mas->offset; + prev_node = mas->node; + node = mas_mn(mas); + mt = mte_node_type(mas->node); + mas->offset++; + if (unlikely(mas->offset >= mt_slots[mt])) { + mas->offset = mt_slots[mt] - 1; + goto next_node; + } + + while (!mas_is_none(mas)) { + entry = mas_next_nentry(mas, node, limit, mt); + if (unlikely(ma_dead_node(node))) { + mas_rewalk(mas, last); + goto retry; + } + + if (likely(entry)) + return entry; + + if (unlikely((mas->index > limit))) + break; + +next_node: + prev_node = mas->node; + offset = mas->offset; + if (unlikely(mas_next_node(mas, node, limit))) { + mas_rewalk(mas, last); + goto retry; + } + mas->offset = 0; + node = mas_mn(mas); + mt = mte_node_type(mas->node); + } + + mas->index = mas->last = limit; + mas->offset = offset; + mas->node = prev_node; + return NULL; +} + +/* + * mas_prev_nentry() - Get the previous node entry. + * @mas: The maple state. + * @limit: The lower limit to check for a value. + * + * Return: the entry, %NULL otherwise. + */ +static inline void *mas_prev_nentry(struct ma_state *mas, unsigned long limit, + unsigned long index) +{ + unsigned long pivot, min; + unsigned char offset; + struct maple_node *mn; + enum maple_type mt; + unsigned long *pivots; + void __rcu **slots; + void *entry; + +retry: + if (!mas->offset) + return NULL; + + mn = mas_mn(mas); + mt = mte_node_type(mas->node); + offset = mas->offset - 1; + if (offset >= mt_slots[mt]) + offset = mt_slots[mt] - 1; + + slots = ma_slots(mn, mt); + pivots = ma_pivots(mn, mt); + if (offset == mt_pivots[mt]) + pivot = mas->max; + else + pivot = pivots[offset]; + + if (unlikely(ma_dead_node(mn))) { + mas_rewalk(mas, index); + goto retry; + } + + while (offset && ((!mas_slot(mas, slots, offset) && pivot >= limit) || + !pivot)) + pivot = pivots[--offset]; + + min = mas_safe_min(mas, pivots, offset); + entry = mas_slot(mas, slots, offset); + if (unlikely(ma_dead_node(mn))) { + mas_rewalk(mas, index); + goto retry; + } + + if (likely(entry)) { + mas->offset = offset; + mas->last = pivot; + mas->index = min; + } + return entry; +} + +static inline void *mas_prev_entry(struct ma_state *mas, unsigned long min) +{ + void *entry; + +retry: + while (likely(!mas_is_none(mas))) { + entry = mas_prev_nentry(mas, min, mas->index); + if (unlikely(mas->last < min)) + goto not_found; + + if (likely(entry)) + return entry; + + if (unlikely(mas_prev_node(mas, min))) { + mas_rewalk(mas, mas->index); + goto retry; + } + + mas->offset++; + } + + mas->offset--; +not_found: + mas->index = mas->last = min; + return NULL; +} + +/* + * mas_rev_awalk() - Internal function. Reverse allocation walk. Find the + * highest gap address of a given size in a given node and descend. + * @mas: The maple state + * @size: The needed size. + * + * Return: True if found in a leaf, false otherwise. + * + */ +static bool mas_rev_awalk(struct ma_state *mas, unsigned long size) +{ + enum maple_type type = mte_node_type(mas->node); + struct maple_node *node = mas_mn(mas); + unsigned long *pivots, *gaps; + void __rcu **slots; + unsigned long gap = 0; + unsigned long max, min, index; + unsigned char offset; + + if (unlikely(mas_is_err(mas))) + return true; + + if (ma_is_dense(type)) { + /* dense nodes. */ + mas->offset = (unsigned char)(mas->index - mas->min); + return true; + } + + pivots = ma_pivots(node, type); + slots = ma_slots(node, type); + gaps = ma_gaps(node, type); + offset = mas->offset; + min = mas_safe_min(mas, pivots, offset); + /* Skip out of bounds. */ + while (mas->last < min) + min = mas_safe_min(mas, pivots, --offset); + + max = mas_safe_pivot(mas, pivots, offset, type); + index = mas->index; + while (index <= max) { + gap = 0; + if (gaps) + gap = gaps[offset]; + else if (!mas_slot(mas, slots, offset)) + gap = max - min + 1; + + if (gap) { + if ((size <= gap) && (size <= mas->last - min + 1)) + break; + + if (!gaps) { + /* Skip the next slot, it cannot be a gap. */ + if (offset < 2) + goto ascend; + + offset -= 2; + max = pivots[offset]; + min = mas_safe_min(mas, pivots, offset); + continue; + } + } + + if (!offset) + goto ascend; + + offset--; + max = min - 1; + min = mas_safe_min(mas, pivots, offset); + } + + if (unlikely(index > max)) { + mas_set_err(mas, -EBUSY); + return false; + } + + if (unlikely(ma_is_leaf(type))) { + mas->offset = offset; + mas->min = min; + mas->max = min + gap - 1; + return true; + } + + /* descend, only happens under lock. */ + mas->node = mas_slot(mas, slots, offset); + mas->min = min; + mas->max = max; + mas->offset = mas_data_end(mas); + return false; + +ascend: + if (mte_is_root(mas->node)) + mas_set_err(mas, -EBUSY); + + return false; +} + +static inline bool mas_anode_descend(struct ma_state *mas, unsigned long size) +{ + enum maple_type type = mte_node_type(mas->node); + unsigned long pivot, min, gap = 0; + unsigned char offset; + unsigned long *gaps; + unsigned long *pivots = ma_pivots(mas_mn(mas), type); + void __rcu **slots = ma_slots(mas_mn(mas), type); + bool found = false; + + if (ma_is_dense(type)) { + mas->offset = (unsigned char)(mas->index - mas->min); + return true; + } + + gaps = ma_gaps(mte_to_node(mas->node), type); + offset = mas->offset; + min = mas_safe_min(mas, pivots, offset); + for (; offset < mt_slots[type]; offset++) { + pivot = mas_safe_pivot(mas, pivots, offset, type); + if (offset && !pivot) + break; + + /* Not within lower bounds */ + if (mas->index > pivot) + goto next_slot; + + if (gaps) + gap = gaps[offset]; + else if (!mas_slot(mas, slots, offset)) + gap = min(pivot, mas->last) - max(mas->index, min) + 1; + else + goto next_slot; + + if (gap >= size) { + if (ma_is_leaf(type)) { + found = true; + goto done; + } + if (mas->index <= pivot) { + mas->node = mas_slot(mas, slots, offset); + mas->min = min; + mas->max = pivot; + offset = 0; + break; + } + } +next_slot: + min = pivot + 1; + if (mas->last <= pivot) { + mas_set_err(mas, -EBUSY); + return true; + } + } + + if (mte_is_root(mas->node)) + found = true; +done: + mas->offset = offset; + return found; +} + +/** + * mas_walk() - Search for @mas->index in the tree. + * @mas: The maple state. + * + * mas->index and mas->last will be set to the range if there is a value. If + * mas->node is MAS_NONE, reset to MAS_START. + * + * Return: the entry at the location or %NULL. + */ +void *mas_walk(struct ma_state *mas) +{ + void *entry; + +retry: + entry = mas_state_walk(mas); + if (mas_is_start(mas)) + goto retry; + + if (mas_is_ptr(mas)) { + if (!mas->index) { + mas->last = 0; + } else { + mas->index = 1; + mas->last = ULONG_MAX; + } + return entry; + } + + if (mas_is_none(mas)) { + mas->index = 0; + mas->last = ULONG_MAX; + } + + return entry; +} +EXPORT_SYMBOL_GPL(mas_walk); + +static inline bool mas_rewind_node(struct ma_state *mas) +{ + unsigned char slot; + + do { + if (mte_is_root(mas->node)) { + slot = mas->offset; + if (!slot) + return false; + } else { + mas_ascend(mas); + slot = mas->offset; + } + } while (!slot); + + mas->offset = --slot; + return true; +} + +/* + * mas_skip_node() - Internal function. Skip over a node. + * @mas: The maple state. + * + * Return: true if there is another node, false otherwise. + */ +static inline bool mas_skip_node(struct ma_state *mas) +{ + unsigned char slot, slot_count; + unsigned long *pivots; + enum maple_type mt; + + mt = mte_node_type(mas->node); + slot_count = mt_slots[mt] - 1; + do { + if (mte_is_root(mas->node)) { + slot = mas->offset; + if (slot > slot_count) { + mas_set_err(mas, -EBUSY); + return false; + } + } else { + mas_ascend(mas); + slot = mas->offset; + mt = mte_node_type(mas->node); + slot_count = mt_slots[mt] - 1; + } + } while (slot > slot_count); + + mas->offset = ++slot; + pivots = ma_pivots(mas_mn(mas), mt); + if (slot > 0) + mas->min = pivots[slot - 1] + 1; + + if (slot <= slot_count) + mas->max = pivots[slot]; + + return true; +} + +/* + * mas_awalk() - Allocation walk. Search from low address to high, for a gap of + * @size + * @mas: The maple state + * @size: The size of the gap required + * + * Search between @mas->index and @mas->last for a gap of @size. + */ +static inline void mas_awalk(struct ma_state *mas, unsigned long size) +{ + struct maple_enode *last = NULL; + + /* + * There are 4 options: + * go to child (descend) + * go back to parent (ascend) + * no gap found. (return, slot == MAPLE_NODE_SLOTS) + * found the gap. (return, slot != MAPLE_NODE_SLOTS) + */ + while (!mas_is_err(mas) && !mas_anode_descend(mas, size)) { + if (last == mas->node) + mas_skip_node(mas); + else + last = mas->node; + } +} + +/* + * mas_fill_gap() - Fill a located gap with @entry. + * @mas: The maple state + * @entry: The value to store + * @slot: The offset into the node to store the @entry + * @size: The size of the entry + * @index: The start location + */ +static inline void mas_fill_gap(struct ma_state *mas, void *entry, + unsigned char slot, unsigned long size, unsigned long *index) +{ + MA_WR_STATE(wr_mas, mas, entry); + unsigned char pslot = mte_parent_slot(mas->node); + struct maple_enode *mn = mas->node; + unsigned long *pivots; + enum maple_type ptype; + /* + * mas->index is the start address for the search + * which may no longer be needed. + * mas->last is the end address for the search + */ + + *index = mas->index; + mas->last = mas->index + size - 1; + + /* + * It is possible that using mas->max and mas->min to correctly + * calculate the index and last will cause an issue in the gap + * calculation, so fix the ma_state here + */ + mas_ascend(mas); + ptype = mte_node_type(mas->node); + pivots = ma_pivots(mas_mn(mas), ptype); + mas->max = mas_safe_pivot(mas, pivots, pslot, ptype); + mas->min = mas_safe_min(mas, pivots, pslot); + mas->node = mn; + mas->offset = slot; + mas_wr_store_entry(&wr_mas); +} + +/* + * mas_sparse_area() - Internal function. Return upper or lower limit when + * searching for a gap in an empty tree. + * @mas: The maple state + * @min: the minimum range + * @max: The maximum range + * @size: The size of the gap + * @fwd: Searching forward or back + */ +static inline void mas_sparse_area(struct ma_state *mas, unsigned long min, + unsigned long max, unsigned long size, bool fwd) +{ + unsigned long start = 0; + + if (!unlikely(mas_is_none(mas))) + start++; + /* mas_is_ptr */ + + if (start < min) + start = min; + + if (fwd) { + mas->index = start; + mas->last = start + size - 1; + return; + } + + mas->index = max; +} + +/* + * mas_empty_area() - Get the lowest address within the range that is + * sufficient for the size requested. + * @mas: The maple state + * @min: The lowest value of the range + * @max: The highest value of the range + * @size: The size needed + */ +int mas_empty_area(struct ma_state *mas, unsigned long min, + unsigned long max, unsigned long size) +{ + unsigned char offset; + unsigned long *pivots; + enum maple_type mt; + + if (mas_is_start(mas)) + mas_start(mas); + else if (mas->offset >= 2) + mas->offset -= 2; + else if (!mas_skip_node(mas)) + return -EBUSY; + + /* Empty set */ + if (mas_is_none(mas) || mas_is_ptr(mas)) { + mas_sparse_area(mas, min, max, size, true); + return 0; + } + + /* The start of the window can only be within these values */ + mas->index = min; + mas->last = max; + mas_awalk(mas, size); + + if (unlikely(mas_is_err(mas))) + return xa_err(mas->node); + + offset = mas->offset; + if (unlikely(offset == MAPLE_NODE_SLOTS)) + return -EBUSY; + + mt = mte_node_type(mas->node); + pivots = ma_pivots(mas_mn(mas), mt); + if (offset) + mas->min = pivots[offset - 1] + 1; + + if (offset < mt_pivots[mt]) + mas->max = pivots[offset]; + + if (mas->index < mas->min) + mas->index = mas->min; + + mas->last = mas->index + size - 1; + return 0; +} +EXPORT_SYMBOL_GPL(mas_empty_area); + +/* + * mas_empty_area_rev() - Get the highest address within the range that is + * sufficient for the size requested. + * @mas: The maple state + * @min: The lowest value of the range + * @max: The highest value of the range + * @size: The size needed + */ +int mas_empty_area_rev(struct ma_state *mas, unsigned long min, + unsigned long max, unsigned long size) +{ + struct maple_enode *last = mas->node; + + if (mas_is_start(mas)) { + mas_start(mas); + mas->offset = mas_data_end(mas); + } else if (mas->offset >= 2) { + mas->offset -= 2; + } else if (!mas_rewind_node(mas)) { + return -EBUSY; + } + + /* Empty set. */ + if (mas_is_none(mas) || mas_is_ptr(mas)) { + mas_sparse_area(mas, min, max, size, false); + return 0; + } + + /* The start of the window can only be within these values. */ + mas->index = min; + mas->last = max; + + while (!mas_rev_awalk(mas, size)) { + if (last == mas->node) { + if (!mas_rewind_node(mas)) + return -EBUSY; + } else { + last = mas->node; + } + } + + if (mas_is_err(mas)) + return xa_err(mas->node); + + if (unlikely(mas->offset == MAPLE_NODE_SLOTS)) + return -EBUSY; + + /* + * mas_rev_awalk() has set mas->min and mas->max to the gap values. If + * the maximum is outside the window we are searching, then use the last + * location in the search. + * mas->max and mas->min is the range of the gap. + * mas->index and mas->last are currently set to the search range. + */ + + /* Trim the upper limit to the max. */ + if (mas->max <= mas->last) + mas->last = mas->max; + + mas->index = mas->last - size + 1; + return 0; +} +EXPORT_SYMBOL_GPL(mas_empty_area_rev); + +static inline int mas_alloc(struct ma_state *mas, void *entry, + unsigned long size, unsigned long *index) +{ + unsigned long min; + + mas_start(mas); + if (mas_is_none(mas) || mas_is_ptr(mas)) { + mas_root_expand(mas, entry); + if (mas_is_err(mas)) + return xa_err(mas->node); + + if (!mas->index) + return mte_pivot(mas->node, 0); + return mte_pivot(mas->node, 1); + } + + /* Must be walking a tree. */ + mas_awalk(mas, size); + if (mas_is_err(mas)) + return xa_err(mas->node); + + if (mas->offset == MAPLE_NODE_SLOTS) + goto no_gap; + + /* + * At this point, mas->node points to the right node and we have an + * offset that has a sufficient gap. + */ + min = mas->min; + if (mas->offset) + min = mte_pivot(mas->node, mas->offset - 1) + 1; + + if (mas->index < min) + mas->index = min; + + mas_fill_gap(mas, entry, mas->offset, size, index); + return 0; + +no_gap: + return -EBUSY; +} + +static inline int mas_rev_alloc(struct ma_state *mas, unsigned long min, + unsigned long max, void *entry, + unsigned long size, unsigned long *index) +{ + int ret = 0; + + ret = mas_empty_area_rev(mas, min, max, size); + if (ret) + return ret; + + if (mas_is_err(mas)) + return xa_err(mas->node); + + if (mas->offset == MAPLE_NODE_SLOTS) + goto no_gap; + + mas_fill_gap(mas, entry, mas->offset, size, index); + return 0; + +no_gap: + return -EBUSY; +} + +/* + * mas_dead_leaves() - Mark all leaves of a node as dead. + * @mas: The maple state + * @slots: Pointer to the slot array + * + * Must hold the write lock. + * + * Return: The number of leaves marked as dead. + */ +static inline +unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots) +{ + struct maple_node *node; + enum maple_type type; + void *entry; + int offset; + + for (offset = 0; offset < mt_slot_count(mas->node); offset++) { + entry = mas_slot_locked(mas, slots, offset); + type = mte_node_type(entry); + node = mte_to_node(entry); + /* Use both node and type to catch LE & BE metadata */ + if (!node || !type) + break; + + mte_set_node_dead(entry); + smp_wmb(); /* Needed for RCU */ + node->type = type; + rcu_assign_pointer(slots[offset], node); + } + + return offset; +} + +static void __rcu **mas_dead_walk(struct ma_state *mas, unsigned char offset) +{ + struct maple_node *node, *next; + void __rcu **slots = NULL; + + next = mas_mn(mas); + do { + mas->node = ma_enode_ptr(next); + node = mas_mn(mas); + slots = ma_slots(node, node->type); + next = mas_slot_locked(mas, slots, offset); + offset = 0; + } while (!ma_is_leaf(next->type)); + + return slots; +} + +static void mt_free_walk(struct rcu_head *head) +{ + void __rcu **slots; + struct maple_node *node, *start; + struct maple_tree mt; + unsigned char offset; + enum maple_type type; + MA_STATE(mas, &mt, 0, 0); + + node = container_of(head, struct maple_node, rcu); + + if (ma_is_leaf(node->type)) + goto free_leaf; + + mt_init_flags(&mt, node->ma_flags); + mas_lock(&mas); + start = node; + mas.node = mt_mk_node(node, node->type); + slots = mas_dead_walk(&mas, 0); + node = mas_mn(&mas); + do { + mt_free_bulk(node->slot_len, slots); + offset = node->parent_slot + 1; + mas.node = node->piv_parent; + if (mas_mn(&mas) == node) + goto start_slots_free; + + type = mte_node_type(mas.node); + slots = ma_slots(mte_to_node(mas.node), type); + if ((offset < mt_slots[type]) && (slots[offset])) + slots = mas_dead_walk(&mas, offset); + + node = mas_mn(&mas); + } while ((node != start) || (node->slot_len < offset)); + + slots = ma_slots(node, node->type); + mt_free_bulk(node->slot_len, slots); + +start_slots_free: + mas_unlock(&mas); +free_leaf: + mt_free_rcu(&node->rcu); +} + +static inline void __rcu **mas_destroy_descend(struct ma_state *mas, + struct maple_enode *prev, unsigned char offset) +{ + struct maple_node *node; + struct maple_enode *next = mas->node; + void __rcu **slots = NULL; + + do { + mas->node = next; + node = mas_mn(mas); + slots = ma_slots(node, mte_node_type(mas->node)); + next = mas_slot_locked(mas, slots, 0); + if ((mte_dead_node(next))) + next = mas_slot_locked(mas, slots, 1); + + mte_set_node_dead(mas->node); + node->type = mte_node_type(mas->node); + node->piv_parent = prev; + node->parent_slot = offset; + offset = 0; + prev = mas->node; + } while (!mte_is_leaf(next)); + + return slots; +} + +static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags, + bool free) +{ + void __rcu **slots; + struct maple_node *node = mte_to_node(enode); + struct maple_enode *start; + struct maple_tree mt; + + MA_STATE(mas, &mt, 0, 0); + + if (mte_is_leaf(enode)) + goto free_leaf; + + mt_init_flags(&mt, ma_flags); + mas_lock(&mas); + + mas.node = start = enode; + slots = mas_destroy_descend(&mas, start, 0); + node = mas_mn(&mas); + do { + enum maple_type type; + unsigned char offset; + struct maple_enode *parent, *tmp; + + node->slot_len = mas_dead_leaves(&mas, slots); + if (free) + mt_free_bulk(node->slot_len, slots); + offset = node->parent_slot + 1; + mas.node = node->piv_parent; + if (mas_mn(&mas) == node) + goto start_slots_free; + + type = mte_node_type(mas.node); + slots = ma_slots(mte_to_node(mas.node), type); + if (offset >= mt_slots[type]) + goto next; + + tmp = mas_slot_locked(&mas, slots, offset); + if (mte_node_type(tmp) && mte_to_node(tmp)) { + parent = mas.node; + mas.node = tmp; + slots = mas_destroy_descend(&mas, parent, offset); + } +next: + node = mas_mn(&mas); + } while (start != mas.node); + + node = mas_mn(&mas); + node->slot_len = mas_dead_leaves(&mas, slots); + if (free) + mt_free_bulk(node->slot_len, slots); + +start_slots_free: + mas_unlock(&mas); + +free_leaf: + if (free) + mt_free_rcu(&node->rcu); +} + +/* + * mte_destroy_walk() - Free a tree or sub-tree. + * @enode - the encoded maple node (maple_enode) to start + * @mn - the tree to free - needed for node types. + * + * Must hold the write lock. + */ +static inline void mte_destroy_walk(struct maple_enode *enode, + struct maple_tree *mt) +{ + struct maple_node *node = mte_to_node(enode); + + if (mt_in_rcu(mt)) { + mt_destroy_walk(enode, mt->ma_flags, false); + call_rcu(&node->rcu, mt_free_walk); + } else { + mt_destroy_walk(enode, mt->ma_flags, true); + } +} + +static void mas_wr_store_setup(struct ma_wr_state *wr_mas) +{ + if (!mas_is_start(wr_mas->mas)) { + if (mas_is_none(wr_mas->mas)) { + mas_reset(wr_mas->mas); + } else { + wr_mas->r_max = wr_mas->mas->max; + wr_mas->type = mte_node_type(wr_mas->mas->node); + if (mas_is_span_wr(wr_mas)) + mas_reset(wr_mas->mas); + } + } + +} + +/* Interface */ + +/** + * mas_store() - Store an @entry. + * @mas: The maple state. + * @entry: The entry to store. + * + * The @mas->index and @mas->last is used to set the range for the @entry. + * Note: The @mas should have pre-allocated entries to ensure there is memory to + * store the entry. Please see mas_expected_entries()/mas_destroy() for more details. + * + * Return: the first entry between mas->index and mas->last or %NULL. + */ +void *mas_store(struct ma_state *mas, void *entry) +{ + MA_WR_STATE(wr_mas, mas, entry); + + trace_ma_write(__func__, mas, 0, entry); +#ifdef CONFIG_DEBUG_MAPLE_TREE + if (mas->index > mas->last) + pr_err("Error %lu > %lu %p\n", mas->index, mas->last, entry); + MT_BUG_ON(mas->tree, mas->index > mas->last); + if (mas->index > mas->last) { + mas_set_err(mas, -EINVAL); + return NULL; + } + +#endif + + /* + * Storing is the same operation as insert with the added caveat that it + * can overwrite entries. Although this seems simple enough, one may + * want to examine what happens if a single store operation was to + * overwrite multiple entries within a self-balancing B-Tree. + */ + mas_wr_store_setup(&wr_mas); + mas_wr_store_entry(&wr_mas); + return wr_mas.content; +} +EXPORT_SYMBOL_GPL(mas_store); + +/** + * mas_store_gfp() - Store a value into the tree. + * @mas: The maple state + * @entry: The entry to store + * @gfp: The GFP_FLAGS to use for allocations if necessary. + * + * Return: 0 on success, -EINVAL on invalid request, -ENOMEM if memory could not + * be allocated. + */ +int mas_store_gfp(struct ma_state *mas, void *entry, gfp_t gfp) +{ + MA_WR_STATE(wr_mas, mas, entry); + + mas_wr_store_setup(&wr_mas); + trace_ma_write(__func__, mas, 0, entry); +retry: + mas_wr_store_entry(&wr_mas); + if (unlikely(mas_nomem(mas, gfp))) + goto retry; + + if (unlikely(mas_is_err(mas))) + return xa_err(mas->node); + + return 0; +} +EXPORT_SYMBOL_GPL(mas_store_gfp); + +/** + * mas_store_prealloc() - Store a value into the tree using memory + * preallocated in the maple state. + * @mas: The maple state + * @entry: The entry to store. + */ +void mas_store_prealloc(struct ma_state *mas, void *entry) +{ + MA_WR_STATE(wr_mas, mas, entry); + + mas_wr_store_setup(&wr_mas); + trace_ma_write(__func__, mas, 0, entry); + mas_wr_store_entry(&wr_mas); + BUG_ON(mas_is_err(mas)); + mas_destroy(mas); +} +EXPORT_SYMBOL_GPL(mas_store_prealloc); + +/** + * mas_preallocate() - Preallocate enough nodes for a store operation + * @mas: The maple state + * @entry: The entry that will be stored + * @gfp: The GFP_FLAGS to use for allocations. + * + * Return: 0 on success, -ENOMEM if memory could not be allocated. + */ +int mas_preallocate(struct ma_state *mas, void *entry, gfp_t gfp) +{ + int ret; + + mas_node_count_gfp(mas, 1 + mas_mt_height(mas) * 3, gfp); + mas->mas_flags |= MA_STATE_PREALLOC; + if (likely(!mas_is_err(mas))) + return 0; + + mas_set_alloc_req(mas, 0); + ret = xa_err(mas->node); + mas_reset(mas); + mas_destroy(mas); + mas_reset(mas); + return ret; +} + +/* + * mas_destroy() - destroy a maple state. + * @mas: The maple state + * + * Upon completion, check the left-most node and rebalance against the node to + * the right if necessary. Frees any allocated nodes associated with this maple + * state. + */ +void mas_destroy(struct ma_state *mas) +{ + struct maple_alloc *node; + + /* + * When using mas_for_each() to insert an expected number of elements, + * it is possible that the number inserted is less than the expected + * number. To fix an invalid final node, a check is performed here to + * rebalance the previous node with the final node. + */ + if (mas->mas_flags & MA_STATE_REBALANCE) { + unsigned char end; + + if (mas_is_start(mas)) + mas_start(mas); + + mtree_range_walk(mas); + end = mas_data_end(mas) + 1; + if (end < mt_min_slot_count(mas->node) - 1) + mas_destroy_rebalance(mas, end); + + mas->mas_flags &= ~MA_STATE_REBALANCE; + } + mas->mas_flags &= ~(MA_STATE_BULK|MA_STATE_PREALLOC); + + while (mas->alloc && !((unsigned long)mas->alloc & 0x1)) { + node = mas->alloc; + mas->alloc = node->slot[0]; + if (node->node_count > 0) + mt_free_bulk(node->node_count, + (void __rcu **)&node->slot[1]); + kmem_cache_free(maple_node_cache, node); + } + mas->alloc = NULL; +} +EXPORT_SYMBOL_GPL(mas_destroy); + +/* + * mas_expected_entries() - Set the expected number of entries that will be inserted. + * @mas: The maple state + * @nr_entries: The number of expected entries. + * + * This will attempt to pre-allocate enough nodes to store the expected number + * of entries. The allocations will occur using the bulk allocator interface + * for speed. Please call mas_destroy() on the @mas after inserting the entries + * to ensure any unused nodes are freed. + * + * Return: 0 on success, -ENOMEM if memory could not be allocated. + */ +int mas_expected_entries(struct ma_state *mas, unsigned long nr_entries) +{ + int nonleaf_cap = MAPLE_ARANGE64_SLOTS - 2; + struct maple_enode *enode = mas->node; + int nr_nodes; + int ret; + + /* + * Sometimes it is necessary to duplicate a tree to a new tree, such as + * forking a process and duplicating the VMAs from one tree to a new + * tree. When such a situation arises, it is known that the new tree is + * not going to be used until the entire tree is populated. For + * performance reasons, it is best to use a bulk load with RCU disabled. + * This allows for optimistic splitting that favours the left and reuse + * of nodes during the operation. + */ + + /* Optimize splitting for bulk insert in-order */ + mas->mas_flags |= MA_STATE_BULK; + + /* + * Avoid overflow, assume a gap between each entry and a trailing null. + * If this is wrong, it just means allocation can happen during + * insertion of entries. + */ + nr_nodes = max(nr_entries, nr_entries * 2 + 1); + if (!mt_is_alloc(mas->tree)) + nonleaf_cap = MAPLE_RANGE64_SLOTS - 2; + + /* Leaves; reduce slots to keep space for expansion */ + nr_nodes = DIV_ROUND_UP(nr_nodes, MAPLE_RANGE64_SLOTS - 2); + /* Internal nodes */ + nr_nodes += DIV_ROUND_UP(nr_nodes, nonleaf_cap); + /* Add working room for split (2 nodes) + new parents */ + mas_node_count(mas, nr_nodes + 3); + + /* Detect if allocations run out */ + mas->mas_flags |= MA_STATE_PREALLOC; + + if (!mas_is_err(mas)) + return 0; + + ret = xa_err(mas->node); + mas->node = enode; + mas_destroy(mas); + return ret; + +} +EXPORT_SYMBOL_GPL(mas_expected_entries); + +/** + * mas_next() - Get the next entry. + * @mas: The maple state + * @max: The maximum index to check. + * + * Returns the next entry after @mas->index. + * Must hold rcu_read_lock or the write lock. + * Can return the zero entry. + * + * Return: The next entry or %NULL + */ +void *mas_next(struct ma_state *mas, unsigned long max) +{ + if (mas_is_none(mas) || mas_is_paused(mas)) + mas->node = MAS_START; + + if (mas_is_start(mas)) + mas_walk(mas); /* Retries on dead nodes handled by mas_walk */ + + if (mas_is_ptr(mas)) { + if (!mas->index) { + mas->index = 1; + mas->last = ULONG_MAX; + } + return NULL; + } + + if (mas->last == ULONG_MAX) + return NULL; + + /* Retries on dead nodes handled by mas_next_entry */ + return mas_next_entry(mas, max); +} +EXPORT_SYMBOL_GPL(mas_next); + +/** + * mt_next() - get the next value in the maple tree + * @mt: The maple tree + * @index: The start index + * @max: The maximum index to check + * + * Return: The entry at @index or higher, or %NULL if nothing is found. + */ +void *mt_next(struct maple_tree *mt, unsigned long index, unsigned long max) +{ + void *entry = NULL; + MA_STATE(mas, mt, index, index); + + rcu_read_lock(); + entry = mas_next(&mas, max); + rcu_read_unlock(); + return entry; +} +EXPORT_SYMBOL_GPL(mt_next); + +/** + * mas_prev() - Get the previous entry + * @mas: The maple state + * @min: The minimum value to check. + * + * Must hold rcu_read_lock or the write lock. + * Will reset mas to MAS_START if the node is MAS_NONE. Will stop on not + * searchable nodes. + * + * Return: the previous value or %NULL. + */ +void *mas_prev(struct ma_state *mas, unsigned long min) +{ + if (!mas->index) { + /* Nothing comes before 0 */ + mas->last = 0; + return NULL; + } + + if (unlikely(mas_is_ptr(mas))) + return NULL; + + if (mas_is_none(mas) || mas_is_paused(mas)) + mas->node = MAS_START; + + if (mas_is_start(mas)) { + mas_walk(mas); + if (!mas->index) + return NULL; + } + + if (mas_is_ptr(mas)) { + if (!mas->index) { + mas->last = 0; + return NULL; + } + + mas->index = mas->last = 0; + return mas_root_locked(mas); + } + return mas_prev_entry(mas, min); +} +EXPORT_SYMBOL_GPL(mas_prev); + +/** + * mt_prev() - get the previous value in the maple tree + * @mt: The maple tree + * @index: The start index + * @min: The minimum index to check + * + * Return: The entry at @index or lower, or %NULL if nothing is found. + */ +void *mt_prev(struct maple_tree *mt, unsigned long index, unsigned long min) +{ + void *entry = NULL; + MA_STATE(mas, mt, index, index); + + rcu_read_lock(); + entry = mas_prev(&mas, min); + rcu_read_unlock(); + return entry; +} +EXPORT_SYMBOL_GPL(mt_prev); + +/** + * mas_pause() - Pause a mas_find/mas_for_each to drop the lock. + * @mas: The maple state to pause + * + * Some users need to pause a walk and drop the lock they're holding in + * order to yield to a higher priority thread or carry out an operation + * on an entry. Those users should call this function before they drop + * the lock. It resets the @mas to be suitable for the next iteration + * of the loop after the user has reacquired the lock. If most entries + * found during a walk require you to call mas_pause(), the mt_for_each() + * iterator may be more appropriate. + * + */ +void mas_pause(struct ma_state *mas) +{ + mas->node = MAS_PAUSE; +} +EXPORT_SYMBOL_GPL(mas_pause); + +/** + * mas_find() - On the first call, find the entry at or after mas->index up to + * %max. Otherwise, find the entry after mas->index. + * @mas: The maple state + * @max: The maximum value to check. + * + * Must hold rcu_read_lock or the write lock. + * If an entry exists, last and index are updated accordingly. + * May set @mas->node to MAS_NONE. + * + * Return: The entry or %NULL. + */ +void *mas_find(struct ma_state *mas, unsigned long max) +{ + if (unlikely(mas_is_paused(mas))) { + if (unlikely(mas->last == ULONG_MAX)) { + mas->node = MAS_NONE; + return NULL; + } + mas->node = MAS_START; + mas->index = ++mas->last; + } + + if (unlikely(mas_is_start(mas))) { + /* First run or continue */ + void *entry; + + if (mas->index > max) + return NULL; + + entry = mas_walk(mas); + if (entry) + return entry; + } + + if (unlikely(!mas_searchable(mas))) + return NULL; + + /* Retries on dead nodes handled by mas_next_entry */ + return mas_next_entry(mas, max); +} +EXPORT_SYMBOL_GPL(mas_find); + +/** + * mas_find_rev: On the first call, find the first non-null entry at or below + * mas->index down to %min. Otherwise find the first non-null entry below + * mas->index down to %min. + * @mas: The maple state + * @min: The minimum value to check. + * + * Must hold rcu_read_lock or the write lock. + * If an entry exists, last and index are updated accordingly. + * May set @mas->node to MAS_NONE. + * + * Return: The entry or %NULL. + */ +void *mas_find_rev(struct ma_state *mas, unsigned long min) +{ + if (unlikely(mas_is_paused(mas))) { + if (unlikely(mas->last == ULONG_MAX)) { + mas->node = MAS_NONE; + return NULL; + } + mas->node = MAS_START; + mas->last = --mas->index; + } + + if (unlikely(mas_is_start(mas))) { + /* First run or continue */ + void *entry; + + if (mas->index < min) + return NULL; + + entry = mas_walk(mas); + if (entry) + return entry; + } + + if (unlikely(!mas_searchable(mas))) + return NULL; + + if (mas->index < min) + return NULL; + + /* Retries on dead nodes handled by mas_prev_entry */ + return mas_prev_entry(mas, min); +} +EXPORT_SYMBOL_GPL(mas_find_rev); + +/** + * mas_erase() - Find the range in which index resides and erase the entire + * range. + * @mas: The maple state + * + * Must hold the write lock. + * Searches for @mas->index, sets @mas->index and @mas->last to the range and + * erases that range. + * + * Return: the entry that was erased or %NULL, @mas->index and @mas->last are updated. + */ +void *mas_erase(struct ma_state *mas) +{ + void *entry; + MA_WR_STATE(wr_mas, mas, NULL); + + if (mas_is_none(mas) || mas_is_paused(mas)) + mas->node = MAS_START; + + /* Retry unnecessary when holding the write lock. */ + entry = mas_state_walk(mas); + if (!entry) + return NULL; + +write_retry: + /* Must reset to ensure spanning writes of last slot are detected */ + mas_reset(mas); + mas_wr_store_setup(&wr_mas); + mas_wr_store_entry(&wr_mas); + if (mas_nomem(mas, GFP_KERNEL)) + goto write_retry; + + return entry; +} +EXPORT_SYMBOL_GPL(mas_erase); + +/** + * mas_nomem() - Check if there was an error allocating and do the allocation + * if necessary If there are allocations, then free them. + * @mas: The maple state + * @gfp: The GFP_FLAGS to use for allocations + * Return: true on allocation, false otherwise. + */ +bool mas_nomem(struct ma_state *mas, gfp_t gfp) + __must_hold(mas->tree->lock) +{ + if (likely(mas->node != MA_ERROR(-ENOMEM))) { + mas_destroy(mas); + return false; + } + + if (gfpflags_allow_blocking(gfp) && !mt_external_lock(mas->tree)) { + mtree_unlock(mas->tree); + mas_alloc_nodes(mas, gfp); + mtree_lock(mas->tree); + } else { + mas_alloc_nodes(mas, gfp); + } + + if (!mas_allocated(mas)) + return false; + + mas->node = MAS_START; + return true; +} + +void __init maple_tree_init(void) +{ + maple_node_cache = kmem_cache_create("maple_node", + sizeof(struct maple_node), sizeof(struct maple_node), + SLAB_PANIC, NULL); +} + +/** + * mtree_load() - Load a value stored in a maple tree + * @mt: The maple tree + * @index: The index to load + * + * Return: the entry or %NULL + */ +void *mtree_load(struct maple_tree *mt, unsigned long index) +{ + MA_STATE(mas, mt, index, index); + void *entry; + + trace_ma_read(__func__, &mas); + rcu_read_lock(); +retry: + entry = mas_start(&mas); + if (unlikely(mas_is_none(&mas))) + goto unlock; + + if (unlikely(mas_is_ptr(&mas))) { + if (index) + entry = NULL; + + goto unlock; + } + + entry = mtree_lookup_walk(&mas); + if (!entry && unlikely(mas_is_start(&mas))) + goto retry; +unlock: + rcu_read_unlock(); + if (xa_is_zero(entry)) + return NULL; + + return entry; +} +EXPORT_SYMBOL(mtree_load); + +/** + * mtree_store_range() - Store an entry at a given range. + * @mt: The maple tree + * @index: The start of the range + * @last: The end of the range + * @entry: The entry to store + * @gfp: The GFP_FLAGS to use for allocations + * + * Return: 0 on success, -EINVAL on invalid request, -ENOMEM if memory could not + * be allocated. + */ +int mtree_store_range(struct maple_tree *mt, unsigned long index, + unsigned long last, void *entry, gfp_t gfp) +{ + MA_STATE(mas, mt, index, last); + MA_WR_STATE(wr_mas, &mas, entry); + + trace_ma_write(__func__, &mas, 0, entry); + if (WARN_ON_ONCE(xa_is_advanced(entry))) + return -EINVAL; + + if (index > last) + return -EINVAL; + + mtree_lock(mt); +retry: + mas_wr_store_entry(&wr_mas); + if (mas_nomem(&mas, gfp)) + goto retry; + + mtree_unlock(mt); + if (mas_is_err(&mas)) + return xa_err(mas.node); + + return 0; +} +EXPORT_SYMBOL(mtree_store_range); + +/** + * mtree_store() - Store an entry at a given index. + * @mt: The maple tree + * @index: The index to store the value + * @entry: The entry to store + * @gfp: The GFP_FLAGS to use for allocations + * + * Return: 0 on success, -EINVAL on invalid request, -ENOMEM if memory could not + * be allocated. + */ +int mtree_store(struct maple_tree *mt, unsigned long index, void *entry, + gfp_t gfp) +{ + return mtree_store_range(mt, index, index, entry, gfp); +} +EXPORT_SYMBOL(mtree_store); + +/** + * mtree_insert_range() - Insert an entry at a give range if there is no value. + * @mt: The maple tree + * @first: The start of the range + * @last: The end of the range + * @entry: The entry to store + * @gfp: The GFP_FLAGS to use for allocations. + * + * Return: 0 on success, -EEXISTS if the range is occupied, -EINVAL on invalid + * request, -ENOMEM if memory could not be allocated. + */ +int mtree_insert_range(struct maple_tree *mt, unsigned long first, + unsigned long last, void *entry, gfp_t gfp) +{ + MA_STATE(ms, mt, first, last); + + if (WARN_ON_ONCE(xa_is_advanced(entry))) + return -EINVAL; + + if (first > last) + return -EINVAL; + + mtree_lock(mt); +retry: + mas_insert(&ms, entry); + if (mas_nomem(&ms, gfp)) + goto retry; + + mtree_unlock(mt); + if (mas_is_err(&ms)) + return xa_err(ms.node); + + return 0; +} +EXPORT_SYMBOL(mtree_insert_range); + +/** + * mtree_insert() - Insert an entry at a give index if there is no value. + * @mt: The maple tree + * @index : The index to store the value + * @entry: The entry to store + * @gfp: The FGP_FLAGS to use for allocations. + * + * Return: 0 on success, -EEXISTS if the range is occupied, -EINVAL on invalid + * request, -ENOMEM if memory could not be allocated. + */ +int mtree_insert(struct maple_tree *mt, unsigned long index, void *entry, + gfp_t gfp) +{ + return mtree_insert_range(mt, index, index, entry, gfp); +} +EXPORT_SYMBOL(mtree_insert); + +int mtree_alloc_range(struct maple_tree *mt, unsigned long *startp, + void *entry, unsigned long size, unsigned long min, + unsigned long max, gfp_t gfp) +{ + int ret = 0; + + MA_STATE(mas, mt, min, max - size); + if (!mt_is_alloc(mt)) + return -EINVAL; + + if (WARN_ON_ONCE(mt_is_reserved(entry))) + return -EINVAL; + + if (min > max) + return -EINVAL; + + if (max < size) + return -EINVAL; + + if (!size) + return -EINVAL; + + mtree_lock(mt); +retry: + mas.offset = 0; + mas.index = min; + mas.last = max - size; + ret = mas_alloc(&mas, entry, size, startp); + if (mas_nomem(&mas, gfp)) + goto retry; + + mtree_unlock(mt); + return ret; +} +EXPORT_SYMBOL(mtree_alloc_range); + +int mtree_alloc_rrange(struct maple_tree *mt, unsigned long *startp, + void *entry, unsigned long size, unsigned long min, + unsigned long max, gfp_t gfp) +{ + int ret = 0; + + MA_STATE(mas, mt, min, max - size); + if (!mt_is_alloc(mt)) + return -EINVAL; + + if (WARN_ON_ONCE(mt_is_reserved(entry))) + return -EINVAL; + + if (min >= max) + return -EINVAL; + + if (max < size - 1) + return -EINVAL; + + if (!size) + return -EINVAL; + + mtree_lock(mt); +retry: + ret = mas_rev_alloc(&mas, min, max, entry, size, startp); + if (mas_nomem(&mas, gfp)) + goto retry; + + mtree_unlock(mt); + return ret; +} +EXPORT_SYMBOL(mtree_alloc_rrange); + +/** + * mtree_erase() - Find an index and erase the entire range. + * @mt: The maple tree + * @index: The index to erase + * + * Erasing is the same as a walk to an entry then a store of a NULL to that + * ENTIRE range. In fact, it is implemented as such using the advanced API. + * + * Return: The entry stored at the @index or %NULL + */ +void *mtree_erase(struct maple_tree *mt, unsigned long index) +{ + void *entry = NULL; + + MA_STATE(mas, mt, index, index); + trace_ma_op(__func__, &mas); + + mtree_lock(mt); + entry = mas_erase(&mas); + mtree_unlock(mt); + + return entry; +} +EXPORT_SYMBOL(mtree_erase); + +/** + * __mt_destroy() - Walk and free all nodes of a locked maple tree. + * @mt: The maple tree + * + * Note: Does not handle locking. + */ +void __mt_destroy(struct maple_tree *mt) +{ + void *root = mt_root_locked(mt); + + rcu_assign_pointer(mt->ma_root, NULL); + if (xa_is_node(root)) + mte_destroy_walk(root, mt); + + mt->ma_flags = 0; +} +EXPORT_SYMBOL_GPL(__mt_destroy); + +/** + * mtree_destroy() - Destroy a maple tree + * @mt: The maple tree + * + * Frees all resources used by the tree. Handles locking. + */ +void mtree_destroy(struct maple_tree *mt) +{ + mtree_lock(mt); + __mt_destroy(mt); + mtree_unlock(mt); +} +EXPORT_SYMBOL(mtree_destroy); + +/** + * mt_find() - Search from the start up until an entry is found. + * @mt: The maple tree + * @index: Pointer which contains the start location of the search + * @max: The maximum value to check + * + * Handles locking. @index will be incremented to one beyond the range. + * + * Return: The entry at or after the @index or %NULL + */ +void *mt_find(struct maple_tree *mt, unsigned long *index, unsigned long max) +{ + MA_STATE(mas, mt, *index, *index); + void *entry; +#ifdef CONFIG_DEBUG_MAPLE_TREE + unsigned long copy = *index; +#endif + + trace_ma_read(__func__, &mas); + + if ((*index) > max) + return NULL; + + rcu_read_lock(); +retry: + entry = mas_state_walk(&mas); + if (mas_is_start(&mas)) + goto retry; + + if (unlikely(xa_is_zero(entry))) + entry = NULL; + + if (entry) + goto unlock; + + while (mas_searchable(&mas) && (mas.index < max)) { + entry = mas_next_entry(&mas, max); + if (likely(entry && !xa_is_zero(entry))) + break; + } + + if (unlikely(xa_is_zero(entry))) + entry = NULL; +unlock: + rcu_read_unlock(); + if (likely(entry)) { + *index = mas.last + 1; +#ifdef CONFIG_DEBUG_MAPLE_TREE + if ((*index) && (*index) <= copy) + pr_err("index not increased! %lx <= %lx\n", + *index, copy); + MT_BUG_ON(mt, (*index) && ((*index) <= copy)); +#endif + } + + return entry; +} +EXPORT_SYMBOL(mt_find); + +/** + * mt_find_after() - Search from the start up until an entry is found. + * @mt: The maple tree + * @index: Pointer which contains the start location of the search + * @max: The maximum value to check + * + * Handles locking, detects wrapping on index == 0 + * + * Return: The entry at or after the @index or %NULL + */ +void *mt_find_after(struct maple_tree *mt, unsigned long *index, + unsigned long max) +{ + if (!(*index)) + return NULL; + + return mt_find(mt, index, max); +} +EXPORT_SYMBOL(mt_find_after); + +#ifdef CONFIG_DEBUG_MAPLE_TREE +atomic_t maple_tree_tests_run; +EXPORT_SYMBOL_GPL(maple_tree_tests_run); +atomic_t maple_tree_tests_passed; +EXPORT_SYMBOL_GPL(maple_tree_tests_passed); + +#ifndef __KERNEL__ +extern void kmem_cache_set_non_kernel(struct kmem_cache *, unsigned int); +void mt_set_non_kernel(unsigned int val) +{ + kmem_cache_set_non_kernel(maple_node_cache, val); +} + +extern unsigned long kmem_cache_get_alloc(struct kmem_cache *); +unsigned long mt_get_alloc_size(void) +{ + return kmem_cache_get_alloc(maple_node_cache); +} + +extern void kmem_cache_zero_nr_tallocated(struct kmem_cache *); +void mt_zero_nr_tallocated(void) +{ + kmem_cache_zero_nr_tallocated(maple_node_cache); +} + +extern unsigned int kmem_cache_nr_tallocated(struct kmem_cache *); +unsigned int mt_nr_tallocated(void) +{ + return kmem_cache_nr_tallocated(maple_node_cache); +} + +extern unsigned int kmem_cache_nr_allocated(struct kmem_cache *); +unsigned int mt_nr_allocated(void) +{ + return kmem_cache_nr_allocated(maple_node_cache); +} + +/* + * mas_dead_node() - Check if the maple state is pointing to a dead node. + * @mas: The maple state + * @index: The index to restore in @mas. + * + * Used in test code. + * Return: 1 if @mas has been reset to MAS_START, 0 otherwise. + */ +static inline int mas_dead_node(struct ma_state *mas, unsigned long index) +{ + if (unlikely(!mas_searchable(mas) || mas_is_start(mas))) + return 0; + + if (likely(!mte_dead_node(mas->node))) + return 0; + + mas_rewalk(mas, index); + return 1; +} + +void mt_cache_shrink(void) +{ +} +#else +/* + * mt_cache_shrink() - For testing, don't use this. + * + * Certain testcases can trigger an OOM when combined with other memory + * debugging configuration options. This function is used to reduce the + * possibility of an out of memory even due to kmem_cache objects remaining + * around for longer than usual. + */ +void mt_cache_shrink(void) +{ + kmem_cache_shrink(maple_node_cache); + +} +EXPORT_SYMBOL_GPL(mt_cache_shrink); + +#endif /* not defined __KERNEL__ */ +/* + * mas_get_slot() - Get the entry in the maple state node stored at @offset. + * @mas: The maple state + * @offset: The offset into the slot array to fetch. + * + * Return: The entry stored at @offset. + */ +static inline struct maple_enode *mas_get_slot(struct ma_state *mas, + unsigned char offset) +{ + return mas_slot(mas, ma_slots(mas_mn(mas), mte_node_type(mas->node)), + offset); +} + + +/* + * mas_first_entry() - Go the first leaf and find the first entry. + * @mas: the maple state. + * @limit: the maximum index to check. + * @*r_start: Pointer to set to the range start. + * + * Sets mas->offset to the offset of the entry, r_start to the range minimum. + * + * Return: The first entry or MAS_NONE. + */ +static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn, + unsigned long limit, enum maple_type mt) + +{ + unsigned long max; + unsigned long *pivots; + void __rcu **slots; + void *entry = NULL; + + mas->index = mas->min; + if (mas->index > limit) + goto none; + + max = mas->max; + mas->offset = 0; + while (likely(!ma_is_leaf(mt))) { + MT_BUG_ON(mas->tree, mte_dead_node(mas->node)); + slots = ma_slots(mn, mt); + pivots = ma_pivots(mn, mt); + max = pivots[0]; + entry = mas_slot(mas, slots, 0); + if (unlikely(ma_dead_node(mn))) + return NULL; + mas->node = entry; + mn = mas_mn(mas); + mt = mte_node_type(mas->node); + } + MT_BUG_ON(mas->tree, mte_dead_node(mas->node)); + + mas->max = max; + slots = ma_slots(mn, mt); + entry = mas_slot(mas, slots, 0); + if (unlikely(ma_dead_node(mn))) + return NULL; + + /* Slot 0 or 1 must be set */ + if (mas->index > limit) + goto none; + + if (likely(entry)) + return entry; + + pivots = ma_pivots(mn, mt); + mas->index = pivots[0] + 1; + mas->offset = 1; + entry = mas_slot(mas, slots, 1); + if (unlikely(ma_dead_node(mn))) + return NULL; + + if (mas->index > limit) + goto none; + + if (likely(entry)) + return entry; + +none: + if (likely(!ma_dead_node(mn))) + mas->node = MAS_NONE; + return NULL; +} + +/* Depth first search, post-order */ +static void mas_dfs_postorder(struct ma_state *mas, unsigned long max) +{ + + struct maple_enode *p = MAS_NONE, *mn = mas->node; + unsigned long p_min, p_max; + + mas_next_node(mas, mas_mn(mas), max); + if (!mas_is_none(mas)) + return; + + if (mte_is_root(mn)) + return; + + mas->node = mn; + mas_ascend(mas); + while (mas->node != MAS_NONE) { + p = mas->node; + p_min = mas->min; + p_max = mas->max; + mas_prev_node(mas, 0); + } + + if (p == MAS_NONE) + return; + + mas->node = p; + mas->max = p_max; + mas->min = p_min; +} + +/* Tree validations */ +static void mt_dump_node(const struct maple_tree *mt, void *entry, + unsigned long min, unsigned long max, unsigned int depth); +static void mt_dump_range(unsigned long min, unsigned long max, + unsigned int depth) +{ + static const char spaces[] = " "; + + if (min == max) + pr_info("%.*s%lu: ", depth * 2, spaces, min); + else + pr_info("%.*s%lu-%lu: ", depth * 2, spaces, min, max); +} + +static void mt_dump_entry(void *entry, unsigned long min, unsigned long max, + unsigned int depth) +{ + mt_dump_range(min, max, depth); + + if (xa_is_value(entry)) + pr_cont("value %ld (0x%lx) [%p]\n", xa_to_value(entry), + xa_to_value(entry), entry); + else if (xa_is_zero(entry)) + pr_cont("zero (%ld)\n", xa_to_internal(entry)); + else if (mt_is_reserved(entry)) + pr_cont("UNKNOWN ENTRY (%p)\n", entry); + else + pr_cont("%p\n", entry); +} + +static void mt_dump_range64(const struct maple_tree *mt, void *entry, + unsigned long min, unsigned long max, unsigned int depth) +{ + struct maple_range_64 *node = &mte_to_node(entry)->mr64; + bool leaf = mte_is_leaf(entry); + unsigned long first = min; + int i; + + pr_cont(" contents: "); + for (i = 0; i < MAPLE_RANGE64_SLOTS - 1; i++) + pr_cont("%p %lu ", node->slot[i], node->pivot[i]); + pr_cont("%p\n", node->slot[i]); + for (i = 0; i < MAPLE_RANGE64_SLOTS; i++) { + unsigned long last = max; + + if (i < (MAPLE_RANGE64_SLOTS - 1)) + last = node->pivot[i]; + else if (!node->slot[i] && max != mt_max[mte_node_type(entry)]) + break; + if (last == 0 && i > 0) + break; + if (leaf) + mt_dump_entry(mt_slot(mt, node->slot, i), + first, last, depth + 1); + else if (node->slot[i]) + mt_dump_node(mt, mt_slot(mt, node->slot, i), + first, last, depth + 1); + + if (last == max) + break; + if (last > max) { + pr_err("node %p last (%lu) > max (%lu) at pivot %d!\n", + node, last, max, i); + break; + } + first = last + 1; + } +} + +static void mt_dump_arange64(const struct maple_tree *mt, void *entry, + unsigned long min, unsigned long max, unsigned int depth) +{ + struct maple_arange_64 *node = &mte_to_node(entry)->ma64; + bool leaf = mte_is_leaf(entry); + unsigned long first = min; + int i; + + pr_cont(" contents: "); + for (i = 0; i < MAPLE_ARANGE64_SLOTS; i++) + pr_cont("%lu ", node->gap[i]); + pr_cont("| %02X %02X| ", node->meta.end, node->meta.gap); + for (i = 0; i < MAPLE_ARANGE64_SLOTS - 1; i++) + pr_cont("%p %lu ", node->slot[i], node->pivot[i]); + pr_cont("%p\n", node->slot[i]); + for (i = 0; i < MAPLE_ARANGE64_SLOTS; i++) { + unsigned long last = max; + + if (i < (MAPLE_ARANGE64_SLOTS - 1)) + last = node->pivot[i]; + else if (!node->slot[i]) + break; + if (last == 0 && i > 0) + break; + if (leaf) + mt_dump_entry(mt_slot(mt, node->slot, i), + first, last, depth + 1); + else if (node->slot[i]) + mt_dump_node(mt, mt_slot(mt, node->slot, i), + first, last, depth + 1); + + if (last == max) + break; + if (last > max) { + pr_err("node %p last (%lu) > max (%lu) at pivot %d!\n", + node, last, max, i); + break; + } + first = last + 1; + } +} + +static void mt_dump_node(const struct maple_tree *mt, void *entry, + unsigned long min, unsigned long max, unsigned int depth) +{ + struct maple_node *node = mte_to_node(entry); + unsigned int type = mte_node_type(entry); + unsigned int i; + + mt_dump_range(min, max, depth); + + pr_cont("node %p depth %d type %d parent %p", node, depth, type, + node ? node->parent : NULL); + switch (type) { + case maple_dense: + pr_cont("\n"); + for (i = 0; i < MAPLE_NODE_SLOTS; i++) { + if (min + i > max) + pr_cont("OUT OF RANGE: "); + mt_dump_entry(mt_slot(mt, node->slot, i), + min + i, min + i, depth); + } + break; + case maple_leaf_64: + case maple_range_64: + mt_dump_range64(mt, entry, min, max, depth); + break; + case maple_arange_64: + mt_dump_arange64(mt, entry, min, max, depth); + break; + + default: + pr_cont(" UNKNOWN TYPE\n"); + } +} + +void mt_dump(const struct maple_tree *mt) +{ + void *entry = rcu_dereference_check(mt->ma_root, mt_locked(mt)); + + pr_info("maple_tree(%p) flags %X, height %u root %p\n", + mt, mt->ma_flags, mt_height(mt), entry); + if (!xa_is_node(entry)) + mt_dump_entry(entry, 0, 0, 0); + else if (entry) + mt_dump_node(mt, entry, 0, mt_max[mte_node_type(entry)], 0); +} +EXPORT_SYMBOL_GPL(mt_dump); + +/* + * Calculate the maximum gap in a node and check if that's what is reported in + * the parent (unless root). + */ +static void mas_validate_gaps(struct ma_state *mas) +{ + struct maple_enode *mte = mas->node; + struct maple_node *p_mn; + unsigned long gap = 0, max_gap = 0; + unsigned long p_end, p_start = mas->min; + unsigned char p_slot; + unsigned long *gaps = NULL; + unsigned long *pivots = ma_pivots(mte_to_node(mte), mte_node_type(mte)); + int i; + + if (ma_is_dense(mte_node_type(mte))) { + for (i = 0; i < mt_slot_count(mte); i++) { + if (mas_get_slot(mas, i)) { + if (gap > max_gap) + max_gap = gap; + gap = 0; + continue; + } + gap++; + } + goto counted; + } + + gaps = ma_gaps(mte_to_node(mte), mte_node_type(mte)); + for (i = 0; i < mt_slot_count(mte); i++) { + p_end = mas_logical_pivot(mas, pivots, i, mte_node_type(mte)); + + if (!gaps) { + if (mas_get_slot(mas, i)) { + gap = 0; + goto not_empty; + } + + gap += p_end - p_start + 1; + } else { + void *entry = mas_get_slot(mas, i); + + gap = gaps[i]; + if (!entry) { + if (gap != p_end - p_start + 1) { + pr_err("%p[%u] -> %p %lu != %lu - %lu + 1\n", + mas_mn(mas), i, + mas_get_slot(mas, i), gap, + p_end, p_start); + mt_dump(mas->tree); + + MT_BUG_ON(mas->tree, + gap != p_end - p_start + 1); + } + } else { + if (gap > p_end - p_start + 1) { + pr_err("%p[%u] %lu >= %lu - %lu + 1 (%lu)\n", + mas_mn(mas), i, gap, p_end, p_start, + p_end - p_start + 1); + MT_BUG_ON(mas->tree, + gap > p_end - p_start + 1); + } + } + } + + if (gap > max_gap) + max_gap = gap; +not_empty: + p_start = p_end + 1; + if (p_end >= mas->max) + break; + } + +counted: + if (mte_is_root(mte)) + return; + + p_slot = mte_parent_slot(mas->node); + p_mn = mte_parent(mte); + MT_BUG_ON(mas->tree, max_gap > mas->max); + if (ma_gaps(p_mn, mas_parent_enum(mas, mte))[p_slot] != max_gap) { + pr_err("gap %p[%u] != %lu\n", p_mn, p_slot, max_gap); + mt_dump(mas->tree); + } + + MT_BUG_ON(mas->tree, + ma_gaps(p_mn, mas_parent_enum(mas, mte))[p_slot] != max_gap); +} + +static void mas_validate_parent_slot(struct ma_state *mas) +{ + struct maple_node *parent; + struct maple_enode *node; + enum maple_type p_type = mas_parent_enum(mas, mas->node); + unsigned char p_slot = mte_parent_slot(mas->node); + void __rcu **slots; + int i; + + if (mte_is_root(mas->node)) + return; + + parent = mte_parent(mas->node); + slots = ma_slots(parent, p_type); + MT_BUG_ON(mas->tree, mas_mn(mas) == parent); + + /* Check prev/next parent slot for duplicate node entry */ + + for (i = 0; i < mt_slots[p_type]; i++) { + node = mas_slot(mas, slots, i); + if (i == p_slot) { + if (node != mas->node) + pr_err("parent %p[%u] does not have %p\n", + parent, i, mas_mn(mas)); + MT_BUG_ON(mas->tree, node != mas->node); + } else if (node == mas->node) { + pr_err("Invalid child %p at parent %p[%u] p_slot %u\n", + mas_mn(mas), parent, i, p_slot); + MT_BUG_ON(mas->tree, node == mas->node); + } + } +} + +static void mas_validate_child_slot(struct ma_state *mas) +{ + enum maple_type type = mte_node_type(mas->node); + void __rcu **slots = ma_slots(mte_to_node(mas->node), type); + unsigned long *pivots = ma_pivots(mte_to_node(mas->node), type); + struct maple_enode *child; + unsigned char i; + + if (mte_is_leaf(mas->node)) + return; + + for (i = 0; i < mt_slots[type]; i++) { + child = mas_slot(mas, slots, i); + if (!pivots[i] || pivots[i] == mas->max) + break; + + if (!child) + break; + + if (mte_parent_slot(child) != i) { + pr_err("Slot error at %p[%u]: child %p has pslot %u\n", + mas_mn(mas), i, mte_to_node(child), + mte_parent_slot(child)); + MT_BUG_ON(mas->tree, 1); + } + + if (mte_parent(child) != mte_to_node(mas->node)) { + pr_err("child %p has parent %p not %p\n", + mte_to_node(child), mte_parent(child), + mte_to_node(mas->node)); + MT_BUG_ON(mas->tree, 1); + } + } +} + +/* + * Validate all pivots are within mas->min and mas->max. + */ +static void mas_validate_limits(struct ma_state *mas) +{ + int i; + unsigned long prev_piv = 0; + enum maple_type type = mte_node_type(mas->node); + void __rcu **slots = ma_slots(mte_to_node(mas->node), type); + unsigned long *pivots = ma_pivots(mas_mn(mas), type); + + /* all limits are fine here. */ + if (mte_is_root(mas->node)) + return; + + for (i = 0; i < mt_slots[type]; i++) { + unsigned long piv; + + piv = mas_safe_pivot(mas, pivots, i, type); + + if (!piv && (i != 0)) + break; + + if (!mte_is_leaf(mas->node)) { + void *entry = mas_slot(mas, slots, i); + + if (!entry) + pr_err("%p[%u] cannot be null\n", + mas_mn(mas), i); + + MT_BUG_ON(mas->tree, !entry); + } + + if (prev_piv > piv) { + pr_err("%p[%u] piv %lu < prev_piv %lu\n", + mas_mn(mas), i, piv, prev_piv); + MT_BUG_ON(mas->tree, piv < prev_piv); + } + + if (piv < mas->min) { + pr_err("%p[%u] %lu < %lu\n", mas_mn(mas), i, + piv, mas->min); + MT_BUG_ON(mas->tree, piv < mas->min); + } + if (piv > mas->max) { + pr_err("%p[%u] %lu > %lu\n", mas_mn(mas), i, + piv, mas->max); + MT_BUG_ON(mas->tree, piv > mas->max); + } + prev_piv = piv; + if (piv == mas->max) + break; + } + for (i += 1; i < mt_slots[type]; i++) { + void *entry = mas_slot(mas, slots, i); + + if (entry && (i != mt_slots[type] - 1)) { + pr_err("%p[%u] should not have entry %p\n", mas_mn(mas), + i, entry); + MT_BUG_ON(mas->tree, entry != NULL); + } + + if (i < mt_pivots[type]) { + unsigned long piv = pivots[i]; + + if (!piv) + continue; + + pr_err("%p[%u] should not have piv %lu\n", + mas_mn(mas), i, piv); + MT_BUG_ON(mas->tree, i < mt_pivots[type] - 1); + } + } +} + +static void mt_validate_nulls(struct maple_tree *mt) +{ + void *entry, *last = (void *)1; + unsigned char offset = 0; + void __rcu **slots; + MA_STATE(mas, mt, 0, 0); + + mas_start(&mas); + if (mas_is_none(&mas) || (mas.node == MAS_ROOT)) + return; + + while (!mte_is_leaf(mas.node)) + mas_descend(&mas); + + slots = ma_slots(mte_to_node(mas.node), mte_node_type(mas.node)); + do { + entry = mas_slot(&mas, slots, offset); + if (!last && !entry) { + pr_err("Sequential nulls end at %p[%u]\n", + mas_mn(&mas), offset); + } + MT_BUG_ON(mt, !last && !entry); + last = entry; + if (offset == mas_data_end(&mas)) { + mas_next_node(&mas, mas_mn(&mas), ULONG_MAX); + if (mas_is_none(&mas)) + return; + offset = 0; + slots = ma_slots(mte_to_node(mas.node), + mte_node_type(mas.node)); + } else { + offset++; + } + + } while (!mas_is_none(&mas)); +} + +/* + * validate a maple tree by checking: + * 1. The limits (pivots are within mas->min to mas->max) + * 2. The gap is correctly set in the parents + */ +void mt_validate(struct maple_tree *mt) +{ + unsigned char end; + + MA_STATE(mas, mt, 0, 0); + rcu_read_lock(); + mas_start(&mas); + if (!mas_searchable(&mas)) + goto done; + + mas_first_entry(&mas, mas_mn(&mas), ULONG_MAX, mte_node_type(mas.node)); + while (!mas_is_none(&mas)) { + MT_BUG_ON(mas.tree, mte_dead_node(mas.node)); + if (!mte_is_root(mas.node)) { + end = mas_data_end(&mas); + if ((end < mt_min_slot_count(mas.node)) && + (mas.max != ULONG_MAX)) { + pr_err("Invalid size %u of %p\n", end, + mas_mn(&mas)); + MT_BUG_ON(mas.tree, 1); + } + + } + mas_validate_parent_slot(&mas); + mas_validate_child_slot(&mas); + mas_validate_limits(&mas); + if (mt_is_alloc(mt)) + mas_validate_gaps(&mas); + mas_dfs_postorder(&mas, ULONG_MAX); + } + mt_validate_nulls(mt); +done: + rcu_read_unlock(); + +} +EXPORT_SYMBOL_GPL(mt_validate); + +#endif /* CONFIG_DEBUG_MAPLE_TREE */ diff --git a/lib/test_maple_tree.c b/lib/test_maple_tree.c new file mode 100644 index 000000000000..f425f169ef08 --- /dev/null +++ b/lib/test_maple_tree.c @@ -0,0 +1,2767 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * test_maple_tree.c: Test the maple tree API + * Copyright (c) 2018-2022 Oracle Corporation + * Author: Liam R. Howlett + * + * Any tests that only require the interface of the tree. + */ + +#include +#include + +#define MTREE_ALLOC_MAX 0x2000000000000Ul +#ifndef CONFIG_DEBUG_MAPLE_TREE +#define CONFIG_DEBUG_MAPLE_TREE +#endif +#define CONFIG_MAPLE_SEARCH +#define MAPLE_32BIT (MAPLE_NODE_SLOTS > 31) + +/* #define BENCH_SLOT_STORE */ +/* #define BENCH_NODE_STORE */ +/* #define BENCH_AWALK */ +/* #define BENCH_WALK */ +/* #define BENCH_MT_FOR_EACH */ +/* #define BENCH_FORK */ + +#ifdef __KERNEL__ +#define mt_set_non_kernel(x) do {} while (0) +#define mt_zero_nr_tallocated(x) do {} while (0) +#else +#define cond_resched() do {} while (0) +#endif +static +int mtree_insert_index(struct maple_tree *mt, unsigned long index, gfp_t gfp) +{ + return mtree_insert(mt, index, xa_mk_value(index & LONG_MAX), gfp); +} + +static void mtree_erase_index(struct maple_tree *mt, unsigned long index) +{ + MT_BUG_ON(mt, mtree_erase(mt, index) != xa_mk_value(index & LONG_MAX)); + MT_BUG_ON(mt, mtree_load(mt, index) != NULL); +} + +static int mtree_test_insert(struct maple_tree *mt, unsigned long index, + void *ptr) +{ + return mtree_insert(mt, index, ptr, GFP_KERNEL); +} + +static int mtree_test_store_range(struct maple_tree *mt, unsigned long start, + unsigned long end, void *ptr) +{ + return mtree_store_range(mt, start, end, ptr, GFP_KERNEL); +} + +static int mtree_test_store(struct maple_tree *mt, unsigned long start, + void *ptr) +{ + return mtree_test_store_range(mt, start, start, ptr); +} + +static int mtree_test_insert_range(struct maple_tree *mt, unsigned long start, + unsigned long end, void *ptr) +{ + return mtree_insert_range(mt, start, end, ptr, GFP_KERNEL); +} + +static void *mtree_test_load(struct maple_tree *mt, unsigned long index) +{ + return mtree_load(mt, index); +} + +static void *mtree_test_erase(struct maple_tree *mt, unsigned long index) +{ + return mtree_erase(mt, index); +} + +#if defined(CONFIG_64BIT) +static noinline void check_mtree_alloc_range(struct maple_tree *mt, + unsigned long start, unsigned long end, unsigned long size, + unsigned long expected, int eret, void *ptr) +{ + + unsigned long result = expected + 1; + int ret; + + ret = mtree_alloc_range(mt, &result, ptr, size, start, end, + GFP_KERNEL); + MT_BUG_ON(mt, ret != eret); + if (ret) + return; + + MT_BUG_ON(mt, result != expected); +} + +static noinline void check_mtree_alloc_rrange(struct maple_tree *mt, + unsigned long start, unsigned long end, unsigned long size, + unsigned long expected, int eret, void *ptr) +{ + + unsigned long result = expected + 1; + int ret; + + ret = mtree_alloc_rrange(mt, &result, ptr, size, start, end - 1, + GFP_KERNEL); + MT_BUG_ON(mt, ret != eret); + if (ret) + return; + + MT_BUG_ON(mt, result != expected); +} +#endif + +static noinline void check_load(struct maple_tree *mt, unsigned long index, + void *ptr) +{ + void *ret = mtree_test_load(mt, index); + + if (ret != ptr) + pr_err("Load %lu returned %p expect %p\n", index, ret, ptr); + MT_BUG_ON(mt, ret != ptr); +} + +static noinline void check_store_range(struct maple_tree *mt, + unsigned long start, unsigned long end, void *ptr, int expected) +{ + int ret = -EINVAL; + unsigned long i; + + ret = mtree_test_store_range(mt, start, end, ptr); + MT_BUG_ON(mt, ret != expected); + + if (ret) + return; + + for (i = start; i <= end; i++) + check_load(mt, i, ptr); +} + +static noinline void check_insert_range(struct maple_tree *mt, + unsigned long start, unsigned long end, void *ptr, int expected) +{ + int ret = -EINVAL; + unsigned long i; + + ret = mtree_test_insert_range(mt, start, end, ptr); + MT_BUG_ON(mt, ret != expected); + + if (ret) + return; + + for (i = start; i <= end; i++) + check_load(mt, i, ptr); +} + +static noinline void check_insert(struct maple_tree *mt, unsigned long index, + void *ptr) +{ + int ret = -EINVAL; + + ret = mtree_test_insert(mt, index, ptr); + MT_BUG_ON(mt, ret != 0); +} + +static noinline void check_dup_insert(struct maple_tree *mt, + unsigned long index, void *ptr) +{ + int ret = -EINVAL; + + ret = mtree_test_insert(mt, index, ptr); + MT_BUG_ON(mt, ret != -EEXIST); +} + + +static noinline +void check_index_load(struct maple_tree *mt, unsigned long index) +{ + return check_load(mt, index, xa_mk_value(index & LONG_MAX)); +} + +static inline int not_empty(struct maple_node *node) +{ + int i; + + if (node->parent) + return 1; + + for (i = 0; i < ARRAY_SIZE(node->slot); i++) + if (node->slot[i]) + return 1; + + return 0; +} + + +static noinline void check_rev_seq(struct maple_tree *mt, unsigned long max, + bool verbose) +{ + unsigned long i = max, j; + + MT_BUG_ON(mt, !mtree_empty(mt)); + + mt_zero_nr_tallocated(); + while (i) { + MT_BUG_ON(mt, mtree_insert_index(mt, i, GFP_KERNEL)); + for (j = i; j <= max; j++) + check_index_load(mt, j); + + check_load(mt, i - 1, NULL); + mt_set_in_rcu(mt); + MT_BUG_ON(mt, !mt_height(mt)); + mt_clear_in_rcu(mt); + MT_BUG_ON(mt, !mt_height(mt)); + i--; + } + check_load(mt, max + 1, NULL); + +#ifndef __KERNEL__ + if (verbose) { + rcu_barrier(); + mt_dump(mt); + pr_info(" %s test of 0-%lu %luK in %d active (%d total)\n", + __func__, max, mt_get_alloc_size()/1024, mt_nr_allocated(), + mt_nr_tallocated()); + } +#endif +} + +static noinline void check_seq(struct maple_tree *mt, unsigned long max, + bool verbose) +{ + unsigned long i, j; + + MT_BUG_ON(mt, !mtree_empty(mt)); + + mt_zero_nr_tallocated(); + for (i = 0; i <= max; i++) { + MT_BUG_ON(mt, mtree_insert_index(mt, i, GFP_KERNEL)); + for (j = 0; j <= i; j++) + check_index_load(mt, j); + + if (i) + MT_BUG_ON(mt, !mt_height(mt)); + check_load(mt, i + 1, NULL); + } + +#ifndef __KERNEL__ + if (verbose) { + rcu_barrier(); + mt_dump(mt); + pr_info(" seq test of 0-%lu %luK in %d active (%d total)\n", + max, mt_get_alloc_size()/1024, mt_nr_allocated(), + mt_nr_tallocated()); + } +#endif +} + +static noinline void check_lb_not_empty(struct maple_tree *mt) +{ + unsigned long i, j; + unsigned long huge = 4000UL * 1000 * 1000; + + + i = huge; + while (i > 4096) { + check_insert(mt, i, (void *) i); + for (j = huge; j >= i; j /= 2) { + check_load(mt, j-1, NULL); + check_load(mt, j, (void *) j); + check_load(mt, j+1, NULL); + } + i /= 2; + } + mtree_destroy(mt); +} + +static noinline void check_lower_bound_split(struct maple_tree *mt) +{ + MT_BUG_ON(mt, !mtree_empty(mt)); + check_lb_not_empty(mt); +} + +static noinline void check_upper_bound_split(struct maple_tree *mt) +{ + unsigned long i, j; + unsigned long huge; + + MT_BUG_ON(mt, !mtree_empty(mt)); + + if (MAPLE_32BIT) + huge = 2147483647UL; + else + huge = 4000UL * 1000 * 1000; + + i = 4096; + while (i < huge) { + check_insert(mt, i, (void *) i); + for (j = i; j >= huge; j *= 2) { + check_load(mt, j-1, NULL); + check_load(mt, j, (void *) j); + check_load(mt, j+1, NULL); + } + i *= 2; + } + mtree_destroy(mt); +} + +static noinline void check_mid_split(struct maple_tree *mt) +{ + unsigned long huge = 8000UL * 1000 * 1000; + + check_insert(mt, huge, (void *) huge); + check_insert(mt, 0, xa_mk_value(0)); + check_lb_not_empty(mt); +} + +static noinline void check_rev_find(struct maple_tree *mt) +{ + int i, nr_entries = 200; + void *val; + MA_STATE(mas, mt, 0, 0); + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + rcu_read_lock(); + mas_set(&mas, 1000); + val = mas_find_rev(&mas, 1000); + MT_BUG_ON(mt, val != xa_mk_value(100)); + val = mas_find_rev(&mas, 1000); + MT_BUG_ON(mt, val != NULL); + + mas_set(&mas, 999); + val = mas_find_rev(&mas, 997); + MT_BUG_ON(mt, val != NULL); + + mas_set(&mas, 1000); + val = mas_find_rev(&mas, 900); + MT_BUG_ON(mt, val != xa_mk_value(100)); + val = mas_find_rev(&mas, 900); + MT_BUG_ON(mt, val != xa_mk_value(99)); + + mas_set(&mas, 20); + val = mas_find_rev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(2)); + val = mas_find_rev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(1)); + val = mas_find_rev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(0)); + val = mas_find_rev(&mas, 0); + MT_BUG_ON(mt, val != NULL); + rcu_read_unlock(); +} + +static noinline void check_find(struct maple_tree *mt) +{ + unsigned long val = 0; + unsigned long count; + unsigned long max; + unsigned long top; + unsigned long last = 0, index = 0; + void *entry, *entry2; + + MA_STATE(mas, mt, 0, 0); + + /* Insert 0. */ + MT_BUG_ON(mt, mtree_insert_index(mt, val++, GFP_KERNEL)); + +#if defined(CONFIG_64BIT) + top = 4398046511104UL; +#else + top = ULONG_MAX; +#endif + + if (MAPLE_32BIT) { + count = 15; + } else { + count = 20; + } + + for (int i = 0; i <= count; i++) { + if (val != 64) + MT_BUG_ON(mt, mtree_insert_index(mt, val, GFP_KERNEL)); + else + MT_BUG_ON(mt, mtree_insert(mt, val, + XA_ZERO_ENTRY, GFP_KERNEL)); + + val <<= 2; + } + + val = 0; + mas_set(&mas, val); + mas_lock(&mas); + while ((entry = mas_find(&mas, 268435456)) != NULL) { + if (val != 64) + MT_BUG_ON(mt, xa_mk_value(val) != entry); + else + MT_BUG_ON(mt, entry != XA_ZERO_ENTRY); + + val <<= 2; + /* For zero check. */ + if (!val) + val = 1; + } + mas_unlock(&mas); + + val = 0; + mas_set(&mas, val); + mas_lock(&mas); + mas_for_each(&mas, entry, ULONG_MAX) { + if (val != 64) + MT_BUG_ON(mt, xa_mk_value(val) != entry); + else + MT_BUG_ON(mt, entry != XA_ZERO_ENTRY); + val <<= 2; + /* For zero check. */ + if (!val) + val = 1; + } + mas_unlock(&mas); + + /* Test mas_pause */ + val = 0; + mas_set(&mas, val); + mas_lock(&mas); + mas_for_each(&mas, entry, ULONG_MAX) { + if (val != 64) + MT_BUG_ON(mt, xa_mk_value(val) != entry); + else + MT_BUG_ON(mt, entry != XA_ZERO_ENTRY); + val <<= 2; + /* For zero check. */ + if (!val) + val = 1; + + mas_pause(&mas); + mas_unlock(&mas); + mas_lock(&mas); + } + mas_unlock(&mas); + + val = 0; + max = 300; /* A value big enough to include XA_ZERO_ENTRY at 64. */ + mt_for_each(mt, entry, index, max) { + MT_BUG_ON(mt, xa_mk_value(val) != entry); + val <<= 2; + if (val == 64) /* Skip zero entry. */ + val <<= 2; + /* For zero check. */ + if (!val) + val = 1; + } + + val = 0; + max = 0; + index = 0; + MT_BUG_ON(mt, mtree_insert_index(mt, ULONG_MAX, GFP_KERNEL)); + mt_for_each(mt, entry, index, ULONG_MAX) { + if (val == top) + MT_BUG_ON(mt, entry != xa_mk_value(LONG_MAX)); + else + MT_BUG_ON(mt, xa_mk_value(val) != entry); + + /* Workaround for 32bit */ + if ((val << 2) < val) + val = ULONG_MAX; + else + val <<= 2; + + if (val == 64) /* Skip zero entry. */ + val <<= 2; + /* For zero check. */ + if (!val) + val = 1; + max++; + MT_BUG_ON(mt, max > 25); + } + mtree_erase_index(mt, ULONG_MAX); + + mas_reset(&mas); + index = 17; + entry = mt_find(mt, &index, 512); + MT_BUG_ON(mt, xa_mk_value(256) != entry); + + mas_reset(&mas); + index = 17; + entry = mt_find(mt, &index, 20); + MT_BUG_ON(mt, entry != NULL); + + + /* Range check.. */ + /* Insert ULONG_MAX */ + MT_BUG_ON(mt, mtree_insert_index(mt, ULONG_MAX, GFP_KERNEL)); + + val = 0; + mas_set(&mas, 0); + mas_lock(&mas); + mas_for_each(&mas, entry, ULONG_MAX) { + if (val == 64) + MT_BUG_ON(mt, entry != XA_ZERO_ENTRY); + else if (val == top) + MT_BUG_ON(mt, entry != xa_mk_value(LONG_MAX)); + else + MT_BUG_ON(mt, xa_mk_value(val) != entry); + + /* Workaround for 32bit */ + if ((val << 2) < val) + val = ULONG_MAX; + else + val <<= 2; + + /* For zero check. */ + if (!val) + val = 1; + mas_pause(&mas); + mas_unlock(&mas); + mas_lock(&mas); + } + mas_unlock(&mas); + + mas_set(&mas, 1048576); + mas_lock(&mas); + entry = mas_find(&mas, 1048576); + mas_unlock(&mas); + MT_BUG_ON(mas.tree, entry == NULL); + + /* + * Find last value. + * 1. get the expected value, leveraging the existence of an end entry + * 2. delete end entry + * 3. find the last value but searching for ULONG_MAX and then using + * prev + */ + /* First, get the expected result. */ + mas_lock(&mas); + mas_reset(&mas); + mas.index = ULONG_MAX; /* start at max.. */ + entry = mas_find(&mas, ULONG_MAX); + entry = mas_prev(&mas, 0); + index = mas.index; + last = mas.last; + + /* Erase the last entry. */ + mas_reset(&mas); + mas.index = ULONG_MAX; + mas.last = ULONG_MAX; + mas_erase(&mas); + + /* Get the previous value from MAS_START */ + mas_reset(&mas); + entry2 = mas_prev(&mas, 0); + + /* Check results. */ + MT_BUG_ON(mt, entry != entry2); + MT_BUG_ON(mt, index != mas.index); + MT_BUG_ON(mt, last != mas.last); + + + mas.node = MAS_NONE; + mas.index = ULONG_MAX; + mas.last = ULONG_MAX; + entry2 = mas_prev(&mas, 0); + MT_BUG_ON(mt, entry != entry2); + + mas_set(&mas, 0); + MT_BUG_ON(mt, mas_prev(&mas, 0) != NULL); + + mas_unlock(&mas); + mtree_destroy(mt); +} + +static noinline void check_find_2(struct maple_tree *mt) +{ + unsigned long i, j; + void *entry; + + MA_STATE(mas, mt, 0, 0); + rcu_read_lock(); + mas_for_each(&mas, entry, ULONG_MAX) + MT_BUG_ON(mt, true); + rcu_read_unlock(); + + for (i = 0; i < 256; i++) { + mtree_insert_index(mt, i, GFP_KERNEL); + j = 0; + mas_set(&mas, 0); + rcu_read_lock(); + mas_for_each(&mas, entry, ULONG_MAX) { + MT_BUG_ON(mt, entry != xa_mk_value(j)); + j++; + } + rcu_read_unlock(); + MT_BUG_ON(mt, j != i + 1); + } + + for (i = 0; i < 256; i++) { + mtree_erase_index(mt, i); + j = i + 1; + mas_set(&mas, 0); + rcu_read_lock(); + mas_for_each(&mas, entry, ULONG_MAX) { + if (xa_is_zero(entry)) + continue; + + MT_BUG_ON(mt, entry != xa_mk_value(j)); + j++; + } + rcu_read_unlock(); + MT_BUG_ON(mt, j != 256); + } + + /*MT_BUG_ON(mt, !mtree_empty(mt)); */ +} + + +#if defined(CONFIG_64BIT) +static noinline void check_alloc_rev_range(struct maple_tree *mt) +{ + /* + * Generated by: + * cat /proc/self/maps | awk '{print $1}'| + * awk -F "-" '{printf "0x%s, 0x%s, ", $1, $2}' + */ + + unsigned long range[] = { + /* Inclusive , Exclusive. */ + 0x565234af2000, 0x565234af4000, + 0x565234af4000, 0x565234af9000, + 0x565234af9000, 0x565234afb000, + 0x565234afc000, 0x565234afd000, + 0x565234afd000, 0x565234afe000, + 0x565235def000, 0x565235e10000, + 0x7f36d4bfd000, 0x7f36d4ee2000, + 0x7f36d4ee2000, 0x7f36d4f04000, + 0x7f36d4f04000, 0x7f36d504c000, + 0x7f36d504c000, 0x7f36d5098000, + 0x7f36d5098000, 0x7f36d5099000, + 0x7f36d5099000, 0x7f36d509d000, + 0x7f36d509d000, 0x7f36d509f000, + 0x7f36d509f000, 0x7f36d50a5000, + 0x7f36d50b9000, 0x7f36d50db000, + 0x7f36d50db000, 0x7f36d50dc000, + 0x7f36d50dc000, 0x7f36d50fa000, + 0x7f36d50fa000, 0x7f36d5102000, + 0x7f36d5102000, 0x7f36d5103000, + 0x7f36d5103000, 0x7f36d5104000, + 0x7f36d5104000, 0x7f36d5105000, + 0x7fff5876b000, 0x7fff5878d000, + 0x7fff5878e000, 0x7fff58791000, + 0x7fff58791000, 0x7fff58793000, + }; + + unsigned long holes[] = { + /* + * Note: start of hole is INCLUSIVE + * end of hole is EXCLUSIVE + * (opposite of the above table.) + * Start of hole, end of hole, size of hole (+1) + */ + 0x565234afb000, 0x565234afc000, 0x1000, + 0x565234afe000, 0x565235def000, 0x12F1000, + 0x565235e10000, 0x7f36d4bfd000, 0x28E49EDED000, + }; + + /* + * req_range consists of 4 values. + * 1. min index + * 2. max index + * 3. size + * 4. number that should be returned. + * 5. return value + */ + unsigned long req_range[] = { + 0x565234af9000, /* Min */ + 0x7fff58791000, /* Max */ + 0x1000, /* Size */ + 0x7fff5878d << 12, /* First rev hole of size 0x1000 */ + 0, /* Return value success. */ + + 0x0, /* Min */ + 0x565234AF1 << 12, /* Max */ + 0x3000, /* Size */ + 0x565234AEE << 12, /* max - 3. */ + 0, /* Return value success. */ + + 0x0, /* Min */ + -1, /* Max */ + 0x1000, /* Size */ + 562949953421311 << 12,/* First rev hole of size 0x1000 */ + 0, /* Return value success. */ + + 0x0, /* Min */ + 0x7F36D510A << 12, /* Max */ + 0x4000, /* Size */ + 0x7F36D5106 << 12, /* First rev hole of size 0x4000 */ + 0, /* Return value success. */ + + /* Ascend test. */ + 0x0, + 34148798629 << 12, + 19 << 12, + 34148797418 << 12, + 0x0, + + /* Too big test. */ + 0x0, + 18446744073709551615UL, + 562915594369134UL << 12, + 0x0, + -EBUSY, + + }; + + int i, range_count = ARRAY_SIZE(range); + int req_range_count = ARRAY_SIZE(req_range); + unsigned long min = 0; + + MA_STATE(mas, mt, 0, 0); + + mtree_store_range(mt, MTREE_ALLOC_MAX, ULONG_MAX, XA_ZERO_ENTRY, + GFP_KERNEL); +#define DEBUG_REV_RANGE 0 + for (i = 0; i < range_count; i += 2) { + /* Inclusive, Inclusive (with the -1) */ + +#if DEBUG_REV_RANGE + pr_debug("\t%s: Insert %lu-%lu\n", __func__, range[i] >> 12, + (range[i + 1] >> 12) - 1); +#endif + check_insert_range(mt, range[i] >> 12, (range[i + 1] >> 12) - 1, + xa_mk_value(range[i] >> 12), 0); + mt_validate(mt); + } + + + mas_lock(&mas); + for (i = 0; i < ARRAY_SIZE(holes); i += 3) { +#if DEBUG_REV_RANGE + pr_debug("Search from %lu-%lu for gap %lu should be at %lu\n", + min, holes[i+1]>>12, holes[i+2]>>12, + holes[i] >> 12); +#endif + MT_BUG_ON(mt, mas_empty_area_rev(&mas, min, + holes[i+1] >> 12, + holes[i+2] >> 12)); +#if DEBUG_REV_RANGE + pr_debug("Found %lu %lu\n", mas.index, mas.last); + pr_debug("gap %lu %lu\n", (holes[i] >> 12), + (holes[i+1] >> 12)); +#endif + MT_BUG_ON(mt, mas.last + 1 != (holes[i+1] >> 12)); + MT_BUG_ON(mt, mas.index != (holes[i+1] >> 12) - (holes[i+2] >> 12)); + min = holes[i+1] >> 12; + mas_reset(&mas); + } + + mas_unlock(&mas); + for (i = 0; i < req_range_count; i += 5) { +#if DEBUG_REV_RANGE + pr_debug("\tReverse request between %lu-%lu size %lu, should get %lu\n", + req_range[i] >> 12, + (req_range[i + 1] >> 12) - 1, + req_range[i+2] >> 12, + req_range[i+3] >> 12); +#endif + check_mtree_alloc_rrange(mt, + req_range[i] >> 12, /* start */ + req_range[i+1] >> 12, /* end */ + req_range[i+2] >> 12, /* size */ + req_range[i+3] >> 12, /* expected address */ + req_range[i+4], /* expected return */ + xa_mk_value(req_range[i] >> 12)); /* pointer */ + mt_validate(mt); + } + + mt_set_non_kernel(1); + mtree_erase(mt, 34148798727); /* create a deleted range. */ + check_mtree_alloc_rrange(mt, 0, 34359052173, 210253414, + 34148798725, 0, mt); + + mtree_destroy(mt); +} + +static noinline void check_alloc_range(struct maple_tree *mt) +{ + /* + * Generated by: + * cat /proc/self/maps|awk '{print $1}'| + * awk -F "-" '{printf "0x%s, 0x%s, ", $1, $2}' + */ + + unsigned long range[] = { + /* Inclusive , Exclusive. */ + 0x565234af2000, 0x565234af4000, + 0x565234af4000, 0x565234af9000, + 0x565234af9000, 0x565234afb000, + 0x565234afc000, 0x565234afd000, + 0x565234afd000, 0x565234afe000, + 0x565235def000, 0x565235e10000, + 0x7f36d4bfd000, 0x7f36d4ee2000, + 0x7f36d4ee2000, 0x7f36d4f04000, + 0x7f36d4f04000, 0x7f36d504c000, + 0x7f36d504c000, 0x7f36d5098000, + 0x7f36d5098000, 0x7f36d5099000, + 0x7f36d5099000, 0x7f36d509d000, + 0x7f36d509d000, 0x7f36d509f000, + 0x7f36d509f000, 0x7f36d50a5000, + 0x7f36d50b9000, 0x7f36d50db000, + 0x7f36d50db000, 0x7f36d50dc000, + 0x7f36d50dc000, 0x7f36d50fa000, + 0x7f36d50fa000, 0x7f36d5102000, + 0x7f36d5102000, 0x7f36d5103000, + 0x7f36d5103000, 0x7f36d5104000, + 0x7f36d5104000, 0x7f36d5105000, + 0x7fff5876b000, 0x7fff5878d000, + 0x7fff5878e000, 0x7fff58791000, + 0x7fff58791000, 0x7fff58793000, + }; + unsigned long holes[] = { + /* Start of hole, end of hole, size of hole (+1) */ + 0x565234afb000, 0x565234afc000, 0x1000, + 0x565234afe000, 0x565235def000, 0x12F1000, + 0x565235e10000, 0x7f36d4bfd000, 0x28E49EDED000, + }; + + /* + * req_range consists of 4 values. + * 1. min index + * 2. max index + * 3. size + * 4. number that should be returned. + * 5. return value + */ + unsigned long req_range[] = { + 0x565234af9000, /* Min */ + 0x7fff58791000, /* Max */ + 0x1000, /* Size */ + 0x565234afb000, /* First hole in our data of size 1000. */ + 0, /* Return value success. */ + + 0x0, /* Min */ + 0x7fff58791000, /* Max */ + 0x1F00, /* Size */ + 0x0, /* First hole in our data of size 2000. */ + 0, /* Return value success. */ + + /* Test ascend. */ + 34148797436 << 12, /* Min */ + 0x7fff587AF000, /* Max */ + 0x3000, /* Size */ + 34148798629 << 12, /* Expected location */ + 0, /* Return value success. */ + + /* Test failing. */ + 34148798623 << 12, /* Min */ + 34148798683 << 12, /* Max */ + 0x15000, /* Size */ + 0, /* Expected location */ + -EBUSY, /* Return value failed. */ + + /* Test filling entire gap. */ + 34148798623 << 12, /* Min */ + 0x7fff587AF000, /* Max */ + 0x10000, /* Size */ + 34148798632 << 12, /* Expected location */ + 0, /* Return value success. */ + + /* Test walking off the end of root. */ + 0, /* Min */ + -1, /* Max */ + -1, /* Size */ + 0, /* Expected location */ + -EBUSY, /* Return value failure. */ + + /* Test looking for too large a hole across entire range. */ + 0, /* Min */ + -1, /* Max */ + 4503599618982063UL << 12, /* Size */ + 34359052178 << 12, /* Expected location */ + -EBUSY, /* Return failure. */ + }; + int i, range_count = ARRAY_SIZE(range); + int req_range_count = ARRAY_SIZE(req_range); + unsigned long min = 0x565234af2000; + MA_STATE(mas, mt, 0, 0); + + mtree_store_range(mt, MTREE_ALLOC_MAX, ULONG_MAX, XA_ZERO_ENTRY, + GFP_KERNEL); + for (i = 0; i < range_count; i += 2) { +#define DEBUG_ALLOC_RANGE 0 +#if DEBUG_ALLOC_RANGE + pr_debug("\tInsert %lu-%lu\n", range[i] >> 12, + (range[i + 1] >> 12) - 1); + mt_dump(mt); +#endif + check_insert_range(mt, range[i] >> 12, (range[i + 1] >> 12) - 1, + xa_mk_value(range[i] >> 12), 0); + mt_validate(mt); + } + + + + mas_lock(&mas); + for (i = 0; i < ARRAY_SIZE(holes); i += 3) { + +#if DEBUG_ALLOC_RANGE + pr_debug("\tGet empty %lu-%lu size %lu (%lx-%lx)\n", min >> 12, + holes[i+1] >> 12, holes[i+2] >> 12, + min, holes[i+1]); +#endif + MT_BUG_ON(mt, mas_empty_area(&mas, min >> 12, + holes[i+1] >> 12, + holes[i+2] >> 12)); + MT_BUG_ON(mt, mas.index != holes[i] >> 12); + min = holes[i+1]; + mas_reset(&mas); + } + mas_unlock(&mas); + for (i = 0; i < req_range_count; i += 5) { +#if DEBUG_ALLOC_RANGE + pr_debug("\tTest %d: %lu-%lu size %lu expected %lu (%lu-%lu)\n", + i/5, req_range[i] >> 12, req_range[i + 1] >> 12, + req_range[i + 2] >> 12, req_range[i + 3] >> 12, + req_range[i], req_range[i+1]); +#endif + check_mtree_alloc_range(mt, + req_range[i] >> 12, /* start */ + req_range[i+1] >> 12, /* end */ + req_range[i+2] >> 12, /* size */ + req_range[i+3] >> 12, /* expected address */ + req_range[i+4], /* expected return */ + xa_mk_value(req_range[i] >> 12)); /* pointer */ + mt_validate(mt); +#if DEBUG_ALLOC_RANGE + mt_dump(mt); +#endif + } + + mtree_destroy(mt); +} +#endif + +static noinline void check_ranges(struct maple_tree *mt) +{ + int i, val, val2; + unsigned long r[] = { + 10, 15, + 20, 25, + 17, 22, /* Overlaps previous range. */ + 9, 1000, /* Huge. */ + 100, 200, + 45, 168, + 118, 128, + }; + + MT_BUG_ON(mt, !mtree_empty(mt)); + check_insert_range(mt, r[0], r[1], xa_mk_value(r[0]), 0); + check_insert_range(mt, r[2], r[3], xa_mk_value(r[2]), 0); + check_insert_range(mt, r[4], r[5], xa_mk_value(r[4]), -EEXIST); + MT_BUG_ON(mt, !mt_height(mt)); + /* Store */ + check_store_range(mt, r[4], r[5], xa_mk_value(r[4]), 0); + check_store_range(mt, r[6], r[7], xa_mk_value(r[6]), 0); + check_store_range(mt, r[8], r[9], xa_mk_value(r[8]), 0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + MT_BUG_ON(mt, mt_height(mt)); + + check_seq(mt, 50, false); + mt_set_non_kernel(4); + check_store_range(mt, 5, 47, xa_mk_value(47), 0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + /* Create tree of 1-100 */ + check_seq(mt, 100, false); + /* Store 45-168 */ + mt_set_non_kernel(10); + check_store_range(mt, r[10], r[11], xa_mk_value(r[10]), 0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + /* Create tree of 1-200 */ + check_seq(mt, 200, false); + /* Store 45-168 */ + check_store_range(mt, r[10], r[11], xa_mk_value(r[10]), 0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + check_seq(mt, 30, false); + check_store_range(mt, 6, 18, xa_mk_value(6), 0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + /* Overwrite across multiple levels. */ + /* Create tree of 1-400 */ + check_seq(mt, 400, false); + mt_set_non_kernel(50); + /* Store 118-128 */ + check_store_range(mt, r[12], r[13], xa_mk_value(r[12]), 0); + mt_set_non_kernel(50); + mtree_test_erase(mt, 140); + mtree_test_erase(mt, 141); + mtree_test_erase(mt, 142); + mtree_test_erase(mt, 143); + mtree_test_erase(mt, 130); + mtree_test_erase(mt, 131); + mtree_test_erase(mt, 132); + mtree_test_erase(mt, 133); + mtree_test_erase(mt, 134); + mtree_test_erase(mt, 135); + check_load(mt, r[12], xa_mk_value(r[12])); + check_load(mt, r[13], xa_mk_value(r[12])); + check_load(mt, r[13] - 1, xa_mk_value(r[12])); + check_load(mt, r[13] + 1, xa_mk_value(r[13] + 1)); + check_load(mt, 135, NULL); + check_load(mt, 140, NULL); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + + + /* Overwrite multiple levels at the end of the tree (slot 7) */ + mt_set_non_kernel(50); + check_seq(mt, 400, false); + check_store_range(mt, 353, 361, xa_mk_value(353), 0); + check_store_range(mt, 347, 352, xa_mk_value(347), 0); + + check_load(mt, 346, xa_mk_value(346)); + for (i = 347; i <= 352; i++) + check_load(mt, i, xa_mk_value(347)); + for (i = 353; i <= 361; i++) + check_load(mt, i, xa_mk_value(353)); + check_load(mt, 362, xa_mk_value(362)); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + mt_set_non_kernel(50); + check_seq(mt, 400, false); + check_store_range(mt, 352, 364, NULL, 0); + check_store_range(mt, 351, 363, xa_mk_value(352), 0); + check_load(mt, 350, xa_mk_value(350)); + check_load(mt, 351, xa_mk_value(352)); + for (i = 352; i <= 363; i++) + check_load(mt, i, xa_mk_value(352)); + check_load(mt, 364, NULL); + check_load(mt, 365, xa_mk_value(365)); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + mt_set_non_kernel(5); + check_seq(mt, 400, false); + check_store_range(mt, 352, 364, NULL, 0); + check_store_range(mt, 351, 364, xa_mk_value(352), 0); + check_load(mt, 350, xa_mk_value(350)); + check_load(mt, 351, xa_mk_value(352)); + for (i = 352; i <= 364; i++) + check_load(mt, i, xa_mk_value(352)); + check_load(mt, 365, xa_mk_value(365)); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + + mt_set_non_kernel(50); + check_seq(mt, 400, false); + check_store_range(mt, 362, 367, xa_mk_value(362), 0); + check_store_range(mt, 353, 361, xa_mk_value(353), 0); + mt_set_non_kernel(0); + mt_validate(mt); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + /* + * Interesting cases: + * 1. Overwrite the end of a node and end in the first entry of the next + * node. + * 2. Split a single range + * 3. Overwrite the start of a range + * 4. Overwrite the end of a range + * 5. Overwrite the entire range + * 6. Overwrite a range that causes multiple parent nodes to be + * combined + * 7. Overwrite a range that causes multiple parent nodes and part of + * root to be combined + * 8. Overwrite the whole tree + * 9. Try to overwrite the zero entry of an alloc tree. + * 10. Write a range larger than a nodes current pivot + */ + + mt_set_non_kernel(50); + for (i = 0; i <= 500; i++) { + val = i*5; + val2 = (i+1)*5; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + } + check_store_range(mt, 2400, 2400, xa_mk_value(2400), 0); + check_store_range(mt, 2411, 2411, xa_mk_value(2411), 0); + check_store_range(mt, 2412, 2412, xa_mk_value(2412), 0); + check_store_range(mt, 2396, 2400, xa_mk_value(4052020), 0); + check_store_range(mt, 2402, 2402, xa_mk_value(2402), 0); + mtree_destroy(mt); + mt_set_non_kernel(0); + + mt_set_non_kernel(50); + for (i = 0; i <= 500; i++) { + val = i*5; + val2 = (i+1)*5; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + } + check_store_range(mt, 2422, 2422, xa_mk_value(2422), 0); + check_store_range(mt, 2424, 2424, xa_mk_value(2424), 0); + check_store_range(mt, 2425, 2425, xa_mk_value(2), 0); + check_store_range(mt, 2460, 2470, NULL, 0); + check_store_range(mt, 2435, 2460, xa_mk_value(2435), 0); + check_store_range(mt, 2461, 2470, xa_mk_value(2461), 0); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + /* Test rebalance gaps */ + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + mt_set_non_kernel(50); + for (i = 0; i <= 50; i++) { + val = i*10; + val2 = (i+1)*10; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + } + check_store_range(mt, 161, 161, xa_mk_value(161), 0); + check_store_range(mt, 162, 162, xa_mk_value(162), 0); + check_store_range(mt, 163, 163, xa_mk_value(163), 0); + check_store_range(mt, 240, 249, NULL, 0); + mtree_erase(mt, 200); + mtree_erase(mt, 210); + mtree_erase(mt, 220); + mtree_erase(mt, 230); + mt_set_non_kernel(0); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + for (i = 0; i <= 500; i++) { + val = i*10; + val2 = (i+1)*10; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + } + check_store_range(mt, 4600, 4959, xa_mk_value(1), 0); + mt_validate(mt); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + for (i = 0; i <= 500; i++) { + val = i*10; + val2 = (i+1)*10; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + } + check_store_range(mt, 4811, 4811, xa_mk_value(4811), 0); + check_store_range(mt, 4812, 4812, xa_mk_value(4812), 0); + check_store_range(mt, 4861, 4861, xa_mk_value(4861), 0); + check_store_range(mt, 4862, 4862, xa_mk_value(4862), 0); + check_store_range(mt, 4842, 4849, NULL, 0); + mt_validate(mt); + MT_BUG_ON(mt, !mt_height(mt)); + mtree_destroy(mt); + + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + for (i = 0; i <= 1300; i++) { + val = i*10; + val2 = (i+1)*10; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + MT_BUG_ON(mt, mt_height(mt) >= 4); + } + /* Cause a 3 child split all the way up the tree. */ + for (i = 5; i < 215; i += 10) + check_store_range(mt, 11450 + i, 11450 + i + 1, NULL, 0); + for (i = 5; i < 65; i += 10) + check_store_range(mt, 11770 + i, 11770 + i + 1, NULL, 0); + + MT_BUG_ON(mt, mt_height(mt) >= 4); + for (i = 5; i < 45; i += 10) + check_store_range(mt, 11700 + i, 11700 + i + 1, NULL, 0); + if (!MAPLE_32BIT) + MT_BUG_ON(mt, mt_height(mt) < 4); + mtree_destroy(mt); + + + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + for (i = 0; i <= 1200; i++) { + val = i*10; + val2 = (i+1)*10; + check_store_range(mt, val, val2, xa_mk_value(val), 0); + MT_BUG_ON(mt, mt_height(mt) >= 4); + } + /* Fill parents and leaves before split. */ + for (i = 5; i < 455; i += 10) + check_store_range(mt, 7800 + i, 7800 + i + 1, NULL, 0); + + for (i = 1; i < 16; i++) + check_store_range(mt, 8185 + i, 8185 + i + 1, + xa_mk_value(8185+i), 0); + MT_BUG_ON(mt, mt_height(mt) >= 4); + /* triple split across multiple levels. */ + check_store_range(mt, 8184, 8184, xa_mk_value(8184), 0); + if (!MAPLE_32BIT) + MT_BUG_ON(mt, mt_height(mt) != 4); +} + +static noinline void check_next_entry(struct maple_tree *mt) +{ + void *entry = NULL; + unsigned long limit = 30, i = 0; + MA_STATE(mas, mt, i, i); + + MT_BUG_ON(mt, !mtree_empty(mt)); + + check_seq(mt, limit, false); + rcu_read_lock(); + + /* Check the first one and get ma_state in the correct state. */ + MT_BUG_ON(mt, mas_walk(&mas) != xa_mk_value(i++)); + for ( ; i <= limit + 1; i++) { + entry = mas_next(&mas, limit); + if (i > limit) + MT_BUG_ON(mt, entry != NULL); + else + MT_BUG_ON(mt, xa_mk_value(i) != entry); + } + rcu_read_unlock(); + mtree_destroy(mt); +} + +static noinline void check_prev_entry(struct maple_tree *mt) +{ + unsigned long index = 16; + void *value; + int i; + + MA_STATE(mas, mt, index, index); + + MT_BUG_ON(mt, !mtree_empty(mt)); + check_seq(mt, 30, false); + + rcu_read_lock(); + value = mas_find(&mas, ULONG_MAX); + MT_BUG_ON(mt, value != xa_mk_value(index)); + value = mas_prev(&mas, 0); + MT_BUG_ON(mt, value != xa_mk_value(index - 1)); + rcu_read_unlock(); + mtree_destroy(mt); + + /* Check limits on prev */ + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + mas_lock(&mas); + for (i = 0; i <= index; i++) { + mas_set_range(&mas, i*10, i*10+5); + mas_store_gfp(&mas, xa_mk_value(i), GFP_KERNEL); + } + + mas_set(&mas, 20); + value = mas_walk(&mas); + MT_BUG_ON(mt, value != xa_mk_value(2)); + + value = mas_prev(&mas, 19); + MT_BUG_ON(mt, value != NULL); + + mas_set(&mas, 80); + value = mas_walk(&mas); + MT_BUG_ON(mt, value != xa_mk_value(8)); + + value = mas_prev(&mas, 76); + MT_BUG_ON(mt, value != NULL); + + mas_unlock(&mas); +} + +static noinline void check_root_expand(struct maple_tree *mt) +{ + MA_STATE(mas, mt, 0, 0); + void *ptr; + + + mas_lock(&mas); + mas_set(&mas, 3); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != NULL); + MT_BUG_ON(mt, mas.index != 0); + MT_BUG_ON(mt, mas.last != ULONG_MAX); + + ptr = &check_prev_entry; + mas_set(&mas, 1); + mas_store_gfp(&mas, ptr, GFP_KERNEL); + + mas_set(&mas, 0); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != NULL); + + mas_set(&mas, 1); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != &check_prev_entry); + + mas_set(&mas, 2); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != NULL); + mas_unlock(&mas); + mtree_destroy(mt); + + + mt_init_flags(mt, 0); + mas_lock(&mas); + + mas_set(&mas, 0); + ptr = &check_prev_entry; + mas_store_gfp(&mas, ptr, GFP_KERNEL); + + mas_set(&mas, 5); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != NULL); + MT_BUG_ON(mt, mas.index != 1); + MT_BUG_ON(mt, mas.last != ULONG_MAX); + + mas_set_range(&mas, 0, 100); + ptr = mas_walk(&mas); + MT_BUG_ON(mt, ptr != &check_prev_entry); + MT_BUG_ON(mt, mas.last != 0); + mas_unlock(&mas); + mtree_destroy(mt); + + mt_init_flags(mt, 0); + mas_lock(&mas); + + mas_set(&mas, 0); + ptr = (void *)((unsigned long) check_prev_entry | 1UL); + mas_store_gfp(&mas, ptr, GFP_KERNEL); + ptr = mas_next(&mas, ULONG_MAX); + MT_BUG_ON(mt, ptr != NULL); + MT_BUG_ON(mt, (mas.index != 1) && (mas.last != ULONG_MAX)); + + mas_set(&mas, 1); + ptr = mas_prev(&mas, 0); + MT_BUG_ON(mt, (mas.index != 0) && (mas.last != 0)); + MT_BUG_ON(mt, ptr != (void *)((unsigned long) check_prev_entry | 1UL)); + + mas_unlock(&mas); + + mtree_destroy(mt); + + mt_init_flags(mt, 0); + mas_lock(&mas); + mas_set(&mas, 0); + ptr = (void *)((unsigned long) check_prev_entry | 2UL); + mas_store_gfp(&mas, ptr, GFP_KERNEL); + ptr = mas_next(&mas, ULONG_MAX); + MT_BUG_ON(mt, ptr != NULL); + MT_BUG_ON(mt, (mas.index != 1) && (mas.last != ULONG_MAX)); + + mas_set(&mas, 1); + ptr = mas_prev(&mas, 0); + MT_BUG_ON(mt, (mas.index != 0) && (mas.last != 0)); + MT_BUG_ON(mt, ptr != (void *)((unsigned long) check_prev_entry | 2UL)); + + + mas_unlock(&mas); +} + +static noinline void check_gap_combining(struct maple_tree *mt) +{ + struct maple_enode *mn1, *mn2; + void *entry; + unsigned long singletons = 100; + unsigned long *seq100; + unsigned long seq100_64[] = { + /* 0-5 */ + 74, 75, 76, + 50, 100, 2, + + /* 6-12 */ + 44, 45, 46, 43, + 20, 50, 3, + + /* 13-20*/ + 80, 81, 82, + 76, 2, 79, 85, 4, + }; + + unsigned long seq100_32[] = { + /* 0-5 */ + 61, 62, 63, + 50, 100, 2, + + /* 6-12 */ + 31, 32, 33, 30, + 20, 50, 3, + + /* 13-20*/ + 80, 81, 82, + 76, 2, 79, 85, 4, + }; + + unsigned long seq2000[] = { + 1152, 1151, + 1100, 1200, 2, + }; + unsigned long seq400[] = { + 286, 318, + 256, 260, 266, 270, 275, 280, 290, 398, + 286, 310, + }; + + unsigned long index; + + MA_STATE(mas, mt, 0, 0); + + if (MAPLE_32BIT) + seq100 = seq100_32; + else + seq100 = seq100_64; + + index = seq100[0]; + mas_set(&mas, index); + MT_BUG_ON(mt, !mtree_empty(mt)); + check_seq(mt, singletons, false); /* create 100 singletons. */ + + mt_set_non_kernel(1); + mtree_test_erase(mt, seq100[2]); + check_load(mt, seq100[2], NULL); + mtree_test_erase(mt, seq100[1]); + check_load(mt, seq100[1], NULL); + + rcu_read_lock(); + entry = mas_find(&mas, ULONG_MAX); + MT_BUG_ON(mt, entry != xa_mk_value(index)); + mn1 = mas.node; + mas_next(&mas, ULONG_MAX); + entry = mas_next(&mas, ULONG_MAX); + MT_BUG_ON(mt, entry != xa_mk_value(index + 4)); + mn2 = mas.node; + MT_BUG_ON(mt, mn1 == mn2); /* test the test. */ + + /* + * At this point, there is a gap of 2 at index + 1 between seq100[3] and + * seq100[4]. Search for the gap. + */ + mt_set_non_kernel(1); + mas_reset(&mas); + MT_BUG_ON(mt, mas_empty_area_rev(&mas, seq100[3], seq100[4], + seq100[5])); + MT_BUG_ON(mt, mas.index != index + 1); + rcu_read_unlock(); + + mtree_test_erase(mt, seq100[6]); + check_load(mt, seq100[6], NULL); + mtree_test_erase(mt, seq100[7]); + check_load(mt, seq100[7], NULL); + mtree_test_erase(mt, seq100[8]); + index = seq100[9]; + + rcu_read_lock(); + mas.index = index; + mas.last = index; + mas_reset(&mas); + entry = mas_find(&mas, ULONG_MAX); + MT_BUG_ON(mt, entry != xa_mk_value(index)); + mn1 = mas.node; + entry = mas_next(&mas, ULONG_MAX); + MT_BUG_ON(mt, entry != xa_mk_value(index + 4)); + mas_next(&mas, ULONG_MAX); /* go to the next entry. */ + mn2 = mas.node; + MT_BUG_ON(mt, mn1 == mn2); /* test the next entry is in the next node. */ + + /* + * At this point, there is a gap of 3 at seq100[6]. Find it by + * searching 20 - 50 for size 3. + */ + mas_reset(&mas); + MT_BUG_ON(mt, mas_empty_area_rev(&mas, seq100[10], seq100[11], + seq100[12])); + MT_BUG_ON(mt, mas.index != seq100[6]); + rcu_read_unlock(); + + mt_set_non_kernel(1); + mtree_store(mt, seq100[13], NULL, GFP_KERNEL); + check_load(mt, seq100[13], NULL); + check_load(mt, seq100[14], xa_mk_value(seq100[14])); + mtree_store(mt, seq100[14], NULL, GFP_KERNEL); + check_load(mt, seq100[13], NULL); + check_load(mt, seq100[14], NULL); + + mas_reset(&mas); + rcu_read_lock(); + MT_BUG_ON(mt, mas_empty_area_rev(&mas, seq100[16], seq100[15], + seq100[17])); + MT_BUG_ON(mt, mas.index != seq100[13]); + mt_validate(mt); + rcu_read_unlock(); + + /* + * *DEPRECATED: no retries anymore* Test retry entry in the start of a + * gap. + */ + mt_set_non_kernel(2); + mtree_test_store_range(mt, seq100[18], seq100[14], NULL); + mtree_test_erase(mt, seq100[15]); + mas_reset(&mas); + rcu_read_lock(); + MT_BUG_ON(mt, mas_empty_area_rev(&mas, seq100[16], seq100[19], + seq100[20])); + rcu_read_unlock(); + MT_BUG_ON(mt, mas.index != seq100[18]); + mt_validate(mt); + mtree_destroy(mt); + + /* seq 2000 tests are for multi-level tree gaps */ + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_seq(mt, 2000, false); + mt_set_non_kernel(1); + mtree_test_erase(mt, seq2000[0]); + mtree_test_erase(mt, seq2000[1]); + + mt_set_non_kernel(2); + mas_reset(&mas); + rcu_read_lock(); + MT_BUG_ON(mt, mas_empty_area_rev(&mas, seq2000[2], seq2000[3], + seq2000[4])); + MT_BUG_ON(mt, mas.index != seq2000[1]); + rcu_read_unlock(); + mt_validate(mt); + mtree_destroy(mt); + + /* seq 400 tests rebalancing over two levels. */ + mt_set_non_kernel(99); + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_seq(mt, 400, false); + mtree_test_store_range(mt, seq400[0], seq400[1], NULL); + mt_set_non_kernel(0); + mtree_destroy(mt); + + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_seq(mt, 400, false); + mt_set_non_kernel(50); + mtree_test_store_range(mt, seq400[2], seq400[9], + xa_mk_value(seq400[2])); + mtree_test_store_range(mt, seq400[3], seq400[9], + xa_mk_value(seq400[3])); + mtree_test_store_range(mt, seq400[4], seq400[9], + xa_mk_value(seq400[4])); + mtree_test_store_range(mt, seq400[5], seq400[9], + xa_mk_value(seq400[5])); + mtree_test_store_range(mt, seq400[0], seq400[9], + xa_mk_value(seq400[0])); + mtree_test_store_range(mt, seq400[6], seq400[9], + xa_mk_value(seq400[6])); + mtree_test_store_range(mt, seq400[7], seq400[9], + xa_mk_value(seq400[7])); + mtree_test_store_range(mt, seq400[8], seq400[9], + xa_mk_value(seq400[8])); + mtree_test_store_range(mt, seq400[10], seq400[11], + xa_mk_value(seq400[10])); + mt_validate(mt); + mt_set_non_kernel(0); + mtree_destroy(mt); +} +static noinline void check_node_overwrite(struct maple_tree *mt) +{ + int i, max = 4000; + + for (i = 0; i < max; i++) + mtree_test_store_range(mt, i*100, i*100 + 50, xa_mk_value(i*100)); + + mtree_test_store_range(mt, 319951, 367950, NULL); + /*mt_dump(mt); */ + mt_validate(mt); +} + +#if defined(BENCH_SLOT_STORE) +static noinline void bench_slot_store(struct maple_tree *mt) +{ + int i, brk = 105, max = 1040, brk_start = 100, count = 20000000; + + for (i = 0; i < max; i += 10) + mtree_store_range(mt, i, i + 5, xa_mk_value(i), GFP_KERNEL); + + for (i = 0; i < count; i++) { + mtree_store_range(mt, brk, brk, NULL, GFP_KERNEL); + mtree_store_range(mt, brk_start, brk, xa_mk_value(brk), + GFP_KERNEL); + } +} +#endif + +#if defined(BENCH_NODE_STORE) +static noinline void bench_node_store(struct maple_tree *mt) +{ + int i, overwrite = 76, max = 240, count = 20000000; + + for (i = 0; i < max; i += 10) + mtree_store_range(mt, i, i + 5, xa_mk_value(i), GFP_KERNEL); + + for (i = 0; i < count; i++) { + mtree_store_range(mt, overwrite, overwrite + 15, + xa_mk_value(overwrite), GFP_KERNEL); + + overwrite += 5; + if (overwrite >= 135) + overwrite = 76; + } +} +#endif + +#if defined(BENCH_AWALK) +static noinline void bench_awalk(struct maple_tree *mt) +{ + int i, max = 2500, count = 50000000; + MA_STATE(mas, mt, 1470, 1470); + + for (i = 0; i < max; i += 10) + mtree_store_range(mt, i, i + 5, xa_mk_value(i), GFP_KERNEL); + + mtree_store_range(mt, 1470, 1475, NULL, GFP_KERNEL); + + for (i = 0; i < count; i++) { + mas_empty_area_rev(&mas, 0, 2000, 10); + mas_reset(&mas); + } +} +#endif +#if defined(BENCH_WALK) +static noinline void bench_walk(struct maple_tree *mt) +{ + int i, max = 2500, count = 550000000; + MA_STATE(mas, mt, 1470, 1470); + + for (i = 0; i < max; i += 10) + mtree_store_range(mt, i, i + 5, xa_mk_value(i), GFP_KERNEL); + + for (i = 0; i < count; i++) { + mas_walk(&mas); + mas_reset(&mas); + } + +} +#endif + +#if defined(BENCH_MT_FOR_EACH) +static noinline void bench_mt_for_each(struct maple_tree *mt) +{ + int i, count = 1000000; + unsigned long max = 2500, index = 0; + void *entry; + + for (i = 0; i < max; i += 5) + mtree_store_range(mt, i, i + 4, xa_mk_value(i), GFP_KERNEL); + + for (i = 0; i < count; i++) { + unsigned long j = 0; + + mt_for_each(mt, entry, index, max) { + MT_BUG_ON(mt, entry != xa_mk_value(j)); + j += 5; + } + + index = 0; + } + +} +#endif + +/* check_forking - simulate the kernel forking sequence with the tree. */ +static noinline void check_forking(struct maple_tree *mt) +{ + + struct maple_tree newmt; + int i, nr_entries = 134; + void *val; + MA_STATE(mas, mt, 0, 0); + MA_STATE(newmas, mt, 0, 0); + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + mt_set_non_kernel(99999); + mt_init_flags(&newmt, MT_FLAGS_ALLOC_RANGE); + newmas.tree = &newmt; + mas_reset(&newmas); + mas_reset(&mas); + mas_lock(&newmas); + mas.index = 0; + mas.last = 0; + if (mas_expected_entries(&newmas, nr_entries)) { + pr_err("OOM!"); + BUG_ON(1); + } + rcu_read_lock(); + mas_for_each(&mas, val, ULONG_MAX) { + newmas.index = mas.index; + newmas.last = mas.last; + mas_store(&newmas, val); + } + rcu_read_unlock(); + mas_destroy(&newmas); + mas_unlock(&newmas); + mt_validate(&newmt); + mt_set_non_kernel(0); + mtree_destroy(&newmt); +} + +static noinline void check_mas_store_gfp(struct maple_tree *mt) +{ + + struct maple_tree newmt; + int i, nr_entries = 135; + void *val; + MA_STATE(mas, mt, 0, 0); + MA_STATE(newmas, mt, 0, 0); + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + mt_set_non_kernel(99999); + mt_init_flags(&newmt, MT_FLAGS_ALLOC_RANGE); + newmas.tree = &newmt; + rcu_read_lock(); + mas_lock(&newmas); + mas_reset(&newmas); + mas_set(&mas, 0); + mas_for_each(&mas, val, ULONG_MAX) { + newmas.index = mas.index; + newmas.last = mas.last; + mas_store_gfp(&newmas, val, GFP_KERNEL); + } + mas_unlock(&newmas); + rcu_read_unlock(); + mt_validate(&newmt); + mt_set_non_kernel(0); + mtree_destroy(&newmt); +} + +#if defined(BENCH_FORK) +static noinline void bench_forking(struct maple_tree *mt) +{ + + struct maple_tree newmt; + int i, nr_entries = 134, nr_fork = 80000; + void *val; + MA_STATE(mas, mt, 0, 0); + MA_STATE(newmas, mt, 0, 0); + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + for (i = 0; i < nr_fork; i++) { + mt_set_non_kernel(99999); + mt_init_flags(&newmt, MT_FLAGS_ALLOC_RANGE); + newmas.tree = &newmt; + mas_reset(&newmas); + mas_reset(&mas); + mas.index = 0; + mas.last = 0; + rcu_read_lock(); + mas_lock(&newmas); + if (mas_expected_entries(&newmas, nr_entries)) { + printk("OOM!"); + BUG_ON(1); + } + mas_for_each(&mas, val, ULONG_MAX) { + newmas.index = mas.index; + newmas.last = mas.last; + mas_store(&newmas, val); + } + mas_destroy(&newmas); + mas_unlock(&newmas); + rcu_read_unlock(); + mt_validate(&newmt); + mt_set_non_kernel(0); + mtree_destroy(&newmt); + } +} +#endif + +static noinline void next_prev_test(struct maple_tree *mt) +{ + int i, nr_entries; + void *val; + MA_STATE(mas, mt, 0, 0); + struct maple_enode *mn; + unsigned long *level2; + unsigned long level2_64[] = {707, 1000, 710, 715, 720, 725}; + unsigned long level2_32[] = {1747, 2000, 1750, 1755, 1760, 1765}; + + if (MAPLE_32BIT) { + nr_entries = 500; + level2 = level2_32; + } else { + nr_entries = 200; + level2 = level2_64; + } + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + mas_lock(&mas); + for (i = 0; i <= nr_entries / 2; i++) { + mas_next(&mas, 1000); + if (mas_is_none(&mas)) + break; + + } + mas_reset(&mas); + mas_set(&mas, 0); + i = 0; + mas_for_each(&mas, val, 1000) { + i++; + } + + mas_reset(&mas); + mas_set(&mas, 0); + i = 0; + mas_for_each(&mas, val, 1000) { + mas_pause(&mas); + i++; + } + + /* + * 680 - 685 = 0x61a00001930c + * 686 - 689 = NULL; + * 690 - 695 = 0x61a00001930c + * Check simple next/prev + */ + mas_set(&mas, 686); + val = mas_walk(&mas); + MT_BUG_ON(mt, val != NULL); + + val = mas_next(&mas, 1000); + MT_BUG_ON(mt, val != xa_mk_value(690 / 10)); + MT_BUG_ON(mt, mas.index != 690); + MT_BUG_ON(mt, mas.last != 695); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(680 / 10)); + MT_BUG_ON(mt, mas.index != 680); + MT_BUG_ON(mt, mas.last != 685); + + val = mas_next(&mas, 1000); + MT_BUG_ON(mt, val != xa_mk_value(690 / 10)); + MT_BUG_ON(mt, mas.index != 690); + MT_BUG_ON(mt, mas.last != 695); + + val = mas_next(&mas, 1000); + MT_BUG_ON(mt, val != xa_mk_value(700 / 10)); + MT_BUG_ON(mt, mas.index != 700); + MT_BUG_ON(mt, mas.last != 705); + + /* Check across node boundaries of the tree */ + mas_set(&mas, 70); + val = mas_walk(&mas); + MT_BUG_ON(mt, val != xa_mk_value(70 / 10)); + MT_BUG_ON(mt, mas.index != 70); + MT_BUG_ON(mt, mas.last != 75); + + val = mas_next(&mas, 1000); + MT_BUG_ON(mt, val != xa_mk_value(80 / 10)); + MT_BUG_ON(mt, mas.index != 80); + MT_BUG_ON(mt, mas.last != 85); + + val = mas_prev(&mas, 70); + MT_BUG_ON(mt, val != xa_mk_value(70 / 10)); + MT_BUG_ON(mt, mas.index != 70); + MT_BUG_ON(mt, mas.last != 75); + + /* Check across two levels of the tree */ + mas_reset(&mas); + mas_set(&mas, level2[0]); + val = mas_walk(&mas); + MT_BUG_ON(mt, val != NULL); + val = mas_next(&mas, level2[1]); + MT_BUG_ON(mt, val != xa_mk_value(level2[2] / 10)); + MT_BUG_ON(mt, mas.index != level2[2]); + MT_BUG_ON(mt, mas.last != level2[3]); + mn = mas.node; + + val = mas_next(&mas, level2[1]); + MT_BUG_ON(mt, val != xa_mk_value(level2[4] / 10)); + MT_BUG_ON(mt, mas.index != level2[4]); + MT_BUG_ON(mt, mas.last != level2[5]); + MT_BUG_ON(mt, mn == mas.node); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(level2[2] / 10)); + MT_BUG_ON(mt, mas.index != level2[2]); + MT_BUG_ON(mt, mas.last != level2[3]); + + /* Check running off the end and back on */ + mas_set(&mas, nr_entries * 10); + val = mas_walk(&mas); + MT_BUG_ON(mt, val != xa_mk_value(nr_entries)); + MT_BUG_ON(mt, mas.index != (nr_entries * 10)); + MT_BUG_ON(mt, mas.last != (nr_entries * 10 + 5)); + + val = mas_next(&mas, ULONG_MAX); + MT_BUG_ON(mt, val != NULL); + MT_BUG_ON(mt, mas.index != ULONG_MAX); + MT_BUG_ON(mt, mas.last != ULONG_MAX); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(nr_entries)); + MT_BUG_ON(mt, mas.index != (nr_entries * 10)); + MT_BUG_ON(mt, mas.last != (nr_entries * 10 + 5)); + + /* Check running off the start and back on */ + mas_reset(&mas); + mas_set(&mas, 10); + val = mas_walk(&mas); + MT_BUG_ON(mt, val != xa_mk_value(1)); + MT_BUG_ON(mt, mas.index != 10); + MT_BUG_ON(mt, mas.last != 15); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != xa_mk_value(0)); + MT_BUG_ON(mt, mas.index != 0); + MT_BUG_ON(mt, mas.last != 5); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != NULL); + MT_BUG_ON(mt, mas.index != 0); + MT_BUG_ON(mt, mas.last != 0); + + mas.index = 0; + mas.last = 5; + mas_store(&mas, NULL); + mas_reset(&mas); + mas_set(&mas, 10); + mas_walk(&mas); + + val = mas_prev(&mas, 0); + MT_BUG_ON(mt, val != NULL); + MT_BUG_ON(mt, mas.index != 0); + MT_BUG_ON(mt, mas.last != 0); + mas_unlock(&mas); + + mtree_destroy(mt); + + mt_init(mt); + mtree_store_range(mt, 0, 0, xa_mk_value(0), GFP_KERNEL); + mtree_store_range(mt, 5, 5, xa_mk_value(5), GFP_KERNEL); + rcu_read_lock(); + mas_set(&mas, 5); + val = mas_prev(&mas, 4); + MT_BUG_ON(mt, val != NULL); + rcu_read_unlock(); +} + + + +/* Test spanning writes that require balancing right sibling or right cousin */ +static noinline void check_spanning_relatives(struct maple_tree *mt) +{ + + unsigned long i, nr_entries = 1000; + + for (i = 0; i <= nr_entries; i++) + mtree_store_range(mt, i*10, i*10 + 5, + xa_mk_value(i), GFP_KERNEL); + + + mtree_store_range(mt, 9365, 9955, NULL, GFP_KERNEL); +} + +static noinline void check_fuzzer(struct maple_tree *mt) +{ + /* + * 1. Causes a spanning rebalance of a single root node. + * Fixed by setting the correct limit in mast_cp_to_nodes() when the + * entire right side is consumed. + */ + mtree_test_insert(mt, 88, (void *)0xb1); + mtree_test_insert(mt, 84, (void *)0xa9); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_insert(mt, 14, (void *)0x1d); + mtree_test_insert(mt, 7, (void *)0xf); + mtree_test_insert(mt, 12, (void *)0x19); + mtree_test_insert(mt, 18, (void *)0x25); + mtree_test_store_range(mt, 8, 18, (void *)0x11); + mtree_destroy(mt); + + + /* + * 2. Cause a spanning rebalance of two nodes in root. + * Fixed by setting mast->r->max correctly. + */ + mt_init_flags(mt, 0); + mtree_test_store(mt, 87, (void *)0xaf); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_load(mt, 4); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_store(mt, 8, (void *)0x11); + mtree_test_store(mt, 44, (void *)0x59); + mtree_test_store(mt, 68, (void *)0x89); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_insert(mt, 43, (void *)0x57); + mtree_test_insert(mt, 24, (void *)0x31); + mtree_test_insert(mt, 844, (void *)0x699); + mtree_test_store(mt, 84, (void *)0xa9); + mtree_test_store(mt, 4, (void *)0x9); + mtree_test_erase(mt, 4); + mtree_test_load(mt, 5); + mtree_test_erase(mt, 0); + mtree_destroy(mt); + + /* + * 3. Cause a node overflow on copy + * Fixed by using the correct check for node size in mas_wr_modify() + * Also discovered issue with metadata setting. + */ + mt_init_flags(mt, 0); + mtree_test_store_range(mt, 0, ULONG_MAX, (void *)0x1); + mtree_test_store(mt, 4, (void *)0x9); + mtree_test_erase(mt, 5); + mtree_test_erase(mt, 0); + mtree_test_erase(mt, 4); + mtree_test_store(mt, 5, (void *)0xb); + mtree_test_erase(mt, 5); + mtree_test_store(mt, 5, (void *)0xb); + mtree_test_erase(mt, 5); + mtree_test_erase(mt, 4); + mtree_test_store(mt, 4, (void *)0x9); + mtree_test_store(mt, 444, (void *)0x379); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_load(mt, 0); + mtree_test_store(mt, 5, (void *)0xb); + mtree_test_erase(mt, 0); + mtree_destroy(mt); + + /* + * 4. spanning store failure due to writing incorrect pivot value at + * last slot. + * Fixed by setting mast->r->max correctly in mast_cp_to_nodes() + * + */ + mt_init_flags(mt, 0); + mtree_test_insert(mt, 261, (void *)0x20b); + mtree_test_store(mt, 516, (void *)0x409); + mtree_test_store(mt, 6, (void *)0xd); + mtree_test_insert(mt, 5, (void *)0xb); + mtree_test_insert(mt, 1256, (void *)0x9d1); + mtree_test_store(mt, 4, (void *)0x9); + mtree_test_erase(mt, 1); + mtree_test_store(mt, 56, (void *)0x71); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_store(mt, 24, (void *)0x31); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 2263, (void *)0x11af); + mtree_test_insert(mt, 446, (void *)0x37d); + mtree_test_store_range(mt, 6, 45, (void *)0xd); + mtree_test_store_range(mt, 3, 446, (void *)0x7); + mtree_destroy(mt); + + /* + * 5. mas_wr_extend_null() may overflow slots. + * Fix by checking against wr_mas->node_end. + */ + mt_init_flags(mt, 0); + mtree_test_store(mt, 48, (void *)0x61); + mtree_test_store(mt, 3, (void *)0x7); + mtree_test_load(mt, 0); + mtree_test_store(mt, 88, (void *)0xb1); + mtree_test_store(mt, 81, (void *)0xa3); + mtree_test_insert(mt, 0, (void *)0x1); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_insert(mt, 2480, (void *)0x1361); + mtree_test_insert(mt, ULONG_MAX, + (void *)0xffffffffffffffff); + mtree_test_erase(mt, ULONG_MAX); + mtree_destroy(mt); + + /* + * 6. When reusing a node with an implied pivot and the node is + * shrinking, old data would be left in the implied slot + * Fixed by checking the last pivot for the mas->max and clear + * accordingly. This only affected the left-most node as that node is + * the only one allowed to end in NULL. + */ + mt_init_flags(mt, 0); + mtree_test_erase(mt, 3); + mtree_test_insert(mt, 22, (void *)0x2d); + mtree_test_insert(mt, 15, (void *)0x1f); + mtree_test_load(mt, 2); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 5, (void *)0xb); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 3); + mtree_test_insert(mt, 22, (void *)0x2d); + mtree_test_insert(mt, 15, (void *)0x1f); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_load(mt, 2); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_insert(mt, 5, (void *)0xb); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 3); + mtree_test_insert(mt, 22, (void *)0x2d); + mtree_test_insert(mt, 15, (void *)0x1f); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_insert(mt, 12, (void *)0x19); + mtree_test_erase(mt, 1); + mtree_test_store_range(mt, 4, 62, (void *)0x9); + mtree_test_erase(mt, 62); + mtree_test_store_range(mt, 1, 0, (void *)0x3); + mtree_test_insert(mt, 11, (void *)0x17); + mtree_test_insert(mt, 3, (void *)0x7); + mtree_test_insert(mt, 3, (void *)0x7); + mtree_test_store(mt, 62, (void *)0x7d); + mtree_test_erase(mt, 62); + mtree_test_store_range(mt, 1, 15, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 22, (void *)0x2d); + mtree_test_insert(mt, 12, (void *)0x19); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 3, (void *)0x7); + mtree_test_store(mt, 62, (void *)0x7d); + mtree_test_erase(mt, 62); + mtree_test_insert(mt, 122, (void *)0xf5); + mtree_test_store(mt, 3, (void *)0x7); + mtree_test_insert(mt, 0, (void *)0x1); + mtree_test_store_range(mt, 0, 1, (void *)0x1); + mtree_test_insert(mt, 85, (void *)0xab); + mtree_test_insert(mt, 72, (void *)0x91); + mtree_test_insert(mt, 81, (void *)0xa3); + mtree_test_insert(mt, 726, (void *)0x5ad); + mtree_test_insert(mt, 0, (void *)0x1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_store(mt, 51, (void *)0x67); + mtree_test_insert(mt, 611, (void *)0x4c7); + mtree_test_insert(mt, 485, (void *)0x3cb); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 0, (void *)0x1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert_range(mt, 26, 1, (void *)0x35); + mtree_test_load(mt, 1); + mtree_test_store_range(mt, 1, 22, (void *)0x3); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_load(mt, 53); + mtree_test_load(mt, 1); + mtree_test_store_range(mt, 1, 1, (void *)0x3); + mtree_test_insert(mt, 222, (void *)0x1bd); + mtree_test_insert(mt, 485, (void *)0x3cb); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_load(mt, 0); + mtree_test_insert(mt, 21, (void *)0x2b); + mtree_test_insert(mt, 3, (void *)0x7); + mtree_test_store(mt, 621, (void *)0x4db); + mtree_test_insert(mt, 0, (void *)0x1); + mtree_test_erase(mt, 5); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_store(mt, 62, (void *)0x7d); + mtree_test_erase(mt, 62); + mtree_test_store_range(mt, 1, 0, (void *)0x3); + mtree_test_insert(mt, 22, (void *)0x2d); + mtree_test_insert(mt, 12, (void *)0x19); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_store_range(mt, 4, 62, (void *)0x9); + mtree_test_erase(mt, 62); + mtree_test_erase(mt, 1); + mtree_test_load(mt, 1); + mtree_test_store_range(mt, 1, 22, (void *)0x3); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_load(mt, 53); + mtree_test_load(mt, 1); + mtree_test_store_range(mt, 1, 1, (void *)0x3); + mtree_test_insert(mt, 222, (void *)0x1bd); + mtree_test_insert(mt, 485, (void *)0x3cb); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_load(mt, 0); + mtree_test_load(mt, 0); + mtree_destroy(mt); + + /* + * 7. Previous fix was incomplete, fix mas_resuse_node() clearing of old + * data by overwriting it first - that way metadata is of no concern. + */ + mt_init_flags(mt, 0); + mtree_test_load(mt, 1); + mtree_test_insert(mt, 102, (void *)0xcd); + mtree_test_erase(mt, 2); + mtree_test_erase(mt, 0); + mtree_test_load(mt, 0); + mtree_test_insert(mt, 4, (void *)0x9); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 110, (void *)0xdd); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_insert_range(mt, 5, 0, (void *)0xb); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_store(mt, 112, (void *)0xe1); + mtree_test_insert(mt, 21, (void *)0x2b); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_insert_range(mt, 110, 2, (void *)0xdd); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_load(mt, 22); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 210, (void *)0x1a5); + mtree_test_store_range(mt, 0, 2, (void *)0x1); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_erase(mt, 2); + mtree_test_erase(mt, 22); + mtree_test_erase(mt, 1); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_load(mt, 112); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_insert_range(mt, 1, 2, (void *)0x3); + mtree_test_erase(mt, 0); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_erase(mt, 0); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_erase(mt, 2); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert_range(mt, 1, 2, (void *)0x3); + mtree_test_erase(mt, 0); + mtree_test_erase(mt, 2); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_load(mt, 112); + mtree_test_store_range(mt, 110, 12, (void *)0xdd); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_load(mt, 110); + mtree_test_insert_range(mt, 4, 71, (void *)0x9); + mtree_test_load(mt, 2); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_insert_range(mt, 11, 22, (void *)0x17); + mtree_test_erase(mt, 12); + mtree_test_store(mt, 2, (void *)0x5); + mtree_test_load(mt, 22); + mtree_destroy(mt); + + + /* + * 8. When rebalancing or spanning_rebalance(), the max of the new node + * may be set incorrectly to the final pivot and not the right max. + * Fix by setting the left max to orig right max if the entire node is + * consumed. + */ + mt_init_flags(mt, 0); + mtree_test_store(mt, 6, (void *)0xd); + mtree_test_store(mt, 67, (void *)0x87); + mtree_test_insert(mt, 15, (void *)0x1f); + mtree_test_insert(mt, 6716, (void *)0x3479); + mtree_test_store(mt, 61, (void *)0x7b); + mtree_test_insert(mt, 13, (void *)0x1b); + mtree_test_store(mt, 8, (void *)0x11); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_load(mt, 0); + mtree_test_erase(mt, 67167); + mtree_test_insert_range(mt, 6, 7167, (void *)0xd); + mtree_test_insert(mt, 6, (void *)0xd); + mtree_test_erase(mt, 67); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 667167); + mtree_test_insert(mt, 6, (void *)0xd); + mtree_test_store(mt, 67, (void *)0x87); + mtree_test_insert(mt, 5, (void *)0xb); + mtree_test_erase(mt, 1); + mtree_test_insert(mt, 6, (void *)0xd); + mtree_test_erase(mt, 67); + mtree_test_insert(mt, 15, (void *)0x1f); + mtree_test_insert(mt, 67167, (void *)0x20cbf); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_load(mt, 7); + mtree_test_insert(mt, 16, (void *)0x21); + mtree_test_insert(mt, 36, (void *)0x49); + mtree_test_store(mt, 67, (void *)0x87); + mtree_test_store(mt, 6, (void *)0xd); + mtree_test_insert(mt, 367, (void *)0x2df); + mtree_test_insert(mt, 115, (void *)0xe7); + mtree_test_store(mt, 0, (void *)0x1); + mtree_test_store_range(mt, 1, 3, (void *)0x3); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_erase(mt, 67167); + mtree_test_insert_range(mt, 6, 47, (void *)0xd); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_insert_range(mt, 1, 67, (void *)0x3); + mtree_test_load(mt, 67); + mtree_test_insert(mt, 1, (void *)0x3); + mtree_test_erase(mt, 67167); + mtree_destroy(mt); + + /* + * 9. spanning store to the end of data caused an invalid metadata + * length which resulted in a crash eventually. + * Fix by checking if there is a value in pivot before incrementing the + * metadata end in mab_mas_cp(). To ensure this doesn't happen again, + * abstract the two locations this happens into a function called + * mas_leaf_set_meta(). + */ + mt_init_flags(mt, 0); + mtree_test_insert(mt, 21, (void *)0x2b); + mtree_test_insert(mt, 12, (void *)0x19); + mtree_test_insert(mt, 6, (void *)0xd); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, 91, (void *)0xb7); + mtree_test_insert(mt, 18, (void *)0x25); + mtree_test_insert(mt, 81, (void *)0xa3); + mtree_test_store_range(mt, 0, 128, (void *)0x1); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_erase(mt, 8); + mtree_test_insert(mt, 11, (void *)0x17); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_insert(mt, 21, (void *)0x2b); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, ULONG_MAX - 10, (void *)0xffffffffffffffeb); + mtree_test_erase(mt, ULONG_MAX - 10); + mtree_test_store_range(mt, 0, 281, (void *)0x1); + mtree_test_erase(mt, 2); + mtree_test_insert(mt, 1211, (void *)0x977); + mtree_test_insert(mt, 111, (void *)0xdf); + mtree_test_insert(mt, 13, (void *)0x1b); + mtree_test_insert(mt, 211, (void *)0x1a7); + mtree_test_insert(mt, 11, (void *)0x17); + mtree_test_insert(mt, 5, (void *)0xb); + mtree_test_insert(mt, 1218, (void *)0x985); + mtree_test_insert(mt, 61, (void *)0x7b); + mtree_test_store(mt, 1, (void *)0x3); + mtree_test_insert(mt, 121, (void *)0xf3); + mtree_test_insert(mt, 8, (void *)0x11); + mtree_test_insert(mt, 21, (void *)0x2b); + mtree_test_insert(mt, 2, (void *)0x5); + mtree_test_insert(mt, ULONG_MAX - 10, (void *)0xffffffffffffffeb); + mtree_test_erase(mt, ULONG_MAX - 10); +} + +/* duplicate the tree with a specific gap */ +static noinline void check_dup_gaps(struct maple_tree *mt, + unsigned long nr_entries, bool zero_start, + unsigned long gap) +{ + unsigned long i = 0; + struct maple_tree newmt; + int ret; + void *tmp; + MA_STATE(mas, mt, 0, 0); + MA_STATE(newmas, &newmt, 0, 0); + + if (!zero_start) + i = 1; + + mt_zero_nr_tallocated(); + for (; i <= nr_entries; i++) + mtree_store_range(mt, i*10, (i+1)*10 - gap, + xa_mk_value(i), GFP_KERNEL); + + mt_init_flags(&newmt, MT_FLAGS_ALLOC_RANGE); + mt_set_non_kernel(99999); + mas_lock(&newmas); + ret = mas_expected_entries(&newmas, nr_entries); + mt_set_non_kernel(0); + MT_BUG_ON(mt, ret != 0); + + rcu_read_lock(); + mas_for_each(&mas, tmp, ULONG_MAX) { + newmas.index = mas.index; + newmas.last = mas.last; + mas_store(&newmas, tmp); + } + rcu_read_unlock(); + mas_destroy(&newmas); + mas_unlock(&newmas); + + mtree_destroy(&newmt); +} + +/* Duplicate many sizes of trees. Mainly to test expected entry values */ +static noinline void check_dup(struct maple_tree *mt) +{ + int i; + int big_start = 100010; + + /* Check with a value at zero */ + for (i = 10; i < 1000; i++) { + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_dup_gaps(mt, i, true, 5); + mtree_destroy(mt); + rcu_barrier(); + } + + cond_resched(); + mt_cache_shrink(); + /* Check with a value at zero, no gap */ + for (i = 1000; i < 2000; i++) { + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_dup_gaps(mt, i, true, 0); + mtree_destroy(mt); + rcu_barrier(); + } + + cond_resched(); + mt_cache_shrink(); + /* Check with a value at zero and unreasonably large */ + for (i = big_start; i < big_start + 10; i++) { + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_dup_gaps(mt, i, true, 5); + mtree_destroy(mt); + rcu_barrier(); + } + + cond_resched(); + mt_cache_shrink(); + /* Small to medium size not starting at zero*/ + for (i = 200; i < 1000; i++) { + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_dup_gaps(mt, i, false, 5); + mtree_destroy(mt); + rcu_barrier(); + } + + cond_resched(); + mt_cache_shrink(); + /* Unreasonably large not starting at zero*/ + for (i = big_start; i < big_start + 10; i++) { + mt_init_flags(mt, MT_FLAGS_ALLOC_RANGE); + check_dup_gaps(mt, i, false, 5); + mtree_destroy(mt); + rcu_barrier(); + cond_resched(); + mt_cache_shrink(); + } + + /* Check non-allocation tree not starting at zero */ + for (i = 1500; i < 3000; i++) { + mt_init_flags(mt, 0); + check_dup_gaps(mt, i, false, 5); + mtree_destroy(mt); + rcu_barrier(); + cond_resched(); + if (i % 2 == 0) + mt_cache_shrink(); + } + + mt_cache_shrink(); + /* Check non-allocation tree starting at zero */ + for (i = 200; i < 1000; i++) { + mt_init_flags(mt, 0); + check_dup_gaps(mt, i, true, 5); + mtree_destroy(mt); + rcu_barrier(); + cond_resched(); + } + + mt_cache_shrink(); + /* Unreasonably large */ + for (i = big_start + 5; i < big_start + 10; i++) { + mt_init_flags(mt, 0); + check_dup_gaps(mt, i, true, 5); + mtree_destroy(mt); + rcu_barrier(); + mt_cache_shrink(); + cond_resched(); + } +} + +static DEFINE_MTREE(tree); +static int maple_tree_seed(void) +{ + unsigned long set[] = {5015, 5014, 5017, 25, 1000, + 1001, 1002, 1003, 1005, 0, + 5003, 5002}; + void *ptr = &set; + + pr_info("\nTEST STARTING\n\n"); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_root_expand(&tree); + mtree_destroy(&tree); + +#if defined(BENCH_SLOT_STORE) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_slot_store(&tree); + mtree_destroy(&tree); + goto skip; +#endif +#if defined(BENCH_NODE_STORE) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_node_store(&tree); + mtree_destroy(&tree); + goto skip; +#endif +#if defined(BENCH_AWALK) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_awalk(&tree); + mtree_destroy(&tree); + goto skip; +#endif +#if defined(BENCH_WALK) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_walk(&tree); + mtree_destroy(&tree); + goto skip; +#endif +#if defined(BENCH_FORK) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_forking(&tree); + mtree_destroy(&tree); + goto skip; +#endif +#if defined(BENCH_MT_FOR_EACH) +#define BENCH + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + bench_mt_for_each(&tree); + mtree_destroy(&tree); + goto skip; +#endif + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_forking(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_mas_store_gfp(&tree); + mtree_destroy(&tree); + + /* Test ranges (store and insert) */ + mt_init_flags(&tree, 0); + check_ranges(&tree); + mtree_destroy(&tree); + +#if defined(CONFIG_64BIT) + /* These tests have ranges outside of 4GB */ + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_alloc_range(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_alloc_rev_range(&tree); + mtree_destroy(&tree); +#endif + + mt_init_flags(&tree, 0); + + check_load(&tree, set[0], NULL); /* See if 5015 -> NULL */ + + check_insert(&tree, set[9], &tree); /* Insert 0 */ + check_load(&tree, set[9], &tree); /* See if 0 -> &tree */ + check_load(&tree, set[0], NULL); /* See if 5015 -> NULL */ + + check_insert(&tree, set[10], ptr); /* Insert 5003 */ + check_load(&tree, set[9], &tree); /* See if 0 -> &tree */ + check_load(&tree, set[11], NULL); /* See if 5002 -> NULL */ + check_load(&tree, set[10], ptr); /* See if 5003 -> ptr */ + + /* Clear out the tree */ + mtree_destroy(&tree); + + /* Try to insert, insert a dup, and load back what was inserted. */ + mt_init_flags(&tree, 0); + check_insert(&tree, set[0], &tree); /* Insert 5015 */ + check_dup_insert(&tree, set[0], &tree); /* Insert 5015 again */ + check_load(&tree, set[0], &tree); /* See if 5015 -> &tree */ + + /* + * Second set of tests try to load a value that doesn't exist, inserts + * a second value, then loads the value again + */ + check_load(&tree, set[1], NULL); /* See if 5014 -> NULL */ + check_insert(&tree, set[1], ptr); /* insert 5014 -> ptr */ + check_load(&tree, set[1], ptr); /* See if 5014 -> ptr */ + check_load(&tree, set[0], &tree); /* See if 5015 -> &tree */ + /* + * Tree currently contains: + * p[0]: 14 -> (nil) p[1]: 15 -> ptr p[2]: 16 -> &tree p[3]: 0 -> (nil) + */ + check_insert(&tree, set[6], ptr); /* insert 1002 -> ptr */ + check_insert(&tree, set[7], &tree); /* insert 1003 -> &tree */ + + check_load(&tree, set[0], &tree); /* See if 5015 -> &tree */ + check_load(&tree, set[1], ptr); /* See if 5014 -> ptr */ + check_load(&tree, set[6], ptr); /* See if 1002 -> ptr */ + check_load(&tree, set[7], &tree); /* 1003 = &tree ? */ + + /* Clear out tree */ + mtree_destroy(&tree); + + mt_init_flags(&tree, 0); + /* Test inserting into a NULL hole. */ + check_insert(&tree, set[5], ptr); /* insert 1001 -> ptr */ + check_insert(&tree, set[7], &tree); /* insert 1003 -> &tree */ + check_insert(&tree, set[6], ptr); /* insert 1002 -> ptr */ + check_load(&tree, set[5], ptr); /* See if 1001 -> ptr */ + check_load(&tree, set[6], ptr); /* See if 1002 -> ptr */ + check_load(&tree, set[7], &tree); /* See if 1003 -> &tree */ + + /* Clear out the tree */ + mtree_destroy(&tree); + + mt_init_flags(&tree, 0); + /* + * set[] = {5015, 5014, 5017, 25, 1000, + * 1001, 1002, 1003, 1005, 0, + * 5003, 5002}; + */ + + check_insert(&tree, set[0], ptr); /* 5015 */ + check_insert(&tree, set[1], &tree); /* 5014 */ + check_insert(&tree, set[2], ptr); /* 5017 */ + check_insert(&tree, set[3], &tree); /* 25 */ + check_load(&tree, set[0], ptr); + check_load(&tree, set[1], &tree); + check_load(&tree, set[2], ptr); + check_load(&tree, set[3], &tree); + check_insert(&tree, set[4], ptr); /* 1000 < Should split. */ + check_load(&tree, set[0], ptr); + check_load(&tree, set[1], &tree); + check_load(&tree, set[2], ptr); + check_load(&tree, set[3], &tree); /*25 */ + check_load(&tree, set[4], ptr); + check_insert(&tree, set[5], &tree); /* 1001 */ + check_load(&tree, set[0], ptr); + check_load(&tree, set[1], &tree); + check_load(&tree, set[2], ptr); + check_load(&tree, set[3], &tree); + check_load(&tree, set[4], ptr); + check_load(&tree, set[5], &tree); + check_insert(&tree, set[6], ptr); + check_load(&tree, set[0], ptr); + check_load(&tree, set[1], &tree); + check_load(&tree, set[2], ptr); + check_load(&tree, set[3], &tree); + check_load(&tree, set[4], ptr); + check_load(&tree, set[5], &tree); + check_load(&tree, set[6], ptr); + check_insert(&tree, set[7], &tree); + check_load(&tree, set[0], ptr); + check_insert(&tree, set[8], ptr); + + check_insert(&tree, set[9], &tree); + + check_load(&tree, set[0], ptr); + check_load(&tree, set[1], &tree); + check_load(&tree, set[2], ptr); + check_load(&tree, set[3], &tree); + check_load(&tree, set[4], ptr); + check_load(&tree, set[5], &tree); + check_load(&tree, set[6], ptr); + check_load(&tree, set[9], &tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, 0); + check_seq(&tree, 16, false); + mtree_destroy(&tree); + + mt_init_flags(&tree, 0); + check_seq(&tree, 1000, true); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_rev_seq(&tree, 1000, true); + mtree_destroy(&tree); + + check_lower_bound_split(&tree); + check_upper_bound_split(&tree); + check_mid_split(&tree); + + mt_init_flags(&tree, 0); + check_next_entry(&tree); + check_find(&tree); + check_find_2(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_prev_entry(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_gap_combining(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_node_overwrite(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + next_prev_test(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_spanning_relatives(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_rev_find(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, 0); + check_fuzzer(&tree); + mtree_destroy(&tree); + + mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE); + check_dup(&tree); + mtree_destroy(&tree); + +#if defined(BENCH) +skip: +#endif + rcu_barrier(); + pr_info("maple_tree: %u of %u tests passed\n", + atomic_read(&maple_tree_tests_passed), + atomic_read(&maple_tree_tests_run)); + if (atomic_read(&maple_tree_tests_run) == + atomic_read(&maple_tree_tests_passed)) + return 0; + + return -EINVAL; +} + +static void maple_tree_harvest(void) +{ + +} + +module_init(maple_tree_seed); +module_exit(maple_tree_harvest); +MODULE_AUTHOR("Liam R. Howlett "); +MODULE_LICENSE("GPL"); diff --git a/mm/Kconfig b/mm/Kconfig index d5373a23902c..e8b147872458 100644 --- a/mm/Kconfig +++ b/mm/Kconfig @@ -1141,6 +1141,32 @@ config PTE_MARKER_UFFD_WP purposes. It is required to enable userfaultfd write protection on file-backed memory types like shmem and hugetlbfs. +# multi-gen LRU { +config LRU_GEN + bool "Multi-Gen LRU" + depends on MMU + # make sure folio->flags has enough spare bits + depends on 64BIT || !SPARSEMEM || SPARSEMEM_VMEMMAP + help + A high performance LRU implementation to overcommit memory. See + Documentation/admin-guide/mm/multigen_lru.rst for details. + +config LRU_GEN_ENABLED + bool "Enable by default" + depends on LRU_GEN + help + This option enables the multi-gen LRU by default. + +config LRU_GEN_STATS + bool "Full stats for debugging" + depends on LRU_GEN + help + Do not enable this option unless you plan to look at historical stats + from evicted generations for debugging purpose. + + This option has a per-memcg and per-node memory overhead. +# } + source "mm/damon/Kconfig" endmenu diff --git a/mm/Makefile b/mm/Makefile index eb7235da6e61..a8b9a2b913d3 100644 --- a/mm/Makefile +++ b/mm/Makefile @@ -52,7 +52,7 @@ obj-y := filemap.o mempool.o oom_kill.o fadvise.o \ readahead.o swap.o truncate.o vmscan.o shmem.o \ util.o mmzone.o vmstat.o backing-dev.o \ mm_init.o percpu.o slab_common.o \ - compaction.o vmacache.o \ + compaction.o \ interval_tree.o list_lru.o workingset.o \ debug.o gup.o mmap_lock.o $(mmu-y) diff --git a/mm/damon/vaddr-test.h b/mm/damon/vaddr-test.h index d4f55f349100..bce37c487540 100644 --- a/mm/damon/vaddr-test.h +++ b/mm/damon/vaddr-test.h @@ -14,33 +14,19 @@ #include -static void __link_vmas(struct vm_area_struct *vmas, ssize_t nr_vmas) +static void __link_vmas(struct maple_tree *mt, struct vm_area_struct *vmas, + ssize_t nr_vmas) { - int i, j; - unsigned long largest_gap, gap; + int i; + MA_STATE(mas, mt, 0, 0); if (!nr_vmas) return; - for (i = 0; i < nr_vmas - 1; i++) { - vmas[i].vm_next = &vmas[i + 1]; - - vmas[i].vm_rb.rb_left = NULL; - vmas[i].vm_rb.rb_right = &vmas[i + 1].vm_rb; - - largest_gap = 0; - for (j = i; j < nr_vmas; j++) { - if (j == 0) - continue; - gap = vmas[j].vm_start - vmas[j - 1].vm_end; - if (gap > largest_gap) - largest_gap = gap; - } - vmas[i].rb_subtree_gap = largest_gap; - } - vmas[i].vm_next = NULL; - vmas[i].vm_rb.rb_right = NULL; - vmas[i].rb_subtree_gap = 0; + mas_lock(&mas); + for (i = 0; i < nr_vmas; i++) + vma_mas_store(&vmas[i], &mas); + mas_unlock(&mas); } /* @@ -72,6 +58,7 @@ static void __link_vmas(struct vm_area_struct *vmas, ssize_t nr_vmas) */ static void damon_test_three_regions_in_vmas(struct kunit *test) { + static struct mm_struct mm; struct damon_addr_range regions[3] = {0,}; /* 10-20-25, 200-210-220, 300-305, 307-330 */ struct vm_area_struct vmas[] = { @@ -83,9 +70,10 @@ static void damon_test_three_regions_in_vmas(struct kunit *test) (struct vm_area_struct) {.vm_start = 307, .vm_end = 330}, }; - __link_vmas(vmas, 6); + mt_init_flags(&mm.mm_mt, MM_MT_FLAGS); + __link_vmas(&mm.mm_mt, vmas, ARRAY_SIZE(vmas)); - __damon_va_three_regions(&vmas[0], regions); + __damon_va_three_regions(&mm, regions); KUNIT_EXPECT_EQ(test, 10ul, regions[0].start); KUNIT_EXPECT_EQ(test, 25ul, regions[0].end); diff --git a/mm/damon/vaddr.c b/mm/damon/vaddr.c index 1d16c6c79638..95769ba14262 100644 --- a/mm/damon/vaddr.c +++ b/mm/damon/vaddr.c @@ -113,37 +113,38 @@ static unsigned long sz_range(struct damon_addr_range *r) * * Returns 0 if success, or negative error code otherwise. */ -static int __damon_va_three_regions(struct vm_area_struct *vma, +static int __damon_va_three_regions(struct mm_struct *mm, struct damon_addr_range regions[3]) { - struct damon_addr_range gap = {0}, first_gap = {0}, second_gap = {0}; - struct vm_area_struct *last_vma = NULL; - unsigned long start = 0; - struct rb_root rbroot; - - /* Find two biggest gaps so that first_gap > second_gap > others */ - for (; vma; vma = vma->vm_next) { - if (!last_vma) { - start = vma->vm_start; - goto next; - } + struct damon_addr_range first_gap = {0}, second_gap = {0}; + VMA_ITERATOR(vmi, mm, 0); + struct vm_area_struct *vma, *prev = NULL; + unsigned long start; - if (vma->rb_subtree_gap <= sz_range(&second_gap)) { - rbroot.rb_node = &vma->vm_rb; - vma = rb_entry(rb_last(&rbroot), - struct vm_area_struct, vm_rb); + /* + * Find the two biggest gaps so that first_gap > second_gap > others. + * If this is too slow, it can be optimised to examine the maple + * tree gaps. + */ + for_each_vma(vmi, vma) { + unsigned long gap; + + if (!prev) { + start = vma->vm_start; goto next; } - - gap.start = last_vma->vm_end; - gap.end = vma->vm_start; - if (sz_range(&gap) > sz_range(&second_gap)) { - swap(gap, second_gap); - if (sz_range(&second_gap) > sz_range(&first_gap)) - swap(second_gap, first_gap); + gap = vma->vm_start - prev->vm_end; + + if (gap > sz_range(&first_gap)) { + second_gap = first_gap; + first_gap.start = prev->vm_end; + first_gap.end = vma->vm_start; + } else if (gap > sz_range(&second_gap)) { + second_gap.start = prev->vm_end; + second_gap.end = vma->vm_start; } next: - last_vma = vma; + prev = vma; } if (!sz_range(&second_gap) || !sz_range(&first_gap)) @@ -159,7 +160,7 @@ static int __damon_va_three_regions(struct vm_area_struct *vma, regions[1].start = ALIGN(first_gap.end, DAMON_MIN_REGION); regions[1].end = ALIGN(second_gap.start, DAMON_MIN_REGION); regions[2].start = ALIGN(second_gap.end, DAMON_MIN_REGION); - regions[2].end = ALIGN(last_vma->vm_end, DAMON_MIN_REGION); + regions[2].end = ALIGN(prev->vm_end, DAMON_MIN_REGION); return 0; } @@ -180,7 +181,7 @@ static int damon_va_three_regions(struct damon_target *t, return -EINVAL; mmap_read_lock(mm); - rc = __damon_va_three_regions(mm->mmap, regions); + rc = __damon_va_three_regions(mm, regions); mmap_read_unlock(mm); mmput(mm); diff --git a/mm/debug.c b/mm/debug.c index bef329bf28f0..0fd15ba70d16 100644 --- a/mm/debug.c +++ b/mm/debug.c @@ -139,13 +139,11 @@ EXPORT_SYMBOL(dump_page); void dump_vma(const struct vm_area_struct *vma) { - pr_emerg("vma %px start %px end %px\n" - "next %px prev %px mm %px\n" + pr_emerg("vma %px start %px end %px mm %px\n" "prot %lx anon_vma %px vm_ops %px\n" "pgoff %lx file %px private_data %px\n" "flags: %#lx(%pGv)\n", - vma, (void *)vma->vm_start, (void *)vma->vm_end, vma->vm_next, - vma->vm_prev, vma->vm_mm, + vma, (void *)vma->vm_start, (void *)vma->vm_end, vma->vm_mm, (unsigned long)pgprot_val(vma->vm_page_prot), vma->anon_vma, vma->vm_ops, vma->vm_pgoff, vma->vm_file, vma->vm_private_data, @@ -155,11 +153,11 @@ EXPORT_SYMBOL(dump_vma); void dump_mm(const struct mm_struct *mm) { - pr_emerg("mm %px mmap %px seqnum %llu task_size %lu\n" + pr_emerg("mm %px task_size %lu\n" #ifdef CONFIG_MMU "get_unmapped_area %px\n" #endif - "mmap_base %lu mmap_legacy_base %lu highest_vm_end %lu\n" + "mmap_base %lu mmap_legacy_base %lu\n" "pgd %px mm_users %d mm_count %d pgtables_bytes %lu map_count %d\n" "hiwater_rss %lx hiwater_vm %lx total_vm %lx locked_vm %lx\n" "pinned_vm %llx data_vm %lx exec_vm %lx stack_vm %lx\n" @@ -183,11 +181,11 @@ void dump_mm(const struct mm_struct *mm) "tlb_flush_pending %d\n" "def_flags: %#lx(%pGv)\n", - mm, mm->mmap, (long long) mm->vmacache_seqnum, mm->task_size, + mm, mm->task_size, #ifdef CONFIG_MMU mm->get_unmapped_area, #endif - mm->mmap_base, mm->mmap_legacy_base, mm->highest_vm_end, + mm->mmap_base, mm->mmap_legacy_base, mm->pgd, atomic_read(&mm->mm_users), atomic_read(&mm->mm_count), mm_pgtables_bytes(mm), diff --git a/mm/gup.c b/mm/gup.c index 251cb6a10bc0..89f25b31b6bf 100644 --- a/mm/gup.c +++ b/mm/gup.c @@ -1679,10 +1679,11 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors) if (!locked) { locked = 1; mmap_read_lock(mm); - vma = find_vma(mm, nstart); + vma = find_vma_intersection(mm, nstart, end); } else if (nstart >= vma->vm_end) - vma = vma->vm_next; - if (!vma || vma->vm_start >= end) + vma = find_vma_intersection(mm, vma->vm_end, end); + + if (!vma) break; /* * Set [nstart; nend) to intersection of desired address diff --git a/mm/huge_memory.c b/mm/huge_memory.c index 9558dbf3954c..27ffc26cc95d 100644 --- a/mm/huge_memory.c +++ b/mm/huge_memory.c @@ -2334,11 +2334,11 @@ void vma_adjust_trans_huge(struct vm_area_struct *vma, split_huge_pmd_if_needed(vma, end); /* - * If we're also updating the vma->vm_next->vm_start, + * If we're also updating the next vma vm_start, * check if we need to split it. */ if (adjust_next > 0) { - struct vm_area_struct *next = vma->vm_next; + struct vm_area_struct *next = find_vma(vma->vm_mm, vma->vm_end); unsigned long nstart = next->vm_start; nstart += adjust_next; split_huge_pmd_if_needed(next, nstart); @@ -2438,7 +2438,8 @@ static void __split_huge_page_tail(struct page *head, int tail, #ifdef CONFIG_64BIT (1L << PG_arch_2) | #endif - (1L << PG_dirty))); + (1L << PG_dirty) | + LRU_GEN_MASK | LRU_REFS_MASK)); /* ->mapping in first tail page is compound_mapcount */ VM_BUG_ON_PAGE(tail > 2 && page_tail->mapping != TAIL_MAPPING, diff --git a/mm/init-mm.c b/mm/init-mm.c index fbe7844d0912..c9327abb771c 100644 --- a/mm/init-mm.c +++ b/mm/init-mm.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0 #include -#include +#include #include #include #include @@ -28,7 +28,7 @@ * and size this cpu_bitmask to NR_CPUS. */ struct mm_struct init_mm = { - .mm_rb = RB_ROOT, + .mm_mt = MTREE_INIT_EXT(mm_mt, MM_MT_FLAGS, init_mm.mmap_lock), .pgd = swapper_pg_dir, .mm_users = ATOMIC_INIT(2), .mm_count = ATOMIC_INIT(1), diff --git a/mm/internal.h b/mm/internal.h index 785409805ed7..1adf57d9713e 100644 --- a/mm/internal.h +++ b/mm/internal.h @@ -83,9 +83,11 @@ vm_fault_t do_swap_page(struct vm_fault *vmf); void folio_rotate_reclaimable(struct folio *folio); bool __folio_end_writeback(struct folio *folio); void deactivate_file_folio(struct folio *folio); +void folio_activate(struct folio *folio); -void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma, - unsigned long floor, unsigned long ceiling); +void free_pgtables(struct mmu_gather *tlb, struct maple_tree *mt, + struct vm_area_struct *start_vma, unsigned long floor, + unsigned long ceiling); void pmd_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte); struct zap_details; @@ -479,9 +481,6 @@ static inline bool is_data_mapping(vm_flags_t flags) } /* mm/util.c */ -void __vma_link_list(struct mm_struct *mm, struct vm_area_struct *vma, - struct vm_area_struct *prev); -void __vma_unlink_list(struct mm_struct *mm, struct vm_area_struct *vma); struct anon_vma *folio_anon_vma(struct folio *folio); #ifdef CONFIG_MMU diff --git a/mm/khugepaged.c b/mm/khugepaged.c index 70b7ac66411c..ef1f78d45e64 100644 --- a/mm/khugepaged.c +++ b/mm/khugepaged.c @@ -1389,7 +1389,7 @@ static void collapse_and_free_pmd(struct mm_struct *mm, struct vm_area_struct *v void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr) { unsigned long haddr = addr & HPAGE_PMD_MASK; - struct vm_area_struct *vma = find_vma(mm, haddr); + struct vm_area_struct *vma = vma_lookup(mm, haddr); struct page *hpage; pte_t *start_pte, *pte; pmd_t *pmd; @@ -2058,6 +2058,7 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, __releases(&khugepaged_mm_lock) __acquires(&khugepaged_mm_lock) { + struct vma_iterator vmi; struct mm_slot *mm_slot; struct mm_struct *mm; struct vm_area_struct *vma; @@ -2085,11 +2086,13 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, vma = NULL; if (unlikely(!mmap_read_trylock(mm))) goto breakouterloop_mmap_lock; - if (likely(!khugepaged_test_exit(mm))) - vma = find_vma(mm, khugepaged_scan.address); progress++; - for (; vma; vma = vma->vm_next) { + if (unlikely(khugepaged_test_exit(mm))) + goto breakouterloop; + + vma_iter_init(&vmi, mm, khugepaged_scan.address); + for_each_vma(vmi, vma) { unsigned long hstart, hend; cond_resched(); diff --git a/mm/ksm.c b/mm/ksm.c index 42ab153335a2..63b4b9d71597 100644 --- a/mm/ksm.c +++ b/mm/ksm.c @@ -981,11 +981,13 @@ static int unmerge_and_remove_all_rmap_items(void) struct mm_slot, mm_list); spin_unlock(&ksm_mmlist_lock); - for (mm_slot = ksm_scan.mm_slot; - mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) { + for (mm_slot = ksm_scan.mm_slot; mm_slot != &ksm_mm_head; + mm_slot = ksm_scan.mm_slot) { + VMA_ITERATOR(vmi, mm_slot->mm, 0); + mm = mm_slot->mm; mmap_read_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (ksm_test_exit(mm)) break; if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma) @@ -2232,6 +2234,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) struct mm_slot *slot; struct vm_area_struct *vma; struct rmap_item *rmap_item; + struct vma_iterator vmi; int nid; if (list_empty(&ksm_mm_head.mm_list)) @@ -2290,13 +2293,13 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) } mm = slot->mm; + vma_iter_init(&vmi, mm, ksm_scan.address); + mmap_read_lock(mm); if (ksm_test_exit(mm)) - vma = NULL; - else - vma = find_vma(mm, ksm_scan.address); + goto no_vmas; - for (; vma; vma = vma->vm_next) { + for_each_vma(vmi, vma) { if (!(vma->vm_flags & VM_MERGEABLE)) continue; if (ksm_scan.address < vma->vm_start) @@ -2334,6 +2337,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page) } if (ksm_test_exit(mm)) { +no_vmas: ksm_scan.address = 0; ksm_scan.rmap_list = &slot->rmap_list; } diff --git a/mm/madvise.c b/mm/madvise.c index 98ed17a4471a..a5728a7975a2 100644 --- a/mm/madvise.c +++ b/mm/madvise.c @@ -1251,7 +1251,7 @@ int madvise_walk_vmas(struct mm_struct *mm, unsigned long start, if (start >= end) break; if (prev) - vma = prev->vm_next; + vma = find_vma(mm, prev->vm_end); else /* madvise_remove dropped mmap_lock */ vma = find_vma(mm, start); } diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 6a95ea7c5ee7..05b6af366ca4 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -2789,6 +2789,7 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg) * - LRU isolation * - lock_page_memcg() * - exclusive reference + * - mem_cgroup_trylock_pages() */ folio->memcg_data = (unsigned long)memcg; } @@ -5170,6 +5171,7 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg) static void mem_cgroup_free(struct mem_cgroup *memcg) { + lru_gen_exit_memcg(memcg); memcg_wb_domain_exit(memcg); __mem_cgroup_free(memcg); } @@ -5228,6 +5230,7 @@ static struct mem_cgroup *mem_cgroup_alloc(void) memcg->deferred_split_queue.split_queue_len = 0; #endif idr_replace(&mem_cgroup_idr, memcg, memcg->id.id); + lru_gen_init_memcg(memcg); return memcg; fail: mem_cgroup_id_remove(memcg); @@ -5871,7 +5874,7 @@ static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm) unsigned long precharge; mmap_read_lock(mm); - walk_page_range(mm, 0, mm->highest_vm_end, &precharge_walk_ops, NULL); + walk_page_range(mm, 0, ULONG_MAX, &precharge_walk_ops, NULL); mmap_read_unlock(mm); precharge = mc.precharge; @@ -6169,9 +6172,7 @@ static void mem_cgroup_move_charge(void) * When we have consumed all precharges and failed in doing * additional charge, the page walk just aborts. */ - walk_page_range(mc.mm, 0, mc.mm->highest_vm_end, &charge_walk_ops, - NULL); - + walk_page_range(mc.mm, 0, ULONG_MAX, &charge_walk_ops, NULL); mmap_read_unlock(mc.mm); atomic_dec(&mc.from->moving_account); } @@ -6196,6 +6197,30 @@ static void mem_cgroup_move_task(void) } #endif +#ifdef CONFIG_LRU_GEN +static void mem_cgroup_attach(struct cgroup_taskset *tset) +{ + struct task_struct *task; + struct cgroup_subsys_state *css; + + /* find the first leader if there is any */ + cgroup_taskset_for_each_leader(task, css, tset) + break; + + if (!task) + return; + + task_lock(task); + if (task->mm && READ_ONCE(task->mm->owner) == task) + lru_gen_migrate_mm(task->mm); + task_unlock(task); +} +#else +static void mem_cgroup_attach(struct cgroup_taskset *tset) +{ +} +#endif /* CONFIG_LRU_GEN */ + static int seq_puts_memcg_tunable(struct seq_file *m, unsigned long value) { if (value == PAGE_COUNTER_MAX) @@ -6601,6 +6626,7 @@ struct cgroup_subsys memory_cgrp_subsys = { .css_reset = mem_cgroup_css_reset, .css_rstat_flush = mem_cgroup_css_rstat_flush, .can_attach = mem_cgroup_can_attach, + .attach = mem_cgroup_attach, .cancel_attach = mem_cgroup_cancel_attach, .post_attach = mem_cgroup_move_task, .dfl_cftypes = memory_files, diff --git a/mm/memory.c b/mm/memory.c index de0dbe09b013..a2c19d411ad3 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -125,18 +125,6 @@ int randomize_va_space __read_mostly = 2; #endif -#ifndef arch_faults_on_old_pte -static inline bool arch_faults_on_old_pte(void) -{ - /* - * Those arches which don't have hw access flag feature need to - * implement their own helper. By default, "true" means pagefault - * will be hit on old pte. - */ - return true; -} -#endif - #ifndef arch_wants_old_prefaulted_pte static inline bool arch_wants_old_prefaulted_pte(void) { @@ -402,12 +390,21 @@ void free_pgd_range(struct mmu_gather *tlb, } while (pgd++, addr = next, addr != end); } -void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma, - unsigned long floor, unsigned long ceiling) +void free_pgtables(struct mmu_gather *tlb, struct maple_tree *mt, + struct vm_area_struct *vma, unsigned long floor, + unsigned long ceiling) { - while (vma) { - struct vm_area_struct *next = vma->vm_next; + MA_STATE(mas, mt, vma->vm_end, vma->vm_end); + + do { unsigned long addr = vma->vm_start; + struct vm_area_struct *next; + + /* + * Note: USER_PGTABLES_CEILING may be passed as ceiling and may + * be 0. This will underflow and is okay. + */ + next = mas_find(&mas, ceiling - 1); /* * Hide vma from rmap and truncate_pagecache before freeing @@ -426,7 +423,7 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma, while (next && next->vm_start <= vma->vm_end + PMD_SIZE && !is_vm_hugetlb_page(next)) { vma = next; - next = vma->vm_next; + next = mas_find(&mas, ceiling - 1); unlink_anon_vmas(vma); unlink_file_vma(vma); } @@ -434,7 +431,7 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma, floor, next ? next->vm_start : ceiling); } vma = next; - } + } while (vma); } void pmd_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte) @@ -1715,7 +1712,7 @@ static void unmap_single_vma(struct mmu_gather *tlb, * ensure that any thus-far unmapped pages are flushed before unmap_vmas() * drops the lock and schedules. */ -void unmap_vmas(struct mmu_gather *tlb, +void unmap_vmas(struct mmu_gather *tlb, struct maple_tree *mt, struct vm_area_struct *vma, unsigned long start_addr, unsigned long end_addr) { @@ -1725,12 +1722,14 @@ void unmap_vmas(struct mmu_gather *tlb, /* Careful - we need to zap private pages too! */ .even_cows = true, }; + MA_STATE(mas, mt, vma->vm_end, vma->vm_end); mmu_notifier_range_init(&range, MMU_NOTIFY_UNMAP, 0, vma, vma->vm_mm, start_addr, end_addr); mmu_notifier_invalidate_range_start(&range); - for ( ; vma && vma->vm_start < end_addr; vma = vma->vm_next) + do { unmap_single_vma(tlb, vma, start_addr, end_addr, &details); + } while ((vma = mas_find(&mas, end_addr - 1)) != NULL); mmu_notifier_invalidate_range_end(&range); } @@ -1745,8 +1744,11 @@ void unmap_vmas(struct mmu_gather *tlb, void zap_page_range(struct vm_area_struct *vma, unsigned long start, unsigned long size) { + struct maple_tree *mt = &vma->vm_mm->mm_mt; + unsigned long end = start + size; struct mmu_notifier_range range; struct mmu_gather tlb; + MA_STATE(mas, mt, vma->vm_end, vma->vm_end); lru_add_drain(); mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, vma->vm_mm, @@ -1754,8 +1756,9 @@ void zap_page_range(struct vm_area_struct *vma, unsigned long start, tlb_gather_mmu(&tlb, vma->vm_mm); update_hiwater_rss(vma->vm_mm); mmu_notifier_invalidate_range_start(&range); - for ( ; vma && vma->vm_start < range.end; vma = vma->vm_next) + do { unmap_single_vma(&tlb, vma, start, range.end, NULL); + } while ((vma = mas_find(&mas, end - 1)) != NULL); mmu_notifier_invalidate_range_end(&range); tlb_finish_mmu(&tlb); } @@ -2872,7 +2875,7 @@ static inline bool __wp_page_copy_user(struct page *dst, struct page *src, * On architectures with software "accessed" bits, we would * take a double page fault, so mark it accessed here. */ - if (arch_faults_on_old_pte() && !pte_young(vmf->orig_pte)) { + if (!arch_has_hw_pte_young() && !pte_young(vmf->orig_pte)) { pte_t entry; vmf->pte = pte_offset_map_lock(mm, vmf->pmd, addr, &vmf->ptl); @@ -5122,6 +5125,27 @@ static inline void mm_account_fault(struct pt_regs *regs, perf_sw_event(PERF_COUNT_SW_PAGE_FAULTS_MIN, 1, regs, address); } +#ifdef CONFIG_LRU_GEN +static void lru_gen_enter_fault(struct vm_area_struct *vma) +{ + /* the LRU algorithm doesn't apply to sequential or random reads */ + current->in_lru_fault = !(vma->vm_flags & (VM_SEQ_READ | VM_RAND_READ)); +} + +static void lru_gen_exit_fault(void) +{ + current->in_lru_fault = false; +} +#else +static void lru_gen_enter_fault(struct vm_area_struct *vma) +{ +} + +static void lru_gen_exit_fault(void) +{ +} +#endif /* CONFIG_LRU_GEN */ + /* * By the time we get here, we already hold the mm semaphore * @@ -5153,11 +5177,15 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address, if (flags & FAULT_FLAG_USER) mem_cgroup_enter_user_fault(); + lru_gen_enter_fault(vma); + if (unlikely(is_vm_hugetlb_page(vma))) ret = hugetlb_fault(vma->vm_mm, vma, address, flags); else ret = __handle_mm_fault(vma, address, flags); + lru_gen_exit_fault(); + if (flags & FAULT_FLAG_USER) { mem_cgroup_exit_user_fault(); /* diff --git a/mm/mempolicy.c b/mm/mempolicy.c index b73d3248d976..7e877c24afb5 100644 --- a/mm/mempolicy.c +++ b/mm/mempolicy.c @@ -381,9 +381,10 @@ void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new) void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new) { struct vm_area_struct *vma; + VMA_ITERATOR(vmi, mm, 0); mmap_write_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) + for_each_vma(vmi, vma) mpol_rebind_policy(vma->vm_policy, new); mmap_write_unlock(mm); } @@ -654,7 +655,7 @@ static unsigned long change_prot_numa(struct vm_area_struct *vma, static int queue_pages_test_walk(unsigned long start, unsigned long end, struct mm_walk *walk) { - struct vm_area_struct *vma = walk->vma; + struct vm_area_struct *next, *vma = walk->vma; struct queue_pages *qp = walk->private; unsigned long endvma = vma->vm_end; unsigned long flags = qp->flags; @@ -669,9 +670,10 @@ static int queue_pages_test_walk(unsigned long start, unsigned long end, /* hole at head side of range */ return -EFAULT; } + next = find_vma(vma->vm_mm, vma->vm_end); if (!(flags & MPOL_MF_DISCONTIG_OK) && ((vma->vm_end < qp->end) && - (!vma->vm_next || vma->vm_end < vma->vm_next->vm_start))) + (!next || vma->vm_end < next->vm_start))) /* hole at middle or tail of range */ return -EFAULT; @@ -785,26 +787,29 @@ static int vma_replace_policy(struct vm_area_struct *vma, static int mbind_range(struct mm_struct *mm, unsigned long start, unsigned long end, struct mempolicy *new_pol) { + MA_STATE(mas, &mm->mm_mt, start, start); struct vm_area_struct *prev; struct vm_area_struct *vma; int err = 0; pgoff_t pgoff; - unsigned long vmstart; - unsigned long vmend; - vma = find_vma(mm, start); - VM_BUG_ON(!vma); + prev = mas_prev(&mas, 0); + if (unlikely(!prev)) + mas_set(&mas, start); + + vma = mas_find(&mas, end - 1); + if (WARN_ON(!vma)) + return 0; - prev = vma->vm_prev; if (start > vma->vm_start) prev = vma; - for (; vma && vma->vm_start < end; prev = vma, vma = vma->vm_next) { - vmstart = max(start, vma->vm_start); - vmend = min(end, vma->vm_end); + for (; vma; vma = mas_next(&mas, end - 1)) { + unsigned long vmstart = max(start, vma->vm_start); + unsigned long vmend = min(end, vma->vm_end); if (mpol_equal(vma_policy(vma), new_pol)) - continue; + goto next; pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT); @@ -813,6 +818,8 @@ static int mbind_range(struct mm_struct *mm, unsigned long start, new_pol, vma->vm_userfaultfd_ctx, anon_vma_name(vma)); if (prev) { + /* vma_merge() invalidated the mas */ + mas_pause(&mas); vma = prev; goto replace; } @@ -820,19 +827,25 @@ static int mbind_range(struct mm_struct *mm, unsigned long start, err = split_vma(vma->vm_mm, vma, vmstart, 1); if (err) goto out; + /* split_vma() invalidated the mas */ + mas_pause(&mas); } if (vma->vm_end != vmend) { err = split_vma(vma->vm_mm, vma, vmend, 0); if (err) goto out; + /* split_vma() invalidated the mas */ + mas_pause(&mas); } - replace: +replace: err = vma_replace_policy(vma, new_pol); if (err) goto out; +next: + prev = vma; } - out: +out: return err; } @@ -1047,6 +1060,7 @@ static int migrate_to_node(struct mm_struct *mm, int source, int dest, int flags) { nodemask_t nmask; + struct vm_area_struct *vma; LIST_HEAD(pagelist); int err = 0; struct migration_target_control mtc = { @@ -1062,8 +1076,9 @@ static int migrate_to_node(struct mm_struct *mm, int source, int dest, * need migration. Between passing in the full user address * space range and MPOL_MF_DISCONTIG_OK, this call can not fail. */ + vma = find_vma(mm, 0); VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))); - queue_pages_range(mm, mm->mmap->vm_start, mm->task_size, &nmask, + queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask, flags | MPOL_MF_DISCONTIG_OK, &pagelist); if (!list_empty(&pagelist)) { @@ -1193,14 +1208,13 @@ static struct page *new_page(struct page *page, unsigned long start) struct folio *dst, *src = page_folio(page); struct vm_area_struct *vma; unsigned long address; + VMA_ITERATOR(vmi, current->mm, start); gfp_t gfp = GFP_HIGHUSER_MOVABLE | __GFP_RETRY_MAYFAIL; - vma = find_vma(current->mm, start); - while (vma) { + for_each_vma(vmi, vma) { address = page_address_in_vma(page, vma); if (address != -EFAULT) break; - vma = vma->vm_next; } if (folio_test_hugetlb(src)) @@ -1478,6 +1492,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le unsigned long vmend; unsigned long end; int err = -ENOENT; + VMA_ITERATOR(vmi, mm, start); start = untagged_addr(start); if (start & ~PAGE_MASK) @@ -1503,9 +1518,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le if (end == start) return 0; mmap_write_lock(mm); - vma = find_vma(mm, start); - for (; vma && vma->vm_start < end; vma = vma->vm_next) { - + for_each_vma_range(vmi, vma, end) { vmstart = max(start, vma->vm_start); vmend = min(end, vma->vm_end); new = mpol_dup(vma_policy(vma)); diff --git a/mm/mlock.c b/mm/mlock.c index b14e929084cc..43d19a1f28eb 100644 --- a/mm/mlock.c +++ b/mm/mlock.c @@ -471,6 +471,7 @@ static int apply_vma_lock_flags(unsigned long start, size_t len, unsigned long nstart, end, tmp; struct vm_area_struct *vma, *prev; int error; + MA_STATE(mas, ¤t->mm->mm_mt, start, start); VM_BUG_ON(offset_in_page(start)); VM_BUG_ON(len != PAGE_ALIGN(len)); @@ -479,13 +480,14 @@ static int apply_vma_lock_flags(unsigned long start, size_t len, return -EINVAL; if (end == start) return 0; - vma = find_vma(current->mm, start); - if (!vma || vma->vm_start > start) + vma = mas_walk(&mas); + if (!vma) return -ENOMEM; - prev = vma->vm_prev; if (start > vma->vm_start) prev = vma; + else + prev = mas_prev(&mas, 0); for (nstart = start ; ; ) { vm_flags_t newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK; @@ -505,7 +507,7 @@ static int apply_vma_lock_flags(unsigned long start, size_t len, if (nstart >= end) break; - vma = prev->vm_next; + vma = find_vma(prev->vm_mm, prev->vm_end); if (!vma || vma->vm_start != nstart) { error = -ENOMEM; break; @@ -526,24 +528,23 @@ static unsigned long count_mm_mlocked_page_nr(struct mm_struct *mm, { struct vm_area_struct *vma; unsigned long count = 0; + unsigned long end; + VMA_ITERATOR(vmi, mm, start); if (mm == NULL) mm = current->mm; - vma = find_vma(mm, start); - if (vma == NULL) - return 0; - - for (; vma ; vma = vma->vm_next) { - if (start >= vma->vm_end) - continue; - if (start + len <= vma->vm_start) - break; + /* Don't overflow past ULONG_MAX */ + if (unlikely(ULONG_MAX - len < start)) + end = ULONG_MAX; + else + end = start + len; + for_each_vma_range(vmi, vma, end) { if (vma->vm_flags & VM_LOCKED) { if (start > vma->vm_start) count -= (start - vma->vm_start); - if (start + len < vma->vm_end) { - count += start + len - vma->vm_start; + if (end < vma->vm_end) { + count += end - vma->vm_start; break; } count += vma->vm_end - vma->vm_start; @@ -659,6 +660,7 @@ SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len) */ static int apply_mlockall_flags(int flags) { + MA_STATE(mas, ¤t->mm->mm_mt, 0, 0); struct vm_area_struct *vma, *prev = NULL; vm_flags_t to_add = 0; @@ -679,7 +681,7 @@ static int apply_mlockall_flags(int flags) to_add |= VM_LOCKONFAULT; } - for (vma = current->mm->mmap; vma ; vma = prev->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { vm_flags_t newflags; newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK; @@ -687,6 +689,7 @@ static int apply_mlockall_flags(int flags) /* Ignore errors */ mlock_fixup(vma, &prev, vma->vm_start, vma->vm_end, newflags); + mas_pause(&mas); cond_resched(); } out: diff --git a/mm/mm_init.c b/mm/mm_init.c index 9ddaf0e1b0ab..0d7b2bd2454a 100644 --- a/mm/mm_init.c +++ b/mm/mm_init.c @@ -65,14 +65,16 @@ void __init mminit_verify_pageflags_layout(void) shift = 8 * sizeof(unsigned long); width = shift - SECTIONS_WIDTH - NODES_WIDTH - ZONES_WIDTH - - LAST_CPUPID_SHIFT - KASAN_TAG_WIDTH; + - LAST_CPUPID_SHIFT - KASAN_TAG_WIDTH - LRU_GEN_WIDTH - LRU_REFS_WIDTH; mminit_dprintk(MMINIT_TRACE, "pageflags_layout_widths", - "Section %d Node %d Zone %d Lastcpupid %d Kasantag %d Flags %d\n", + "Section %d Node %d Zone %d Lastcpupid %d Kasantag %d Gen %d Tier %d Flags %d\n", SECTIONS_WIDTH, NODES_WIDTH, ZONES_WIDTH, LAST_CPUPID_WIDTH, KASAN_TAG_WIDTH, + LRU_GEN_WIDTH, + LRU_REFS_WIDTH, NR_PAGEFLAGS); mminit_dprintk(MMINIT_TRACE, "pageflags_layout_shifts", "Section %d Node %d Zone %d Lastcpupid %d Kasantag %d\n", diff --git a/mm/mmap.c b/mm/mmap.c index 36c08e2c78da..1b7867aa2737 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -39,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -77,9 +75,10 @@ int mmap_rnd_compat_bits __read_mostly = CONFIG_ARCH_MMAP_RND_COMPAT_BITS; static bool ignore_rlimit_data; core_param(ignore_rlimit_data, ignore_rlimit_data, bool, 0644); -static void unmap_region(struct mm_struct *mm, +static void unmap_region(struct mm_struct *mm, struct maple_tree *mt, struct vm_area_struct *vma, struct vm_area_struct *prev, - unsigned long start, unsigned long end); + struct vm_area_struct *next, unsigned long start, + unsigned long end); static pgprot_t vm_pgprot_modify(pgprot_t oldprot, unsigned long vm_flags) { @@ -132,12 +131,10 @@ void unlink_file_vma(struct vm_area_struct *vma) } /* - * Close a vm structure and free it, returning the next. + * Close a vm structure and free it. */ -static struct vm_area_struct *remove_vma(struct vm_area_struct *vma) +static void remove_vma(struct vm_area_struct *vma) { - struct vm_area_struct *next = vma->vm_next; - might_sleep(); if (vma->vm_ops && vma->vm_ops->close) vma->vm_ops->close(vma); @@ -145,20 +142,41 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma) fput(vma->vm_file); mpol_put(vma_policy(vma)); vm_area_free(vma); - return next; } -static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags, - struct list_head *uf); +/* + * check_brk_limits() - Use platform specific check of range & verify mlock + * limits. + * @addr: The address to check + * @len: The size of increase. + * + * Return: 0 on success. + */ +static int check_brk_limits(unsigned long addr, unsigned long len) +{ + unsigned long mapped_addr; + + mapped_addr = get_unmapped_area(NULL, addr, len, 0, MAP_FIXED); + if (IS_ERR_VALUE(mapped_addr)) + return mapped_addr; + + return mlock_future_check(current->mm, current->mm->def_flags, len); +} +static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, + unsigned long newbrk, unsigned long oldbrk, + struct list_head *uf); +static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *brkvma, + unsigned long addr, unsigned long request, unsigned long flags); SYSCALL_DEFINE1(brk, unsigned long, brk) { unsigned long newbrk, oldbrk, origbrk; struct mm_struct *mm = current->mm; - struct vm_area_struct *next; + struct vm_area_struct *brkvma, *next = NULL; unsigned long min_brk; bool populate; bool downgraded = false; LIST_HEAD(uf); + MA_STATE(mas, &mm->mm_mt, 0, 0); if (mmap_write_lock_killable(mm)) return -EINTR; @@ -200,35 +218,51 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) /* * Always allow shrinking brk. - * __do_munmap() may downgrade mmap_lock to read. + * do_brk_munmap() may downgrade mmap_lock to read. */ if (brk <= mm->brk) { int ret; + /* Search one past newbrk */ + mas_set(&mas, newbrk); + brkvma = mas_find(&mas, oldbrk); + BUG_ON(brkvma == NULL); + if (brkvma->vm_start >= oldbrk) + goto out; /* mapping intersects with an existing non-brk vma. */ /* - * mm->brk must to be protected by write mmap_lock so update it - * before downgrading mmap_lock. When __do_munmap() fails, - * mm->brk will be restored from origbrk. + * mm->brk must be protected by write mmap_lock. + * do_brk_munmap() may downgrade the lock, so update it + * before calling do_brk_munmap(). */ mm->brk = brk; - ret = __do_munmap(mm, newbrk, oldbrk-newbrk, &uf, true); - if (ret < 0) { - mm->brk = origbrk; - goto out; - } else if (ret == 1) { + ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf); + if (ret == 1) { downgraded = true; - } - goto success; + goto success; + } else if (!ret) + goto success; + + mm->brk = origbrk; + goto out; } - /* Check against existing mmap mappings. */ - next = find_vma(mm, oldbrk); + if (check_brk_limits(oldbrk, newbrk - oldbrk)) + goto out; + + /* + * Only check if the next VMA is within the stack_guard_gap of the + * expansion area + */ + mas_set(&mas, oldbrk); + next = mas_find(&mas, newbrk - 1 + PAGE_SIZE + stack_guard_gap); if (next && newbrk + PAGE_SIZE > vm_start_gap(next)) goto out; + brkvma = mas_prev(&mas, mm->start_brk); /* Ok, looks good - let it rip. */ - if (do_brk_flags(oldbrk, newbrk-oldbrk, 0, &uf) < 0) + if (do_brk_flags(&mas, brkvma, oldbrk, newbrk - oldbrk, 0) < 0) goto out; + mm->brk = brk; success: @@ -247,104 +281,45 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) return origbrk; } -static inline unsigned long vma_compute_gap(struct vm_area_struct *vma) -{ - unsigned long gap, prev_end; - - /* - * Note: in the rare case of a VM_GROWSDOWN above a VM_GROWSUP, we - * allow two stack_guard_gaps between them here, and when choosing - * an unmapped area; whereas when expanding we only require one. - * That's a little inconsistent, but keeps the code here simpler. - */ - gap = vm_start_gap(vma); - if (vma->vm_prev) { - prev_end = vm_end_gap(vma->vm_prev); - if (gap > prev_end) - gap -= prev_end; - else - gap = 0; - } - return gap; -} - -#ifdef CONFIG_DEBUG_VM_RB -static unsigned long vma_compute_subtree_gap(struct vm_area_struct *vma) -{ - unsigned long max = vma_compute_gap(vma), subtree_gap; - if (vma->vm_rb.rb_left) { - subtree_gap = rb_entry(vma->vm_rb.rb_left, - struct vm_area_struct, vm_rb)->rb_subtree_gap; - if (subtree_gap > max) - max = subtree_gap; - } - if (vma->vm_rb.rb_right) { - subtree_gap = rb_entry(vma->vm_rb.rb_right, - struct vm_area_struct, vm_rb)->rb_subtree_gap; - if (subtree_gap > max) - max = subtree_gap; - } - return max; -} - -static int browse_rb(struct mm_struct *mm) -{ - struct rb_root *root = &mm->mm_rb; - int i = 0, j, bug = 0; - struct rb_node *nd, *pn = NULL; - unsigned long prev = 0, pend = 0; - - for (nd = rb_first(root); nd; nd = rb_next(nd)) { - struct vm_area_struct *vma; - vma = rb_entry(nd, struct vm_area_struct, vm_rb); - if (vma->vm_start < prev) { - pr_emerg("vm_start %lx < prev %lx\n", - vma->vm_start, prev); - bug = 1; - } - if (vma->vm_start < pend) { - pr_emerg("vm_start %lx < pend %lx\n", - vma->vm_start, pend); - bug = 1; - } - if (vma->vm_start > vma->vm_end) { - pr_emerg("vm_start %lx > vm_end %lx\n", - vma->vm_start, vma->vm_end); - bug = 1; - } - spin_lock(&mm->page_table_lock); - if (vma->rb_subtree_gap != vma_compute_subtree_gap(vma)) { - pr_emerg("free gap %lx, correct %lx\n", - vma->rb_subtree_gap, - vma_compute_subtree_gap(vma)); - bug = 1; +#if defined(CONFIG_DEBUG_VM_MAPLE_TREE) +extern void mt_validate(struct maple_tree *mt); +extern void mt_dump(const struct maple_tree *mt); + +/* Validate the maple tree */ +static void validate_mm_mt(struct mm_struct *mm) +{ + struct maple_tree *mt = &mm->mm_mt; + struct vm_area_struct *vma_mt; + + MA_STATE(mas, mt, 0, 0); + + mt_validate(&mm->mm_mt); + mas_for_each(&mas, vma_mt, ULONG_MAX) { + if ((vma_mt->vm_start != mas.index) || + (vma_mt->vm_end - 1 != mas.last)) { + pr_emerg("issue in %s\n", current->comm); + dump_stack(); + dump_vma(vma_mt); + pr_emerg("mt piv: %p %lu - %lu\n", vma_mt, + mas.index, mas.last); + pr_emerg("mt vma: %p %lu - %lu\n", vma_mt, + vma_mt->vm_start, vma_mt->vm_end); + + mt_dump(mas.tree); + if (vma_mt->vm_end != mas.last + 1) { + pr_err("vma: %p vma_mt %lu-%lu\tmt %lu-%lu\n", + mm, vma_mt->vm_start, vma_mt->vm_end, + mas.index, mas.last); + mt_dump(mas.tree); + } + VM_BUG_ON_MM(vma_mt->vm_end != mas.last + 1, mm); + if (vma_mt->vm_start != mas.index) { + pr_err("vma: %p vma_mt %p %lu - %lu doesn't match\n", + mm, vma_mt, vma_mt->vm_start, vma_mt->vm_end); + mt_dump(mas.tree); + } + VM_BUG_ON_MM(vma_mt->vm_start != mas.index, mm); } - spin_unlock(&mm->page_table_lock); - i++; - pn = nd; - prev = vma->vm_start; - pend = vma->vm_end; - } - j = 0; - for (nd = pn; nd; nd = rb_prev(nd)) - j++; - if (i != j) { - pr_emerg("backwards %d, forwards %d\n", j, i); - bug = 1; - } - return bug ? -1 : i; -} - -static void validate_mm_rb(struct rb_root *root, struct vm_area_struct *ignore) -{ - struct rb_node *nd; - - for (nd = rb_first(root); nd; nd = rb_next(nd)) { - struct vm_area_struct *vma; - vma = rb_entry(nd, struct vm_area_struct, vm_rb); - VM_BUG_ON_VMA(vma != ignore && - vma->rb_subtree_gap != vma_compute_subtree_gap(vma), - vma); } } @@ -352,10 +327,13 @@ static void validate_mm(struct mm_struct *mm) { int bug = 0; int i = 0; - unsigned long highest_address = 0; - struct vm_area_struct *vma = mm->mmap; + struct vm_area_struct *vma; + MA_STATE(mas, &mm->mm_mt, 0, 0); - while (vma) { + validate_mm_mt(mm); + + mas_for_each(&mas, vma, ULONG_MAX) { +#ifdef CONFIG_DEBUG_VM_RB struct anon_vma *anon_vma = vma->anon_vma; struct anon_vma_chain *avc; @@ -365,93 +343,20 @@ static void validate_mm(struct mm_struct *mm) anon_vma_interval_tree_verify(avc); anon_vma_unlock_read(anon_vma); } - - highest_address = vm_end_gap(vma); - vma = vma->vm_next; +#endif i++; } if (i != mm->map_count) { - pr_emerg("map_count %d vm_next %d\n", mm->map_count, i); - bug = 1; - } - if (highest_address != mm->highest_vm_end) { - pr_emerg("mm->highest_vm_end %lx, found %lx\n", - mm->highest_vm_end, highest_address); - bug = 1; - } - i = browse_rb(mm); - if (i != mm->map_count) { - if (i != -1) - pr_emerg("map_count %d rb %d\n", mm->map_count, i); + pr_emerg("map_count %d mas_for_each %d\n", mm->map_count, i); bug = 1; } VM_BUG_ON_MM(bug, mm); } -#else -#define validate_mm_rb(root, ignore) do { } while (0) -#define validate_mm(mm) do { } while (0) -#endif - -RB_DECLARE_CALLBACKS_MAX(static, vma_gap_callbacks, - struct vm_area_struct, vm_rb, - unsigned long, rb_subtree_gap, vma_compute_gap) - -/* - * Update augmented rbtree rb_subtree_gap values after vma->vm_start or - * vma->vm_prev->vm_end values changed, without modifying the vma's position - * in the rbtree. - */ -static void vma_gap_update(struct vm_area_struct *vma) -{ - /* - * As it turns out, RB_DECLARE_CALLBACKS_MAX() already created - * a callback function that does exactly what we want. - */ - vma_gap_callbacks_propagate(&vma->vm_rb, NULL); -} - -static inline void vma_rb_insert(struct vm_area_struct *vma, - struct rb_root *root) -{ - /* All rb_subtree_gap values must be consistent prior to insertion */ - validate_mm_rb(root, NULL); - - rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks); -} - -static void __vma_rb_erase(struct vm_area_struct *vma, struct rb_root *root) -{ - /* - * Note rb_erase_augmented is a fairly large inline function, - * so make sure we instantiate it only once with our desired - * augmented rbtree callbacks. - */ - rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks); -} - -static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma, - struct rb_root *root, - struct vm_area_struct *ignore) -{ - /* - * All rb_subtree_gap values must be consistent prior to erase, - * with the possible exception of - * - * a. the "next" vma being erased if next->vm_start was reduced in - * __vma_adjust() -> __vma_unlink() - * b. the vma being erased in detach_vmas_to_be_unmapped() -> - * vma_rb_erase() - */ - validate_mm_rb(root, ignore); - - __vma_rb_erase(vma, root); -} -static __always_inline void vma_rb_erase(struct vm_area_struct *vma, - struct rb_root *root) -{ - vma_rb_erase_ignore(vma, root, vma); -} +#else /* !CONFIG_DEBUG_VM_MAPLE_TREE */ +#define validate_mm_mt(root) do { } while (0) +#define validate_mm(mm) do { } while (0) +#endif /* CONFIG_DEBUG_VM_MAPLE_TREE */ /* * vma has some anon_vma assigned, and is already inserted on that @@ -485,208 +390,220 @@ anon_vma_interval_tree_post_update_vma(struct vm_area_struct *vma) anon_vma_interval_tree_insert(avc, &avc->anon_vma->rb_root); } -static int find_vma_links(struct mm_struct *mm, unsigned long addr, - unsigned long end, struct vm_area_struct **pprev, - struct rb_node ***rb_link, struct rb_node **rb_parent) +static unsigned long count_vma_pages_range(struct mm_struct *mm, + unsigned long addr, unsigned long end) { - struct rb_node **__rb_link, *__rb_parent, *rb_prev; + VMA_ITERATOR(vmi, mm, addr); + struct vm_area_struct *vma; + unsigned long nr_pages = 0; - mmap_assert_locked(mm); - __rb_link = &mm->mm_rb.rb_node; - rb_prev = __rb_parent = NULL; + for_each_vma_range(vmi, vma, end) { + unsigned long vm_start = max(addr, vma->vm_start); + unsigned long vm_end = min(end, vma->vm_end); - while (*__rb_link) { - struct vm_area_struct *vma_tmp; + nr_pages += PHYS_PFN(vm_end - vm_start); + } - __rb_parent = *__rb_link; - vma_tmp = rb_entry(__rb_parent, struct vm_area_struct, vm_rb); + return nr_pages; +} - if (vma_tmp->vm_end > addr) { - /* Fail if an existing vma overlaps the area */ - if (vma_tmp->vm_start < end) - return -ENOMEM; - __rb_link = &__rb_parent->rb_left; - } else { - rb_prev = __rb_parent; - __rb_link = &__rb_parent->rb_right; - } - } +static void __vma_link_file(struct vm_area_struct *vma, + struct address_space *mapping) +{ + if (vma->vm_flags & VM_SHARED) + mapping_allow_writable(mapping); - *pprev = NULL; - if (rb_prev) - *pprev = rb_entry(rb_prev, struct vm_area_struct, vm_rb); - *rb_link = __rb_link; - *rb_parent = __rb_parent; - return 0; + flush_dcache_mmap_lock(mapping); + vma_interval_tree_insert(vma, &mapping->i_mmap); + flush_dcache_mmap_unlock(mapping); } /* - * vma_next() - Get the next VMA. - * @mm: The mm_struct. - * @vma: The current vma. + * vma_mas_store() - Store a VMA in the maple tree. + * @vma: The vm_area_struct + * @mas: The maple state * - * If @vma is NULL, return the first vma in the mm. + * Efficient way to store a VMA in the maple tree when the @mas has already + * walked to the correct location. * - * Returns: The next VMA after @vma. + * Note: the end address is inclusive in the maple tree. */ -static inline struct vm_area_struct *vma_next(struct mm_struct *mm, - struct vm_area_struct *vma) +void vma_mas_store(struct vm_area_struct *vma, struct ma_state *mas) { - if (!vma) - return mm->mmap; - - return vma->vm_next; + trace_vma_store(mas->tree, vma); + mas_set_range(mas, vma->vm_start, vma->vm_end - 1); + mas_store_prealloc(mas, vma); } /* - * munmap_vma_range() - munmap VMAs that overlap a range. - * @mm: The mm struct - * @start: The start of the range. - * @len: The length of the range. - * @pprev: pointer to the pointer that will be set to previous vm_area_struct - * @rb_link: the rb_node - * @rb_parent: the parent rb_node - * - * Find all the vm_area_struct that overlap from @start to - * @end and munmap them. Set @pprev to the previous vm_area_struct. + * vma_mas_remove() - Remove a VMA from the maple tree. + * @vma: The vm_area_struct + * @mas: The maple state * - * Returns: -ENOMEM on munmap failure or 0 on success. + * Efficient way to remove a VMA from the maple tree when the @mas has already + * been established and points to the correct location. + * Note: the end address is inclusive in the maple tree. */ -static inline int -munmap_vma_range(struct mm_struct *mm, unsigned long start, unsigned long len, - struct vm_area_struct **pprev, struct rb_node ***link, - struct rb_node **parent, struct list_head *uf) +void vma_mas_remove(struct vm_area_struct *vma, struct ma_state *mas) { - - while (find_vma_links(mm, start, start + len, pprev, link, parent)) - if (do_munmap(mm, start, len, uf)) - return -ENOMEM; - - return 0; + trace_vma_mas_szero(mas->tree, vma->vm_start, vma->vm_end - 1); + mas->index = vma->vm_start; + mas->last = vma->vm_end - 1; + mas_store_prealloc(mas, NULL); } -static unsigned long count_vma_pages_range(struct mm_struct *mm, - unsigned long addr, unsigned long end) + +/* + * vma_mas_szero() - Set a given range to zero. Used when modifying a + * vm_area_struct start or end. + * + * @mm: The struct_mm + * @start: The start address to zero + * @end: The end address to zero. + */ +static inline void vma_mas_szero(struct ma_state *mas, unsigned long start, + unsigned long end) { - unsigned long nr_pages = 0; - struct vm_area_struct *vma; + trace_vma_mas_szero(mas->tree, start, end - 1); + mas_set_range(mas, start, end - 1); + mas_store_prealloc(mas, NULL); +} - /* Find first overlapping mapping */ - vma = find_vma_intersection(mm, addr, end); - if (!vma) - return 0; +static int vma_link(struct mm_struct *mm, struct vm_area_struct *vma) +{ + MA_STATE(mas, &mm->mm_mt, 0, 0); + struct address_space *mapping = NULL; - nr_pages = (min(end, vma->vm_end) - - max(addr, vma->vm_start)) >> PAGE_SHIFT; + if (mas_preallocate(&mas, vma, GFP_KERNEL)) + return -ENOMEM; - /* Iterate over the rest of the overlaps */ - for (vma = vma->vm_next; vma; vma = vma->vm_next) { - unsigned long overlap_len; + if (vma->vm_file) { + mapping = vma->vm_file->f_mapping; + i_mmap_lock_write(mapping); + } - if (vma->vm_start > end) - break; + vma_mas_store(vma, &mas); - overlap_len = min(end, vma->vm_end) - vma->vm_start; - nr_pages += overlap_len >> PAGE_SHIFT; + if (mapping) { + __vma_link_file(vma, mapping); + i_mmap_unlock_write(mapping); } - return nr_pages; + mm->map_count++; + validate_mm(mm); + return 0; } -void __vma_link_rb(struct mm_struct *mm, struct vm_area_struct *vma, - struct rb_node **rb_link, struct rb_node *rb_parent) +/* + * vma_expand - Expand an existing VMA + * + * @mas: The maple state + * @vma: The vma to expand + * @start: The start of the vma + * @end: The exclusive end of the vma + * @pgoff: The page offset of vma + * @next: The current of next vma. + * + * Expand @vma to @start and @end. Can expand off the start and end. Will + * expand over @next if it's different from @vma and @end == @next->vm_end. + * Checking if the @vma can expand and merge with @next needs to be handled by + * the caller. + * + * Returns: 0 on success + */ +inline int vma_expand(struct ma_state *mas, struct vm_area_struct *vma, + unsigned long start, unsigned long end, pgoff_t pgoff, + struct vm_area_struct *next) { - /* Update tracking information for the gap following the new vma. */ - if (vma->vm_next) - vma_gap_update(vma->vm_next); - else - mm->highest_vm_end = vm_end_gap(vma); + struct mm_struct *mm = vma->vm_mm; + struct address_space *mapping = NULL; + struct rb_root_cached *root = NULL; + struct anon_vma *anon_vma = vma->anon_vma; + struct file *file = vma->vm_file; + bool remove_next = false; - /* - * vma->vm_prev wasn't known when we followed the rbtree to find the - * correct insertion point for that vma. As a result, we could not - * update the vma vm_rb parents rb_subtree_gap values on the way down. - * So, we first insert the vma with a zero rb_subtree_gap value - * (to be consistent with what we did on the way down), and then - * immediately update the gap to the correct value. Finally we - * rebalance the rbtree after all augmented values have been set. - */ - rb_link_node(&vma->vm_rb, rb_parent, rb_link); - vma->rb_subtree_gap = 0; - vma_gap_update(vma); - vma_rb_insert(vma, &mm->mm_rb); -} + if (next && (vma != next) && (end == next->vm_end)) { + remove_next = true; + if (next->anon_vma && !vma->anon_vma) { + int error; -static void __vma_link_file(struct vm_area_struct *vma) -{ - struct file *file; + anon_vma = next->anon_vma; + vma->anon_vma = anon_vma; + error = anon_vma_clone(vma, next); + if (error) + return error; + } + } + + /* Not merging but overwriting any part of next is not handled. */ + VM_BUG_ON(next && !remove_next && next != vma && end > next->vm_start); + /* Only handles expanding */ + VM_BUG_ON(vma->vm_start < start || vma->vm_end > end); + + if (mas_preallocate(mas, vma, GFP_KERNEL)) + goto nomem; + + vma_adjust_trans_huge(vma, start, end, 0); - file = vma->vm_file; if (file) { - struct address_space *mapping = file->f_mapping; + mapping = file->f_mapping; + root = &mapping->i_mmap; + uprobe_munmap(vma, vma->vm_start, vma->vm_end); + i_mmap_lock_write(mapping); + } - if (vma->vm_flags & VM_SHARED) - mapping_allow_writable(mapping); + if (anon_vma) { + anon_vma_lock_write(anon_vma); + anon_vma_interval_tree_pre_update_vma(vma); + } + if (file) { flush_dcache_mmap_lock(mapping); - vma_interval_tree_insert(vma, &mapping->i_mmap); - flush_dcache_mmap_unlock(mapping); + vma_interval_tree_remove(vma, root); } -} -static void -__vma_link(struct mm_struct *mm, struct vm_area_struct *vma, - struct vm_area_struct *prev, struct rb_node **rb_link, - struct rb_node *rb_parent) -{ - __vma_link_list(mm, vma, prev); - __vma_link_rb(mm, vma, rb_link, rb_parent); -} + vma->vm_start = start; + vma->vm_end = end; + vma->vm_pgoff = pgoff; + /* Note: mas must be pointing to the expanding VMA */ + vma_mas_store(vma, mas); -static void vma_link(struct mm_struct *mm, struct vm_area_struct *vma, - struct vm_area_struct *prev, struct rb_node **rb_link, - struct rb_node *rb_parent) -{ - struct address_space *mapping = NULL; + if (file) { + vma_interval_tree_insert(vma, root); + flush_dcache_mmap_unlock(mapping); + } - if (vma->vm_file) { - mapping = vma->vm_file->f_mapping; - i_mmap_lock_write(mapping); + /* Expanding over the next vma */ + if (remove_next && file) { + __remove_shared_vm_struct(next, file, mapping); } - __vma_link(mm, vma, prev, rb_link, rb_parent); - __vma_link_file(vma); + if (anon_vma) { + anon_vma_interval_tree_post_update_vma(vma); + anon_vma_unlock_write(anon_vma); + } - if (mapping) + if (file) { i_mmap_unlock_write(mapping); + uprobe_mmap(vma); + } - mm->map_count++; - validate_mm(mm); -} - -/* - * Helper for vma_adjust() in the split_vma insert case: insert a vma into the - * mm's list and rbtree. It has already been inserted into the interval tree. - */ -static void __insert_vm_struct(struct mm_struct *mm, struct vm_area_struct *vma) -{ - struct vm_area_struct *prev; - struct rb_node **rb_link, *rb_parent; + if (remove_next) { + if (file) { + uprobe_munmap(next, next->vm_start, next->vm_end); + fput(file); + } + if (next->anon_vma) + anon_vma_merge(vma, next); + mm->map_count--; + mpol_put(vma_policy(next)); + vm_area_free(next); + } - if (find_vma_links(mm, vma->vm_start, vma->vm_end, - &prev, &rb_link, &rb_parent)) - BUG(); - __vma_link(mm, vma, prev, rb_link, rb_parent); - mm->map_count++; -} + validate_mm(mm); + return 0; -static __always_inline void __vma_unlink(struct mm_struct *mm, - struct vm_area_struct *vma, - struct vm_area_struct *ignore) -{ - vma_rb_erase_ignore(vma, &mm->mm_rb, ignore); - __vma_unlink_list(mm, vma); - /* Kill the cache */ - vmacache_invalidate(mm); +nomem: + return -ENOMEM; } /* @@ -701,18 +618,20 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, struct vm_area_struct *expand) { struct mm_struct *mm = vma->vm_mm; - struct vm_area_struct *next = vma->vm_next, *orig_vma = vma; + struct vm_area_struct *next_next = NULL; /* uninit var warning */ + struct vm_area_struct *next = find_vma(mm, vma->vm_end); + struct vm_area_struct *orig_vma = vma; struct address_space *mapping = NULL; struct rb_root_cached *root = NULL; struct anon_vma *anon_vma = NULL; struct file *file = vma->vm_file; - bool start_changed = false, end_changed = false; + bool vma_changed = false; long adjust_next = 0; int remove_next = 0; + MA_STATE(mas, &mm->mm_mt, 0, 0); + struct vm_area_struct *exporter = NULL, *importer = NULL; if (next && !insert) { - struct vm_area_struct *exporter = NULL, *importer = NULL; - if (end >= next->vm_end) { /* * vma expands, overlapping all the next, and @@ -741,10 +660,11 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, * remove_next == 1 is case 1 or 7. */ remove_next = 1 + (end > next->vm_end); + if (remove_next == 2) + next_next = find_vma(mm, next->vm_end); + VM_WARN_ON(remove_next == 2 && - end != next->vm_next->vm_end); - /* trim end to next, for case 6 first pass */ - end = next->vm_end; + end != next_next->vm_end); } exporter = next; @@ -755,7 +675,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, * next, if the vma overlaps with it. */ if (remove_next == 2 && !next->anon_vma) - exporter = next->vm_next; + exporter = next_next; } else if (end > next->vm_start) { /* @@ -792,9 +712,11 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, return error; } } -again: - vma_adjust_trans_huge(orig_vma, start, end, adjust_next); + if (mas_preallocate(&mas, vma, GFP_KERNEL)) + return -ENOMEM; + + vma_adjust_trans_huge(orig_vma, start, end, adjust_next); if (file) { mapping = file->f_mapping; root = &mapping->i_mmap; @@ -804,14 +726,14 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, uprobe_munmap(next, next->vm_start, next->vm_end); i_mmap_lock_write(mapping); - if (insert) { + if (insert && insert->vm_file) { /* * Put into interval tree now, so instantiated pages * are visible to arm/parisc __flush_dcache_page * throughout; but we cannot insert into address * space until vma start or end is updated. */ - __vma_link_file(insert); + __vma_link_file(insert, insert->vm_file->f_mapping); } } @@ -835,17 +757,37 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, } if (start != vma->vm_start) { + if ((vma->vm_start < start) && + (!insert || (insert->vm_end != start))) { + vma_mas_szero(&mas, vma->vm_start, start); + VM_WARN_ON(insert && insert->vm_start > vma->vm_start); + } else { + vma_changed = true; + } vma->vm_start = start; - start_changed = true; } if (end != vma->vm_end) { + if (vma->vm_end > end) { + if (!insert || (insert->vm_start != end)) { + vma_mas_szero(&mas, end, vma->vm_end); + mas_reset(&mas); + VM_WARN_ON(insert && + insert->vm_end < vma->vm_end); + } + } else { + vma_changed = true; + } vma->vm_end = end; - end_changed = true; } + + if (vma_changed) + vma_mas_store(vma, &mas); + vma->vm_pgoff = pgoff; if (adjust_next) { next->vm_start += adjust_next; next->vm_pgoff += adjust_next >> PAGE_SHIFT; + vma_mas_store(next, &mas); } if (file) { @@ -855,42 +797,19 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, flush_dcache_mmap_unlock(mapping); } - if (remove_next) { - /* - * vma_merge has merged next into vma, and needs - * us to remove next before dropping the locks. - */ - if (remove_next != 3) - __vma_unlink(mm, next, next); - else - /* - * vma is not before next if they've been - * swapped. - * - * pre-swap() next->vm_start was reduced so - * tell validate_mm_rb to ignore pre-swap() - * "next" (which is stored in post-swap() - * "vma"). - */ - __vma_unlink(mm, next, vma); - if (file) - __remove_shared_vm_struct(next, file, mapping); + if (remove_next && file) { + __remove_shared_vm_struct(next, file, mapping); + if (remove_next == 2) + __remove_shared_vm_struct(next_next, file, mapping); } else if (insert) { /* * split_vma has split insert from vma, and needs * us to insert it before dropping the locks * (it may either follow vma or precede it). */ - __insert_vm_struct(mm, insert); - } else { - if (start_changed) - vma_gap_update(vma); - if (end_changed) { - if (!next) - mm->highest_vm_end = vm_end_gap(vma); - else if (!adjust_next) - vma_gap_update(next); - } + mas_reset(&mas); + vma_mas_store(insert, &mas); + mm->map_count++; } if (anon_vma) { @@ -909,6 +828,7 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, } if (remove_next) { +again: if (file) { uprobe_munmap(next, next->vm_start, next->vm_end); fput(file); @@ -917,66 +837,24 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, anon_vma_merge(vma, next); mm->map_count--; mpol_put(vma_policy(next)); + if (remove_next != 2) + BUG_ON(vma->vm_end < next->vm_end); vm_area_free(next); + /* * In mprotect's case 6 (see comments on vma_merge), - * we must remove another next too. It would clutter - * up the code too much to do both in one go. + * we must remove next_next too. */ - if (remove_next != 3) { - /* - * If "next" was removed and vma->vm_end was - * expanded (up) over it, in turn - * "next->vm_prev->vm_end" changed and the - * "vma->vm_next" gap must be updated. - */ - next = vma->vm_next; - } else { - /* - * For the scope of the comment "next" and - * "vma" considered pre-swap(): if "vma" was - * removed, next->vm_start was expanded (down) - * over it and the "next" gap must be updated. - * Because of the swap() the post-swap() "vma" - * actually points to pre-swap() "next" - * (post-swap() "next" as opposed is now a - * dangling pointer). - */ - next = vma; - } if (remove_next == 2) { remove_next = 1; - end = next->vm_end; + next = next_next; goto again; } - else if (next) - vma_gap_update(next); - else { - /* - * If remove_next == 2 we obviously can't - * reach this path. - * - * If remove_next == 3 we can't reach this - * path because pre-swap() next is always not - * NULL. pre-swap() "next" is not being - * removed and its next->vm_end is not altered - * (and furthermore "end" already matches - * next->vm_end in remove_next == 3). - * - * We reach this only in the remove_next == 1 - * case if the "next" vma that was removed was - * the highest vma of the mm. However in such - * case next->vm_end == "end" and the extended - * "vma" has vma->vm_end == next->vm_end so - * mm->highest_vm_end doesn't need any update - * in remove_next == 1 case. - */ - VM_WARN_ON(mm->highest_vm_end != vm_end_gap(vma)); - } } if (insert && file) uprobe_mmap(insert); + mas_destroy(&mas); validate_mm(mm); return 0; @@ -1138,10 +1016,10 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm, if (vm_flags & VM_SPECIAL) return NULL; - next = vma_next(mm, prev); + next = find_vma(mm, prev ? prev->vm_end : 0); area = next; if (area && area->vm_end == end) /* cases 6, 7, 8 */ - next = next->vm_next; + next = find_vma(mm, next->vm_end); /* verify some invariant that must be enforced by the caller */ VM_WARN_ON(prev && addr <= prev->vm_start); @@ -1275,18 +1153,24 @@ static struct anon_vma *reusable_anon_vma(struct vm_area_struct *old, struct vm_ */ struct anon_vma *find_mergeable_anon_vma(struct vm_area_struct *vma) { + MA_STATE(mas, &vma->vm_mm->mm_mt, vma->vm_end, vma->vm_end); struct anon_vma *anon_vma = NULL; + struct vm_area_struct *prev, *next; /* Try next first. */ - if (vma->vm_next) { - anon_vma = reusable_anon_vma(vma->vm_next, vma, vma->vm_next); + next = mas_walk(&mas); + if (next) { + anon_vma = reusable_anon_vma(next, vma, next); if (anon_vma) return anon_vma; } + prev = mas_prev(&mas, 0); + VM_BUG_ON_VMA(prev != vma, vma); + prev = mas_prev(&mas, 0); /* Try prev next. */ - if (vma->vm_prev) - anon_vma = reusable_anon_vma(vma->vm_prev, vma->vm_prev, vma); + if (prev) + anon_vma = reusable_anon_vma(prev, prev, vma); /* * We might reach here with anon_vma == NULL if we can't find @@ -1375,6 +1259,7 @@ unsigned long do_mmap(struct file *file, unsigned long addr, vm_flags_t vm_flags; int pkey = 0; + validate_mm(mm); *populate = 0; if (!len) @@ -1678,391 +1563,63 @@ static inline int accountable_mapping(struct file *file, vm_flags_t vm_flags) return (vm_flags & (VM_NORESERVE | VM_SHARED | VM_WRITE)) == VM_WRITE; } -unsigned long mmap_region(struct file *file, unsigned long addr, - unsigned long len, vm_flags_t vm_flags, unsigned long pgoff, - struct list_head *uf) +/** + * unmapped_area() - Find an area between the low_limit and the high_limit with + * the correct alignment and offset, all from @info. Note: current->mm is used + * for the search. + * + * @info: The unmapped area information including the range (low_limit - + * hight_limit), the alignment offset and mask. + * + * Return: A memory address or -ENOMEM. + */ +static unsigned long unmapped_area(struct vm_unmapped_area_info *info) { - struct mm_struct *mm = current->mm; - struct vm_area_struct *vma, *prev, *merge; - int error; - struct rb_node **rb_link, *rb_parent; - unsigned long charged = 0; - - /* Check against address space limit. */ - if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) { - unsigned long nr_pages; - - /* - * MAP_FIXED may remove pages of mappings that intersects with - * requested mapping. Account for the pages it would unmap. - */ - nr_pages = count_vma_pages_range(mm, addr, addr + len); - - if (!may_expand_vm(mm, vm_flags, - (len >> PAGE_SHIFT) - nr_pages)) - return -ENOMEM; - } - - /* Clear old maps, set up prev, rb_link, rb_parent, and uf */ - if (munmap_vma_range(mm, addr, len, &prev, &rb_link, &rb_parent, uf)) - return -ENOMEM; - /* - * Private writable mapping: check memory availability - */ - if (accountable_mapping(file, vm_flags)) { - charged = len >> PAGE_SHIFT; - if (security_vm_enough_memory_mm(mm, charged)) - return -ENOMEM; - vm_flags |= VM_ACCOUNT; - } - - /* - * Can we just expand an old mapping? - */ - vma = vma_merge(mm, prev, addr, addr + len, vm_flags, - NULL, file, pgoff, NULL, NULL_VM_UFFD_CTX, NULL); - if (vma) - goto out; - - /* - * Determine the object being mapped and call the appropriate - * specific mapper. the address has already been validated, but - * not unmapped, but the maps are removed from the list. - */ - vma = vm_area_alloc(mm); - if (!vma) { - error = -ENOMEM; - goto unacct_error; - } - - vma->vm_start = addr; - vma->vm_end = addr + len; - vma->vm_flags = vm_flags; - vma->vm_page_prot = vm_get_page_prot(vm_flags); - vma->vm_pgoff = pgoff; - - if (file) { - if (vm_flags & VM_SHARED) { - error = mapping_map_writable(file->f_mapping); - if (error) - goto free_vma; - } - - vma->vm_file = get_file(file); - error = call_mmap(file, vma); - if (error) - goto unmap_and_free_vma; - - /* Can addr have changed?? - * - * Answer: Yes, several device drivers can do it in their - * f_op->mmap method. -DaveM - * Bug: If addr is changed, prev, rb_link, rb_parent should - * be updated for vma_link() - */ - WARN_ON_ONCE(addr != vma->vm_start); - - addr = vma->vm_start; - - /* If vm_flags changed after call_mmap(), we should try merge vma again - * as we may succeed this time. - */ - if (unlikely(vm_flags != vma->vm_flags && prev)) { - merge = vma_merge(mm, prev, vma->vm_start, vma->vm_end, vma->vm_flags, - NULL, vma->vm_file, vma->vm_pgoff, NULL, NULL_VM_UFFD_CTX, NULL); - if (merge) { - /* ->mmap() can change vma->vm_file and fput the original file. So - * fput the vma->vm_file here or we would add an extra fput for file - * and cause general protection fault ultimately. - */ - fput(vma->vm_file); - vm_area_free(vma); - vma = merge; - /* Update vm_flags to pick up the change. */ - vm_flags = vma->vm_flags; - goto unmap_writable; - } - } - - vm_flags = vma->vm_flags; - } else if (vm_flags & VM_SHARED) { - error = shmem_zero_setup(vma); - if (error) - goto free_vma; - } else { - vma_set_anonymous(vma); - } - - /* Allow architectures to sanity-check the vm_flags */ - if (!arch_validate_flags(vma->vm_flags)) { - error = -EINVAL; - if (file) - goto close_and_free_vma; - else - goto free_vma; - } - - vma_link(mm, vma, prev, rb_link, rb_parent); - - /* - * vma_merge() calls khugepaged_enter_vma() either, the below - * call covers the non-merge case. - */ - khugepaged_enter_vma(vma, vma->vm_flags); - - /* Once vma denies write, undo our temporary denial count */ -unmap_writable: - if (file && vm_flags & VM_SHARED) - mapping_unmap_writable(file->f_mapping); - file = vma->vm_file; -out: - perf_event_mmap(vma); - - vm_stat_account(mm, vm_flags, len >> PAGE_SHIFT); - if (vm_flags & VM_LOCKED) { - if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) || - is_vm_hugetlb_page(vma) || - vma == get_gate_vma(current->mm)) - vma->vm_flags &= VM_LOCKED_CLEAR_MASK; - else - mm->locked_vm += (len >> PAGE_SHIFT); - } - - if (file) - uprobe_mmap(vma); - - /* - * New (or expanded) vma always get soft dirty status. - * Otherwise user-space soft-dirty page tracker won't - * be able to distinguish situation when vma area unmapped, - * then new mapped in-place (which must be aimed as - * a completely new data area). - */ - vma->vm_flags |= VM_SOFTDIRTY; - - vma_set_page_prot(vma); - - return addr; - -close_and_free_vma: - if (vma->vm_ops && vma->vm_ops->close) - vma->vm_ops->close(vma); -unmap_and_free_vma: - fput(vma->vm_file); - vma->vm_file = NULL; - - /* Undo any partial mapping done by a device driver. */ - unmap_region(mm, vma, prev, vma->vm_start, vma->vm_end); - if (vm_flags & VM_SHARED) - mapping_unmap_writable(file->f_mapping); -free_vma: - vm_area_free(vma); -unacct_error: - if (charged) - vm_unacct_memory(charged); - return error; -} - -static unsigned long unmapped_area(struct vm_unmapped_area_info *info) -{ - /* - * We implement the search by looking for an rbtree node that - * immediately follows a suitable gap. That is, - * - gap_start = vma->vm_prev->vm_end <= info->high_limit - length; - * - gap_end = vma->vm_start >= info->low_limit + length; - * - gap_end - gap_start >= length - */ + unsigned long length, gap; - struct mm_struct *mm = current->mm; - struct vm_area_struct *vma; - unsigned long length, low_limit, high_limit, gap_start, gap_end; + MA_STATE(mas, ¤t->mm->mm_mt, 0, 0); /* Adjust search length to account for worst case alignment overhead */ length = info->length + info->align_mask; if (length < info->length) return -ENOMEM; - /* Adjust search limits by the desired length */ - if (info->high_limit < length) - return -ENOMEM; - high_limit = info->high_limit - length; - - if (info->low_limit > high_limit) - return -ENOMEM; - low_limit = info->low_limit + length; - - /* Check if rbtree root looks promising */ - if (RB_EMPTY_ROOT(&mm->mm_rb)) - goto check_highest; - vma = rb_entry(mm->mm_rb.rb_node, struct vm_area_struct, vm_rb); - if (vma->rb_subtree_gap < length) - goto check_highest; - - while (true) { - /* Visit left subtree if it looks promising */ - gap_end = vm_start_gap(vma); - if (gap_end >= low_limit && vma->vm_rb.rb_left) { - struct vm_area_struct *left = - rb_entry(vma->vm_rb.rb_left, - struct vm_area_struct, vm_rb); - if (left->rb_subtree_gap >= length) { - vma = left; - continue; - } - } - - gap_start = vma->vm_prev ? vm_end_gap(vma->vm_prev) : 0; -check_current: - /* Check if current node has a suitable gap */ - if (gap_start > high_limit) - return -ENOMEM; - if (gap_end >= low_limit && - gap_end > gap_start && gap_end - gap_start >= length) - goto found; - - /* Visit right subtree if it looks promising */ - if (vma->vm_rb.rb_right) { - struct vm_area_struct *right = - rb_entry(vma->vm_rb.rb_right, - struct vm_area_struct, vm_rb); - if (right->rb_subtree_gap >= length) { - vma = right; - continue; - } - } - - /* Go back up the rbtree to find next candidate node */ - while (true) { - struct rb_node *prev = &vma->vm_rb; - if (!rb_parent(prev)) - goto check_highest; - vma = rb_entry(rb_parent(prev), - struct vm_area_struct, vm_rb); - if (prev == vma->vm_rb.rb_left) { - gap_start = vm_end_gap(vma->vm_prev); - gap_end = vm_start_gap(vma); - goto check_current; - } - } - } - -check_highest: - /* Check highest gap, which does not precede any rbtree node */ - gap_start = mm->highest_vm_end; - gap_end = ULONG_MAX; /* Only for VM_BUG_ON below */ - if (gap_start > high_limit) + if (mas_empty_area(&mas, info->low_limit, info->high_limit - 1, + length)) return -ENOMEM; -found: - /* We found a suitable gap. Clip it with the original low_limit. */ - if (gap_start < info->low_limit) - gap_start = info->low_limit; - - /* Adjust gap address to the desired alignment */ - gap_start += (info->align_offset - gap_start) & info->align_mask; - - VM_BUG_ON(gap_start + info->length > info->high_limit); - VM_BUG_ON(gap_start + info->length > gap_end); - return gap_start; + gap = mas.index; + gap += (info->align_offset - gap) & info->align_mask; + return gap; } +/** + * unmapped_area_topdown() - Find an area between the low_limit and the + * high_limit with * the correct alignment and offset at the highest available + * address, all from @info. Note: current->mm is used for the search. + * + * @info: The unmapped area information including the range (low_limit - + * hight_limit), the alignment offset and mask. + * + * Return: A memory address or -ENOMEM. + */ static unsigned long unmapped_area_topdown(struct vm_unmapped_area_info *info) { - struct mm_struct *mm = current->mm; - struct vm_area_struct *vma; - unsigned long length, low_limit, high_limit, gap_start, gap_end; + unsigned long length, gap; + MA_STATE(mas, ¤t->mm->mm_mt, 0, 0); /* Adjust search length to account for worst case alignment overhead */ length = info->length + info->align_mask; if (length < info->length) return -ENOMEM; - /* - * Adjust search limits by the desired length. - * See implementation comment at top of unmapped_area(). - */ - gap_end = info->high_limit; - if (gap_end < length) - return -ENOMEM; - high_limit = gap_end - length; - - if (info->low_limit > high_limit) - return -ENOMEM; - low_limit = info->low_limit + length; - - /* Check highest gap, which does not precede any rbtree node */ - gap_start = mm->highest_vm_end; - if (gap_start <= high_limit) - goto found_highest; - - /* Check if rbtree root looks promising */ - if (RB_EMPTY_ROOT(&mm->mm_rb)) + if (mas_empty_area_rev(&mas, info->low_limit, info->high_limit - 1, + length)) return -ENOMEM; - vma = rb_entry(mm->mm_rb.rb_node, struct vm_area_struct, vm_rb); - if (vma->rb_subtree_gap < length) - return -ENOMEM; - - while (true) { - /* Visit right subtree if it looks promising */ - gap_start = vma->vm_prev ? vm_end_gap(vma->vm_prev) : 0; - if (gap_start <= high_limit && vma->vm_rb.rb_right) { - struct vm_area_struct *right = - rb_entry(vma->vm_rb.rb_right, - struct vm_area_struct, vm_rb); - if (right->rb_subtree_gap >= length) { - vma = right; - continue; - } - } - -check_current: - /* Check if current node has a suitable gap */ - gap_end = vm_start_gap(vma); - if (gap_end < low_limit) - return -ENOMEM; - if (gap_start <= high_limit && - gap_end > gap_start && gap_end - gap_start >= length) - goto found; - - /* Visit left subtree if it looks promising */ - if (vma->vm_rb.rb_left) { - struct vm_area_struct *left = - rb_entry(vma->vm_rb.rb_left, - struct vm_area_struct, vm_rb); - if (left->rb_subtree_gap >= length) { - vma = left; - continue; - } - } - /* Go back up the rbtree to find next candidate node */ - while (true) { - struct rb_node *prev = &vma->vm_rb; - if (!rb_parent(prev)) - return -ENOMEM; - vma = rb_entry(rb_parent(prev), - struct vm_area_struct, vm_rb); - if (prev == vma->vm_rb.rb_right) { - gap_start = vma->vm_prev ? - vm_end_gap(vma->vm_prev) : 0; - goto check_current; - } - } - } - -found: - /* We found a suitable gap. Clip it with the original high_limit. */ - if (gap_end > info->high_limit) - gap_end = info->high_limit; - -found_highest: - /* Compute highest gap address at the desired alignment */ - gap_end -= info->length; - gap_end -= (gap_end - info->align_offset) & info->align_mask; - - VM_BUG_ON(gap_end < info->low_limit); - VM_BUG_ON(gap_end < gap_start); - return gap_end; + gap = mas.last + 1 - info->length; + gap -= (gap - info->align_offset) & info->align_mask; + return gap; } /* @@ -2252,58 +1809,67 @@ get_unmapped_area(struct file *file, unsigned long addr, unsigned long len, EXPORT_SYMBOL(get_unmapped_area); -/* Look up the first VMA which satisfies addr < vm_end, NULL if none. */ -struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +/** + * find_vma_intersection() - Look up the first VMA which intersects the interval + * @mm: The process address space. + * @start_addr: The inclusive start user address. + * @end_addr: The exclusive end user address. + * + * Returns: The first VMA within the provided range, %NULL otherwise. Assumes + * start_addr < end_addr. + */ +struct vm_area_struct *find_vma_intersection(struct mm_struct *mm, + unsigned long start_addr, + unsigned long end_addr) { - struct rb_node *rb_node; - struct vm_area_struct *vma; + unsigned long index = start_addr; mmap_assert_locked(mm); - /* Check the cache first. */ - vma = vmacache_find(mm, addr); - if (likely(vma)) - return vma; - - rb_node = mm->mm_rb.rb_node; - - while (rb_node) { - struct vm_area_struct *tmp; - - tmp = rb_entry(rb_node, struct vm_area_struct, vm_rb); + return mt_find(&mm->mm_mt, &index, end_addr - 1); +} +EXPORT_SYMBOL(find_vma_intersection); - if (tmp->vm_end > addr) { - vma = tmp; - if (tmp->vm_start <= addr) - break; - rb_node = rb_node->rb_left; - } else - rb_node = rb_node->rb_right; - } +/** + * find_vma() - Find the VMA for a given address, or the next VMA. + * @mm: The mm_struct to check + * @addr: The address + * + * Returns: The VMA associated with addr, or the next VMA. + * May return %NULL in the case of no VMA at addr or above. + */ +struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +{ + unsigned long index = addr; - if (vma) - vmacache_update(addr, vma); - return vma; + mmap_assert_locked(mm); + return mt_find(&mm->mm_mt, &index, ULONG_MAX); } - EXPORT_SYMBOL(find_vma); -/* - * Same as find_vma, but also return a pointer to the previous VMA in *pprev. +/** + * find_vma_prev() - Find the VMA for a given address, or the next vma and + * set %pprev to the previous VMA, if any. + * @mm: The mm_struct to check + * @addr: The address + * @pprev: The pointer to set to the previous VMA + * + * Note that RCU lock is missing here since the external mmap_lock() is used + * instead. + * + * Returns: The VMA associated with @addr, or the next vma. + * May return %NULL in the case of no vma at addr or above. */ struct vm_area_struct * find_vma_prev(struct mm_struct *mm, unsigned long addr, struct vm_area_struct **pprev) { struct vm_area_struct *vma; + MA_STATE(mas, &mm->mm_mt, addr, addr); - vma = find_vma(mm, addr); - if (vma) { - *pprev = vma->vm_prev; - } else { - struct rb_node *rb_node = rb_last(&mm->mm_rb); - - *pprev = rb_node ? rb_entry(rb_node, struct vm_area_struct, vm_rb) : NULL; - } + vma = mas_walk(&mas); + *pprev = mas_prev(&mas, 0); + if (!vma) + vma = mas_next(&mas, ULONG_MAX); return vma; } @@ -2357,6 +1923,7 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) struct vm_area_struct *next; unsigned long gap_addr; int error = 0; + MA_STATE(mas, &mm->mm_mt, 0, 0); if (!(vma->vm_flags & VM_GROWSUP)) return -EFAULT; @@ -2374,16 +1941,21 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) if (gap_addr < address || gap_addr > TASK_SIZE) gap_addr = TASK_SIZE; - next = vma->vm_next; - if (next && next->vm_start < gap_addr && vma_is_accessible(next)) { + next = find_vma_intersection(mm, vma->vm_end, gap_addr); + if (next && vma_is_accessible(next)) { if (!(next->vm_flags & VM_GROWSUP)) return -ENOMEM; /* Check that both stack segments have the same anon_vma? */ } + if (mas_preallocate(&mas, vma, GFP_KERNEL)) + return -ENOMEM; + /* We must make sure the anon_vma is allocated. */ - if (unlikely(anon_vma_prepare(vma))) + if (unlikely(anon_vma_prepare(vma))) { + mas_destroy(&mas); return -ENOMEM; + } /* * vma->vm_start/vm_end cannot change under us because the caller @@ -2404,15 +1976,13 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) error = acct_stack_growth(vma, size, grow); if (!error) { /* - * vma_gap_update() doesn't support concurrent - * updates, but we only hold a shared mmap_lock - * lock here, so we need to protect against - * concurrent vma expansions. - * anon_vma_lock_write() doesn't help here, as - * we don't guarantee that all growable vmas - * in a mm share the same root anon vma. - * So, we reuse mm->page_table_lock to guard - * against concurrent vma expansions. + * We only hold a shared mmap_lock lock here, so + * we need to protect against concurrent vma + * expansions. anon_vma_lock_write() doesn't + * help here, as we don't guarantee that all + * growable vmas in a mm share the same root + * anon vma. So, we reuse mm->page_table_lock + * to guard against concurrent vma expansions. */ spin_lock(&mm->page_table_lock); if (vma->vm_flags & VM_LOCKED) @@ -2420,11 +1990,9 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) vm_stat_account(mm, vma->vm_flags, grow); anon_vma_interval_tree_pre_update_vma(vma); vma->vm_end = address; + /* Overwrite old entry in mtree. */ + vma_mas_store(vma, &mas); anon_vma_interval_tree_post_update_vma(vma); - if (vma->vm_next) - vma_gap_update(vma->vm_next); - else - mm->highest_vm_end = vm_end_gap(vma); spin_unlock(&mm->page_table_lock); perf_event_mmap(vma); @@ -2433,7 +2001,7 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) } anon_vma_unlock_write(vma->anon_vma); khugepaged_enter_vma(vma, vma->vm_flags); - validate_mm(mm); + mas_destroy(&mas); return error; } #endif /* CONFIG_STACK_GROWSUP || CONFIG_IA64 */ @@ -2441,10 +2009,10 @@ int expand_upwards(struct vm_area_struct *vma, unsigned long address) /* * vma is the first one with address < vma->vm_start. Have to extend vma. */ -int expand_downwards(struct vm_area_struct *vma, - unsigned long address) +int expand_downwards(struct vm_area_struct *vma, unsigned long address) { struct mm_struct *mm = vma->vm_mm; + MA_STATE(mas, &mm->mm_mt, vma->vm_start, vma->vm_start); struct vm_area_struct *prev; int error = 0; @@ -2453,7 +2021,7 @@ int expand_downwards(struct vm_area_struct *vma, return -EPERM; /* Enforce stack_guard_gap */ - prev = vma->vm_prev; + prev = mas_prev(&mas, 0); /* Check that both stack segments have the same anon_vma? */ if (prev && !(prev->vm_flags & VM_GROWSDOWN) && vma_is_accessible(prev)) { @@ -2461,9 +2029,14 @@ int expand_downwards(struct vm_area_struct *vma, return -ENOMEM; } + if (mas_preallocate(&mas, vma, GFP_KERNEL)) + return -ENOMEM; + /* We must make sure the anon_vma is allocated. */ - if (unlikely(anon_vma_prepare(vma))) + if (unlikely(anon_vma_prepare(vma))) { + mas_destroy(&mas); return -ENOMEM; + } /* * vma->vm_start/vm_end cannot change under us because the caller @@ -2484,15 +2057,13 @@ int expand_downwards(struct vm_area_struct *vma, error = acct_stack_growth(vma, size, grow); if (!error) { /* - * vma_gap_update() doesn't support concurrent - * updates, but we only hold a shared mmap_lock - * lock here, so we need to protect against - * concurrent vma expansions. - * anon_vma_lock_write() doesn't help here, as - * we don't guarantee that all growable vmas - * in a mm share the same root anon vma. - * So, we reuse mm->page_table_lock to guard - * against concurrent vma expansions. + * We only hold a shared mmap_lock lock here, so + * we need to protect against concurrent vma + * expansions. anon_vma_lock_write() doesn't + * help here, as we don't guarantee that all + * growable vmas in a mm share the same root + * anon vma. So, we reuse mm->page_table_lock + * to guard against concurrent vma expansions. */ spin_lock(&mm->page_table_lock); if (vma->vm_flags & VM_LOCKED) @@ -2501,8 +2072,9 @@ int expand_downwards(struct vm_area_struct *vma, anon_vma_interval_tree_pre_update_vma(vma); vma->vm_start = address; vma->vm_pgoff -= grow; + /* Overwrite old entry in mtree. */ + vma_mas_store(vma, &mas); anon_vma_interval_tree_post_update_vma(vma); - vma_gap_update(vma); spin_unlock(&mm->page_table_lock); perf_event_mmap(vma); @@ -2511,7 +2083,7 @@ int expand_downwards(struct vm_area_struct *vma, } anon_vma_unlock_write(vma->anon_vma); khugepaged_enter_vma(vma, vma->vm_flags); - validate_mm(mm); + mas_destroy(&mas); return error; } @@ -2584,25 +2156,26 @@ find_extend_vma(struct mm_struct *mm, unsigned long addr) EXPORT_SYMBOL_GPL(find_extend_vma); /* - * Ok - we have the memory areas we should free on the vma list, - * so release them, and do the vma updates. + * Ok - we have the memory areas we should free on a maple tree so release them, + * and do the vma updates. * * Called with the mm semaphore held. */ -static void remove_vma_list(struct mm_struct *mm, struct vm_area_struct *vma) +static inline void remove_mt(struct mm_struct *mm, struct ma_state *mas) { unsigned long nr_accounted = 0; + struct vm_area_struct *vma; /* Update high watermark before we lower total_vm */ update_hiwater_vm(mm); - do { + mas_for_each(mas, vma, ULONG_MAX) { long nrpages = vma_pages(vma); if (vma->vm_flags & VM_ACCOUNT) nr_accounted += nrpages; vm_stat_account(mm, vma->vm_flags, -nrpages); - vma = remove_vma(vma); - } while (vma); + remove_vma(vma); + } vm_unacct_memory(nr_accounted); validate_mm(mm); } @@ -2612,75 +2185,32 @@ static void remove_vma_list(struct mm_struct *mm, struct vm_area_struct *vma) * * Called with the mm semaphore held. */ -static void unmap_region(struct mm_struct *mm, +static void unmap_region(struct mm_struct *mm, struct maple_tree *mt, struct vm_area_struct *vma, struct vm_area_struct *prev, + struct vm_area_struct *next, unsigned long start, unsigned long end) { - struct vm_area_struct *next = vma_next(mm, prev); struct mmu_gather tlb; lru_add_drain(); tlb_gather_mmu(&tlb, mm); update_hiwater_rss(mm); - unmap_vmas(&tlb, vma, start, end); - free_pgtables(&tlb, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS, + unmap_vmas(&tlb, mt, vma, start, end); + free_pgtables(&tlb, mt, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS, next ? next->vm_start : USER_PGTABLES_CEILING); tlb_finish_mmu(&tlb); } /* - * Create a list of vma's touched by the unmap, removing them from the mm's - * vma list as we go.. + * __split_vma() bypasses sysctl_max_map_count checking. We use this where it + * has already been checked or doesn't make sense to fail. */ -static bool -detach_vmas_to_be_unmapped(struct mm_struct *mm, struct vm_area_struct *vma, - struct vm_area_struct *prev, unsigned long end) +int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma, + unsigned long addr, int new_below) { - struct vm_area_struct **insertion_point; - struct vm_area_struct *tail_vma = NULL; - - insertion_point = (prev ? &prev->vm_next : &mm->mmap); - vma->vm_prev = NULL; - do { - vma_rb_erase(vma, &mm->mm_rb); - if (vma->vm_flags & VM_LOCKED) - mm->locked_vm -= vma_pages(vma); - mm->map_count--; - tail_vma = vma; - vma = vma->vm_next; - } while (vma && vma->vm_start < end); - *insertion_point = vma; - if (vma) { - vma->vm_prev = prev; - vma_gap_update(vma); - } else - mm->highest_vm_end = prev ? vm_end_gap(prev) : 0; - tail_vma->vm_next = NULL; - - /* Kill the cache */ - vmacache_invalidate(mm); - - /* - * Do not downgrade mmap_lock if we are next to VM_GROWSDOWN or - * VM_GROWSUP VMA. Such VMAs can change their size under - * down_read(mmap_lock) and collide with the VMA we are about to unmap. - */ - if (vma && (vma->vm_flags & VM_GROWSDOWN)) - return false; - if (prev && (prev->vm_flags & VM_GROWSUP)) - return false; - return true; -} - -/* - * __split_vma() bypasses sysctl_max_map_count checking. We use this where it - * has already been checked or doesn't make sense to fail. - */ -int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma, - unsigned long addr, int new_below) -{ - struct vm_area_struct *new; - int err; + struct vm_area_struct *new; + int err; + validate_mm_mt(mm); if (vma->vm_ops && vma->vm_ops->may_split) { err = vma->vm_ops->may_split(vma, addr); @@ -2723,6 +2253,9 @@ int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma, if (!err) return 0; + /* Avoid vm accounting in close() operation */ + new->vm_start = new->vm_end; + new->vm_pgoff = 0; /* Clean everything up if vma_adjust failed. */ if (new->vm_ops && new->vm_ops->close) new->vm_ops->close(new); @@ -2733,6 +2266,7 @@ int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma, mpol_put(vma_policy(new)); out_free_vma: vm_area_free(new); + validate_mm_mt(mm); return err; } @@ -2749,38 +2283,48 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma, return __split_vma(mm, vma, addr, new_below); } -/* Munmap is split into 2 main parts -- this part which finds - * what needs doing, and the areas themselves, which do the - * work. This now handles partial unmappings. - * Jeremy Fitzhardinge - */ -int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, - struct list_head *uf, bool downgrade) +static inline int munmap_sidetree(struct vm_area_struct *vma, + struct ma_state *mas_detach) { - unsigned long end; - struct vm_area_struct *vma, *prev, *last; - - if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE-start) - return -EINVAL; + mas_set_range(mas_detach, vma->vm_start, vma->vm_end - 1); + if (mas_store_gfp(mas_detach, vma, GFP_KERNEL)) + return -ENOMEM; - len = PAGE_ALIGN(len); - end = start + len; - if (len == 0) - return -EINVAL; + if (vma->vm_flags & VM_LOCKED) + vma->vm_mm->locked_vm -= vma_pages(vma); - /* - * arch_unmap() might do unmaps itself. It must be called - * and finish any rbtree manipulation before this code - * runs and also starts to manipulate the rbtree. - */ - arch_unmap(mm, start, end); + return 0; +} - /* Find the first overlapping VMA where start < vma->vm_end */ - vma = find_vma_intersection(mm, start, end); - if (!vma) - return 0; - prev = vma->vm_prev; +/* + * do_mas_align_munmap() - munmap the aligned region from @start to @end. + * @mas: The maple_state, ideally set up to alter the correct tree location. + * @vma: The starting vm_area_struct + * @mm: The mm_struct + * @start: The aligned start address to munmap. + * @end: The aligned end address to munmap. + * @uf: The userfaultfd list_head + * @downgrade: Set to true to attempt a write downgrade of the mmap_sem + * + * If @downgrade is true, check return code for potential release of the lock. + */ +static int +do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, + struct mm_struct *mm, unsigned long start, + unsigned long end, struct list_head *uf, bool downgrade) +{ + struct vm_area_struct *prev, *next = NULL; + struct maple_tree mt_detach; + int count = 0; + int error = -ENOMEM; + MA_STATE(mas_detach, &mt_detach, 0, 0); + mt_init_flags(&mt_detach, MT_FLAGS_LOCK_EXTERN); + mt_set_external_lock(&mt_detach, &mm->mmap_lock); + + if (mas_preallocate(mas, vma, GFP_KERNEL)) + return -ENOMEM; + mas->last = end - 1; /* * If we need to split any vma, do it now to save pain later. * @@ -2788,8 +2332,9 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, * unmapped vm_area_struct will remain in use: so lower split_vma * places tmp vma above, and higher split_vma places tmp vma below. */ + + /* Does it split the first one? */ if (start > vma->vm_start) { - int error; /* * Make sure that map_count on return from munmap() will @@ -2797,22 +2342,61 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, * its limit temporarily, to help free resources as expected. */ if (end < vma->vm_end && mm->map_count >= sysctl_max_map_count) - return -ENOMEM; + goto map_count_exceeded; + /* + * mas_pause() is not needed since mas->index needs to be set + * differently than vma->vm_end anyways. + */ error = __split_vma(mm, vma, start, 0); if (error) - return error; - prev = vma; + goto start_split_failed; + + mas_set(mas, start); + vma = mas_walk(mas); } - /* Does it split the last one? */ - last = find_vma(mm, end); - if (last && end > last->vm_start) { - int error = __split_vma(mm, last, end, 1); + prev = mas_prev(mas, 0); + if (unlikely((!prev))) + mas_set(mas, start); + + /* + * Detach a range of VMAs from the mm. Using next as a temp variable as + * it is always overwritten. + */ + mas_for_each(mas, next, end - 1) { + /* Does it split the end? */ + if (next->vm_end > end) { + struct vm_area_struct *split; + + error = __split_vma(mm, next, end, 1); + if (error) + goto end_split_failed; + + mas_set(mas, end); + split = mas_prev(mas, 0); + error = munmap_sidetree(split, &mas_detach); + if (error) + goto munmap_sidetree_failed; + + count++; + if (vma == next) + vma = split; + break; + } + error = munmap_sidetree(next, &mas_detach); if (error) - return error; + goto munmap_sidetree_failed; + + count++; +#ifdef CONFIG_DEBUG_VM_MAPLE_TREE + BUG_ON(next->vm_start < start); + BUG_ON(next->vm_start > end); +#endif } - vma = vma_next(mm, prev); + + if (!next) + next = mas_next(mas, ULONG_MAX); if (unlikely(uf)) { /* @@ -2824,30 +2408,372 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len, * split, despite we could. This is unlikely enough * failure that it's not worth optimizing it for. */ - int error = userfaultfd_unmap_prep(vma, start, end, uf); + error = userfaultfd_unmap_prep(mm, start, end, uf); + if (error) - return error; + goto userfaultfd_error; + } + + /* Point of no return */ + mas_set_range(mas, start, end - 1); +#if defined(CONFIG_DEBUG_VM_MAPLE_TREE) + /* Make sure no VMAs are about to be lost. */ + { + MA_STATE(test, &mt_detach, start, end - 1); + struct vm_area_struct *vma_mas, *vma_test; + int test_count = 0; + + rcu_read_lock(); + vma_test = mas_find(&test, end - 1); + mas_for_each(mas, vma_mas, end - 1) { + BUG_ON(vma_mas != vma_test); + test_count++; + vma_test = mas_next(&test, end - 1); + } + rcu_read_unlock(); + BUG_ON(count != test_count); + mas_set_range(mas, start, end - 1); + } +#endif + mas_store_prealloc(mas, NULL); + mm->map_count -= count; + /* + * Do not downgrade mmap_lock if we are next to VM_GROWSDOWN or + * VM_GROWSUP VMA. Such VMAs can change their size under + * down_read(mmap_lock) and collide with the VMA we are about to unmap. + */ + if (downgrade) { + if (next && (next->vm_flags & VM_GROWSDOWN)) + downgrade = false; + else if (prev && (prev->vm_flags & VM_GROWSUP)) + downgrade = false; + else + mmap_write_downgrade(mm); } - /* Detach vmas from rbtree */ - if (!detach_vmas_to_be_unmapped(mm, vma, prev, end)) - downgrade = false; + unmap_region(mm, &mt_detach, vma, prev, next, start, end); + /* Statistics and freeing VMAs */ + mas_set(&mas_detach, start); + remove_mt(mm, &mas_detach); + __mt_destroy(&mt_detach); - if (downgrade) - mmap_write_downgrade(mm); - unmap_region(mm, vma, prev, start, end); + validate_mm(mm); + return downgrade ? 1 : 0; - /* Fix up all other VM information */ - remove_vma_list(mm, vma); +userfaultfd_error: +munmap_sidetree_failed: +end_split_failed: + __mt_destroy(&mt_detach); +start_split_failed: +map_count_exceeded: + mas_destroy(mas); + return error; +} - return downgrade ? 1 : 0; +/* + * do_mas_munmap() - munmap a given range. + * @mas: The maple state + * @mm: The mm_struct + * @start: The start address to munmap + * @len: The length of the range to munmap + * @uf: The userfaultfd list_head + * @downgrade: set to true if the user wants to attempt to write_downgrade the + * mmap_sem + * + * This function takes a @mas that is either pointing to the previous VMA or set + * to MA_START and sets it up to remove the mapping(s). The @len will be + * aligned and any arch_unmap work will be preformed. + * + * Returns: -EINVAL on failure, 1 on success and unlock, 0 otherwise. + */ +int do_mas_munmap(struct ma_state *mas, struct mm_struct *mm, + unsigned long start, size_t len, struct list_head *uf, + bool downgrade) +{ + unsigned long end; + struct vm_area_struct *vma; + + if ((offset_in_page(start)) || start > TASK_SIZE || len > TASK_SIZE-start) + return -EINVAL; + + end = start + PAGE_ALIGN(len); + if (end == start) + return -EINVAL; + + /* arch_unmap() might do unmaps itself. */ + arch_unmap(mm, start, end); + + /* Find the first overlapping VMA */ + vma = mas_find(mas, end - 1); + if (!vma) + return 0; + + return do_mas_align_munmap(mas, vma, mm, start, end, uf, downgrade); } +/* do_munmap() - Wrapper function for non-maple tree aware do_munmap() calls. + * @mm: The mm_struct + * @start: The start address to munmap + * @len: The length to be munmapped. + * @uf: The userfaultfd list_head + */ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, struct list_head *uf) { - return __do_munmap(mm, start, len, uf, false); + MA_STATE(mas, &mm->mm_mt, start, start); + + return do_mas_munmap(&mas, mm, start, len, uf, false); +} + +unsigned long mmap_region(struct file *file, unsigned long addr, + unsigned long len, vm_flags_t vm_flags, unsigned long pgoff, + struct list_head *uf) +{ + struct mm_struct *mm = current->mm; + struct vm_area_struct *vma = NULL; + struct vm_area_struct *next, *prev, *merge; + pgoff_t pglen = len >> PAGE_SHIFT; + unsigned long charged = 0; + unsigned long end = addr + len; + unsigned long merge_start = addr, merge_end = end; + pgoff_t vm_pgoff; + int error; + MA_STATE(mas, &mm->mm_mt, addr, end - 1); + + /* Check against address space limit. */ + if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) { + unsigned long nr_pages; + + /* + * MAP_FIXED may remove pages of mappings that intersects with + * requested mapping. Account for the pages it would unmap. + */ + nr_pages = count_vma_pages_range(mm, addr, end); + + if (!may_expand_vm(mm