github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/tcp/tcp_test.go (about)

     1  package tcp
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"log"
     7  	"reflect"
     8  	"testing"
     9  
    10  	"go.aporeto.io/enforcerd/trireme-lib/common"
    11  	acommon "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    12  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    13  )
    14  
    15  func testTLSCertificate() tls.Certificate {
    16  	certPem := []byte(`-----BEGIN CERTIFICATE-----
    17  MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
    18  DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
    19  EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
    20  7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
    21  5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
    22  BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
    23  NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
    24  Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
    25  6MF9+Yw1Yy0t
    26  -----END CERTIFICATE-----`)
    27  	keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
    28  MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
    29  AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
    30  EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
    31  -----END EC PRIVATE KEY-----`)
    32  	cert, err := tls.X509KeyPair(certPem, keyPem)
    33  	if err != nil {
    34  		log.Fatal(err)
    35  	}
    36  	return cert
    37  }
    38  
    39  func Test_getClientTLSConfig(t *testing.T) {
    40  	type args struct {
    41  		caPool      *x509.CertPool
    42  		clientCerts []tls.Certificate
    43  		serverName  string
    44  		external    bool
    45  	}
    46  	basicCaPool, _ := x509.SystemCertPool()
    47  	basicTLSCert := testTLSCertificate()
    48  	basicTLSCertList := []tls.Certificate{basicTLSCert}
    49  	tests := []struct {
    50  		name    string
    51  		args    args
    52  		wantT   *tls.Config
    53  		wantErr bool
    54  	}{
    55  		{
    56  			name: "basic external service",
    57  			args: args{
    58  				external:    true,
    59  				caPool:      nil,                 // no caPool => we dont need additional CAs to validate server certs. they might be using digicert/letsencrypt/any std cert.
    60  				clientCerts: []tls.Certificate{}, // no certs. for external service dont use client certs
    61  				serverName:  "www.google.com",
    62  			},
    63  			wantT: &tls.Config{
    64  				PreferServerCipherSuites: true,
    65  				SessionTicketsDisabled:   true,
    66  				MaxVersion:               tls.VersionTLS12,
    67  				ServerName:               "www.google.com",
    68  			},
    69  			wantErr: false,
    70  		},
    71  		{
    72  			name: "basic external service ignored client certs",
    73  			args: args{
    74  				external:    true,
    75  				caPool:      nil,              // no caPool => we dont need additional CAs to validate server certs. they might be using digicert/letsencrypt/any std cert.
    76  				clientCerts: basicTLSCertList, // clientCerts should be ignored for external service.
    77  				serverName:  "www.google.com",
    78  			},
    79  			wantT: &tls.Config{
    80  				PreferServerCipherSuites: true,
    81  				SessionTicketsDisabled:   true,
    82  				MaxVersion:               tls.VersionTLS12,
    83  				ServerName:               "www.google.com",
    84  			},
    85  			wantErr: false,
    86  		},
    87  		{
    88  			name: "basic external service with trusted ca pool and ignored cert list",
    89  			args: args{
    90  				caPool:      basicCaPool,      // caPool should be used to validate server certs
    91  				clientCerts: basicTLSCertList, // should be ignored as we dont provide client certs for external service
    92  				serverName:  "www.google.com",
    93  				external:    true,
    94  			},
    95  			wantT: &tls.Config{
    96  				PreferServerCipherSuites: true,
    97  				SessionTicketsDisabled:   true,
    98  				MaxVersion:               tls.VersionTLS12,
    99  				RootCAs:                  basicCaPool,
   100  				ServerName:               "www.google.com",
   101  			},
   102  			wantErr: false,
   103  		},
   104  	}
   105  	for _, tt := range tests {
   106  		t.Run(tt.name, func(t *testing.T) {
   107  			gotT, err := getClientTLSConfig(tt.args.caPool, tt.args.clientCerts, tt.args.serverName, tt.args.external)
   108  			if (err != nil) != tt.wantErr {
   109  				t.Errorf("getClientTLSConfig() error = %v, wantErr %v", err, tt.wantErr)
   110  				return
   111  			}
   112  			if !reflect.DeepEqual(gotT, tt.wantT) {
   113  				t.Errorf("getClientTLSConfig() = %+v, want %+v", gotT, tt.wantT)
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func Test_getTLSServerName(t *testing.T) {
   120  	type args struct {
   121  		addrAndPort string
   122  		service     *policy.ApplicationService
   123  	}
   124  	tests := []struct {
   125  		name     string
   126  		args     args
   127  		wantName string
   128  		wantErr  bool
   129  	}{
   130  		{
   131  			name:     "nil service and bad addr (error)",
   132  			args:     args{},
   133  			wantName: "",
   134  			wantErr:  true,
   135  		},
   136  		{
   137  			name: "service with nil network info and bad addr (error)",
   138  			args: args{
   139  				service: &policy.ApplicationService{},
   140  			},
   141  			wantName: "",
   142  			wantErr:  true,
   143  		},
   144  		{
   145  			name: "no fqdn and bad addr (error)",
   146  			args: args{
   147  				service: &policy.ApplicationService{
   148  					NetworkInfo: &common.Service{
   149  						FQDNs: []string{},
   150  					},
   151  				},
   152  			},
   153  			wantName: "",
   154  			wantErr:  true,
   155  		},
   156  		{
   157  			name: "no fqdn and valid addr (success)",
   158  			args: args{
   159  				addrAndPort: "dns:80",
   160  				service: &policy.ApplicationService{
   161  					NetworkInfo: &common.Service{
   162  						FQDNs: []string{},
   163  					},
   164  				},
   165  			},
   166  			wantName: "dns",
   167  			wantErr:  false,
   168  		},
   169  		{
   170  			name: "fqdn and valid addr use fqdn[0]",
   171  			args: args{
   172  				addrAndPort: "dns:80",
   173  				service: &policy.ApplicationService{
   174  					NetworkInfo: &common.Service{
   175  						FQDNs: []string{"www.google.com", "alt.google.com"},
   176  					},
   177  				},
   178  			},
   179  			wantName: "www.google.com",
   180  			wantErr:  false,
   181  		},
   182  	}
   183  	for _, tt := range tests {
   184  		t.Run(tt.name, func(t *testing.T) {
   185  			gotName, err := acommon.GetTLSServerName(tt.args.addrAndPort, tt.args.service)
   186  			if (err != nil) != tt.wantErr {
   187  				t.Errorf("getTLSServerName() error = %v, wantErr %v", err, tt.wantErr)
   188  				return
   189  			}
   190  			if gotName != tt.wantName {
   191  				t.Errorf("getTLSServerName() = %v, want %v", gotName, tt.wantName)
   192  			}
   193  		})
   194  	}
   195  }