github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/x/tls_helper.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package x
    18  
    19  import (
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"io/ioutil"
    23  	"path"
    24  	"strings"
    25  
    26  	"github.com/pkg/errors"
    27  	"github.com/spf13/pflag"
    28  	"github.com/spf13/viper"
    29  )
    30  
    31  const (
    32  	tlsRootCert = "ca.crt"
    33  )
    34  
    35  // TLSHelperConfig define params used to create a tls.Config
    36  type TLSHelperConfig struct {
    37  	CertDir          string
    38  	CertRequired     bool
    39  	Cert             string
    40  	Key              string
    41  	ServerName       string
    42  	RootCACert       string
    43  	ClientAuth       string
    44  	UseSystemCACerts bool
    45  }
    46  
    47  // RegisterClientTLSFlags registers the required flags to set up a TLS client.
    48  func RegisterClientTLSFlags(flag *pflag.FlagSet) {
    49  	flag.String("tls_cacert", "",
    50  		"The CA Cert file used to verify server certificates. Required for enabling TLS.")
    51  	flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
    52  	flag.String("tls_server_name", "", "Used to verify the server hostname.")
    53  	flag.String("tls_cert", "", "(optional) The Cert file provided by the client to the server.")
    54  	flag.String("tls_key", "", "(optional) The private key file "+
    55  		"provided by the client to the server.")
    56  }
    57  
    58  // LoadServerTLSConfig loads the TLS config into the server with the given parameters.
    59  func LoadServerTLSConfig(v *viper.Viper, tlsCertFile string, tlsKeyFile string) (*tls.Config,
    60  	error) {
    61  	conf := TLSHelperConfig{}
    62  	conf.CertDir = v.GetString("tls_dir")
    63  	if conf.CertDir != "" {
    64  		conf.CertRequired = true
    65  		conf.RootCACert = path.Join(conf.CertDir, tlsRootCert)
    66  		conf.Cert = path.Join(conf.CertDir, tlsCertFile)
    67  		conf.Key = path.Join(conf.CertDir, tlsKeyFile)
    68  		conf.ClientAuth = v.GetString("tls_client_auth")
    69  	}
    70  	conf.UseSystemCACerts = v.GetBool("tls_use_system_ca")
    71  
    72  	return GenerateServerTLSConfig(&conf)
    73  }
    74  
    75  // LoadClientTLSConfig loads the TLS config into the client with the given parameters.
    76  func LoadClientTLSConfig(v *viper.Viper) (*tls.Config, error) {
    77  	// When the --tls_cacert option is pecified, the connection will be set up using TLS instead of
    78  	// plaintext. However the client cert files are optional, depending on whether the server
    79  	// requires a client certificate.
    80  	caCert := v.GetString("tls_cacert")
    81  	if caCert != "" {
    82  		tlsCfg := tls.Config{}
    83  
    84  		// 1. set up the root CA
    85  		pool, err := generateCertPool(caCert, v.GetBool("tls_use_system_ca"))
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  		tlsCfg.RootCAs = pool
    90  
    91  		// 2. set up the server name for verification
    92  		tlsCfg.ServerName = v.GetString("tls_server_name")
    93  
    94  		// 3. optionally load the client cert files
    95  		certFile := v.GetString("tls_cert")
    96  		keyFile := v.GetString("tls_key")
    97  		if certFile != "" && keyFile != "" {
    98  			cert, err := tls.LoadX509KeyPair(certFile, keyFile)
    99  			if err != nil {
   100  				return nil, err
   101  			}
   102  			tlsCfg.Certificates = []tls.Certificate{cert}
   103  		}
   104  
   105  		return &tlsCfg, nil
   106  	} else
   107  	// Attempt to determine if user specified *any* TLS option. Unfortunately and contrary to
   108  	// Viper's own documentation, there's no way to tell whether an option value came from a
   109  	// command-line option or a built-it default.
   110  	if v.GetString("tls_server_name") != "" ||
   111  		v.GetString("tls_cert") != "" ||
   112  		v.GetString("tls_key") != "" {
   113  		return nil, errors.Errorf("--tls_cacert is required for enabling TLS")
   114  	}
   115  	return nil, nil
   116  }
   117  
   118  func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) {
   119  	var pool *x509.CertPool
   120  	if useSystemCA {
   121  		var err error
   122  		if pool, err = x509.SystemCertPool(); err != nil {
   123  			return nil, err
   124  		}
   125  	} else {
   126  		pool = x509.NewCertPool()
   127  	}
   128  
   129  	if len(certPath) > 0 {
   130  		caFile, err := ioutil.ReadFile(certPath)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  		if !pool.AppendCertsFromPEM(caFile) {
   135  			return nil, errors.Errorf("error reading CA file %q", certPath)
   136  		}
   137  	}
   138  
   139  	return pool, nil
   140  }
   141  
   142  func setupClientAuth(authType string) (tls.ClientAuthType, error) {
   143  	auth := map[string]tls.ClientAuthType{
   144  		"REQUEST":          tls.RequestClientCert,
   145  		"REQUIREANY":       tls.RequireAnyClientCert,
   146  		"VERIFYIFGIVEN":    tls.VerifyClientCertIfGiven,
   147  		"REQUIREANDVERIFY": tls.RequireAndVerifyClientCert,
   148  	}
   149  
   150  	if len(authType) > 0 {
   151  		if v, has := auth[strings.ToUpper(authType)]; has {
   152  			return v, nil
   153  		}
   154  		return tls.NoClientCert, errors.Errorf("Invalid client auth. Valid values " +
   155  			"[REQUEST, REQUIREANY, VERIFYIFGIVEN, REQUIREANDVERIFY]")
   156  	}
   157  
   158  	return tls.NoClientCert, nil
   159  }
   160  
   161  // GenerateServerTLSConfig creates and returns a new *tls.Config with the
   162  // configuration provided.
   163  func GenerateServerTLSConfig(config *TLSHelperConfig) (tlsCfg *tls.Config, err error) {
   164  	if config.CertRequired {
   165  		tlsCfg = new(tls.Config)
   166  		cert, err := tls.LoadX509KeyPair(config.Cert, config.Key)
   167  		if err != nil {
   168  			return nil, err
   169  		}
   170  		tlsCfg.Certificates = []tls.Certificate{cert}
   171  
   172  		pool, err := generateCertPool(config.RootCACert, config.UseSystemCACerts)
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  		tlsCfg.ClientCAs = pool
   177  
   178  		auth, err := setupClientAuth(config.ClientAuth)
   179  		if err != nil {
   180  			return nil, err
   181  		}
   182  		tlsCfg.ClientAuth = auth
   183  
   184  		tlsCfg.MinVersion = tls.VersionTLS11
   185  		tlsCfg.MaxVersion = tls.VersionTLS12
   186  
   187  		return tlsCfg, nil
   188  	}
   189  	return nil, nil
   190  }
   191  
   192  // GenerateClientTLSConfig creates and returns a new client side *tls.Config with the
   193  // configuration provided.
   194  func GenerateClientTLSConfig(config *TLSHelperConfig) (tlsCfg *tls.Config, err error) {
   195  	pool, err := generateCertPool(config.RootCACert, config.UseSystemCACerts)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	return &tls.Config{RootCAs: pool, ServerName: config.ServerName}, nil
   201  }