github.com/blend/go-sdk@v1.20220411.3/webutil/posted_files_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package webutil
     9  
    10  import (
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"sort"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/blend/go-sdk/assert"
    20  )
    21  
    22  func Test_PostedFiles(t *testing.T) {
    23  	its := assert.New(t)
    24  
    25  	file0 := PostedFile{
    26  		Key:      "file0",
    27  		FileName: "file0.txt",
    28  		Contents: []byte("file0-contents"),
    29  	}
    30  	file1 := PostedFile{
    31  		Key:      "file1",
    32  		FileName: "file1.txt",
    33  		Contents: []byte(strings.Repeat("a", 1<<20)),
    34  	}
    35  
    36  	server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    37  		files, err := PostedFiles(req)
    38  		if err != nil {
    39  			http.Error(rw, err.Error(), http.StatusBadRequest)
    40  			return
    41  		}
    42  		if len(files) != 2 {
    43  			http.Error(rw, "invalid file count", http.StatusBadRequest)
    44  			return
    45  		}
    46  		sort.Slice(files, func(i, j int) bool {
    47  			return files[i].Key < files[j].Key
    48  		})
    49  
    50  		if files[0].Key != file0.Key {
    51  			http.Error(rw, fmt.Sprintf("invalid file0 key: %s, expectd: %s", files[0].Key, file0.Key), http.StatusBadRequest)
    52  			return
    53  		}
    54  		if files[0].FileName != file0.FileName {
    55  			http.Error(rw, fmt.Sprintf("invalid file0 key: %s, expectd: %s", files[0].FileName, file0.FileName), http.StatusBadRequest)
    56  			return
    57  		}
    58  		if string(files[0].Contents) != string(file0.Contents) {
    59  			http.Error(rw, fmt.Sprintf("invalid file0 contents: %s, expectd: %s", files[0].Contents, file0.Contents), http.StatusBadRequest)
    60  			return
    61  		}
    62  
    63  		if files[1].Key != file1.Key {
    64  			http.Error(rw, fmt.Sprintf("invalid file1 key: %s, expectd: %s", files[1].Key, file1.Key), http.StatusBadRequest)
    65  			return
    66  		}
    67  		if files[1].FileName != file1.FileName {
    68  			http.Error(rw, fmt.Sprintf("invalid file1 key: %s, expectd: %s", files[1].FileName, file1.FileName), http.StatusBadRequest)
    69  			return
    70  		}
    71  		if string(files[1].Contents) != string(file1.Contents) {
    72  			http.Error(rw, "invalid file1 contents", http.StatusBadRequest)
    73  			return
    74  		}
    75  		rw.WriteHeader(http.StatusOK)
    76  		fmt.Fprintf(rw, "OK!")
    77  		return
    78  	}))
    79  	defer server.Close()
    80  
    81  	r, err := http.NewRequest(http.MethodPost, server.URL, nil)
    82  	its.Nil(err)
    83  	err = OptPostedFiles(file0, file1)(r)
    84  	its.Nil(err)
    85  	res, err := http.DefaultClient.Do(r)
    86  	its.Nil(err)
    87  	defer res.Body.Close()
    88  	bodyContents, _ := io.ReadAll(res.Body)
    89  	its.Equal(http.StatusOK, res.StatusCode, string(bodyContents))
    90  }
    91  
    92  func Test_PostedFiles_onlyParseForm(t *testing.T) {
    93  	its := assert.New(t)
    94  
    95  	file0 := PostedFile{
    96  		Key:      "file0",
    97  		FileName: "file0.txt",
    98  		Contents: []byte("file0-contents"),
    99  	}
   100  	file1 := PostedFile{
   101  		Key:      "file1",
   102  		FileName: "file1.txt",
   103  		Contents: []byte(strings.Repeat("a", 1<<20)),
   104  	}
   105  
   106  	server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   107  		files, err := PostedFiles(req,
   108  			OptPostedFilesParseMultipartForm(false),
   109  			OptPostedFilesParseForm(true),
   110  		)
   111  		if err != nil {
   112  			http.Error(rw, err.Error(), http.StatusBadRequest)
   113  			return
   114  		}
   115  		if len(files) != 0 {
   116  			http.Error(rw, fmt.Sprintf("invalid file count: %d", len(files)), http.StatusBadRequest)
   117  			return
   118  		}
   119  		rw.WriteHeader(http.StatusOK)
   120  		fmt.Fprintf(rw, "OK!")
   121  		return
   122  	}))
   123  	defer server.Close()
   124  
   125  	r, err := http.NewRequest(http.MethodPost, server.URL, nil)
   126  	its.Nil(err)
   127  	err = OptPostedFiles(file0, file1)(r)
   128  	its.Nil(err)
   129  	res, err := http.DefaultClient.Do(r)
   130  	its.Nil(err)
   131  	defer res.Body.Close()
   132  	contents, err := io.ReadAll(res.Body)
   133  	its.Nil(err)
   134  	its.Equal(http.StatusOK, res.StatusCode, string(contents))
   135  }
   136  
   137  func Test_PostedFiles_maxMemory(t *testing.T) {
   138  	its := assert.New(t)
   139  
   140  	file0 := PostedFile{
   141  		Key:      "file0",
   142  		FileName: "file0.txt",
   143  		Contents: []byte("file0-contents"),
   144  	}
   145  	file1 := PostedFile{
   146  		Key:      "file1",
   147  		FileName: "file1.txt",
   148  		Contents: []byte(strings.Repeat("a", 1<<20)),
   149  	}
   150  
   151  	server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   152  		files, err := PostedFiles(req,
   153  			OptPostedFilesMaxMemory(1<<10),
   154  			OptPostedFilesParseMultipartForm(true),
   155  			OptPostedFilesParseForm(false),
   156  		)
   157  		if err != nil {
   158  			http.Error(rw, err.Error(), http.StatusBadRequest)
   159  			return
   160  		}
   161  		if len(files) != 2 {
   162  			http.Error(rw, fmt.Sprintf("invalid file count: %d", len(files)), http.StatusBadRequest)
   163  			return
   164  		}
   165  		rw.WriteHeader(http.StatusOK)
   166  		fmt.Fprintf(rw, "OK!")
   167  		return
   168  	}))
   169  	defer server.Close()
   170  
   171  	r, err := http.NewRequest(http.MethodPost, server.URL, nil)
   172  	its.Nil(err)
   173  	err = OptPostedFiles(file0, file1)(r)
   174  	its.Nil(err)
   175  	res, err := http.DefaultClient.Do(r)
   176  	its.Nil(err)
   177  	defer res.Body.Close()
   178  	contents, err := io.ReadAll(res.Body)
   179  	its.Nil(err)
   180  	its.Equal(http.StatusOK, res.StatusCode, string(contents))
   181  }