github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/testutil/client.go (about) 1 /* 2 * Copyright 2019 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package testutil 18 19 import ( 20 "bytes" 21 "context" 22 "encoding/json" 23 "fmt" 24 "io/ioutil" 25 "net/http" 26 "os" 27 "os/exec" 28 "strconv" 29 "strings" 30 "testing" 31 "time" 32 33 "github.com/dgraph-io/dgo" 34 "github.com/dgraph-io/dgo/protos/api" 35 "github.com/dgraph-io/dgraph/x" 36 "github.com/pkg/errors" 37 "github.com/spf13/viper" 38 "github.com/stretchr/testify/require" 39 "google.golang.org/grpc" 40 "google.golang.org/grpc/credentials" 41 ) 42 43 // socket addr = IP address and port number 44 var ( 45 46 // SockAddr is the address to the gRPC endpoint of the alpha used during tests. 47 SockAddr string 48 // SockAddrHttp is the address to the HTTP of alpha used during tests. 49 SockAddrHttp string 50 // SockAddrZero is the address to the gRPC endpoint of the zero used during tests. 51 SockAddrZero string 52 // SockAddrZeroHttp is the address to the HTTP endpoint of the zero used during tests. 53 SockAddrZeroHttp string 54 ) 55 56 // This allows running (most) tests against dgraph running on the default ports, for example. 57 // Only the GRPC ports are needed and the others are deduced. 58 func init() { 59 var grpcPort int 60 61 getPort := func(envVar string, dfault int) int { 62 p := os.Getenv(envVar) 63 if p == "" { 64 return dfault 65 } 66 port, _ := strconv.Atoi(p) 67 return port 68 } 69 70 grpcPort = getPort("TEST_PORT_ALPHA", 9180) 71 SockAddr = fmt.Sprintf("localhost:%d", grpcPort) 72 SockAddrHttp = fmt.Sprintf("localhost:%d", grpcPort-1000) 73 74 grpcPort = getPort("TEST_PORT_ZERO", 5180) 75 SockAddrZero = fmt.Sprintf("localhost:%d", grpcPort) 76 SockAddrZeroHttp = fmt.Sprintf("localhost:%d", grpcPort+1000) 77 } 78 79 // DgraphClientDropAll creates a Dgraph client and drops all existing data. 80 // It is intended to be called from TestMain() to establish a Dgraph connection shared 81 // by all tests, so there is no testing.T instance for it to use. 82 func DgraphClientDropAll(serviceAddr string) *dgo.Dgraph { 83 dg := DgraphClient(serviceAddr) 84 var err error 85 for { 86 // keep retrying until we succeed or receive a non-retriable error 87 err := dg.Alter(context.Background(), &api.Operation{DropAll: true}) 88 if err == nil || !strings.Contains(err.Error(), "Please retry") { 89 break 90 } 91 time.Sleep(time.Second) 92 } 93 x.CheckfNoTrace(err) 94 95 return dg 96 } 97 98 // DgraphClientWithGroot creates a Dgraph client with groot permissions set up. 99 // It is intended to be called from TestMain() to establish a Dgraph connection shared 100 // by all tests, so there is no testing.T instance for it to use. 101 func DgraphClientWithGroot(serviceAddr string) *dgo.Dgraph { 102 dg := DgraphClient(serviceAddr) 103 104 var err error 105 ctx := context.Background() 106 for { 107 // keep retrying until we succeed or receive a non-retriable error 108 err = dg.Login(ctx, x.GrootId, "password") 109 if err == nil || !strings.Contains(err.Error(), "Please retry") { 110 break 111 } 112 time.Sleep(time.Second) 113 } 114 x.CheckfNoTrace(err) 115 116 return dg 117 } 118 119 // DgraphClient creates a Dgraph client. 120 // It is intended to be called from TestMain() to establish a Dgraph connection shared 121 // by all tests, so there is no testing.T instance for it to use. 122 func DgraphClient(serviceAddr string) *dgo.Dgraph { 123 conn, err := grpc.Dial(serviceAddr, grpc.WithInsecure()) 124 x.CheckfNoTrace(err) 125 126 dg := dgo.NewDgraphClient(api.NewDgraphClient(conn)) 127 x.CheckfNoTrace(err) 128 129 return dg 130 } 131 132 // DgraphClientWithCerts creates a Dgraph client with TLS configured using the given 133 // viper configuration. 134 // It is intended to be called from TestMain() to establish a Dgraph connection shared 135 // by all tests, so there is no testing.T instance for it to use. 136 func DgraphClientWithCerts(serviceAddr string, conf *viper.Viper) (*dgo.Dgraph, error) { 137 tlsCfg, err := x.LoadClientTLSConfig(conf) 138 if err != nil { 139 return nil, err 140 } 141 142 dialOpts := []grpc.DialOption{} 143 if tlsCfg != nil { 144 dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg))) 145 } else { 146 dialOpts = append(dialOpts, grpc.WithInsecure()) 147 } 148 conn, err := grpc.Dial(serviceAddr, dialOpts...) 149 if err != nil { 150 return nil, err 151 } 152 dg := dgo.NewDgraphClient(api.NewDgraphClient(conn)) 153 return dg, nil 154 } 155 156 // DropAll drops all the data in the Dgraph instance associated with the given client. 157 func DropAll(t *testing.T, dg *dgo.Dgraph) { 158 err := dg.Alter(context.Background(), &api.Operation{DropAll: true}) 159 require.NoError(t, err) 160 } 161 162 // RetryQuery will retry a query until it succeeds or a non-retryable error is received. 163 func RetryQuery(dg *dgo.Dgraph, q string) (*api.Response, error) { 164 for { 165 resp, err := dg.NewTxn().Query(context.Background(), q) 166 if err != nil && strings.Contains(err.Error(), "Please retry") { 167 time.Sleep(10 * time.Millisecond) 168 continue 169 } 170 return resp, err 171 } 172 } 173 174 // RetryMutation will retry a mutation until it succeeds or a non-retryable error is received. 175 // The mutation should have CommitNow set to true. 176 func RetryMutation(dg *dgo.Dgraph, mu *api.Mutation) error { 177 for { 178 _, err := dg.NewTxn().Mutate(context.Background(), mu) 179 if err != nil && (strings.Contains(err.Error(), "Please retry") || 180 strings.Contains(err.Error(), "Tablet isn't being served by this instance")) { 181 time.Sleep(10 * time.Millisecond) 182 continue 183 } 184 return err 185 } 186 } 187 188 // LoginParams stores the information needed to perform a login request. 189 type LoginParams struct { 190 Endpoint string 191 UserID string 192 Passwd string 193 RefreshJwt string 194 } 195 196 // HttpLogin sends a HTTP request to the server 197 // and returns the access JWT and refresh JWT extracted from 198 // the HTTP response 199 func HttpLogin(params *LoginParams) (string, string, error) { 200 loginPayload := api.LoginRequest{} 201 if len(params.RefreshJwt) > 0 { 202 loginPayload.RefreshToken = params.RefreshJwt 203 } else { 204 loginPayload.Userid = params.UserID 205 loginPayload.Password = params.Passwd 206 } 207 208 body, err := json.Marshal(&loginPayload) 209 if err != nil { 210 return "", "", errors.Wrapf(err, "unable to marshal body") 211 } 212 213 req, err := http.NewRequest("POST", params.Endpoint, bytes.NewBuffer(body)) 214 if err != nil { 215 return "", "", errors.Wrapf(err, "unable to create request") 216 } 217 218 client := &http.Client{} 219 resp, err := client.Do(req) 220 if err != nil { 221 return "", "", errors.Wrapf(err, "login through curl failed") 222 } 223 defer resp.Body.Close() 224 225 respBody, err := ioutil.ReadAll(resp.Body) 226 if err != nil { 227 return "", "", errors.Wrapf(err, "unable to read from response") 228 } 229 230 var outputJson map[string]map[string]string 231 if err := json.Unmarshal(respBody, &outputJson); err != nil { 232 var errOutputJson map[string]interface{} 233 if err := json.Unmarshal(respBody, &errOutputJson); err == nil { 234 if _, ok := errOutputJson["errors"]; ok { 235 return "", "", errors.Errorf("response error: %v", string(respBody)) 236 } 237 } 238 return "", "", errors.Wrapf(err, "unable to unmarshal the output to get JWTs") 239 } 240 241 data, found := outputJson["data"] 242 if !found { 243 return "", "", errors.Wrapf(err, "data entry found in the output") 244 } 245 246 newAccessJwt, found := data["accessJWT"] 247 if !found { 248 return "", "", errors.Errorf("no access JWT found in the output") 249 } 250 newRefreshJwt, found := data["refreshJWT"] 251 if !found { 252 return "", "", errors.Errorf("no refresh JWT found in the output") 253 } 254 255 return newAccessJwt, newRefreshJwt, nil 256 } 257 258 // GrootHttpLogin logins using the groot account with the default password 259 // and returns the access JWT 260 func GrootHttpLogin(endpoint string) (string, string) { 261 accessJwt, refreshJwt, err := HttpLogin(&LoginParams{ 262 Endpoint: endpoint, 263 UserID: x.GrootId, 264 Passwd: "password", 265 }) 266 x.Check(err) 267 return accessJwt, refreshJwt 268 } 269 270 // CurlFailureConfig stores information about the expected failure of a curl test. 271 type CurlFailureConfig struct { 272 ShouldFail bool 273 CurlErrMsg string 274 DgraphErrMsg string 275 } 276 277 type curlErrorEntry struct { 278 Code string `json:"code"` 279 Message string `json:"message"` 280 } 281 282 type curlOutput struct { 283 Data map[string]interface{} `json:"data"` 284 Errors []curlErrorEntry `json:"errors"` 285 } 286 287 func verifyOutput(t *testing.T, bytes []byte, failureConfig *CurlFailureConfig) { 288 output := curlOutput{} 289 require.NoError(t, json.Unmarshal(bytes, &output), 290 "unable to unmarshal the curl output") 291 292 if failureConfig.ShouldFail { 293 require.True(t, len(output.Errors) > 0, "no error entry found") 294 if len(failureConfig.DgraphErrMsg) > 0 { 295 errorEntry := output.Errors[0] 296 require.True(t, strings.Contains(errorEntry.Message, failureConfig.DgraphErrMsg), 297 fmt.Sprintf("the failure msg\n%s\nis not part of the curl error output:%s\n", 298 failureConfig.DgraphErrMsg, errorEntry.Message)) 299 } 300 } else { 301 require.True(t, len(output.Data) > 0, 302 fmt.Sprintf("no data entry found in the output:%+v", output)) 303 } 304 } 305 306 // VerifyCurlCmd executes the curl command with the given arguments and verifies 307 // the result against the expected output. 308 func VerifyCurlCmd(t *testing.T, args []string, 309 failureConfig *CurlFailureConfig) { 310 queryCmd := exec.Command("curl", args...) 311 312 output, err := queryCmd.Output() 313 if len(failureConfig.CurlErrMsg) > 0 { 314 // the curl command should have returned an non-zero code 315 require.Error(t, err, "the curl command should have failed") 316 if ee, ok := err.(*exec.ExitError); ok { 317 require.True(t, strings.Contains(string(ee.Stderr), failureConfig.CurlErrMsg), 318 "the curl output does not contain the expected output") 319 } 320 } else { 321 require.NoError(t, err, "the curl command should have succeeded") 322 verifyOutput(t, output, failureConfig) 323 } 324 } 325 326 // AssignUids talks to zero to assign the given number of uids. 327 func AssignUids(num uint64) { 328 _, err := http.Get(fmt.Sprintf("http://"+SockAddrZeroHttp+"/assign?what=uids&num=%d", num)) 329 if err != nil { 330 panic(fmt.Sprintf("Could not assign uids. Got error %v", err.Error())) 331 } 332 }