Skip to content

[tosa] : Add option to enable/disable patterns selectively.#4485

Merged
sahas3 merged 2 commits into
llvm:mainfrom
sahas3:enablePatterns
Mar 6, 2026
Merged

[tosa] : Add option to enable/disable patterns selectively.#4485
sahas3 merged 2 commits into
llvm:mainfrom
sahas3:enablePatterns

Conversation

@sahas3

@sahas3 sahas3 commented Mar 4, 2026

Copy link
Copy Markdown
Member

Consider the source IR:

func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
  %int7 = torch.constant.int 7
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %false = torch.constant.bool false
  %none = torch.constant.none
  %kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
  %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
  %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
  %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
  return %0 : !torch.vtensor<[1,512,1,1],f32>
}

When lowered through TOSA path we get

❯ torch-mlir-opt --convert-torch-to-tosa /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect                                     
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>                                                                                                                                                                                                       
module {                                                                                                                                                                                                                                                      
  func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {                                                                                                                                         
    %cst = arith.constant 4.900000e+01 : f32                                                                                                                                                                                                                  
    %cst_0 = arith.constant 0.000000e+00 : f32                                                                                                                                                                                                                
    %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32>                                                                                                                                                      
    %1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int                                                                                                                                                                                        
    %2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int                                                                                                                                                                                        
    %3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int                                                                                                                                                                                        
    %4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool                                                                                                                                                                                        
    %5 = "torch.constant.none"() : () -> !torch.none                                                                                                                                                                                                          
    %6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int>                                                                                                                                                                    
    %7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int>                                                                                                                                                                    
    %8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int>                                                                                                                                                                    
    %9 = tensor.empty() : tensor<1x7x7x512xf32>                                                                                                                                                                                                               
    %transposed = linalg.transpose ins(%0 : tensor<1x512x7x7xf32>) outs(%9 : tensor<1x7x7x512xf32>) permutation = [0, 2, 3, 1]                                                                                                                                
    %10 = tensor.empty() : tensor<1x1x1x512xf32>                                                                                                                                                                                                              
    %11 = linalg.fill ins(%cst_0 : f32) outs(%10 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32>                                                                                                                                                            
    %12 = tensor.empty() : tensor<7x7xf32>                                                                                                                                                                                                                    
    %13 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%transposed, %12 : tensor<1x7x7x512xf32>, tensor<7x7xf32>) outs(%11 : tensor<1x1x1x512xf32>) -> tensor<1x1x1x512xf32>                        
    %14 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x1x1x512xf32>) outs(%10 : tensor<1x1x1x512xf32>) {                                                               
    ^bb0(%in: f32, %out: f32):                                                                                                                                                                                                                                
      %17 = arith.divf %in, %cst : f32
      linalg.yield %17 : f32
    } -> tensor<1x1x1x512xf32>
    %15 = tensor.empty() : tensor<1x512x1x1xf32>
    %transposed_1 = linalg.transpose ins(%14 : tensor<1x1x1x512xf32>) outs(%15 : tensor<1x512x1x1xf32>) permutation = [0, 3, 1, 2] 
    %16 = "torch_c.from_builtin_tensor"(%transposed_1) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32>
    return %16 : !torch.vtensor<[1,512,1,1],f32>
  }
}

When lowered through linalg path we get:

❯ torch-mlir-opt --convert-torch-to-linalg /tmp/torch.mlir --mlir-print-op-generic | mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,cse))" --allow-unregistered-dialect
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
    %cst = arith.constant 4.900000e+01 : f32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1,512,7,7],f32>) -> tensor<1x512x7x7xf32>
    %1 = "torch.constant.int"() <{value = 7 : i64}> : () -> !torch.int
    %2 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int
    %3 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int
    %4 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
    %5 = "torch.constant.none"() : () -> !torch.none
    %6 = "torch.prim.ListConstruct"(%1, %1) : (!torch.int, !torch.int) -> !torch.list<int>
    %7 = "torch.prim.ListConstruct"(%2, %2) : (!torch.int, !torch.int) -> !torch.list<int>
    %8 = "torch.prim.ListConstruct"(%3, %3) : (!torch.int, !torch.int) -> !torch.list<int>
    %9 = tensor.empty() : tensor<1x512x1x1xf32>
    %10 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32>
    %11 = tensor.empty() : tensor<7x7xf32>
    %12 = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%0, %11 : tensor<1x512x7x7xf32>, tensor<7x7xf32>) outs(%10 : tensor<1x512x1x1xf32>) -> tensor<1x512x1x1xf32>
    %13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x512x1x1xf32>) outs(%9 : tensor<1x512x1x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %15 = arith.divf %in, %cst : f32
      linalg.yield %15 : f32
    } -> tensor<1x512x1x1xf32>
    %14 = "torch_c.from_builtin_tensor"(%13) : (tensor<1x512x1x1xf32>) -> !torch.vtensor<[1,512,1,1],f32>
    return %14 : !torch.vtensor<[1,512,1,1],f32>
  }
}

