github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/gnmi/client.go (about)

     1  // Copyright (c) 2017 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package gnmi
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"flag"
    12  	"fmt"
    13  	"math"
    14  	"net"
    15  	"os"
    16  	"regexp"
    17  	"slices"
    18  	"strings"
    19  
    20  	"github.com/aristanetworks/goarista/netns"
    21  	pb "github.com/openconfig/gnmi/proto/gnmi"
    22  	"github.com/openconfig/gnmi/proto/gnmi_ext"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials"
    25  	"google.golang.org/grpc/encoding/gzip"
    26  	"google.golang.org/grpc/metadata"
    27  	"google.golang.org/protobuf/proto"
    28  )
    29  
    30  const (
    31  	defaultPort = "6030"
    32  	// HostnameArg is the value to be replaced by the actual hostname
    33  	HostnameArg = "HOSTNAME"
    34  )
    35  
    36  type tlsVersionMap map[string]uint16
    37  
    38  func (m tlsVersionMap) String() string {
    39  	r := make([]string, 0, len(m))
    40  	for k := range m {
    41  		r = append(r, k)
    42  	}
    43  	slices.Sort(r)
    44  	return strings.Join(r, ", ")
    45  }
    46  
    47  // TLSVersions is a map from TLS version strings to the tls version
    48  // constants in the crypto/tls package
    49  var TLSVersions = getTLSVersions()
    50  
    51  // PublishFunc is the method to publish responses
    52  type PublishFunc func(addr string, message proto.Message)
    53  
    54  // ParseHostnames parses a comma-separated list of names and replaces HOSTNAME with the current
    55  // hostname in it
    56  func ParseHostnames(list string) ([]string, error) {
    57  	items := strings.Split(list, ",")
    58  	hostname, err := os.Hostname()
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	names := make([]string, len(items))
    63  	for i, name := range items {
    64  		if name == HostnameArg {
    65  			name = hostname
    66  		}
    67  		names[i] = name
    68  	}
    69  	return names, nil
    70  }
    71  
    72  // Config is the gnmi.Client config
    73  type Config struct {
    74  	Addr string
    75  
    76  	// File path to load data or raw cert data. Alternatively, raw data can be provided below.
    77  	CAFile   string
    78  	CertFile string
    79  	KeyFile  string
    80  
    81  	// Raw certificate data. If respective file is provided above, that is used instead.
    82  	CAData   []byte
    83  	CertData []byte
    84  	KeyData  []byte
    85  
    86  	Password      string
    87  	Username      string
    88  	TLS           bool
    89  	TLSMinVersion string
    90  	TLSMaxVersion string
    91  	Compression   string
    92  	BDP           bool
    93  	DialOptions   []grpc.DialOption
    94  	Token         string
    95  	GRPCMetadata  map[string]string
    96  }
    97  
    98  // SubscribeOptions is the gNMI subscription request options
    99  type SubscribeOptions struct {
   100  	UpdatesOnly       bool
   101  	Prefix            string
   102  	Mode              string
   103  	StreamMode        string
   104  	SampleInterval    uint64
   105  	SuppressRedundant bool
   106  	HeartbeatInterval uint64
   107  	Paths             [][]string
   108  	Origin            string
   109  	Target            string
   110  	Extensions        []*gnmi_ext.Extension
   111  }
   112  
   113  // ParseFlags reads arguments from stdin and returns a populated Config object and a list of
   114  // paths to subscribe to
   115  func ParseFlags() (*Config, []string) {
   116  	// flags
   117  	var (
   118  		addrsFlag = flag.String("addrs", "localhost:6030",
   119  			"Comma-separated list of addresses of OpenConfig gRPC servers. The address 'HOSTNAME' "+
   120  				"is replaced by the current hostname.")
   121  
   122  		caFileFlag = flag.String("cafile", "",
   123  			"Path to server TLS certificate file")
   124  
   125  		certFileFlag = flag.String("certfile", "",
   126  			"Path to client TLS certificate file")
   127  
   128  		keyFileFlag = flag.String("keyfile", "",
   129  			"Path to client TLS private key file")
   130  
   131  		passwordFlag = flag.String("password", "",
   132  			"Password to authenticate with")
   133  
   134  		usernameFlag = flag.String("username", "",
   135  			"Username to authenticate with")
   136  
   137  		tlsFlag = flag.Bool("tls", false,
   138  			"Enable TLS")
   139  		tlsMinVersion = flag.String("tls-min-version", "",
   140  			fmt.Sprintf("Set minimum TLS version for connection (%s)", TLSVersions))
   141  		tlsMaxVersion = flag.String("tls-max-version", "",
   142  			fmt.Sprintf("Set minimum TLS version for connection (%s)", TLSVersions))
   143  
   144  		compressionFlag = flag.String("compression", "",
   145  			"Type of compression to use")
   146  
   147  		subscribeFlag = flag.String("subscribe", "",
   148  			"Comma-separated list of paths to subscribe to upon connecting to the server")
   149  
   150  		token = flag.String("token", "",
   151  			"Authentication token")
   152  	)
   153  	flag.Parse()
   154  	cfg := &Config{
   155  		Addr:          *addrsFlag,
   156  		CAFile:        *caFileFlag,
   157  		CertFile:      *certFileFlag,
   158  		KeyFile:       *keyFileFlag,
   159  		Password:      *passwordFlag,
   160  		Username:      *usernameFlag,
   161  		TLS:           *tlsFlag,
   162  		TLSMinVersion: *tlsMinVersion,
   163  		TLSMaxVersion: *tlsMaxVersion,
   164  		Compression:   *compressionFlag,
   165  		Token:         *token,
   166  	}
   167  	subscriptions := strings.Split(*subscribeFlag, ",")
   168  	return cfg, subscriptions
   169  
   170  }
   171  
   172  // accessTokenCred implements credentials.PerRPCCredentials, the gRPC
   173  // interface for credentials that need to attach security information
   174  // to every RPC.
   175  type accessTokenCred struct {
   176  	bearerToken string
   177  }
   178  
   179  // newAccessTokenCredential constructs a new per-RPC credential from a token.
   180  func newAccessTokenCredential(token string) credentials.PerRPCCredentials {
   181  	bearerFmt := "Bearer %s"
   182  	return &accessTokenCred{bearerToken: fmt.Sprintf(bearerFmt, token)}
   183  }
   184  
   185  func (a *accessTokenCred) GetRequestMetadata(ctx context.Context,
   186  	uri ...string) (map[string]string, error) {
   187  	authHeader := "Authorization"
   188  	return map[string]string{
   189  		authHeader: a.bearerToken,
   190  	}, nil
   191  }
   192  
   193  func (a *accessTokenCred) RequireTransportSecurity() bool { return true }
   194  
   195  // DialContextConn connects to a gnmi service and return a client connection
   196  func DialContextConn(ctx context.Context, cfg *Config) (*grpc.ClientConn, error) {
   197  	opts := append([]grpc.DialOption(nil), cfg.DialOptions...)
   198  
   199  	if !cfg.BDP {
   200  		// By default, the client and server will dynamically adjust the connection's
   201  		// window size using the Bandwidth Delay Product (BDP).
   202  		// See: https://grpc.io/blog/grpc-go-perf-improvements/
   203  		// The default values for InitialWindowSize and InitialConnWindowSize are 65535.
   204  		// If values less than 65535 are used, then BDP and dynamic windows are enabled.
   205  		// Here, we disable the BDP and dynamic windows by setting these values >= 65535.
   206  		// We set these values to (1 << 20) * 16 as this is the largest window size that
   207  		// the BDP estimator could ever use.
   208  		// See: https://github.com/grpc/grpc-go/blob/master/internal/transport/bdp_estimator.go
   209  		const maxWindowSize int32 = (1 << 20) * 16
   210  		opts = append(opts,
   211  			grpc.WithInitialWindowSize(maxWindowSize),
   212  			grpc.WithInitialConnWindowSize(maxWindowSize),
   213  		)
   214  	}
   215  
   216  	switch cfg.Compression {
   217  	case "":
   218  	case "gzip":
   219  		opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name)))
   220  	default:
   221  		return nil, fmt.Errorf("unsupported compression option: %q", cfg.Compression)
   222  	}
   223  
   224  	var err error
   225  	caData := cfg.CAData
   226  	certData := cfg.CertData
   227  	keyData := cfg.KeyData
   228  	if cfg.CAFile != "" {
   229  		if caData, err = os.ReadFile(cfg.CAFile); err != nil {
   230  			return nil, err
   231  		}
   232  	}
   233  	if cfg.CertFile != "" {
   234  		if certData, err = os.ReadFile(cfg.CertFile); err != nil {
   235  			return nil, err
   236  		}
   237  	}
   238  	if cfg.KeyFile != "" {
   239  		if keyData, err = os.ReadFile(cfg.KeyFile); err != nil {
   240  			return nil, err
   241  		}
   242  	}
   243  
   244  	if cfg.TLS || len(caData) > 0 || len(certData) > 0 || cfg.Token != "" {
   245  		tlsConfig := &tls.Config{}
   246  		if len(caData) > 0 {
   247  			cp := x509.NewCertPool()
   248  			if !cp.AppendCertsFromPEM(caData) {
   249  				return nil, fmt.Errorf("credentials: failed to append certificates")
   250  			}
   251  			tlsConfig.RootCAs = cp
   252  		} else {
   253  			tlsConfig.InsecureSkipVerify = true
   254  		}
   255  		if len(certData) > 0 {
   256  			if len(keyData) == 0 {
   257  				return nil, fmt.Errorf("no key provided for client certificate")
   258  			}
   259  			cert, err := tls.X509KeyPair(certData, keyData)
   260  			if err != nil {
   261  				return nil, err
   262  			}
   263  			tlsConfig.Certificates = []tls.Certificate{cert}
   264  		}
   265  		if cfg.Token != "" {
   266  			opts = append(opts,
   267  				grpc.WithPerRPCCredentials(newAccessTokenCredential(cfg.Token)))
   268  		}
   269  		if cfg.TLSMaxVersion != "" {
   270  			var ok bool
   271  			tlsConfig.MaxVersion, ok = TLSVersions[cfg.TLSMaxVersion]
   272  			if !ok {
   273  				return nil, fmt.Errorf("unrecognised TLS max version."+
   274  					" Supported TLS versions are %s", TLSVersions)
   275  			}
   276  		}
   277  		if cfg.TLSMinVersion != "" {
   278  			var ok bool
   279  			tlsConfig.MinVersion, ok = TLSVersions[cfg.TLSMinVersion]
   280  			if !ok {
   281  				return nil, fmt.Errorf("unrecognised TLS min version."+
   282  					" Supported TLS versions are %s", TLSVersions)
   283  			}
   284  		}
   285  		if cfg.TLSMinVersion != "" && cfg.TLSMaxVersion != "" &&
   286  			tlsConfig.MinVersion > tlsConfig.MaxVersion {
   287  			return nil, fmt.Errorf(
   288  				"TLS min version was greater than TLS max version")
   289  		}
   290  
   291  		opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
   292  	} else {
   293  		opts = append(opts, grpc.WithInsecure())
   294  	}
   295  
   296  	dial := func(ctx context.Context, addrIn string) (conn net.Conn, err error) {
   297  		var network, nsName, addr string
   298  
   299  		split := strings.Split(addrIn, "://")
   300  		if l := len(split); l == 2 {
   301  			network = split[0]
   302  			addr = split[1]
   303  		} else {
   304  			network = "tcp"
   305  			addr = split[0]
   306  		}
   307  
   308  		if !strings.HasPrefix(network, "unix") {
   309  			if !strings.ContainsRune(addr, ':') {
   310  				addr += ":" + defaultPort
   311  			}
   312  
   313  			nsName, addr, err = netns.ParseAddress(addr)
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  		}
   318  
   319  		err = netns.Do(nsName, func() (err error) {
   320  			conn, err = (&net.Dialer{}).DialContext(ctx, network, addr)
   321  			return
   322  		})
   323  		return
   324  	}
   325  
   326  	opts = append(opts,
   327  		grpc.WithContextDialer(dial),
   328  
   329  		// Allows received protobuf messages to be larger than 4MB
   330  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
   331  	)
   332  
   333  	return grpc.DialContext(ctx, cfg.Addr, opts...)
   334  }
   335  
   336  // DialContext connects to a gnmi service and returns a client
   337  func DialContext(ctx context.Context, cfg *Config) (pb.GNMIClient, error) {
   338  	grpcconn, err := DialContextConn(ctx, cfg)
   339  	if err != nil {
   340  		return nil, fmt.Errorf("failed to dial: %s", err)
   341  	}
   342  	return pb.NewGNMIClient(grpcconn), nil
   343  }
   344  
   345  // Dial connects to a gnmi service and returns a client
   346  func Dial(cfg *Config) (pb.GNMIClient, error) {
   347  	return DialContext(context.Background(), cfg)
   348  }
   349  
   350  // NewContext returns a new context with username and password
   351  // metadata if they are set in cfg, as well as any other metadata
   352  // provided.
   353  func NewContext(ctx context.Context, cfg *Config) context.Context {
   354  	md := map[string]string{}
   355  	for k, v := range cfg.GRPCMetadata {
   356  		md[k] = v
   357  	}
   358  	if cfg.Username != "" {
   359  		md["username"] = cfg.Username
   360  		md["password"] = cfg.Password
   361  	}
   362  	if len(md) > 0 {
   363  		ctx = metadata.NewOutgoingContext(ctx, metadata.New(md))
   364  	}
   365  	return ctx
   366  }
   367  
   368  // NewGetRequest returns a GetRequest for the given paths
   369  func NewGetRequest(paths [][]string, origin string) (*pb.GetRequest, error) {
   370  	req := &pb.GetRequest{
   371  		Path: make([]*pb.Path, len(paths)),
   372  	}
   373  	for i, p := range paths {
   374  		gnmiPath, err := ParseGNMIElements(p)
   375  		if err != nil {
   376  			return nil, err
   377  		}
   378  		req.Path[i] = gnmiPath
   379  		req.Path[i].Origin = origin
   380  	}
   381  	return req, nil
   382  }
   383  
   384  // NewSubscribeRequest returns a SubscribeRequest for the given paths
   385  func NewSubscribeRequest(subscribeOptions *SubscribeOptions) (*pb.SubscribeRequest, error) {
   386  	var mode pb.SubscriptionList_Mode
   387  	switch subscribeOptions.Mode {
   388  	case "once":
   389  		mode = pb.SubscriptionList_ONCE
   390  	case "poll":
   391  		mode = pb.SubscriptionList_POLL
   392  	case "":
   393  		fallthrough
   394  	case "stream":
   395  		mode = pb.SubscriptionList_STREAM
   396  	default:
   397  		return nil, fmt.Errorf("subscribe mode (%s) invalid", subscribeOptions.Mode)
   398  	}
   399  
   400  	var streamMode pb.SubscriptionMode
   401  	switch subscribeOptions.StreamMode {
   402  	case "on_change":
   403  		streamMode = pb.SubscriptionMode_ON_CHANGE
   404  	case "sample":
   405  		streamMode = pb.SubscriptionMode_SAMPLE
   406  	case "":
   407  		fallthrough
   408  	case "target_defined":
   409  		streamMode = pb.SubscriptionMode_TARGET_DEFINED
   410  	default:
   411  		return nil, fmt.Errorf("subscribe stream mode (%s) invalid", subscribeOptions.StreamMode)
   412  	}
   413  
   414  	prefixPath, err := ParseGNMIElements(SplitPath(subscribeOptions.Prefix))
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  	subList := &pb.SubscriptionList{
   419  		Subscription: make([]*pb.Subscription, len(subscribeOptions.Paths)),
   420  		Mode:         mode,
   421  		UpdatesOnly:  subscribeOptions.UpdatesOnly,
   422  		Prefix:       prefixPath,
   423  	}
   424  	if subscribeOptions.Target != "" {
   425  		if subList.Prefix == nil {
   426  			subList.Prefix = &pb.Path{}
   427  		}
   428  		subList.Prefix.Target = subscribeOptions.Target
   429  	}
   430  	for i, p := range subscribeOptions.Paths {
   431  		gnmiPath, err := ParseGNMIElements(p)
   432  		if err != nil {
   433  			return nil, err
   434  		}
   435  		gnmiPath.Origin = subscribeOptions.Origin
   436  		subList.Subscription[i] = &pb.Subscription{
   437  			Path:              gnmiPath,
   438  			Mode:              streamMode,
   439  			SampleInterval:    subscribeOptions.SampleInterval,
   440  			SuppressRedundant: subscribeOptions.SuppressRedundant,
   441  			HeartbeatInterval: subscribeOptions.HeartbeatInterval,
   442  		}
   443  	}
   444  	return &pb.SubscribeRequest{
   445  		Extension: subscribeOptions.Extensions,
   446  		Request: &pb.SubscribeRequest_Subscribe{
   447  			Subscribe: subList,
   448  		},
   449  	}, nil
   450  }
   451  
   452  // HistorySnapshotExtension returns an Extension_History for the given
   453  // time.
   454  func HistorySnapshotExtension(t int64) *gnmi_ext.Extension_History {
   455  	return &gnmi_ext.Extension_History{
   456  		History: &gnmi_ext.History{
   457  			Request: &gnmi_ext.History_SnapshotTime{
   458  				SnapshotTime: t,
   459  			},
   460  		},
   461  	}
   462  }
   463  
   464  // HistoryRangeExtension returns an Extension_History for the the
   465  // specified start and end times.
   466  func HistoryRangeExtension(s, e int64) *gnmi_ext.Extension_History {
   467  	return &gnmi_ext.Extension_History{
   468  		History: &gnmi_ext.History{
   469  			Request: &gnmi_ext.History_Range{
   470  				Range: &gnmi_ext.TimeRange{
   471  					Start: s,
   472  					End:   e,
   473  				},
   474  			},
   475  		},
   476  	}
   477  }
   478  
   479  // getTLSVersions generates a map of TLS version name to tls version, based on the versions
   480  // available in the crypto/tls package
   481  func getTLSVersions(testHook ...func(uint16, *regexp.Regexp)) tlsVersionMap {
   482  	cipherSuites := tls.CipherSuites()
   483  	allSupportedVersions := make(map[uint16]struct{})
   484  
   485  	for _, cipherSuite := range cipherSuites {
   486  		for _, version := range cipherSuite.SupportedVersions {
   487  			allSupportedVersions[version] = struct{}{}
   488  		}
   489  	}
   490  
   491  	// match TLS versions in dot format like X.Y or X.Y.Z etc (right now everything is X.Y)
   492  	re := regexp.MustCompile(`[\d.]+`)
   493  
   494  	nameToVersion := make(map[string]uint16, len(allSupportedVersions))
   495  	for version := range allSupportedVersions {
   496  		// tls.VersionName(version) will be something like "TLS 1.3"
   497  		name := re.FindString(tls.VersionName(version))
   498  		// check if the regex either failed to match, or if it is not specific enough
   499  		// (matching something which was already found)
   500  		if _, ok := nameToVersion[name]; ok || name == "" {
   501  			// if we ever fail to match a regex we shouldn't do anything in production
   502  			// but let's make a test fail so we can investigate and update the regex
   503  			for _, f := range testHook {
   504  				f(version, re)
   505  			}
   506  			continue
   507  		}
   508  
   509  		nameToVersion[name] = version
   510  
   511  	}
   512  	return nameToVersion
   513  }