| | |
| | |
| | |
| |
|
| | from torch import Tensor |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| |
|
| | |
| | @dataclass |
| | class InferenceParams: |
| | """Inference parameters that are passed to the main model in order |
| | to efficienly calculate and store the context during inference.""" |
| |
|
| | max_seqlen: int |
| | max_batch_size: int |
| | seqlen_offset: int = 0 |
| | batch_size_offset: int = 0 |
| | key_value_memory_dict: dict = field(default_factory=dict) |
| | lengths_per_sample: Optional[Tensor] = None |
| |
|
| | def reset(self, max_seqlen, max_batch_size): |
| | self.max_seqlen = max_seqlen |
| | self.max_batch_size = max_batch_size |
| | self.seqlen_offset = 0 |
| | if self.lengths_per_sample is not None: |
| | self.lengths_per_sample.zero_() |
| |
|
| |
|
| | @dataclass |
| | class RecurrentInferenceParams: |
| | """Inference parameters passed to blocks with recurrent mode.""" |
| |
|
| | fir_filter_length: int = 3 |
| | state_dim: int = 16 |
| | seqlen_offset: int = 0 |
| | fir_state_dict: dict = field(default_factory=dict) |
| | state_dict: dict = field(default_factory=dict) |
| |
|
| | def reset(self): |
| | self.fir_filter_length = 3 |
| | self.state_dim = 16 |
| | self.seqlen_offset = 0 |
| |
|