Skip to content

Commit 9c26498

Browse files
extracted json parsing function, added parametrized test and checked coverage
1 parent 53609c3 commit 9c26498

3 files changed

Lines changed: 488 additions & 782 deletions

File tree

python/packages/core/agent_framework/_mcp.py

Lines changed: 15 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
import json
43
import logging
54
import re
65
import sys
@@ -19,9 +18,9 @@
1918
from mcp.shared.context import RequestContext
2019
from mcp.shared.exceptions import McpError
2120
from mcp.shared.session import RequestResponder
22-
from pydantic import BaseModel, Field, create_model
21+
from pydantic import BaseModel, create_model
2322

24-
from ._tools import AIFunction, HostedMCPSpecificApproval
23+
from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema
2524
from ._types import (
2625
ChatMessage,
2726
Contents,
@@ -274,153 +273,26 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
274273
if not prompt.arguments:
275274
return create_model(f"{prompt.name}_input")
276275

277-
field_definitions: dict[str, Any] = {}
278-
for prompt_argument in prompt.arguments:
279-
# For prompts, all arguments are typically required and string type
280-
# unless specified otherwise in the prompt argument
281-
python_type = str # Default type for prompt arguments
276+
# Convert prompt arguments to JSON schema format
277+
properties: dict[str, Any] = {}
278+
required: list[str] = []
282279

283-
# Create field definition for create_model
280+
for prompt_argument in prompt.arguments:
281+
# For prompts, all arguments are typically string type unless specified otherwise
282+
properties[prompt_argument.name] = {
283+
"type": "string",
284+
"description": prompt_argument.description if hasattr(prompt_argument, "description") else "",
285+
}
284286
if prompt_argument.required:
285-
field_definitions[prompt_argument.name] = (python_type, ...)
286-
else:
287-
field_definitions[prompt_argument.name] = (python_type, None)
287+
required.append(prompt_argument.name)
288288

289-
return create_model(f"{prompt.name}_input", **field_definitions)
289+
schema = {"properties": properties, "required": required}
290+
return _build_pydantic_model_from_json_schema(prompt.name, schema)
290291

291292

292293
def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
293294
"""Creates a Pydantic model from a tools parameters."""
294-
properties = tool.inputSchema.get("properties", None)
295-
required = tool.inputSchema.get("required", [])
296-
definitions = tool.inputSchema.get("$defs", {})
297-
298-
# Check if 'properties' is missing or not a dictionary
299-
if not properties:
300-
return create_model(f"{tool.name}_input")
301-
302-
def resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
303-
"""Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays.
304-
305-
Args:
306-
prop_details: The JSON Schema property details
307-
parent_name: Name to use for creating nested models (for uniqueness)
308-
309-
Returns:
310-
Python type annotation (could be int, str, list[str], or a nested Pydantic model)
311-
"""
312-
# Handle $ref by resolving the reference
313-
if "$ref" in prop_details:
314-
ref = prop_details["$ref"]
315-
# Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam")
316-
if ref.startswith("#/$defs/"):
317-
def_name = ref.split("/")[-1]
318-
if def_name in definitions:
319-
# Resolve the reference and use its type
320-
resolved = definitions[def_name]
321-
return resolve_type(resolved, def_name)
322-
# If we can't resolve the ref, default to dict for safety
323-
return dict
324-
325-
# Map JSON Schema types to Python types
326-
json_type = prop_details.get("type", "string")
327-
match json_type:
328-
case "integer":
329-
return int
330-
case "number":
331-
return float
332-
case "boolean":
333-
return bool
334-
case "array":
335-
# Handle typed arrays
336-
items_schema = prop_details.get("items")
337-
if items_schema and isinstance(items_schema, dict):
338-
# Recursively resolve the item type
339-
item_type = resolve_type(items_schema, f"{parent_name}_item")
340-
# Return list[ItemType] instead of bare list
341-
return list[item_type] # type: ignore
342-
# If no items schema or invalid, return bare list
343-
return list
344-
case "object":
345-
# Handle nested objects by creating a nested Pydantic model
346-
nested_properties = prop_details.get("properties")
347-
nested_required = prop_details.get("required", [])
348-
349-
if nested_properties and isinstance(nested_properties, dict):
350-
# Create the name for the nested model
351-
nested_model_name = f"{parent_name}_nested" if parent_name else "NestedModel"
352-
353-
# Recursively build field definitions for the nested model
354-
nested_field_definitions: dict[str, Any] = {}
355-
for nested_prop_name, nested_prop_details in nested_properties.items():
356-
nested_prop_details = (
357-
json.loads(nested_prop_details)
358-
if isinstance(nested_prop_details, str)
359-
else nested_prop_details
360-
)
361-
362-
nested_python_type = resolve_type(
363-
nested_prop_details, f"{nested_model_name}_{nested_prop_name}"
364-
)
365-
nested_description = nested_prop_details.get("description", "")
366-
367-
# Build field kwargs for nested property
368-
nested_field_kwargs: dict[str, Any] = {}
369-
if nested_description:
370-
nested_field_kwargs["description"] = nested_description
371-
372-
# Create field definition
373-
if nested_prop_name in nested_required:
374-
nested_field_definitions[nested_prop_name] = (
375-
(
376-
nested_python_type,
377-
Field(**nested_field_kwargs),
378-
)
379-
if nested_field_kwargs
380-
else (nested_python_type, ...)
381-
)
382-
else:
383-
nested_field_kwargs["default"] = nested_prop_details.get("default", None)
384-
nested_field_definitions[nested_prop_name] = (
385-
nested_python_type,
386-
Field(**nested_field_kwargs),
387-
)
388-
389-
# Create and return the nested Pydantic model
390-
return create_model(nested_model_name, **nested_field_definitions) # type: ignore
391-
392-
# If no properties defined, return bare dict
393-
return dict
394-
case _:
395-
return str # default
396-
397-
field_definitions: dict[str, Any] = {}
398-
for prop_name, prop_details in properties.items():
399-
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
400-
401-
python_type = resolve_type(prop_details, f"{tool.name}_{prop_name}")
402-
description = prop_details.get("description", "")
403-
404-
# Build field kwargs (description, etc.)
405-
field_kwargs: dict[str, Any] = {}
406-
if description:
407-
field_kwargs["description"] = description
408-
409-
# Create field definition for create_model
410-
if prop_name in required:
411-
if field_kwargs:
412-
field_definitions[prop_name] = (python_type, Field(**field_kwargs))
413-
else:
414-
field_definitions[prop_name] = (python_type, ...)
415-
else:
416-
default_value = prop_details.get("default", None)
417-
field_kwargs["default"] = default_value
418-
if field_kwargs and any(k != "default" for k in field_kwargs):
419-
field_definitions[prop_name] = (python_type, Field(**field_kwargs))
420-
else:
421-
field_definitions[prop_name] = (python_type, default_value)
422-
423-
return create_model(f"{tool.name}_input", **field_definitions)
295+
return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema)
424296

