github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/lib/http/server_test.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func testEmptyHandler() http.Handler {
    18  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    19  }
    20  
    21  func testEchoHandler(data []byte) http.Handler {
    22  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    23  		_, _ = w.Write(data)
    24  	})
    25  }
    26  
    27  func testAuthUserHandler() http.Handler {
    28  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    29  		userID, ok := CtxGetUser(r.Context())
    30  		if !ok {
    31  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
    32  		}
    33  		_, _ = w.Write([]byte(userID))
    34  	})
    35  }
    36  
    37  func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) {
    38  	body, err := io.ReadAll(resp.Body)
    39  	require.NoError(t, err)
    40  	require.Equal(t, expected, body)
    41  }
    42  
    43  func testGetServerURL(t *testing.T, s *Server) string {
    44  	urls := s.URLs()
    45  	require.GreaterOrEqual(t, len(urls), 1, "server should return at least one url")
    46  	return urls[0]
    47  }
    48  
    49  func testNewHTTPClientUnix(path string) *http.Client {
    50  	return &http.Client{
    51  		Transport: &http.Transport{
    52  			DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
    53  				return net.Dial("unix", path)
    54  			},
    55  		},
    56  	}
    57  }
    58  
    59  func testReadTestdataFile(t *testing.T, path string) []byte {
    60  	data, err := os.ReadFile(filepath.Join("./testdata", path))
    61  	require.NoError(t, err, "")
    62  	return data
    63  }
    64  
    65  func TestNewServerUnix(t *testing.T) {
    66  	ctx := context.Background()
    67  
    68  	tempDir := t.TempDir()
    69  	path := filepath.Join(tempDir, "rclone.sock")
    70  
    71  	cfg := DefaultCfg()
    72  	cfg.ListenAddr = []string{path}
    73  
    74  	auth := AuthConfig{
    75  		BasicUser: "test",
    76  		BasicPass: "test",
    77  	}
    78  
    79  	s, err := NewServer(ctx, WithConfig(cfg), WithAuth(auth))
    80  	require.NoError(t, err)
    81  	defer func() {
    82  		require.NoError(t, s.Shutdown())
    83  		_, err := os.Stat(path)
    84  		require.ErrorIs(t, err, os.ErrNotExist, "shutdown should remove socket")
    85  	}()
    86  
    87  	require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
    88  
    89  	expected := []byte("hello world")
    90  	s.Router().Mount("/", testEchoHandler(expected))
    91  	s.Serve()
    92  
    93  	client := testNewHTTPClientUnix(path)
    94  	req, err := http.NewRequest("GET", "http://unix", nil)
    95  	require.NoError(t, err)
    96  
    97  	resp, err := client.Do(req)
    98  	require.NoError(t, err)
    99  
   100  	testExpectRespBody(t, resp, expected)
   101  
   102  	require.Equal(t, http.StatusOK, resp.StatusCode, "unix sockets should ignore auth")
   103  
   104  	for _, key := range _testCORSHeaderKeys {
   105  		require.NotContains(t, resp.Header, key, "unix sockets should not be sent CORS headers")
   106  	}
   107  }
   108  
   109  func TestNewServerHTTP(t *testing.T) {
   110  	ctx := context.Background()
   111  
   112  	cfg := DefaultCfg()
   113  	cfg.ListenAddr = []string{"127.0.0.1:0"}
   114  
   115  	auth := AuthConfig{
   116  		BasicUser: "test",
   117  		BasicPass: "test",
   118  	}
   119  
   120  	s, err := NewServer(ctx, WithConfig(cfg), WithAuth(auth))
   121  	require.NoError(t, err)
   122  	defer func() {
   123  		require.NoError(t, s.Shutdown())
   124  	}()
   125  
   126  	url := testGetServerURL(t, s)
   127  	require.True(t, strings.HasPrefix(url, "http://"), "url should have http scheme")
   128  
   129  	expected := []byte("hello world")
   130  	s.Router().Mount("/", testEchoHandler(expected))
   131  	s.Serve()
   132  
   133  	t.Run("StatusUnauthorized", func(t *testing.T) {
   134  		client := &http.Client{}
   135  		req, err := http.NewRequest("GET", url, nil)
   136  		require.NoError(t, err)
   137  
   138  		resp, err := client.Do(req)
   139  		require.NoError(t, err)
   140  		defer func() {
   141  			_ = resp.Body.Close()
   142  		}()
   143  
   144  		require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "no basic auth creds should return unauthorized")
   145  	})
   146  
   147  	t.Run("StatusOK", func(t *testing.T) {
   148  		client := &http.Client{}
   149  		req, err := http.NewRequest("GET", url, nil)
   150  		require.NoError(t, err)
   151  
   152  		req.SetBasicAuth(auth.BasicUser, auth.BasicPass)
   153  
   154  		resp, err := client.Do(req)
   155  		require.NoError(t, err)
   156  		defer func() {
   157  			_ = resp.Body.Close()
   158  		}()
   159  
   160  		require.Equal(t, http.StatusOK, resp.StatusCode, "using basic auth creds should return ok")
   161  
   162  		testExpectRespBody(t, resp, expected)
   163  	})
   164  }
   165  func TestNewServerBaseURL(t *testing.T) {
   166  	servers := []struct {
   167  		name   string
   168  		cfg    Config
   169  		suffix string
   170  	}{
   171  		{
   172  			name: "Empty",
   173  			cfg: Config{
   174  				ListenAddr: []string{"127.0.0.1:0"},
   175  				BaseURL:    "",
   176  			},
   177  			suffix: "/",
   178  		},
   179  		{
   180  			name: "Single/NoTrailingSlash",
   181  			cfg: Config{
   182  				ListenAddr: []string{"127.0.0.1:0"},
   183  				BaseURL:    "/rclone",
   184  			},
   185  			suffix: "/rclone/",
   186  		},
   187  		{
   188  			name: "Single/TrailingSlash",
   189  			cfg: Config{
   190  				ListenAddr: []string{"127.0.0.1:0"},
   191  				BaseURL:    "/rclone/",
   192  			},
   193  			suffix: "/rclone/",
   194  		},
   195  		{
   196  			name: "Multi/NoTrailingSlash",
   197  			cfg: Config{
   198  				ListenAddr: []string{"127.0.0.1:0"},
   199  				BaseURL:    "/rclone/test/base/url",
   200  			},
   201  			suffix: "/rclone/test/base/url/",
   202  		},
   203  		{
   204  			name: "Multi/TrailingSlash",
   205  			cfg: Config{
   206  				ListenAddr: []string{"127.0.0.1:0"},
   207  				BaseURL:    "/rclone/test/base/url/",
   208  			},
   209  			suffix: "/rclone/test/base/url/",
   210  		},
   211  	}
   212  
   213  	for _, ss := range servers {
   214  		t.Run(ss.name, func(t *testing.T) {
   215  			s, err := NewServer(context.Background(), WithConfig(ss.cfg))
   216  			require.NoError(t, err)
   217  			defer func() {
   218  				require.NoError(t, s.Shutdown())
   219  			}()
   220  
   221  			expected := []byte("data")
   222  			s.Router().Get("/", testEchoHandler(expected).ServeHTTP)
   223  			s.Serve()
   224  
   225  			url := testGetServerURL(t, s)
   226  			require.True(t, strings.HasPrefix(url, "http://"), "url should have http scheme")
   227  			require.True(t, strings.HasSuffix(url, ss.suffix), "url should have the expected suffix")
   228  
   229  			client := &http.Client{}
   230  			req, err := http.NewRequest("GET", url, nil)
   231  			require.NoError(t, err)
   232  
   233  			resp, err := client.Do(req)
   234  			require.NoError(t, err)
   235  			defer func() {
   236  				_ = resp.Body.Close()
   237  			}()
   238  
   239  			t.Log(url, resp.Request.URL)
   240  
   241  			require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
   242  
   243  			testExpectRespBody(t, resp, expected)
   244  		})
   245  	}
   246  }
   247  
   248  func TestNewServerTLS(t *testing.T) {
   249  	serverCertBytes := testReadTestdataFile(t, "local.crt")
   250  	serverKeyBytes := testReadTestdataFile(t, "local.key")
   251  	clientCertBytes := testReadTestdataFile(t, "client.crt")
   252  	clientKeyBytes := testReadTestdataFile(t, "client.key")
   253  	clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
   254  	require.NoError(t, err)
   255  
   256  	// TODO: generate a proper cert with SAN
   257  
   258  	servers := []struct {
   259  		name          string
   260  		clientCerts   []tls.Certificate
   261  		wantErr       bool
   262  		wantClientErr bool
   263  		err           error
   264  		http          Config
   265  	}{
   266  		{
   267  			name: "FromFile/Valid",
   268  			http: Config{
   269  				ListenAddr:    []string{"127.0.0.1:0"},
   270  				TLSCert:       "./testdata/local.crt",
   271  				TLSKey:        "./testdata/local.key",
   272  				MinTLSVersion: "tls1.0",
   273  			},
   274  		},
   275  		{
   276  			name:    "FromFile/NoCert",
   277  			wantErr: true,
   278  			err:     ErrTLSFileMismatch,
   279  			http: Config{
   280  				ListenAddr:    []string{"127.0.0.1:0"},
   281  				TLSCert:       "",
   282  				TLSKey:        "./testdata/local.key",
   283  				MinTLSVersion: "tls1.0",
   284  			},
   285  		},
   286  		{
   287  			name:    "FromFile/InvalidCert",
   288  			wantErr: true,
   289  			http: Config{
   290  				ListenAddr:    []string{"127.0.0.1:0"},
   291  				TLSCert:       "./testdata/local.crt.invalid",
   292  				TLSKey:        "./testdata/local.key",
   293  				MinTLSVersion: "tls1.0",
   294  			},
   295  		},
   296  		{
   297  			name:    "FromFile/NoKey",
   298  			wantErr: true,
   299  			err:     ErrTLSFileMismatch,
   300  			http: Config{
   301  				ListenAddr:    []string{"127.0.0.1:0"},
   302  				TLSCert:       "./testdata/local.crt",
   303  				TLSKey:        "",
   304  				MinTLSVersion: "tls1.0",
   305  			},
   306  		},
   307  		{
   308  			name:    "FromFile/InvalidKey",
   309  			wantErr: true,
   310  			http: Config{
   311  				ListenAddr:    []string{"127.0.0.1:0"},
   312  				TLSCert:       "./testdata/local.crt",
   313  				TLSKey:        "./testdata/local.key.invalid",
   314  				MinTLSVersion: "tls1.0",
   315  			},
   316  		},
   317  		{
   318  			name: "FromBody/Valid",
   319  			http: Config{
   320  				ListenAddr:    []string{"127.0.0.1:0"},
   321  				TLSCertBody:   serverCertBytes,
   322  				TLSKeyBody:    serverKeyBytes,
   323  				MinTLSVersion: "tls1.0",
   324  			},
   325  		},
   326  		{
   327  			name:    "FromBody/NoCert",
   328  			wantErr: true,
   329  			err:     ErrTLSBodyMismatch,
   330  			http: Config{
   331  				ListenAddr:    []string{"127.0.0.1:0"},
   332  				TLSCertBody:   nil,
   333  				TLSKeyBody:    serverKeyBytes,
   334  				MinTLSVersion: "tls1.0",
   335  			},
   336  		},
   337  		{
   338  			name:    "FromBody/InvalidCert",
   339  			wantErr: true,
   340  			http: Config{
   341  				ListenAddr:    []string{"127.0.0.1:0"},
   342  				TLSCertBody:   []byte("JUNK DATA"),
   343  				TLSKeyBody:    serverKeyBytes,
   344  				MinTLSVersion: "tls1.0",
   345  			},
   346  		},
   347  		{
   348  			name:    "FromBody/NoKey",
   349  			wantErr: true,
   350  			err:     ErrTLSBodyMismatch,
   351  			http: Config{
   352  				ListenAddr:    []string{"127.0.0.1:0"},
   353  				TLSCertBody:   serverCertBytes,
   354  				TLSKeyBody:    nil,
   355  				MinTLSVersion: "tls1.0",
   356  			},
   357  		},
   358  		{
   359  			name:    "FromBody/InvalidKey",
   360  			wantErr: true,
   361  			http: Config{
   362  				ListenAddr:    []string{"127.0.0.1:0"},
   363  				TLSCertBody:   serverCertBytes,
   364  				TLSKeyBody:    []byte("JUNK DATA"),
   365  				MinTLSVersion: "tls1.0",
   366  			},
   367  		},
   368  		{
   369  			name: "MinTLSVersion/Valid/1.1",
   370  			http: Config{
   371  				ListenAddr:    []string{"127.0.0.1:0"},
   372  				TLSCertBody:   serverCertBytes,
   373  				TLSKeyBody:    serverKeyBytes,
   374  				MinTLSVersion: "tls1.1",
   375  			},
   376  		},
   377  		{
   378  			name: "MinTLSVersion/Valid/1.2",
   379  			http: Config{
   380  				ListenAddr:    []string{"127.0.0.1:0"},
   381  				TLSCertBody:   serverCertBytes,
   382  				TLSKeyBody:    serverKeyBytes,
   383  				MinTLSVersion: "tls1.2",
   384  			},
   385  		},
   386  		{
   387  			name: "MinTLSVersion/Valid/1.3",
   388  			http: Config{
   389  				ListenAddr:    []string{"127.0.0.1:0"},
   390  				TLSCertBody:   serverCertBytes,
   391  				TLSKeyBody:    serverKeyBytes,
   392  				MinTLSVersion: "tls1.3",
   393  			},
   394  		},
   395  		{
   396  			name:    "MinTLSVersion/Invalid",
   397  			wantErr: true,
   398  			err:     ErrInvalidMinTLSVersion,
   399  			http: Config{
   400  				ListenAddr:    []string{"127.0.0.1:0"},
   401  				TLSCertBody:   serverCertBytes,
   402  				TLSKeyBody:    serverKeyBytes,
   403  				MinTLSVersion: "tls0.9",
   404  			},
   405  		},
   406  		{
   407  			name:        "MutualTLS/InvalidCA",
   408  			clientCerts: []tls.Certificate{clientCert},
   409  			wantErr:     true,
   410  			http: Config{
   411  				ListenAddr:    []string{"127.0.0.1:0"},
   412  				TLSCertBody:   serverCertBytes,
   413  				TLSKeyBody:    serverKeyBytes,
   414  				MinTLSVersion: "tls1.0",
   415  				ClientCA:      "./testdata/client-ca.crt.invalid",
   416  			},
   417  		},
   418  		{
   419  			name:          "MutualTLS/InvalidClient",
   420  			clientCerts:   []tls.Certificate{},
   421  			wantClientErr: true,
   422  			http: Config{
   423  				ListenAddr:    []string{"127.0.0.1:0"},
   424  				TLSCertBody:   serverCertBytes,
   425  				TLSKeyBody:    serverKeyBytes,
   426  				MinTLSVersion: "tls1.0",
   427  				ClientCA:      "./testdata/client-ca.crt",
   428  			},
   429  		},
   430  		{
   431  			name:        "MutualTLS/Valid",
   432  			clientCerts: []tls.Certificate{clientCert},
   433  			http: Config{
   434  				ListenAddr:    []string{"127.0.0.1:0"},
   435  				TLSCertBody:   serverCertBytes,
   436  				TLSKeyBody:    serverKeyBytes,
   437  				MinTLSVersion: "tls1.0",
   438  				ClientCA:      "./testdata/client-ca.crt",
   439  			},
   440  		},
   441  	}
   442  
   443  	for _, ss := range servers {
   444  		t.Run(ss.name, func(t *testing.T) {
   445  			s, err := NewServer(context.Background(), WithConfig(ss.http))
   446  			if ss.wantErr == true {
   447  				if ss.err != nil {
   448  					require.ErrorIs(t, err, ss.err, "new server should return the expected error")
   449  				} else {
   450  					require.Error(t, err, "new server should return error for invalid TLS config")
   451  				}
   452  				return
   453  			}
   454  
   455  			require.NoError(t, err)
   456  			defer func() {
   457  				require.NoError(t, s.Shutdown())
   458  			}()
   459  
   460  			expected := []byte("secret-page")
   461  			s.Router().Mount("/", testEchoHandler(expected))
   462  			s.Serve()
   463  
   464  			url := testGetServerURL(t, s)
   465  			require.True(t, strings.HasPrefix(url, "https://"), "url should have https scheme")
   466  
   467  			client := &http.Client{
   468  				Transport: &http.Transport{
   469  					DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
   470  						dest := strings.TrimPrefix(url, "https://")
   471  						dest = strings.TrimSuffix(dest, "/")
   472  						return net.Dial("tcp", dest)
   473  					},
   474  					TLSClientConfig: &tls.Config{
   475  						Certificates:       ss.clientCerts,
   476  						InsecureSkipVerify: true,
   477  					},
   478  				},
   479  			}
   480  			req, err := http.NewRequest("GET", "https://dev.rclone.org", nil)
   481  			require.NoError(t, err)
   482  
   483  			resp, err := client.Do(req)
   484  
   485  			if ss.wantClientErr {
   486  				require.Error(t, err, "new server client should return error")
   487  				return
   488  			}
   489  
   490  			require.NoError(t, err)
   491  			defer func() {
   492  				_ = resp.Body.Close()
   493  			}()
   494  
   495  			require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
   496  
   497  			testExpectRespBody(t, resp, expected)
   498  		})
   499  	}
   500  }
   501  
   502  func TestHelpPrefixServer(t *testing.T) {
   503  	// This test assumes template variables are placed correctly.
   504  	const testPrefix = "server-help-test"
   505  	helpMessage := Help(testPrefix)
   506  	if !strings.Contains(helpMessage, testPrefix) {
   507  		t.Fatal("flag prefix not found")
   508  	}
   509  }