github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/testutil/client.go (about)

     1  /*
     2   * Copyright 2019 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package testutil
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"os"
    27  	"os/exec"
    28  	"strconv"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/dgraph-io/dgo"
    34  	"github.com/dgraph-io/dgo/protos/api"
    35  	"github.com/dgraph-io/dgraph/x"
    36  	"github.com/pkg/errors"
    37  	"github.com/spf13/viper"
    38  	"github.com/stretchr/testify/require"
    39  	"google.golang.org/grpc"
    40  	"google.golang.org/grpc/credentials"
    41  )
    42  
    43  // socket addr = IP address and port number
    44  var (
    45  
    46  	// SockAddr is the address to the gRPC endpoint of the alpha used during tests.
    47  	SockAddr string
    48  	// SockAddrHttp is the address to the HTTP of alpha used during tests.
    49  	SockAddrHttp string
    50  	// SockAddrZero is the address to the gRPC endpoint of the zero used during tests.
    51  	SockAddrZero string
    52  	// SockAddrZeroHttp is the address to the HTTP endpoint of the zero used during tests.
    53  	SockAddrZeroHttp string
    54  )
    55  
    56  // This allows running (most) tests against dgraph running on the default ports, for example.
    57  // Only the GRPC ports are needed and the others are deduced.
    58  func init() {
    59  	var grpcPort int
    60  
    61  	getPort := func(envVar string, dfault int) int {
    62  		p := os.Getenv(envVar)
    63  		if p == "" {
    64  			return dfault
    65  		}
    66  		port, _ := strconv.Atoi(p)
    67  		return port
    68  	}
    69  
    70  	grpcPort = getPort("TEST_PORT_ALPHA", 9180)
    71  	SockAddr = fmt.Sprintf("localhost:%d", grpcPort)
    72  	SockAddrHttp = fmt.Sprintf("localhost:%d", grpcPort-1000)
    73  
    74  	grpcPort = getPort("TEST_PORT_ZERO", 5180)
    75  	SockAddrZero = fmt.Sprintf("localhost:%d", grpcPort)
    76  	SockAddrZeroHttp = fmt.Sprintf("localhost:%d", grpcPort+1000)
    77  }
    78  
    79  // DgraphClientDropAll creates a Dgraph client and drops all existing data.
    80  // It is intended to be called from TestMain() to establish a Dgraph connection shared
    81  // by all tests, so there is no testing.T instance for it to use.
    82  func DgraphClientDropAll(serviceAddr string) *dgo.Dgraph {
    83  	dg := DgraphClient(serviceAddr)
    84  	var err error
    85  	for {
    86  		// keep retrying until we succeed or receive a non-retriable error
    87  		err := dg.Alter(context.Background(), &api.Operation{DropAll: true})
    88  		if err == nil || !strings.Contains(err.Error(), "Please retry") {
    89  			break
    90  		}
    91  		time.Sleep(time.Second)
    92  	}
    93  	x.CheckfNoTrace(err)
    94  
    95  	return dg
    96  }
    97  
    98  // DgraphClientWithGroot creates a Dgraph client with groot permissions set up.
    99  // It is intended to be called from TestMain() to establish a Dgraph connection shared
   100  // by all tests, so there is no testing.T instance for it to use.
   101  func DgraphClientWithGroot(serviceAddr string) *dgo.Dgraph {
   102  	dg := DgraphClient(serviceAddr)
   103  
   104  	var err error
   105  	ctx := context.Background()
   106  	for {
   107  		// keep retrying until we succeed or receive a non-retriable error
   108  		err = dg.Login(ctx, x.GrootId, "password")
   109  		if err == nil || !strings.Contains(err.Error(), "Please retry") {
   110  			break
   111  		}
   112  		time.Sleep(time.Second)
   113  	}
   114  	x.CheckfNoTrace(err)
   115  
   116  	return dg
   117  }
   118  
   119  // DgraphClient creates a Dgraph client.
   120  // It is intended to be called from TestMain() to establish a Dgraph connection shared
   121  // by all tests, so there is no testing.T instance for it to use.
   122  func DgraphClient(serviceAddr string) *dgo.Dgraph {
   123  	conn, err := grpc.Dial(serviceAddr, grpc.WithInsecure())
   124  	x.CheckfNoTrace(err)
   125  
   126  	dg := dgo.NewDgraphClient(api.NewDgraphClient(conn))
   127  	x.CheckfNoTrace(err)
   128  
   129  	return dg
   130  }
   131  
   132  // DgraphClientWithCerts creates a Dgraph client with TLS configured using the given
   133  // viper configuration.
   134  // It is intended to be called from TestMain() to establish a Dgraph connection shared
   135  // by all tests, so there is no testing.T instance for it to use.
   136  func DgraphClientWithCerts(serviceAddr string, conf *viper.Viper) (*dgo.Dgraph, error) {
   137  	tlsCfg, err := x.LoadClientTLSConfig(conf)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	dialOpts := []grpc.DialOption{}
   143  	if tlsCfg != nil {
   144  		dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)))
   145  	} else {
   146  		dialOpts = append(dialOpts, grpc.WithInsecure())
   147  	}
   148  	conn, err := grpc.Dial(serviceAddr, dialOpts...)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	dg := dgo.NewDgraphClient(api.NewDgraphClient(conn))
   153  	return dg, nil
   154  }
   155  
   156  // DropAll drops all the data in the Dgraph instance associated with the given client.
   157  func DropAll(t *testing.T, dg *dgo.Dgraph) {
   158  	err := dg.Alter(context.Background(), &api.Operation{DropAll: true})
   159  	require.NoError(t, err)
   160  }
   161  
   162  // RetryQuery will retry a query until it succeeds or a non-retryable error is received.
   163  func RetryQuery(dg *dgo.Dgraph, q string) (*api.Response, error) {
   164  	for {
   165  		resp, err := dg.NewTxn().Query(context.Background(), q)
   166  		if err != nil && strings.Contains(err.Error(), "Please retry") {
   167  			time.Sleep(10 * time.Millisecond)
   168  			continue
   169  		}
   170  		return resp, err
   171  	}
   172  }
   173  
   174  // RetryMutation will retry a mutation until it succeeds or a non-retryable error is received.
   175  // The mutation should have CommitNow set to true.
   176  func RetryMutation(dg *dgo.Dgraph, mu *api.Mutation) error {
   177  	for {
   178  		_, err := dg.NewTxn().Mutate(context.Background(), mu)
   179  		if err != nil && (strings.Contains(err.Error(), "Please retry") ||
   180  			strings.Contains(err.Error(), "Tablet isn't being served by this instance")) {
   181  			time.Sleep(10 * time.Millisecond)
   182  			continue
   183  		}
   184  		return err
   185  	}
   186  }
   187  
   188  // LoginParams stores the information needed to perform a login request.
   189  type LoginParams struct {
   190  	Endpoint   string
   191  	UserID     string
   192  	Passwd     string
   193  	RefreshJwt string
   194  }
   195  
   196  // HttpLogin sends a HTTP request to the server
   197  // and returns the access JWT and refresh JWT extracted from
   198  // the HTTP response
   199  func HttpLogin(params *LoginParams) (string, string, error) {
   200  	loginPayload := api.LoginRequest{}
   201  	if len(params.RefreshJwt) > 0 {
   202  		loginPayload.RefreshToken = params.RefreshJwt
   203  	} else {
   204  		loginPayload.Userid = params.UserID
   205  		loginPayload.Password = params.Passwd
   206  	}
   207  
   208  	body, err := json.Marshal(&loginPayload)
   209  	if err != nil {
   210  		return "", "", errors.Wrapf(err, "unable to marshal body")
   211  	}
   212  
   213  	req, err := http.NewRequest("POST", params.Endpoint, bytes.NewBuffer(body))
   214  	if err != nil {
   215  		return "", "", errors.Wrapf(err, "unable to create request")
   216  	}
   217  
   218  	client := &http.Client{}
   219  	resp, err := client.Do(req)
   220  	if err != nil {
   221  		return "", "", errors.Wrapf(err, "login through curl failed")
   222  	}
   223  	defer resp.Body.Close()
   224  
   225  	respBody, err := ioutil.ReadAll(resp.Body)
   226  	if err != nil {
   227  		return "", "", errors.Wrapf(err, "unable to read from response")
   228  	}
   229  
   230  	var outputJson map[string]map[string]string
   231  	if err := json.Unmarshal(respBody, &outputJson); err != nil {
   232  		var errOutputJson map[string]interface{}
   233  		if err := json.Unmarshal(respBody, &errOutputJson); err == nil {
   234  			if _, ok := errOutputJson["errors"]; ok {
   235  				return "", "", errors.Errorf("response error: %v", string(respBody))
   236  			}
   237  		}
   238  		return "", "", errors.Wrapf(err, "unable to unmarshal the output to get JWTs")
   239  	}
   240  
   241  	data, found := outputJson["data"]
   242  	if !found {
   243  		return "", "", errors.Wrapf(err, "data entry found in the output")
   244  	}
   245  
   246  	newAccessJwt, found := data["accessJWT"]
   247  	if !found {
   248  		return "", "", errors.Errorf("no access JWT found in the output")
   249  	}
   250  	newRefreshJwt, found := data["refreshJWT"]
   251  	if !found {
   252  		return "", "", errors.Errorf("no refresh JWT found in the output")
   253  	}
   254  
   255  	return newAccessJwt, newRefreshJwt, nil
   256  }
   257  
   258  // GrootHttpLogin logins using the groot account with the default password
   259  // and returns the access JWT
   260  func GrootHttpLogin(endpoint string) (string, string) {
   261  	accessJwt, refreshJwt, err := HttpLogin(&LoginParams{
   262  		Endpoint: endpoint,
   263  		UserID:   x.GrootId,
   264  		Passwd:   "password",
   265  	})
   266  	x.Check(err)
   267  	return accessJwt, refreshJwt
   268  }
   269  
   270  // CurlFailureConfig stores information about the expected failure of a curl test.
   271  type CurlFailureConfig struct {
   272  	ShouldFail   bool
   273  	CurlErrMsg   string
   274  	DgraphErrMsg string
   275  }
   276  
   277  type curlErrorEntry struct {
   278  	Code    string `json:"code"`
   279  	Message string `json:"message"`
   280  }
   281  
   282  type curlOutput struct {
   283  	Data   map[string]interface{} `json:"data"`
   284  	Errors []curlErrorEntry       `json:"errors"`
   285  }
   286  
   287  func verifyOutput(t *testing.T, bytes []byte, failureConfig *CurlFailureConfig) {
   288  	output := curlOutput{}
   289  	require.NoError(t, json.Unmarshal(bytes, &output),
   290  		"unable to unmarshal the curl output")
   291  
   292  	if failureConfig.ShouldFail {
   293  		require.True(t, len(output.Errors) > 0, "no error entry found")
   294  		if len(failureConfig.DgraphErrMsg) > 0 {
   295  			errorEntry := output.Errors[0]
   296  			require.True(t, strings.Contains(errorEntry.Message, failureConfig.DgraphErrMsg),
   297  				fmt.Sprintf("the failure msg\n%s\nis not part of the curl error output:%s\n",
   298  					failureConfig.DgraphErrMsg, errorEntry.Message))
   299  		}
   300  	} else {
   301  		require.True(t, len(output.Data) > 0,
   302  			fmt.Sprintf("no data entry found in the output:%+v", output))
   303  	}
   304  }
   305  
   306  // VerifyCurlCmd executes the curl command with the given arguments and verifies
   307  // the result against the expected output.
   308  func VerifyCurlCmd(t *testing.T, args []string,
   309  	failureConfig *CurlFailureConfig) {
   310  	queryCmd := exec.Command("curl", args...)
   311  
   312  	output, err := queryCmd.Output()
   313  	if len(failureConfig.CurlErrMsg) > 0 {
   314  		// the curl command should have returned an non-zero code
   315  		require.Error(t, err, "the curl command should have failed")
   316  		if ee, ok := err.(*exec.ExitError); ok {
   317  			require.True(t, strings.Contains(string(ee.Stderr), failureConfig.CurlErrMsg),
   318  				"the curl output does not contain the expected output")
   319  		}
   320  	} else {
   321  		require.NoError(t, err, "the curl command should have succeeded")
   322  		verifyOutput(t, output, failureConfig)
   323  	}
   324  }
   325  
   326  // AssignUids talks to zero to assign the given number of uids.
   327  func AssignUids(num uint64) {
   328  	_, err := http.Get(fmt.Sprintf("http://"+SockAddrZeroHttp+"/assign?what=uids&num=%d", num))
   329  	if err != nil {
   330  		panic(fmt.Sprintf("Could not assign uids. Got error %v", err.Error()))
   331  	}
   332  }