425297

426298
def _normalize_mcp_name(name: str) -> str:

python/packages/core/agent_framework/_tools.py

Lines changed: 146 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from opentelemetry.metrics import Histogram
2727
from pydantic import AnyUrl, BaseModel, Field, ValidationError, create_model
28-
from pydantic.fields import FieldInfo
2928

3029
from ._logging import get_logger
3130
from ._serialization import SerializationMixin
@@ -932,6 +931,151 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B
932931
}
933932

934933

934+
def _build_pydantic_model_from_json_schema(
935+
model_name: str,
936+
schema: Mapping[str, Any],
937+
) -> type[BaseModel]:
938+
"""Creates a Pydantic model from JSON Schema with support for $refs, nested objects, and typed arrays.
939+
940+
Args:
941+
model_name: The name of the model to be created.
942+
schema: The JSON Schema definition (should contain 'properties', 'required', '$defs', etc.).
943+
944+
Returns:
945+
The dynamically created Pydantic model class.
946+
"""
947+
properties = schema.get("properties")
948+
required = schema.get("required", [])
949+
definitions = schema.get("$defs", {})
950+
951+
# Check if 'properties' is missing or not a dictionary
952+
if not properties:
953+
return create_model(f"{model_name}_input")
954+
955+
def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
956+
"""Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays.
957+
958+
Args:
959+
prop_details: The JSON Schema property details
960+
parent_name: Name to use for creating nested models (for uniqueness)
961+
962+
Returns:
963+
Python type annotation (could be int, str, list[str], or a nested Pydantic model)
964+
"""
965+
# Handle $ref by resolving the reference
966+
if "$ref" in prop_details:
967+
ref = prop_details["$ref"]
968+
# Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam")
969+
if ref.startswith("#/$defs/"):
970+
def_name = ref.split("/")[-1]
971+
if def_name in definitions:
972+
# Resolve the reference and use its type
973+
resolved = definitions[def_name]
974+
return _resolve_type(resolved, def_name)
975+
# If we can't resolve the ref, default to dict for safety
976+
return dict
977+
978+
# Map JSON Schema types to Python types
979+
json_type = prop_details.get("type", "string")
980+
match json_type:
981+
case "integer":
982+
return int
983+
case "number":
984+
return float
985+
case "boolean":
986+
return bool
987+
case "array":
988+
# Handle typed arrays
989+
items_schema = prop_details.get("items")
990+
if items_schema and isinstance(items_schema, dict):
991+
# Recursively resolve the item type
992+
item_type = _resolve_type(items_schema, f"{parent_name}_item")
993+
# Return list[ItemType] instead of bare list
994+
return list[item_type] # type: ignore
995+
# If no items schema or invalid, return bare list
996+
return list
997+
case "object":
998+
# Handle nested objects by creating a nested Pydantic model
999+
nested_properties = prop_details.get("properties")
1000+
nested_required = prop_details.get("required", [])
1001+
1002+
if nested_properties and isinstance(nested_properties, dict):
1003+
# Create the name for the nested model
1004+
nested_model_name = f"{parent_name}_nested" if parent_name else "NestedModel"
1005+
1006+
# Recursively build field definitions for the nested model
1007+
nested_field_definitions: dict[str, Any] = {}
1008+
for nested_prop_name, nested_prop_details in nested_properties.items():
1009+
nested_prop_details = (
1010+
json.loads(nested_prop_details)
1011+
if isinstance(nested_prop_details, str)
1012+
else nested_prop_details
1013+
)
1014+
1015+
nested_python_type = _resolve_type(
1016+
nested_prop_details, f"{nested_model_name}_{nested_prop_name}"
1017+
)
1018+
nested_description = nested_prop_details.get("description", "")
1019+
1020+
# Build field kwargs for nested property
1021+
nested_field_kwargs: dict[str, Any] = {}
1022+
if nested_description:
1023+
nested_field_kwargs["description"] = nested_description
1024+
1025+
# Create field definition
1026+
if nested_prop_name in nested_required:
1027+
nested_field_definitions[nested_prop_name] = (
1028+
(
1029+
nested_python_type,
1030+
Field(**nested_field_kwargs),
1031+
)
1032+
if nested_field_kwargs
1033+
else (nested_python_type, ...)
1034+
)
1035+
else:
1036+
nested_field_kwargs["default"] = nested_prop_details.get("default", None)
1037+
nested_field_definitions[nested_prop_name] = (
1038+
nested_python_type,
1039+
Field(**nested_field_kwargs),
1040+
)
1041+
1042+
# Create and return the nested Pydantic model
1043+
return create_model(nested_model_name, **nested_field_definitions) # type: ignore
1044+
1045+
# If no properties defined, return bare dict
1046+
return dict
1047+
case _:
1048+
return str # default
1049+
1050+
field_definitions: dict[str, Any] = {}
1051+
for prop_name, prop_details in properties.items():
1052+
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
1053+
1054+
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
1055+
description = prop_details.get("description", "")
1056+
1057+
# Build field kwargs (description, etc.)
1058+
field_kwargs: dict[str, Any] = {}
1059+
if description:
1060+
field_kwargs["description"] = description
1061+
1062+
# Create field definition for create_model
1063+
if prop_name in required:
1064+
if field_kwargs:
1065+
field_definitions[prop_name] = (python_type, Field(**field_kwargs))
1066+
else:
1067+
field_definitions[prop_name] = (python_type, ...)
1068+
else:
1069+
default_value = prop_details.get("default", None)
1070+
field_kwargs["default"] = default_value
1071+
if field_kwargs and any(k != "default" for k in field_kwargs):
1072+
field_definitions[prop_name] = (python_type, Field(**field_kwargs))
1073+
else:
1074+
field_definitions[prop_name] = (python_type, default_value)
1075+
1076+
return create_model(f"{model_name}_input", **field_definitions)
1077+
1078+
9351079
def _create_model_from_json_schema(tool_name: str, schema_json: Mapping[str, Any]) -> type[BaseModel]:
9361080
"""Creates a Pydantic model from a given JSON Schema.
9371081
@@ -948,29 +1092,8 @@ def _create_model_from_json_schema(tool_name: str, schema_json: Mapping[str, Any
9481092
f"JSON schema for tool '{tool_name}' must contain a 'properties' key of type dict. "
9491093
f"Got: {schema_json.get('properties', None)}"
9501094
)
951-
# Extract field definitions with type annotations
952-
field_definitions: dict[str, tuple[type, FieldInfo]] = {}
953-
for field_name, field_schema in schema_json["properties"].items():
954-
field_args: dict[str, Any] = {}
955-
if (field_description := field_schema.get("description", None)) is not None:
956-
field_args["description"] = field_description
957-
if (field_default := field_schema.get("default", None)) is not None:
958-
field_args["default"] = field_default
959-
field_type = field_schema.get("type", None)
960-
if field_type is None:
961-
raise ValueError(
962-
f"Missing 'type' for field '{field_name}' in JSON schema. "
963-
f"Got: {field_schema}, Supported types: {list(TYPE_MAPPING.keys())}"
964-
)
965-
python_type = TYPE_MAPPING.get(field_type)
966-
if python_type is None:
967-
raise ValueError(
968-
f"Unsupported type '{field_type}' for field '{field_name}' in JSON schema. "
969-
f"Got: {field_schema}, Supported types: {list(TYPE_MAPPING.keys())}"
970-
)
971-
field_definitions[field_name] = (python_type, Field(**field_args))
9721095

973-
return create_model(f"{tool_name}_input", **field_definitions) # type: ignore[call-overload, no-any-return]
1096+
return _build_pydantic_model_from_json_schema(tool_name, schema_json)
9741097

9751098

9761099
@overload

0 commit comments

Comments
 (0)