Skip to content

Commit 69e3386

Browse files
authored
Add tokenizer_name to base huggingface inference engines (#1862)
1 parent a0cfd56 commit 69e3386

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

src/unitxt/inference.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ class HFInferenceEngineBase(
488488
TorchDeviceMixin,
489489
):
490490
model_name: str
491+
tokenizer_name: Optional[str] = None
491492
label: str
492493

493494
n_top_tokens: int = 5
@@ -710,8 +711,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
710711
def _init_processor(self):
711712
from transformers import AutoTokenizer
712713

714+
tokenizer_name = self.tokenizer_name or self.model_name
713715
self.processor = AutoTokenizer.from_pretrained(
714-
pretrained_model_name_or_path=self.model_name,
716+
pretrained_model_name_or_path=tokenizer_name,
715717
use_fast=self.use_fast_tokenizer,
716718
)
717719

@@ -1120,6 +1122,7 @@ class HFPipelineBasedInferenceEngine(
11201122
TorchDeviceMixin,
11211123
):
11221124
model_name: str
1125+
tokenizer_name: Optional[str] = None
11231126
label: str = "hf_pipeline_inference_engine"
11241127

11251128
use_fast_tokenizer: bool = True
@@ -1217,8 +1220,8 @@ def _create_pipeline(self, model_args: Dict[str, Any]):
12171220
path = self.model_name
12181221
if settings.hf_offline_models_path is not None:
12191222
path = os.path.join(settings.hf_offline_models_path, path)
1220-
1221-
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
1223+
tokenizer_name = self.tokenizer_name or self.model_name
1224+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
12221225
self.model = pipeline(
12231226
model=path,
12241227
task=self.task,

0 commit comments

Comments
 (0)