github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/tlstest/acl/acl_over_tls_test.go (about) 1 package acl 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "io/ioutil" 8 9 "github.com/dgraph-io/dgo" 10 "github.com/dgraph-io/dgo/protos/api" 11 "github.com/dgraph-io/dgraph/testutil" 12 "github.com/golang/glog" 13 "github.com/pkg/errors" 14 "github.com/spf13/viper" 15 "google.golang.org/grpc" 16 "google.golang.org/grpc/credentials" 17 ) 18 19 func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) { 20 var pool *x509.CertPool 21 if useSystemCA { 22 var err error 23 if pool, err = x509.SystemCertPool(); err != nil { 24 return nil, err 25 } 26 } else { 27 pool = x509.NewCertPool() 28 } 29 30 if len(certPath) > 0 { 31 caFile, err := ioutil.ReadFile(certPath) 32 if err != nil { 33 return nil, err 34 } 35 if !pool.AppendCertsFromPEM(caFile) { 36 return nil, errors.Errorf("error reading CA file %q", certPath) 37 } 38 } 39 40 return pool, nil 41 } 42 43 func loadClientTLSConfig(v *viper.Viper) (*tls.Config, error) { 44 // When the --tls_cacert option is pecified, the connection will be set up using TLS instead of 45 // plaintext. However the client cert files are optional, depending on whether the server is 46 // requiring a client certificate. 47 caCert := v.GetString("tls_cacert") 48 if caCert != "" { 49 tlsCfg := tls.Config{} 50 51 // 1. set up the root CA 52 pool, err := generateCertPool(caCert, v.GetBool("tls_use_system_ca")) 53 if err != nil { 54 return nil, err 55 } 56 tlsCfg.RootCAs = pool 57 58 // 2. set up the server name for verification 59 tlsCfg.ServerName = v.GetString("tls_server_name") 60 61 // 3. optionally load the client cert files 62 certFile := v.GetString("tls_cert") 63 keyFile := v.GetString("tls_key") 64 if certFile != "" && keyFile != "" { 65 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 66 if err != nil { 67 return nil, err 68 } 69 tlsCfg.Certificates = []tls.Certificate{cert} 70 } 71 72 return &tlsCfg, nil 73 } 74 return nil, nil 75 } 76 77 func dgraphClientWithCerts(serviceAddr string, conf *viper.Viper) (*dgo.Dgraph, error) { 78 tlsCfg, err := loadClientTLSConfig(conf) 79 if err != nil { 80 return nil, err 81 } 82 83 dialOpts := []grpc.DialOption{} 84 if tlsCfg != nil { 85 dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg))) 86 } else { 87 dialOpts = append(dialOpts, grpc.WithInsecure()) 88 } 89 conn, err := grpc.Dial(serviceAddr, dialOpts...) 90 if err != nil { 91 return nil, err 92 } 93 dg := dgo.NewDgraphClient(api.NewDgraphClient(conn)) 94 return dg, nil 95 } 96 97 func ExampleLoginOverTLS() { 98 conf := viper.New() 99 conf.Set("tls_cacert", "../tls/ca.crt") 100 conf.Set("tls_server_name", "node") 101 102 dg, err := dgraphClientWithCerts(testutil.SockAddr, conf) 103 if err != nil { 104 glog.Fatalf("Unable to get dgraph client: %v", err) 105 } 106 if err := dg.Login(context.Background(), "groot", "password"); err != nil { 107 glog.Fatalf("Unable to login using the groot account: %v", err) 108 } 109 110 // Output: 111 }