github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/common/security.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"io/ioutil"
    21  	"net"
    22  	"net/http"
    23  	"net/http/httptest"
    24  
    25  	"github.com/pingcap/errors"
    26  	pd "github.com/tikv/pd/client"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/credentials"
    29  )
    30  
    31  // TLS
    32  type TLS struct {
    33  	caPath   string
    34  	certPath string
    35  	keyPath  string
    36  	inner    *tls.Config
    37  	client   *http.Client
    38  	url      string
    39  }
    40  
    41  // ToTLSConfig constructs a `*tls.Config` from the CA, certification and key
    42  // paths.
    43  //
    44  // If the CA path is empty, returns nil.
    45  func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
    46  	if len(caPath) == 0 {
    47  		return nil, nil
    48  	}
    49  
    50  	// Load the client certificates from disk
    51  	var certificates []tls.Certificate
    52  	if len(certPath) != 0 && len(keyPath) != 0 {
    53  		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
    54  		if err != nil {
    55  			return nil, errors.Annotate(err, "could not load client key pair")
    56  		}
    57  		certificates = []tls.Certificate{cert}
    58  	}
    59  
    60  	// Create a certificate pool from CA
    61  	certPool := x509.NewCertPool()
    62  	ca, err := ioutil.ReadFile(caPath)
    63  	if err != nil {
    64  		return nil, errors.Annotate(err, "could not read ca certificate")
    65  	}
    66  
    67  	// Append the certificates from the CA
    68  	if !certPool.AppendCertsFromPEM(ca) {
    69  		return nil, errors.New("failed to append ca certs")
    70  	}
    71  
    72  	return &tls.Config{
    73  		Certificates: certificates,
    74  		RootCAs:      certPool,
    75  		NextProtos:   []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
    76  	}, nil
    77  }
    78  
    79  // NewTLS constructs a new HTTP client with TLS configured with the CA,
    80  // certificate and key paths.
    81  //
    82  // If the CA path is empty, returns an instance where TLS is disabled.
    83  func NewTLS(caPath, certPath, keyPath, host string) (*TLS, error) {
    84  	if len(caPath) == 0 {
    85  		return &TLS{
    86  			inner:  nil,
    87  			client: &http.Client{},
    88  			url:    "http://" + host,
    89  		}, nil
    90  	}
    91  	inner, err := ToTLSConfig(caPath, certPath, keyPath)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	transport := http.DefaultTransport.(*http.Transport).Clone()
    96  	transport.TLSClientConfig = inner
    97  	return &TLS{
    98  		caPath:   caPath,
    99  		certPath: certPath,
   100  		keyPath:  keyPath,
   101  		inner:    inner,
   102  		client:   &http.Client{Transport: transport},
   103  		url:      "https://" + host,
   104  	}, nil
   105  }
   106  
   107  // NewTLSFromMockServer constructs a new TLS instance from the certificates of
   108  // an *httptest.Server.
   109  func NewTLSFromMockServer(server *httptest.Server) *TLS {
   110  	return &TLS{
   111  		inner:  server.TLS,
   112  		client: server.Client(),
   113  		url:    server.URL,
   114  	}
   115  }
   116  
   117  // WithHost creates a new TLS instance with the host replaced.
   118  func (tc *TLS) WithHost(host string) *TLS {
   119  	var url string
   120  	if tc.inner != nil {
   121  		url = "https://" + host
   122  	} else {
   123  		url = "http://" + host
   124  	}
   125  	return &TLS{
   126  		inner:  tc.inner,
   127  		client: tc.client,
   128  		url:    url,
   129  	}
   130  }
   131  
   132  // ToGRPCDialOption constructs a gRPC dial option.
   133  func (tc *TLS) ToGRPCDialOption() grpc.DialOption {
   134  	if tc.inner != nil {
   135  		return grpc.WithTransportCredentials(credentials.NewTLS(tc.inner))
   136  	}
   137  	return grpc.WithInsecure()
   138  }
   139  
   140  // WrapListener places a TLS layer on top of the existing listener.
   141  func (tc *TLS) WrapListener(l net.Listener) net.Listener {
   142  	if tc.inner == nil {
   143  		return l
   144  	}
   145  	return tls.NewListener(l, tc.inner)
   146  }
   147  
   148  func (tc *TLS) GetJSON(ctx context.Context, path string, v interface{}) error {
   149  	return GetJSON(ctx, tc.client, tc.url+path, v)
   150  }
   151  
   152  func (tc *TLS) ToPDSecurityOption() pd.SecurityOption {
   153  	return pd.SecurityOption{
   154  		CAPath:   tc.caPath,
   155  		CertPath: tc.certPath,
   156  		KeyPath:  tc.keyPath,
   157  	}
   158  }
   159  
   160  func (tc *TLS) TLSConfig() *tls.Config {
   161  	return tc.inner
   162  }