""" CVE-2026-26198 — Patched Implementation This module demonstrates the fix: validate column names against the model's known fields BEFORE passing them to any SQL construction. The fix applies the same validation that sum()/avg() already had to min() and max(), ensuring consistency across all aggregate methods. Two approaches are shown: 1. Whitelist validation (recommended, simple) 2. Parameterized column reference (defense-in-depth) """ import re import sqlite3 class ColumnValidationError(Exception): """Raised when a column name fails validation.""" class PatchedQuerySet: """ Fixed QuerySet that validates all column names before query construction. The key change: every aggregate method now calls _validate_column() before the column name can reach any SQL expression. """ def __init__(self, db_path: str, table_name: str, model_fields: list[str]): self.db_path = db_path self.table_name = table_name self.model_fields = model_fields def _execute(self, sql: str) -> any: conn = sqlite3.connect(self.db_path) try: cursor = conn.execute(sql) row = cursor.fetchone() return row[0] if row else None finally: conn.close() def _validate_column(self, column: str) -> str: """ Validate that the column name refers to an actual model field. This is the core of the fix. The approach is a WHITELIST: we only allow column names that exist in the model's field definitions. Why whitelist instead of blacklist? - Blacklisting SQL keywords is brittle (new keywords, encoding tricks) - Whitelisting is simple and definitive: if it's not a known field, reject it - This matches what Ormar's sum()/avg() already did with is_numeric Additional hardening: we also reject anything that doesn't look like a valid identifier, as a defense-in-depth measure. """ # Strip optional table prefix (Ormar uses "tablename__fieldname" syntax) parts = column.split("__") field_name = parts[-1] # Defense-in-depth: column names must be valid SQL identifiers # This catches injection attempts even if the whitelist is misconfigured if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', field_name): raise ColumnValidationError( f"Invalid column name: {column!r}. " f"Column names must be valid identifiers." ) # Primary defense: whitelist against known model fields if field_name not in self.model_fields: raise ColumnValidationError( f"Unknown column: {column!r}. " f"Valid columns are: {', '.join(self.model_fields)}" ) return field_name # --- ALL aggregate methods now validate before query construction --- def max(self, column: str) -> any: """ PATCHED: Validates column name before building the query. The fix ensures that only known field names can reach the SQL layer. Injection payloads like "(SELECT password FROM users)" are rejected because they don't match any field in the model. """ safe_column = self._validate_column(column) sql = f"SELECT max({safe_column}) FROM {self.table_name}" return self._execute(sql) def min(self, column: str) -> any: """PATCHED: Same validation as max().""" safe_column = self._validate_column(column) sql = f"SELECT min({safe_column}) FROM {self.table_name}" return self._execute(sql) def sum(self, column: str) -> any: """Already had validation, now uses the shared _validate_column().""" safe_column = self._validate_column(column) sql = f"SELECT sum({safe_column}) FROM {self.table_name}" return self._execute(sql) def avg(self, column: str) -> any: """Already had validation, now uses the shared _validate_column().""" safe_column = self._validate_column(column) sql = f"SELECT avg({safe_column}) FROM {self.table_name}" return self._execute(sql)