github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/gqlclient/gqlclient.go (about)

     1  package gqlclient
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"mime/multipart"
    10  	"net/http"
    11  	"os"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/ActiveState/cli/internal/constants"
    16  	"github.com/ActiveState/cli/internal/errs"
    17  	"github.com/ActiveState/cli/internal/logging"
    18  	"github.com/ActiveState/cli/internal/profile"
    19  	"github.com/ActiveState/cli/internal/singleton/uniqid"
    20  	"github.com/ActiveState/cli/internal/strutils"
    21  	"github.com/ActiveState/cli/pkg/platform/api"
    22  	"github.com/ActiveState/graphql"
    23  	"github.com/pkg/errors"
    24  )
    25  
    26  type File struct {
    27  	Field string
    28  	Name  string
    29  	R     io.Reader
    30  }
    31  
    32  type Request0 interface {
    33  	Query() string
    34  	Vars() map[string]interface{}
    35  }
    36  
    37  type Request interface {
    38  	Query() string
    39  	Vars() (map[string]interface{}, error)
    40  }
    41  
    42  type RequestWithFiles interface {
    43  	Request
    44  	Files() []File
    45  }
    46  
    47  type Header map[string][]string
    48  
    49  type graphqlClient = graphql.Client
    50  
    51  // StandardizedErrors works around API's that don't follow the graphql standard
    52  // It looks redundant because it needs to address two different API responses.
    53  // https://activestatef.atlassian.net/browse/PB-4291
    54  type StandardizedErrors struct {
    55  	Message string
    56  	Error   string
    57  	Errors  []graphErr
    58  }
    59  
    60  func (e StandardizedErrors) HasErrors() bool {
    61  	return len(e.Errors) > 0 || e.Error != ""
    62  }
    63  
    64  // Values tells us all the relevant error messages returned.
    65  // We don't include e.Error because it's an unhelpful generic error code redundant with the message.
    66  func (e StandardizedErrors) Values() []string {
    67  	var errs []string
    68  	for _, err := range e.Errors {
    69  		errs = append(errs, err.Message)
    70  	}
    71  	if e.Message != "" {
    72  		errs = append(errs, e.Message)
    73  	}
    74  	return errs
    75  }
    76  
    77  type graphResponse struct {
    78  	Data    interface{}
    79  	Error   string
    80  	Message string
    81  	Errors  []graphErr
    82  }
    83  
    84  type graphErr struct {
    85  	Message string
    86  }
    87  
    88  func (e graphErr) Error() string {
    89  	return "graphql: " + e.Message
    90  }
    91  
    92  type BearerTokenProvider interface {
    93  	BearerToken() string
    94  }
    95  
    96  type Client struct {
    97  	*graphqlClient
    98  	url           string
    99  	tokenProvider BearerTokenProvider
   100  	timeout       time.Duration
   101  }
   102  
   103  func NewWithOpts(url string, timeout time.Duration, opts ...graphql.ClientOption) *Client {
   104  	if timeout == 0 {
   105  		timeout = time.Second * 60
   106  	}
   107  
   108  	client := &Client{
   109  		graphqlClient: graphql.NewClient(url, opts...),
   110  		timeout:       timeout,
   111  		url:           url,
   112  	}
   113  	if os.Getenv(constants.DebugServiceRequestsEnvVarName) == "true" {
   114  		client.EnableDebugLog()
   115  	}
   116  	return client
   117  }
   118  
   119  func New(url string, timeout time.Duration) *Client {
   120  	return NewWithOpts(url, timeout, graphql.WithHTTPClient(api.NewHTTPClient()))
   121  }
   122  
   123  // EnableDebugLog turns on debug logging
   124  func (c *Client) EnableDebugLog() {
   125  	c.graphqlClient.Log = func(s string) { logging.Debug("graphqlClient log message: %s", s) }
   126  }
   127  
   128  func (c *Client) SetTokenProvider(tokenProvider BearerTokenProvider) {
   129  	c.tokenProvider = tokenProvider
   130  }
   131  
   132  func (c *Client) SetDebug(b bool) {
   133  	c.graphqlClient.Log = func(string) {}
   134  	if b {
   135  		c.graphqlClient.Log = func(s string) {
   136  			fmt.Fprintln(os.Stderr, s)
   137  		}
   138  	}
   139  }
   140  
   141  func (c *Client) Run(request Request, response interface{}) error {
   142  	ctx := context.Background()
   143  	if c.timeout != 0 {
   144  		var cancel context.CancelFunc
   145  		ctx, cancel = context.WithTimeout(ctx, c.timeout)
   146  		defer cancel()
   147  	}
   148  	err := c.RunWithContext(ctx, request, response)
   149  	return err // Needs var so the cancel defer triggers at the right time
   150  }
   151  
   152  type PostProcessor interface {
   153  	PostProcess() error
   154  }
   155  
   156  func (c *Client) RunWithContext(ctx context.Context, request Request, response interface{}) (rerr error) {
   157  	defer func() {
   158  		if rerr != nil {
   159  			return
   160  		}
   161  		if postProcessor, ok := response.(PostProcessor); ok {
   162  			rerr = postProcessor.PostProcess()
   163  		}
   164  	}()
   165  	name := strutils.Summarize(request.Query(), 25)
   166  	defer profile.Measure(fmt.Sprintf("gqlclient:RunWithContext:(%s)", name), time.Now())
   167  
   168  	if fileRequest, ok := request.(RequestWithFiles); ok {
   169  		return c.runWithFiles(ctx, fileRequest, response)
   170  	}
   171  
   172  	vars, err := request.Vars()
   173  	if err != nil {
   174  		return errs.Wrap(err, "Could not get variables")
   175  	}
   176  
   177  	graphRequest := graphql.NewRequest(request.Query())
   178  	for key, value := range vars {
   179  		graphRequest.Var(key, value)
   180  	}
   181  
   182  	if fileRequest, ok := request.(RequestWithFiles); ok {
   183  		for _, file := range fileRequest.Files() {
   184  			graphRequest.File(file.Field, file.Name, file.R)
   185  		}
   186  	}
   187  
   188  	var bearerToken string
   189  	if c.tokenProvider != nil {
   190  		bearerToken = c.tokenProvider.BearerToken()
   191  		if bearerToken != "" {
   192  			graphRequest.Header.Set("Authorization", "Bearer "+bearerToken)
   193  		}
   194  	}
   195  
   196  	graphRequest.Header.Set("X-Requestor", uniqid.Text())
   197  
   198  	if err := c.graphqlClient.Run(ctx, graphRequest, &response); err != nil {
   199  		return NewRequestError(err, request)
   200  	}
   201  
   202  	return nil
   203  }
   204  
   205  type JsonRequest struct {
   206  	Query     string                 `json:"query"`
   207  	Variables map[string]interface{} `json:"variables"`
   208  }
   209  
   210  func (c *Client) runWithFiles(ctx context.Context, gqlReq RequestWithFiles, response interface{}) error {
   211  	// Construct the multi-part request.
   212  	bodyReader, bodyWriter := io.Pipe()
   213  
   214  	req, err := http.NewRequest("POST", c.url, bodyReader)
   215  	if err != nil {
   216  		return errs.Wrap(err, "Could not create http request")
   217  	}
   218  
   219  	req.Body = bodyReader
   220  
   221  	mw := multipart.NewWriter(bodyWriter)
   222  	req.Header.Set("Content-Type", "multipart/form-data; boundary="+mw.Boundary())
   223  
   224  	vars, err := gqlReq.Vars()
   225  	if err != nil {
   226  		return errs.Wrap(err, "Could not get variables")
   227  	}
   228  
   229  	varJson, err := json.Marshal(vars)
   230  	if err != nil {
   231  		return errs.Wrap(err, "Could not marshal vars")
   232  	}
   233  
   234  	reqErrChan := make(chan error)
   235  	go func() {
   236  		defer bodyWriter.Close()
   237  		defer mw.Close()
   238  		defer close(reqErrChan)
   239  
   240  		// Operations
   241  		operations, err := mw.CreateFormField("operations")
   242  		if err != nil {
   243  			reqErrChan <- errs.Wrap(err, "Could not create form field operations")
   244  			return
   245  		}
   246  
   247  		jsonReq := JsonRequest{
   248  			Query:     gqlReq.Query(),
   249  			Variables: vars,
   250  		}
   251  		jsonReqV, err := json.Marshal(jsonReq)
   252  		if err != nil {
   253  			reqErrChan <- errs.Wrap(err, "Could not marshal json request")
   254  			return
   255  		}
   256  		if _, err := operations.Write(jsonReqV); err != nil {
   257  			reqErrChan <- errs.Wrap(err, "Could not write json request")
   258  			return
   259  		}
   260  
   261  		// Map
   262  		if len(gqlReq.Files()) > 0 {
   263  			mapField, err := mw.CreateFormField("map")
   264  			if err != nil {
   265  				reqErrChan <- errs.Wrap(err, "Could not create form field map")
   266  				return
   267  			}
   268  			for n, f := range gqlReq.Files() {
   269  				if _, err := mapField.Write([]byte(fmt.Sprintf(`{"%d": ["%s"]}`, n, f.Field))); err != nil {
   270  					reqErrChan <- errs.Wrap(err, "Could not write map field")
   271  					return
   272  				}
   273  			}
   274  			// File upload
   275  			for n, file := range gqlReq.Files() {
   276  				part, err := mw.CreateFormFile(fmt.Sprintf("%d", n), file.Name)
   277  				if err != nil {
   278  					reqErrChan <- errs.Wrap(err, "Could not create form file")
   279  					return
   280  				}
   281  
   282  				_, err = io.Copy(part, file.R)
   283  				if err != nil {
   284  					reqErrChan <- errs.Wrap(err, "Could not read file")
   285  					return
   286  				}
   287  			}
   288  		}
   289  	}()
   290  
   291  	c.Log(fmt.Sprintf(">> query: %s", gqlReq.Query()))
   292  	c.Log(fmt.Sprintf(">> variables: %s", string(varJson)))
   293  	fnames := []string{}
   294  	for _, file := range gqlReq.Files() {
   295  		fnames = append(fnames, file.Name)
   296  	}
   297  	c.Log(fmt.Sprintf(">> files: %v", fnames))
   298  
   299  	// Run the request.
   300  	var bearerToken string
   301  	if c.tokenProvider != nil {
   302  		bearerToken = c.tokenProvider.BearerToken()
   303  		if bearerToken != "" {
   304  			req.Header.Set("Authorization", "Bearer "+bearerToken)
   305  		}
   306  	}
   307  	if os.Getenv(constants.DebugServiceRequestsEnvVarName) == "true" {
   308  		responseData, err := json.MarshalIndent(response, "", "  ")
   309  		if err != nil {
   310  			return errs.Wrap(err, "failed to marshal response")
   311  		}
   312  		logging.Debug("gqlclient: response: %s", responseData)
   313  	}
   314  
   315  	gr := &graphResponse{
   316  		Data: response,
   317  	}
   318  	req = req.WithContext(ctx)
   319  	c.Log(fmt.Sprintf(">> Raw Request: %s\n", req.URL.String()))
   320  
   321  	var res *http.Response
   322  	resErrChan := make(chan error)
   323  	go func() {
   324  		var err error
   325  		res, err = http.DefaultClient.Do(req)
   326  		resErrChan <- err
   327  	}()
   328  
   329  	// Due to the streaming uploads the request error can happen both before and after the http request itself, hence
   330  	// the creative select case you see before you.
   331  	wait := true
   332  	for wait {
   333  		select {
   334  		case err := <-reqErrChan:
   335  			if err != nil {
   336  				c.Log(fmt.Sprintf("Request Error: %s", err))
   337  				return err
   338  			}
   339  		case err := <-resErrChan:
   340  			wait = false
   341  			if err != nil {
   342  				c.Log(fmt.Sprintf("Response Error: %s", err))
   343  				return err
   344  			}
   345  		}
   346  	}
   347  
   348  	if res == nil {
   349  		return errs.New("Received empty response")
   350  	}
   351  
   352  	defer res.Body.Close()
   353  	var buf bytes.Buffer
   354  	if _, err := io.Copy(&buf, res.Body); err != nil {
   355  		c.Log(fmt.Sprintf("Read Error: %s", err))
   356  		return errors.Wrap(err, "reading body")
   357  	}
   358  	resp := buf.Bytes()
   359  	c.Log(fmt.Sprintf("<< Response code: %d, body: %s\n", res.StatusCode, string(resp)))
   360  
   361  	// Work around API's that don't follow the graphql standard
   362  	// https://activestatef.atlassian.net/browse/PB-4291
   363  	standardizedErrors := StandardizedErrors{}
   364  	if err := json.Unmarshal(resp, &standardizedErrors); err != nil {
   365  		return errors.Wrap(err, "decoding error response")
   366  	}
   367  	if standardizedErrors.HasErrors() {
   368  		return errs.New(strings.Join(standardizedErrors.Values(), "\n"))
   369  	}
   370  
   371  	if err := json.Unmarshal(resp, &gr); err != nil {
   372  		return errors.Wrap(err, "decoding response")
   373  	}
   374  	return nil
   375  }