Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +33 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -156,6 +156,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
|
|
| 156 |
return _collate_fn
|
| 157 |
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
class EmbeddingMixin:
|
| 160 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 161 |
raise NotImplementedError
|
|
@@ -243,7 +263,7 @@ class EmbeddingMixin:
|
|
| 243 |
|
| 244 |
def embed_dataset(
|
| 245 |
self,
|
| 246 |
-
sequences: List[str],
|
| 247 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 248 |
batch_size: int = 2,
|
| 249 |
max_len: int = 512,
|
|
@@ -256,6 +276,7 @@ class EmbeddingMixin:
|
|
| 256 |
save: bool = True,
|
| 257 |
sql_db_path: str = 'embeddings.db',
|
| 258 |
save_path: str = 'embeddings.pth',
|
|
|
|
| 259 |
**kwargs,
|
| 260 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 261 |
"""
|
|
@@ -264,7 +285,15 @@ class EmbeddingMixin:
|
|
| 264 |
Supports two modes:
|
| 265 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 266 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
|
|
|
|
|
|
|
|
|
| 267 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 269 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 270 |
hidden_size = self.config.hidden_size
|
|
@@ -653,8 +682,9 @@ def get_attention_mask(
|
|
| 653 |
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
|
| 654 |
return attention_mask_2d, None, flex_block_mask
|
| 655 |
|
| 656 |
-
# SDPA / manual
|
| 657 |
-
|
|
|
|
| 658 |
return attention_mask_2d, attention_mask_4d, None
|
| 659 |
|
| 660 |
|
|
|
|
| 156 |
return _collate_fn
|
| 157 |
|
| 158 |
|
| 159 |
+
def parse_fasta(fasta_path: str) -> List[str]:
|
| 160 |
+
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
|
| 161 |
+
sequences = []
|
| 162 |
+
current_seq = []
|
| 163 |
+
with open(fasta_path, 'r') as f:
|
| 164 |
+
for line in f:
|
| 165 |
+
line = line.strip()
|
| 166 |
+
if not line:
|
| 167 |
+
continue
|
| 168 |
+
if line.startswith('>'):
|
| 169 |
+
if current_seq:
|
| 170 |
+
sequences.append(''.join(current_seq))
|
| 171 |
+
current_seq = []
|
| 172 |
+
else:
|
| 173 |
+
current_seq.append(line)
|
| 174 |
+
if current_seq:
|
| 175 |
+
sequences.append(''.join(current_seq))
|
| 176 |
+
return sequences
|
| 177 |
+
|
| 178 |
+
|
| 179 |
class EmbeddingMixin:
|
| 180 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 181 |
raise NotImplementedError
|
|
|
|
| 263 |
|
| 264 |
def embed_dataset(
|
| 265 |
self,
|
| 266 |
+
sequences: Optional[List[str]] = None,
|
| 267 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 268 |
batch_size: int = 2,
|
| 269 |
max_len: int = 512,
|
|
|
|
| 276 |
save: bool = True,
|
| 277 |
sql_db_path: str = 'embeddings.db',
|
| 278 |
save_path: str = 'embeddings.pth',
|
| 279 |
+
fasta_path: Optional[str] = None,
|
| 280 |
**kwargs,
|
| 281 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 282 |
"""
|
|
|
|
| 285 |
Supports two modes:
|
| 286 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 287 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
| 288 |
+
|
| 289 |
+
Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
|
| 290 |
+
`fasta_path`, or both (the two sources are combined). At least one must be provided.
|
| 291 |
"""
|
| 292 |
+
if fasta_path is not None:
|
| 293 |
+
fasta_sequences = parse_fasta(fasta_path)
|
| 294 |
+
sequences = list(sequences or []) + fasta_sequences
|
| 295 |
+
assert sequences is not None and len(sequences) > 0, \
|
| 296 |
+
"Must provide at least one sequence via `sequences` or `fasta_path`."
|
| 297 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 298 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 299 |
hidden_size = self.config.hidden_size
|
|
|
|
| 682 |
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
|
| 683 |
return attention_mask_2d, None, flex_block_mask
|
| 684 |
|
| 685 |
+
# SDPA / manual — only mask the key dimension so padding query positions attend to
|
| 686 |
+
# real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
|
| 687 |
+
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 688 |
return attention_mask_2d, attention_mask_4d, None
|
| 689 |
|
| 690 |
|