Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions backend/backend/core/scheduler/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,30 @@ def _trigger_chained_job(user_task: UserTaskDetails, user_id: int, organization_
acks_late=True,
max_retries=0, # we handle retries ourselves
)
def trigger_scheduled_run(self, *, user_task_id: int, user_id: int, organization_id: str = None):
def trigger_scheduled_run(
self,
*,
user_task_id: int,
user_id: int,
organization_id: str = None,
models_override: list = None,
trigger: str = "scheduled",
):
"""Execute a scheduled Visitran run.

This is the Celery task wired to ``Task.SCHEDULER_JOB``.

Args:
models_override: If provided, execute only these model names (plus
their downstream dependents) instead of every model in
``user_task.model_configs``. Used by the Quick Deploy flow to
run a single model against the job's environment.
trigger: "scheduled" (default, used by Celery beat) or "manual"
(used by ad-hoc dispatch from trigger_task_once*). Stored in
TaskRunHistory.kwargs alongside ``scope`` so Run History can
distinguish scheduled vs on-demand runs.
"""
scope = "model" if models_override else "job"
from backend.application.context.application import ApplicationContext
from backend.utils.tenant_context import _get_tenant_context
from backend.core.models.user_model import User
Expand Down Expand Up @@ -228,17 +247,23 @@ def trigger_scheduled_run(self, *, user_task_id: int, user_id: int, organization

# ── Create run-history entry ──────────────────────────────────────
# Note: organization is automatically set by DefaultOrganizationMixin from tenant context
run_kwargs = {
"user_task_id": user_task_id,
"user_id": user_id,
"model_configs": user_task.model_configs,
"trigger": trigger,
"scope": scope,
}
if models_override:
run_kwargs["models_override"] = list(models_override)

run = TaskRunHistory.objects.create(
task_id=self.request.id or f"manual-{user_task_id}-{uuid.uuid4().hex[:8]}",
retry_num=retry_num,
status="STARTED",
start_time=timezone.now(),
user_task_detail=user_task,
kwargs={
"user_task_id": user_task_id,
"user_id": user_id,
"model_configs": user_task.model_configs,
},
kwargs=run_kwargs,
)

# ── Mark task as running ──────────────────────────────────────────
Expand Down Expand Up @@ -288,7 +313,13 @@ def trigger_scheduled_run(self, *, user_task_id: int, user_id: int, organization
timeout = user_task.run_timeout_seconds or 0

with _timeout_guard(timeout):
app_context.execute_visitran_run_command(environment_id=environment_id)
if models_override:
app_context.execute_visitran_run_command(
environment_id=environment_id,
current_models=list(models_override),
)
else:
app_context.execute_visitran_run_command(environment_id=environment_id)

# ── Mark success ──────────────────────────────────────────────
success = True
Expand Down Expand Up @@ -325,12 +356,16 @@ def trigger_scheduled_run(self, *, user_task_id: int, user_id: int, organization
retry_num + 1,
user_task.max_retries,
)
retry_kwargs = {
"user_task_id": user_task_id,
"user_id": user_id,
"organization_id": organization_id,
"trigger": trigger,
}
if models_override:
retry_kwargs["models_override"] = list(models_override)
trigger_scheduled_run.apply_async(
kwargs={
"user_task_id": user_task_id,
"user_id": user_id,
"organization_id": organization_id,
},
kwargs=retry_kwargs,
countdown=30 * (retry_num + 1), # progressive backoff
)
return
Expand Down
18 changes: 18 additions & 0 deletions backend/backend/core/scheduler/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
update_periodic_task,
task_run_history,
trigger_task_once,
trigger_task_once_for_model,
list_deploy_candidates,
list_recent_runs_for_model,
get_periodic_task,
get_model_columns,
)
Expand All @@ -32,6 +35,21 @@
trigger_task_once,
name="trigger_task_once",
),
path(
"/trigger-periodic-task/<int:user_task_id>/model/<str:model_name>",
trigger_task_once_for_model,
name="trigger_task_once_for_model",
),
path(
"/quick-deploy/candidates/<str:model_name>",
list_deploy_candidates,
name="list_deploy_candidates",
),
path(
"/quick-deploy/recent-runs/<str:model_name>",
list_recent_runs_for_model,
name="list_recent_runs_for_model",
),
# Model columns endpoint for incremental job configuration
path(
"/model/<str:model_name>/columns",
Expand Down
172 changes: 154 additions & 18 deletions backend/backend/core/scheduler/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,30 +605,22 @@ def task_run_history(request, project_id, user_task_id):
)


@api_view(["POST"])
@permission_classes([IsAuthenticated])
def trigger_task_once(request, project_id, user_task_id):
"""Trigger a task to run immediately.
def _dispatch_task_run(task, user_id, models_override=None):
"""Shared dispatch: try Celery broker, fall back to synchronous execution.

