Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,6 +1725,20 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)

# If dims is a list, filter out axes where dimension is not 1
# This is needed because PyTorch decomposition may pass all axes
if isinstance(dim, (list, tuple)) and len(dim) > 0:
shape = self.shape_of(x)
# Filter to only include axes where the dimension is 1
valid_dims = []
for d in dim:
axis = d if d >= 0 else len(shape) + d
if axis < len(shape) and shape[axis] == 1:
valid_dims.append(d)
# If no valid dims, use None to squeeze all size-1 dimensions
dim = valid_dims if valid_dims else None

return self.block_builder.emit(relax.op.squeeze(x, dim))

def _stack(self, node: fx.Node) -> relax.Var:
Expand Down
20 changes: 16 additions & 4 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,23 @@ def _select(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.take(x, index, dim))

def _slice(self, node: fx.Node) -> relax.Var:
import sys

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For style consistency and to avoid potential repeated import overhead, import statements should be at the top of the file. Please move import sys to the module's top-level imports.


x = self.env[node.args[0]]
axes = [node.args[1]]
begin = [node.args[2]]
end = [node.args[3]]
stride = [node.args[4] if len(node.args) > 4 else 1]
dim = node.args[1] if len(node.args) > 1 else 0
start = node.args[2] if len(node.args) > 2 else None
end_val = node.args[3] if len(node.args) > 3 else None
step = node.args[4] if len(node.args) > 4 else 1

if start is None:
start = 0
if end_val is None:
end_val = sys.maxsize

axes = [dim]
begin = [start]
end = [end_val]
stride = [step]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))

def _unflatten(self, node: fx.Node) -> relax.Var:
Expand Down
Loading
Loading