github.com/Cloud-Foundations/Dominator@v0.3.4/lib/srpc/setupserver/impl.go (about)

     1  package setupserver
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"path"
    12  	"time"
    13  
    14  	"github.com/Cloud-Foundations/Dominator/lib/format"
    15  	"github.com/Cloud-Foundations/Dominator/lib/log/nulllogger"
    16  	"github.com/Cloud-Foundations/Dominator/lib/srpc"
    17  )
    18  
    19  var (
    20  	caFile = flag.String("CAfile", "/etc/ssl/CA.pem",
    21  		"Name of file containing the root of trust for identity and methods")
    22  	certFile = flag.String("certFile",
    23  		path.Join("/etc/ssl", getDirname(), "cert.pem"),
    24  		"Name of file containing the SSL certificate")
    25  	identityCaFile = flag.String("identityCAfile", "/etc/ssl/IdentityCA.pem",
    26  		"Name of file containing the root of trust for identity only")
    27  	keyFile = flag.String("keyFile",
    28  		path.Join("/etc/ssl", getDirname(), "key.pem"),
    29  		"Name of file containing the SSL key")
    30  )
    31  
    32  func getDirname() string {
    33  	return path.Base(os.Args[0])
    34  }
    35  
    36  func getSleepInterval(cert *x509.Certificate) time.Duration {
    37  	day := 24 * time.Hour
    38  	week := 7 * day
    39  	lifetime := cert.NotAfter.Sub(cert.NotBefore)
    40  	refreshIn := time.Until(cert.NotBefore.Add(7 * lifetime >> 3))
    41  	if refreshIn > 0 {
    42  		return refreshIn
    43  	}
    44  	expiresIn := time.Until(cert.NotAfter)
    45  	if expiresIn > 2*week {
    46  		return week
    47  	} else if expiresIn > 2*day {
    48  		return day
    49  	} else if expiresIn > 2*time.Hour {
    50  		return time.Hour
    51  	} else if expiresIn > 2*time.Minute {
    52  		return time.Minute
    53  	} else {
    54  		return 5 * time.Second
    55  	}
    56  }
    57  
    58  func loadClientCert(params Params) (*tls.Certificate, error) {
    59  	// Load certificate and key.
    60  	if *certFile == "" || *keyFile == "" {
    61  		cert, err := srpc.LoadCertificatesFromMetadata(100*time.Millisecond,
    62  			true, false)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		params.Logger.Debugln(0,
    67  			"Loaded certifcate and key from metadata service\n")
    68  		return cert, nil
    69  	}
    70  	cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
    71  	if err != nil {
    72  		if os.IsNotExist(err) {
    73  			cert, e := srpc.LoadCertificatesFromMetadata(100*time.Millisecond,
    74  				true, false)
    75  			if e != nil {
    76  				return nil, err
    77  			}
    78  			params.Logger.Debugln(0,
    79  				"Loaded certifcate and key from metadata service\n")
    80  			return cert, nil
    81  		}
    82  		return nil, fmt.Errorf("unable to load keypair: %s", err)
    83  	}
    84  	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	cert.Leaf = x509Cert
    89  	params.Logger.Debugf(0, "Loaded certifcate and key from: %s and %s\n",
    90  		*certFile, *keyFile)
    91  	return &cert, nil
    92  }
    93  
    94  func loadLoop(params Params, cert *x509.Certificate) {
    95  	params.FailIfExpired = true
    96  	for {
    97  		time.Sleep(getSleepInterval(cert))
    98  		if c, err := setupTlsOnce(params); err != nil {
    99  			params.Logger.Println(err)
   100  		} else {
   101  			cert = c
   102  		}
   103  	}
   104  }
   105  
   106  func setupTls(params Params) error {
   107  	if params.Logger == nil {
   108  		params.Logger = nulllogger.New()
   109  	}
   110  	cert, err := setupTlsOnce(params)
   111  	if err != nil {
   112  		return err
   113  	}
   114  	go loadLoop(params, cert)
   115  	return nil
   116  }
   117  
   118  func setupTlsOnce(params Params) (*x509.Certificate, error) {
   119  	// Setup client.
   120  	tlsCert, err := loadClientCert(params)
   121  	if err != nil {
   122  		return nil, fmt.Errorf("unable to load keypair: %s", err)
   123  	}
   124  	now := time.Now()
   125  	x509Cert := tlsCert.Leaf
   126  	if notYet := x509Cert.NotBefore.Sub(now); notYet > 0 {
   127  		msg := fmt.Sprintf("%s will not be valid for %s",
   128  			*certFile, format.Duration(notYet))
   129  		if params.FailIfExpired {
   130  			return nil, errors.New(msg)
   131  		}
   132  		params.Logger.Println(msg)
   133  	} else if expired := now.Sub(x509Cert.NotAfter); expired > 0 {
   134  		msg := fmt.Sprintf("%s expired %s ago",
   135  			*certFile, format.Duration(expired))
   136  		if params.FailIfExpired {
   137  			return nil, errors.New(msg)
   138  		}
   139  		params.Logger.Println(msg)
   140  	} else {
   141  		params.Logger.Debugf(0, "Certificate expires at: %s (%s)\n",
   142  			x509Cert.NotAfter.Local(),
   143  			format.Duration(time.Until(x509Cert.NotAfter)))
   144  	}
   145  	clientConfig := new(tls.Config)
   146  	clientConfig.InsecureSkipVerify = true
   147  	clientConfig.MinVersion = tls.VersionTLS12
   148  	clientConfig.Certificates = append(clientConfig.Certificates, *tlsCert)
   149  	srpc.RegisterClientTlsConfig(clientConfig)
   150  	if !params.ClientOnly {
   151  		if *caFile == "" {
   152  			return nil, srpc.ErrorMissingCA
   153  		}
   154  		caData, err := ioutil.ReadFile(*caFile)
   155  		if err != nil {
   156  			if os.IsNotExist(err) {
   157  				return nil, srpc.ErrorMissingCA
   158  			}
   159  			return nil, fmt.Errorf("unable to load CA file: \"%s\": %s",
   160  				*caFile, err)
   161  		}
   162  		caCertPool := x509.NewCertPool()
   163  		if !caCertPool.AppendCertsFromPEM(caData) {
   164  			return nil, fmt.Errorf("unable to parse CA file")
   165  		}
   166  		serverConfig := new(tls.Config)
   167  		serverConfig.ClientAuth = tls.RequireAndVerifyClientCert
   168  		serverConfig.MinVersion = tls.VersionTLS12
   169  		serverConfig.ClientCAs = caCertPool
   170  		serverConfig.Certificates = append(serverConfig.Certificates, *tlsCert)
   171  		if *identityCaFile != "" {
   172  			identityCaData, err := ioutil.ReadFile(*identityCaFile)
   173  			if err != nil {
   174  				if !os.IsNotExist(err) {
   175  					return nil, fmt.Errorf("unable to load CA file: \"%s\": %s",
   176  						*caFile, err)
   177  				}
   178  			} else {
   179  				srpc.RegisterFullAuthCA(caCertPool)
   180  				caCertPool := x509.NewCertPool()
   181  				if !caCertPool.AppendCertsFromPEM(caData) {
   182  					return nil, fmt.Errorf("unable to parse CA file")
   183  				}
   184  				if !caCertPool.AppendCertsFromPEM(identityCaData) {
   185  					return nil, fmt.Errorf("unable to parse identity CA file")
   186  				}
   187  				serverConfig.ClientCAs = caCertPool
   188  			}
   189  		}
   190  		srpc.RegisterServerTlsConfig(serverConfig, true)
   191  	}
   192  	return tlsCert.Leaf, nil
   193  }