github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/contrib/integration/swap/main.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package main
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"flag"
    23  	"fmt"
    24  	"math/rand"
    25  	"reflect"
    26  	"sort"
    27  	"strings"
    28  	"sync/atomic"
    29  	"time"
    30  
    31  	"github.com/dgraph-io/dgo"
    32  	"github.com/dgraph-io/dgo/protos/api"
    33  	"github.com/dgraph-io/dgo/x"
    34  	"github.com/dgraph-io/dgraph/testutil"
    35  )
    36  
    37  var (
    38  	alpha     = flag.String("alpha", "localhost:9180", "Dgraph alpha address")
    39  	timeout   = flag.Int("timeout", 60, "query/mutation timeout")
    40  	numSents  = flag.Int("sentences", 100, "number of sentences")
    41  	numSwaps  = flag.Int("swaps", 1000, "number of swaps to attempt")
    42  	concurr   = flag.Int("concurrency", 10, "number of concurrent swaps to run concurrently")
    43  	invPerSec = flag.Int("inv", 10, "number of times to check invariants per second")
    44  )
    45  
    46  var (
    47  	successCount uint64
    48  	failCount    uint64
    49  	invChecks    uint64
    50  )
    51  
    52  func main() {
    53  	flag.Parse()
    54  
    55  	sents := createSentences(*numSents)
    56  	sort.Strings(sents)
    57  	wordCount := make(map[string]int)
    58  	for _, s := range sents {
    59  		words := strings.Split(s, " ")
    60  		for _, w := range words {
    61  			wordCount[w]++
    62  		}
    63  	}
    64  	type wc struct {
    65  		word  string
    66  		count int
    67  	}
    68  	var wcs []wc
    69  	for w, c := range wordCount {
    70  		wcs = append(wcs, wc{w, c})
    71  	}
    72  	sort.Slice(wcs, func(i, j int) bool {
    73  		wi := wcs[i]
    74  		wj := wcs[j]
    75  		return wi.word < wj.word
    76  	})
    77  	for _, w := range wcs {
    78  		fmt.Printf("%15s: %3d\n", w.word, w.count)
    79  	}
    80  
    81  	c := testutil.DgraphClientWithGroot(*alpha)
    82  	uids := setup(c, sents)
    83  
    84  	// Check invariants before doing any mutations as a sanity check.
    85  	x.Check(checkInvariants(c, uids, sents))
    86  
    87  	go func() {
    88  		ticker := time.NewTicker(time.Second / time.Duration(*invPerSec))
    89  		for range ticker.C {
    90  			for {
    91  				if err := checkInvariants(c, uids, sents); err == nil {
    92  					break
    93  				} else {
    94  					fmt.Printf("Error while running inv: %v\n", err)
    95  				}
    96  			}
    97  			atomic.AddUint64(&invChecks, 1)
    98  		}
    99  	}()
   100  
   101  	done := make(chan struct{})
   102  	go func() {
   103  		pending := make(chan struct{}, *concurr)
   104  		for i := 0; i < *numSwaps; i++ {
   105  			pending <- struct{}{}
   106  			go func() {
   107  				swapSentences(c,
   108  					uids[rand.Intn(len(uids))],
   109  					uids[rand.Intn(len(uids))],
   110  				)
   111  				<-pending
   112  			}()
   113  		}
   114  		for i := 0; i < *concurr; i++ {
   115  			pending <- struct{}{}
   116  		}
   117  		close(done)
   118  	}()
   119  
   120  	for {
   121  		select {
   122  		case <-time.After(time.Second):
   123  			fmt.Printf("Success:%d Fail:%d Check:%d\n",
   124  				atomic.LoadUint64(&successCount),
   125  				atomic.LoadUint64(&failCount),
   126  				atomic.LoadUint64(&invChecks),
   127  			)
   128  		case <-done:
   129  			// One final check for invariants.
   130  			x.Check(checkInvariants(c, uids, sents))
   131  			return
   132  		}
   133  	}
   134  
   135  }
   136  
   137  func createSentences(n int) []string {
   138  	sents := make([]string, n)
   139  	for i := range sents {
   140  		sents[i] = nextWord()
   141  	}
   142  
   143  	// add trailing words -- some will be common between sentences
   144  	same := 2
   145  	for {
   146  		var w string
   147  		var count int
   148  		for i := range sents {
   149  			if i%same == 0 {
   150  				w = nextWord()
   151  				count++
   152  			}
   153  			sents[i] += " " + w
   154  		}
   155  		if count == 1 {
   156  			// Every sentence got the same trailing word, no point going any further.  Sort the
   157  			// words within each sentence.
   158  			for i, one := range sents {
   159  				splits := strings.Split(one, " ")
   160  				sort.Strings(splits)
   161  				sents[i] = strings.Join(splits, " ")
   162  			}
   163  			return sents
   164  		}
   165  		same *= 2
   166  	}
   167  }
   168  
   169  func setup(c *dgo.Dgraph, sentences []string) []string {
   170  	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second)
   171  	defer cancel()
   172  	x.Check(c.Alter(ctx, &api.Operation{
   173  		DropAll: true,
   174  	}))
   175  	x.Check(c.Alter(ctx, &api.Operation{
   176  		Schema: `sentence: string @index(term) .`,
   177  	}))
   178  
   179  	rdfs := ""
   180  	for i, s := range sentences {
   181  		rdfs += fmt.Sprintf("_:s%d <sentence> %q .\n", i, s)
   182  	}
   183  	txn := c.NewTxn()
   184  	defer func() {
   185  		if err := txn.Discard(ctx); err != nil {
   186  			fmt.Printf("Discarding transaction failed: %+v\n", err)
   187  		}
   188  	}()
   189  
   190  	assigned, err := txn.Mutate(ctx, &api.Mutation{
   191  		SetNquads: []byte(rdfs),
   192  	})
   193  	x.Check(err)
   194  	x.Check(txn.Commit(ctx))
   195  
   196  	var uids []string
   197  	for _, uid := range assigned.GetUids() {
   198  		uids = append(uids, uid)
   199  	}
   200  	return uids
   201  }
   202  
   203  func swapSentences(c *dgo.Dgraph, node1, node2 string) {
   204  	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second)
   205  	defer cancel()
   206  
   207  	txn := c.NewTxn()
   208  	defer func() {
   209  		if err := txn.Discard(ctx); err != nil {
   210  			fmt.Printf("Discarding transaction failed: %+v\n", err)
   211  		}
   212  	}()
   213  
   214  	resp, err := txn.Query(ctx, fmt.Sprintf(`
   215  	{
   216  		node1(func: uid(%s)) {
   217  			sentence
   218  		}
   219  		node2(func: uid(%s)) {
   220  			sentence
   221  		}
   222  	}
   223  	`, node1, node2))
   224  	x.Check(err)
   225  
   226  	decode := struct {
   227  		Node1 []struct {
   228  			Sentence *string
   229  		}
   230  		Node2 []struct {
   231  			Sentence *string
   232  		}
   233  	}{}
   234  	err = json.Unmarshal(resp.GetJson(), &decode)
   235  	x.Check(err)
   236  	x.AssertTrue(len(decode.Node1) == 1)
   237  	x.AssertTrue(len(decode.Node2) == 1)
   238  	x.AssertTrue(decode.Node1[0].Sentence != nil)
   239  	x.AssertTrue(decode.Node2[0].Sentence != nil)
   240  
   241  	// Delete sentences as an intermediate step.
   242  	delRDFs := fmt.Sprintf(`
   243  		<%s> <sentence> %q .
   244  		<%s> <sentence> %q .
   245  	`,
   246  		node1, *decode.Node1[0].Sentence,
   247  		node2, *decode.Node2[0].Sentence,
   248  	)
   249  	if _, err := txn.Mutate(ctx, &api.Mutation{
   250  		DelNquads: []byte(delRDFs),
   251  	}); err != nil {
   252  		atomic.AddUint64(&failCount, 1)
   253  		return
   254  	}
   255  
   256  	// Add garbage data as an intermediate step.
   257  	garbageRDFs := fmt.Sprintf(`
   258  		<%s> <sentence> "...garbage..." .
   259  		<%s> <sentence> "...garbage..." .
   260  	`,
   261  		node1, node2,
   262  	)
   263  	if _, err := txn.Mutate(ctx, &api.Mutation{
   264  		SetNquads: []byte(garbageRDFs),
   265  	}); err != nil {
   266  		atomic.AddUint64(&failCount, 1)
   267  		return
   268  	}
   269  
   270  	// Perform swap.
   271  	rdfs := fmt.Sprintf(`
   272  		<%s> <sentence> %q .
   273  		<%s> <sentence> %q .
   274  	`,
   275  		node1, *decode.Node2[0].Sentence,
   276  		node2, *decode.Node1[0].Sentence,
   277  	)
   278  	if _, err := txn.Mutate(ctx, &api.Mutation{
   279  		SetNquads: []byte(rdfs),
   280  	}); err != nil {
   281  		atomic.AddUint64(&failCount, 1)
   282  		return
   283  	}
   284  	if err := txn.Commit(ctx); err != nil {
   285  		atomic.AddUint64(&failCount, 1)
   286  		return
   287  	}
   288  	atomic.AddUint64(&successCount, 1)
   289  }
   290  
   291  func checkInvariants(c *dgo.Dgraph, uids []string, sentences []string) error {
   292  	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second)
   293  	defer cancel()
   294  
   295  	// Get the sentence for each node. Then build (in memory) a term index.
   296  	// Then we can query dgraph for each term, and make sure the posting list
   297  	// is the same.
   298  
   299  	txn := c.NewTxn()
   300  	uidList := strings.Join(uids, ",")
   301  	resp, err := txn.Query(ctx, fmt.Sprintf(`
   302  	{
   303  		q(func: uid(%s)) {
   304  			sentence
   305  			uid
   306  		}
   307  	}
   308  	`, uidList))
   309  	if err != nil {
   310  		return err
   311  	}
   312  	decode := struct {
   313  		Q []struct {
   314  			Sentence *string
   315  			Uid      *string
   316  		}
   317  	}{}
   318  	x.Check(json.Unmarshal(resp.GetJson(), &decode))
   319  	x.AssertTrue(len(decode.Q) == len(sentences))
   320  
   321  	index := map[string][]string{} // term to uid list
   322  	var gotSentences []string
   323  	for _, node := range decode.Q {
   324  		x.AssertTrue(node.Sentence != nil)
   325  		x.AssertTrue(node.Uid != nil)
   326  		gotSentences = append(gotSentences, *node.Sentence)
   327  		for _, word := range strings.Split(*node.Sentence, " ") {
   328  			index[word] = append(index[word], *node.Uid)
   329  		}
   330  	}
   331  	sort.Strings(gotSentences)
   332  	for i := 0; i < len(sentences); i++ {
   333  		if sentences[i] != gotSentences[i] {
   334  			fmt.Printf("Sentence doesn't match. Wanted: %q. Got: %q\n", sentences[i], gotSentences[i])
   335  			fmt.Printf("All sentences: %v\n", sentences)
   336  			fmt.Printf("Got sentences: %v\n", gotSentences)
   337  			x.AssertTrue(false)
   338  		}
   339  	}
   340  
   341  	for word, uids := range index {
   342  		q := fmt.Sprintf(`
   343  		{
   344  			q(func: anyofterms(sentence, %q)) {
   345  				uid
   346  			}
   347  		}
   348  		`, word)
   349  
   350  		resp, err := txn.Query(ctx, q)
   351  		if err != nil {
   352  			return err
   353  		}
   354  		decode := struct {
   355  			Q []struct {
   356  				Uid *string
   357  			}
   358  		}{}
   359  		x.Check(json.Unmarshal(resp.GetJson(), &decode))
   360  		var gotUids []string
   361  		for _, node := range decode.Q {
   362  			x.AssertTrue(node.Uid != nil)
   363  			gotUids = append(gotUids, *node.Uid)
   364  		}
   365  
   366  		sort.Strings(gotUids)
   367  		sort.Strings(uids)
   368  		if !reflect.DeepEqual(gotUids, uids) {
   369  			panic(fmt.Sprintf(`query: %s\n
   370  			Uids in index for %q didn't match
   371  			calculated: %v. Len: %d
   372  				got:        %v
   373  			`, q, word, uids, len(uids), gotUids))
   374  		}
   375  	}
   376  	return nil
   377  }