lhallee commited on
Commit
c3b74d2
·
verified ·
1 Parent(s): 1e8ace1

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +33 -3
modeling_esm_plusplus.py CHANGED
@@ -156,6 +156,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
156
  return _collate_fn
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  class EmbeddingMixin:
160
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
  raise NotImplementedError
@@ -243,7 +263,7 @@ class EmbeddingMixin:
243
 
244
  def embed_dataset(
245
  self,
246
- sequences: List[str],
247
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
248
  batch_size: int = 2,
249
  max_len: int = 512,
@@ -256,6 +276,7 @@ class EmbeddingMixin:
256
  save: bool = True,
257
  sql_db_path: str = 'embeddings.db',
258
  save_path: str = 'embeddings.pth',
 
259
  **kwargs,
260
  ) -> Optional[dict[str, torch.Tensor]]:
261
  """
@@ -264,7 +285,15 @@ class EmbeddingMixin:
264
  Supports two modes:
265
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
266
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
 
 
 
267
  """
 
 
 
 
 
268
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
269
  sequences = sorted(sequences, key=len, reverse=True)
270
  hidden_size = self.config.hidden_size
@@ -653,8 +682,9 @@ def get_attention_mask(
653
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
654
  return attention_mask_2d, None, flex_block_mask
655
 
656
- # SDPA / manual
657
- attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
 
658
  return attention_mask_2d, attention_mask_4d, None
659
 
660
 
 
156
  return _collate_fn
157
 
158
 
159
+ def parse_fasta(fasta_path: str) -> List[str]:
160
+ assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
161
+ sequences = []
162
+ current_seq = []
163
+ with open(fasta_path, 'r') as f:
164
+ for line in f:
165
+ line = line.strip()
166
+ if not line:
167
+ continue
168
+ if line.startswith('>'):
169
+ if current_seq:
170
+ sequences.append(''.join(current_seq))
171
+ current_seq = []
172
+ else:
173
+ current_seq.append(line)
174
+ if current_seq:
175
+ sequences.append(''.join(current_seq))
176
+ return sequences
177
+
178
+
179
  class EmbeddingMixin:
180
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
181
  raise NotImplementedError
 
263
 
264
  def embed_dataset(
265
  self,
266
+ sequences: Optional[List[str]] = None,
267
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
268
  batch_size: int = 2,
269
  max_len: int = 512,
 
276
  save: bool = True,
277
  sql_db_path: str = 'embeddings.db',
278
  save_path: str = 'embeddings.pth',
279
+ fasta_path: Optional[str] = None,
280
  **kwargs,
281
  ) -> Optional[dict[str, torch.Tensor]]:
282
  """
 
285
  Supports two modes:
286
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
287
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
288
+
289
+ Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
290
+ `fasta_path`, or both (the two sources are combined). At least one must be provided.
291
  """
292
+ if fasta_path is not None:
293
+ fasta_sequences = parse_fasta(fasta_path)
294
+ sequences = list(sequences or []) + fasta_sequences
295
+ assert sequences is not None and len(sequences) > 0, \
296
+ "Must provide at least one sequence via `sequences` or `fasta_path`."
297
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
298
  sequences = sorted(sequences, key=len, reverse=True)
299
  hidden_size = self.config.hidden_size
 
682
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
683
  return attention_mask_2d, None, flex_block_mask
684
 
685
+ # SDPA / manual — only mask the key dimension so padding query positions attend to
686
+ # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
687
+ attention_mask_4d = attention_mask_2d[:, None, None, :]
688
  return attention_mask_2d, attention_mask_4d, None
689
 
690