github.com/icyphox/x@v0.0.355-0.20220311094250-029bd783e8b8/osx/file_test.go (about)

     1  package osx
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/ory/x/httpx"
    14  )
    15  
    16  var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
    17  	_, _ = w.Write([]byte("hello world"))
    18  }
    19  
    20  func TestReadFileFromAllSources(t *testing.T) {
    21  	ts := httptest.NewServer(handler)
    22  	defer ts.Close()
    23  
    24  	sslTS := httptest.NewTLSServer(handler)
    25  	defer sslTS.Close()
    26  
    27  	for k, tc := range []struct {
    28  		opts         []Option
    29  		src          string
    30  		expectedErr  string
    31  		expectedBody string
    32  	}{
    33  		{src: "base64://aGVsbG8gd29ybGQ", expectedBody: "hello world"},
    34  		{src: "base64://aGVsbG8gd29ybGQ=", expectedBody: "hello world", opts: []Option{WithBase64Encoding(base64.URLEncoding)}},
    35  		{src: "base64://aGVsbG8gd29ybGQ=", expectedErr: "unable to base64 decode the location: illegal base64 data at input byte 15"},
    36  		{src: "base64://aGVsbG8gd29ybGQ", expectedErr: "base64 loader disabled", opts: []Option{WithDisabledBase64Loader()}},
    37  		{src: "base64://notbase64", expectedErr: "unable to base64 decode the location: illegal base64 data at input byte 8"},
    38  
    39  		{src: "file://stub/text.txt", expectedBody: "hello world"},
    40  		{src: "file://stub/text.txt", expectedErr: "file loader disabled", opts: []Option{WithDisabledFileLoader()}},
    41  
    42  		{src: ts.URL, expectedBody: "hello world"},
    43  		{src: sslTS.URL, expectedErr: "unable to load remote file: GET " + sslTS.URL + " giving up after 1 attempt(s): Get \"" + sslTS.URL + "\": x509: certificate signed by unknown authority"},
    44  		{src: sslTS.URL, expectedBody: "hello world", opts: []Option{WithHTTPClient(httpx.NewResilientClient(httpx.ResilientClientWithClient(sslTS.Client())))}},
    45  		{src: sslTS.URL, expectedErr: "http(s) loader disabled", opts: []Option{WithDisabledHTTPLoader()}},
    46  
    47  		{src: "file://stub/text.txt", expectedErr: "file loader disabled", opts: []Option{WithDisabledFileLoader()}},
    48  
    49  		{src: "lmao://stub/text.txt", expectedErr: "unsupported source `lmao`"},
    50  
    51  		{src: "/stub/text.txt", expectedErr: "unsupported source ``"},
    52  	} {
    53  		t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
    54  			body, err := ReadFileFromAllSources(tc.src, tc.opts...)
    55  			if tc.expectedErr != "" {
    56  				require.Error(t, err)
    57  				assert.Equal(t, tc.expectedErr, err.Error())
    58  				return
    59  			}
    60  			require.NoError(t, err)
    61  			assert.Equal(t, tc.expectedBody, string(body))
    62  		})
    63  	}
    64  }
    65  
    66  func TestRestrictedReadFile(t *testing.T) {
    67  	ts := httptest.NewServer(handler)
    68  	defer ts.Close()
    69  
    70  	sslTS := httptest.NewTLSServer(handler)
    71  	defer sslTS.Close()
    72  
    73  	for k, tc := range []struct {
    74  		opts         []Option
    75  		src          string
    76  		expectedErr  string
    77  		expectedBody string
    78  	}{
    79  		{src: "base64://aGVsbG8gd29ybGQ", expectedErr: "base64 loader disabled"},
    80  		{src: "base64://aGVsbG8gd29ybGQ", expectedBody: "hello world", opts: []Option{WithEnabledBase64Loader()}},
    81  
    82  		{src: "file://stub/text.txt", expectedErr: "file loader disabled"},
    83  		{src: "file://stub/text.txt", expectedBody: "hello world", opts: []Option{WithEnabledFileLoader()}},
    84  
    85  		{src: sslTS.URL, expectedErr: "http(s) loader disabled"},
    86  		{src: ts.URL, expectedBody: "hello world", opts: []Option{WithEnabledHTTPLoader()}},
    87  	} {
    88  		t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
    89  			body, err := RestrictedReadFile(tc.src, tc.opts...)
    90  			if tc.expectedErr != "" {
    91  				require.Error(t, err)
    92  				assert.Equal(t, tc.expectedErr, err.Error())
    93  				return
    94  			}
    95  			require.NoError(t, err)
    96  			assert.Equal(t, tc.expectedBody, string(body))
    97  		})
    98  	}
    99  }