diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py index 1945fa25a2e69..a7baf68d1d55e 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py @@ -247,6 +247,18 @@ def forward(self, x, max_new_tokens: int = 96): self.eval() self.device = next(self.model.parameters()).device + if len(x.shape) == 2: + batch_size, cur_len = x.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}") + elif cur_len % self.config.input_token_len != 0: + new_len = (cur_len // self.config.input_token_len) * \ + self.config.input_token_len + x = x[:, -new_len:] + else: + raise ValueError('Input shape must be: [batch_size, seq_len]') + use_cache = self.config.use_cache all_input_ids = x