github.com/argoproj/argo-cd@v1.8.7/pkg/apiclient/grpcproxy.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"google.golang.org/grpc"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/metadata"
    17  	"google.golang.org/grpc/status"
    18  
    19  	argocderrors "github.com/argoproj/argo-cd/util/errors"
    20  	argoio "github.com/argoproj/argo-cd/util/io"
    21  	"github.com/argoproj/argo-cd/util/rand"
    22  )
    23  
    24  const (
    25  	frameHeaderLength = 5
    26  	endOfStreamFlag   = 128
    27  )
    28  
    29  type noopCodec struct{}
    30  
    31  func (noopCodec) Marshal(v interface{}) ([]byte, error) {
    32  	return v.([]byte), nil
    33  }
    34  
    35  func (noopCodec) Unmarshal(data []byte, v interface{}) error {
    36  	pointer := v.(*[]byte)
    37  	*pointer = data
    38  	return nil
    39  }
    40  
    41  func (noopCodec) String() string {
    42  	return "bytes"
    43  }
    44  
    45  func toFrame(msg []byte) []byte {
    46  	frame := append([]byte{0, 0, 0, 0}, msg...)
    47  	binary.BigEndian.PutUint32(frame, uint32(len(msg)))
    48  	frame = append([]byte{0}, frame...)
    49  	return frame
    50  }
    51  
    52  func (c *client) executeRequest(fullMethodName string, msg []byte, md metadata.MD) (*http.Response, error) {
    53  	schema := "https"
    54  	if c.PlainText {
    55  		schema = "http"
    56  	}
    57  	rootPath := strings.TrimRight(strings.TrimLeft(c.GRPCWebRootPath, "/"), "/")
    58  
    59  	var requestURL string
    60  	if rootPath != "" {
    61  		requestURL = fmt.Sprintf("%s://%s/%s%s", schema, c.ServerAddr, rootPath, fullMethodName)
    62  	} else {
    63  		requestURL = fmt.Sprintf("%s://%s%s", schema, c.ServerAddr, fullMethodName)
    64  	}
    65  	req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewReader(toFrame(msg)))
    66  
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	for k, v := range md {
    71  		if strings.HasPrefix(k, ":") {
    72  			continue
    73  		}
    74  		for i := range v {
    75  			req.Header.Set(k, v[i])
    76  		}
    77  	}
    78  	req.Header.Set("content-type", "application/grpc-web+proto")
    79  
    80  	client := &http.Client{}
    81  	if !c.PlainText {
    82  		tlsConfig, err := c.tlsConfig()
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  		client.Transport = &http.Transport{
    87  			TLSClientConfig: tlsConfig,
    88  		}
    89  	}
    90  
    91  	resp, err := client.Do(req)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	if resp.StatusCode != http.StatusOK {
    96  		return nil, fmt.Errorf("%s %s failed with status code %d", req.Method, req.URL, resp.StatusCode)
    97  	}
    98  	var code codes.Code
    99  	if statusStr := resp.Header.Get("Grpc-Status"); statusStr != "" {
   100  		statusInt, err := strconv.Atoi(statusStr)
   101  		if err != nil {
   102  			code = codes.Unknown
   103  		} else {
   104  			code = codes.Code(statusInt)
   105  		}
   106  		if code != codes.OK {
   107  			return nil, status.Error(code, resp.Header.Get("Grpc-Message"))
   108  		}
   109  	}
   110  	return resp, nil
   111  }
   112  
   113  func (c *client) startGRPCProxy() (*grpc.Server, net.Listener, error) {
   114  	serverAddr := fmt.Sprintf("%s/argocd-%s.sock", os.TempDir(), rand.RandString(16))
   115  	ln, err := net.Listen("unix", serverAddr)
   116  
   117  	if err != nil {
   118  		return nil, nil, err
   119  	}
   120  	proxySrv := grpc.NewServer(
   121  		grpc.CustomCodec(&noopCodec{}),
   122  		grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
   123  			fullMethodName, ok := grpc.MethodFromServerStream(stream)
   124  			if !ok {
   125  				return fmt.Errorf("Unable to get method name from stream context.")
   126  			}
   127  			msg := make([]byte, 0)
   128  			err = stream.RecvMsg(&msg)
   129  			if err != nil {
   130  				return err
   131  			}
   132  
   133  			md, _ := metadata.FromIncomingContext(stream.Context())
   134  
   135  			for _, kv := range c.Headers {
   136  				if len(strings.Split(kv, ":"))%2 == 1 {
   137  					return fmt.Errorf("additional headers key/values must be separated by a colon(:): %s", kv)
   138  				}
   139  				md.Append(strings.Split(kv, ":")[0], strings.Split(kv, ":")[1])
   140  			}
   141  
   142  			resp, err := c.executeRequest(fullMethodName, msg, md)
   143  			if err != nil {
   144  				return err
   145  			}
   146  
   147  			go func() {
   148  				<-stream.Context().Done()
   149  				argoio.Close(resp.Body)
   150  			}()
   151  			defer argoio.Close(resp.Body)
   152  
   153  			for {
   154  				header := make([]byte, frameHeaderLength)
   155  				if _, err := resp.Body.Read(header); err != nil {
   156  					if err == io.EOF {
   157  						err = io.ErrUnexpectedEOF
   158  					}
   159  					return err
   160  				}
   161  
   162  				if header[0] == endOfStreamFlag {
   163  					return nil
   164  				}
   165  				length := int(binary.BigEndian.Uint32(header[1:frameHeaderLength]))
   166  				data := make([]byte, length)
   167  
   168  				if read, err := io.ReadAtLeast(resp.Body, data, length); err != nil {
   169  					if err != io.EOF {
   170  						return err
   171  					} else if read < length {
   172  						return io.ErrUnexpectedEOF
   173  					} else {
   174  						return nil
   175  					}
   176  				}
   177  
   178  				if err := stream.SendMsg(data); err != nil {
   179  					return err
   180  				}
   181  
   182  			}
   183  		}))
   184  	go func() {
   185  		err := proxySrv.Serve(ln)
   186  		argocderrors.CheckError(err)
   187  	}()
   188  	return proxySrv, ln, nil
   189  }
   190  
   191  // useGRPCProxy ensures that grpc proxy server is started and return closer which stops server when no one uses it
   192  func (c *client) useGRPCProxy() (net.Addr, io.Closer, error) {
   193  	c.proxyMutex.Lock()
   194  	defer c.proxyMutex.Unlock()
   195  
   196  	if c.proxyListener == nil {
   197  		var err error
   198  		c.proxyServer, c.proxyListener, err = c.startGRPCProxy()
   199  		if err != nil {
   200  			return nil, nil, err
   201  		}
   202  	}
   203  	c.proxyUsersCount = c.proxyUsersCount + 1
   204  
   205  	return c.proxyListener.Addr(), argoio.NewCloser(func() error {
   206  		c.proxyMutex.Lock()
   207  		defer c.proxyMutex.Unlock()
   208  		c.proxyUsersCount = c.proxyUsersCount - 1
   209  		if c.proxyUsersCount == 0 {
   210  			c.proxyServer.Stop()
   211  			c.proxyListener = nil
   212  			c.proxyServer = nil
   213  			return nil
   214  		}
   215  		return nil
   216  	}), nil
   217  }