diff --git a/cmd/run/run.go b/cmd/run/run.go index e380de5b..1fe574b2 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -16,6 +16,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/briandowns/spinner" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/modelkey" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/prompt" @@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } + parsedModel, err := modelkey.ParseModelKey(modelName) + if err != nil { + return "", fmt.Errorf("invalid model format: %w", err) + } + + if parsedModel.Provider == "custom" { + // Skip validation for custom provider + return parsedModel.String(), nil + } + + // For non-custom providers, validate the model exists + expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName) foundMatch := false for _, model := range models { - if model.HasName(modelName) { + if model.HasName(expectedModelID) { foundMatch = true break } @@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } - return modelName, nil + return expectedModelID, nil } func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) { diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 43ef6a1c..eb10649c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -403,3 +403,56 @@ func TestParseTemplateVariables(t *testing.T) { }) } } + +func TestValidateModelName(t *testing.T) { + tests := []struct { + name string + modelName string + expectedModel string + expectError bool + }{ + { + name: "custom provider skips validation", + modelName: "custom/mycompany/custom-model", + expectedModel: "custom/mycompany/custom-model", + expectError: false, + }, + { + name: "azureml provider requires validation", + modelName: "openai/gpt-4", + expectedModel: "openai/gpt-4", + expectError: false, + }, + { + name: "invalid model format", + modelName: "invalid-format", + expectError: true, + }, + { + name: "nonexistent azureml model", + modelName: "nonexistent/model", + expectError: true, + }, + } + + // Create a mock model for testing + mockModel := &azuremodels.ModelSummary{ + Name: "gpt-4", + Publisher: "openai", + Task: "chat-completion", + } + models := []*azuremodels.ModelSummary{mockModel} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := validateModelName(tt.modelName, models) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedModel, result) + } + }) + } +} diff --git a/internal/azuremodels/model_details.go b/internal/azuremodels/model_details.go index ecd135ac..53289cf0 100644 --- a/internal/azuremodels/model_details.go +++ b/internal/azuremodels/model_details.go @@ -2,7 +2,8 @@ package azuremodels import ( "fmt" - "strings" + + "github.com/github/gh-models/internal/modelkey" ) // ModelDetails includes detailed information about a model. @@ -28,12 +29,5 @@ func (m *ModelDetails) ContextLimits() string { // FormatIdentifier formats the model identifier based on the publisher and model name. func FormatIdentifier(publisher, name string) string { - formatPart := func(s string) string { - // Replace spaces with dashes and convert to lowercase - result := strings.ToLower(s) - result = strings.ReplaceAll(result, " ", "-") - return result - } - - return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name)) + return modelkey.FormatIdentifier("azureml", publisher, name) } diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go new file mode 100644 index 00000000..bd18562d --- /dev/null +++ b/internal/modelkey/modelkey.go @@ -0,0 +1,76 @@ +package modelkey + +import ( + "fmt" + "strings" +) + +type ModelKey struct { + Provider string + Publisher string + ModelName string +} + +func ParseModelKey(modelKey string) (*ModelKey, error) { + if modelKey == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + + parts := strings.Split(modelKey, "/") + + // Check for empty parts + for _, part := range parts { + if part == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + } + + switch len(parts) { + case 2: + // Format: publisher/model-name (provider defaults to "azureml") + return &ModelKey{ + Provider: "azureml", + Publisher: parts[0], + ModelName: parts[1], + }, nil + case 3: + // Format: provider/publisher/model-name + return &ModelKey{ + Provider: parts[0], + Publisher: parts[1], + ModelName: parts[2], + }, nil + default: + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } +} + +// String returns the string representation of the ModelKey. +func (mk *ModelKey) String() string { + provider := formatPart(mk.Provider) + publisher := formatPart(mk.Publisher) + modelName := formatPart(mk.ModelName) + + if provider == "azureml" { + return fmt.Sprintf("%s/%s", publisher, modelName) + } + + return fmt.Sprintf("%s/%s/%s", provider, publisher, modelName) +} + +func formatPart(s string) string { + s = strings.ToLower(s) + s = strings.ReplaceAll(s, " ", "-") + + return s +} + +func FormatIdentifier(provider, publisher, name string) string { + mk := &ModelKey{ + Provider: provider, + Publisher: publisher, + ModelName: name, + } + + return mk.String() +} diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go new file mode 100644 index 00000000..f4d13410 --- /dev/null +++ b/internal/modelkey/modelkey_test.go @@ -0,0 +1,202 @@ +package modelkey + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseModelKey(t *testing.T) { + tests := []struct { + name string + input string + expected *ModelKey + expectError bool + }{ + { + name: "valid format with provider", + input: "custom/openai/gpt-4", + expected: &ModelKey{ + Provider: "custom", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format without provider (defaults to azureml)", + input: "openai/gpt-4", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format with azureml provider explicitly", + input: "azureml/microsoft/phi-3", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expectError: false, + }, + { + name: "valid format with hyphens in model name", + input: "cohere/command-r-plus", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expectError: false, + }, + { + name: "valid format with underscores in model name", + input: "ai21/jamba_instruct", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expectError: false, + }, + { + name: "invalid format with only one part", + input: "gpt-4", + expected: nil, + expectError: true, + }, + { + name: "invalid format with four parts", + input: "provider/publisher/model/extra", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty string", + input: "", + expected: nil, + expectError: true, + }, + { + name: "invalid format with only slashes", + input: "//", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty parts", + input: "provider//model", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseModelKey(tt.input) + + if tt.expectError { + require.Error(t, err) + require.Nil(t, result) + } else { + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expected.Provider, result.Provider) + require.Equal(t, tt.expected.Publisher, result.Publisher) + require.Equal(t, tt.expected.ModelName, result.ModelName) + } + }) + } +} + +func TestModelKey_String(t *testing.T) { + tests := []struct { + name string + modelKey *ModelKey + expected string + }{ + { + name: "standard format with azureml provider - should omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expected: "openai/gpt-4", + }, + { + name: "custom provider - should include provider", + modelKey: &ModelKey{ + Provider: "custom", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expected: "custom/microsoft/phi-3", + }, + { + name: "azureml provider with hyphens - should omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expected: "cohere/command-r-plus", + }, + { + name: "azureml provider with underscores - should omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expected: "ai21/jamba_instruct", + }, + { + name: "non-azureml provider - should include provider", + modelKey: &ModelKey{ + Provider: "custom-provider", + Publisher: "test-publisher", + ModelName: "test-model", + }, + expected: "custom-provider/test-publisher/test-model", + }, + { + name: "azureml provider with uppercase and spaces - should format and omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Open AI", + ModelName: "GPT 4", + }, + expected: "open-ai/gpt-4", + }, + { + name: "non-azureml provider with uppercase and spaces - should format and include provider", + modelKey: &ModelKey{ + Provider: "Custom Provider", + Publisher: "Test Publisher", + ModelName: "Test Model Name", + }, + expected: "custom-provider/test-publisher/test-model-name", + }, + { + name: "mixed case with multiple spaces", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Microsoft Corporation", + ModelName: "Phi 3 Mini Instruct", + }, + expected: "microsoft-corporation/phi-3-mini-instruct", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.modelKey.String() + require.Equal(t, tt.expected, result) + }) + } +}