github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/x/x.go (about) 1 /* 2 * Copyright 2015-2018 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package x 18 19 import ( 20 "bufio" 21 "bytes" 22 "context" 23 "crypto/tls" 24 "encoding/json" 25 "fmt" 26 "math" 27 "math/rand" 28 "net" 29 "net/http" 30 "os" 31 "regexp" 32 "sort" 33 "strconv" 34 "strings" 35 "syscall" 36 "time" 37 38 "golang.org/x/crypto/ssh/terminal" 39 40 "github.com/dgraph-io/dgo" 41 "github.com/dgraph-io/dgo/protos/api" 42 "github.com/golang/glog" 43 "github.com/pkg/errors" 44 "github.com/spf13/viper" 45 "go.opencensus.io/trace" 46 "google.golang.org/grpc" 47 "google.golang.org/grpc/credentials" 48 "google.golang.org/grpc/encoding/gzip" 49 ) 50 51 // Error constants representing different types of errors. 52 var ( 53 // ErrNotSupported is thrown when an enterprise feature is requested in the open source version. 54 ErrNotSupported = errors.Errorf("Feature available only in Dgraph Enterprise Edition") 55 ) 56 57 const ( 58 // Success is equivalent to the HTTP 200 error code. 59 Success = "Success" 60 // ErrorUnauthorized is equivalent to the HTTP 401 error code. 61 ErrorUnauthorized = "ErrorUnauthorized" 62 // ErrorInvalidMethod is equivalent to the HTTP 405 error code. 63 ErrorInvalidMethod = "ErrorInvalidMethod" 64 // ErrorInvalidRequest is equivalent to the HTTP 400 error code. 65 ErrorInvalidRequest = "ErrorInvalidRequest" 66 // Error is a general error code. 67 Error = "Error" 68 // ErrorNoData is an error returned when the requested data cannot be returned. 69 ErrorNoData = "ErrorNoData" 70 // ValidHostnameRegex is a regex that accepts our expected hostname format. 71 ValidHostnameRegex = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]" + 72 "|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])$" 73 // Star is equivalent to using * in a mutation. 74 // When changing this value also remember to change in in client/client.go:DeleteEdges. 75 Star = "_STAR_ALL" 76 77 // GrpcMaxSize is the maximum possible size for a gRPC message. 78 // Dgraph uses the maximum size for the most flexibility (4GB - equal 79 // to the max grpc frame size). Users will still need to set the max 80 // message sizes allowable on the client size when dialing. 81 GrpcMaxSize = 4 << 30 82 83 // PortZeroGrpc is the default gRPC port for zero. 84 PortZeroGrpc = 5080 85 // PortZeroHTTP is the default HTTP port for zero. 86 PortZeroHTTP = 6080 87 // PortInternal is the default port for internal use. 88 PortInternal = 7080 89 // PortHTTP is the default HTTP port for alpha. 90 PortHTTP = 8080 91 // PortGrpc is the default gRPC port for alpha. 92 PortGrpc = 9080 93 // ForceAbortDifference is the maximum allowed difference between 94 // AppliedUntil - TxnMarks.DoneUntil() before old transactions start getting aborted. 95 ForceAbortDifference = 5000 96 97 // FacetDelimeter is the symbol used to distinguish predicate names from facets. 98 FacetDelimeter = "|" 99 100 // GrootId is the ID of the admin user for ACLs. 101 GrootId = "groot" 102 // AclPredicates is the JSON representation of the predicates reserved for use 103 // by the ACL system. 104 AclPredicates = ` 105 {"predicate":"dgraph.xid","type":"string", "index": true, "tokenizer":["exact"], "upsert": true}, 106 {"predicate":"dgraph.password","type":"password"}, 107 {"predicate":"dgraph.user.group","list":true, "reverse": true, "type": "uid"}, 108 {"predicate":"dgraph.group.acl","type":"string"} 109 ` 110 ) 111 112 var ( 113 // Useful for running multiple servers on the same machine. 114 regExpHostName = regexp.MustCompile(ValidHostnameRegex) 115 // Nilbyte is a nil byte slice. Used 116 Nilbyte []byte 117 ) 118 119 // ShouldCrash returns true if the error should cause the process to crash. 120 func ShouldCrash(err error) bool { 121 if err == nil { 122 return false 123 } 124 errStr := grpc.ErrorDesc(err) 125 return strings.Contains(errStr, "REUSE_RAFTID") || 126 strings.Contains(errStr, "REUSE_ADDR") || 127 strings.Contains(errStr, "NO_ADDR") || 128 strings.Contains(errStr, "ENTERPRISE_LIMIT_REACHED") 129 } 130 131 // WhiteSpace Replacer removes spaces and tabs from a string. 132 var WhiteSpace = strings.NewReplacer(" ", "", "\t", "") 133 134 // GqlError is a GraphQL spec compliant error structure. See GraphQL spec on 135 // errors here: https://graphql.github.io/graphql-spec/June2018/#sec-Errors 136 // 137 // Note: "Every error must contain an entry with the key message with a string 138 // description of the error intended for the developer as a guide to understand 139 // and correct the error." 140 // 141 // "If an error can be associated to a particular point in the request [the error] 142 // should contain an entry with the key locations with a list of locations" 143 // 144 // Path is about GraphQL results and Errors for GraphQL layer. 145 // 146 // Extensions is for everything else. 147 type GqlError struct { 148 Message string `json:"message"` 149 Locations []Location `json:"locations,omitempty"` 150 Path []interface{} `json:"path,omitempty"` 151 Extensions map[string]interface{} `json:"extensions,omitempty"` 152 } 153 154 // A Location is the Line+Column index of an error in a request. 155 type Location struct { 156 Line int `json:"line,omitempty"` 157 Column int `json:"column,omitempty"` 158 } 159 160 type queryRes struct { 161 Errors []GqlError `json:"errors"` 162 } 163 164 // SetStatus sets the error code, message and the newly assigned uids 165 // in the http response. 166 func SetStatus(w http.ResponseWriter, code, msg string) { 167 var qr queryRes 168 ext := make(map[string]interface{}) 169 ext["code"] = code 170 qr.Errors = append(qr.Errors, GqlError{Message: msg, Extensions: ext}) 171 if js, err := json.Marshal(qr); err == nil { 172 if _, err := w.Write(js); err != nil { 173 glog.Errorf("Error while writing: %+v", err) 174 } 175 } else { 176 panic(fmt.Sprintf("Unable to marshal: %+v", qr)) 177 } 178 } 179 180 // SetHttpStatus is similar to SetStatus but sets a proper HTTP status code 181 // in the response instead of always returning HTTP 200 (OK). 182 func SetHttpStatus(w http.ResponseWriter, code int, msg string) { 183 w.WriteHeader(code) 184 SetStatus(w, "error", msg) 185 } 186 187 // AddCorsHeaders adds the CORS headers to an HTTP response. 188 func AddCorsHeaders(w http.ResponseWriter) { 189 w.Header().Set("Access-Control-Allow-Origin", "*") 190 w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") 191 w.Header().Set("Access-Control-Allow-Headers", "X-Dgraph-AccessToken, "+ 192 "Content-Type, Content-Length, Accept-Encoding, Cache-Control, "+ 193 "X-CSRF-Token, X-Auth-Token, X-Requested-With") 194 w.Header().Set("Access-Control-Allow-Credentials", "true") 195 w.Header().Set("Connection", "close") 196 } 197 198 // QueryResWithData represents a response that holds errors as well as data. 199 type QueryResWithData struct { 200 Errors []GqlError `json:"errors"` 201 Data *string `json:"data"` 202 } 203 204 // SetStatusWithData sets the errors in the response and ensures that the data key 205 // in the data is present with value nil. 206 // In case an error was encountered after the query execution started, we have to return data 207 // key with null value according to GraphQL spec. 208 func SetStatusWithData(w http.ResponseWriter, code, msg string) { 209 var qr QueryResWithData 210 ext := make(map[string]interface{}) 211 ext["code"] = code 212 qr.Errors = append(qr.Errors, GqlError{Message: msg, Extensions: ext}) 213 // This would ensure that data key is present with value null. 214 if js, err := json.Marshal(qr); err == nil { 215 if _, err := w.Write(js); err != nil { 216 glog.Errorf("Error while writing: %+v", err) 217 } 218 } else { 219 panic(fmt.Sprintf("Unable to marshal: %+v", qr)) 220 } 221 } 222 223 // Reply sets the body of an HTTP response to the JSON representation of the given reply. 224 func Reply(w http.ResponseWriter, rep interface{}) { 225 if js, err := json.Marshal(rep); err == nil { 226 w.Header().Set("Content-Type", "application/json") 227 fmt.Fprint(w, string(js)) 228 } else { 229 SetStatus(w, Error, "Internal server error") 230 } 231 } 232 233 // ParseRequest parses the body of the given request. 234 func ParseRequest(w http.ResponseWriter, r *http.Request, data interface{}) bool { 235 defer r.Body.Close() 236 decoder := json.NewDecoder(r.Body) 237 if err := decoder.Decode(&data); err != nil { 238 SetStatus(w, Error, fmt.Sprintf("While parsing request: %v", err)) 239 return false 240 } 241 return true 242 } 243 244 // Min returns the minimum of the two given numbers. 245 func Min(a, b uint64) uint64 { 246 if a < b { 247 return a 248 } 249 return b 250 } 251 252 // Max returns the maximum of the two given numbers. 253 func Max(a, b uint64) uint64 { 254 if a > b { 255 return a 256 } 257 return b 258 } 259 260 // RetryUntilSuccess runs the given function until it succeeds or can no longer be retried. 261 func RetryUntilSuccess(maxRetries int, waitAfterFailure time.Duration, 262 f func() error) error { 263 var err error 264 for retry := maxRetries; retry != 0; retry-- { 265 if err = f(); err == nil { 266 return nil 267 } 268 if waitAfterFailure > 0 { 269 time.Sleep(waitAfterFailure) 270 } 271 } 272 return err 273 } 274 275 // HasString returns whether the slice contains the given string. 276 func HasString(a []string, b string) bool { 277 for _, k := range a { 278 if k == b { 279 return true 280 } 281 } 282 return false 283 } 284 285 // ReadLine reads a single line from a buffered reader. The line is read into the 286 // passed in buffer to minimize allocations. This is the preferred 287 // method for loading long lines which could be longer than the buffer 288 // size of bufio.Scanner. 289 func ReadLine(r *bufio.Reader, buf *bytes.Buffer) error { 290 isPrefix := true 291 var err error 292 buf.Reset() 293 for isPrefix && err == nil { 294 var line []byte 295 // The returned line is an pb.buffer in bufio and is only 296 // valid until the next call to ReadLine. It needs to be copied 297 // over to our own buffer. 298 line, isPrefix, err = r.ReadLine() 299 if err == nil { 300 buf.Write(line) 301 } 302 } 303 return err 304 } 305 306 // FixedDuration returns the given duration as a string of fixed length. 307 func FixedDuration(d time.Duration) string { 308 str := fmt.Sprintf("%02ds", int(d.Seconds())%60) 309 if d >= time.Minute { 310 str = fmt.Sprintf("%02dm", int(d.Minutes())%60) + str 311 } 312 if d >= time.Hour { 313 str = fmt.Sprintf("%02dh", int(d.Hours())) + str 314 } 315 return str 316 } 317 318 // PageRange returns start and end indices given pagination params. Note that n 319 // is the size of the input list. 320 func PageRange(count, offset, n int) (int, int) { 321 if n == 0 { 322 return 0, 0 323 } 324 if count == 0 && offset == 0 { 325 return 0, n 326 } 327 if count < 0 { 328 // Items from the back of the array, like Python arrays. Do a positive mod n. 329 if count*-1 > n { 330 count = -n 331 } 332 return (((n + count) % n) + n) % n, n 333 } 334 start := offset 335 if start < 0 { 336 start = 0 337 } 338 if start > n { 339 return n, n 340 } 341 if count == 0 { // No count specified. Just take the offset parameter. 342 return start, n 343 } 344 end := start + count 345 if end > n { 346 end = n 347 } 348 return start, end 349 } 350 351 // ValidateAddress checks whether given address can be used with grpc dial function 352 func ValidateAddress(addr string) bool { 353 host, port, err := net.SplitHostPort(addr) 354 if err != nil { 355 return false 356 } 357 if p, err := strconv.Atoi(port); err != nil || p <= 0 || p >= 65536 { 358 return false 359 } 360 if ip := net.ParseIP(host); ip != nil { 361 return true 362 } 363 // try to parse as hostname as per hostname RFC 364 if len(strings.Replace(host, ".", "", -1)) > 255 { 365 return false 366 } 367 return regExpHostName.MatchString(host) 368 } 369 370 // RemoveDuplicates sorts the slice of strings and removes duplicates. changes the input slice. 371 // This function should be called like: someSlice = RemoveDuplicates(someSlice) 372 func RemoveDuplicates(s []string) (out []string) { 373 sort.Strings(s) 374 out = s[:0] 375 for i := range s { 376 if i > 0 && s[i] == s[i-1] { 377 continue 378 } 379 out = append(out, s[i]) 380 } 381 return 382 } 383 384 // BytesBuffer provides a buffer backed by byte slices. 385 type BytesBuffer struct { 386 data [][]byte 387 off int 388 sz int 389 } 390 391 func (b *BytesBuffer) grow(n int) { 392 if n < 128 { 393 n = 128 394 } 395 if len(b.data) == 0 { 396 b.data = append(b.data, make([]byte, n)) 397 } 398 399 last := len(b.data) - 1 400 // Return if we have sufficient space 401 if len(b.data[last])-b.off >= n { 402 return 403 } 404 sz := len(b.data[last]) * 2 405 if sz > 512<<10 { 406 sz = 512 << 10 // 512 KB 407 } 408 if sz < n { 409 sz = n 410 } 411 b.data[last] = b.data[last][:b.off] 412 b.sz += len(b.data[last]) 413 b.data = append(b.data, make([]byte, sz)) 414 b.off = 0 415 } 416 417 // Slice returns a slice of length n to be used for writing. 418 func (b *BytesBuffer) Slice(n int) []byte { 419 b.grow(n) 420 last := len(b.data) - 1 421 b.off += n 422 b.sz += n 423 return b.data[last][b.off-n : b.off] 424 } 425 426 // Length returns the size of the buffer. 427 func (b *BytesBuffer) Length() int { 428 return b.sz 429 } 430 431 // CopyTo copies the contents of the buffer to the given byte slice. 432 // Caller should ensure that o is of appropriate length. 433 func (b *BytesBuffer) CopyTo(o []byte) int { 434 offset := 0 435 for i, d := range b.data { 436 if i == len(b.data)-1 { 437 copy(o[offset:], d[:b.off]) 438 offset += b.off 439 } else { 440 copy(o[offset:], d) 441 offset += len(d) 442 } 443 } 444 return offset 445 } 446 447 // TruncateBy reduces the size of the bugger by the given amount. 448 // Always give back <= touched bytes. 449 func (b *BytesBuffer) TruncateBy(n int) { 450 b.off -= n 451 b.sz -= n 452 AssertTrue(b.off >= 0 && b.sz >= 0) 453 } 454 455 type record struct { 456 Name string 457 Dur time.Duration 458 } 459 460 // Timer implements a timer that supports recording the duration of events. 461 type Timer struct { 462 start time.Time 463 last time.Time 464 records []record 465 } 466 467 // Start starts the timer and clears the list of records. 468 func (t *Timer) Start() { 469 t.start = time.Now() 470 t.last = t.start 471 t.records = t.records[:0] 472 } 473 474 // Record records an event and assigns it the given name. 475 func (t *Timer) Record(name string) { 476 now := time.Now() 477 t.records = append(t.records, record{ 478 Name: name, 479 Dur: now.Sub(t.last).Round(time.Millisecond), 480 }) 481 t.last = now 482 } 483 484 // Total returns the duration since the timer was started. 485 func (t *Timer) Total() time.Duration { 486 return time.Since(t.start).Round(time.Millisecond) 487 } 488 489 func (t *Timer) String() string { 490 sort.Slice(t.records, func(i, j int) bool { 491 return t.records[i].Dur > t.records[j].Dur 492 }) 493 return fmt.Sprintf("Timer Total: %s. Breakdown: %v", t.Total(), t.records) 494 } 495 496 // PredicateLang extracts the language from a predicate (or facet) name. 497 // Returns the predicate and the language tag, if any. 498 func PredicateLang(s string) (string, string) { 499 i := strings.LastIndex(s, "@") 500 if i <= 0 { 501 return s, "" 502 } 503 return s[0:i], s[i+1:] 504 } 505 506 // DivideAndRule is used to divide a number of tasks among multiple go routines. 507 func DivideAndRule(num int) (numGo, width int) { 508 numGo, width = 64, 0 509 for ; numGo >= 1; numGo /= 2 { 510 widthF := math.Ceil(float64(num) / float64(numGo)) 511 if numGo == 1 || widthF >= 256.0 { 512 width = int(widthF) 513 return 514 } 515 } 516 return 517 } 518 519 // SetupConnection starts a secure gRPC connection to the given host. 520 func SetupConnection(host string, tlsCfg *tls.Config, useGz bool) (*grpc.ClientConn, error) { 521 callOpts := append([]grpc.CallOption{}, 522 grpc.MaxCallRecvMsgSize(GrpcMaxSize), 523 grpc.MaxCallSendMsgSize(GrpcMaxSize)) 524 525 if useGz { 526 fmt.Fprintf(os.Stderr, "Using compression with %s\n", host) 527 callOpts = append(callOpts, grpc.UseCompressor(gzip.Name)) 528 } 529 530 dialOpts := append([]grpc.DialOption{}, 531 grpc.WithDefaultCallOptions(callOpts...), 532 grpc.WithBlock()) 533 534 if tlsCfg != nil { 535 dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg))) 536 } else { 537 dialOpts = append(dialOpts, grpc.WithInsecure()) 538 } 539 540 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 541 defer cancel() 542 543 return grpc.DialContext(ctx, host, dialOpts...) 544 } 545 546 // Diff computes the difference between the keys of the two given maps. 547 func Diff(dst map[string]struct{}, src map[string]struct{}) ([]string, []string) { 548 var add []string 549 var del []string 550 551 for g := range dst { 552 if _, ok := src[g]; !ok { 553 add = append(add, g) 554 } 555 } 556 for g := range src { 557 if _, ok := dst[g]; !ok { 558 del = append(del, g) 559 } 560 } 561 562 return add, del 563 } 564 565 // SpanTimer returns a function used to record the duration of the given span. 566 func SpanTimer(span *trace.Span, name string) func() { 567 if span == nil { 568 return func() {} 569 } 570 uniq := int64(rand.Int31()) 571 attrs := []trace.Attribute{ 572 trace.Int64Attribute("funcId", uniq), 573 trace.StringAttribute("funcName", name), 574 } 575 span.Annotate(attrs, "Start.") 576 start := time.Now() 577 578 return func() { 579 span.Annotatef(attrs, "End. Took %s", time.Since(start)) 580 // TODO: We can look into doing a latency record here. 581 } 582 } 583 584 // CloseFunc needs to be called to close all the client connections. 585 type CloseFunc func() 586 587 // CredOpt stores the options for logging in, including the password and user. 588 type CredOpt struct { 589 Conf *viper.Viper 590 UserID string 591 PasswordOpt string 592 } 593 594 // GetDgraphClient creates a Dgraph client based on the following options in the configuration: 595 // --alpha specifies a comma separated list of endpoints to connect to 596 // --tls_cacert, --tls_cert, --tls_key etc specify the TLS configuration of the connection 597 // --retries specifies how many times we should retry the connection to each endpoint upon failures 598 // --user and --password specify the credentials we should use to login with the server 599 func GetDgraphClient(conf *viper.Viper, login bool) (*dgo.Dgraph, CloseFunc) { 600 alphas := conf.GetString("alpha") 601 if len(alphas) == 0 { 602 glog.Fatalf("The --alpha option must be set in order to connect to Dgraph") 603 } 604 605 fmt.Printf("\nRunning transaction with dgraph endpoint: %v\n", alphas) 606 tlsCfg, err := LoadClientTLSConfig(conf) 607 Checkf(err, "While loading TLS configuration") 608 609 ds := strings.Split(alphas, ",") 610 var conns []*grpc.ClientConn 611 var clients []api.DgraphClient 612 613 retries := 1 614 if conf.IsSet("retries") { 615 retries = conf.GetInt("retries") 616 if retries < 1 { 617 retries = 1 618 } 619 } 620 621 for _, d := range ds { 622 var conn *grpc.ClientConn 623 for i := 0; i < retries; retries++ { 624 conn, err = SetupConnection(d, tlsCfg, false) 625 if err == nil { 626 break 627 } 628 fmt.Printf("While trying to setup connection: %v. Retrying...\n", err) 629 time.Sleep(time.Second) 630 } 631 if conn == nil { 632 Fatalf("Could not setup connection after %d retries", retries) 633 } 634 635 conns = append(conns, conn) 636 dc := api.NewDgraphClient(conn) 637 clients = append(clients, dc) 638 } 639 640 dg := dgo.NewDgraphClient(clients...) 641 user := conf.GetString("user") 642 if login && len(user) > 0 { 643 err = GetPassAndLogin(dg, &CredOpt{ 644 Conf: conf, 645 UserID: user, 646 PasswordOpt: "password", 647 }) 648 Checkf(err, "While retrieving password and logging in") 649 } 650 651 closeFunc := func() { 652 for _, c := range conns { 653 c.Close() 654 } 655 } 656 return dg, closeFunc 657 } 658 659 // AskUserPassword prompts the user to enter the password for the given user ID. 660 func AskUserPassword(userid string, pwdType string, times int) (string, error) { 661 AssertTrue(times == 1 || times == 2) 662 AssertTrue(pwdType == "Current" || pwdType == "New") 663 // ask for the user's password 664 fmt.Printf("%s password for %v:", pwdType, userid) 665 pd, err := terminal.ReadPassword(int(syscall.Stdin)) 666 if err != nil { 667 return "", errors.Wrapf(err, "while reading password") 668 } 669 fmt.Println() 670 password := string(pd) 671 672 if times == 2 { 673 fmt.Printf("Retype %s password for %v:", strings.ToLower(pwdType), userid) 674 pd2, err := terminal.ReadPassword(int(syscall.Stdin)) 675 if err != nil { 676 return "", errors.Wrapf(err, "while reading password") 677 } 678 fmt.Println() 679 680 password2 := string(pd2) 681 if password2 != password { 682 return "", errors.Errorf("the two typed passwords do not match") 683 } 684 } 685 return password, nil 686 } 687 688 // GetPassAndLogin uses the given credentials and client to perform the login operation. 689 func GetPassAndLogin(dg *dgo.Dgraph, opt *CredOpt) error { 690 password := opt.Conf.GetString(opt.PasswordOpt) 691 if len(password) == 0 { 692 var err error 693 password, err = AskUserPassword(opt.UserID, "Current", 1) 694 if err != nil { 695 return err 696 } 697 } 698 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 699 defer cancel() 700 if err := dg.Login(ctx, opt.UserID, password); err != nil { 701 return errors.Wrapf(err, "unable to login to the %v account", opt.UserID) 702 } 703 fmt.Println("Login successful.") 704 // update the context so that it has the admin jwt token 705 return nil 706 }