-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathcompositeconfig.go
More file actions
137 lines (121 loc) · 3.52 KB
/
compositeconfig.go
File metadata and controls
137 lines (121 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package rig
import (
"errors"
"fmt"
"github.com/k0sproject/rig/v2/protocol"
"github.com/k0sproject/rig/v2/protocol/localhost"
"github.com/k0sproject/rig/v2/protocol/openssh"
"github.com/k0sproject/rig/v2/protocol/ssh"
"github.com/k0sproject/rig/v2/protocol/winrm"
)
var _ ConnectionFactory = (*CompositeConfig)(nil)
// LocalhostConfig is a bool-valued type that also accepts the v0.x YAML form
// "localhost:\n enabled: true" for backward compatibility. To assign a bool
// variable, use an explicit cast: LocalhostConfig(b).
type LocalhostConfig bool
// UnmarshalYAML implements yaml.Unmarshaler, accepting both:
//
// localhost: true (current form)
// localhost: (v0.x form)
// enabled: true
func (l *LocalhostConfig) UnmarshalYAML(unmarshal func(any) error) error {
*l = false
var b bool
if err := unmarshal(&b); err == nil {
*l = LocalhostConfig(b)
return nil
}
var m map[string]any
if err := unmarshal(&m); err != nil {
return errors.Join(fmt.Errorf("%w: localhost must be a bool or {enabled: bool}", protocol.ErrValidationFailed), err)
}
for k := range m {
if k != "enabled" {
return fmt.Errorf("%w: localhost mapping has unknown key %q (only 'enabled' is allowed)", protocol.ErrValidationFailed, k)
}
}
if v, ok := m["enabled"]; ok {
b, ok := v.(bool)
if !ok {
return fmt.Errorf("%w: localhost 'enabled' must be a bool, got %T", protocol.ErrValidationFailed, v)
}
*l = LocalhostConfig(b)
}
return nil
}
// CompositeConfig is a composite configuration of all the protocols supported out of the box by rig.
// It is intended to be embedded into host structs that are unmarshaled from configuration files.
type CompositeConfig struct {
SSH *ssh.Config `yaml:"ssh,omitempty"`
WinRM *winrm.Config `yaml:"winRM,omitempty"`
OpenSSH *openssh.Config `yaml:"openSSH,omitempty"`
Localhost LocalhostConfig `yaml:"localhost,omitempty"`
}
func (c *CompositeConfig) configuredConfig() (ConnectionFactory, error) {
var factory ConnectionFactory
count := 0
if c.WinRM != nil {
factory = c.WinRM
count++
}
if c.SSH != nil {
factory = c.SSH
count++
}
if c.OpenSSH != nil {
factory = c.OpenSSH
count++
}
if c.Localhost {
count++
conn, err := localhost.NewConnection()
if err != nil {
return nil, fmt.Errorf("create localhost connection: %w", err)
}
factory = conn
}
switch count {
case 0:
return nil, fmt.Errorf("%w: no protocol configuration", protocol.ErrValidationFailed)
case 1:
return factory, nil
default:
return nil, fmt.Errorf("%w: multiple protocols configured for a single client", protocol.ErrValidationFailed)
}
}
type validatable interface {
Validate() error
}
// Validate the configuration.
func (c *CompositeConfig) Validate() error {
factory, err := c.configuredConfig()
if err != nil {
return err
}
if v, ok := factory.(validatable); ok {
if err := v.Validate(); err != nil {
return fmt.Errorf("validate %T: %w", factory, err)
}
}
return nil
}
// Connection returns a connection for the first configured protocol.
func (c *CompositeConfig) Connection() (protocol.Connection, error) {
cfg, err := c.configuredConfig()
if err != nil {
return nil, err
}
conn, err := cfg.Connection()
if err != nil {
return nil, fmt.Errorf("create connection for %T: %w", cfg, err)
}
return conn, nil
}
// String returns the string representation of the first configured protocol configuration.
func (c *CompositeConfig) String() string {
cfg, err := c.configuredConfig()
if err != nil {
return "unknown{}"
}
return cfg.String()
}