Skip to content

Commit 428c9af

Browse files
committed
expose model version api
1 parent 4358e70 commit 428c9af

4 files changed

Lines changed: 54 additions & 0 deletions

File tree

python/rwkv_cpp/rwkv_cpp_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def __init__(
7070

7171
self._valid: bool = True
7272

73+
@property
74+
def arch_version_major(self) -> int:
75+
return self._library.rwkv_get_arch_version_major(self._ctx)
76+
77+
@property
78+
def arch_version_minor(self) -> int:
79+
return self._library.rwkv_get_arch_version_minor(self._ctx)
80+
7381
@property
7482
def n_vocab(self) -> int:
7583
return self._library.rwkv_get_n_vocab(self._ctx)

python/rwkv_cpp/rwkv_cpp_shared_library.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def __init__(self, shared_library_path: str) -> None:
7979
]
8080
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
8181

82+
self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p]
83+
self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32
84+
85+
self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p]
86+
self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32
87+
8288
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
8389
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
8490

@@ -261,6 +267,30 @@ def rwkv_eval_sequence_in_chunks(
261267
):
262268
raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')
263269

270+
def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int:
271+
"""
272+
Returns the major version used by the given model.
273+
274+
Parameters
275+
----------
276+
ctx : RWKVContext
277+
RWKV context obtained from rwkv_init_from_file.
278+
"""
279+
280+
return self.library.rwkv_get_arch_version_major(ctx.ptr)
281+
282+
def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int:
283+
"""
284+
Returns the minor version used by the given model.
285+
286+
Parameters
287+
----------
288+
ctx : RWKVContext
289+
RWKV context obtained from rwkv_init_from_file.
290+
"""
291+
292+
return self.library.rwkv_get_arch_version_minor(ctx.ptr)
293+
264294
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
265295
"""
266296
Returns the number of tokens in the given model's vocabulary.

rwkv.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r
152152
return rwkv_get_logits_len(ctx);
153153
}
154154

155+
// API function.
156+
size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx) {
157+
return (size_t) ctx->model->arch_version_major;
158+
}
159+
160+
// API function.
161+
size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx) {
162+
return (size_t) ctx->model->arch_version_minor;
163+
}
164+
155165
// API function.
156166
size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
157167
return (size_t) ctx->model->header.n_vocab;

rwkv.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ extern "C" {
172172
float * logits_out
173173
);
174174

175+
// Returns the major version used by the given model.
176+
RWKV_API size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx);
177+
178+
// Returns the minor version used by the given model.
179+
RWKV_API size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx);
180+
175181
// Returns the number of tokens in the given model's vocabulary.
176182
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
177183
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);

0 commit comments

Comments
 (0)