github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/env/grpc_dial_provider.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package env
    16  
    17  import (
    18  	"crypto/tls"
    19  	"errors"
    20  	"net"
    21  	"net/http"
    22  	"os"
    23  	"runtime"
    24  	"strings"
    25  	"unicode"
    26  
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/credentials"
    29  
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/creds"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
    34  )
    35  
    36  // GRPCDialProvider implements dbfactory.GRPCDialProvider. By default, it is not able to use custom user credentials, but
    37  // if it is initialized with a DoltEnv, it will load custom user credentials from it.
    38  type GRPCDialProvider struct {
    39  	dEnv *DoltEnv
    40  }
    41  
    42  var _ dbfactory.GRPCDialProvider = GRPCDialProvider{}
    43  
    44  // NewGRPCDialProvider returns a new GRPCDialProvider, with no DoltEnv configured and without supporting
    45  // custom user credentials.
    46  func NewGRPCDialProvider() *GRPCDialProvider {
    47  	return &GRPCDialProvider{}
    48  }
    49  
    50  // NewGRPCDialProviderFromDoltEnv returns a new GRPCDialProvider, configured with the specified DoltEnv
    51  // and uses that DoltEnv to load custom user credentials.
    52  func NewGRPCDialProviderFromDoltEnv(dEnv *DoltEnv) *GRPCDialProvider {
    53  	return &GRPCDialProvider{
    54  		dEnv: dEnv,
    55  	}
    56  }
    57  
    58  // GetGRPCDialParams implements dbfactory.GRPCDialProvider
    59  func (p GRPCDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GRPCRemoteConfig, error) {
    60  	endpoint := config.Endpoint
    61  	if strings.IndexRune(endpoint, ':') == -1 {
    62  		if config.Insecure {
    63  			endpoint += ":80"
    64  		} else {
    65  			endpoint += ":443"
    66  		}
    67  	}
    68  
    69  	var httpfetcher grpcendpoint.HTTPFetcher = http.DefaultClient
    70  
    71  	var opts []grpc.DialOption
    72  	if config.TLSConfig != nil {
    73  		tc := credentials.NewTLS(config.TLSConfig)
    74  		opts = append(opts, grpc.WithTransportCredentials(tc))
    75  
    76  		httpfetcher = &http.Client{
    77  			Transport: &http.Transport{
    78  				TLSClientConfig:   config.TLSConfig,
    79  				ForceAttemptHTTP2: true,
    80  			},
    81  		}
    82  	} else if config.Insecure {
    83  		opts = append(opts, grpc.WithInsecure())
    84  	} else {
    85  		tc := credentials.NewTLS(&tls.Config{})
    86  		opts = append(opts, grpc.WithTransportCredentials(tc))
    87  	}
    88  
    89  	opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(128*1024*1024)))
    90  	opts = append(opts, grpc.WithUserAgent(p.getUserAgentString()))
    91  
    92  	if config.Creds != nil {
    93  		opts = append(opts, grpc.WithPerRPCCredentials(config.Creds))
    94  	} else if config.WithEnvCreds {
    95  		var rpcCreds credentials.PerRPCCredentials
    96  		var err error
    97  		if config.UserIdForOsEnvAuth != "" {
    98  			rpcCreds, err = p.getRPCCredsFromOSEnv(config.UserIdForOsEnvAuth)
    99  			if err != nil {
   100  				return dbfactory.GRPCRemoteConfig{}, err
   101  			}
   102  		} else {
   103  			rpcCreds, err = p.getRPCCreds(endpoint)
   104  			if err != nil {
   105  				return dbfactory.GRPCRemoteConfig{}, err
   106  			}
   107  		}
   108  		if rpcCreds != nil {
   109  			opts = append(opts, grpc.WithPerRPCCredentials(rpcCreds))
   110  		}
   111  	}
   112  	return dbfactory.GRPCRemoteConfig{
   113  		Endpoint:    endpoint,
   114  		DialOptions: opts,
   115  		HTTPFetcher: httpfetcher,
   116  	}, nil
   117  }
   118  
   119  // getRPCCredsFromOSEnv returns RPC Credentials for the specified username, using the DOLT_REMOTE_PASSWORD
   120  func (p GRPCDialProvider) getRPCCredsFromOSEnv(username string) (credentials.PerRPCCredentials, error) {
   121  	if username == "" {
   122  		return nil, errors.New("Runtime error: username must be provided to getRPCCredsFromOSEnv")
   123  	}
   124  
   125  	pass, found := os.LookupEnv(dconfig.EnvDoltRemotePassword)
   126  	if !found {
   127  		return nil, errors.New("error: must set DOLT_REMOTE_PASSWORD environment variable to use --user param")
   128  	}
   129  	c := creds.DoltCredsForPass{
   130  		Username: username,
   131  		Password: pass,
   132  	}
   133  
   134  	return c.RPCCreds(), nil
   135  }
   136  
   137  // getRPCCreds returns any RPC credentials available to this dial provider. If a DoltEnv has been configured
   138  // in this dial provider, it will be used to load custom user credentials, otherwise nil will be returned.
   139  func (p GRPCDialProvider) getRPCCreds(endpoint string) (credentials.PerRPCCredentials, error) {
   140  	if p.dEnv == nil {
   141  		return nil, nil
   142  	}
   143  
   144  	if p.dEnv.UserPassConfig != nil {
   145  		return p.dEnv.UserPassConfig.RPCCreds(), nil
   146  	}
   147  
   148  	dCreds, valid, err := p.dEnv.UserDoltCreds()
   149  	if err != nil {
   150  		return nil, ErrInvalidCredsFile
   151  	}
   152  	if !valid {
   153  		return nil, nil
   154  	}
   155  
   156  	return dCreds.RPCCreds(getHostFromEndpoint(endpoint)), nil
   157  }
   158  
   159  func getHostFromEndpoint(endpoint string) string {
   160  	host, _, err := net.SplitHostPort(endpoint)
   161  	if err != nil {
   162  		return DefaultRemotesApiHost
   163  	}
   164  	return host
   165  }
   166  
   167  // getUserAgentString returns a user agent string to use in GRPC requests.
   168  func (p GRPCDialProvider) getUserAgentString() string {
   169  	version := ""
   170  	if p.dEnv != nil {
   171  		version = p.dEnv.Version
   172  	}
   173  
   174  	tokens := []string{
   175  		"dolt_cli",
   176  		version,
   177  		runtime.GOOS,
   178  		runtime.GOARCH,
   179  	}
   180  
   181  	for i, t := range tokens {
   182  		tokens[i] = strings.Map(func(r rune) rune {
   183  			if unicode.IsSpace(r) {
   184  				return '_'
   185  			}
   186  
   187  			return r
   188  		}, strings.TrimSpace(t))
   189  	}
   190  
   191  	return strings.Join(tokens, " ")
   192  }