| | import torch.nn as nn |
| | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
| | from typing import Optional, Tuple, Union, Any, Dict, List |
| | from transformers import Seq2SeqTrainer, GPT2LMHeadModel |
| | from torch.utils.data.distributed import DistributedSampler |
| | import torch |
| | from transformers.deepspeed import is_deepspeed_zero3_enabled |
| | from transformers.generation.logits_process import LogitsProcessorList |
| | from transformers.generation.stopping_criteria import StoppingCriteriaList |
| | from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput |
| | from transformers.generation.beam_search import BeamScorer |
| |
|
| | try: |
| | from torch_geometric.loader import DataLoader |
| | from torch_geometric.data import Dataset |
| | except ImportError: |
| | raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html') |
| |
|
| | class _GPT2LMHeadModel(GPT2LMHeadModel): |
| | def _init_(self, config): |
| | super(GPT2LMHeadModel, self).init_(config) |
| | self.config = config |
| | |
| | |
| | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, encoder_outputs=None, **kwargs): |
| | ''' |
| | This function is an edited version of the prepare_inputs_for_generation function from HuggingFace's transformers |
| | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
| | ''' |
| | token_type_ids = kwargs.get("token_type_ids", None) |
| | |
| | if past_key_values: |
| | input_ids = input_ids[:, -1].unsqueeze(-1) |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) |
| |
|
| | attention_mask = kwargs.get("attention_mask", None) |
| | position_ids = kwargs.get("position_ids", None) |
| | if self.config.prot2text_version=="1.1" or self.config.prot2text_version=="1.2": |
| | encoder_attention_mask = kwargs.get("encoder_attention_mask", None) |
| | elif self.config.prot2text_version=="1.0": |
| | encoder_attention_mask = None |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | if past_key_values: |
| | position_ids = position_ids[:, -1].unsqueeze(-1) |
| | else: |
| | position_ids = None |
| |
|
| | model_specific_kwargs = { |
| | "encoder_hidden_states": encoder_outputs['hidden_states'], |
| | } |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache"), |
| | "position_ids": position_ids, |
| | "attention_mask": attention_mask, |
| | "token_type_ids": token_type_ids, |
| | "encoder_attention_mask": encoder_attention_mask, |
| | **model_specific_kwargs |
| | } |
| | |
| | |
| | def greedy_search( |
| | self, |
| | input_ids: torch.LongTensor, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: bool = False, |
| | streamer: Optional["BaseStreamer"] = None, |
| | **model_kwargs, |
| | ) -> Union[GreedySearchOutput, torch.LongTensor]: |
| | ''' |
| | This function is an edited version of the greedy_search function from HuggingFace's transformers |
| | https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py |
| | ''' |
| | |
| | |
| | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| | if max_length is not None: |
| | warnings.warn( |
| | "`max_length` is deprecated in this function, use" |
| | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", |
| | UserWarning, |
| | ) |
| | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
| | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
| | if isinstance(eos_token_id, int): |
| | eos_token_id = [eos_token_id] |
| | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
| | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
| | output_attentions = ( |
| | output_attentions if output_attentions is not None else self.generation_config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
| | ) |
| | return_dict_in_generate = ( |
| | return_dict_in_generate |
| | if return_dict_in_generate is not None |
| | else self.generation_config.return_dict_in_generate |
| | ) |
| |
|
| | |
| | scores = () if (return_dict_in_generate and output_scores) else None |
| | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
| |
|
| | |
| | if return_dict_in_generate and self.config.is_encoder_decoder: |
| | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| | encoder_hidden_states = ( |
| | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| | ) |
| |
|
| | |
| | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
| |
|
| | this_peer_finished = False |
| | while True: |
| | if synced_gpus: |
| | |
| | |
| | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| | |
| | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| | |
| | if this_peer_finished_flag.item() == 0.0: |
| | break |
| |
|
| | |
| | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| |
|
| | |
| | outputs = self( |
| | **model_inputs, |
| | return_dict=True, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | if synced_gpus and this_peer_finished: |
| | continue |
| |
|
| | next_token_logits = outputs.logits[:, -1, :] |
| |
|
| | |
| | next_tokens_scores = logits_processor(input_ids, next_token_logits) |
| |
|
| | |
| | if return_dict_in_generate: |
| | if output_scores: |
| | scores += (next_tokens_scores,) |
| | if output_attentions: |
| | decoder_attentions += ( |
| | (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) |
| | ) |
| | if self.config.is_encoder_decoder: |
| | cross_attentions += (outputs.cross_attentions,) |
| |
|
| | if output_hidden_states: |
| | decoder_hidden_states += ( |
| | (outputs.decoder_hidden_states,) |
| | if self.config.is_encoder_decoder |
| | else (outputs.hidden_states,) |
| | ) |
| |
|
| | |
| | next_tokens = torch.argmax(next_tokens_scores, dim=-1) |
| |
|
| | |
| | if eos_token_id is not None: |
| | if pad_token_id is None: |
| | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
| | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
| |
|
| | |
| | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| | if streamer is not None: |
| | streamer.put(next_tokens.cpu()) |
| | model_kwargs = self._update_model_kwargs_for_generation( |
| | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| | ) |
| |
|
| | |
| | if eos_token_id_tensor is not None: |
| | unfinished_sequences = unfinished_sequences.mul( |
| | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
| | ) |
| |
|
| | |
| | if unfinished_sequences.max() == 0: |
| | this_peer_finished = True |
| |
|
| | |
| | try: |
| | if stopping_criteria(input_ids, scores): |
| | this_peer_finished = True |
| | except: |
| | if all(stopping_criteria(input_ids, scores)): |
| | this_peer_finished = True |
| |
|
| | if this_peer_finished and not synced_gpus: |
| | break |
| |
|
| | if streamer is not None: |
| | streamer.end() |
| |
|
| | if return_dict_in_generate: |
| | if self.config.is_encoder_decoder: |
| | return GreedySearchEncoderDecoderOutput( |
| | sequences=input_ids, |
| | scores=scores, |
| | encoder_attentions=encoder_attentions, |
| | encoder_hidden_states=encoder_hidden_states, |
| | decoder_attentions=decoder_attentions, |
| | cross_attentions=cross_attentions, |
| | decoder_hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return GreedySearchDecoderOnlyOutput( |
| | sequences=input_ids, |
| | scores=scores, |
| | attentions=decoder_attentions, |
| | hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return input_ids |
| | |
| | def _greedy_search( |
| | self, |
| | input_ids: torch.LongTensor, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: bool = False, |
| | streamer: Optional["BaseStreamer"] = None, |
| | **model_kwargs, |
| | ) -> Union[GreedySearchOutput, torch.LongTensor]: |
| | |
| | return self.greedy_search( |
| | input_ids, |
| | logits_processor, |
| | stopping_criteria, |
| | max_length, |
| | pad_token_id, |
| | eos_token_id, |
| | output_attentions, |
| | output_hidden_states, |
| | output_scores, |
| | return_dict_in_generate, |
| | synced_gpus, |
| | streamer, |
| | **model_kwargs, |
| | ) |
| | def _beam_search( |
| | self, |
| | input_ids: torch.LongTensor, |
| | beam_scorer: BeamScorer, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: bool = False, |
| | **model_kwargs, |
| | ) -> Union[BeamSearchOutput, torch.LongTensor]: |
| | |
| | return self.beam_search( |
| | input_ids, |
| | beam_scorer, |
| | logits_processor, |
| | stopping_criteria, |
| | max_length, |
| | pad_token_id, |
| | eos_token_id, |
| | output_attentions, |
| | output_hidden_states, |
| | output_scores, |
| | return_dict_in_generate, |
| | synced_gpus, |
| | **model_kwargs, |
| | ) |
| | |
| | def beam_search( |
| | self, |
| | input_ids: torch.LongTensor, |
| | beam_scorer: BeamScorer, |
| | logits_processor: Optional[LogitsProcessorList] = None, |
| | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[Union[int, List[int]]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_scores: Optional[bool] = None, |
| | return_dict_in_generate: Optional[bool] = None, |
| | synced_gpus: bool = False, |
| | **model_kwargs, |
| | ) -> Union[BeamSearchOutput, torch.LongTensor]: |
| | ''' |
| | This function is an edited version of the beam_search function from HuggingFace's transformers |
| | https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py |
| | ''' |
| | |
| | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| | if max_length is not None: |
| | warnings.warn( |
| | "`max_length` is deprecated in this function, use" |
| | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
| | UserWarning, |
| | ) |
| | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
| | if len(stopping_criteria) == 0: |
| | warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
| | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
| | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |
| | if isinstance(eos_token_id, int): |
| | eos_token_id = [eos_token_id] |
| | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
| | output_attentions = ( |
| | output_attentions if output_attentions is not None else self.generation_config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
| | ) |
| | return_dict_in_generate = ( |
| | return_dict_in_generate |
| | if return_dict_in_generate is not None |
| | else self.generation_config.return_dict_in_generate |
| | ) |
| |
|
| | batch_size = len(beam_scorer._beam_hyps) |
| | num_beams = beam_scorer.num_beams |
| |
|
| | batch_beam_size, cur_len = input_ids.shape |
| |
|
| | if num_beams * batch_size != batch_beam_size: |
| | raise ValueError( |
| | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| | ) |
| |
|
| | |
| | scores = () if (return_dict_in_generate and output_scores) else None |
| | beam_indices = ( |
| | tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
| | ) |
| | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
| |
|
| | |
| | if return_dict_in_generate and self.config.is_encoder_decoder: |
| | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
| | encoder_hidden_states = ( |
| | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
| | ) |
| |
|
| | |
| | |
| | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
| | beam_scores[:, 1:] = -1e9 |
| | beam_scores = beam_scores.view((batch_size * num_beams,)) |
| |
|
| | this_peer_finished = False |
| | while True: |
| | if synced_gpus: |
| | |
| | |
| | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
| | |
| | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
| | |
| | if this_peer_finished_flag.item() == 0.0: |
| | break |
| |
|
| | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| |
|
| | outputs = self( |
| | **model_inputs, |
| | return_dict=True, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | if synced_gpus and this_peer_finished: |
| | cur_len = cur_len + 1 |
| | continue |
| |
|
| | next_token_logits = outputs.logits[:, -1, :] |
| | |
| | |
| | |
| | next_token_scores = nn.functional.log_softmax( |
| | next_token_logits, dim=-1 |
| | ) |
| |
|
| | next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
| | |
| | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( |
| | next_token_scores_processed |
| | ) |
| |
|
| | |
| | if return_dict_in_generate: |
| | if output_scores: |
| | scores += (next_token_scores_processed,) |
| | if output_attentions: |
| | decoder_attentions += ( |
| | (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) |
| | ) |
| | if self.config.is_encoder_decoder: |
| | cross_attentions += (outputs.cross_attentions,) |
| |
|
| | if output_hidden_states: |
| | decoder_hidden_states += ( |
| | (outputs.decoder_hidden_states,) |
| | if self.config.is_encoder_decoder |
| | else (outputs.hidden_states,) |
| | ) |
| |
|
| | |
| | vocab_size = next_token_scores.shape[-1] |
| | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
| |
|
| | |
| |
|
| | |
| | next_token_scores, next_tokens = torch.topk( |
| | next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
| | ) |
| |
|
| | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| | next_tokens = next_tokens % vocab_size |
| |
|
| | |
| | beam_outputs = beam_scorer.process( |
| | input_ids, |
| | next_token_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | beam_indices=beam_indices, |
| | ) |
| |
|
| | beam_scores = beam_outputs["next_beam_scores"] |
| | beam_next_tokens = beam_outputs["next_beam_tokens"] |
| | beam_idx = beam_outputs["next_beam_indices"] |
| |
|
| | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| |
|
| | model_kwargs = self._update_model_kwargs_for_generation( |
| | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| | ) |
| | if model_kwargs["past_key_values"] is not None: |
| | model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
| |
|
| | if return_dict_in_generate and output_scores: |
| | beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
| |
|
| | |
| | cur_len = cur_len + 1 |
| |
|
| | try: |
| | if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
| | if not synced_gpus: |
| | break |
| | else: |
| | this_peer_finished = True |
| | except: |
| | if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): |
| | if not synced_gpus: |
| | break |
| | else: |
| | this_peer_finished = True |
| | |
| |
|
| | sequence_outputs = beam_scorer.finalize( |
| | input_ids, |
| | beam_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | max_length=stopping_criteria.max_length, |
| | beam_indices=beam_indices, |
| | ) |
| |
|
| | if return_dict_in_generate: |
| | if not output_scores: |
| | sequence_outputs["sequence_scores"] = None |
| |
|
| | if self.config.is_encoder_decoder: |
| | return BeamSearchEncoderDecoderOutput( |
| | sequences=sequence_outputs["sequences"], |
| | sequences_scores=sequence_outputs["sequence_scores"], |
| | scores=scores, |
| | beam_indices=sequence_outputs["beam_indices"], |
| | encoder_attentions=encoder_attentions, |
| | encoder_hidden_states=encoder_hidden_states, |
| | decoder_attentions=decoder_attentions, |
| | cross_attentions=cross_attentions, |
| | decoder_hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return BeamSearchDecoderOnlyOutput( |
| | sequences=sequence_outputs["sequences"], |
| | sequences_scores=sequence_outputs["sequence_scores"], |
| | scores=scores, |
| | beam_indices=sequence_outputs["beam_indices"], |
| | attentions=decoder_attentions, |
| | hidden_states=decoder_hidden_states, |
| | ) |
| | else: |
| | return sequence_outputs["sequences"] |
| | |
| |
|
| | class CABlock(nn.Module): |
| | ''' |
| | This function is an edited version of the gpt2 decoder block function from HuggingFace's transformers |
| | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
| | ''' |
| | def __init__(self, config, layer_idx=None): |
| | super().__init__() |
| | hidden_size = config.hidden_size |
| | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
| |
|
| | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| |
|
| | self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) |
| | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| |
|
| | self.mlp = GPT2MLP(inner_dim, config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Optional[Tuple[torch.FloatTensor]], |
| | layer_past: Optional[Tuple[torch.Tensor]] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | head_mask: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = False, |
| | output_attentions: Optional[bool] = False, |
| | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: |
| | |
| |
|
| | residual = hidden_states |
| | hidden_states = self.ln_cross_attn(hidden_states) |
| | cross_attn_outputs = self.crossattention( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | head_mask=head_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | output_attentions=output_attentions, |
| | ) |
| | attn_output = cross_attn_outputs[0] |
| | |
| | hidden_states = residual + attn_output |
| |
|
| | residual = hidden_states |
| | hidden_states = self.ln_2(hidden_states) |
| | feed_forward_hidden_states = self.mlp(hidden_states) |
| | |
| | hidden_states = residual + feed_forward_hidden_states |
| |
|
| | return (hidden_states,) |
| | |
| | class Prot2TextTrainer(Seq2SeqTrainer): |
| | ''' |
| | This function is an edited version of the Seq2SeqTrainer from HuggingFace's transformers |
| | ''' |
| | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| | if self.args.world_size > 1: |
| | eval_sampler = DistributedSampler(self.eval_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) |
| | else: |
| | eval_sampler = None |
| | return DataLoader( |
| | self.eval_dataset, |
| | batch_size=self.args.eval_batch_size, |
| | collate_fn=None, |
| | num_workers=self.args.dataloader_num_workers, |
| | pin_memory=self.args.dataloader_pin_memory, |
| | sampler=eval_sampler, |
| | ) |
| | def get_train_dataloader(self) -> DataLoader: |
| | if self.args.world_size > 1: |
| | train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) |
| | else: |
| | train_sampler = None |
| | return DataLoader( |
| | self.train_dataset, |
| | batch_size=self.args.per_device_train_batch_size, |
| | collate_fn=None, |
| | num_workers=self.args.dataloader_num_workers, |
| | pin_memory=self.args.dataloader_pin_memory, |
| | sampler=train_sampler, |
| | ) |
| | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: |
| | """ |
| | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
| | handling potential state. |
| | """ |
| | inputs = self._prepare_input(inputs) |
| | if len(inputs) == 0: |
| | raise ValueError( |
| | "The batch received was empty, your model won't be able to train on it. Double-check that your " |
| | f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." |
| | ) |
| | if self.args.past_index >= 0 and self._past is not None: |
| | inputs["mems"] = self._past |
| | |
| | inputs = inputs.to_dict() |
| | inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0) |
| | inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1) |
| | inputs = {k: v.to(device=self.args.device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} |
| | return inputs |
| | |
| | def prediction_step( |
| | self, |
| | model: nn.Module, |
| | inputs: Dict[str, Union[torch.Tensor, Any]], |
| | prediction_loss_only: bool, |
| | ignore_keys: Optional[List[str]] = None, |
| | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| | """ |
| | Perform an evaluation step on `model` using `inputs`. |
| | |
| | Subclass and override to inject custom behavior. |
| | |
| | Args: |
| | model (`nn.Module`): |
| | The model to evaluate. |
| | inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| | The inputs and targets of the model. |
| | |
| | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| | argument `labels`. Check your model's documentation for all accepted arguments. |
| | prediction_loss_only (`bool`): |
| | Whether or not to return the loss only. |
| | |
| | Return: |
| | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and |
| | labels (each being optional). |
| | """ |
| |
|
| | if not self.args.predict_with_generate or prediction_loss_only: |
| | return super().prediction_step( |
| | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys |
| | ) |
| |
|
| | has_labels = "labels" in inputs |
| | inputs = self._prepare_inputs(inputs) |
| |
|
| | |
| | gen_kwargs = self._gen_kwargs.copy() |
| | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
| | gen_kwargs["max_length"] = self.model.config.max_length |
| | gen_kwargs["num_beams"] = ( |
| | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams |
| | ) |
| | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False |
| | gen_kwargs["synced_gpus"] = ( |
| | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus |
| | ) |
| |
|
| | if "attention_mask" in inputs: |
| | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) |
| | if "global_attention_mask" in inputs: |
| | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) |
| |
|
| | generation_inputs = None |
| | gen_kwargs['x'] = inputs.get('x', None) |
| | gen_kwargs['edge_index'] = inputs.get('edge_index', None) |
| | gen_kwargs['edge_type'] = inputs.get('edge_type', None) |
| | gen_kwargs['batch'] = inputs.get('batch', None) |
| | gen_kwargs['encoder_input_ids'] = inputs.get('encoder_input_ids', None) |
| | gen_kwargs['decoder_input_ids'] = inputs.get('decoder_input_ids', None)[:,0:1] |
| | gen_kwargs["decoder_attention_mask"] = torch.ones(gen_kwargs['decoder_input_ids'].shape[0], 1).to(self.args.device) |
| |
|
| | generated_tokens = self.model.generate( |
| | generation_inputs, |
| | **gen_kwargs, |
| | ) |
| | |
| | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: |
| | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) |
| | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( |
| | gen_kwargs["max_new_tokens"] + 1 |
| | ): |
| | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) |
| |
|
| | with torch.no_grad(): |
| | if has_labels: |
| | with self.compute_loss_context_manager(): |
| | outputs = model(**inputs) |
| | if self.label_smoother is not None: |
| | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() |
| | else: |
| | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() |
| | else: |
| | loss = None |
| |
|
| | if self.args.prediction_loss_only: |
| | return (loss, None, None) |
| |
|
| | if has_labels: |
| | labels = inputs["labels"] |
| | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: |
| | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) |
| | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( |
| | gen_kwargs["max_new_tokens"] + 1 |
| | ): |
| | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) |
| | else: |
| | labels = None |
| |
|
| | return (loss, generated_tokens, labels) |