[build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] name = "tabpfn" version = "7.1.1" dependencies = [ "torch>=2.5", "numpy>=1.21.6", "scikit-learn>=1.2.0", "typing_extensions>=4.12.0", "scipy>=1.11.1", "pandas>=1.4.0", "einops>=0.4.0", "huggingface-hub>=0.19.0", "pydantic>=2.8.0", "pydantic-settings>=2.10.1", # eval-type-backport is required on Python 3.9 to enable support for "X | Y" notation # for union types in Pydantic. # Once Python 3.10 is the minimum version, this can be removed. "eval-type-backport>=0.2.2", "joblib>=1.2.0", "tqdm>=4.66.0", "tabpfn-common-utils[telemetry-interactive]>=0.2.13", "filelock>=3.11.0", "lightgbm>=3.0", ] requires-python = ">=3.9" authors = [ { name = "Noah Hollmann" }, { name = "Samuel Müller" }, { name = "Lennart Purucker" }, { name = "Arjun Krishnakumar" }, { name = "Max Körfer" }, { name = "Shi Bin Hoo" }, { name = "Robin Tibor Schirrmeister" }, { name = "Frank Hutter" }, # Huge thanks to code refactoring contributor Eddie { name = "Eddie Bergman" }, # Prior Labs Contributors { name = "Leo Grinsztajn" }, { name = "Felix Jabloski" }, { name = "Klemens Flöge" }, { name = "Oscar Key" }, { name = "Felix Birkel" }, { name = "Philipp Jund" }, { name = "Brendan Roof" }, { name = "Dominik Safaric" }, { name = "Benjamin Jaeger" }, { name = "Alan Arazi" } ] readme = "README.md" description = "TabPFN: Foundation model for tabular data" classifiers = [ 'Intended Audience :: Science/Research', 'Intended Audience :: Developers', 'Programming Language :: Python', 'Topic :: Software Development', 'Topic :: Scientific/Engineering', 'Operating System :: POSIX', 'Operating System :: Unix', 'Operating System :: MacOS', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', ] license = { file = "LICENSE" } [project.optional-dependencies] wandb = ["wandb>=0.25.1"] [project.urls] documentation = "https://priorlabs.ai/docs" source = "https://github.com/priorlabs/tabpfn" [dependency-groups] dev = [ {include-group = "ci"}, # Lint/format "pre-commit>=4.3.0", "ruff==0.14.0", # This version must be the same as in .pre-commit-config.yaml "mypy==1.19.1", # This version must be the same as in .pre-commit-config.yaml # Test "pytest-xdist>=3.8.0", # Changelog "towncrier>=24.8.0", # Docs "mkdocs>=1.6.1", "mkdocs-material>=9.6.21", "mkdocs-autorefs>=1.4.3", "mkdocs-gen-files>=0.5.0", "mkdocs-literate-nav>=0.6.2", "mkdocs-glightbox>=0.5.1", "mkdocstrings[python]>=0.30.1", "markdown-exec[ansi]>=1.11.0", "mike>=2.1.3", # We use Ruff for formatting but this allows mkdocstrings to format signatures in the # docs. "black>=25.9.0", ] # The minimum subset of the dev dependencies required to run the tests on the CI. # The idea is to be as close to the deployment environment as possible. ci = [ "licensecheck>=2025.1.0", "onnx>=1.19.0", # We run onnx export in the tests, but not in the production package. "pytest-mock>=3.14.1", "pytest>=8.4.2", ] [tool.uv] exclude-newer = "7 days" exclude-newer-package = {torch = false} [tool.setuptools.package-data] "tabpfn.architectures.shared" = ["tabpfn_col_embedding.pt"] [tool.pytest.ini_options] testpaths = ["tests"] # Where the tests are located minversion = "8.0" empty_parameter_set_mark = "xfail" # Prevents user error of an empty `parametrize` of a test log_cli = false log_level = "DEBUG" xfail_strict = true addopts = "--durations=10 -vv" markers = [ "slow: slow tests that we run only on merge, rather than in PRs" ] # https://github.com/charliermarsh/ruff [tool.ruff] target-version = "py39" line-length = 88 output-format = "full" src = ["src", "tests", "examples"] extend-exclude = [ "**/*.ipynb" ] [tool.ruff.lint] # Extend what ruff is allowed to fix, even it it may break # This is okay given we use it all the time and it ensures # better practices. Would be dangerous if using for first # time on established project. extend-safe-fixes = ["ALL"] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" select = [ "A", "ANN", # type annotations "ARG", "B", "BLE", "COM", "C4", "D", # "DTZ", # One day I should know how to utilize timezones and dates... "E", # "EXE", Meh "ERA", "F", "FBT", "I", # "ISC", # Favours implicit string concatenation "INP", # "INT", # I don't understand this one "N", "NPY", "PD", "PLC", "PLE", "PLR", "PLW", "PIE", "PT", "PTH", # "PYI", # Specific to .pyi files for type stubs "Q", "PGH004", "RET", "RUF", "C90", "S", # "SLF", # Private member accessed (sure, it's python) "SIM", # "TRY", # Good in principle, would take a lot of work to statisfy "T10", "T20", "TID", "TCH", "UP", "N", "W", "YTT", ] ignore = [ "ERA001", # commented code? "D104", # Missing docstring in public package "D105", # Missing docstring in magic mthod "D203", # 1 blank line required before class docstring "D205", # 1 blank line between summary and description "D401", # First line of docstring should be in imperative mood "N806", # Variable X in function should be lowercase "E731", # Do not assign a lambda expression, use a def "A002", # Shadowing a builtin "A003", # Shadowing a builtin "S101", # Use of assert detected. "W292", # No newline at end of file "PLC1901", # "" can be simplified to be falsey "TC003", # Move stdlib import into TYPE_CHECKING "PLR2004", # Magic numbers, gets in the way a lot "PLR0915", # Too many statements "N803", # Argument name `X` should be lowercase "N802", # Function name should be lowercase "COM812", # Trailing comma missing (conflicts with formatter) "ANN002", # No type annotation for *args is okay. "ANN003", # No type annotation for *kwargs is okay. "ANN204", # No type annotation for __init__ is okay. "ANN401" # We do allow Any annotations, but use responsibly. ] exclude = [ ".bzr", ".direnv", ".eggs", ".git", ".hg", ".mypy_cache", ".nox", ".pants.d", ".ruff_cache", ".svn", ".tox", ".venv", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "venv", "docs", "**/*.ipynb", ] [tool.ruff.lint.per-file-ignores] # These files are copied without changes from extenal repositories, so ignore all rules. "src/tabpfn/misc/debug_versions.py" = ["ALL"] "src/tabpfn/misc/_sklearn_compat.py" = ["ALL"] "tests/*.py" = [ "S101", "D101", "D102", "D103", "ANN001", "ANN201", "FBT001", "D100", "PLR2004", "PD901", # X is a bad variable name. (pandas) "TCH", "N803", "C901", # Too complex ] "__init__.py" = ["I002"] "examples/*" = ["INP001", "I002", "E741", "D101", "D103", "T20", "D415", "ERA001", "E402", "E501", "BLE001"] "docs/*" = ["INP001"] "src/tabpfn/architectures/base/*.py" = [ # Documentation "D100", "D101", "D102", "D103", "D107", ] # TODO(eddiebergman): There's a lot of typing and ruff problems detected here "src/tabpfn/architectures/base/multi_head_attention.py" = [ ] "src/tabpfn/architectures/base/encoders.py" = [ "PT018", "ARG002", "E501", "ERA001", "F821", "FBT001", "FBT002", "A001", ] "src/tabpfn/architectures/base/preprocessing.py" = [ "ANN" ] "src/tabpfn/model_loading.py" = [ "C901" ] "src/tabpfn/*.py" = [ "D107", ] [tool.ruff.lint.isort] known-first-party = ["tabpfn"] known-third-party = ["sklearn"] no-lines-before = ["future"] required-imports = ["from __future__ import annotations"] combine-as-imports = true extra-standard-library = ["typing_extensions"] force-wrap-aliases = true [tool.ruff.lint.pydocstyle] convention = "google" ignore-decorators = ["typing.override", "typing_extensions.override", "typing.overload"] [tool.ruff.lint.pylint] max-args = 10 # Changed from default of 5 [tool.mypy] python_version = "3.9" packages = ["src/tabpfn", "tests"] show_error_codes = true warn_unused_configs = true # warn about unused [tool.mypy] lines follow_imports = "normal" # Type check top level api code we use from imports ignore_missing_imports = false # prefer explicit ignores disallow_untyped_defs = true # All functions must have types disallow_untyped_decorators = true # ... even decorators disallow_incomplete_defs = true # ...all types allow_redefinition = true # Allow redefining types within a scope no_implicit_optional = true check_untyped_defs = true warn_return_any = true [[tool.mypy.overrides]] module = ["tests.*"] disallow_untyped_defs = false # Sometimes we just want to ignore verbose types disallow_untyped_decorators = false # Test decorators are not properly typed disallow_incomplete_defs = false # Sometimes we just want to ignore verbose types disable_error_code = ["var-annotated"] # TODO(eddiebergman): Too much to deal with right now [[tool.mypy.overrides]] module = [ "tabpfn.architectures.base.multi_head_attention", "tabpfn.architectures.encoders" ] ignore_errors = true [[tool.mypy.overrides]] module = [ "sklearn.*", "matplotlib.*", "einops.*", "networkx.*", "scipy.*", "pandas.*", "huggingface_hub.*", "joblib.*", "torch.*", "kditransform.*", "lightgbm.*", ] ignore_missing_imports = true # TODO: We don't necessarily need this [tool.pyright] include = ["src", "tests"] pythonVersion = "3.9" typeCheckingMode = "strict" strictListInference = true strictSetInference = true strictDictionaryInference = false reportImportCycles = false reportMissingSuperCall = true reportMissingTypeArgument = false reportOverlappingOverload = true reportIncompatibleVariableOverride = true reportIncompatibleMethodOverride = true reportInvalidTypeVarUse = true reportCallInDefaultInitializer = true reportImplicitOverride = true reportUnknownMemberType = false reportUnknownParameterType = false reportUnknownVariableType = false reportUnknownArgumentType = false reportUnknownLambdaType = false reportPrivateUsage = false reportUnnecessaryCast = false reportUnusedFunction = false reportMissingTypeStubs = false reportPrivateImportUsage = false reportUnnecessaryComparison = false reportConstantRedefinition = false reportUntypedFunctionDecorator = false [tool.licensecheck] # Acceptable licenses only_licenses = [ "APACHE", "MIT", "BSD", "ISC", "PYTHON", "UNLICENSE", ] # Packages that we don't consider ignore_packages = [ # Uses MPL, but acceptable because we don't modify it. https://github.com/certifi/python-certifi/blob/fb14ac49a976b1695d84b1ac1307276a20b3aac9/LICENSE "certifi", # Appears to be MIT, but the tool doesn't know. https://github.com/arogozhnikov/einops/blob/361b11e87da94ead4bd09de636c5dbed73e0e3e0/LICENSE "einops", # Is Apache Licensed since 1.2 but the tool doesn't know. https://github.com/calvinmccarter/kditransform/blob/5ee7cfad665bb1078211c0becad8fbd31e78429d/LICENSE "kditransform", "nvidia*", # Our packages "tabpfn*", # LICENSEREF-NVIDIA-SOFTWARE-LICENSE, introduced by torch>=2.10 "cuda-bindings", "cuda-toolkit", ] [tool.changelog-bot.towncrier_changelog] enabled = true changelog_skip_label = "no changelog needed" [tool.towncrier] directory = "changelog" filename = "CHANGELOG.md" start_string = "## [Unreleased]\n" title_format = "## [{version}] - {project_date}" issue_format = "[#{issue}](https://github.com/PriorLabs/TabPFN/pull/{issue})" [[tool.towncrier.type]] directory = "breaking" name = "Breaking Changes" showcontent = true [[tool.towncrier.type]] directory = "added" name = "Added" showcontent = true [[tool.towncrier.type]] directory = "changed" name = "Changed" showcontent = true [[tool.towncrier.type]] directory = "fixed" name = "Fixed" showcontent = true [[tool.towncrier.type]] directory = "deprecated" name = "Deprecated" showcontent = true [[tool.towncrier.section]] path = "" name = ""