github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/gnmi/client.go (about) 1 // Copyright (c) 2017 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package gnmi 6 7 import ( 8 "context" 9 "crypto/tls" 10 "crypto/x509" 11 "flag" 12 "fmt" 13 "math" 14 "net" 15 "os" 16 "regexp" 17 "slices" 18 "strings" 19 20 "github.com/aristanetworks/goarista/netns" 21 pb "github.com/openconfig/gnmi/proto/gnmi" 22 "github.com/openconfig/gnmi/proto/gnmi_ext" 23 "google.golang.org/grpc" 24 "google.golang.org/grpc/credentials" 25 "google.golang.org/grpc/encoding/gzip" 26 "google.golang.org/grpc/metadata" 27 "google.golang.org/protobuf/proto" 28 ) 29 30 const ( 31 defaultPort = "6030" 32 // HostnameArg is the value to be replaced by the actual hostname 33 HostnameArg = "HOSTNAME" 34 ) 35 36 type tlsVersionMap map[string]uint16 37 38 func (m tlsVersionMap) String() string { 39 r := make([]string, 0, len(m)) 40 for k := range m { 41 r = append(r, k) 42 } 43 slices.Sort(r) 44 return strings.Join(r, ", ") 45 } 46 47 // TLSVersions is a map from TLS version strings to the tls version 48 // constants in the crypto/tls package 49 var TLSVersions = getTLSVersions() 50 51 // PublishFunc is the method to publish responses 52 type PublishFunc func(addr string, message proto.Message) 53 54 // ParseHostnames parses a comma-separated list of names and replaces HOSTNAME with the current 55 // hostname in it 56 func ParseHostnames(list string) ([]string, error) { 57 items := strings.Split(list, ",") 58 hostname, err := os.Hostname() 59 if err != nil { 60 return nil, err 61 } 62 names := make([]string, len(items)) 63 for i, name := range items { 64 if name == HostnameArg { 65 name = hostname 66 } 67 names[i] = name 68 } 69 return names, nil 70 } 71 72 // Config is the gnmi.Client config 73 type Config struct { 74 Addr string 75 76 // File path to load data or raw cert data. Alternatively, raw data can be provided below. 77 CAFile string 78 CertFile string 79 KeyFile string 80 81 // Raw certificate data. If respective file is provided above, that is used instead. 82 CAData []byte 83 CertData []byte 84 KeyData []byte 85 86 Password string 87 Username string 88 TLS bool 89 TLSMinVersion string 90 TLSMaxVersion string 91 Compression string 92 BDP bool 93 DialOptions []grpc.DialOption 94 Token string 95 GRPCMetadata map[string]string 96 } 97 98 // SubscribeOptions is the gNMI subscription request options 99 type SubscribeOptions struct { 100 UpdatesOnly bool 101 Prefix string 102 Mode string 103 StreamMode string 104 SampleInterval uint64 105 SuppressRedundant bool 106 HeartbeatInterval uint64 107 Paths [][]string 108 Origin string 109 Target string 110 Extensions []*gnmi_ext.Extension 111 } 112 113 // ParseFlags reads arguments from stdin and returns a populated Config object and a list of 114 // paths to subscribe to 115 func ParseFlags() (*Config, []string) { 116 // flags 117 var ( 118 addrsFlag = flag.String("addrs", "localhost:6030", 119 "Comma-separated list of addresses of OpenConfig gRPC servers. The address 'HOSTNAME' "+ 120 "is replaced by the current hostname.") 121 122 caFileFlag = flag.String("cafile", "", 123 "Path to server TLS certificate file") 124 125 certFileFlag = flag.String("certfile", "", 126 "Path to client TLS certificate file") 127 128 keyFileFlag = flag.String("keyfile", "", 129 "Path to client TLS private key file") 130 131 passwordFlag = flag.String("password", "", 132 "Password to authenticate with") 133 134 usernameFlag = flag.String("username", "", 135 "Username to authenticate with") 136 137 tlsFlag = flag.Bool("tls", false, 138 "Enable TLS") 139 tlsMinVersion = flag.String("tls-min-version", "", 140 fmt.Sprintf("Set minimum TLS version for connection (%s)", TLSVersions)) 141 tlsMaxVersion = flag.String("tls-max-version", "", 142 fmt.Sprintf("Set minimum TLS version for connection (%s)", TLSVersions)) 143 144 compressionFlag = flag.String("compression", "", 145 "Type of compression to use") 146 147 subscribeFlag = flag.String("subscribe", "", 148 "Comma-separated list of paths to subscribe to upon connecting to the server") 149 150 token = flag.String("token", "", 151 "Authentication token") 152 ) 153 flag.Parse() 154 cfg := &Config{ 155 Addr: *addrsFlag, 156 CAFile: *caFileFlag, 157 CertFile: *certFileFlag, 158 KeyFile: *keyFileFlag, 159 Password: *passwordFlag, 160 Username: *usernameFlag, 161 TLS: *tlsFlag, 162 TLSMinVersion: *tlsMinVersion, 163 TLSMaxVersion: *tlsMaxVersion, 164 Compression: *compressionFlag, 165 Token: *token, 166 } 167 subscriptions := strings.Split(*subscribeFlag, ",") 168 return cfg, subscriptions 169 170 } 171 172 // accessTokenCred implements credentials.PerRPCCredentials, the gRPC 173 // interface for credentials that need to attach security information 174 // to every RPC. 175 type accessTokenCred struct { 176 bearerToken string 177 } 178 179 // newAccessTokenCredential constructs a new per-RPC credential from a token. 180 func newAccessTokenCredential(token string) credentials.PerRPCCredentials { 181 bearerFmt := "Bearer %s" 182 return &accessTokenCred{bearerToken: fmt.Sprintf(bearerFmt, token)} 183 } 184 185 func (a *accessTokenCred) GetRequestMetadata(ctx context.Context, 186 uri ...string) (map[string]string, error) { 187 authHeader := "Authorization" 188 return map[string]string{ 189 authHeader: a.bearerToken, 190 }, nil 191 } 192 193 func (a *accessTokenCred) RequireTransportSecurity() bool { return true } 194 195 // DialContextConn connects to a gnmi service and return a client connection 196 func DialContextConn(ctx context.Context, cfg *Config) (*grpc.ClientConn, error) { 197 opts := append([]grpc.DialOption(nil), cfg.DialOptions...) 198 199 if !cfg.BDP { 200 // By default, the client and server will dynamically adjust the connection's 201 // window size using the Bandwidth Delay Product (BDP). 202 // See: https://grpc.io/blog/grpc-go-perf-improvements/ 203 // The default values for InitialWindowSize and InitialConnWindowSize are 65535. 204 // If values less than 65535 are used, then BDP and dynamic windows are enabled. 205 // Here, we disable the BDP and dynamic windows by setting these values >= 65535. 206 // We set these values to (1 << 20) * 16 as this is the largest window size that 207 // the BDP estimator could ever use. 208 // See: https://github.com/grpc/grpc-go/blob/master/internal/transport/bdp_estimator.go 209 const maxWindowSize int32 = (1 << 20) * 16 210 opts = append(opts, 211 grpc.WithInitialWindowSize(maxWindowSize), 212 grpc.WithInitialConnWindowSize(maxWindowSize), 213 ) 214 } 215 216 switch cfg.Compression { 217 case "": 218 case "gzip": 219 opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name))) 220 default: 221 return nil, fmt.Errorf("unsupported compression option: %q", cfg.Compression) 222 } 223 224 var err error 225 caData := cfg.CAData 226 certData := cfg.CertData 227 keyData := cfg.KeyData 228 if cfg.CAFile != "" { 229 if caData, err = os.ReadFile(cfg.CAFile); err != nil { 230 return nil, err 231 } 232 } 233 if cfg.CertFile != "" { 234 if certData, err = os.ReadFile(cfg.CertFile); err != nil { 235 return nil, err 236 } 237 } 238 if cfg.KeyFile != "" { 239 if keyData, err = os.ReadFile(cfg.KeyFile); err != nil { 240 return nil, err 241 } 242 } 243 244 if cfg.TLS || len(caData) > 0 || len(certData) > 0 || cfg.Token != "" { 245 tlsConfig := &tls.Config{} 246 if len(caData) > 0 { 247 cp := x509.NewCertPool() 248 if !cp.AppendCertsFromPEM(caData) { 249 return nil, fmt.Errorf("credentials: failed to append certificates") 250 } 251 tlsConfig.RootCAs = cp 252 } else { 253 tlsConfig.InsecureSkipVerify = true 254 } 255 if len(certData) > 0 { 256 if len(keyData) == 0 { 257 return nil, fmt.Errorf("no key provided for client certificate") 258 } 259 cert, err := tls.X509KeyPair(certData, keyData) 260 if err != nil { 261 return nil, err 262 } 263 tlsConfig.Certificates = []tls.Certificate{cert} 264 } 265 if cfg.Token != "" { 266 opts = append(opts, 267 grpc.WithPerRPCCredentials(newAccessTokenCredential(cfg.Token))) 268 } 269 if cfg.TLSMaxVersion != "" { 270 var ok bool 271 tlsConfig.MaxVersion, ok = TLSVersions[cfg.TLSMaxVersion] 272 if !ok { 273 return nil, fmt.Errorf("unrecognised TLS max version."+ 274 " Supported TLS versions are %s", TLSVersions) 275 } 276 } 277 if cfg.TLSMinVersion != "" { 278 var ok bool 279 tlsConfig.MinVersion, ok = TLSVersions[cfg.TLSMinVersion] 280 if !ok { 281 return nil, fmt.Errorf("unrecognised TLS min version."+ 282 " Supported TLS versions are %s", TLSVersions) 283 } 284 } 285 if cfg.TLSMinVersion != "" && cfg.TLSMaxVersion != "" && 286 tlsConfig.MinVersion > tlsConfig.MaxVersion { 287 return nil, fmt.Errorf( 288 "TLS min version was greater than TLS max version") 289 } 290 291 opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 292 } else { 293 opts = append(opts, grpc.WithInsecure()) 294 } 295 296 dial := func(ctx context.Context, addrIn string) (conn net.Conn, err error) { 297 var network, nsName, addr string 298 299 split := strings.Split(addrIn, "://") 300 if l := len(split); l == 2 { 301 network = split[0] 302 addr = split[1] 303 } else { 304 network = "tcp" 305 addr = split[0] 306 } 307 308 if !strings.HasPrefix(network, "unix") { 309 if !strings.ContainsRune(addr, ':') { 310 addr += ":" + defaultPort 311 } 312 313 nsName, addr, err = netns.ParseAddress(addr) 314 if err != nil { 315 return nil, err 316 } 317 } 318 319 err = netns.Do(nsName, func() (err error) { 320 conn, err = (&net.Dialer{}).DialContext(ctx, network, addr) 321 return 322 }) 323 return 324 } 325 326 opts = append(opts, 327 grpc.WithContextDialer(dial), 328 329 // Allows received protobuf messages to be larger than 4MB 330 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)), 331 ) 332 333 return grpc.DialContext(ctx, cfg.Addr, opts...) 334 } 335 336 // DialContext connects to a gnmi service and returns a client 337 func DialContext(ctx context.Context, cfg *Config) (pb.GNMIClient, error) { 338 grpcconn, err := DialContextConn(ctx, cfg) 339 if err != nil { 340 return nil, fmt.Errorf("failed to dial: %s", err) 341 } 342 return pb.NewGNMIClient(grpcconn), nil 343 } 344 345 // Dial connects to a gnmi service and returns a client 346 func Dial(cfg *Config) (pb.GNMIClient, error) { 347 return DialContext(context.Background(), cfg) 348 } 349 350 // NewContext returns a new context with username and password 351 // metadata if they are set in cfg, as well as any other metadata 352 // provided. 353 func NewContext(ctx context.Context, cfg *Config) context.Context { 354 md := map[string]string{} 355 for k, v := range cfg.GRPCMetadata { 356 md[k] = v 357 } 358 if cfg.Username != "" { 359 md["username"] = cfg.Username 360 md["password"] = cfg.Password 361 } 362 if len(md) > 0 { 363 ctx = metadata.NewOutgoingContext(ctx, metadata.New(md)) 364 } 365 return ctx 366 } 367 368 // NewGetRequest returns a GetRequest for the given paths 369 func NewGetRequest(paths [][]string, origin string) (*pb.GetRequest, error) { 370 req := &pb.GetRequest{ 371 Path: make([]*pb.Path, len(paths)), 372 } 373 for i, p := range paths { 374 gnmiPath, err := ParseGNMIElements(p) 375 if err != nil { 376 return nil, err 377 } 378 req.Path[i] = gnmiPath 379 req.Path[i].Origin = origin 380 } 381 return req, nil 382 } 383 384 // NewSubscribeRequest returns a SubscribeRequest for the given paths 385 func NewSubscribeRequest(subscribeOptions *SubscribeOptions) (*pb.SubscribeRequest, error) { 386 var mode pb.SubscriptionList_Mode 387 switch subscribeOptions.Mode { 388 case "once": 389 mode = pb.SubscriptionList_ONCE 390 case "poll": 391 mode = pb.SubscriptionList_POLL 392 case "": 393 fallthrough 394 case "stream": 395 mode = pb.SubscriptionList_STREAM 396 default: 397 return nil, fmt.Errorf("subscribe mode (%s) invalid", subscribeOptions.Mode) 398 } 399 400 var streamMode pb.SubscriptionMode 401 switch subscribeOptions.StreamMode { 402 case "on_change": 403 streamMode = pb.SubscriptionMode_ON_CHANGE 404 case "sample": 405 streamMode = pb.SubscriptionMode_SAMPLE 406 case "": 407 fallthrough 408 case "target_defined": 409 streamMode = pb.SubscriptionMode_TARGET_DEFINED 410 default: 411 return nil, fmt.Errorf("subscribe stream mode (%s) invalid", subscribeOptions.StreamMode) 412 } 413 414 prefixPath, err := ParseGNMIElements(SplitPath(subscribeOptions.Prefix)) 415 if err != nil { 416 return nil, err 417 } 418 subList := &pb.SubscriptionList{ 419 Subscription: make([]*pb.Subscription, len(subscribeOptions.Paths)), 420 Mode: mode, 421 UpdatesOnly: subscribeOptions.UpdatesOnly, 422 Prefix: prefixPath, 423 } 424 if subscribeOptions.Target != "" { 425 if subList.Prefix == nil { 426 subList.Prefix = &pb.Path{} 427 } 428 subList.Prefix.Target = subscribeOptions.Target 429 } 430 for i, p := range subscribeOptions.Paths { 431 gnmiPath, err := ParseGNMIElements(p) 432 if err != nil { 433 return nil, err 434 } 435 gnmiPath.Origin = subscribeOptions.Origin 436 subList.Subscription[i] = &pb.Subscription{ 437 Path: gnmiPath, 438 Mode: streamMode, 439 SampleInterval: subscribeOptions.SampleInterval, 440 SuppressRedundant: subscribeOptions.SuppressRedundant, 441 HeartbeatInterval: subscribeOptions.HeartbeatInterval, 442 } 443 } 444 return &pb.SubscribeRequest{ 445 Extension: subscribeOptions.Extensions, 446 Request: &pb.SubscribeRequest_Subscribe{ 447 Subscribe: subList, 448 }, 449 }, nil 450 } 451 452 // HistorySnapshotExtension returns an Extension_History for the given 453 // time. 454 func HistorySnapshotExtension(t int64) *gnmi_ext.Extension_History { 455 return &gnmi_ext.Extension_History{ 456 History: &gnmi_ext.History{ 457 Request: &gnmi_ext.History_SnapshotTime{ 458 SnapshotTime: t, 459 }, 460 }, 461 } 462 } 463 464 // HistoryRangeExtension returns an Extension_History for the the 465 // specified start and end times. 466 func HistoryRangeExtension(s, e int64) *gnmi_ext.Extension_History { 467 return &gnmi_ext.Extension_History{ 468 History: &gnmi_ext.History{ 469 Request: &gnmi_ext.History_Range{ 470 Range: &gnmi_ext.TimeRange{ 471 Start: s, 472 End: e, 473 }, 474 }, 475 }, 476 } 477 } 478 479 // getTLSVersions generates a map of TLS version name to tls version, based on the versions 480 // available in the crypto/tls package 481 func getTLSVersions(testHook ...func(uint16, *regexp.Regexp)) tlsVersionMap { 482 cipherSuites := tls.CipherSuites() 483 allSupportedVersions := make(map[uint16]struct{}) 484 485 for _, cipherSuite := range cipherSuites { 486 for _, version := range cipherSuite.SupportedVersions { 487 allSupportedVersions[version] = struct{}{} 488 } 489 } 490 491 // match TLS versions in dot format like X.Y or X.Y.Z etc (right now everything is X.Y) 492 re := regexp.MustCompile(`[\d.]+`) 493 494 nameToVersion := make(map[string]uint16, len(allSupportedVersions)) 495 for version := range allSupportedVersions { 496 // tls.VersionName(version) will be something like "TLS 1.3" 497 name := re.FindString(tls.VersionName(version)) 498 // check if the regex either failed to match, or if it is not specific enough 499 // (matching something which was already found) 500 if _, ok := nameToVersion[name]; ok || name == "" { 501 // if we ever fail to match a regex we shouldn't do anything in production 502 // but let's make a test fail so we can investigate and update the regex 503 for _, f := range testHook { 504 f(version, re) 505 } 506 continue 507 } 508 509 nameToVersion[name] = version 510 511 } 512 return nameToVersion 513 }