Tries Celery first; if the broker is unreachable, falls back to
synchronous (in-process) execution so local dev works without Redis.
Always marks the run as ``trigger="manual"`` — only the Celery beat
scheduler path hits ``trigger_scheduled_run`` without this dispatch
wrapper, and it keeps the default ``trigger="scheduled"``.
"""
try:
task = UserTaskDetails.objects.get(
id=user_task_id, project__project_uuid=project_id
)
except UserTaskDetails.DoesNotExist:
return Response(
{"error": "Task not found"}, status=status.HTTP_404_NOT_FOUND
)

run_kwargs = {
"user_task_id": task.id,
"user_id": request.user.id,
"user_id": user_id,
"organization_id": str(task.organization_id) if task.organization_id else None,
"trigger": "manual",
}
if models_override:
run_kwargs["models_override"] = list(models_override)

# Try async dispatch via Celery broker
try:
from backend.core.scheduler.task_constant import Task as TaskConst
from celery import current_app
Expand All @@ -646,7 +638,6 @@ def trigger_task_once(request, project_id, user_task_id):
except Exception as broker_err:
logger.warning("Celery broker unavailable (%s), running synchronously.", broker_err)

# Fallback: run synchronously in-process
try:
from backend.core.scheduler.celery_tasks import trigger_scheduled_run

Expand All @@ -661,3 +652,148 @@ def trigger_task_once(request, project_id, user_task_id):
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)


@api_view(["POST"])
@permission_classes([IsAuthenticated])
def trigger_task_once(request, project_id, user_task_id):
"""Trigger a task to run immediately.

Tries Celery first; if the broker is unreachable, falls back to
synchronous (in-process) execution so local dev works without Redis.
"""
try:
task = UserTaskDetails.objects.get(
id=user_task_id, project__project_uuid=project_id
)
except UserTaskDetails.DoesNotExist:
return Response(
{"error": "Task not found"}, status=status.HTTP_404_NOT_FOUND
)

return _dispatch_task_run(task, request.user.id)


@api_view(["POST"])
@permission_classes([IsAuthenticated])
def trigger_task_once_for_model(request, project_id, user_task_id, model_name):
"""Quick Deploy: trigger a job to run a single model against its configured environment.

Execution reuses the scheduler pipeline (TaskRunHistory, retries, Slack
notifications) but scopes the DAG run to ``model_name`` only. The model
must be present and enabled in the task's ``model_configs``.
"""
try:
task = UserTaskDetails.objects.select_related("project").get(
id=user_task_id, project__project_uuid=project_id
)
except UserTaskDetails.DoesNotExist:
return Response(
{"error": "Task not found"}, status=status.HTTP_404_NOT_FOUND
)

model_cfg = (task.model_configs or {}).get(model_name)
if not model_cfg or not model_cfg.get("enabled", True):
return Response(
{"error": f"Model '{model_name}' is not enabled on this job."},
status=status.HTTP_400_BAD_REQUEST,
)

return _dispatch_task_run(task, request.user.id, models_override=[model_name])


@api_view(["GET"])
@permission_classes([IsAuthenticated])
def list_recent_runs_for_model(request, project_id, model_name):
"""Return recent TaskRunHistory entries for any job in this project that
includes ``model_name`` in its ``model_configs``. Mixes scheduled and
quick-deploy runs; caller distinguishes via each row's
``kwargs.source``.
"""
try:
limit = int(request.GET.get("limit", 5))
except (TypeError, ValueError):
limit = 5
limit = max(1, min(limit, 50))

runs_qs = TaskRunHistory.objects.select_related(
"user_task_detail", "user_task_detail__environment",
).filter(
user_task_detail__project__project_uuid=project_id,
user_task_detail__model_configs__has_key=model_name,
).order_by("-start_time")[:limit]

data = []
for run in runs_qs:
task = run.user_task_detail
env = task.environment
kwargs = run.kwargs or {}
models_override = kwargs.get("models_override") or []
# Back-compat: rows written before the trigger/scope split only
# carried kwargs.source=="quick_deploy" as their manual-model marker.
legacy_source = kwargs.get("source")
trigger = kwargs.get("trigger") or (
"manual" if legacy_source == "quick_deploy" else "scheduled"
)
scope = kwargs.get("scope") or (
"model" if models_override or legacy_source == "quick_deploy" else "job"
)
data.append({
"run_id": run.id,
"user_task_id": task.id,
"task_name": task.task_name,
"status": run.status,
"start_time": run.start_time.isoformat() if run.start_time else None,
"end_time": run.end_time.isoformat() if run.end_time else None,
"error_message": run.error_message,
"environment_name": getattr(env, "environment_name", "")
or getattr(env, "name", ""),
"trigger": trigger,
"scope": scope,
"models_override": models_override,
})

return Response({"data": data}, status=status.HTTP_200_OK)


@api_view(["GET"])
@permission_classes([IsAuthenticated])
def list_deploy_candidates(request, project_id, model_name):
"""Return jobs in ``project_id`` that can deploy ``model_name``.

A job qualifies when ``model_name`` is a key in ``model_configs`` and its
``enabled`` flag is truthy (defaults to True if the flag is absent).
"""
tasks = UserTaskDetails.objects.select_related("environment", "project").filter(
project__project_uuid=project_id,
model_configs__has_key=model_name,
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

candidates = []
for task in tasks:
model_configs = task.model_configs or {}
cfg = model_configs.get(model_name)
if not cfg or not cfg.get("enabled", True):
continue
enabled_model_count = sum(
Comment thread
abhizipstack marked this conversation as resolved.
1
for m_cfg in model_configs.values()
if isinstance(m_cfg, dict) and m_cfg.get("enabled", True)
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
env = task.environment
candidates.append({
"user_task_id": task.id,
"task_name": task.task_name,
"environment_id": str(env.environment_id) if env else "",
"environment_name": (
getattr(env, "environment_name", "")
or getattr(env, "name", "")
) if env else "",
"status": task.status,
"prev_run_status": task.prev_run_status,
"task_run_time": task.task_run_time.isoformat() if task.task_run_time else None,
"next_run_time": task.next_run_time.isoformat() if task.next_run_time else None,
"enabled_model_count": enabled_model_count,
})

return Response({"data": candidates}, status=status.HTTP_200_OK)
Loading
Loading