Skip to content

Commit cc64ebb

Browse files
committed
feat: Add include_sql to Search Pipeline Run API
1 parent 8d6118b commit cc64ebb

2 files changed

Lines changed: 128 additions & 10 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class GetPipelineRunResponse(PipelineRunResponse):
8585
class ListPipelineJobsResponse:
8686
pipeline_runs: list[PipelineRunResponse]
8787
next_page_token: str | None = None
88+
sql: str | None = None
8889

8990

9091
class PipelineRunsApiService_Sql:
@@ -196,6 +197,36 @@ def terminate(
196197
execution_node.extra_data["desired_state"] = "TERMINATED"
197198
session.commit()
198199

200+
@staticmethod
201+
def _compile_sql_string(
202+
stmt: sql.Select,
203+
dialect: sql.engine.Dialect,
204+
) -> str:
205+
"""Compile a SQLAlchemy statement to a SQL string for debugging.
206+
207+
Uses ``literal_binds=True`` to inline bound parameters as literal
208+
values, producing a self-contained query string::
209+
210+
SELECT ... WHERE key = 'environment' AND created_at < '2024-01-15' LIMIT 10
211+
212+
If a column type lacks a ``literal_processor`` (raises CompileError or
213+
NotImplementedError), falls back to placeholder syntax with a params
214+
comment::
215+
216+
SELECT ... WHERE key = :key_1 AND created_at < :created_at_1 LIMIT :param_1
217+
-- params: {'key_1': 'environment', 'created_at_1': '2024-01-15', 'param_1': 10}
218+
"""
219+
try:
220+
compiled = stmt.compile(
221+
dialect=dialect,
222+
compile_kwargs={"literal_binds": True},
223+
)
224+
return str(compiled)
225+
except (sql.exc.CompileError, NotImplementedError):
226+
compiled = stmt.compile(dialect=dialect)
227+
params_suffix = f"\n-- params: {compiled.params}" if compiled.params else ""
228+
return str(compiled) + params_suffix
229+
199230
# Note: This method must be last to not shadow the "list" type
200231
def list(
201232
self,
@@ -207,6 +238,7 @@ def list(
207238
current_user: str | None = None,
208239
include_pipeline_names: bool = False,
209240
include_execution_stats: bool = False,
241+
include_sql: bool = False,
210242
) -> ListPipelineJobsResponse:
211243
where_clauses = filter_query_sql.build_list_filters(
212244
filter_value=filter,
@@ -215,18 +247,22 @@ def list(
215247
current_user=current_user,
216248
)
217249

218-
pipeline_runs = list(
219-
session.scalars(
220-
sql.select(bts.PipelineRun)
221-
.where(*where_clauses)
222-
.order_by(
223-
bts.PipelineRun.created_at.desc(),
224-
bts.PipelineRun.id.desc(),
225-
)
226-
.limit(self._DEFAULT_PAGE_SIZE)
227-
).all()
250+
stmt = (
251+
sql.select(bts.PipelineRun)
252+
.where(*where_clauses)
253+
.order_by(
254+
bts.PipelineRun.created_at.desc(),
255+
bts.PipelineRun.id.desc(),
256+
)
257+
.limit(self._DEFAULT_PAGE_SIZE)
228258
)
229259

260+
sql_string = None
261+
if include_sql:
262+
sql_string = self._compile_sql_string(stmt, session.bind.dialect)
263+
264+
pipeline_runs = list(session.scalars(stmt).all())
265+
230266
next_page_token = filter_query_sql.maybe_next_page_token(
231267
rows=pipeline_runs, page_size=self._DEFAULT_PAGE_SIZE
232268
)
@@ -242,6 +278,7 @@ def list(
242278
for pipeline_run in pipeline_runs
243279
],
244280
next_page_token=next_page_token,
281+
sql=sql_string,
245282
)
246283

247284
def _create_pipeline_run_response(

tests/test_api_server_sql.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,87 @@ def test_list_filter_created_by_me(self, session_factory, service):
295295
assert len(result.pipeline_runs) == 1
296296
assert result.pipeline_runs[0].created_by == "alice@example.com"
297297

298+
def test_list_include_sql_default_none(self, session_factory, service):
299+
_create_run(session_factory, service, root_task=_make_task_spec())
300+
301+
with session_factory() as session:
302+
result = service.list(session=session)
303+
assert result.sql is None
304+
305+
def test_list_include_sql_true(self, session_factory, service):
306+
_create_run(session_factory, service, root_task=_make_task_spec())
307+
308+
with session_factory() as session:
309+
result = service.list(session=session, include_sql=True)
310+
expected = (
311+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
312+
" pipeline_run.annotations, pipeline_run.created_by,"
313+
" pipeline_run.created_at, pipeline_run.updated_at,"
314+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
315+
"FROM pipeline_run"
316+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
317+
" LIMIT 10 OFFSET 0"
318+
)
319+
assert result.sql == expected
320+
321+
def test_list_include_sql_with_filter_query(self, session_factory, service):
322+
run = _create_run(session_factory, service, root_task=_make_task_spec())
323+
with session_factory() as session:
324+
service.set_annotation(session=session, id=run.id, key="team", value="ml")
325+
326+
fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]})
327+
with session_factory() as session:
328+
result = service.list(session=session, filter_query=fq, include_sql=True)
329+
expected = (
330+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
331+
" pipeline_run.annotations, pipeline_run.created_by,"
332+
" pipeline_run.created_at, pipeline_run.updated_at,"
333+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
334+
"FROM pipeline_run \n"
335+
"WHERE EXISTS (SELECT pipeline_run_annotation.pipeline_run_id \n"
336+
"FROM pipeline_run_annotation \n"
337+
"WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id"
338+
" AND pipeline_run_annotation.\"key\" = 'team')"
339+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
340+
" LIMIT 10 OFFSET 0"
341+
)
342+
assert result.sql == expected
343+
344+
def test_list_include_sql_with_cursor(self, session_factory, service):
345+
for i in range(12):
346+
_create_run(
347+
session_factory,
348+
service,
349+
root_task=_make_task_spec(f"pipeline-{i}"),
350+
)
351+
352+
with session_factory() as session:
353+
page1 = service.list(session=session)
354+
assert page1.next_page_token is not None
355+
356+
with session_factory() as session:
357+
page2 = service.list(
358+
session=session,
359+
page_token=page1.next_page_token,
360+
include_sql=True,
361+
)
362+
363+
cursor_dt_iso, cursor_id = page1.next_page_token.split("~")
364+
cursor_dt = datetime.datetime.fromisoformat(cursor_dt_iso)
365+
sql_dt = cursor_dt.strftime("%Y-%m-%d %H:%M:%S.%f")
366+
expected = (
367+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
368+
" pipeline_run.annotations, pipeline_run.created_by,"
369+
" pipeline_run.created_at, pipeline_run.updated_at,"
370+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
371+
"FROM pipeline_run \n"
372+
f"WHERE (pipeline_run.created_at, pipeline_run.id)"
373+
f" < ('{sql_dt}', '{cursor_id}')"
374+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
375+
" LIMIT 10 OFFSET 0"
376+
)
377+
assert page2.sql == expected
378+
298379

299380
class TestCreatePipelineRunResponse:
300381
def test_base_response(self, session_factory, service):

0 commit comments

Comments
 (0)