diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 853c13b13..2314f350e 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -32,6 +32,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -66,6 +67,7 @@ class BailingAttention(nn.Module): config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, prefix: str = "", ): super().__init__() @@ -82,10 +84,11 @@ class BailingAttention(nn.Module): self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads - self.num_kv_heads = self.total_kv_heads // tp_size self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.use_rmsnorm = getattr(config, "use_rmsnorm", False) self.query_key_value = QKVParallelLinear( self.hidden_size, @@ -97,30 +100,46 @@ class BailingAttention(nn.Module): prefix=f"{prefix}.query_key_value", ) + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm \ + else nn.LayerNorm(self.head_dim, eps=1e-6) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm \ + else nn.LayerNorm(self.head_dim, eps=1e-6) + self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=config.use_bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.dense", ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn") + if hasattr(config, "partial_rotary_factor"): + self.rotary_dim = int(self.head_dim * config.partial_rotary_factor) + elif hasattr(config, "rotary_dim"): + self.rotary_dim = config.rotary_dim + else: + self.rotary_dim = self.head_dim self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, base=config.rope_theta, is_neox_style=True, rope_scaling=config.rope_scaling, ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + def forward( self, hidden_states: torch.Tensor, @@ -133,6 +152,14 @@ class BailingAttention(nn.Module): ], dim=-1) + if self.use_qk_norm: + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + q = self.query_layernorm(q) + k = self.key_layernorm(k) + q = q.view(-1, self.q_size_per_rank) + k = k.view(-1, self.kv_size_per_rank) + q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn(q, k, v) @@ -196,44 +223,95 @@ class BailingMoE(nn.Module): self.hidden_size = config.hidden_size self.quant_config = quant_config self.num_shared_experts = config.num_shared_experts - # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_experts, - bias=False, - quant_config=None) - - self.experts = FusedMoE(num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") + self.score_function = getattr(config, "score_function", None) + self.n_group = getattr(config, "n_group", None) + self.topk_group = getattr(config, "topk_group", None) + self.use_grouped_topk = (self.n_group is not None + and self.topk_group is not None) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + router_dtype = getattr(config, "router_dtype", None) + if router_dtype is None: + self.router_dtype = None + elif router_dtype == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_experts, + bias=False, + quant_config=None, + params_dtype=self.router_dtype, + ) + + if getattr(config, "moe_router_enable_expert_bias", False): + self.gate.expert_bias = nn.Parameter(torch.empty((config.num_experts,), dtype=torch.float32)) + else: + self.gate.expert_bias = None + + self.correction_bias = ( + self.gate.expert_bias.data if self.gate.expert_bias is not None else None + ) + + if self.score_function is not None: + assert ( + self.score_function == "softmax" and self.correction_bias is None + ) or ( + self.score_function == "sigmoid" and self.correction_bias is not None + ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" + else: + # default value for scoring_func + self.score_function = "softmax" + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=self.gate.expert_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + ) if self.num_shared_experts > 0: - intermediate_size = (config.moe_intermediate_size * - self.num_shared_experts) + if hasattr(config, "moe_shared_expert_intermediate_size"): + intermediate_size = config.moe_shared_expert_intermediate_size + else: + intermediate_size = config.moe_intermediate_size + intermediate_size *= config.num_shared_experts self.shared_experts = BailingMLP( intermediate_size=intermediate_size, config=config, quant_config=quant_config, reduce_results=False, - prefix=f"{prefix}.shared_experts") + prefix=f"{prefix}.shared_experts" + ) else: self.shared_experts = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_size) - if self.num_shared_experts > 0: + if self.shared_experts: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits, _ = self.gate(hidden_states.to(self.router_dtype)) + router_logits = router_logits.to(hidden_states.dtype) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.num_shared_experts > 0: + final_hidden_states *= self.routed_scaling_factor + + if self.shared_experts: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -252,20 +330,28 @@ class BailingMoeBlock(nn.Module): prefix: str = "", ): super().__init__() + layer_idx = int(prefix.split('.')[-1]) + self.config = config hidden_size = config.hidden_size intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - self.attention = BailingAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") - self.post_attention_layernorm = RMSNorm(hidden_size, - eps=config.rms_norm_eps) - self.mlp = BailingMoE(intermediate_size, - config, - quant_config, - True, - prefix=f"{prefix}.mlp") + + self.attention = BailingAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.attention", + ) + + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + + # Choose MLP class based on the number of experts and layer index + if layer_idx < config.first_k_dense_replace: + mlp_class = BailingMLP + else: + mlp_class = BailingMoE + self.mlp = mlp_class(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") def forward( self, @@ -291,6 +377,7 @@ class BailingMoeBlock(nn.Module): return hidden_states, residual +@support_torch_compile class BailingMoeModel(nn.Module): def __init__( @@ -307,11 +394,16 @@ class BailingMoeModel(nn.Module): self.config = config self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or (self.tie_word_embeddings and + get_pp_group().is_last_rank): self.word_embeddings = VocabParallelEmbedding( - self.vocab_size, self.embed_dim) + self.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.word_embeddings", + ) else: self.word_embeddings = PPMissingLayer() @@ -325,11 +417,14 @@ class BailingMoeModel(nn.Module): quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers" + ) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + ["hidden_states", "residual"], config.hidden_size + ) + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) @@ -370,8 +465,11 @@ class BailingMoeModel(nn.Module): "hidden_states": hidden_states, "residual": residual }) - - hidden_states, _ = self.norm(hidden_states, residual) + else: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -394,7 +492,11 @@ class BailingMoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.config.norm_head and "lm_head.weight" in name: + if ( + hasattr(self.config, "norm_head") + and self.config.norm_head + and "lm_head.weight" in name + ): loaded_weight = F.normalize(loaded_weight, dim=0, p=2, @@ -428,13 +530,17 @@ class BailingMoeModel(nn.Module): if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -472,24 +578,37 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): super().__init__() config = vllm_config.model_config.hf_config + if hasattr(config, "llm_config"): + config = config.llm_config + vllm_config.model_config.hf_config = config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings self.model = BailingMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if get_pp_group().is_last_rank: - self.lm_head = (self.word_embeddings if config.tie_word_embeddings - else ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config)) + if self.tie_word_embeddings: + self.lm_head = self.model.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -519,9 +638,12 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + if self.tie_word_embeddings else None), ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + +class BailingMoeV2ForCausalLM(BailingMoeForCausalLM): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2aaac7798..fcf295e50 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -43,6 +43,7 @@ _TEXT_GENERATION_MODELS = { # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), + "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),