github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/mutualauth/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/pem"
    11  	"flag"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"log"
    15  	"math/big"
    16  	"net"
    17  	"net/http"
    18  	"path"
    19  	"path/filepath"
    20  	"time"
    21  )
    22  
    23  var (
    24  	isClient = flag.Bool("client", false, "Is the client")
    25  	addr     = flag.String("addr", "localhost:8080", "`address`")
    26  )
    27  
    28  func main() {
    29  	flag.Parse()
    30  
    31  	if filepath.Ext(flag.Arg(0)) == ".pem" {
    32  		if *isClient {
    33  			log.Fatal(NewCert(flag.Arg(0)))
    34  		} else {
    35  			log.Fatal(NewCert(flag.Arg(0)))
    36  		}
    37  	}
    38  
    39  	if *isClient {
    40  		log.Fatal(Client())
    41  	} else {
    42  		log.Fatal(Server())
    43  	}
    44  }
    45  
    46  func echo(rw http.ResponseWriter, req *http.Request) {
    47  	fmt.Fprintf(rw, "%v", req.URL)
    48  }
    49  
    50  func Client() error {
    51  	config := &tls.Config{
    52  		RootCAs: x509.NewCertPool(),
    53  	}
    54  
    55  	// load the client certificates
    56  	clientCert, err := tls.LoadX509KeyPair("client.pem", "client.pem")
    57  	if err != nil {
    58  		return err
    59  	}
    60  	config.Certificates = append(config.Certificates, clientCert)
    61  
    62  	// load list of verified servers
    63  	validServers, err := ioutil.ReadFile("server.pem")
    64  	if err != nil {
    65  		return err
    66  	}
    67  	config.RootCAs.AppendCertsFromPEM(validServers)
    68  
    69  	client := &http.Client{
    70  		Transport: &http.Transport{
    71  			TLSClientConfig: config,
    72  		},
    73  	}
    74  
    75  	resp, err := client.Get("https://" + path.Join(*addr, "hello", "world"))
    76  	if err != nil {
    77  		return err
    78  	}
    79  	defer resp.Body.Close()
    80  	body, err := ioutil.ReadAll(resp.Body)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	log.Println("RESPONSE", string(body))
    85  	return nil
    86  }
    87  
    88  func Server() error {
    89  	config := &tls.Config{
    90  		ClientAuth: tls.RequireAndVerifyClientCert,
    91  
    92  		ClientCAs: x509.NewCertPool(),
    93  	}
    94  
    95  	// load the server certificates
    96  	serverCert, err := tls.LoadX509KeyPair("server.pem", "server.pem")
    97  	if err != nil {
    98  		return err
    99  	}
   100  	config.Certificates = append(config.Certificates, serverCert)
   101  
   102  	// load the verified clients
   103  	clientCAs, err := ioutil.ReadFile("client.pem")
   104  	if err != nil {
   105  		return err
   106  	}
   107  	config.ClientCAs.AppendCertsFromPEM(clientCAs)
   108  
   109  	// create the server
   110  	server := &http.Server{
   111  		Addr:      *addr,
   112  		Handler:   http.HandlerFunc(echo),
   113  		TLSConfig: config,
   114  	}
   115  
   116  	// setup tcp server
   117  	inner, err := net.Listen("tcp", server.Addr)
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	// start serving
   123  	listener := tls.NewListener(tcpKeepAliveListener{inner.(*net.TCPListener)}, config)
   124  	return server.Serve(listener)
   125  }
   126  
   127  // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
   128  // connections. It's used by ListenAndServe and ListenAndServeTLS so
   129  // dead TCP connections (e.g. closing laptop mid-download) eventually
   130  // go away.
   131  type tcpKeepAliveListener struct {
   132  	*net.TCPListener
   133  }
   134  
   135  func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
   136  	tc, err := ln.AcceptTCP()
   137  	if err != nil {
   138  		return
   139  	}
   140  	tc.SetKeepAlive(true)
   141  	tc.SetKeepAlivePeriod(3 * time.Minute)
   142  	return tc, nil
   143  }
   144  
   145  func NewCert(pemfile string) error {
   146  	priv, err := rsa.GenerateKey(rand.Reader, 4096)
   147  	if err != nil {
   148  		return fmt.Errorf("error generating new key: %s", err)
   149  	}
   150  
   151  	notBefore := time.Now()
   152  	notAfter := notBefore.Add(5 * 365 * 24 * time.Hour) // 5 years
   153  
   154  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   155  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   156  	if err != nil {
   157  		return fmt.Errorf("failed to generate serial number: %s", err)
   158  	}
   159  
   160  	template := x509.Certificate{
   161  		SerialNumber: serialNumber,
   162  		Subject: pkix.Name{
   163  			Organization: []string{"Hello"},
   164  		},
   165  		NotBefore: notBefore,
   166  		NotAfter:  notAfter,
   167  
   168  		KeyUsage: x509.KeyUsageKeyEncipherment |
   169  			x509.KeyUsageDigitalSignature |
   170  			x509.KeyUsageDataEncipherment |
   171  			x509.KeyUsageCertSign,
   172  		ExtKeyUsage: []x509.ExtKeyUsage{
   173  			x509.ExtKeyUsageServerAuth,
   174  			x509.ExtKeyUsageClientAuth,
   175  		},
   176  
   177  		IsCA: true,
   178  		BasicConstraintsValid: true,
   179  
   180  		DNSNames: []string{"localhost"},
   181  	}
   182  
   183  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
   184  	if err != nil {
   185  		return fmt.Errorf("Failed to create certificate: %s", err)
   186  	}
   187  
   188  	var pemOut bytes.Buffer
   189  	pem.Encode(&pemOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
   190  	pem.Encode(&pemOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
   191  
   192  	return ioutil.WriteFile(pemfile, pemOut.Bytes(), 0644)
   193  }