Skip to content

[FRONTEND] Add max_trials_per_task parameter to static_shape_tuning_pipeline#18159

Closed
ConvolutedDog wants to merge 3 commits into
apache:mainfrom
ConvolutedDog:patch-2
Closed

[FRONTEND] Add max_trials_per_task parameter to static_shape_tuning_pipeline#18159
ConvolutedDog wants to merge 3 commits into
apache:mainfrom
ConvolutedDog:patch-2

Conversation

@ConvolutedDog

@ConvolutedDog ConvolutedDog commented Jul 21, 2025

Copy link
Copy Markdown
Contributor

This PR introduces the max_trials_per_task parameter to the static_shape_tuning_pipeline, giving users explicit control over the maximum number of trials allocated to each task during MetaSchedule tuning.

Currently, the number of tuning trials per task per iteration is determined by min(max_trials_per_iter=64, total_trials). This leads to potential issues when tuning on GPUs:

When the user-specified total_trials is:

  • Less than (number_of_tasks * 64) but greater than max_trials_per_iter=64, or
  • Less than max_trials_per_iter=64 itself,

Some TIR functions will not be tuned at all, and fail to be bound to threads properly.

…ipeline

This commit introduces the max_trials_per_task parameter to the static_shape_tuning_pipeline, giving users explicit control over the maximum number of trials allocated to each task during MetaSchedule tuning.
Comment thread docs/how_to/tutorials/e2e_opt_model.py Outdated
gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
gpu_out = vm["main"](gpu_data, *gpu_params)

gpu_out = np.array([gpu_out[0].numpy()[0][j] for j in range(1000)])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have simplified the numpy conversion.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the new FFI was introduced, vm["main"](...) returns an object of <class 'tvm.ffi.container.Array'>, so a [0] needs to be added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants