Instructions to use Vikhrmodels/Borealis with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Vikhrmodels/Borealis with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="Vikhrmodels/Borealis", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("Vikhrmodels/Borealis", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import WhisperModel, PreTrainedModel, WhisperFeatureExtractor | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from .configuration_borealis import BorealisConfig | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| class AudioLanguageAdapter(nn.Module): | |
| def __init__(self, hidden_size: int, dim: int) -> None: | |
| super().__init__() | |
| self.w_in = nn.Linear(hidden_size, dim, bias=False) | |
| self.gelu = nn.GELU() | |
| self.w_out = nn.Linear(dim, dim, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.w_out(self.gelu(self.w_in(x))) | |
| class BorealisForConditionalGeneration(PreTrainedModel, PyTorchModelHubMixin): | |
| config_class = BorealisConfig | |
| def __init__(self, config: BorealisConfig, language_model=None, tokenizer=None): | |
| super().__init__(config) | |
| assert tokenizer is not None, "Tokenizer надо передать в модельку" | |
| self.encoder: WhisperModel = WhisperModel.from_pretrained( | |
| config.whisper_encoder_name | |
| ).encoder | |
| self.encoder.to(torch.bfloat16) | |
| self.encoder.eval() | |
| for p in self.encoder.parameters(): | |
| p.requires_grad = False | |
| self.llm = language_model | |
| self.tokenizer = tokenizer | |
| self.llm.resize_token_embeddings(len(tokenizer)) | |
| self.downsample_factor = config.downsample_factor | |
| self.adapter = AudioLanguageAdapter( | |
| hidden_size=self.encoder.config.d_model * self.downsample_factor, | |
| dim=self.llm.config.hidden_size, | |
| ) | |
| self.adapter.to(torch.bfloat16) | |
| self.bos_id = tokenizer.convert_tokens_to_ids("<|im_start|>") | |
| self.audio_start_id = tokenizer.convert_tokens_to_ids("<|start_of_audio|>") | |
| self.audio_end_id = tokenizer.convert_tokens_to_ids("<|end_of_audio|>") | |
| def _downsample(self, seq: torch.Tensor) -> torch.Tensor: | |
| k, (T, d) = self.downsample_factor, seq.shape | |
| target = k * math.ceil(T / k) | |
| if target != T: | |
| seq = F.pad(seq, (0, 0, 0, target - T)) | |
| return seq.contiguous().view(target // k, d * k) | |
| def _tok_embed(self, tok_id: int, batch: int, device) -> torch.Tensor: | |
| idx = torch.full((batch, 1), tok_id, dtype=torch.long, device=device) | |
| return self.llm.get_input_embeddings()(idx) | |
| def forward( | |
| self, | |
| mel: torch.Tensor, | |
| audio_att_mask: torch.Tensor, | |
| labels: torch.Tensor, | |
| text_att_mask: torch.Tensor, | |
| ): | |
| B, device = mel.size(0), mel.device | |
| enc_out = self.encoder( | |
| input_features=mel, attention_mask=None, return_dict=True | |
| ).last_hidden_state | |
| audio_embs, audio_mask, max_T = [], [], 0 | |
| for seq in enc_out: | |
| ds = self._downsample(seq) | |
| audio_embs.append(ds) | |
| max_T = max(max_T, ds.size(0)) | |
| for ds in audio_embs: | |
| pad = max_T - ds.size(0) | |
| audio_mask.append( | |
| torch.cat( | |
| [ | |
| torch.ones(ds.size(0), dtype=torch.long, device=device), | |
| torch.zeros(pad, dtype=torch.long, device=device), | |
| ] | |
| ) | |
| ) | |
| if pad: | |
| ds = F.pad(ds, (0, 0, 0, pad)) | |
| audio_embeddings = torch.stack(audio_embs, 0) | |
| audio_mask = torch.stack(audio_mask, 0) | |
| audio_embeddings = self.adapter(audio_embeddings) | |
| text_embeddings = self.llm.get_input_embeddings()(labels) | |
| sa_positions = (labels == self.audio_start_id).nonzero(as_tuple=True) | |
| ea_positions = (labels == self.audio_end_id).nonzero(as_tuple=True) | |
| inputs_embeds = [] | |
| att_mask = [] | |
| for b in range(B): | |
| sa_idx = sa_positions[1][sa_positions[0] == b].item() | |
| ea_idx = ea_positions[1][ea_positions[0] == b].item() | |
| prefix_emb = text_embeddings[b, : sa_idx + 1] | |
| postfix_emb = text_embeddings[b, ea_idx:] | |
| emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0) | |
| prefix_mask = text_att_mask[b, : sa_idx + 1] | |
| postfix_mask = text_att_mask[b, ea_idx:] | |
| full_mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0) | |
| inputs_embeds.append(emb) | |
| att_mask.append(full_mask) | |
| inputs_embeds = torch.nn.utils.rnn.pad_sequence( | |
| inputs_embeds, batch_first=True, padding_value=0.0 | |
| ) | |
| att_mask = torch.nn.utils.rnn.pad_sequence( | |
| att_mask, batch_first=True, padding_value=0 | |
| ) | |
| assistant_prompt = self.tokenizer( | |
| "<|im_start|>assistant\n", add_special_tokens=False | |
| ).input_ids | |
| assistant_starts = [] | |
| for b in range(B): | |
| seq = labels[b] | |
| for i in range(len(seq) - len(assistant_prompt)): | |
| if torch.equal( | |
| seq[i : i + len(assistant_prompt)], | |
| torch.tensor(assistant_prompt, device=device), | |
| ): | |
| assistant_start = i + len(assistant_prompt) | |
| break | |
| else: | |
| raise ValueError("Assistant prompt not found") | |
| assistant_starts.append(assistant_start + (ea_idx - sa_idx - 1) + max_T) | |
| max_len = inputs_embeds.size(1) | |
| loss_labels = labels.new_full((B, max_len), -100) | |
| for b in range(B): | |
| orig_assist_start = assistant_starts[b] - max_T - (ea_idx - sa_idx - 1) | |
| content_len = len(labels[b]) - orig_assist_start | |
| loss_labels[b, assistant_starts[b] : assistant_starts[b] + content_len] = ( | |
| labels[b, orig_assist_start:] | |
| ) | |
| if self.tokenizer.pad_token_id is not None: | |
| loss_labels[loss_labels == self.tokenizer.pad_token_id] = -100 | |
| out = self.llm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=att_mask, | |
| labels=loss_labels, | |
| return_dict=True, | |
| ) | |
| return out.loss, out.logits | |
| def generate( | |
| self, | |
| mel: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| max_new_tokens: int = 512, | |
| **kwargs, | |
| ): | |
| return_tokens = kwargs.pop("return_tokens", False) | |
| single = mel.dim() == 2 | |
| if single: | |
| mel, att_mask = mel.unsqueeze(0), att_mask.unsqueeze(0) | |
| mel = mel.to(torch.bfloat16) | |
| B, device = mel.size(0), mel.device | |
| enc_out = self.encoder( | |
| input_features=mel, attention_mask=None, return_dict=True | |
| ).last_hidden_state | |
| audio_embs, audio_mask, max_T = [], [], 0 | |
| for seq in enc_out: | |
| ds = self._downsample(seq) | |
| audio_embs.append(ds) | |
| max_T = max(max_T, ds.size(0)) | |
| for i, ds in enumerate(audio_embs): | |
| pad = max_T - ds.size(0) | |
| audio_mask.append( | |
| torch.cat( | |
| [ | |
| torch.ones(ds.size(0), dtype=torch.long, device=device), | |
| torch.zeros(pad, dtype=torch.long, device=device), | |
| ] | |
| ) | |
| ) | |
| if pad: | |
| audio_embs[i] = F.pad(ds, (0, 0, 0, pad)) | |
| audio_embeddings = torch.stack(audio_embs, 0) | |
| audio_mask = torch.stack(audio_mask, 0) | |
| audio_embeddings = self.adapter(audio_embeddings) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "Вы полезный помощник по автоматическому распознаванию речи. Точно транскрибируйте аудио в текст.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": "Транскрибируйте это аудио: <|start_of_audio|><|end_of_audio|>", | |
| }, | |
| ] | |
| chat_text = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = self.tokenizer(chat_text, return_tensors="pt").to(device) | |
| input_ids = model_inputs.input_ids.repeat(B, 1) | |
| text_att_mask = model_inputs.attention_mask.repeat(B, 1) | |
| text_embeddings = self.llm.get_input_embeddings()(input_ids) | |
| sa_idx = (input_ids[0] == self.audio_start_id).nonzero(as_tuple=True)[0].item() | |
| ea_idx = (input_ids[0] == self.audio_end_id).nonzero(as_tuple=True)[0].item() | |
| inputs_embeds = [] | |
| full_att_mask = [] | |
| for b in range(B): | |
| prefix_emb = text_embeddings[b, : sa_idx + 1] | |
| postfix_emb = text_embeddings[b, ea_idx:] | |
| emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0) | |
| prefix_mask = text_att_mask[b, : sa_idx + 1] | |
| postfix_mask = text_att_mask[b, ea_idx:] | |
| mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0) | |
| inputs_embeds.append(emb) | |
| full_att_mask.append(mask) | |
| inputs_embeds = torch.nn.utils.rnn.pad_sequence( | |
| inputs_embeds, batch_first=True, padding_value=0.0 | |
| ) | |
| att_mask = torch.nn.utils.rnn.pad_sequence( | |
| full_att_mask, batch_first=True, padding_value=0 | |
| ) | |
| gen_ids = self.llm.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=att_mask, | |
| max_new_tokens=max_new_tokens, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| **kwargs, | |
| ) | |
| if return_tokens: | |
| return gen_ids[0] if single else gen_ids | |
| else: | |
| txt = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) | |
| if single: | |
| return txt[0] | |
| else: | |
| return [t for t in txt] | |
| def save_pretrained(self, save_directory, **kwargs): | |
| os.makedirs(save_directory, exist_ok=True) | |
| self.config.save_pretrained(save_directory) | |
| state_dict = self.state_dict() | |
| torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) | |
| self.tokenizer.save_pretrained(save_directory) | |
| extractor = WhisperFeatureExtractor.from_pretrained( | |
| self.config.whisper_encoder_name | |
| ) | |
| extractor.save_pretrained(save_directory) | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| config = BorealisConfig.from_pretrained(pretrained_model_name_or_path) | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | |
| language_model = AutoModelForCausalLM.from_pretrained(config.llm_name) | |
| model = cls(config, language_model=language_model, tokenizer=tokenizer) | |
| state_dict_path = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin" | |
| ) | |
| state_dict = torch.load(state_dict_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| return model | |
| BorealisForConditionalGeneration.register_for_auto_class("AutoModelForCausalLM") | |