github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/tlstest/acl/acl_over_tls_test.go (about)

     1  package acl
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"io/ioutil"
     8  
     9  	"github.com/dgraph-io/dgo"
    10  	"github.com/dgraph-io/dgo/protos/api"
    11  	"github.com/dgraph-io/dgraph/testutil"
    12  	"github.com/golang/glog"
    13  	"github.com/pkg/errors"
    14  	"github.com/spf13/viper"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/credentials"
    17  )
    18  
    19  func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) {
    20  	var pool *x509.CertPool
    21  	if useSystemCA {
    22  		var err error
    23  		if pool, err = x509.SystemCertPool(); err != nil {
    24  			return nil, err
    25  		}
    26  	} else {
    27  		pool = x509.NewCertPool()
    28  	}
    29  
    30  	if len(certPath) > 0 {
    31  		caFile, err := ioutil.ReadFile(certPath)
    32  		if err != nil {
    33  			return nil, err
    34  		}
    35  		if !pool.AppendCertsFromPEM(caFile) {
    36  			return nil, errors.Errorf("error reading CA file %q", certPath)
    37  		}
    38  	}
    39  
    40  	return pool, nil
    41  }
    42  
    43  func loadClientTLSConfig(v *viper.Viper) (*tls.Config, error) {
    44  	// When the --tls_cacert option is pecified, the connection will be set up using TLS instead of
    45  	// plaintext. However the client cert files are optional, depending on whether the server is
    46  	// requiring a client certificate.
    47  	caCert := v.GetString("tls_cacert")
    48  	if caCert != "" {
    49  		tlsCfg := tls.Config{}
    50  
    51  		// 1. set up the root CA
    52  		pool, err := generateCertPool(caCert, v.GetBool("tls_use_system_ca"))
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  		tlsCfg.RootCAs = pool
    57  
    58  		// 2. set up the server name for verification
    59  		tlsCfg.ServerName = v.GetString("tls_server_name")
    60  
    61  		// 3. optionally load the client cert files
    62  		certFile := v.GetString("tls_cert")
    63  		keyFile := v.GetString("tls_key")
    64  		if certFile != "" && keyFile != "" {
    65  			cert, err := tls.LoadX509KeyPair(certFile, keyFile)
    66  			if err != nil {
    67  				return nil, err
    68  			}
    69  			tlsCfg.Certificates = []tls.Certificate{cert}
    70  		}
    71  
    72  		return &tlsCfg, nil
    73  	}
    74  	return nil, nil
    75  }
    76  
    77  func dgraphClientWithCerts(serviceAddr string, conf *viper.Viper) (*dgo.Dgraph, error) {
    78  	tlsCfg, err := loadClientTLSConfig(conf)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	dialOpts := []grpc.DialOption{}
    84  	if tlsCfg != nil {
    85  		dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)))
    86  	} else {
    87  		dialOpts = append(dialOpts, grpc.WithInsecure())
    88  	}
    89  	conn, err := grpc.Dial(serviceAddr, dialOpts...)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	dg := dgo.NewDgraphClient(api.NewDgraphClient(conn))
    94  	return dg, nil
    95  }
    96  
    97  func ExampleLoginOverTLS() {
    98  	conf := viper.New()
    99  	conf.Set("tls_cacert", "../tls/ca.crt")
   100  	conf.Set("tls_server_name", "node")
   101  
   102  	dg, err := dgraphClientWithCerts(testutil.SockAddr, conf)
   103  	if err != nil {
   104  		glog.Fatalf("Unable to get dgraph client: %v", err)
   105  	}
   106  	if err := dg.Login(context.Background(), "groot", "password"); err != nil {
   107  		glog.Fatalf("Unable to login using the groot account: %v", err)
   108  	}
   109  
   110  	// Output:
   111  }