github.com/dshekhar95/sub_dgraph@v0.0.0-20230424164411-6be28e40bbf1/dgraph/cmd/live/batch.go (about)

     1  /*
     2   * Copyright 2017-2022 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package live
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"math"
    23  	"math/rand"
    24  	"strconv"
    25  	"strings"
    26  	"sync"
    27  	"sync/atomic"
    28  	"time"
    29  
    30  	"github.com/dgryski/go-farm"
    31  	"github.com/dustin/go-humanize/english"
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/status"
    35  
    36  	"github.com/dgraph-io/badger/v3"
    37  	"github.com/dgraph-io/dgo/v210"
    38  	"github.com/dgraph-io/dgo/v210/protos/api"
    39  	"github.com/dgraph-io/dgraph/dql"
    40  	"github.com/dgraph-io/dgraph/protos/pb"
    41  	"github.com/dgraph-io/dgraph/tok"
    42  	"github.com/dgraph-io/dgraph/types"
    43  	"github.com/dgraph-io/dgraph/x"
    44  	"github.com/dgraph-io/dgraph/xidmap"
    45  )
    46  
    47  // batchMutationOptions sets the clients batch mode to Pending number of buffers each of Size.
    48  // Running counters of number of rdfs processed, total time and mutations per second are printed
    49  // if PrintCounters is set true.  See Counter.
    50  type batchMutationOptions struct {
    51  	Size          int
    52  	Pending       int
    53  	PrintCounters bool
    54  	MaxRetries    uint32
    55  	// BufferSize is the number of requests that a live loader thread can store at a time
    56  	bufferSize int
    57  	// User could pass a context so that we can stop retrying requests once context is done
    58  	Ctx context.Context
    59  }
    60  
    61  // loader is the data structure held by the user program for all interactions with the Dgraph
    62  // server.  After making grpc connection a new Dgraph is created by function NewDgraphClient.
    63  type loader struct {
    64  	opts batchMutationOptions
    65  
    66  	dc         *dgo.Dgraph
    67  	alloc      *xidmap.XidMap
    68  	ticker     *time.Ticker
    69  	db         *badger.DB
    70  	requestsWg sync.WaitGroup
    71  	// If we retry a request, we add one to retryRequestsWg.
    72  	retryRequestsWg sync.WaitGroup
    73  
    74  	// Miscellaneous information to print counters.
    75  	// Num of N-Quads sent
    76  	nquads uint64
    77  	// Num of txns sent
    78  	txns uint64
    79  	// Num of aborts
    80  	aborts uint64
    81  	// To get time elapsed
    82  	start time.Time
    83  
    84  	conflicts map[uint64]struct{}
    85  	uidsLock  sync.RWMutex
    86  
    87  	reqNum     uint64
    88  	reqs       chan *request
    89  	zeroconn   *grpc.ClientConn
    90  	schema     *schema
    91  	namespaces map[uint64]struct{}
    92  
    93  	upsertLock sync.RWMutex
    94  }
    95  
    96  // Counter keeps a track of various parameters about a batch mutation. Running totals are printed
    97  // if BatchMutationOptions PrintCounters is set to true.
    98  type Counter struct {
    99  	// Number of N-Quads processed by server.
   100  	Nquads uint64
   101  	// Number of mutations processed by the server.
   102  	TxnsDone uint64
   103  	// Number of Aborts
   104  	Aborts uint64
   105  	// Time elapsed since the batch started.
   106  	Elapsed time.Duration
   107  }
   108  
   109  // handleError inspects errors and terminates if the errors are non-recoverable.
   110  // A gRPC code is Internal if there is an unforeseen issue that needs attention.
   111  // A gRPC code is Unavailable when we can't possibly reach the remote server, most likely the
   112  // server expects TLS and our certificate does not match or the host name is not verified. When
   113  // the node certificate is created the name much match the request host name. e.g., localhost not
   114  // 127.0.0.1.
   115  func handleError(err error, isRetry bool) {
   116  	s := status.Convert(err)
   117  	switch {
   118  	case s.Code() == codes.Internal, s.Code() == codes.Unavailable:
   119  		// Let us not crash live loader due to this. Instead, we should infinitely retry to
   120  		// reconnect and retry the request.
   121  		//nolint:gosec // random generator in closed set does not require cryptographic precision
   122  		dur := time.Duration(1+rand.Intn(60)) * time.Second
   123  		fmt.Printf("Connection has been possibly interrupted. Got error: %v."+
   124  			" Will retry after %s.\n", err, dur.Round(time.Second))
   125  		time.Sleep(dur)
   126  	case strings.Contains(s.Message(), "x509"):
   127  		x.Fatalf(s.Message())
   128  	case s.Code() == codes.Aborted:
   129  		if !isRetry && opt.verbose {
   130  			fmt.Printf("Transaction aborted. Will retry in background.\n")
   131  		}
   132  	case strings.Contains(s.Message(), "Server overloaded."):
   133  		//nolint:gosec // random generator in closed set does not require cryptographic precision
   134  		dur := time.Duration(1+rand.Intn(10)) * time.Minute
   135  		fmt.Printf("Server is overloaded. Will retry after %s.\n", dur.Round(time.Minute))
   136  		time.Sleep(dur)
   137  	case err != x.ErrConflict && err != dgo.ErrAborted:
   138  		fmt.Printf("Error while mutating: %v s.Code %v\n", s.Message(), s.Code())
   139  	}
   140  }
   141  
   142  func (l *loader) infinitelyRetry(req *request) {
   143  	defer l.retryRequestsWg.Done()
   144  	defer l.deregister(req)
   145  	nretries := 1
   146  	for i := time.Millisecond; ; i *= 2 {
   147  		err := l.mutate(req)
   148  		if err == nil {
   149  			if opt.verbose {
   150  				fmt.Printf("Transaction succeeded after %s.\n",
   151  					english.Plural(nretries, "retry", "retries"))
   152  			}
   153  			atomic.AddUint64(&l.nquads, uint64(len(req.Set)))
   154  			atomic.AddUint64(&l.txns, 1)
   155  			return
   156  		}
   157  		nretries++
   158  		handleError(err, true)
   159  		atomic.AddUint64(&l.aborts, 1)
   160  		if i >= 10*time.Second {
   161  			i = 10 * time.Second
   162  		}
   163  		time.Sleep(i)
   164  	}
   165  }
   166  
   167  func (l *loader) mutate(req *request) error {
   168  	txn := l.dc.NewTxn()
   169  	req.CommitNow = true
   170  	request := &api.Request{
   171  		CommitNow: true,
   172  		Mutations: []*api.Mutation{req.Mutation},
   173  	}
   174  	_, err := txn.Do(l.opts.Ctx, request)
   175  	return err
   176  }
   177  
   178  func (l *loader) request(req *request) {
   179  	atomic.AddUint64(&l.reqNum, 1)
   180  	err := l.mutate(req)
   181  	if err == nil {
   182  		atomic.AddUint64(&l.nquads, uint64(len(req.Set)))
   183  		atomic.AddUint64(&l.txns, 1)
   184  		l.deregister(req)
   185  		return
   186  	}
   187  	handleError(err, false)
   188  	atomic.AddUint64(&l.aborts, 1)
   189  	l.retryRequestsWg.Add(1)
   190  	go l.infinitelyRetry(req)
   191  }
   192  
   193  func getTypeVal(val *api.Value) (types.Val, error) {
   194  	p := dql.TypeValFrom(val)
   195  	//Convert value to bytes
   196  
   197  	if p.Tid == types.GeoID || p.Tid == types.DateTimeID {
   198  		// Already in bytes format
   199  		p.Value = p.Value.([]byte)
   200  		return p, nil
   201  	}
   202  
   203  	p1 := types.ValueForType(types.BinaryID)
   204  	if err := types.Marshal(p, &p1); err != nil {
   205  		return p1, err
   206  	}
   207  
   208  	p1.Value = p1.Value.([]byte)
   209  	p1.Tid = p.Tid
   210  	return p1, nil
   211  }
   212  
   213  func createUidEdge(nq *api.NQuad, sid, oid uint64) *pb.DirectedEdge {
   214  	return &pb.DirectedEdge{
   215  		Entity:    sid,
   216  		Attr:      nq.Predicate,
   217  		Namespace: nq.Namespace,
   218  		Lang:      nq.Lang,
   219  		Facets:    nq.Facets,
   220  		ValueId:   oid,
   221  		ValueType: pb.Posting_UID,
   222  	}
   223  }
   224  
   225  func createValueEdge(nq *api.NQuad, sid uint64) (*pb.DirectedEdge, error) {
   226  	p := &pb.DirectedEdge{
   227  		Entity:    sid,
   228  		Attr:      nq.Predicate,
   229  		Namespace: nq.Namespace,
   230  		Lang:      nq.Lang,
   231  		Facets:    nq.Facets,
   232  	}
   233  	val, err := getTypeVal(nq.ObjectValue)
   234  	if err != nil {
   235  		return p, err
   236  	}
   237  
   238  	p.Value = val.Value.([]byte)
   239  	p.ValueType = val.Tid.Enum()
   240  	return p, nil
   241  }
   242  
   243  func fingerprintEdge(t *pb.DirectedEdge, pred *predicate) uint64 {
   244  	var id uint64 = math.MaxUint64
   245  
   246  	// Value with a lang type.
   247  	if len(t.Lang) > 0 {
   248  		id = farm.Fingerprint64([]byte(t.Lang))
   249  	} else if pred.List {
   250  		id = farm.Fingerprint64(t.Value)
   251  	}
   252  	return id
   253  }
   254  
   255  func (l *loader) conflictKeysForNQuad(nq *api.NQuad) ([]uint64, error) {
   256  	attr := x.NamespaceAttr(nq.Namespace, nq.Predicate)
   257  	pred, found := l.schema.preds[attr]
   258  
   259  	// We dont' need to generate conflict keys for predicate with noconflict directive.
   260  	if found && pred.NoConflict {
   261  		return nil, nil
   262  	}
   263  
   264  	keys := make([]uint64, 0)
   265  
   266  	// Calculates the conflict keys, inspired by the logic in
   267  	// addMutationInteration in posting/list.go.
   268  	sid, err := strconv.ParseUint(nq.Subject, 0, 64)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	var oid uint64
   274  	var de *pb.DirectedEdge
   275  
   276  	if nq.ObjectValue == nil {
   277  		oid, _ = strconv.ParseUint(nq.ObjectId, 0, 64)
   278  		de = createUidEdge(nq, sid, oid)
   279  	} else {
   280  		var err error
   281  		de, err = createValueEdge(nq, sid)
   282  		x.Check(err)
   283  	}
   284  
   285  	// If the predicate is not found in schema then we don't have to generate any more keys.
   286  	if !found {
   287  		return keys, nil
   288  	}
   289  
   290  	if pred.List {
   291  		key := fingerprintEdge(de, pred)
   292  		keys = append(keys, farm.Fingerprint64(x.DataKey(attr, sid))^key)
   293  	} else {
   294  		keys = append(keys, farm.Fingerprint64(x.DataKey(attr, sid)))
   295  	}
   296  
   297  	if pred.Reverse {
   298  		oi, err := strconv.ParseUint(nq.ObjectId, 0, 64)
   299  		if err != nil {
   300  			return keys, err
   301  		}
   302  		keys = append(keys, farm.Fingerprint64(x.DataKey(attr, oi)))
   303  	}
   304  
   305  	if nq.ObjectValue == nil || !(pred.Count || pred.Index) {
   306  		return keys, nil
   307  	}
   308  
   309  	errs := make([]string, 0)
   310  	for _, tokName := range pred.Tokenizer {
   311  		token, ok := tok.GetTokenizer(tokName)
   312  		if !ok {
   313  			fmt.Printf("unknown tokenizer %q", tokName)
   314  			continue
   315  		}
   316  
   317  		storageVal := types.Val{
   318  			Tid:   types.TypeID(de.GetValueType()),
   319  			Value: de.GetValue(),
   320  		}
   321  
   322  		schemaVal, err := types.Convert(storageVal, types.TypeID(pred.ValueType))
   323  		if err != nil {
   324  			errs = append(errs, err.Error())
   325  		}
   326  		toks, err := tok.BuildTokens(schemaVal.Value, tok.GetTokenizerForLang(token, nq.Lang))
   327  		if err != nil {
   328  			errs = append(errs, err.Error())
   329  		}
   330  
   331  		for _, t := range toks {
   332  			keys = append(keys, farm.Fingerprint64(x.IndexKey(attr, t))^sid)
   333  		}
   334  
   335  	}
   336  
   337  	if len(errs) > 0 {
   338  		return keys, fmt.Errorf(strings.Join(errs, "\n"))
   339  	}
   340  	return keys, nil
   341  }
   342  
   343  func (l *loader) conflictKeysForReq(req *request) []uint64 {
   344  	// Live loader only needs to look at sets and not deletes
   345  	keys := make([]uint64, 0, len(req.Set))
   346  	for _, nq := range req.Set {
   347  		conflicts, err := l.conflictKeysForNQuad(nq)
   348  		if err != nil {
   349  			fmt.Println(err)
   350  			continue
   351  		}
   352  		keys = append(keys, conflicts...)
   353  	}
   354  	return keys
   355  }
   356  
   357  func (l *loader) addConflictKeys(req *request) bool {
   358  	l.uidsLock.Lock()
   359  	defer l.uidsLock.Unlock()
   360  
   361  	for _, key := range req.conflicts {
   362  		if _, ok := l.conflicts[key]; ok {
   363  			return false
   364  		}
   365  	}
   366  
   367  	for _, key := range req.conflicts {
   368  		l.conflicts[key] = struct{}{}
   369  	}
   370  
   371  	return true
   372  }
   373  
   374  func (l *loader) deregister(req *request) {
   375  	l.uidsLock.Lock()
   376  	defer l.uidsLock.Unlock()
   377  
   378  	for _, i := range req.conflicts {
   379  		delete(l.conflicts, i)
   380  	}
   381  }
   382  
   383  // makeRequests can receive requests from batchNquads or directly from BatchSetWithMark.
   384  // It doesn't need to batch the requests anymore. Batching is already done for it by the
   385  // caller functions.
   386  func (l *loader) makeRequests() {
   387  	defer l.requestsWg.Done()
   388  
   389  	buffer := make([]*request, 0, l.opts.bufferSize)
   390  	drain := func(maxSize int) {
   391  		for len(buffer) > maxSize {
   392  			i := 0
   393  			for _, req := range buffer {
   394  				// If there is no conflict in req, we will use it
   395  				// and then it would shift all the other reqs in buffer
   396  				if !l.addConflictKeys(req) {
   397  					buffer[i] = req
   398  					i++
   399  					continue
   400  				}
   401  				// Req will no longer be part of a buffer
   402  				l.request(req)
   403  			}
   404  			buffer = buffer[:i]
   405  		}
   406  	}
   407  
   408  	for req := range l.reqs {
   409  		req.conflicts = l.conflictKeysForReq(req)
   410  		if l.addConflictKeys(req) {
   411  			l.request(req)
   412  		} else {
   413  			buffer = append(buffer, req)
   414  		}
   415  		drain(l.opts.bufferSize - 1)
   416  	}
   417  
   418  	drain(0)
   419  }
   420  
   421  func (l *loader) printCounters() {
   422  	period := 5 * time.Second
   423  	l.ticker = time.NewTicker(period)
   424  	start := time.Now()
   425  
   426  	var last Counter
   427  	for range l.ticker.C {
   428  		counter := l.Counter()
   429  		rate := float64(counter.Nquads-last.Nquads) / period.Seconds()
   430  		elapsed := time.Since(start).Round(time.Second)
   431  		timestamp := time.Now().Format("15:04:05Z0700")
   432  		fmt.Printf("[%s] Elapsed: %s Txns: %d N-Quads: %d N-Quads/s [last 5s]: %5.0f Aborts: %d\n",
   433  			timestamp, x.FixedDuration(elapsed), counter.TxnsDone, counter.Nquads, rate, counter.Aborts)
   434  		last = counter
   435  	}
   436  }
   437  
   438  // Counter returns the current state of the BatchMutation.
   439  func (l *loader) Counter() Counter {
   440  	return Counter{
   441  		Nquads:   atomic.LoadUint64(&l.nquads),
   442  		TxnsDone: atomic.LoadUint64(&l.txns),
   443  		Elapsed:  time.Since(l.start),
   444  		Aborts:   atomic.LoadUint64(&l.aborts),
   445  	}
   446  }