code-intelligence.com/cifuzz@v0.40.0/internal/api/api.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"mime/multipart"
     9  	"net"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strings"
    16  	"syscall"
    17  	"time"
    18  
    19  	"github.com/pkg/errors"
    20  	"golang.org/x/net/proxy"
    21  	"golang.org/x/sync/errgroup"
    22  	"golang.org/x/term"
    23  
    24  	"code-intelligence.com/cifuzz/internal/cmd/remoterun/progress"
    25  	"code-intelligence.com/cifuzz/internal/cmdutils"
    26  	"code-intelligence.com/cifuzz/pkg/log"
    27  	"code-intelligence.com/cifuzz/util/stringutil"
    28  )
    29  
    30  // APIError is returned when a REST request returns a status code other
    31  // than 200 OK
    32  type APIError struct {
    33  	err        error
    34  	StatusCode int
    35  }
    36  
    37  func (e APIError) Error() string {
    38  	return e.err.Error()
    39  }
    40  
    41  func (e APIError) Format(s fmt.State, verb rune) {
    42  	if formatter, ok := e.err.(fmt.Formatter); ok {
    43  		formatter.Format(s, verb)
    44  	} else {
    45  		_, _ = io.WriteString(s, e.Error())
    46  	}
    47  }
    48  
    49  func (e APIError) Unwrap() error {
    50  	return e.err
    51  }
    52  
    53  func responseToAPIError(resp *http.Response) error {
    54  	msg := resp.Status
    55  	body, err := io.ReadAll(resp.Body)
    56  	if err != nil {
    57  		return &APIError{StatusCode: resp.StatusCode, err: errors.New(msg)}
    58  	}
    59  	apiResp := struct {
    60  		Code    int
    61  		Message string
    62  	}{}
    63  	err = json.Unmarshal(body, &apiResp)
    64  	if err != nil {
    65  		return &APIError{StatusCode: resp.StatusCode, err: errors.Errorf("%s: %s", msg, string(body))}
    66  	}
    67  	return &APIError{StatusCode: resp.StatusCode, err: errors.Errorf("%s: %s", msg, apiResp.Message)}
    68  }
    69  
    70  // ConnectionError is returned when a REST request fails to connect to the API
    71  type ConnectionError struct {
    72  	err error
    73  }
    74  
    75  func (e ConnectionError) Error() string {
    76  	return e.err.Error()
    77  }
    78  
    79  func (e ConnectionError) Unwrap() error {
    80  	return e.err
    81  }
    82  
    83  // WrapConnectionError wraps an error returned by the API client in a
    84  // ConnectionError to avoid having the error message printed when the error is
    85  // handled.
    86  func WrapConnectionError(err error) error {
    87  	return &ConnectionError{err}
    88  }
    89  
    90  type APIClient struct {
    91  	Server    string
    92  	UserAgent string
    93  }
    94  
    95  var FeaturedProjectsOrganization = "organizations/1"
    96  
    97  type Artifact struct {
    98  	DisplayName  string `json:"display-name"`
    99  	ResourceName string `json:"resource-name"`
   100  }
   101  
   102  func NewClient(server string, version string) *APIClient {
   103  	return &APIClient{
   104  		Server:    server,
   105  		UserAgent: "cifuzz/" + version + " " + runtime.GOOS + "-" + runtime.GOARCH,
   106  	}
   107  }
   108  
   109  func (client *APIClient) UploadBundle(path string, projectName string, token string) (*Artifact, error) {
   110  	signalHandlerCtx, cancelSignalHandler := context.WithCancel(context.Background())
   111  	routines, routinesCtx := errgroup.WithContext(context.Background())
   112  
   113  	// Cancel the routines context when receiving a termination signal
   114  	sigs := make(chan os.Signal, 1)
   115  	signal.Notify(sigs, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
   116  	routines.Go(func() error {
   117  		select {
   118  		case <-signalHandlerCtx.Done():
   119  			return nil
   120  		case s := <-sigs:
   121  			log.Warnf("Received %s", s.String())
   122  			return cmdutils.NewSignalError(s.(syscall.Signal))
   123  		}
   124  	})
   125  
   126  	// Use a pipe to avoid reading the artifacts into memory at once
   127  	r, w := io.Pipe()
   128  	m := multipart.NewWriter(w)
   129  
   130  	// Write the artifacts to the pipe
   131  	routines.Go(func() error {
   132  		defer w.Close()
   133  		defer m.Close()
   134  
   135  		part, err := m.CreateFormFile("fuzzing-artifacts", path)
   136  		if err != nil {
   137  			return errors.WithStack(err)
   138  		}
   139  
   140  		fileInfo, err := os.Stat(path)
   141  		if err != nil {
   142  			return errors.WithStack(err)
   143  		}
   144  
   145  		f, err := os.Open(path)
   146  		if err != nil {
   147  			return errors.WithStack(err)
   148  		}
   149  		defer f.Close()
   150  
   151  		var reader io.Reader
   152  		printProgress := term.IsTerminal(int(os.Stdout.Fd()))
   153  		if printProgress {
   154  			fmt.Println("Uploading...")
   155  			reader = progress.NewReader(f, fileInfo.Size(), "Upload complete")
   156  		} else {
   157  			reader = f
   158  		}
   159  
   160  		_, err = io.Copy(part, reader)
   161  		return errors.WithStack(err)
   162  	})
   163  
   164  	// Send a POST request with what we read from the pipe. The request
   165  	// gets cancelled with the routines context is cancelled, which
   166  	// happens if an error occurs in the io.Copy above or the user if
   167  	// cancels the operation.
   168  	var body []byte
   169  	routines.Go(func() error {
   170  		defer r.Close()
   171  		defer cancelSignalHandler()
   172  		url, err := url.JoinPath(client.Server, "v2", projectName, "artifacts", "import")
   173  		if err != nil {
   174  			return errors.WithStack(err)
   175  		}
   176  		req, err := http.NewRequestWithContext(routinesCtx, "POST", url, r)
   177  		if err != nil {
   178  			return errors.WithStack(err)
   179  		}
   180  
   181  		req.Header.Set("User-Agent", client.UserAgent)
   182  		req.Header.Set("Content-Type", m.FormDataContentType())
   183  		req.Header.Add("Authorization", "Bearer "+token)
   184  
   185  		httpClient := &http.Client{Transport: getCustomTransport()}
   186  		resp, err := httpClient.Do(req)
   187  		if err != nil {
   188  			return errors.WithStack(err)
   189  		}
   190  		defer resp.Body.Close()
   191  
   192  		if resp.StatusCode != 200 {
   193  			return responseToAPIError(resp)
   194  		}
   195  
   196  		body, err = io.ReadAll(resp.Body)
   197  		if err != nil {
   198  			return errors.WithStack(err)
   199  		}
   200  
   201  		return nil
   202  	})
   203  
   204  	err := routines.Wait()
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	artifact := &Artifact{}
   210  	err = json.Unmarshal(body, artifact)
   211  	if err != nil {
   212  		err = errors.WithStack(err)
   213  		log.Errorf(err, "Failed to parse response from upload bundle API call: %s", err.Error())
   214  		return nil, cmdutils.WrapSilentError(err)
   215  	}
   216  
   217  	return artifact, nil
   218  }
   219  
   220  func (client *APIClient) StartRemoteFuzzingRun(artifact *Artifact, token string) (string, error) {
   221  	url, err := url.JoinPath("/v1", artifact.ResourceName+":run")
   222  	if err != nil {
   223  		return "", err
   224  	}
   225  	resp, err := client.sendRequest("POST", url, nil, token)
   226  	if err != nil {
   227  		return "", err
   228  	}
   229  	defer resp.Body.Close()
   230  
   231  	if resp.StatusCode != 200 {
   232  		return "", responseToAPIError(resp)
   233  	}
   234  
   235  	// Get the campaign run name from the response
   236  	body, err := io.ReadAll(resp.Body)
   237  	if err != nil {
   238  		return "", errors.WithStack(err)
   239  	}
   240  	var objmap map[string]json.RawMessage
   241  	err = json.Unmarshal(body, &objmap)
   242  	if err != nil {
   243  		return "", errors.WithStack(err)
   244  	}
   245  	campaignRunNameJSON, ok := objmap["name"]
   246  	if !ok {
   247  		err = errors.Errorf("Server response doesn't include run name: %v", stringutil.PrettyString(objmap))
   248  		log.Error(err)
   249  		return "", cmdutils.WrapSilentError(err)
   250  	}
   251  	var campaignRunName string
   252  	err = json.Unmarshal(campaignRunNameJSON, &campaignRunName)
   253  	if err != nil {
   254  		return "", errors.WithStack(err)
   255  	}
   256  
   257  	return campaignRunName, nil
   258  }
   259  
   260  // sendRequest sends a request to the API server with a default timeout of 30 seconds.
   261  func (client *APIClient) sendRequest(method string, endpoint string, body io.Reader, token string) (*http.Response, error) {
   262  	// we use 30 seconds as a conservative timeout for the API server to
   263  	// respond to a request. We might have to revisit this value in the future
   264  	// after the rollout of our API features.
   265  	timeout := 30 * time.Second
   266  	return client.sendRequestWithTimeout(method, endpoint, body, token, timeout)
   267  }
   268  
   269  // sendRequestWithTimeout sends a request to the API server with a timeout.
   270  func (client *APIClient) sendRequestWithTimeout(method string, endpoint string, body io.Reader, token string, timeout time.Duration) (*http.Response, error) {
   271  	url, err := url.JoinPath(client.Server, endpoint)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	req, err := http.NewRequestWithContext(context.Background(), method, url, body)
   276  	if err != nil {
   277  		return nil, errors.WithStack(err)
   278  	}
   279  
   280  	req.Header.Set("User-Agent", client.UserAgent)
   281  	req.Header.Add("Authorization", "Bearer "+token)
   282  
   283  	httpClient := &http.Client{Transport: getCustomTransport(), Timeout: timeout}
   284  	resp, err := httpClient.Do(req)
   285  	if err != nil {
   286  		return nil, WrapConnectionError(errors.WithStack(err))
   287  	}
   288  
   289  	return resp, nil
   290  }
   291  
   292  // IsTokenValid checks if the token is valid by querying the API server.
   293  func (client *APIClient) IsTokenValid(token string) (bool, error) {
   294  	// TOOD: Change this to use another check without querying projects
   295  	_, err := client.ListProjects(token)
   296  	if err != nil {
   297  		var apiErr *APIError
   298  		if errors.As(err, &apiErr) {
   299  			if apiErr.StatusCode == 401 {
   300  				log.Warnf("Invalid token: Received 401 Unauthorized from server %s", client.Server)
   301  				return false, nil
   302  			}
   303  		}
   304  		return false, err
   305  	}
   306  	return true, nil
   307  }
   308  
   309  func validateURL(s string) error {
   310  	u, err := url.Parse(s)
   311  	if err != nil {
   312  		return errors.WithStack(err)
   313  	}
   314  	if u.Scheme != "http" && u.Scheme != "https" {
   315  		return errors.Errorf("unsupported protocol scheme %q", u.Scheme)
   316  	}
   317  	return nil
   318  }
   319  
   320  func ValidateAndNormalizeServerURL(server string) (string, error) {
   321  	// Check if the server option is a valid URL
   322  	err := validateURL(server)
   323  	if err != nil {
   324  		// See if prefixing https:// makes it a valid URL
   325  		err = validateURL("https://" + server)
   326  		if err != nil {
   327  			log.Error(err, fmt.Sprintf("server %q is not a valid URL", server))
   328  		}
   329  		server = "https://" + server
   330  	}
   331  
   332  	// normalize server URL by removing trailing slash
   333  	url, err := url.JoinPath(server, "")
   334  	if err != nil {
   335  		return "", err
   336  	}
   337  	url = strings.TrimSuffix(url, "/")
   338  
   339  	return url, nil
   340  }
   341  
   342  func getCustomTransport() *http.Transport {
   343  	// it is not possible to use the default Proxy Environment because
   344  	// of https://github.com/golang/go/issues/24135
   345  	dialer := proxy.FromEnvironment()
   346  	dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
   347  		return dialer.Dial(network, address)
   348  	}
   349  	return &http.Transport{DialContext: dialContext}
   350  }