github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/libkb/client.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"bytes"
     8  	"compress/gzip"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/http/cookiejar"
    16  	"net/url"
    17  	"regexp"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    24  	"github.com/keybase/go-framed-msgpack-rpc/rpc/resinit"
    25  	"golang.org/x/net/context"
    26  )
    27  
    28  type ClientConfig struct {
    29  	Host       string
    30  	Port       int
    31  	UseTLS     bool // XXX unused?
    32  	URL        *url.URL
    33  	RootCAs    *x509.CertPool
    34  	Prefix     string
    35  	UseCookies bool
    36  	Timeout    time.Duration
    37  }
    38  
    39  type Client struct {
    40  	cli    *http.Client
    41  	config *ClientConfig
    42  }
    43  
    44  var hostRE = regexp.MustCompile("^([^:]+)(:([0-9]+))?$")
    45  
    46  func SplitHost(joined string) (host string, port int, err error) {
    47  	match := hostRE.FindStringSubmatch(joined)
    48  	if match == nil {
    49  		err = fmt.Errorf("Invalid host/port found: %s", joined)
    50  	} else {
    51  		host = match[1]
    52  		port = 0
    53  		if len(match[3]) > 0 {
    54  			port, err = strconv.Atoi(match[3])
    55  			if err != nil {
    56  				err = fmt.Errorf("Could not convert port in host %s", joined)
    57  			}
    58  		}
    59  	}
    60  	return
    61  }
    62  
    63  func ParseCA(raw string) (*x509.CertPool, error) {
    64  	ret := x509.NewCertPool()
    65  	ok := ret.AppendCertsFromPEM([]byte(raw))
    66  	var err error
    67  	if !ok {
    68  		err = fmt.Errorf("Could not read CA for keybase.io")
    69  		ret = nil
    70  	}
    71  	return ret, err
    72  }
    73  
    74  func ShortCA(raw string) string {
    75  	parts := strings.Split(raw, "\n")
    76  	if len(parts) >= 3 {
    77  		parts = parts[0:3]
    78  	}
    79  	return strings.Join(parts, " ") + "..."
    80  }
    81  
    82  // GenClientConfigForInternalAPI pulls the information out of the environment configuration,
    83  // and build a Client config that will be used in all API server
    84  // requests
    85  func genClientConfigForInternalAPI(g *GlobalContext) (*ClientConfig, error) {
    86  	e := g.Env
    87  	serverURI, err := e.GetServerURI()
    88  
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	if e.GetTorMode().Enabled() {
    94  		serverURI = e.GetTorHiddenAddress()
    95  	}
    96  
    97  	if serverURI == "" {
    98  		err := fmt.Errorf("Cannot find a server URL")
    99  		return nil, err
   100  	}
   101  	url, err := url.Parse(serverURI)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	if url.Scheme == "" {
   107  		return nil, fmt.Errorf("Server URL missing Scheme")
   108  	}
   109  
   110  	if url.Host == "" {
   111  		return nil, fmt.Errorf("Server URL missing Host")
   112  	}
   113  
   114  	useTLS := (url.Scheme == "https")
   115  	host, port, e2 := SplitHost(url.Host)
   116  	if e2 != nil {
   117  		return nil, e2
   118  	}
   119  	var rootCAs *x509.CertPool
   120  	if rawCA := e.GetBundledCA(host); len(rawCA) > 0 {
   121  		rootCAs, err = ParseCA(rawCA)
   122  		if err != nil {
   123  			err = fmt.Errorf("In parsing CAs for %s: %s", host, err)
   124  			return nil, err
   125  		}
   126  		g.Log.Debug(fmt.Sprintf("Using special root CA for %s: %s",
   127  			host, ShortCA(rawCA)))
   128  	}
   129  
   130  	// If we're using proxies, they might have their own CAs.
   131  	if rootCAs, err = GetProxyCAs(rootCAs, e.config); err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	ret := &ClientConfig{host, port, useTLS, url, rootCAs, url.Path, true, e.GetAPITimeout()}
   136  	return ret, nil
   137  }
   138  
   139  func genClientConfigForScrapers(e *Env) (*ClientConfig, error) {
   140  	return &ClientConfig{
   141  		UseCookies: true,
   142  		Timeout:    e.GetScraperTimeout(),
   143  	}, nil
   144  }
   145  
   146  func NewClient(g *GlobalContext, config *ClientConfig, needCookie bool) (*Client, error) {
   147  	extraLog := func(ctx context.Context, msg string, args ...interface{}) {}
   148  	if g.Env.GetExtraNetLogging() {
   149  		extraLog = func(ctx context.Context, msg string, args ...interface{}) {
   150  			if ctx == nil {
   151  				g.Log.Debug(msg, args...)
   152  			} else {
   153  				g.Log.CDebugf(ctx, msg, args...)
   154  			}
   155  		}
   156  	}
   157  	extraLog(context.TODO(), "api.Client:%v New", needCookie)
   158  	env := g.Env
   159  	var jar *cookiejar.Jar
   160  	if needCookie && (config == nil || config.UseCookies) && env.GetTorMode().UseCookies() {
   161  		jar, _ = cookiejar.New(nil)
   162  	}
   163  
   164  	// Originally copied from http.DefaultTransport
   165  	dialer := net.Dialer{
   166  		Timeout:   30 * time.Second,
   167  		KeepAlive: 30 * time.Second,
   168  		DualStack: true,
   169  	}
   170  	xprt := http.Transport{
   171  		// Don't change this without re-testing proxy support. Currently the client supports proxies through
   172  		// environment variables that ProxyFromEnvironment picks up
   173  		Proxy:                 http.ProxyFromEnvironment,
   174  		DialContext:           (&dialer).DialContext,
   175  		MaxIdleConns:          200,
   176  		MaxIdleConnsPerHost:   100,
   177  		IdleConnTimeout:       90 * time.Second,
   178  		TLSHandshakeTimeout:   10 * time.Second,
   179  		ExpectContinueTimeout: 1 * time.Second,
   180  	}
   181  
   182  	xprt.DialContext = func(ctx context.Context, network, addr string) (c net.Conn, err error) {
   183  		c, err = dialer.DialContext(ctx, network, addr)
   184  		if err != nil {
   185  			extraLog(ctx, "api.Client:%v transport.Dial err=%v", needCookie, err)
   186  			// If we get a DNS error, it could be because glibc has cached an
   187  			// old version of /etc/resolv.conf. The res_init() libc function
   188  			// busts that cache and keeps us from getting stuck in a state
   189  			// where DNS requests keep failing even though the network is up.
   190  			// This is similar to what the Rust standard library does:
   191  			// https://github.com/rust-lang/rust/blob/028569ab1b/src/libstd/sys_common/net.rs#L186-L190
   192  			resinit.ResInitIfDNSError(err)
   193  			return c, err
   194  		}
   195  		if err = rpc.DisableSigPipe(c); err != nil {
   196  			extraLog(ctx, "api.Client:%v transport.Dial DisableSigPipe err=%v", needCookie, err)
   197  			return c, err
   198  		}
   199  		return c, nil
   200  	}
   201  
   202  	if config != nil && config.RootCAs != nil {
   203  		xprt.TLSClientConfig = &tls.Config{RootCAs: config.RootCAs}
   204  	}
   205  
   206  	xprt.Proxy = MakeProxy(env)
   207  
   208  	if !env.GetTorMode().Enabled() && env.GetRunMode() == DevelRunMode {
   209  		xprt.Proxy = func(req *http.Request) (*url.URL, error) {
   210  			host, port, err := net.SplitHostPort(req.URL.Host)
   211  			if err == nil && host == "localhost" {
   212  				// ProxyFromEnvironment refuses to proxy when the hostname is set to "localhost".
   213  				// So make a fake copy of the request with the url set to "127.0.0.1".
   214  				// This makes localhost requests use proxy settings.
   215  				// The Host could be anything and is only used to != "localhost".
   216  				url2 := *req.URL
   217  				url2.Host = "keybase.io:" + port
   218  				req2 := req
   219  				req2.URL = &url2
   220  				return http.ProxyFromEnvironment(req2)
   221  			}
   222  			return http.ProxyFromEnvironment(req)
   223  		}
   224  	}
   225  
   226  	var timeout time.Duration
   227  	if config == nil || config.Timeout == 0 {
   228  		timeout = HTTPDefaultTimeout
   229  	} else {
   230  		timeout = config.Timeout
   231  	}
   232  
   233  	ret := &Client{
   234  		cli:    &http.Client{Timeout: timeout},
   235  		config: config,
   236  	}
   237  	if jar != nil {
   238  		ret.cli.Jar = jar
   239  	}
   240  	ret.cli.Transport = NewInstrumentedRoundTripper(g, InstrumentationTagFromRequest, &xprt)
   241  	return ret, nil
   242  }
   243  
   244  func ServerLookup(env *Env, mode RunMode) (string, error) {
   245  	if mode == DevelRunMode {
   246  		return DevelServerURI, nil
   247  	}
   248  	if mode == StagingRunMode {
   249  		return StagingServerURI, nil
   250  	}
   251  	if mode == ProductionRunMode {
   252  		if env.IsCertPinningEnabled() {
   253  			// In order to disable SSL pinning we switch to doing requests against keybase.io which has a TLS
   254  			// cert signed by a publicly trusted CA (compared to api-1.keybaseapi.com which has a non-trusted but
   255  			// pinned certificate
   256  			return ProductionServerURI, nil
   257  		}
   258  		return ProductionSiteURI, nil
   259  	}
   260  	return "", fmt.Errorf("Did not find a server to use with the current RunMode!")
   261  }
   262  
   263  type InstrumentedBody struct {
   264  	MetaContextified
   265  	record *rpc.NetworkInstrumenter
   266  	body   io.ReadCloser
   267  	// track how large the body is
   268  	n int
   269  	// uncompressed indicates if the body was compressed on the wire but
   270  	// uncompressed by the http library. In this case we recompress to
   271  	// instrument the gzipped size.
   272  	uncompressed bool
   273  	gzipBuf      bytes.Buffer
   274  	gzipGetter   func(io.Writer) (*gzip.Writer, func())
   275  }
   276  
   277  var _ io.ReadCloser = (*InstrumentedBody)(nil)
   278  
   279  func NewInstrumentedBody(mctx MetaContext, record *rpc.NetworkInstrumenter, body io.ReadCloser, uncompressed bool,
   280  	gzipGetter func(io.Writer) (*gzip.Writer, func())) *InstrumentedBody {
   281  	return &InstrumentedBody{
   282  		MetaContextified: NewMetaContextified(mctx),
   283  		record:           record,
   284  		body:             body,
   285  		gzipGetter:       gzipGetter,
   286  		uncompressed:     uncompressed,
   287  	}
   288  }
   289  
   290  func (b *InstrumentedBody) Read(p []byte) (n int, err error) {
   291  	n, err = b.body.Read(p)
   292  	b.n += n
   293  	if b.uncompressed && n > 0 {
   294  		if n, err := b.gzipBuf.Write(p[:n]); err != nil {
   295  			return n, err
   296  		}
   297  	}
   298  	return n, err
   299  }
   300  
   301  func (b *InstrumentedBody) Close() (err error) {
   302  	// instrument the full body size even if the caller hasn't consumed it.
   303  	_, _ = io.Copy(io.Discard, b.body)
   304  	// Do actual instrumentation in the background
   305  	go func() {
   306  		if b.uncompressed {
   307  			// gzip the body we stored and instrument the compressed size
   308  			var buf bytes.Buffer
   309  			writer, reclaim := b.gzipGetter(&buf)
   310  			defer reclaim()
   311  			if _, err = writer.Write(b.gzipBuf.Bytes()); err != nil {
   312  				b.M().Debug("InstrumentedBody:unable to write gzip %v", err)
   313  				return
   314  			}
   315  			if err = writer.Close(); err != nil {
   316  				b.M().Debug("InstrumentedBody:unable to close gzip %v", err)
   317  				return
   318  			}
   319  			b.record.IncrementSize(int64(buf.Len()))
   320  		} else {
   321  			b.record.IncrementSize(int64(b.n))
   322  		}
   323  		if err := b.record.Finish(b.M().Ctx()); err != nil {
   324  			b.M().Debug("InstrumentedBody: unable to instrument network request: %s, %s", b.record, err)
   325  		}
   326  	}()
   327  	return b.body.Close()
   328  }
   329  
   330  type InstrumentedRoundTripper struct {
   331  	Contextified
   332  	RoundTripper http.RoundTripper
   333  	tagger       func(*http.Request) string
   334  	gzipPool     sync.Pool
   335  }
   336  
   337  var _ http.RoundTripper = (*InstrumentedRoundTripper)(nil)
   338  
   339  func NewInstrumentedRoundTripper(g *GlobalContext, tagger func(*http.Request) string, xprt http.RoundTripper) *InstrumentedRoundTripper {
   340  	return &InstrumentedRoundTripper{
   341  		Contextified: NewContextified(g),
   342  		RoundTripper: xprt,
   343  		tagger:       tagger,
   344  		gzipPool: sync.Pool{
   345  			New: func() interface{} {
   346  				return gzip.NewWriter(io.Discard)
   347  			},
   348  		},
   349  	}
   350  }
   351  
   352  func (i *InstrumentedRoundTripper) getGzipWriter(writer io.Writer) (*gzip.Writer, func()) {
   353  	gzipWriter := i.gzipPool.Get().(*gzip.Writer)
   354  	gzipWriter.Reset(writer)
   355  	return gzipWriter, func() {
   356  		i.gzipPool.Put(gzipWriter)
   357  	}
   358  }
   359  
   360  func (i *InstrumentedRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
   361  	tags := LogTagsFromString(req.Header.Get("X-Keybase-Log-Tags"))
   362  	mctx := NewMetaContextTODO(i.G()).WithLogTags(tags)
   363  	record := rpc.NewNetworkInstrumenter(i.G().RemoteNetworkInstrumenterStorage, i.tagger(req))
   364  	resp, err = i.RoundTripper.RoundTrip(req)
   365  	record.EndCall()
   366  	if err != nil {
   367  		if rerr := record.Finish(mctx.Ctx()); rerr != nil {
   368  			mctx.Debug("InstrumentedTransport: unable to instrument network request %s, %s", record, rerr)
   369  		}
   370  		return resp, err
   371  	}
   372  	resp.Body = NewInstrumentedBody(mctx, record, resp.Body, resp.Uncompressed, i.getGzipWriter)
   373  	return resp, err
   374  }