Skip to content

Commit 2c40555

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: base_url and global location parsing
PiperOrigin-RevId: 867792895
1 parent bf8c29b commit 2c40555

5 files changed

Lines changed: 114 additions & 66 deletions

File tree

google/genai/_api_client.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,14 +584,15 @@ def __init__(
584584
validated_http_options = HttpOptions.model_validate(http_options)
585585
except ValidationError as e:
586586
raise ValueError('Invalid http_options') from e
587-
elif isinstance(http_options, HttpOptions):
587+
elif http_options and _common.is_duck_type_of(http_options, HttpOptions):
588588
validated_http_options = http_options
589589

590590
if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
591591
# base_url_resource_scope is only valid when base_url is set.
592592
raise ValueError(
593593
'base_url must be set when base_url_resource_scope is set.'
594594
)
595+
print('validated_http_options: ', validated_http_options)
595596

596597
# Retrieve implicitly set values from the environment.
597598
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
@@ -649,7 +650,13 @@ def __init__(
649650
else None
650651
)
651652

652-
if not self.location and not self.api_key and not self.custom_base_url:
653+
if (
654+
not self.location
655+
and not self.api_key
656+
):
657+
if not self.custom_base_url:
658+
self.location = 'global'
659+
elif self.custom_base_url.endswith('.googleapis.com'):
653660
self.location = 'global'
654661

655662
# Skip fetching project from ADC if base url is provided in http options.
@@ -667,12 +674,16 @@ def __init__(
667674
if not has_sufficient_auth and not self.custom_base_url:
668675
# Skip sufficient auth check if base url is provided in http options.
669676
raise ValueError(
670-
'Project or API key must be set when using the Vertex '
671-
'AI API.'
677+
'Project or API key must be set when using the Vertex AI API.'
672678
)
673-
if self.api_key or self.location == 'global':
679+
if (
680+
self.api_key or self.location == 'global'
681+
) and not self.custom_base_url:
674682
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
675-
elif self.custom_base_url and not ((project and location) or api_key):
683+
elif (
684+
self.custom_base_url
685+
and not self.custom_base_url.endswith('.googleapis.com')
686+
) and not ((project and location) or api_key):
676687
# Avoid setting default base url and api version if base_url provided.
677688
# API gateway proxy can use the auth in custom headers, not url.
678689
# Enable custom url if auth is not sufficient.

google/genai/_common.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,3 +814,34 @@ def recursive_dict_update(
814814
target_dict[key] = value
815815
else:
816816
target_dict[key] = value
817+
818+
819+
def is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
820+
"""Checks if an object has all of the fields of a Pydantic model.
821+
822+
This is a duck-typing alternative to `isinstance` to solve dual-import
823+
problems. It returns False for dictionaries, which should be handled by
824+
`isinstance(obj, dict)`.
825+
826+
Args:
827+
obj: The object to check.
828+
cls: The Pydantic model class to duck-type against.
829+
830+
Returns:
831+
True if the object has all the fields defined in the Pydantic model, False
832+
otherwise.
833+
"""
834+
if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
835+
return False
836+
837+
# Check if the object has all of the Pydantic model's defined fields.
838+
all_matched = all(hasattr(obj, field) for field in cls.model_fields)
839+
if not all_matched and isinstance(obj, pydantic.BaseModel):
840+
# Check the other way around if obj is a Pydantic model.
841+
# Check if the Pydantic model has all of the object's defined fields.
842+
try:
843+
obj_private = cls()
844+
all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
845+
except ValueError:
846+
return False
847+
return all_matched

google/genai/_transformers.py

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
3030
from ._mcp_utils import mcp_to_gemini_tool
3131
from ._common import get_value_by_path as getv
32+
from ._common import is_duck_type_of
3233

3334
if typing.TYPE_CHECKING:
3435
import PIL.Image
@@ -72,37 +73,6 @@
7273
metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()}
7374

7475

75-
def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
76-
"""Checks if an object has all of the fields of a Pydantic model.
77-
78-
This is a duck-typing alternative to `isinstance` to solve dual-import
79-
problems. It returns False for dictionaries, which should be handled by
80-
`isinstance(obj, dict)`.
81-
82-
Args:
83-
obj: The object to check.
84-
cls: The Pydantic model class to duck-type against.
85-
86-
Returns:
87-
True if the object has all the fields defined in the Pydantic model, False
88-
otherwise.
89-
"""
90-
if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
91-
return False
92-
93-
# Check if the object has all of the Pydantic model's defined fields.
94-
all_matched = all(hasattr(obj, field) for field in cls.model_fields)
95-
if not all_matched and isinstance(obj, pydantic.BaseModel):
96-
# Check the other way around if obj is a Pydantic model.
97-
# Check if the Pydantic model has all of the object's defined fields.
98-
try:
99-
obj_private = cls()
100-
all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
101-
except ValueError:
102-
return False
103-
return all_matched
104-
105-
10676
def _resource_name(
10777
client: _api_client.BaseApiClient,
10878
resource_name: str,
@@ -311,7 +281,7 @@ def t_function_response(
311281
raise ValueError('function_response is required.')
312282
if isinstance(function_response, dict):
313283
return types.FunctionResponse.model_validate(function_response)
314-
elif _is_duck_type_of(function_response, types.FunctionResponse):
284+
elif is_duck_type_of(function_response, types.FunctionResponse):
315285
return function_response
316286
else:
317287
raise TypeError(
@@ -347,7 +317,7 @@ def t_blob(blob: types.BlobImageUnionDict) -> types.Blob:
347317
if not blob:
348318
raise ValueError('blob is required.')
349319

350-
if _is_duck_type_of(blob, types.Blob):
320+
if is_duck_type_of(blob, types.Blob):
351321
return blob # type: ignore[return-value]
352322

353323
if isinstance(blob, dict):
@@ -388,7 +358,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
388358
raise ValueError('content part is required.')
389359
if isinstance(part, str):
390360
return types.Part(text=part)
391-
if _is_duck_type_of(part, types.File):
361+
if is_duck_type_of(part, types.File):
392362
if not part.uri or not part.mime_type: # type: ignore[union-attr]
393363
raise ValueError('file uri and mime_type are required.')
394364
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr]
@@ -397,7 +367,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
397367
return types.Part.model_validate(part)
398368
except pydantic.ValidationError:
399369
return types.Part(file_data=types.FileData.model_validate(part))
400-
if _is_duck_type_of(part, types.Part):
370+
if is_duck_type_of(part, types.Part):
401371
return part # type: ignore[return-value]
402372

403373
if 'image' in part.__class__.__name__.lower():
@@ -454,7 +424,7 @@ def t_content(
454424
) -> types.Content:
455425
if content is None:
456426
raise ValueError('content is required.')
457-
if _is_duck_type_of(content, types.Content):
427+
if is_duck_type_of(content, types.Content):
458428
return content # type: ignore[return-value]
459429
if isinstance(content, dict):
460430
try:
@@ -466,9 +436,9 @@ def t_content(
466436
if possible_part.function_call
467437
else types.UserContent(parts=[possible_part])
468438
)
469-
if _is_duck_type_of(content, types.File):
439+
if is_duck_type_of(content, types.File):
470440
return types.UserContent(parts=[t_part(content)]) # type: ignore[arg-type]
471-
if _is_duck_type_of(content, types.Part):
441+
if is_duck_type_of(content, types.Part):
472442
return (
473443
types.ModelContent(parts=[content]) # type: ignore[arg-type]
474444
if content.function_call # type: ignore[union-attr]
@@ -521,8 +491,8 @@ def _is_part(
521491
) -> TypeGuard[types.PartUnionDict]:
522492
if (
523493
isinstance(part, str)
524-
or _is_duck_type_of(part, types.File)
525-
or _is_duck_type_of(part, types.Part)
494+
or is_duck_type_of(part, types.File)
495+
or is_duck_type_of(part, types.Part)
526496
):
527497
return True
528498

@@ -592,7 +562,7 @@ def _handle_current_part(
592562
# append to result
593563
# if list, we only accept a list of types.PartUnion
594564
for content in contents:
595-
if _is_duck_type_of(content, types.Content) or isinstance(content, list):
565+
if is_duck_type_of(content, types.Content) or isinstance(content, list):
596566
_append_accumulated_parts_as_content(result, accumulated_parts)
597567
if isinstance(content, list):
598568
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
@@ -889,7 +859,7 @@ def t_schema(
889859
return types.Schema.model_validate(origin)
890860
if isinstance(origin, EnumMeta):
891861
return _process_enum(origin, client)
892-
if _is_duck_type_of(origin, types.Schema):
862+
if is_duck_type_of(origin, types.Schema):
893863
if dict(origin) == dict(types.Schema()): # type: ignore [arg-type]
894864
# response_schema value was coerced to an empty Schema instance because
895865
# it did not adhere to the Schema field annotation
@@ -931,7 +901,7 @@ def t_speech_config(
931901
) -> Optional[types.SpeechConfig]:
932902
if not origin:
933903
return None
934-
if _is_duck_type_of(origin, types.SpeechConfig):
904+
if is_duck_type_of(origin, types.SpeechConfig):
935905
return origin # type: ignore[return-value]
936906
if isinstance(origin, str):
937907
return types.SpeechConfig(
@@ -948,7 +918,7 @@ def t_speech_config(
948918
def t_live_speech_config(
949919
origin: types.SpeechConfigOrDict,
950920
) -> Optional[types.SpeechConfig]:
951-
if _is_duck_type_of(origin, types.SpeechConfig):
921+
if is_duck_type_of(origin, types.SpeechConfig):
952922
speech_config = origin
953923
if isinstance(origin, dict):
954924
speech_config = types.SpeechConfig.model_validate(origin)
@@ -974,7 +944,7 @@ def t_tool(
974944
)
975945
]
976946
)
977-
elif McpTool is not None and _is_duck_type_of(origin, McpTool):
947+
elif McpTool is not None and is_duck_type_of(origin, McpTool):
978948
return mcp_to_gemini_tool(origin)
979949
elif isinstance(origin, dict):
980950
return types.Tool.model_validate(origin)
@@ -1017,7 +987,7 @@ def t_batch_job_source(
1017987
) -> types.BatchJobSource:
1018988
if isinstance(src, dict):
1019989
src = types.BatchJobSource(**src)
1020-
if _is_duck_type_of(src, types.BatchJobSource):
990+
if is_duck_type_of(src, types.BatchJobSource):
1021991
vertex_sources = sum(
1022992
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
1023993
)
@@ -1068,7 +1038,7 @@ def t_embedding_batch_job_source(
10681038
if isinstance(src, dict):
10691039
src = types.EmbeddingsBatchJobSource(**src)
10701040

1071-
if _is_duck_type_of(src, types.EmbeddingsBatchJobSource):
1041+
if is_duck_type_of(src, types.EmbeddingsBatchJobSource):
10721042
mldev_sources = sum([
10731043
src.inlined_requests is not None,
10741044
src.file_name is not None,
@@ -1103,7 +1073,7 @@ def t_batch_job_destination(
11031073
)
11041074
else:
11051075
raise ValueError(f'Unsupported destination: {dest}')
1106-
elif _is_duck_type_of(dest, types.BatchJobDestination):
1076+
elif is_duck_type_of(dest, types.BatchJobDestination):
11071077
return dest
11081078
else:
11091079
raise ValueError(f'Unsupported destination: {dest}')
@@ -1203,11 +1173,11 @@ def t_file_name(
12031173
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
12041174
) -> str:
12051175
# Remove the files/ prefix since it's added to the url path.
1206-
if _is_duck_type_of(name, types.File):
1176+
if is_duck_type_of(name, types.File):
12071177
name = name.name # type: ignore[union-attr]
1208-
elif _is_duck_type_of(name, types.Video):
1178+
elif is_duck_type_of(name, types.Video):
12091179
name = name.uri # type: ignore[union-attr]
1210-
elif _is_duck_type_of(name, types.GeneratedVideo):
1180+
elif is_duck_type_of(name, types.GeneratedVideo):
12111181
if name.video is not None: # type: ignore[union-attr]
12121182
name = name.video.uri # type: ignore[union-attr]
12131183
else:
@@ -1252,7 +1222,7 @@ def t_tuning_job_status(status: str) -> Union[types.JobState, str]:
12521222
def t_content_strict(content: types.ContentOrDict) -> types.Content:
12531223
if isinstance(content, dict):
12541224
return types.Content.model_validate(content)
1255-
elif _is_duck_type_of(content, types.Content):
1225+
elif is_duck_type_of(content, types.Content):
12561226
return content
12571227
else:
12581228
raise ValueError(

google/genai/tests/client/test_client_initialization.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,42 @@ def test_vertexai_default_location_to_global_with_explicit_project_and_env_apike
449449
assert not client.models._api_client.api_key
450450

451451

452+
def test_vertexai_default_location_to_global_with_vertexai_base_url(
453+
monkeypatch,
454+
):
455+
# Test case 4: When project and vertex base url are set
456+
project_id = "env_project_id"
457+
458+
with monkeypatch.context() as m:
459+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
460+
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
461+
client = Client(
462+
vertexai=True,
463+
http_options={'base_url': 'https://fake-url.googleapis.com'},
464+
)
465+
# Implicit project takes precedence over implicit api_key
466+
assert client.models._api_client.location == "global"
467+
assert client.models._api_client.project == project_id
468+
469+
470+
def test_vertexai_default_location_to_global_with_arbitrary_base_url(
471+
monkeypatch,
472+
):
473+
# Test case 5: When project and arbitrary base url (proxy) are set
474+
project_id = "env_project_id"
475+
476+
with monkeypatch.context() as m:
477+
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
478+
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
479+
client = Client(
480+
vertexai=True,
481+
http_options={'base_url': 'https://fake-url.com'},
482+
)
483+
# Implicit project takes precedence over implicit api_key
484+
assert not client.models._api_client.location
485+
assert not client.models._api_client.project
486+
487+
452488
def test_vertexai_default_location_to_global_with_env_project_and_env_apikey(
453489
monkeypatch,
454490
):

0 commit comments

Comments
 (0)