""" Mean Reversion Envelope Strategy. Logic: 1. Compute a short-term price average (SMA, EMA, or Donchian midline) 2. Create envelope bands at configured percentage offsets 3. Enter long when price dips below lower bands; enter short when price spikes above upper bands 4. Exit when price returns to the average 5. Stop-loss and close-all rules for risk management """ import logging from typing import Optional import pandas as pd from src.models import ( AverageType, CoinConfig, Position, Side, Signal, SignalType, TradeDirection, ) logger = logging.getLogger("crypto_bot") class EnvelopeStrategy: """Mean reversion envelope strategy implementation.""" def __init__(self, config: CoinConfig): self.config = config self.symbol = config.symbol self.avg_type = config.average_type self.avg_period = config.average_period self.envelopes = sorted(config.envelopes) # ascending percentages self.stop_loss = config.stop_loss self.price_jump_pct = config.price_jump_pct self.trade_direction = config.trade_direction self.num_envelopes = len(self.envelopes) # ── Price Average Calculation ────────────────────────────────── def compute_average(self, df: pd.DataFrame) -> pd.Series: """Compute the price average based on configured type.""" close = df["close"] if self.avg_type == AverageType.SMA: return close.rolling(window=self.avg_period).mean() elif self.avg_type == AverageType.EMA: return close.ewm(span=self.avg_period, adjust=False).mean() elif self.avg_type == AverageType.DONCHIAN: high_rolling = df["high"].rolling(window=self.avg_period).max() low_rolling = df["low"].rolling(window=self.avg_period).min() return (high_rolling + low_rolling) / 2 else: raise ValueError(f"Unknown average type: {self.avg_type}") def compute_envelope_bands(self, average: pd.Series) -> dict: """ Compute upper and lower envelope bands. Returns dict with 'upper' and 'lower' lists of Series. """ bands = {"upper": [], "lower": []} for pct in self.envelopes: bands["upper"].append(average * (1 + pct)) bands["lower"].append(average * (1 - pct)) return bands # ── Signal Generation ────────────────────────────────────────── def generate_signals(self, df: pd.DataFrame, position: Optional[Position] = None) -> list: """ Analyze latest candle data and generate trading signals. Args: df: OHLCV DataFrame (needs at least avg_period + a few candles) position: Current open position (if any) Returns: List of Signal objects """ if len(df) < self.avg_period + 2: logger.warning(f"Not enough data for {self.symbol} (need {self.avg_period + 2}, got {len(df)})") return [] avg = self.compute_average(df) bands = self.compute_envelope_bands(avg) current_close = df["close"].iloc[-1] current_low = df["low"].iloc[-1] current_high = df["high"].iloc[-1] prev_close = df["close"].iloc[-2] current_avg = avg.iloc[-1] if pd.isna(current_avg): return [] signals = [] # ── Check for exit conditions first ──────────────────────── if position and position.is_open: # Stop-loss check stop_signal = self._check_stop_loss(position, current_close, current_avg) if stop_signal: signals.append(stop_signal) return signals # Stop-loss overrides everything # Close-all / price jump check if self.price_jump_pct is not None: close_all_signal = self._check_price_jump(position, prev_close) if close_all_signal: signals.append(close_all_signal) return signals # Normal exit: price returned to average exit_signal = self._check_exit(position, current_close, current_avg) if exit_signal: signals.append(exit_signal) return signals # ── Check re-entry block ─────────────────────────────────── if position and position.blocked_reentry: # Check if price has crossed back over the average if position.side == Side.LONG and current_close >= current_avg: position.blocked_reentry = False logger.info(f"{self.symbol}: Re-entry unblocked (price above average)") elif position.side == Side.SHORT and current_close <= current_avg: position.blocked_reentry = False logger.info(f"{self.symbol}: Re-entry unblocked (price below average)") if position.blocked_reentry: return [] # Still blocked # ── Check for entry signals ──────────────────────────────── if position and position.is_open: # Already in a position — check for additional envelope entries signals.extend(self._check_additional_entries( position, current_low, current_high, bands, current_avg )) else: # No position — check for new entries signals.extend(self._check_new_entries( current_low, current_high, bands, current_avg )) return signals # ── Entry Checks ─────────────────────────────────────────────── def _check_new_entries(self, low: float, high: float, bands: dict, avg: float) -> list: """Check if price has hit any envelope bands for a new position.""" signals = [] # Long entries (price dipped below lower bands) if self.trade_direction in (TradeDirection.BOTH, TradeDirection.LONG): for i, band in enumerate(bands["lower"]): band_val = band.iloc[-1] if pd.isna(band_val): continue if low <= band_val: signals.append(Signal( type=SignalType.ENTRY, side=Side.LONG, symbol=self.symbol, price=band_val, envelope_index=i, reason=f"Long entry: price hit lower band {i+1} ({self.envelopes[i]*100:.1f}%)", )) # Short entries (price spiked above upper bands) if self.trade_direction in (TradeDirection.BOTH, TradeDirection.SHORT): for i, band in enumerate(bands["upper"]): band_val = band.iloc[-1] if pd.isna(band_val): continue if high >= band_val: signals.append(Signal( type=SignalType.ENTRY, side=Side.SHORT, symbol=self.symbol, price=band_val, envelope_index=i, reason=f"Short entry: price hit upper band {i+1} ({self.envelopes[i]*100:.1f}%)", )) return signals def _check_additional_entries(self, position: Position, low: float, high: float, bands: dict, avg: float) -> list: """Check for additional envelope entries on an existing position.""" signals = [] already_hit = position.envelopes_hit if position.side == Side.LONG: for i in range(already_hit, len(bands["lower"])): band_val = bands["lower"][i].iloc[-1] if pd.isna(band_val): continue if low <= band_val: signals.append(Signal( type=SignalType.ENTRY, side=Side.LONG, symbol=self.symbol, price=band_val, envelope_index=i, reason=f"Long add: price hit lower band {i+1} ({self.envelopes[i]*100:.1f}%)", )) elif position.side == Side.SHORT: for i in range(already_hit, len(bands["upper"])): band_val = bands["upper"][i].iloc[-1] if pd.isna(band_val): continue if high >= band_val: signals.append(Signal( type=SignalType.ENTRY, side=Side.SHORT, symbol=self.symbol, price=band_val, envelope_index=i, reason=f"Short add: price hit upper band {i+1} ({self.envelopes[i]*100:.1f}%)", )) return signals # ── Exit Checks ──────────────────────────────────────────────── def _check_exit(self, position: Position, close: float, avg: float) -> Optional[Signal]: """Check if price has returned to the average for a normal exit.""" if position.side == Side.LONG and close >= avg: return Signal( type=SignalType.EXIT, side=position.side, symbol=self.symbol, price=avg, reason="Exit: price returned to average", ) elif position.side == Side.SHORT and close <= avg: return Signal( type=SignalType.EXIT, side=position.side, symbol=self.symbol, price=avg, reason="Exit: price returned to average", ) return None def _check_stop_loss(self, position: Position, close: float, avg: float) -> Optional[Signal]: """Check if stop-loss should trigger.""" if position.avg_entry_price == 0: return None if position.side == Side.LONG: loss_pct = (position.avg_entry_price - close) / position.avg_entry_price else: loss_pct = (close - position.avg_entry_price) / position.avg_entry_price if loss_pct >= self.stop_loss: return Signal( type=SignalType.STOP_LOSS, side=position.side, symbol=self.symbol, price=close, reason=f"Stop-loss triggered at {loss_pct*100:.1f}% loss (limit: {self.stop_loss*100:.1f}%)", ) return None def _check_price_jump(self, position: Position, prev_close: float) -> Optional[Signal]: """Check if the close-all/price-jump rule should trigger.""" if position.avg_entry_price == 0: return None if position.side == Side.LONG: jump_pct = (position.avg_entry_price - prev_close) / position.avg_entry_price else: jump_pct = (prev_close - position.avg_entry_price) / position.avg_entry_price if jump_pct >= self.price_jump_pct: return Signal( type=SignalType.CLOSE_ALL, side=position.side, symbol=self.symbol, price=prev_close, reason=f"Close-all rule: prev candle {jump_pct*100:.1f}% from entry (limit: {self.price_jump_pct*100:.1f}%)", ) return None # ── Utilities ────────────────────────────────────────────────── def get_position_size(self, capital: float, envelope_index: int) -> float: """Calculate position size for a given envelope entry.""" return capital / self.num_envelopes def get_current_bands(self, df: pd.DataFrame) -> dict: """Get current envelope band values (for display/logging).""" avg = self.compute_average(df) bands = self.compute_envelope_bands(avg) current_avg = avg.iloc[-1] result = {"average": current_avg, "upper": [], "lower": []} for i in range(self.num_envelopes): result["upper"].append(bands["upper"][i].iloc[-1]) result["lower"].append(bands["lower"][i].iloc[-1]) return result