Skip to content

Commit 20f9bf1

Browse files
authored
Fix MODEL_TYPE fragment classification in CreateExternalModelStatement (#176) (#210)
1 parent bc884b9 commit 20f9bf1

7 files changed

Lines changed: 73 additions & 24 deletions

File tree

SqlScriptDom/Parser/TSql/Ast.xml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0"?>
1+
<?xml version="1.0"?>
22
<!--
33
<copyright file="Ast.xml" company="Microsoft">
44
Copyright (c) Microsoft Corporation. All rights reserved.
@@ -1364,7 +1364,7 @@
13641364
<Member Name="Name" Type="Identifier" Summary="The external model name."/>
13651365
<Member Name="Location" Type="Literal" Summary="The external model location name."/>
13661366
<Member Name="ApiFormat" Type="Literal" Summary="The external model api format name."/>
1367-
<Member Name="ModelType" Type="ExternalModelTypeOption?" GenerateUpdatePositionInfoCall="false" Summary="The external model type name."/>
1367+
<Member Name="ModelType" Type="ExternalModelTypeOption" Summary="The external model type name."/>
13681368
<Member Name="ModelName" Type="Literal" Summary="The external model name to be used to generate embeddings."/>
13691369
<Member Name="Credential" Type="Identifier" Summary="The external model credentials name."/>
13701370
<Member Name="Parameters" Type="Literal" Summary="The external model parameters as key-value pairs."/>
@@ -1380,6 +1380,9 @@
13801380
<Class Name="AlterExternalModelStatement" Base="ExternalModelStatement" Summary="Represents a ALTER EXTERNAL MODEL statement.">
13811381
<InheritedClass Name="ExternalModelStatement"/>
13821382
</Class>
1383+
<Class Name="ExternalModelTypeOption" Base="TSqlFragment" Summary="Represents the MODEL_TYPE option in CREATE/ALTER EXTERNAL MODEL.">
1384+
<Member Name="OptionKind" Type="ExternalModelTypeOptionKind" GenerateUpdatePositionInfoCall="false" Summary="The option kind."/>
1385+
</Class>
13831386

13841387
<Class Name="ExternalFileFormatStatement" Abstract="true" Base="TSqlStatement" Summary="Base class for all external file format statement objects.">
13851388
<Member Name="Name" Type="Identifier" Summary="The external file format name."/>

SqlScriptDom/Parser/TSql/ExternalModelTypeOption.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//------------------------------------------------------------------------------
1+
//------------------------------------------------------------------------------
22
// <copyright file="ExternalModelTypeOption.cs" company="Microsoft">
33
// Copyright (c) Microsoft Corporation. All rights reserved.
44
// </copyright>
@@ -11,14 +11,14 @@ namespace Microsoft.SqlServer.TransactSql.ScriptDom
1111

1212
/// <summary>
1313
/// The enumeration specifies the external model type
14-
/// Currently, we support EMBEDDINGS only.
14+
/// Currently, we support Embeddings only.
1515
/// </summary>
16-
public enum ExternalModelTypeOption
16+
public enum ExternalModelTypeOptionKind
1717
{
1818
/// <summary>
1919
/// MODEL_TYPE = EMBEDDINGS
2020
/// </summary>
21-
EMBEDDINGS = 0,
21+
Embeddings = 0,
2222

2323
}
2424

SqlScriptDom/Parser/TSql/TSql170.g

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24170,20 +24170,26 @@ StringLiteral vApiFormat;
2417024170
;
2417124171

2417224172
externalModelModelType[ExternalModelStatement vParent]
24173+
{
24174+
ExternalModelTypeOption vModelTypeOption = null;
24175+
}
2417324176
:
2417424177
tModelType:Identifier
2417524178
{
2417624179
Match(tModelType, CodeGenerationSupporter.ModelType);
24177-
UpdateTokenInfo(vParent, tModelType);
24180+
vModelTypeOption = this.FragmentFactory.CreateFragment<ExternalModelTypeOption>();
24181+
UpdateTokenInfo(vModelTypeOption, tModelType);
2417824182
}
2417924183
EqualsSign
2418024184
(
2418124185
tEmbeddings:Identifier
2418224186
{
2418324187
if (TryMatch(tEmbeddings, CodeGenerationSupporter.Embeddings))
2418424188
{
24185-
vParent.ModelType = ExternalModelTypeOption.EMBEDDINGS;
24186-
UpdateTokenInfo(vParent, tEmbeddings);
24189+
vModelTypeOption.OptionKind = ExternalModelTypeOptionKind.Embeddings;
24190+
UpdateTokenInfo(vModelTypeOption, tEmbeddings);
24191+
vParent.ModelType = vModelTypeOption;
24192+
vParent.UpdateTokenInfo(vModelTypeOption);
2418724193
}
2418824194
else
2418924195
{

SqlScriptDom/Parser/TSql/TSql180.g

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24170,20 +24170,26 @@ StringLiteral vApiFormat;
2417024170
;
2417124171

2417224172
externalModelModelType[ExternalModelStatement vParent]
24173+
{
24174+
ExternalModelTypeOption vModelTypeOption = null;
24175+
}
2417324176
:
2417424177
tModelType:Identifier
2417524178
{
2417624179
Match(tModelType, CodeGenerationSupporter.ModelType);
24177-
UpdateTokenInfo(vParent, tModelType);
24180+
vModelTypeOption = this.FragmentFactory.CreateFragment<ExternalModelTypeOption>();
24181+
UpdateTokenInfo(vModelTypeOption, tModelType);
2417824182
}
2417924183
EqualsSign
2418024184
(
2418124185
tEmbeddings:Identifier
2418224186
{
2418324187
if (TryMatch(tEmbeddings, CodeGenerationSupporter.Embeddings))
2418424188
{
24185-
vParent.ModelType = ExternalModelTypeOption.EMBEDDINGS;
24186-
UpdateTokenInfo(vParent, tEmbeddings);
24189+
vModelTypeOption.OptionKind = ExternalModelTypeOptionKind.Embeddings;
24190+
UpdateTokenInfo(vModelTypeOption, tEmbeddings);
24191+
vParent.ModelType = vModelTypeOption;
24192+
vParent.UpdateTokenInfo(vModelTypeOption);
2418724193
}
2418824194
else
2418924195
{

SqlScriptDom/ScriptDom/SqlServer/ScriptGenerator/SqlScriptGenerator.AlterExternalModelStatement.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//------------------------------------------------------------------------------
1+
//------------------------------------------------------------------------------
22
// <copyright file="SqlScriptGeneratorVisitor.AlterExternalModelStatement.cs" company="Microsoft">
33
// Copyright (c) Microsoft Corporation. All rights reserved.
44
// </copyright>
@@ -53,17 +53,15 @@ protected void GenerateAlterExternalModelStatementBody(AlterExternalModelStateme
5353
}
5454

5555
// external model Model Type options
56-
if (node.ModelType == ExternalModelTypeOption.EMBEDDINGS)
56+
if (node.ModelType != null)
5757
{
5858
if (!ifFirst)
5959
{
6060
GenerateSymbol(TSqlTokenType.Comma);
6161
}
6262
ifFirst = false;
63-
ExternalModelTypeOption typeOption = ExternalModelTypeOption.EMBEDDINGS;
64-
string externalModelTypeOption = GetValueForEnumKey(_externalModelTypeOption, typeOption);
6563
NewLine();
66-
GenerateNameEqualsValue(CodeGenerationSupporter.ModelType, externalModelTypeOption);
64+
GenerateFragmentIfNotNull(node.ModelType);
6765
}
6866

6967
// external model name options

SqlScriptDom/ScriptDom/SqlServer/ScriptGenerator/SqlScriptGeneratorVisitor.CreateExternalModelStatement.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//------------------------------------------------------------------------------
1+
//------------------------------------------------------------------------------
22
// <copyright file="SqlScriptGeneratorVisitor.CreateExternalModelStatement.cs" company="Microsoft">
33
// Copyright (c) Microsoft Corporation. All rights reserved.
44
// </copyright>
@@ -17,9 +17,9 @@ public override void ExplicitVisit(CreateExternalModelStatement node)
1717
GenerateSpaceAndIdentifier(CodeGenerationSupporter.Model);
1818
GenerateCreateExternalModelStatementBody(node);
1919
}
20-
protected static Dictionary<ExternalModelTypeOption, string> _externalModelTypeOption = new Dictionary<ExternalModelTypeOption, string>()
20+
protected static Dictionary<ExternalModelTypeOptionKind, string> _externalModelTypeOptionKind = new Dictionary<ExternalModelTypeOptionKind, string>()
2121
{
22-
{ExternalModelTypeOption.EMBEDDINGS, CodeGenerationSupporter.Embeddings}
22+
{ExternalModelTypeOptionKind.Embeddings, CodeGenerationSupporter.Embeddings}
2323
};
2424

2525
protected void GenerateCreateExternalModelStatementBody(CreateExternalModelStatement node)
@@ -54,17 +54,15 @@ protected void GenerateCreateExternalModelStatementBody(CreateExternalModelState
5454
}
5555

5656
// external model Model Type options
57-
if (node.ModelType == ExternalModelTypeOption.EMBEDDINGS)
57+
if (node.ModelType != null)
5858
{
5959
if (!ifFirst)
6060
{
6161
GenerateSymbol(TSqlTokenType.Comma);
6262
}
6363
ifFirst = false;
64-
ExternalModelTypeOption typeOption = ExternalModelTypeOption.EMBEDDINGS;
65-
string externalModelTypeOption = GetValueForEnumKey(_externalModelTypeOption, typeOption);
6664
NewLine();
67-
GenerateNameEqualsValue(CodeGenerationSupporter.ModelType, externalModelTypeOption);
65+
GenerateFragmentIfNotNull(node.ModelType);
6866
}
6967

7068
// external model name options
@@ -118,5 +116,11 @@ protected void GenerateCreateExternalModelStatementBody(CreateExternalModelState
118116
NewLine();
119117
GenerateKeyword(TSqlTokenType.RightParenthesis);
120118
}
119+
120+
public override void ExplicitVisit(ExternalModelTypeOption node)
121+
{
122+
string optionKindString = GetValueForEnumKey(_externalModelTypeOptionKind, node.OptionKind);
123+
GenerateNameEqualsValue(CodeGenerationSupporter.ModelType, optionKindString);
124+
}
121125
}
122126
}

Test/SqlDom/Only170SyntaxTests.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,5 +392,37 @@ public void TSql100SyntaxIn170ParserTest()
392392
ParserTest.ParseAndVerify(parser, scriptGen, ti._scriptFilename, ti._result100);
393393
}
394394
}
395+
396+
[TestMethod]
397+
[Priority(0)]
398+
[SqlStudioTestCategory(Category.UnitTest)]
399+
public void TestExternalModelModelTypeVisitor()
400+
{
401+
string sql = "CREATE EXTERNAL MODEL simple_model WITH (LOCATION = '/models/simple', API_FORMAT = 'OpenAI', MODEL_TYPE = EMBEDDINGS, MODEL = 'gpt-3.5-turbo');";
402+
TSql170Parser parser = new TSql170Parser(true);
403+
IList<ParseError> errors;
404+
TSqlFragment fragment = parser.Parse(new StringReader(sql), out errors);
405+
406+
Assert.AreEqual(0, errors.Count);
407+
Assert.IsNotNull(fragment);
408+
409+
// Collect all visited fragments using a custom visitor
410+
var visitor = new ExternalModelTypeOptionVisitor();
411+
fragment.Accept(visitor);
412+
413+
Assert.AreEqual(1, visitor.VisitedOptions.Count);
414+
Assert.AreEqual(ExternalModelTypeOptionKind.Embeddings, visitor.VisitedOptions[0].OptionKind);
415+
}
416+
417+
private class ExternalModelTypeOptionVisitor : TSqlFragmentVisitor
418+
{
419+
public List<ExternalModelTypeOption> VisitedOptions { get; } = new List<ExternalModelTypeOption>();
420+
421+
public override void Visit(ExternalModelTypeOption node)
422+
{
423+
VisitedOptions.Add(node);
424+
base.Visit(node);
425+
}
426+
}
395427
}
396428
}

0 commit comments

Comments
 (0)