Skip to content

Commit 9980cd3

Browse files
authored
feat: add Local field to FlagMetadata for granular flag inheritance control (#12)
1 parent b6285cf commit 9980cd3

5 files changed

Lines changed: 225 additions & 24 deletions

File tree

command.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ type FlagMetadata struct {
8383

8484
// Required indicates whether the flag is required.
8585
Required bool
86+
87+
// Local indicates that the flag should not be inherited by child commands. When true, the
88+
// flag is only available on the command that defines it.
89+
Local bool
8690
}
8791

8892
// FlagsFunc is a helper function that creates a new [flag.FlagSet] and applies the given function

parse.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,21 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
110110
name := strings.TrimLeft(arg, "-")
111111
skipValue := false
112112
for _, cmd := range root.state.path {
113+
localFlags := localFlagSet(cmd.FlagsMetadata)
114+
// Skip local flags on ancestor commands (any command already in the
115+
// path is an ancestor of the not-yet-resolved terminal command).
116+
if localFlags[name] {
117+
continue
118+
}
113119
// First try direct lookup.
114120
f := cmd.Flags.Lookup(name)
115121
// If not found, check if it's a short alias.
116122
if f == nil {
117123
for _, fm := range cmd.FlagsMetadata {
118124
if fm.Short == name {
125+
if localFlags[fm.Name] {
126+
break
127+
}
119128
f = cmd.Flags.Lookup(fm.Name)
120129
break
121130
}
@@ -161,13 +170,20 @@ func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
161170
func combineFlags(path []*Command) *flag.FlagSet {
162171
combined := flag.NewFlagSet(path[0].Name, flag.ContinueOnError)
163172
combined.SetOutput(io.Discard)
164-
for i := len(path) - 1; i >= 0; i-- {
173+
terminalIdx := len(path) - 1
174+
for i := terminalIdx; i >= 0; i-- {
165175
cmd := path[i]
166176
if cmd.Flags == nil {
167177
continue
168178
}
179+
localFlags := localFlagSet(cmd.FlagsMetadata)
169180
shortMap := shortFlagMap(cmd.FlagsMetadata)
181+
isAncestor := i < terminalIdx
170182
cmd.Flags.VisitAll(func(f *flag.Flag) {
183+
// Skip local flags from ancestor commands — they are not inherited.
184+
if isAncestor && localFlags[f.Name] {
185+
return
186+
}
171187
if combined.Lookup(f.Name) == nil {
172188
combined.Var(f.Value, f.Name, f.Usage)
173189
}
@@ -182,6 +198,17 @@ func combineFlags(path []*Command) *flag.FlagSet {
182198
return combined
183199
}
184200

201+
// localFlagSet builds a set of flag names that are marked as local in FlagsMetadata.
202+
func localFlagSet(metadata []FlagMetadata) map[string]bool {
203+
m := make(map[string]bool, len(metadata))
204+
for _, fm := range metadata {
205+
if fm.Local {
206+
m[fm.Name] = true
207+
}
208+
}
209+
return m
210+
}
211+
185212
// shortFlagMap builds a map from long flag name to short alias from FlagsMetadata.
186213
func shortFlagMap(metadata []FlagMetadata) map[string]string {
187214
m := make(map[string]string, len(metadata))
@@ -203,12 +230,17 @@ func checkRequiredFlags(path []*Command, combined *flag.FlagSet) error {
203230
setFlags[f.Name] = struct{}{}
204231
})
205232

233+
terminalIdx := len(path) - 1
206234
var missingFlags []string
207-
for _, cmd := range path {
235+
for i, cmd := range path {
208236
for _, flagMetadata := range cmd.FlagsMetadata {
209237
if !flagMetadata.Required {
210238
continue
211239
}
240+
// Skip required-flag checks for local flags on ancestor commands.
241+
if flagMetadata.Local && i < terminalIdx {
242+
continue
243+
}
212244
if combined.Lookup(flagMetadata.Name) == nil {
213245
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(path), formatFlagName(flagMetadata.Name))
214246
}

parse_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,164 @@ func TestShortFlags(t *testing.T) {
836836
})
837837
}
838838

839+
func TestLocalFlags(t *testing.T) {
840+
t.Parallel()
841+
842+
t.Run("local flag on parent not available to child", func(t *testing.T) {
843+
t.Parallel()
844+
child := &Command{
845+
Name: "child",
846+
Exec: func(ctx context.Context, s *State) error { return nil },
847+
}
848+
root := &Command{
849+
Name: "root",
850+
Flags: FlagsFunc(func(f *flag.FlagSet) {
851+
f.Bool("version", false, "show version")
852+
f.Bool("verbose", false, "enable verbose output")
853+
}),
854+
FlagsMetadata: []FlagMetadata{
855+
{Name: "version", Local: true},
856+
},
857+
SubCommands: []*Command{child},
858+
Exec: func(ctx context.Context, s *State) error { return nil },
859+
}
860+
// --version on child should fail because it's local to root
861+
err := Parse(root, []string{"child", "--version"})
862+
require.Error(t, err)
863+
require.ErrorContains(t, err, "flag provided but not defined")
864+
865+
// --verbose on child should still work (not local)
866+
root2 := &Command{
867+
Name: "root",
868+
Flags: FlagsFunc(func(f *flag.FlagSet) {
869+
f.Bool("version", false, "show version")
870+
f.Bool("verbose", false, "enable verbose output")
871+
}),
872+
FlagsMetadata: []FlagMetadata{
873+
{Name: "version", Local: true},
874+
},
875+
SubCommands: []*Command{{
876+
Name: "child",
877+
Exec: func(ctx context.Context, s *State) error { return nil },
878+
}},
879+
Exec: func(ctx context.Context, s *State) error { return nil },
880+
}
881+
err = Parse(root2, []string{"child", "--verbose"})
882+
require.NoError(t, err)
883+
assert.True(t, GetFlag[bool](root2.state, "verbose"))
884+
})
885+
886+
t.Run("local flag works on defining command", func(t *testing.T) {
887+
t.Parallel()
888+
root := &Command{
889+
Name: "root",
890+
Flags: FlagsFunc(func(f *flag.FlagSet) {
891+
f.Bool("version", false, "show version")
892+
}),
893+
FlagsMetadata: []FlagMetadata{
894+
{Name: "version", Local: true},
895+
},
896+
Exec: func(ctx context.Context, s *State) error { return nil },
897+
}
898+
err := Parse(root, []string{"--version"})
899+
require.NoError(t, err)
900+
assert.True(t, GetFlag[bool](root.state, "version"))
901+
})
902+
903+
t.Run("local required flag only enforced on defining command", func(t *testing.T) {
904+
t.Parallel()
905+
child := &Command{
906+
Name: "child",
907+
Exec: func(ctx context.Context, s *State) error { return nil },
908+
}
909+
root := &Command{
910+
Name: "root",
911+
Flags: FlagsFunc(func(f *flag.FlagSet) {
912+
f.String("token", "", "auth token")
913+
}),
914+
FlagsMetadata: []FlagMetadata{
915+
{Name: "token", Required: true, Local: true},
916+
},
917+
SubCommands: []*Command{child},
918+
Exec: func(ctx context.Context, s *State) error { return nil },
919+
}
920+
// Child command should not require parent's local required flag
921+
err := Parse(root, []string{"child"})
922+
require.NoError(t, err)
923+
924+
// But root command itself should still require it
925+
root2 := &Command{
926+
Name: "root",
927+
Flags: FlagsFunc(func(f *flag.FlagSet) {
928+
f.String("token", "", "auth token")
929+
}),
930+
FlagsMetadata: []FlagMetadata{
931+
{Name: "token", Required: true, Local: true},
932+
},
933+
Exec: func(ctx context.Context, s *State) error { return nil },
934+
}
935+
err = Parse(root2, []string{})
936+
require.Error(t, err)
937+
require.ErrorContains(t, err, "required flag")
938+
})
939+
940+
t.Run("usage excludes local parent flags from inherited flags", func(t *testing.T) {
941+
t.Parallel()
942+
child := &Command{
943+
Name: "child",
944+
Flags: FlagsFunc(func(f *flag.FlagSet) {
945+
f.Bool("dry-run", false, "dry run mode")
946+
}),
947+
Exec: func(ctx context.Context, s *State) error { return nil },
948+
}
949+
root := &Command{
950+
Name: "root",
951+
Flags: FlagsFunc(func(f *flag.FlagSet) {
952+
f.Bool("version", false, "show version")
953+
f.Bool("verbose", false, "enable verbose output")
954+
}),
955+
FlagsMetadata: []FlagMetadata{
956+
{Name: "version", Local: true},
957+
},
958+
SubCommands: []*Command{child},
959+
Exec: func(ctx context.Context, s *State) error { return nil },
960+
}
961+
err := Parse(root, []string{"child", "--help"})
962+
require.ErrorIs(t, err, flag.ErrHelp)
963+
964+
usage := DefaultUsage(root)
965+
// --verbose should appear in inherited flags (not local)
966+
assert.Contains(t, usage, "--verbose")
967+
// --version should NOT appear (local to root, not inherited)
968+
assert.NotContains(t, usage, "--version")
969+
// --dry-run should appear in local flags
970+
assert.Contains(t, usage, "--dry-run")
971+
})
972+
973+
t.Run("local flag with short alias not inherited", func(t *testing.T) {
974+
t.Parallel()
975+
child := &Command{
976+
Name: "child",
977+
Exec: func(ctx context.Context, s *State) error { return nil },
978+
}
979+
root := &Command{
980+
Name: "root",
981+
Flags: FlagsFunc(func(f *flag.FlagSet) {
982+
f.Bool("version", false, "show version")
983+
}),
984+
FlagsMetadata: []FlagMetadata{
985+
{Name: "version", Short: "V", Local: true},
986+
},
987+
SubCommands: []*Command{child},
988+
Exec: func(ctx context.Context, s *State) error { return nil },
989+
}
990+
// Short alias -V should also not work on child
991+
err := Parse(root, []string{"child", "-V"})
992+
require.Error(t, err)
993+
require.ErrorContains(t, err, "flag provided but not defined")
994+
})
995+
}
996+
839997
func getCommand(t *testing.T, c *Command) *Command {
840998
require.NotNil(t, c)
841999
require.NotNil(t, c.state)

usage.go

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,26 @@ func DefaultUsage(root *Command) string {
9090

9191
var flags []flagInfo
9292
if root.state != nil && len(root.state.path) > 0 {
93+
terminalIdx := len(root.state.path) - 1
9394
for i, cmd := range root.state.path {
9495
if cmd.Flags == nil {
9596
continue
9697
}
97-
isGlobal := i < len(root.state.path)-1
98+
isInherited := i < terminalIdx
9899
metaMap := flagMetadataMap(cmd.FlagsMetadata)
99100
cmd.Flags.VisitAll(func(f *flag.Flag) {
101+
// Skip local flags from ancestor commands — they don't appear in child help.
102+
if isInherited {
103+
if m, ok := metaMap[f.Name]; ok && m.Local {
104+
return
105+
}
106+
}
100107
fi := flagInfo{
101-
name: "--" + f.Name,
102-
usage: f.Usage,
103-
defval: f.DefValue,
104-
typeName: flagTypeName(f),
105-
global: isGlobal,
108+
name: "--" + f.Name,
109+
usage: f.Usage,
110+
defval: f.DefValue,
111+
typeName: flagTypeName(f),
112+
inherited: isInherited,
106113
}
107114
if m, ok := metaMap[f.Name]; ok {
108115
fi.required = m.Required
@@ -150,10 +157,10 @@ func DefaultUsage(root *Command) string {
150157
}
151158

152159
hasLocal := false
153-
hasGlobal := false
160+
hasInherited := false
154161
for _, f := range flags {
155-
if f.global {
156-
hasGlobal = true
162+
if f.inherited {
163+
hasInherited = true
157164
} else {
158165
hasLocal = true
159166
}
@@ -165,8 +172,8 @@ func DefaultUsage(root *Command) string {
165172
b.WriteString("\n")
166173
}
167174

168-
if hasGlobal {
169-
b.WriteString("Global Flags:\n")
175+
if hasInherited {
176+
b.WriteString("Inherited Flags:\n")
170177
writeFlagSection(&b, flags, maxFlagLen, true, hasAnyShort)
171178
b.WriteString("\n")
172179
}
@@ -184,12 +191,12 @@ func DefaultUsage(root *Command) string {
184191
}
185192

186193
// writeFlagSection handles the formatting of flag descriptions
187-
func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, global, hasAnyShort bool) {
194+
func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, inherited, hasAnyShort bool) {
188195
nameWidth := maxLen + 4
189196
wrapWidth := defaultTerminalWidth - nameWidth
190197

191198
for _, f := range flags {
192-
if f.global != global {
199+
if f.inherited != inherited {
193200
continue
194201
}
195202

@@ -222,13 +229,13 @@ func flagMetadataMap(metadata []FlagMetadata) map[string]FlagMetadata {
222229
}
223230

224231
type flagInfo struct {
225-
name string
226-
short string
227-
usage string
228-
defval string
229-
typeName string
230-
global bool
231-
required bool
232+
name string
233+
short string
234+
usage string
235+
defval string
236+
typeName string
237+
inherited bool
238+
required bool
232239
}
233240

234241
// displayName returns the flag name with optional short alias and type hint. When hasAnyShort is

usage_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ func TestUsageGeneration(t *testing.T) {
305305
require.Contains(t, output, "custom [options] <file>")
306306
})
307307

308-
t.Run("usage with global and local flags", func(t *testing.T) {
308+
t.Run("usage with inherited and local flags", func(t *testing.T) {
309309
t.Parallel()
310310

311311
child := &Command{
@@ -487,6 +487,6 @@ func TestWriteFlagSection(t *testing.T) {
487487

488488
output := DefaultUsage(cmd)
489489
require.NotContains(t, output, "Flags:")
490-
require.NotContains(t, output, "Global Flags:")
490+
require.NotContains(t, output, "Inherited Flags:")
491491
})
492492
}

0 commit comments

Comments
 (0)