from pathlib import Path import types from typing import Optional, Union, Any import gc import openvino as ov try: from openvino import opset13 except ImportError: from openvino.runtime import opset13 import nncf import numpy as np import torch from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor, AutoConfig from transformers.generation import GenerationConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioCausalLMOutputWithPast def model_has_state(ov_model: ov.Model): return len(ov_model.get_sinks()) > 0 def model_has_input_output_name(ov_model: ov.Model, name: str): """ Helper function for checking that model has specified input or output name Parameters: ov_model (ov.Model): name (str): name of input or output Returns: True if input or output with requested name exists else False """ return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) def fuse_cache_reorder( ov_model: ov.Model, not_kv_inputs: list[str], key_value_input_names: list[str], gather_dim: int, ): """ Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly. Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model. Should be run before make_stateful. Implements optimumum's _reorder_cache inside the model in the beginning of each iteration. Gather works along given gather_dim dimension that may vary from model to model. KV-cache inputs are identified based on names in key_value_input_names. Append the new beam_idx parameter to not_kv_inputs. Parameters: ov_model (`ov.Model`): openvino model for processing not_kv_inputs (`list[str]`): list of input nodes in model that not related to past key values key_value_input_names (`list[str]`): list of names for key value input layers gather_dim (int): dimension for gathering cache during reorder pass """ if model_has_input_output_name(ov_model, "beam_idx"): raise ValueError("Model already has fused cache") input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0] beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? ov_model.add_parameters([beam_idx]) not_kv_inputs.append(ov_model.inputs[-1]) # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx for input_name in key_value_input_names: parameter_output_port = ov_model.input(input_name) consumers = parameter_output_port.get_target_inputs() gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) for consumer in consumers: consumer.replace_source_output(gather.output(0)) ov_model.validate_nodes_and_infer_types() def build_state_initializer(ov_model: ov.Model, batch_dim: int): """ Build initialization ShapeOf Expression for all ReadValue ops Parameters: ov_model (ov.Model): openvino model batch_dim (int): index of dimension corresponding to batch size """ input_ids = ov_model.input("inputs_embeds") batch = opset13.gather( opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0), ) for op in ov_model.get_ops(): if op.get_type_name() == "ReadValue": dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] dims[batch_dim] = batch dims = [(opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims] shape = opset13.concat(dims, axis=0) broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) op.set_arguments([broadcast]) ov_model.validate_nodes_and_infer_types() def make_stateful( ov_model: ov.Model, not_kv_inputs: list[str], key_value_input_names: list[str], key_value_output_names: list[str], batch_dim: int, num_attention_heads: int, num_beams_and_batch: int = None, ): """ Hides kv-cache inputs and outputs inside the model as variables. Parameters: ov_model (ov.Model): openvino model not_kv_inputs (`list[str]`): list of input nodes in model that not related to past key values key_value_input_names (`list[str]`): list of names for key value input layers key_value_output_names (`list[str]`): list of names for key value input layers batch_dim (int): index of batch dimension in key value layers num_attention_heads (int): number of attention heads for batch dimension initialization num_beams_an_batch (int): precalculated number of beams and batch for shapes initialization """ from openvino._offline_transformations import apply_make_stateful_transformation input_output_map = {} if num_beams_and_batch is not None: # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue for input in not_kv_inputs: shape = input.get_partial_shape() if shape.rank.get_length() <= 2: # == 1 for beam_index shape[0] = num_beams_and_batch input.get_node().set_partial_shape(shape) for kv_name_pair in zip(key_value_input_names, key_value_output_names): input_output_map[kv_name_pair[0]] = kv_name_pair[1] if num_beams_and_batch is not None: input = ov_model.input(kv_name_pair[0]) shape = input.get_partial_shape() shape[batch_dim] = num_beams_and_batch * num_attention_heads input.get_node().set_partial_shape(shape) if num_beams_and_batch is not None: # Re-validation model if shapes are altered above ov_model.validate_nodes_and_infer_types() apply_make_stateful_transformation(ov_model, input_output_map) if num_beams_and_batch is None: build_state_initializer(ov_model, batch_dim) def patch_stateful(ov_model): key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]] key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]] not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())] if not key_value_input_names or not key_value_output_names: return batch_dim = 0 num_attention_heads = 1 fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) make_stateful( ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None, ) core = ov.Core() def cleanup_torchscript_cache(): """ Helper for removing cached model representation """ torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() torch.jit._state._clear_class_state() LANGUAGE_MODEL_NAME = "openvino_language_model.xml" AUDIO_EMBEDING_NAME = "openvino_audio_embedding.xml" MULTIMODAL_PROJECTOR_NAME = "openvino_mulimodal_projection_model.xml" TEXT_EMBEDDING_NAME = "openvino_text_embedding_model.xml" def convert_qwen2audio_model(model_id, output_dir, quantization_config): output_dir = Path(output_dir) lang_model_path = output_dir / LANGUAGE_MODEL_NAME audio_embed_path = output_dir / AUDIO_EMBEDING_NAME projection_path = output_dir / MULTIMODAL_PROJECTOR_NAME embed_token_path = output_dir / TEXT_EMBEDDING_NAME if all( [ lang_model_path.exists(), audio_embed_path.exists(), projection_path.exists(), embed_token_path.exists(), ] ): print(f"✅ {model_id} model already converted. You can find results in {output_dir}") return print("⌛ {model_id} conversion started. Be patient, it may takes some time.") print("⌛ Load Original model") model = Qwen2AudioForConditionalGeneration.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model.config.save_pretrained(output_dir) processor.save_pretrained(output_dir) print("✅ Original model successfully loaded") if not embed_token_path.exists(): print("⌛ Convert Input embedding model") ov_model = ov.convert_model( model.get_input_embeddings(), example_input=torch.ones([2, 2], dtype=torch.int64), ) ov.save_model(ov_model, embed_token_path) del ov_model cleanup_torchscript_cache() gc.collect() print("✅ Input embedding model successfully converted") if not audio_embed_path.exists(): print("⌛ Convert Audio embedding model") ov_model = ov.convert_model( model.audio_tower, example_input={ "input_features": torch.randn([2, 128, 3000]), "attention_mask": torch.ones([2, 1, 1500, 1500]), }, ) ov.save_model(ov_model, audio_embed_path) del ov_model cleanup_torchscript_cache() gc.collect() print("✅ Audio embedding model successfully converted") if not projection_path.exists(): print("⌛ Convert Multimodal projector model") ov_model = ov.convert_model( model.multi_modal_projector, example_input=torch.ones([2, 750, 1280]), ) ov.save_model(ov_model, projection_path) del ov_model cleanup_torchscript_cache() gc.collect() print("✅ Multimodal projector model successfully converted") if not lang_model_path.exists(): print("⌛ Convert Language model") def forward_wrap( self, attention_mask, position_ids=None, past_key_values=None, inputs_embeds=None, ): from transformers.cache_utils import DynamicCache if past_key_values is not None: pkv = DynamicCache.from_legacy_cache(past_key_values) result = self._orig_forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=pkv, inputs_embeds=inputs_embeds, ) return (result.logits, result.past_key_values.to_legacy_cache()) lang_model = model.language_model print(lang_model.config) lang_model._orig_forward = lang_model.forward lang_model.forward = types.MethodType(forward_wrap, lang_model) hidden_size = lang_model.config.hidden_size num_pkv = lang_model.config.num_hidden_layers pkv_shape = (2, lang_model.config.num_key_value_heads, 2, hidden_size // lang_model.config.num_attention_heads) input_embeds = torch.randn((2, 2, hidden_size)) attention_mask = torch.ones([2, 4], dtype=torch.long) position_ids = torch.tensor([[2, 3], [2, 3]], dtype=torch.long) input_names = ["attention_mask", "position_ids"] output_names = ["logits"] past_key_values = [] for i in range(num_pkv): kv = [torch.randn(pkv_shape) for _ in range(2)] past_key_values.append(kv) input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"]) output_names.extend([f"present.{i}.key", f"present.{i}.value"]) input_names.append("inputs_embeds") example_input = {"inputs_embeds": input_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values} ov_model = ov.convert_model(lang_model, example_input=example_input) for input, input_name in zip(ov_model.inputs, input_names): input.get_tensor().set_names({input_name}) for output, output_name in zip(ov_model.outputs, output_names): output.get_tensor().set_names({output_name}) patch_stateful(ov_model) print("✅ Language model successfully converted") if quantization_config is not None: print(f"⌛ Weights compression with {quantization_config['mode']} mode started") ov_model = nncf.compress_weights(ov_model, **quantization_config) print("✅ Weights compression finished") ov.save_model(ov_model, lang_model_path) del ov_model cleanup_torchscript_cache() del model gc.collect() print(f"✅ {model_id} model conversion finished. You can find results in {output_dir}") class OvModelForCausalLMWithEmb(GenerationMixin): def __init__(self, model_dir, device="CPU", config=None, ov_config=None, compile=True) -> None: self._supports_cache_class = False self.config = AutoConfig.from_pretrained(model_dir).text_config if config is None else config self.config.is_decoder = True self.config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(self.config) model_dir = Path(model_dir) self.model = core.read_model(model_dir / LANGUAGE_MODEL_NAME) self.token_emb = core.read_model(model_dir / TEXT_EMBEDDING_NAME) self.request = None self.token_emb_request = None self._device = device.upper() self.device = torch.device("cpu") self.ov_config = ov_config self.next_beam_idx = None self._past_length = None self.input_names = [input_t.get_any_name() for input_t in self.model.inputs] self.main_input_name = "input_ids" if compile: self.compile() def compile(self): if self.request is None: self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request() self._compile_token_emb() def _compile_token_emb(self): if self.token_emb_request is None: self.token_emb_request = core.compile_model(self.token_emb, self._device, self.ov_config) def to(self, device: str): if isinstance(device, str): self._device = device.upper() self.clear_requests() return self def clear_requests(self): del self.request del self.token_emb_request self.request = None self.token_emb_request = None def embed_tokens(self, input_ids: torch.LongTensor): self._compile_token_emb() res = self.token_emb_request(input_ids, share_inputs=True) return res[0] def prepare_inputs( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, ): batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] inputs = {} # past_key_values are not used explicitly, instead they are handled inside the model if past_key_values is None: # This is the first iteration in a sequence, reset all states if self.request is not None: self.request.reset_state() # Set initial value for the next beam_idx input that will be used at the current iteration # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used self.next_beam_idx = np.arange(batch_size, dtype=int) self._past_length = 0 past_len = self._get_past_length(past_key_values) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids if past_key_values is None else input_ids[:, -1:]) if hasattr(self.config, "scale_emb"): inputs_embeds = inputs_embeds * self.config.scale_emb inputs["inputs_embeds"] = inputs_embeds # Add the attention_mask inputs when needed if "attention_mask" in self.input_names or "position_ids" in self.input_names: if attention_mask is not None: attention_mask = np.array(attention_mask) else: attention_mask = np.ones((inputs_embeds.shape[0], inputs_embeds.shape[1] + past_len), dtype=int) if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask if "position_ids" in self.input_names: if position_ids is not None: position_ids = np.array(position_ids) else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) return inputs def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, **kwargs, ): self.compile() inputs = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, inputs_embeds=inputs_embeds, **kwargs, ) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = self.request.get_tensor("logits").data logits = torch.from_numpy(logits).to(self.device) past_key_values = ((),) self._past_length += inputs["inputs_embeds"].shape[1] return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) if past_key_values is not None: past_len = self._get_past_length(past_key_values) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and input_ids is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif input_ids is not None and past_len < input_ids.shape[1]: input_ids = input_ids[:, past_len:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values and input_ids is not None: position_ids = position_ids[:, -input_ids.shape[1] :] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, "position_ids": position_ids, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds if past_key_values is None else None, } return model_inputs def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 return self._past_length # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache def _reorder_cache(self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration return past_key_values def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) class OVQwen2AudioForConditionalGeneration(GenerationMixin): def __init__(self, model_dir, device="CPU", ov_config=None): self.config = AutoConfig.from_pretrained(model_dir) self.generation_config = GenerationConfig.from_model_config(self.config) self.audio_tower = core.compile_model(model_dir / AUDIO_EMBEDING_NAME, device, ov_config) self.multi_modal_projector = core.compile_model(model_dir / MULTIMODAL_PROJECTOR_NAME, device, ov_config) self.vocab_size = self.config.text_config.vocab_size self.language_model = OvModelForCausalLMWithEmb(model_dir, device, self.config.text_config, ov_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.config.is_decoder = True self.config.is_encoder_decoder = False self._supports_cache_class = False self.main_input_name = "input_ids" self.device = torch.device("cpu") def can_generate(self): return True def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @property def padding_side(self): return self._padding_side @padding_side.setter def padding_side(self, padding_side: str): if padding_side not in ["left", "right"]: raise ValueError(f"{padding_side} is not `left` or `right`.") self._padding_side = padding_side def _merge_input_ids_with_audio_features(self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask): """ Merge input_ids with with audio features into final embeddings Args: audio_features (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): All audio vectors of all audios in the batch num_audio_tokens (`torch.LongTensor` of shape `(num_audios)`): The length of audio embeddings of each audio as stacked in `audio_features` inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): Token embeddings before merging with audio embeddings input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Input_ids of tokens, possibly filled with audio token attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Mask to avoid performing attention on padding token indices. Returns: final_embedding, final_attention_mask, position_ids, final_input_ids """ num_audios, max_audio_tokens, embed_dim = audio_features.shape audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens) < num_audio_tokens.unsqueeze(1) masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) batch_size, sequence_length = input_ids.shape _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) left_padding = True if batch_size > 1: if _left_padding and not _right_padding: left_padding = True elif not _left_padding and _right_padding: left_padding = False elif not _left_padding and not _right_padding: # both side is 1, so cannot tell left_padding = self.padding_side == "left" else: # invalid attention_mask raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") # 1. Create a mask to know where special audio tokens are special_audio_token_mask = input_ids == self.config.audio_token_index num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) batch_indices, non_audio_indices = torch.where((input_ids != self.config.audio_token_index) & (attention_mask == 1)) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged audio-text sequence. # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. # `torch.cumsum` computes how each audio token shifts subsequent text token positions. token_placeholder_num = torch.zeros_like(input_ids) token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 token_placeholder_num = token_placeholder_num + 1 new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 max_token_num = token_placeholder_num.sum(-1).max() nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_audio_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros(batch_size, max_token_num, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device) final_attention_mask = torch.zeros(batch_size, max_token_num, dtype=attention_mask.dtype, device=inputs_embeds.device) final_input_ids = torch.full((batch_size, max_token_num), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device) # 4. Fill the embeddings based on the mask. If we have ["hey" "