github.com/pingcap/br@v5.3.0-alpha.0.20220125034240-ec59c7b6ce30+incompatible/pkg/lightning/common/security.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"net"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"os"
    24  
    25  	"github.com/pingcap/errors"
    26  	pd "github.com/tikv/pd/client"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/credentials"
    29  
    30  	"github.com/pingcap/br/pkg/httputil"
    31  )
    32  
    33  // TLS
    34  type TLS struct {
    35  	caPath   string
    36  	certPath string
    37  	keyPath  string
    38  	inner    *tls.Config
    39  	client   *http.Client
    40  	url      string
    41  }
    42  
    43  // ToTLSConfig constructs a `*tls.Config` from the CA, certification and key
    44  // paths.
    45  //
    46  // If the CA path is empty, returns nil.
    47  func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
    48  	if len(caPath) == 0 {
    49  		return nil, nil
    50  	}
    51  
    52  	// Load the client certificates from disk
    53  	var certificates []tls.Certificate
    54  	if len(certPath) != 0 && len(keyPath) != 0 {
    55  		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
    56  		if err != nil {
    57  			return nil, errors.Annotate(err, "could not load client key pair")
    58  		}
    59  		certificates = []tls.Certificate{cert}
    60  	}
    61  
    62  	// Create a certificate pool from CA
    63  	certPool := x509.NewCertPool()
    64  	ca, err := os.ReadFile(caPath)
    65  	if err != nil {
    66  		return nil, errors.Annotate(err, "could not read ca certificate")
    67  	}
    68  
    69  	// Append the certificates from the CA
    70  	if !certPool.AppendCertsFromPEM(ca) {
    71  		return nil, errors.New("failed to append ca certs")
    72  	}
    73  
    74  	return &tls.Config{
    75  		Certificates: certificates,
    76  		RootCAs:      certPool,
    77  		NextProtos:   []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
    78  	}, nil
    79  }
    80  
    81  // NewTLS constructs a new HTTP client with TLS configured with the CA,
    82  // certificate and key paths.
    83  //
    84  // If the CA path is empty, returns an instance where TLS is disabled.
    85  func NewTLS(caPath, certPath, keyPath, host string) (*TLS, error) {
    86  	if len(caPath) == 0 {
    87  		return &TLS{
    88  			inner:  nil,
    89  			client: &http.Client{},
    90  			url:    "http://" + host,
    91  		}, nil
    92  	}
    93  	inner, err := ToTLSConfig(caPath, certPath, keyPath)
    94  	if err != nil {
    95  		return nil, errors.Trace(err)
    96  	}
    97  	return &TLS{
    98  		caPath:   caPath,
    99  		certPath: certPath,
   100  		keyPath:  keyPath,
   101  		inner:    inner,
   102  		client:   httputil.NewClient(inner),
   103  		url:      "https://" + host,
   104  	}, nil
   105  }
   106  
   107  // NewTLSFromMockServer constructs a new TLS instance from the certificates of
   108  // an *httptest.Server.
   109  func NewTLSFromMockServer(server *httptest.Server) *TLS {
   110  	return &TLS{
   111  		inner:  server.TLS,
   112  		client: server.Client(),
   113  		url:    server.URL,
   114  	}
   115  }
   116  
   117  // WithHost creates a new TLS instance with the host replaced.
   118  func (tc *TLS) WithHost(host string) *TLS {
   119  	var url string
   120  	if tc.inner != nil {
   121  		url = "https://" + host
   122  	} else {
   123  		url = "http://" + host
   124  	}
   125  	return &TLS{
   126  		inner:  tc.inner,
   127  		client: tc.client,
   128  		url:    url,
   129  	}
   130  }
   131  
   132  // ToGRPCDialOption constructs a gRPC dial option.
   133  func (tc *TLS) ToGRPCDialOption() grpc.DialOption {
   134  	if tc.inner != nil {
   135  		return grpc.WithTransportCredentials(credentials.NewTLS(tc.inner))
   136  	}
   137  	return grpc.WithInsecure()
   138  }
   139  
   140  // WrapListener places a TLS layer on top of the existing listener.
   141  func (tc *TLS) WrapListener(l net.Listener) net.Listener {
   142  	if tc.inner == nil {
   143  		return l
   144  	}
   145  	return tls.NewListener(l, tc.inner)
   146  }
   147  
   148  func (tc *TLS) GetJSON(ctx context.Context, path string, v interface{}) error {
   149  	return GetJSON(ctx, tc.client, tc.url+path, v)
   150  }
   151  
   152  func (tc *TLS) ToPDSecurityOption() pd.SecurityOption {
   153  	return pd.SecurityOption{
   154  		CAPath:   tc.caPath,
   155  		CertPath: tc.certPath,
   156  		KeyPath:  tc.keyPath,
   157  	}
   158  }
   159  
   160  func (tc *TLS) TLSConfig() *tls.Config {
   161  	return tc.inner
   162  }