github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/dgraph/cmd/counter/increment.go (about)

     1  /*
     2   * Copyright 2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package counter builds a tool that retrieves a value for UID=0x01, and increments
    18  // it by 1. If successful, it prints out the incremented value. It assumes that it has
    19  // access to UID=0x01, and that `val` predicate is of type int.
    20  package counter
    21  
    22  import (
    23  	"context"
    24  	"encoding/json"
    25  	"fmt"
    26  	"time"
    27  
    28  	"github.com/dgraph-io/dgo"
    29  	"github.com/dgraph-io/dgo/protos/api"
    30  	"github.com/dgraph-io/dgraph/x"
    31  	"github.com/pkg/errors"
    32  	"github.com/spf13/cobra"
    33  	"github.com/spf13/viper"
    34  )
    35  
    36  // Increment is the sub-command invoked when calling "dgraph increment".
    37  var Increment x.SubCommand
    38  
    39  func init() {
    40  	Increment.Cmd = &cobra.Command{
    41  		Use:   "increment",
    42  		Short: "Increment a counter transactionally",
    43  		Run: func(cmd *cobra.Command, args []string) {
    44  			run(Increment.Conf)
    45  		},
    46  	}
    47  	Increment.EnvPrefix = "DGRAPH_INCREMENT"
    48  
    49  	flag := Increment.Cmd.Flags()
    50  	flag.String("alpha", "localhost:9080", "Address of Dgraph Alpha.")
    51  	flag.Int("num", 1, "How many times to run.")
    52  	flag.Int("retries", 10, "How many times to retry setting up the connection.")
    53  	flag.Duration("wait", 0*time.Second, "How long to wait.")
    54  	flag.String("user", "", "Username if login is required.")
    55  	flag.String("password", "", "Password of the user.")
    56  	flag.String("pred", "counter.val",
    57  		"Predicate to use for storing the counter.")
    58  	flag.Bool("ro", false,
    59  		"Read-only. Read the counter value without updating it.")
    60  	flag.Bool("be", false,
    61  		"Best-effort. Read counter value without retrieving timestamp from Zero.")
    62  	// TLS configuration
    63  	x.RegisterClientTLSFlags(flag)
    64  }
    65  
    66  // Counter stores information about the value being incremented by this tool.
    67  type Counter struct {
    68  	Uid string `json:"uid"`
    69  	Val int    `json:"val"`
    70  
    71  	startTs  uint64 // Only used for internal testing.
    72  	qLatency time.Duration
    73  	mLatency time.Duration
    74  }
    75  
    76  func queryCounter(txn *dgo.Txn, pred string) (Counter, error) {
    77  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    78  	defer cancel()
    79  
    80  	var counter Counter
    81  	query := fmt.Sprintf("{ q(func: has(%s)) { uid, val: %s }}", pred, pred)
    82  	resp, err := txn.Query(ctx, query)
    83  	if err != nil {
    84  		return counter, errors.Wrapf(err, "while doing query")
    85  	}
    86  
    87  	// Total query latency is sum of encoding, parsing and processing latencies.
    88  	queryLatency := resp.Latency.GetEncodingNs() +
    89  		resp.Latency.GetParsingNs() + resp.Latency.GetProcessingNs()
    90  
    91  	m := make(map[string][]Counter)
    92  	if err := json.Unmarshal(resp.Json, &m); err != nil {
    93  		return counter, err
    94  	}
    95  	if len(m["q"]) == 0 {
    96  		// Do nothing.
    97  	} else if len(m["q"]) == 1 {
    98  		counter = m["q"][0]
    99  	} else {
   100  		panic(fmt.Sprintf("Invalid response: %q", resp.Json))
   101  	}
   102  	counter.startTs = resp.GetTxn().GetStartTs()
   103  	counter.qLatency = time.Duration(queryLatency).Round(time.Millisecond)
   104  	return counter, nil
   105  }
   106  
   107  func process(dg *dgo.Dgraph, conf *viper.Viper) (Counter, error) {
   108  	ro := conf.GetBool("ro")
   109  	be := conf.GetBool("be")
   110  	pred := conf.GetString("pred")
   111  	var txn *dgo.Txn
   112  
   113  	switch {
   114  	case be:
   115  		txn = dg.NewReadOnlyTxn().BestEffort()
   116  	case ro:
   117  		txn = dg.NewReadOnlyTxn()
   118  	default:
   119  		txn = dg.NewTxn()
   120  	}
   121  	defer func() {
   122  		if err := txn.Discard(nil); err != nil {
   123  			fmt.Printf("Discarding transaction failed: %+v\n", err)
   124  		}
   125  	}()
   126  
   127  	counter, err := queryCounter(txn, pred)
   128  	if err != nil {
   129  		return Counter{}, err
   130  	}
   131  	if be || ro {
   132  		return counter, nil
   133  	}
   134  
   135  	counter.Val++
   136  	var mu api.Mutation
   137  	mu.CommitNow = true
   138  	if len(counter.Uid) == 0 {
   139  		counter.Uid = "_:new"
   140  	}
   141  	mu.SetNquads = []byte(fmt.Sprintf(`<%s> <%s> "%d"^^<xs:int> .`, counter.Uid, pred, counter.Val))
   142  
   143  	// Don't put any timeout for mutation.
   144  	resp, err := txn.Mutate(context.Background(), &mu)
   145  	if err != nil {
   146  		return Counter{}, err
   147  	}
   148  
   149  	mutationLatency := resp.Latency.GetProcessingNs() +
   150  		resp.Latency.GetParsingNs() + resp.Latency.GetEncodingNs()
   151  	counter.mLatency = time.Duration(mutationLatency).Round(time.Millisecond)
   152  	return counter, nil
   153  }
   154  
   155  func run(conf *viper.Viper) {
   156  	startTime := time.Now()
   157  	defer func() { fmt.Println("Total:", time.Since(startTime).Round(time.Millisecond)) }()
   158  
   159  	waitDur := conf.GetDuration("wait")
   160  	num := conf.GetInt("num")
   161  	format := "0102 03:04:05.999"
   162  
   163  	dg, closeFunc := x.GetDgraphClient(Increment.Conf, true)
   164  	defer closeFunc()
   165  
   166  	for num > 0 {
   167  		txnStart := time.Now() // Start time of transaction
   168  		cnt, err := process(dg, conf)
   169  		now := time.Now().UTC().Format(format)
   170  		if err != nil {
   171  			fmt.Printf("%-17s While trying to process counter: %v. Retrying...\n", now, err)
   172  			time.Sleep(time.Second)
   173  			continue
   174  		}
   175  		serverLat := cnt.qLatency + cnt.mLatency
   176  		clientLat := time.Since(txnStart).Round(time.Millisecond)
   177  		fmt.Printf("%-17s Counter VAL: %d   [ Ts: %d ] Latency: Q %s M %s S %s C %s D %s\n", now, cnt.Val,
   178  			cnt.startTs, cnt.qLatency, cnt.mLatency, serverLat, clientLat, clientLat-serverLat)
   179  		num--
   180  		time.Sleep(waitDur)
   181  	}
   182  }