Skip to content

Commit 8bad5af

Browse files
jtmerOswinGuaiycycse
authored
[AINode] Introduce built-in timer_xl model for forecasting (#15468)
* latest timerxl code * add dependency of einops in ainode/pyproject.toml * Update pyproject.toml --------- Co-authored-by: OswinGuai <peizhyi@gmail.com> Co-authored-by: YangCaiyin <wiycy@foxmail.com>
1 parent 6df744a commit 8bad5af

17 files changed

Lines changed: 1951 additions & 3 deletions
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
import abc
19+
import math
20+
import torch
21+
from einops import rearrange
22+
from torch import nn
23+
24+
25+
class AttentionBias(nn.Module, abc.ABC):
26+
def __init__(self, dim: int, num_heads: int):
27+
super().__init__()
28+
assert num_heads > 0 and dim % num_heads == 0
29+
30+
self.num_heads = num_heads
31+
self.head_dim = dim // num_heads
32+
33+
@abc.abstractmethod
34+
def forward(self, query_id, kv_id): ...
35+
36+
37+
class BinaryAttentionBias(AttentionBias):
38+
def __init__(self, dim: int, num_heads: int):
39+
super().__init__(dim, num_heads)
40+
self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads)
41+
42+
def forward(self, query_id, kv_id):
43+
ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
44+
weight = rearrange(
45+
self.emb.weight, "two num_heads -> two num_heads 1 1")
46+
bias = ~ind * weight[:1] + ind * weight[1:]
47+
return bias
48+
49+
50+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
51+
relative_buckets = 0
52+
if bidirectional:
53+
num_buckets //= 2
54+
relative_buckets += (relative_position >
55+
0).to(torch.long) * num_buckets
56+
relative_position = torch.abs(relative_position)
57+
else:
58+
relative_position = - \
59+
torch.min(relative_position, torch.zeros_like(relative_position))
60+
61+
max_exact = num_buckets // 2
62+
is_small = relative_position < max_exact
63+
relative_position_if_large = max_exact + (
64+
torch.log(relative_position.float() / max_exact)
65+
/ math.log(max_distance / max_exact)
66+
* (num_buckets - max_exact)
67+
).to(torch.long)
68+
relative_position_if_large = torch.min(
69+
relative_position_if_large, torch.full_like(
70+
relative_position_if_large, num_buckets - 1)
71+
)
72+
73+
relative_buckets += torch.where(is_small,
74+
relative_position, relative_position_if_large)
75+
return relative_buckets
76+
77+
78+
class T5AttentionBias(AttentionBias):
79+
def __init__(self, dim: int, num_heads: int):
80+
super().__init__(dim, num_heads)
81+
self.num_buckets = 32
82+
self.max_distance = 32
83+
self.relative_attention_bias = nn.Embedding(self.num_buckets, 1)
84+
85+
def forward(self, n_vars, n_tokens):
86+
context_position = torch.arange(n_tokens, dtype=torch.long,)[:, None]
87+
memory_position = torch.arange(n_tokens, dtype=torch.long, )[None, :]
88+
relative_position = memory_position - context_position
89+
bucket = _relative_position_bucket(relative_position=relative_position, bidirectional=False,
90+
num_buckets=self.num_buckets, max_distance=self.max_distance).to(self.relative_attention_bias.weight.device)
91+
bias = self.relative_attention_bias(bucket).squeeze(-1)
92+
bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1])
93+
mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device)
94+
final_bias = torch.kron(mask1, bias)
95+
return final_bias
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
import abc
19+
import torch
20+
from functools import cached_property
21+
from einops import einsum, rearrange, repeat
22+
from torch import nn
23+
24+
25+
class Projection(nn.Module, abc.ABC):
26+
def __init__(self, proj_width: int, num_heads: int, **kwargs):
27+
super().__init__()
28+
self.proj_width = proj_width
29+
self.num_heads = num_heads
30+
31+
@abc.abstractmethod
32+
def forward(self, x, seq_id): ...
33+
34+
35+
class RotaryProjection(Projection):
36+
def __init__(self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000):
37+
super().__init__(proj_width, num_heads)
38+
assert (
39+
self.proj_width % 2 == 0
40+
), f"proj_width must be even, got {self.proj_width}"
41+
self.register_buffer(
42+
"theta",
43+
1.0
44+
/ torch.pow(
45+
base,
46+
torch.arange(0, self.proj_width, 2, dtype=torch.float)
47+
/ self.proj_width,
48+
),
49+
persistent=False,
50+
)
51+
self.register_buffer("cos", None, persistent=False)
52+
self.register_buffer("sin", None, persistent=False)
53+
self._init_freq(max_len=max_len)
54+
55+
def _init_freq(self, max_len: int):
56+
if self.cos is None or self.cos.size(-2) < max_len:
57+
position = torch.arange(
58+
max_len, device=self.theta.device, dtype=self.theta.dtype
59+
)
60+
m_theta = einsum(position, self.theta,
61+
"length, width -> length width")
62+
m_theta = repeat(m_theta, "length width -> length (width 2)")
63+
self.register_buffer("cos", torch.cos(m_theta), persistent=False)
64+
self.register_buffer("sin", torch.sin(m_theta), persistent=False)
65+
66+
@staticmethod
67+
def _rotate(x):
68+
x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2)
69+
return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa
70+
71+
def forward(self, x, seq_id):
72+
self._init_freq(max_len=seq_id.max() + 1)
73+
rot_cos = self.cos[seq_id]
74+
rot_sin = self.sin[seq_id]
75+
return rot_cos * x + rot_sin * self._rotate(x)
76+
77+
78+
class QueryKeyProjection(nn.Module):
79+
def __init__(self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None):
80+
super().__init__()
81+
if partial_factor is not None:
82+
assert (
83+
0.0 <= partial_factor[0] < partial_factor[1] <= 1.0
84+
), f"got {partial_factor[0]}, {partial_factor[1]}"
85+
assert num_heads > 0 and dim % num_heads == 0
86+
87+
self.head_dim = dim // num_heads
88+
self.partial_factor = partial_factor
89+
self.query_proj = proj_layer(
90+
proj_width=self.proj_width,
91+
num_heads=num_heads,
92+
**(kwargs or {}),
93+
)
94+
self.key_proj = self.query_proj
95+
96+
@cached_property
97+
def proj_width(self) -> int:
98+
if self.partial_factor is None:
99+
return self.head_dim
100+
return int(self.head_dim * (self.partial_factor[1] - self.partial_factor[0]))
101+
102+
@cached_property
103+
def split_sizes(self):
104+
if self.partial_factor is None:
105+
return 0, self.head_dim, 0
106+
return (
107+
int(self.partial_factor[0] * self.head_dim),
108+
self.proj_width,
109+
int((1.0 - self.partial_factor[1]) * self.head_dim),
110+
)
111+
112+
def forward(self, query, key, query_id, kv_id):
113+
if self.partial_factor is not None:
114+
queries = list(query.split(self.split_sizes, dim=-1))
115+
keys = list(key.split(self.split_sizes, dim=-1))
116+
queries[1] = self.query_proj(queries[1], seq_id=query_id)
117+
keys[1] = self.key_proj(keys[1], seq_id=kv_id)
118+
query = torch.cat(queries, dim=-1)
119+
key = torch.cat(keys, dim=-1)
120+
else:
121+
query = self.query_proj(query, seq_id=query_id)
122+
key = self.key_proj(key, seq_id=kv_id)
123+
return query, key

0 commit comments

Comments
 (0)