Skip to content

Commit e3e24b6

Browse files
authored
[ML] Harden pytorch_inference with TorchScript model graph validation (#2936) (#2987)
Add a static TorchScript graph validation layer that rejects models containing operations not observed in supported transformer architectures. This reduces the attack surface by ensuring only known-safe operation sets are permitted, complementing the existing Sandbox2/seccomp defenses. Backports #2936
1 parent 65d677a commit e3e24b6

37 files changed

+2790
-17
lines changed

bin/pytorch_inference/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ ml_add_executable(pytorch_inference
3535
CBufferedIStreamAdapter.cc
3636
CCmdLineParser.cc
3737
CCommandParser.cc
38+
CModelGraphValidator.cc
3839
CResultWriter.cc
40+
CSupportedOperations.cc
3941
CThreadSettings.cc
4042
)
4143

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the following additional limitation. Functionality enabled by the
5+
* files subject to the Elastic License 2.0 may only be used in production when
6+
* invoked by an Elasticsearch process with a license key installed that permits
7+
* use of machine learning features. You may not use this file except in
8+
* compliance with the Elastic License 2.0 and the foregoing additional
9+
* limitation.
10+
*/
11+
12+
#include "CModelGraphValidator.h"
13+
14+
#include "CSupportedOperations.h"
15+
16+
#include <core/CLogger.h>
17+
18+
#include <torch/csrc/jit/passes/inliner.h>
19+
20+
#include <algorithm>
21+
22+
namespace ml {
23+
namespace torch {
24+
25+
CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) {
26+
27+
TStringSet observedOps;
28+
std::size_t nodeCount{0};
29+
collectModuleOps(module, observedOps, nodeCount);
30+
31+
if (nodeCount > MAX_NODE_COUNT) {
32+
LOG_ERROR(<< "Model graph is too large: " << nodeCount
33+
<< " nodes exceeds limit of " << MAX_NODE_COUNT);
34+
return {false, {}, {}, nodeCount};
35+
}
36+
37+
LOG_DEBUG(<< "Model graph contains " << observedOps.size()
38+
<< " distinct operations across " << nodeCount << " nodes");
39+
for (const auto& op : observedOps) {
40+
LOG_DEBUG(<< " observed op: " << op);
41+
}
42+
43+
auto result = validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS,
44+
CSupportedOperations::FORBIDDEN_OPERATIONS);
45+
result.s_NodeCount = nodeCount;
46+
return result;
47+
}
48+
49+
CModelGraphValidator::SResult
50+
CModelGraphValidator::validate(const TStringSet& observedOps,
51+
const std::unordered_set<std::string_view>& allowedOps,
52+
const std::unordered_set<std::string_view>& forbiddenOps) {
53+
54+
SResult result;
55+
56+
// Two-pass check: forbidden ops first, then unrecognised. This lets us
57+
// fail fast when a known-dangerous operation is present and avoids the
58+
// cost of scanning for unrecognised ops on a model we will reject anyway.
59+
for (const auto& op : observedOps) {
60+
if (forbiddenOps.contains(op)) {
61+
result.s_IsValid = false;
62+
result.s_ForbiddenOps.push_back(op);
63+
}
64+
}
65+
66+
if (result.s_ForbiddenOps.empty()) {
67+
for (const auto& op : observedOps) {
68+
if (allowedOps.contains(op) == false) {
69+
result.s_IsValid = false;
70+
result.s_UnrecognisedOps.push_back(op);
71+
}
72+
}
73+
}
74+
75+
std::sort(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end());
76+
std::sort(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end());
77+
78+
return result;
79+
}
80+
81+
void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block,
82+
TStringSet& ops,
83+
std::size_t& nodeCount) {
84+
for (const auto* node : block.nodes()) {
85+
if (++nodeCount > MAX_NODE_COUNT) {
86+
return;
87+
}
88+
ops.emplace(node->kind().toQualString());
89+
for (const auto* subBlock : node->blocks()) {
90+
collectBlockOps(*subBlock, ops, nodeCount);
91+
if (nodeCount > MAX_NODE_COUNT) {
92+
return;
93+
}
94+
}
95+
}
96+
}
97+
98+
void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module,
99+
TStringSet& ops,
100+
std::size_t& nodeCount) {
101+
for (const auto& method : module.get_methods()) {
102+
// Inline all method calls so that operations hidden behind
103+
// prim::CallMethod are surfaced. After inlining, any remaining
104+
// prim::CallMethod indicates a call that could not be resolved
105+
// statically and will be flagged as unrecognised.
106+
auto graph = method.graph()->copy();
107+
::torch::jit::Inline(*graph);
108+
collectBlockOps(*graph->block(), ops, nodeCount);
109+
if (nodeCount > MAX_NODE_COUNT) {
110+
return;
111+
}
112+
}
113+
}
114+
}
115+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the following additional limitation. Functionality enabled by the
5+
* files subject to the Elastic License 2.0 may only be used in production when
6+
* invoked by an Elasticsearch process with a license key installed that permits
7+
* use of machine learning features. You may not use this file except in
8+
* compliance with the Elastic License 2.0 and the foregoing additional
9+
* limitation.
10+
*/
11+
12+
#ifndef INCLUDED_ml_torch_CModelGraphValidator_h
13+
#define INCLUDED_ml_torch_CModelGraphValidator_h
14+
15+
#include <torch/script.h>
16+
17+
#include <string>
18+
#include <string_view>
19+
#include <unordered_set>
20+
#include <vector>
21+
22+
namespace ml {
23+
namespace torch {
24+
25+
//! \brief
26+
//! Validates TorchScript model computation graphs against a set of
27+
//! allowed operations.
28+
//!
29+
//! DESCRIPTION:\n
30+
//! Provides defense-in-depth by statically inspecting the TorchScript
31+
//! graph of a loaded model and rejecting any model that contains
32+
//! operations not present in the allowlist derived from supported
33+
//! transformer architectures.
34+
//!
35+
//! IMPLEMENTATION DECISIONS:\n
36+
//! The validation walks all methods of the module and its submodules
37+
//! recursively, collecting every distinct operation. Any operation
38+
//! that appears in the forbidden set causes immediate rejection.
39+
//! Any operation not in the allowed set is collected and reported.
40+
//! This ensures that even operations buried in helper methods or
41+
//! nested submodules are inspected.
42+
//!
43+
class CModelGraphValidator {
44+
public:
45+
using TStringSet = std::unordered_set<std::string>;
46+
using TStringVec = std::vector<std::string>;
47+
48+
//! Upper bound on the number of graph nodes we are willing to inspect.
49+
//! Transformer models typically have O(10k) nodes after inlining; a
50+
//! limit of 1M provides generous headroom while preventing a
51+
//! pathologically large graph from consuming unbounded memory or CPU.
52+
static constexpr std::size_t MAX_NODE_COUNT{1000000};
53+
54+
//! Result of validating a model graph.
55+
struct SResult {
56+
bool s_IsValid{true};
57+
TStringVec s_ForbiddenOps;
58+
TStringVec s_UnrecognisedOps;
59+
std::size_t s_NodeCount{0};
60+
};
61+
62+
public:
63+
//! Validate the computation graph of the given module against the
64+
//! supported operation allowlist. Recursively inspects all methods
65+
//! across all submodules.
66+
static SResult validate(const ::torch::jit::Module& module);
67+
68+
//! Validate a pre-collected set of operation names. Useful for
69+
//! unit testing the matching logic without requiring a real model.
70+
static SResult validate(const TStringSet& observedOps,
71+
const std::unordered_set<std::string_view>& allowedOps,
72+
const std::unordered_set<std::string_view>& forbiddenOps);
73+
74+
private:
75+
//! Collect all operation names from a block, recursing into sub-blocks.
76+
static void collectBlockOps(const ::torch::jit::Block& block,
77+
TStringSet& ops,
78+
std::size_t& nodeCount);
79+
80+
//! Inline all method calls and collect ops from the flattened graph.
81+
//! After inlining, prim::CallMethod should not appear; if it does,
82+
//! the call could not be resolved statically and is treated as
83+
//! unrecognised.
84+
static void collectModuleOps(const ::torch::jit::Module& module,
85+
TStringSet& ops,
86+
std::size_t& nodeCount);
87+
};
88+
}
89+
}
90+
91+
#endif // INCLUDED_ml_torch_CModelGraphValidator_h
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the following additional limitation. Functionality enabled by the
5+
* files subject to the Elastic License 2.0 may only be used in production when
6+
* invoked by an Elasticsearch process with a license key installed that permits
7+
* use of machine learning features. You may not use this file except in
8+
* compliance with the Elastic License 2.0 and the foregoing additional
9+
* limitation.
10+
*/
11+
12+
#include "CSupportedOperations.h"
13+
14+
namespace ml {
15+
namespace torch {
16+
17+
using namespace std::string_view_literals;
18+
19+
const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = {
20+
// Arbitrary memory access — enables heap scanning, address leaks, and
21+
// ROP chain construction.
22+
"aten::as_strided"sv,
23+
"aten::from_file"sv,
24+
"aten::save"sv,
25+
// After graph inlining, method and function calls should be resolved.
26+
// Their presence indicates an opaque call that cannot be validated.
27+
"prim::CallFunction"sv,
28+
"prim::CallMethod"sv,
29+
};
30+
31+
// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1.
32+
// Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased,
33+
// google/electra-small-discriminator, microsoft/mpnet-base,
34+
// microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base,
35+
// google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3,
36+
// elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english,
37+
// elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser,
38+
// elastic/multilingual-e5-small-optimized, elastic/splade-v3,
39+
// elastic/test-elser-v2.
40+
// Additional ops from Elasticsearch integration test models
41+
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
42+
const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = {
43+
// aten operations (core tensor computations)
44+
"aten::Int"sv,
45+
"aten::IntImplicit"sv,
46+
"aten::ScalarImplicit"sv,
47+
"aten::__and__"sv,
48+
"aten::abs"sv,
49+
"aten::add"sv,
50+
"aten::add_"sv,
51+
"aten::arange"sv,
52+
"aten::bitwise_not"sv,
53+
"aten::cat"sv,
54+
"aten::chunk"sv,
55+
"aten::clamp"sv,
56+
"aten::contiguous"sv,
57+
"aten::cumsum"sv,
58+
"aten::div"sv,
59+
"aten::div_"sv,
60+
"aten::dropout"sv,
61+
"aten::embedding"sv,
62+
"aten::expand"sv,
63+
"aten::full_like"sv,
64+
"aten::gather"sv,
65+
"aten::ge"sv,
66+
"aten::gelu"sv,
67+
"aten::hash"sv,
68+
"aten::index"sv,
69+
"aten::index_put_"sv,
70+
"aten::layer_norm"sv,
71+
"aten::len"sv,
72+
"aten::linear"sv,
73+
"aten::log"sv,
74+
"aten::lt"sv,
75+
"aten::manual_seed"sv,
76+
"aten::masked_fill"sv,
77+
"aten::matmul"sv,
78+
"aten::max"sv,
79+
"aten::mean"sv,
80+
"aten::min"sv,
81+
"aten::mul"sv,
82+
"aten::ne"sv,
83+
"aten::neg"sv,
84+
"aten::new_ones"sv,
85+
"aten::ones"sv,
86+
"aten::pad"sv,
87+
"aten::permute"sv,
88+
"aten::pow"sv,
89+
"aten::rand"sv,
90+
"aten::relu"sv,
91+
"aten::repeat"sv,
92+
"aten::reshape"sv,
93+
"aten::rsub"sv,
94+
"aten::scaled_dot_product_attention"sv,
95+
"aten::select"sv,
96+
"aten::size"sv,
97+
"aten::slice"sv,
98+
"aten::softmax"sv,
99+
"aten::sqrt"sv,
100+
"aten::squeeze"sv,
101+
"aten::str"sv,
102+
"aten::sub"sv,
103+
"aten::tanh"sv,
104+
"aten::tensor"sv,
105+
"aten::to"sv,
106+
"aten::transpose"sv,
107+
"aten::type_as"sv,
108+
"aten::unsqueeze"sv,
109+
"aten::view"sv,
110+
"aten::where"sv,
111+
"aten::zeros"sv,
112+
// prim operations (TorchScript graph infrastructure)
113+
"prim::Constant"sv,
114+
"prim::DictConstruct"sv,
115+
"prim::GetAttr"sv,
116+
"prim::If"sv,
117+
"prim::ListConstruct"sv,
118+
"prim::ListUnpack"sv,
119+
"prim::Loop"sv,
120+
"prim::NumToTensor"sv,
121+
"prim::TupleConstruct"sv,
122+
"prim::TupleUnpack"sv,
123+
"prim::device"sv,
124+
"prim::dtype"sv,
125+
"prim::max"sv,
126+
"prim::min"sv,
127+
};
128+
}
129+
}

0 commit comments

Comments
 (0)