github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/x/x.go (about)

     1  /*
     2   * Copyright 2015-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package x
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"context"
    23  	"crypto/tls"
    24  	"encoding/json"
    25  	"fmt"
    26  	"math"
    27  	"math/rand"
    28  	"net"
    29  	"net/http"
    30  	"os"
    31  	"regexp"
    32  	"sort"
    33  	"strconv"
    34  	"strings"
    35  	"syscall"
    36  	"time"
    37  
    38  	"golang.org/x/crypto/ssh/terminal"
    39  
    40  	"github.com/dgraph-io/dgo"
    41  	"github.com/dgraph-io/dgo/protos/api"
    42  	"github.com/golang/glog"
    43  	"github.com/pkg/errors"
    44  	"github.com/spf13/viper"
    45  	"go.opencensus.io/trace"
    46  	"google.golang.org/grpc"
    47  	"google.golang.org/grpc/credentials"
    48  	"google.golang.org/grpc/encoding/gzip"
    49  )
    50  
    51  // Error constants representing different types of errors.
    52  var (
    53  	// ErrNotSupported is thrown when an enterprise feature is requested in the open source version.
    54  	ErrNotSupported = errors.Errorf("Feature available only in Dgraph Enterprise Edition")
    55  )
    56  
    57  const (
    58  	// Success is equivalent to the HTTP 200 error code.
    59  	Success = "Success"
    60  	// ErrorUnauthorized is equivalent to the HTTP 401 error code.
    61  	ErrorUnauthorized = "ErrorUnauthorized"
    62  	// ErrorInvalidMethod is equivalent to the HTTP 405 error code.
    63  	ErrorInvalidMethod = "ErrorInvalidMethod"
    64  	// ErrorInvalidRequest is equivalent to the HTTP 400 error code.
    65  	ErrorInvalidRequest = "ErrorInvalidRequest"
    66  	// Error is a general error code.
    67  	Error = "Error"
    68  	// ErrorNoData is an error returned when the requested data cannot be returned.
    69  	ErrorNoData = "ErrorNoData"
    70  	// ValidHostnameRegex is a regex that accepts our expected hostname format.
    71  	ValidHostnameRegex = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]" +
    72  		"|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])$"
    73  	// Star is equivalent to using * in a mutation.
    74  	// When changing this value also remember to change in in client/client.go:DeleteEdges.
    75  	Star = "_STAR_ALL"
    76  
    77  	// GrpcMaxSize is the maximum possible size for a gRPC message.
    78  	// Dgraph uses the maximum size for the most flexibility (4GB - equal
    79  	// to the max grpc frame size). Users will still need to set the max
    80  	// message sizes allowable on the client size when dialing.
    81  	GrpcMaxSize = 4 << 30
    82  
    83  	// PortZeroGrpc is the default gRPC port for zero.
    84  	PortZeroGrpc = 5080
    85  	// PortZeroHTTP is the default HTTP port for zero.
    86  	PortZeroHTTP = 6080
    87  	// PortInternal is the default port for internal use.
    88  	PortInternal = 7080
    89  	// PortHTTP is the default HTTP port for alpha.
    90  	PortHTTP = 8080
    91  	// PortGrpc is the default gRPC port for alpha.
    92  	PortGrpc = 9080
    93  	// ForceAbortDifference is the maximum allowed difference between
    94  	// AppliedUntil - TxnMarks.DoneUntil() before old transactions start getting aborted.
    95  	ForceAbortDifference = 5000
    96  
    97  	// FacetDelimeter is the symbol used to distinguish predicate names from facets.
    98  	FacetDelimeter = "|"
    99  
   100  	// GrootId is the ID of the admin user for ACLs.
   101  	GrootId = "groot"
   102  	// AclPredicates is the JSON representation of the predicates reserved for use
   103  	// by the ACL system.
   104  	AclPredicates = `
   105  {"predicate":"dgraph.xid","type":"string", "index": true, "tokenizer":["exact"], "upsert": true},
   106  {"predicate":"dgraph.password","type":"password"},
   107  {"predicate":"dgraph.user.group","list":true, "reverse": true, "type": "uid"},
   108  {"predicate":"dgraph.group.acl","type":"string"}
   109  `
   110  )
   111  
   112  var (
   113  	// Useful for running multiple servers on the same machine.
   114  	regExpHostName = regexp.MustCompile(ValidHostnameRegex)
   115  	// Nilbyte is a nil byte slice. Used
   116  	Nilbyte []byte
   117  )
   118  
   119  // ShouldCrash returns true if the error should cause the process to crash.
   120  func ShouldCrash(err error) bool {
   121  	if err == nil {
   122  		return false
   123  	}
   124  	errStr := grpc.ErrorDesc(err)
   125  	return strings.Contains(errStr, "REUSE_RAFTID") ||
   126  		strings.Contains(errStr, "REUSE_ADDR") ||
   127  		strings.Contains(errStr, "NO_ADDR") ||
   128  		strings.Contains(errStr, "ENTERPRISE_LIMIT_REACHED")
   129  }
   130  
   131  // WhiteSpace Replacer removes spaces and tabs from a string.
   132  var WhiteSpace = strings.NewReplacer(" ", "", "\t", "")
   133  
   134  // GqlError is a GraphQL spec compliant error structure.  See GraphQL spec on
   135  // errors here: https://graphql.github.io/graphql-spec/June2018/#sec-Errors
   136  //
   137  // Note: "Every error must contain an entry with the key message with a string
   138  // description of the error intended for the developer as a guide to understand
   139  // and correct the error."
   140  //
   141  // "If an error can be associated to a particular point in the request [the error]
   142  // should contain an entry with the key locations with a list of locations"
   143  //
   144  // Path is about GraphQL results and Errors for GraphQL layer.
   145  //
   146  // Extensions is for everything else.
   147  type GqlError struct {
   148  	Message    string                 `json:"message"`
   149  	Locations  []Location             `json:"locations,omitempty"`
   150  	Path       []interface{}          `json:"path,omitempty"`
   151  	Extensions map[string]interface{} `json:"extensions,omitempty"`
   152  }
   153  
   154  // A Location is the Line+Column index of an error in a request.
   155  type Location struct {
   156  	Line   int `json:"line,omitempty"`
   157  	Column int `json:"column,omitempty"`
   158  }
   159  
   160  type queryRes struct {
   161  	Errors []GqlError `json:"errors"`
   162  }
   163  
   164  // SetStatus sets the error code, message and the newly assigned uids
   165  // in the http response.
   166  func SetStatus(w http.ResponseWriter, code, msg string) {
   167  	var qr queryRes
   168  	ext := make(map[string]interface{})
   169  	ext["code"] = code
   170  	qr.Errors = append(qr.Errors, GqlError{Message: msg, Extensions: ext})
   171  	if js, err := json.Marshal(qr); err == nil {
   172  		if _, err := w.Write(js); err != nil {
   173  			glog.Errorf("Error while writing: %+v", err)
   174  		}
   175  	} else {
   176  		panic(fmt.Sprintf("Unable to marshal: %+v", qr))
   177  	}
   178  }
   179  
   180  // SetHttpStatus is similar to SetStatus but sets a proper HTTP status code
   181  // in the response instead of always returning HTTP 200 (OK).
   182  func SetHttpStatus(w http.ResponseWriter, code int, msg string) {
   183  	w.WriteHeader(code)
   184  	SetStatus(w, "error", msg)
   185  }
   186  
   187  // AddCorsHeaders adds the CORS headers to an HTTP response.
   188  func AddCorsHeaders(w http.ResponseWriter) {
   189  	w.Header().Set("Access-Control-Allow-Origin", "*")
   190  	w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
   191  	w.Header().Set("Access-Control-Allow-Headers", "X-Dgraph-AccessToken, "+
   192  		"Content-Type, Content-Length, Accept-Encoding, Cache-Control, "+
   193  		"X-CSRF-Token, X-Auth-Token, X-Requested-With")
   194  	w.Header().Set("Access-Control-Allow-Credentials", "true")
   195  	w.Header().Set("Connection", "close")
   196  }
   197  
   198  // QueryResWithData represents a response that holds errors as well as data.
   199  type QueryResWithData struct {
   200  	Errors []GqlError `json:"errors"`
   201  	Data   *string    `json:"data"`
   202  }
   203  
   204  // SetStatusWithData sets the errors in the response and ensures that the data key
   205  // in the data is present with value nil.
   206  // In case an error was encountered after the query execution started, we have to return data
   207  // key with null value according to GraphQL spec.
   208  func SetStatusWithData(w http.ResponseWriter, code, msg string) {
   209  	var qr QueryResWithData
   210  	ext := make(map[string]interface{})
   211  	ext["code"] = code
   212  	qr.Errors = append(qr.Errors, GqlError{Message: msg, Extensions: ext})
   213  	// This would ensure that data key is present with value null.
   214  	if js, err := json.Marshal(qr); err == nil {
   215  		if _, err := w.Write(js); err != nil {
   216  			glog.Errorf("Error while writing: %+v", err)
   217  		}
   218  	} else {
   219  		panic(fmt.Sprintf("Unable to marshal: %+v", qr))
   220  	}
   221  }
   222  
   223  // Reply sets the body of an HTTP response to the JSON representation of the given reply.
   224  func Reply(w http.ResponseWriter, rep interface{}) {
   225  	if js, err := json.Marshal(rep); err == nil {
   226  		w.Header().Set("Content-Type", "application/json")
   227  		fmt.Fprint(w, string(js))
   228  	} else {
   229  		SetStatus(w, Error, "Internal server error")
   230  	}
   231  }
   232  
   233  // ParseRequest parses the body of the given request.
   234  func ParseRequest(w http.ResponseWriter, r *http.Request, data interface{}) bool {
   235  	defer r.Body.Close()
   236  	decoder := json.NewDecoder(r.Body)
   237  	if err := decoder.Decode(&data); err != nil {
   238  		SetStatus(w, Error, fmt.Sprintf("While parsing request: %v", err))
   239  		return false
   240  	}
   241  	return true
   242  }
   243  
   244  // Min returns the minimum of the two given numbers.
   245  func Min(a, b uint64) uint64 {
   246  	if a < b {
   247  		return a
   248  	}
   249  	return b
   250  }
   251  
   252  // Max returns the maximum of the two given numbers.
   253  func Max(a, b uint64) uint64 {
   254  	if a > b {
   255  		return a
   256  	}
   257  	return b
   258  }
   259  
   260  // RetryUntilSuccess runs the given function until it succeeds or can no longer be retried.
   261  func RetryUntilSuccess(maxRetries int, waitAfterFailure time.Duration,
   262  	f func() error) error {
   263  	var err error
   264  	for retry := maxRetries; retry != 0; retry-- {
   265  		if err = f(); err == nil {
   266  			return nil
   267  		}
   268  		if waitAfterFailure > 0 {
   269  			time.Sleep(waitAfterFailure)
   270  		}
   271  	}
   272  	return err
   273  }
   274  
   275  // HasString returns whether the slice contains the given string.
   276  func HasString(a []string, b string) bool {
   277  	for _, k := range a {
   278  		if k == b {
   279  			return true
   280  		}
   281  	}
   282  	return false
   283  }
   284  
   285  // ReadLine reads a single line from a buffered reader. The line is read into the
   286  // passed in buffer to minimize allocations. This is the preferred
   287  // method for loading long lines which could be longer than the buffer
   288  // size of bufio.Scanner.
   289  func ReadLine(r *bufio.Reader, buf *bytes.Buffer) error {
   290  	isPrefix := true
   291  	var err error
   292  	buf.Reset()
   293  	for isPrefix && err == nil {
   294  		var line []byte
   295  		// The returned line is an pb.buffer in bufio and is only
   296  		// valid until the next call to ReadLine. It needs to be copied
   297  		// over to our own buffer.
   298  		line, isPrefix, err = r.ReadLine()
   299  		if err == nil {
   300  			buf.Write(line)
   301  		}
   302  	}
   303  	return err
   304  }
   305  
   306  // FixedDuration returns the given duration as a string of fixed length.
   307  func FixedDuration(d time.Duration) string {
   308  	str := fmt.Sprintf("%02ds", int(d.Seconds())%60)
   309  	if d >= time.Minute {
   310  		str = fmt.Sprintf("%02dm", int(d.Minutes())%60) + str
   311  	}
   312  	if d >= time.Hour {
   313  		str = fmt.Sprintf("%02dh", int(d.Hours())) + str
   314  	}
   315  	return str
   316  }
   317  
   318  // PageRange returns start and end indices given pagination params. Note that n
   319  // is the size of the input list.
   320  func PageRange(count, offset, n int) (int, int) {
   321  	if n == 0 {
   322  		return 0, 0
   323  	}
   324  	if count == 0 && offset == 0 {
   325  		return 0, n
   326  	}
   327  	if count < 0 {
   328  		// Items from the back of the array, like Python arrays. Do a positive mod n.
   329  		if count*-1 > n {
   330  			count = -n
   331  		}
   332  		return (((n + count) % n) + n) % n, n
   333  	}
   334  	start := offset
   335  	if start < 0 {
   336  		start = 0
   337  	}
   338  	if start > n {
   339  		return n, n
   340  	}
   341  	if count == 0 { // No count specified. Just take the offset parameter.
   342  		return start, n
   343  	}
   344  	end := start + count
   345  	if end > n {
   346  		end = n
   347  	}
   348  	return start, end
   349  }
   350  
   351  // ValidateAddress checks whether given address can be used with grpc dial function
   352  func ValidateAddress(addr string) bool {
   353  	host, port, err := net.SplitHostPort(addr)
   354  	if err != nil {
   355  		return false
   356  	}
   357  	if p, err := strconv.Atoi(port); err != nil || p <= 0 || p >= 65536 {
   358  		return false
   359  	}
   360  	if ip := net.ParseIP(host); ip != nil {
   361  		return true
   362  	}
   363  	// try to parse as hostname as per hostname RFC
   364  	if len(strings.Replace(host, ".", "", -1)) > 255 {
   365  		return false
   366  	}
   367  	return regExpHostName.MatchString(host)
   368  }
   369  
   370  // RemoveDuplicates sorts the slice of strings and removes duplicates. changes the input slice.
   371  // This function should be called like: someSlice = RemoveDuplicates(someSlice)
   372  func RemoveDuplicates(s []string) (out []string) {
   373  	sort.Strings(s)
   374  	out = s[:0]
   375  	for i := range s {
   376  		if i > 0 && s[i] == s[i-1] {
   377  			continue
   378  		}
   379  		out = append(out, s[i])
   380  	}
   381  	return
   382  }
   383  
   384  // BytesBuffer provides a buffer backed by byte slices.
   385  type BytesBuffer struct {
   386  	data [][]byte
   387  	off  int
   388  	sz   int
   389  }
   390  
   391  func (b *BytesBuffer) grow(n int) {
   392  	if n < 128 {
   393  		n = 128
   394  	}
   395  	if len(b.data) == 0 {
   396  		b.data = append(b.data, make([]byte, n))
   397  	}
   398  
   399  	last := len(b.data) - 1
   400  	// Return if we have sufficient space
   401  	if len(b.data[last])-b.off >= n {
   402  		return
   403  	}
   404  	sz := len(b.data[last]) * 2
   405  	if sz > 512<<10 {
   406  		sz = 512 << 10 // 512 KB
   407  	}
   408  	if sz < n {
   409  		sz = n
   410  	}
   411  	b.data[last] = b.data[last][:b.off]
   412  	b.sz += len(b.data[last])
   413  	b.data = append(b.data, make([]byte, sz))
   414  	b.off = 0
   415  }
   416  
   417  // Slice returns a slice of length n to be used for writing.
   418  func (b *BytesBuffer) Slice(n int) []byte {
   419  	b.grow(n)
   420  	last := len(b.data) - 1
   421  	b.off += n
   422  	b.sz += n
   423  	return b.data[last][b.off-n : b.off]
   424  }
   425  
   426  // Length returns the size of the buffer.
   427  func (b *BytesBuffer) Length() int {
   428  	return b.sz
   429  }
   430  
   431  // CopyTo copies the contents of the buffer to the given byte slice.
   432  // Caller should ensure that o is of appropriate length.
   433  func (b *BytesBuffer) CopyTo(o []byte) int {
   434  	offset := 0
   435  	for i, d := range b.data {
   436  		if i == len(b.data)-1 {
   437  			copy(o[offset:], d[:b.off])
   438  			offset += b.off
   439  		} else {
   440  			copy(o[offset:], d)
   441  			offset += len(d)
   442  		}
   443  	}
   444  	return offset
   445  }
   446  
   447  // TruncateBy reduces the size of the bugger by the given amount.
   448  // Always give back <= touched bytes.
   449  func (b *BytesBuffer) TruncateBy(n int) {
   450  	b.off -= n
   451  	b.sz -= n
   452  	AssertTrue(b.off >= 0 && b.sz >= 0)
   453  }
   454  
   455  type record struct {
   456  	Name string
   457  	Dur  time.Duration
   458  }
   459  
   460  // Timer implements a timer that supports recording the duration of events.
   461  type Timer struct {
   462  	start   time.Time
   463  	last    time.Time
   464  	records []record
   465  }
   466  
   467  // Start starts the timer and clears the list of records.
   468  func (t *Timer) Start() {
   469  	t.start = time.Now()
   470  	t.last = t.start
   471  	t.records = t.records[:0]
   472  }
   473  
   474  // Record records an event and assigns it the given name.
   475  func (t *Timer) Record(name string) {
   476  	now := time.Now()
   477  	t.records = append(t.records, record{
   478  		Name: name,
   479  		Dur:  now.Sub(t.last).Round(time.Millisecond),
   480  	})
   481  	t.last = now
   482  }
   483  
   484  // Total returns the duration since the timer was started.
   485  func (t *Timer) Total() time.Duration {
   486  	return time.Since(t.start).Round(time.Millisecond)
   487  }
   488  
   489  func (t *Timer) String() string {
   490  	sort.Slice(t.records, func(i, j int) bool {
   491  		return t.records[i].Dur > t.records[j].Dur
   492  	})
   493  	return fmt.Sprintf("Timer Total: %s. Breakdown: %v", t.Total(), t.records)
   494  }
   495  
   496  // PredicateLang extracts the language from a predicate (or facet) name.
   497  // Returns the predicate and the language tag, if any.
   498  func PredicateLang(s string) (string, string) {
   499  	i := strings.LastIndex(s, "@")
   500  	if i <= 0 {
   501  		return s, ""
   502  	}
   503  	return s[0:i], s[i+1:]
   504  }
   505  
   506  // DivideAndRule is used to divide a number of tasks among multiple go routines.
   507  func DivideAndRule(num int) (numGo, width int) {
   508  	numGo, width = 64, 0
   509  	for ; numGo >= 1; numGo /= 2 {
   510  		widthF := math.Ceil(float64(num) / float64(numGo))
   511  		if numGo == 1 || widthF >= 256.0 {
   512  			width = int(widthF)
   513  			return
   514  		}
   515  	}
   516  	return
   517  }
   518  
   519  // SetupConnection starts a secure gRPC connection to the given host.
   520  func SetupConnection(host string, tlsCfg *tls.Config, useGz bool) (*grpc.ClientConn, error) {
   521  	callOpts := append([]grpc.CallOption{},
   522  		grpc.MaxCallRecvMsgSize(GrpcMaxSize),
   523  		grpc.MaxCallSendMsgSize(GrpcMaxSize))
   524  
   525  	if useGz {
   526  		fmt.Fprintf(os.Stderr, "Using compression with %s\n", host)
   527  		callOpts = append(callOpts, grpc.UseCompressor(gzip.Name))
   528  	}
   529  
   530  	dialOpts := append([]grpc.DialOption{},
   531  		grpc.WithDefaultCallOptions(callOpts...),
   532  		grpc.WithBlock())
   533  
   534  	if tlsCfg != nil {
   535  		dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)))
   536  	} else {
   537  		dialOpts = append(dialOpts, grpc.WithInsecure())
   538  	}
   539  
   540  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   541  	defer cancel()
   542  
   543  	return grpc.DialContext(ctx, host, dialOpts...)
   544  }
   545  
   546  // Diff computes the difference between the keys of the two given maps.
   547  func Diff(dst map[string]struct{}, src map[string]struct{}) ([]string, []string) {
   548  	var add []string
   549  	var del []string
   550  
   551  	for g := range dst {
   552  		if _, ok := src[g]; !ok {
   553  			add = append(add, g)
   554  		}
   555  	}
   556  	for g := range src {
   557  		if _, ok := dst[g]; !ok {
   558  			del = append(del, g)
   559  		}
   560  	}
   561  
   562  	return add, del
   563  }
   564  
   565  // SpanTimer returns a function used to record the duration of the given span.
   566  func SpanTimer(span *trace.Span, name string) func() {
   567  	if span == nil {
   568  		return func() {}
   569  	}
   570  	uniq := int64(rand.Int31())
   571  	attrs := []trace.Attribute{
   572  		trace.Int64Attribute("funcId", uniq),
   573  		trace.StringAttribute("funcName", name),
   574  	}
   575  	span.Annotate(attrs, "Start.")
   576  	start := time.Now()
   577  
   578  	return func() {
   579  		span.Annotatef(attrs, "End. Took %s", time.Since(start))
   580  		// TODO: We can look into doing a latency record here.
   581  	}
   582  }
   583  
   584  // CloseFunc needs to be called to close all the client connections.
   585  type CloseFunc func()
   586  
   587  // CredOpt stores the options for logging in, including the password and user.
   588  type CredOpt struct {
   589  	Conf        *viper.Viper
   590  	UserID      string
   591  	PasswordOpt string
   592  }
   593  
   594  // GetDgraphClient creates a Dgraph client based on the following options in the configuration:
   595  // --alpha specifies a comma separated list of endpoints to connect to
   596  // --tls_cacert, --tls_cert, --tls_key etc specify the TLS configuration of the connection
   597  // --retries specifies how many times we should retry the connection to each endpoint upon failures
   598  // --user and --password specify the credentials we should use to login with the server
   599  func GetDgraphClient(conf *viper.Viper, login bool) (*dgo.Dgraph, CloseFunc) {
   600  	alphas := conf.GetString("alpha")
   601  	if len(alphas) == 0 {
   602  		glog.Fatalf("The --alpha option must be set in order to connect to Dgraph")
   603  	}
   604  
   605  	fmt.Printf("\nRunning transaction with dgraph endpoint: %v\n", alphas)
   606  	tlsCfg, err := LoadClientTLSConfig(conf)
   607  	Checkf(err, "While loading TLS configuration")
   608  
   609  	ds := strings.Split(alphas, ",")
   610  	var conns []*grpc.ClientConn
   611  	var clients []api.DgraphClient
   612  
   613  	retries := 1
   614  	if conf.IsSet("retries") {
   615  		retries = conf.GetInt("retries")
   616  		if retries < 1 {
   617  			retries = 1
   618  		}
   619  	}
   620  
   621  	for _, d := range ds {
   622  		var conn *grpc.ClientConn
   623  		for i := 0; i < retries; retries++ {
   624  			conn, err = SetupConnection(d, tlsCfg, false)
   625  			if err == nil {
   626  				break
   627  			}
   628  			fmt.Printf("While trying to setup connection: %v. Retrying...\n", err)
   629  			time.Sleep(time.Second)
   630  		}
   631  		if conn == nil {
   632  			Fatalf("Could not setup connection after %d retries", retries)
   633  		}
   634  
   635  		conns = append(conns, conn)
   636  		dc := api.NewDgraphClient(conn)
   637  		clients = append(clients, dc)
   638  	}
   639  
   640  	dg := dgo.NewDgraphClient(clients...)
   641  	user := conf.GetString("user")
   642  	if login && len(user) > 0 {
   643  		err = GetPassAndLogin(dg, &CredOpt{
   644  			Conf:        conf,
   645  			UserID:      user,
   646  			PasswordOpt: "password",
   647  		})
   648  		Checkf(err, "While retrieving password and logging in")
   649  	}
   650  
   651  	closeFunc := func() {
   652  		for _, c := range conns {
   653  			c.Close()
   654  		}
   655  	}
   656  	return dg, closeFunc
   657  }
   658  
   659  // AskUserPassword prompts the user to enter the password for the given user ID.
   660  func AskUserPassword(userid string, pwdType string, times int) (string, error) {
   661  	AssertTrue(times == 1 || times == 2)
   662  	AssertTrue(pwdType == "Current" || pwdType == "New")
   663  	// ask for the user's password
   664  	fmt.Printf("%s password for %v:", pwdType, userid)
   665  	pd, err := terminal.ReadPassword(int(syscall.Stdin))
   666  	if err != nil {
   667  		return "", errors.Wrapf(err, "while reading password")
   668  	}
   669  	fmt.Println()
   670  	password := string(pd)
   671  
   672  	if times == 2 {
   673  		fmt.Printf("Retype %s password for %v:", strings.ToLower(pwdType), userid)
   674  		pd2, err := terminal.ReadPassword(int(syscall.Stdin))
   675  		if err != nil {
   676  			return "", errors.Wrapf(err, "while reading password")
   677  		}
   678  		fmt.Println()
   679  
   680  		password2 := string(pd2)
   681  		if password2 != password {
   682  			return "", errors.Errorf("the two typed passwords do not match")
   683  		}
   684  	}
   685  	return password, nil
   686  }
   687  
   688  // GetPassAndLogin uses the given credentials and client to perform the login operation.
   689  func GetPassAndLogin(dg *dgo.Dgraph, opt *CredOpt) error {
   690  	password := opt.Conf.GetString(opt.PasswordOpt)
   691  	if len(password) == 0 {
   692  		var err error
   693  		password, err = AskUserPassword(opt.UserID, "Current", 1)
   694  		if err != nil {
   695  			return err
   696  		}
   697  	}
   698  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   699  	defer cancel()
   700  	if err := dg.Login(ctx, opt.UserID, password); err != nil {
   701  		return errors.Wrapf(err, "unable to login to the %v account", opt.UserID)
   702  	}
   703  	fmt.Println("Login successful.")
   704  	// update the context so that it has the admin jwt token
   705  	return nil
   706  }