Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions vertexai/prompts/_prompt_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from google.protobuf import field_mask_pb2 as field_mask

import dataclasses
import datetime
from typing import (
Any,
Dict,
Expand All @@ -57,6 +58,29 @@
"gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml"
)

# Delimiter used to embed creator info in a DatasetVersion display_name.
# The Vertex AI DatasetVersion API has no native creator field, so we encode
# it as a suffix: "<display_name>__creator__<email>".
_CREATOR_DELIMITER = "__creator__"


def _encode_display_name(version_name: Optional[str], creator: Optional[str]) -> Optional[str]:
"""Encodes creator into display_name using a structured suffix."""
if not creator:
return version_name
base = version_name or ""
return f"{base}{_CREATOR_DELIMITER}{creator}"


def _decode_display_name(encoded: Optional[str]) -> tuple[Optional[str], Optional[str]]:
"""Returns (display_name, creator) decoded from an encoded display_name string."""
if not encoded or _CREATOR_DELIMITER not in encoded:
return encoded, None
parts = encoded.rsplit(_CREATOR_DELIMITER, 1)
display_name = parts[0] if parts[0] else None
creator = parts[1] if parts[1] else None
return display_name, creator


def _format_function_declaration_parameters(obj: Any):
"""Recursively replaces type_ and format_ fields in-place."""
Expand Down Expand Up @@ -264,17 +288,25 @@ class PromptVersionMetadata:
display_name: The display name of the prompt version.
prompt_id: The id of the prompt.
version_id: The version id of the prompt.
create_time: The timestamp when the prompt version was created.
creator: The user who saved the prompt version. Only populated if the
version was saved using the SDK with a ``creator`` argument. The
Vertex AI DatasetVersion API does not natively track the creator;
this value is encoded in the version display name by the SDK.
"""

display_name: str
prompt_id: str
version_id: str
create_time: Optional[datetime.datetime] = None
creator: Optional[str] = None


def create_version(
prompt: Prompt,
prompt_id: Optional[str] = None,
version_name: Optional[str] = None,
creator: Optional[str] = None,
) -> Prompt:
"""Creates a Prompt or Prompt Version in the online prompt store

Expand All @@ -285,6 +317,12 @@ def create_version(
associated with it, a new prompt resource will be created.
version_name: Optional display name of the new prompt version.
If not specified, a default name including a timestamp will be used.
creator: Optional identifier (e.g. email address) of the user saving
this version. The Vertex AI DatasetVersion API does not have a
native creator field, so this value is encoded in the version's
display name by the SDK and decoded when listing versions. Pass
the authenticated user's email or username to enable team
attribution in prompt version history.

Returns:
A new Prompt object with a reference to the newly created or updated
Expand All @@ -295,11 +333,15 @@ def create_version(
if not (prompt_id or prompt._dataset):
# Case 1: Neither prompt id nor prompt._dataset exists, so we
# create a new prompt resource
return _create_prompt_resource(prompt=prompt, version_name=version_name)
return _create_prompt_resource(
prompt=prompt, version_name=version_name, creator=creator
)

# Case 2: No prompt_id override is given, so we update the existing prompt resource
if not prompt_id:
return _create_prompt_version_resource(prompt=prompt, version_name=version_name)
return _create_prompt_version_resource(
prompt=prompt, version_name=version_name, creator=creator
)

# Case 3: Save a new version to the prompt_id provided as an arg
# prompt_id is guaranteed to exist due to Cases 1 & 2 being handled before
Expand Down Expand Up @@ -328,7 +370,9 @@ def create_version(
prompt._dataset.name = (
f"projects/{project}/locations/{location}/datasets/{prompt_id}"
)
result = _create_prompt_version_resource(prompt=prompt, version_name=version_name)
result = _create_prompt_version_resource(
prompt=prompt, version_name=version_name, creator=creator
)

# Restore the original prompt resource name. This is a no-op if there
# was no original prompt resource name.
Expand Down Expand Up @@ -402,10 +446,13 @@ def _create_dataset(prompt: Prompt, parent: str) -> gca_dataset.Dataset:


def _create_dataset_version(
prompt: Prompt, parent: str, version_name: Optional[str] = None
prompt: Prompt,
parent: str,
version_name: Optional[str] = None,
creator: Optional[str] = None,
):
dataset_version = gca_dataset_version.DatasetVersion(
display_name=version_name,
display_name=_encode_display_name(version_name, creator),
)

dataset_version = prompt._dataset_client.create_dataset_version(
Expand Down Expand Up @@ -435,7 +482,9 @@ def _update_dataset(


def _create_prompt_resource(
prompt: Prompt, version_name: Optional[str] = None
prompt: Prompt,
version_name: Optional[str] = None,
creator: Optional[str] = None,
) -> Prompt:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
Expand All @@ -449,13 +498,15 @@ def _create_prompt_resource(
prompt=prompt,
parent=dataset.name,
version_name=version_name,
creator=creator,
)

# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
display_name, _ = _decode_display_name(dataset_version.display_name)
new_prompt._version_name = display_name
prompt_id = new_prompt._dataset.name.split("/")[5]

_LOGGER.info(
Expand All @@ -467,6 +518,7 @@ def _create_prompt_resource(
def _create_prompt_version_resource(
prompt: Prompt,
version_name: Optional[str] = None,
creator: Optional[str] = None,
) -> Prompt:
# Step 1: Update prompt API call
updated_dataset = _update_dataset(prompt=prompt, dataset=prompt._dataset)
Expand All @@ -476,13 +528,15 @@ def _create_prompt_version_resource(
prompt=prompt,
parent=updated_dataset.name,
version_name=version_name,
creator=creator,
)

# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = updated_dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
display_name, _ = _decode_display_name(dataset_version.display_name)
new_prompt._version_name = display_name
prompt_id = prompt._dataset.name.split("/")[5]

_LOGGER.info(
Expand Down Expand Up @@ -700,6 +754,9 @@ def list_versions(prompt_id: str) -> list[PromptVersionMetadata]:

Returns:
A list of PromptVersionMetadata objects for the prompt resource.
Each entry includes ``create_time`` (the timestamp when the version
was saved) and ``creator`` (the user who saved it, if the version was
created via the SDK with the ``creator`` argument).
"""
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
Expand All @@ -712,11 +769,17 @@ def list_versions(prompt_id: str) -> list[PromptVersionMetadata]:
)
version_history = []
for version in versions_pager:
display_name, creator = _decode_display_name(version.display_name)
create_time = None
if version.create_time:
create_time = version.create_time.ToDatetime(tzinfo=datetime.timezone.utc)
version_history.append(
PromptVersionMetadata(
display_name=version.display_name,
display_name=display_name,
prompt_id=version.name.split("/")[5],
version_id=version.name.split("/")[-1],
create_time=create_time,
creator=creator,
)
)
return version_history
Expand Down