Because of layout mismatch between PyTorch (NCHW) and TOSA (NHWC), there will be two additional transpose operations in the TOSA path. This requires two additional buffers which leads to a problem for resource-constrained embedded HW which don't have enough memory.

This change adds an option to selectively enable/disable legalizations through the TOSA path, so that for the tosa_linalg path we can choose to not lower some ops (depending on the target HW) through TOSA and instead let it lower through the linalg path that runs after TOSA path.

@sahas3 sahas3 requested review from Lallapallooza and sjarus March 4, 2026 13:34
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated
Comment thread lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated
@sjarus

sjarus commented Mar 4, 2026

Copy link
Copy Markdown
Collaborator

@sahas3 I'm curious as to whether https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp can help you here. The original PR lists significant performance gains from this pass, since further optimized by @Hanumanth04 in llvm/llvm-project#148755 .

@sahas3

sahas3 commented Mar 4, 2026

Copy link
Copy Markdown
Member Author

Hi @sjarus, I think it's not possible to optimize this particular scenario in TOSA. Running the full Torch->TOSA pipeline produces:

❯ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/torch.mlir 
module {
  func.func @torch.aten.avg_pool2d$basic(%arg0: tensor<1x512x7x7xf32>) -> tensor<1x512x1x1xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %1 = tosa.const_shape  {values = dense<[1, 512, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %2 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x7x7xf32>) -> tensor<1x7x7x512xf32>
    %3 = tosa.avg_pool2d %2, %0, %0 {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x512xf32>
    %4 = tosa.reshape %3, %1 : (tensor<1x1x1x512xf32>, !tosa.shape<4>) -> tensor<1x512x1x1xf32>
    return %4 : tensor<1x512x1x1xf32>
  }
}

Running the TosaReduceTranspose pass doesn't have any transpose ops that it can move around to reduce any further:

❯ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/torch.mlir | mlir-opt --tosa-reduce-transposes
module {
  func.func @torch.aten.avg_pool2d$basic(%arg0: tensor<1x512x7x7xf32>) -> tensor<1x512x1x1xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %1 = tosa.const_shape  {values = dense<[1, 512, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %2 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x7x7xf32>) -> tensor<1x7x7x512xf32>
    %3 = tosa.avg_pool2d %2, %0, %0 {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x512xf32>
    %4 = tosa.reshape %3, %1 : (tensor<1x1x1x512xf32>, !tosa.shape<4>) -> tensor<1x512x1x1xf32>
    return %4 : tensor<1x512x1x1xf32>
  }
}

One possibility is to optimize transpose -> linalg.pooling_nhwc_sum -> transpose into linalg.pooling_nchw_sum but that has to be done at the linalg level and also it seems such an optimization will be quite tied with the specific pattern we are trying to match.

@sjarus

sjarus commented Mar 4, 2026

Copy link
Copy Markdown
Collaborator

One possibility is to optimize transpose -> linalg.pooling_nhwc_sum -> transpose into linalg.pooling_nchw_sum but that has to be done at the linalg level .

Yeah, your fundamental problem is that the source and target dialects - Torch and LinAlg - support NCHW but TOSA does not. The reduce_transposes pass does a great job of eliminating most of the transposes, but in smaller testcases they'll persist, and in that case you're probably going to have to implement a transformation at the linalg level to spot transpose->linalg.*_nhwc

@sahas3

sahas3 commented Mar 4, 2026

Copy link
Copy Markdown
Member Author

Yeah, your fundamental problem is that the source and target dialects - Torch and LinAlg - support NCHW but TOSA does not. The reduce_transposes pass does a great job of eliminating most of the transposes, but in smaller testcases they'll persist, and in that case you're probably going to have to implement a transformation at the linalg level to spot transpose->linalg.*_nhwc

Yes, exactly. This is only a problem for small models targeting small hardware where we want to minimize buffers needed as much as possible. We may have to write some optimizations at the linalg level in the long term but the change in this PR enables us to target such small hardware without many pattern specific linalg optimizations.

@sjarus

sjarus commented Mar 4, 2026

Copy link
Copy Markdown
Collaborator

This is also hardware dependent. Not all hardware can handle both NHWC and NCHW. Some of them make a call either way. The TOSA baking in of NHWC was such a call. The idea was that if underlying hardware was NCHW, the backend ought to identify the transpose->conv/pool pair and swap dims.

Of course, in future the hard restriction on NHWC may go away, but that's not currently the case.

@sahas3 sahas3 merged commit 56e635e into llvm:main Mar 6, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants