Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/auto_scheduler/auto_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")

TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule")
.set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options);
auto [sch, return_tensors] = AutoSchedule(search_policy, tuning_options);
return Array<ObjectRef>{sch, return_tensors};
});
} // namespace auto_scheduler
Expand Down
17 changes: 6 additions & 11 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1325,10 +1325,9 @@ State ComputeDAG::InferBound(const State& state) const {

Array<te::Stage> stages;
StageToAxesMap stage_to_axes;
te::Schedule sch;
Array<te::Tensor> tensors;
// Replay steps to tvm::Schedule
std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
auto [sch, tensors] = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
(void)tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
sch = sch.normalize_for_feature_extraction();
// Get bound information from TVM schedule
Map<IterVar, Range> bounds = te::InferBound(sch);
Expand Down Expand Up @@ -1382,9 +1381,8 @@ Array<State> ComputeDAG::InferBound(const Array<State>& states) const {
}

ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
te::Schedule sch;
Array<te::Tensor> old_tensors;
std::tie(sch, old_tensors) = ApplySteps(transform_steps);
auto [sch, old_tensors] = ApplySteps(transform_steps);
(void)old_tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
return ComputeDAG(sch);
}

Expand Down Expand Up @@ -1481,11 +1479,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")

TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) =
dag.ApplySteps(state->transform_steps, nullptr, nullptr,
static_cast<LayoutRewriteOption>(layout_rewrite));
auto [sch, return_tensors] = dag.ApplySteps(state->transform_steps, nullptr, nullptr,
static_cast<LayoutRewriteOption>(layout_rewrite));
return Array<ObjectRef>{sch, return_tensors};
});

Expand Down
9 changes: 2 additions & 7 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
unique_lines = std::max(unique_lines, 1.0f);
}

ReuseType reuse_type;
float reuse_dis_iter, reuse_dis_bytes, reuse_ct;
std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) =
auto [reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct] =
ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_, ana_);

acc_feas.emplace_back();
Expand Down Expand Up @@ -1356,10 +1354,7 @@ void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) {

void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs,
std::vector<float>* feature, std::atomic<int>* error_ct) {
te::Schedule sch;
Array<te::Tensor> tensors;

std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps);
auto [sch, tensors] = task->compute_dag.ApplySteps(state->transform_steps);

// When inlining, replace const matrices with const values.
// Produces wrong IR, but good enough for feature extraction, and
Expand Down
4 changes: 1 addition & 3 deletions src/auto_scheduler/search_policy/search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")

TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound")
.set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) {
Array<MeasureInput> inputs;
Array<MeasureResult> results;
std::tie(inputs, results) = policy->ContinueSearchOneRound(num_measure, measurer);
auto [inputs, results] = policy->ContinueSearchOneRound(num_measure, measurer);
return Array<ObjectRef>{inputs, results};
});

Expand Down
3 changes: 1 addition & 2 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
const auto& op = state->stages[stage_id]->op;
if (op->IsInstance<te::ComputeOpNode>()) {
// Compute the product of lengths of all space iters and all reduce iters
int cum_space_len, cum_reduce_len;
std::tie(cum_space_len, cum_reduce_len) =
auto [cum_space_len, cum_reduce_len] =
GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);

if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
Expand Down
5 changes: 1 addition & 4 deletions src/ir/instrument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,7 @@ String RenderPassProfiles() {
os << std::fixed;

while (profiles.size() > 0) {
size_t depth;
PassProfile::Duration parent_duration;
PassProfile* profile;
std::tie(depth, parent_duration, profile) = profiles.top();
auto [depth, parent_duration, profile] = profiles.top();
profiles.pop();

// indent depth
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ class JSONDatabaseNode : public DatabaseNode {

Workload CommitWorkload(const IRModule& mod) {
// Try to insert `mod` into `workloads_`
decltype(this->workloads2idx_)::iterator it;
bool inserted = false;
std::tie(it, inserted) =
auto [it, inserted] =
this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1);
Workload workload = it->first;
// If `mod` is new in `workloads2idx_`, append it to the workload file
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/mutator/mutate_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::Fin
int old_decision = Downcast<Integer>(decision)->value;

// Step 2. Collect all the compute_at locations.
Array<tir::StmtSRef> location_srefs;
std::vector<int> location_indices;
std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref);
auto [location_srefs, location_indices] = CollectComputeLocation(sch->state(), block_sref);
// Step 3. Remove the old decision.
auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
if (it != location_indices.end()) {
Expand Down
6 changes: 1 addition & 5 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
// Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the
// block to its consumers. We want to fuse as much as possible because it results in
// significantly faster schedule.
bool fusible = false;
// `target_loop` is the loop position where the input block will be computed at.
tir::LoopRV target_loop{nullptr};
// `target_block` is the consumer block that we want to compute-at the input block to.
tir::BlockRV target_block{nullptr};
// `tgt_block_innermost_loop` is the innermost loop outside the target block.
tir::LoopRV tgt_block_innermost_loop{nullptr};

std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) =
auto [fusible, target_loop, target_block, tgt_block_innermost_loop] =
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);

// Step 3. Try block fusion.
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
result.clear();
while (!stack.empty()) {
// get the stack.top()
tir::Schedule sch;
Array<tir::BlockRV> blocks;
std::tie(sch, blocks) = stack.back();
auto [sch, blocks] = stack.back();
stack.pop_back();
// if all blocks are visited
if (blocks.empty()) {
Expand Down
12 changes: 3 additions & 9 deletions src/relay/collage/partition_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ std::vector<CandidatePartition> DFPatternPartitionRuleNode::AllCandidates(
continue;
}
IndexSet inside = MatcherToIndexSet(matcher);
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_;
CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec);
Expand Down Expand Up @@ -256,9 +254,7 @@ std::vector<CandidatePartition> OpCallByKindPartitionRuleNode::AllCandidates(
auto node = dataflow_graph.index_to_node(index);
Expr sub_expr = node->ref();
if (sub_expr->IsInstance<CallNode>()) {
OpPatternKind kind;
String label;
std::tie(kind, label) = SubExprKindAndLabel(sub_expr);
auto [kind, label] = SubExprKindAndLabel(sub_expr);
if (kind <= kOutEWiseFusable) {
IndexSet inside(dataflow_graph.size(), {index});
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
Expand Down Expand Up @@ -404,9 +400,7 @@ std::vector<CandidatePartition> HostPartitionRuleNode::AllCandidates(
continue;
}
IndexSet inside(dataflow_graph.size(), {index});
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label);
String rule_name = NestLabels(rule_name_, sub_graph->label_);
// We'll a zero cost for the candidate since we'll never want to actually estimate the cost
Expand Down
8 changes: 2 additions & 6 deletions src/relay/collage/sub_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph&
bool first = true;
OpPatternKind max_kind = kElemWise;
for (PostDfsIndex index : inside) {
OpPatternKind sub_kind;
std::string sub_label;
std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
auto [sub_kind, sub_label] = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
if (!sub_label.empty()) {
if (first) {
first = false;
Expand Down Expand Up @@ -995,9 +993,7 @@ transform::Pass PartitionForTesting(Integer max_exits, Bool allow_taps, String c
// Build the overall sub-graph, which will include any "Composite" functions as
// well as any nodes without a label.
IndexSet inside(dataflow_graph.size(), node_indexes);
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, inside, kind, label, std::move(nested_sub_graphs));

// Push the overall sub-graph into the final "Compiler" function.
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,9 +722,9 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
<< "qnn.conv2d supports only OIHW/HWIO/HWOI/OHWI kernel data layout.";
ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified.";

int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
auto [batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier] =
GetWorkload(arg_types, param);
(void)batch_size; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767

// zero points are allowed to be non-scalar. Let's check if that's the case.
bool dynamic_zp = false;
Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/op/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,11 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
output_zero_point, input_shape);

// alpha * Q_i'
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(alpha);
auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, shift);

// (1 - alpha) * zp_o
int32_t fixed_point_multiplier_z, shift_z;
std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha);
auto [fixed_point_multiplier_z, shift_z] = GetFixedPointMultiplierShift(1 - alpha);
auto scaled_z = FixedPointMultiply(output_zero_point, fixed_point_multiplier_z, shift_z);

// alpha * Q_i' + (1 - alpha) * zp_o
Expand Down
3 changes: 1 addition & 2 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale,
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;

Expand Down Expand Up @@ -128,8 +127,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
bool is_lshift_required = false;
for (auto multiplier : multipliers) {
int32_t fixed_pt_multiplier, shift;
std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
int lshift = shift > 0 ? shift : 0;
int rshift = shift > 0 ? 0 : -shift;
fixed_pt_multipliers.push_back(fixed_pt_multiplier);
Expand Down
6 changes: 2 additions & 4 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor);
auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
Expand Down Expand Up @@ -135,8 +134,7 @@ Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
} else {
data = Cast(data, DataType::Int(64));
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) =
auto [fixed_point_multiplier, shift] =
qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
Call MakeCombinedOp(const Group& branches) {
const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_channels;
std::tie(new_weight, new_channels) = TransformWeight(branches);
auto [new_weight, new_channels] = TransformWeight(branches);

const CallNode* group_root = branches[0][0];
const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
auto [new_weight, new_output_dims] = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
ICHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
Expand Down
4 changes: 1 addition & 3 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,7 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name,
});
} else if (name == "get_input_info") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
GraphExecutor::ShapeInfo shape_info;
GraphExecutor::DtypeInfo dtype_info;
std::tie(shape_info, dtype_info) = this->GetInputInfo();
auto [shape_info, dtype_info] = this->GetInputInfo();
Map<String, ObjectRef> input_info;
input_info.Set("shape", shape_info);
input_info.Set("dtype", dtype_info);
Expand Down
12 changes: 4 additions & 8 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,7 @@ class Replacer {
}
std::string rewrite(std::string str) {
for (auto&& rule : _rules) {
std::string pattern, replacement;
std::tie(pattern, replacement) = rule;
auto [pattern, replacement] = rule;
size_t len = pattern.size();
size_t new_len = replacement.size();
size_t pos = str.find(pattern);
Expand Down Expand Up @@ -532,8 +531,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
dtype_c = ptx::DTypeFromString(C_dtype);
ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
layout_b = ptx::LayoutTypeFromString(B_layout);
int m, n, k;
std::tie(m, n, k) = ptx::ParseMMAShape(shape);
auto [m, n, k] = ptx::ParseMMAShape(shape);
CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse,
saturate);
std::string asm_code = R"(
Expand All @@ -545,8 +543,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
: {inputs});
}
)";
std::string templates_str, inputs_str, outputs_str;
std::tie(templates_str, inputs_str, outputs_str) =
auto [templates_str, inputs_str, outputs_str] =
GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);

// replace patterns
Expand Down Expand Up @@ -622,8 +619,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
);
}
)";
std::string templates_str, outputs_str;
std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);

Replacer replacer;
replacer.register_rule("{.shape}", ".m8n8");
Expand Down
10 changes: 4 additions & 6 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1183,21 +1183,19 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
}

PrimExpr new_outer_cond, new_reduce_cond;
Array<PrimExpr> new_source = red->source;

// Partially lift conditions from the reduce condition
std::tie(new_outer_cond, new_reduce_cond) =
auto [new_outer_cond, new_reduce_cond] =
LiftConditionsThroughReduction(red->condition, red->axis, axis);

// If it's not sum then we haven't yet lifted nonzeroness cond from the source
if (!is_sum) {
PrimExpr outer_nz_cond, nz_cond, nz_source;
auto nz = NonzeronessCondition(red->source[red->value_index]);
// Append conditions from the reduction
nz_cond = new_reduce_cond && nz.cond;
nz_source = nz.value;
std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
PrimExpr nz_source = nz.value;
auto [outer_nz_cond, nz_cond] =
LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis);
new_outer_cond = new_outer_cond && outer_nz_cond;
new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
}
Expand Down
Loading