@@ -488,6 +488,7 @@ class HFInferenceEngineBase(
488488 TorchDeviceMixin ,
489489):
490490 model_name : str
491+ tokenizer_name : Optional [str ] = None
491492 label : str
492493
493494 n_top_tokens : int = 5
@@ -710,8 +711,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
710711 def _init_processor (self ):
711712 from transformers import AutoTokenizer
712713
714+ tokenizer_name = self .tokenizer_name or self .model_name
713715 self .processor = AutoTokenizer .from_pretrained (
714- pretrained_model_name_or_path = self . model_name ,
716+ pretrained_model_name_or_path = tokenizer_name ,
715717 use_fast = self .use_fast_tokenizer ,
716718 )
717719
@@ -1120,6 +1122,7 @@ class HFPipelineBasedInferenceEngine(
11201122 TorchDeviceMixin ,
11211123):
11221124 model_name : str
1125+ tokenizer_name : Optional [str ] = None
11231126 label : str = "hf_pipeline_inference_engine"
11241127
11251128 use_fast_tokenizer : bool = True
@@ -1217,8 +1220,8 @@ def _create_pipeline(self, model_args: Dict[str, Any]):
12171220 path = self .model_name
12181221 if settings .hf_offline_models_path is not None :
12191222 path = os .path .join (settings .hf_offline_models_path , path )
1220-
1221- tokenizer = AutoTokenizer .from_pretrained (self . model_name )
1223+ tokenizer_name = self . tokenizer_name or self . model_name
1224+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_name )
12221225 self .model = pipeline (
12231226 model = path ,
12241227 task = self .task ,
0 commit comments