goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/util/testutil/testutil.go (about)

     1  package testutil
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"io"
     7  	"io/fs"
     8  	"math"
     9  	"mime/multipart"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"path"
    14  	"path/filepath"
    15  	"testing"
    16  
    17  	"goyave.dev/goyave/v5"
    18  	"goyave.dev/goyave/v5/config"
    19  	"goyave.dev/goyave/v5/slog"
    20  	"goyave.dev/goyave/v5/util/errors"
    21  	"goyave.dev/goyave/v5/util/fsutil"
    22  	"goyave.dev/goyave/v5/util/fsutil/osfs"
    23  )
    24  
    25  type copyRequestMiddleware struct {
    26  	goyave.Component
    27  	request *goyave.Request
    28  }
    29  
    30  func (m *copyRequestMiddleware) Handle(next goyave.Handler) goyave.Handler {
    31  	return func(response *goyave.Response, request *goyave.Request) {
    32  		request.Now = m.request.Now
    33  		request.Data = m.request.Data
    34  		request.Extra = m.request.Extra
    35  		request.Lang = m.request.Lang
    36  		request.Query = m.request.Query
    37  		request.RouteParams = m.request.RouteParams
    38  		request.User = m.request.User
    39  		request.Route = m.request.Route
    40  		next(response, request.WithContext(m.request.Context()))
    41  	}
    42  }
    43  
    44  // TestServer extension of `goyave.Server` providing useful functions for testing.
    45  type TestServer struct {
    46  	*goyave.Server
    47  }
    48  
    49  // NewTestServer creates a new server using the given config file. The config path is relative to
    50  // the project's directory. If not nil, the given `routeRegistrer` function is called to register
    51  // routes without starting the server.
    52  //
    53  // A default logger redirecting the output to `testing.T.Log()` is used.
    54  //
    55  // Automatically closes the DB connection (if there is one) using a test `Cleanup` function.
    56  func NewTestServer(t *testing.T, configFileName string) *TestServer {
    57  	rootDirectory := FindRootDirectory()
    58  	cfgPath := path.Join(rootDirectory, configFileName)
    59  	cfg, err := config.LoadFrom(cfgPath)
    60  	if err != nil {
    61  		panic(errors.New(err))
    62  	}
    63  
    64  	return NewTestServerWithOptions(t, goyave.Options{Config: cfg})
    65  }
    66  
    67  // NewTestServerWithOptions creates a new server using the given options.
    68  // If not nil, the given `routeRegistrer` function is called to register
    69  // routes without starting the server.
    70  //
    71  // By default, if no `Logger` is given in the options, a default logger redirecting the
    72  // output to `testing.T.Log()` is used.
    73  //
    74  // Automatically closes the DB connection (if there is one) using a test `Cleanup` function.
    75  func NewTestServerWithOptions(t *testing.T, opts goyave.Options) *TestServer {
    76  	if opts.Config == nil {
    77  		cfg, err := config.Load()
    78  		if err != nil {
    79  			panic(errors.New(err))
    80  		}
    81  		opts.Config = cfg
    82  	}
    83  
    84  	if opts.Logger == nil {
    85  		opts.Logger = slog.New(slog.NewHandler(opts.Config.GetBool("app.debug"), &LogWriter{t: t}))
    86  	}
    87  
    88  	srv, err := goyave.New(opts)
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  
    93  	langDirectory := path.Join(FindRootDirectory(), "resources", "lang")
    94  	if err := srv.Lang.LoadDirectory(&osfs.FS{}, langDirectory); err != nil {
    95  		panic(err)
    96  	}
    97  
    98  	s := &TestServer{srv}
    99  	if t != nil {
   100  		t.Cleanup(func() { s.CloseDB() })
   101  	}
   102  	return s
   103  }
   104  
   105  // TestRequest execute a request by calling the root Router's `ServeHTTP()` implementation.
   106  func (s *TestServer) TestRequest(request *http.Request) *http.Response {
   107  	recorder := httptest.NewRecorder()
   108  	s.Router().ServeHTTP(recorder, request)
   109  	return recorder.Result()
   110  }
   111  
   112  // TestMiddleware executes with the given request and returns the response.
   113  // The `procedure` parameter is the `next` handler passed to the middleware and can be used to
   114  // make assertions. Keep in mind that this procedure won't be executed if your middleware is blocking.
   115  //
   116  // The request will go through the entire lifecycle like a regular request.
   117  //
   118  // The given request is cloned. If the middleware alters the request object, these changes won't be reflected on the input request.
   119  // You can do your assertions inside the `procedure`.
   120  func (s *TestServer) TestMiddleware(middleware goyave.Middleware, request *goyave.Request, procedure goyave.Handler) *http.Response {
   121  	recorder := httptest.NewRecorder()
   122  	router := goyave.NewRouter(s.Server)
   123  	router.GlobalMiddleware(&copyRequestMiddleware{request: request})
   124  	router.Route([]string{request.Method()}, request.Request().URL.Path, procedure).Middleware(middleware)
   125  	router.ServeHTTP(recorder, request.Request())
   126  	return recorder.Result()
   127  }
   128  
   129  // CloseDB close the server DB if one is open. It is a good practice to always
   130  // call this in a test `Cleanup` function when using a database.
   131  func (s *TestServer) CloseDB() {
   132  	if err := s.Server.CloseDB(); err != nil {
   133  		s.Logger.Error(err)
   134  	}
   135  }
   136  
   137  // FindRootDirectory find relative path to the project's root directory based on the
   138  // existence of a `go.mod` file. The returned path is a rooted path.
   139  // Returns an empty string if not found.
   140  func FindRootDirectory() string {
   141  	wd, err := os.Getwd()
   142  	if err != nil {
   143  		return ""
   144  	}
   145  	directory := wd
   146  	fs := &osfs.FS{}
   147  	for !fs.FileExists(path.Join(directory, "go.mod")) {
   148  		directory = path.Join(directory, "..")
   149  		if !fs.IsDirectory(directory) {
   150  			return ""
   151  		}
   152  	}
   153  	return directory
   154  }
   155  
   156  // NewTestRequest create a new `goyave.Request` with an underlying HTTP request created
   157  // usin the `httptest` package.
   158  func NewTestRequest(method, uri string, body io.Reader) *goyave.Request {
   159  	req := httptest.NewRequest(method, uri, body)
   160  	return goyave.NewRequest(req)
   161  }
   162  
   163  // NewTestRequest create a new `goyave.Request` with an underlying HTTP request created
   164  // usin the `httptest` package. This function sets the request language using the default
   165  // language of the server.
   166  func (s *TestServer) NewTestRequest(method, uri string, body io.Reader) *goyave.Request {
   167  	req := NewTestRequest(method, uri, body)
   168  	req.Lang = s.Lang.GetDefault()
   169  	return req
   170  }
   171  
   172  // NewTestResponse create a new `goyave.Response` with an underlying HTTP response recorder created
   173  // using the `httptest` package. This function uses a temporary `goyave.Server` with all defaults values loaded
   174  // so all functions of `*goyave.Response` can be used safely.
   175  func NewTestResponse(request *goyave.Request) (*goyave.Response, *httptest.ResponseRecorder) {
   176  	recorder := httptest.NewRecorder()
   177  	return goyave.NewResponse(NewTestServerWithOptions(nil, goyave.Options{Config: config.LoadDefault()}).Server, request, recorder), recorder
   178  }
   179  
   180  // NewTestResponse create a new `goyave.Response` with an underlying HTTP response recorder created
   181  // using the `httptest` package.
   182  func (s *TestServer) NewTestResponse(request *goyave.Request) (*goyave.Response, *httptest.ResponseRecorder) {
   183  	recorder := httptest.NewRecorder()
   184  	return goyave.NewResponse(s.Server, request, recorder), recorder
   185  }
   186  
   187  // ReadJSONBody decodes the given body reader into a new variable of type `*T`.
   188  func ReadJSONBody[T any](body io.Reader) (T, error) {
   189  	var data T
   190  	err := json.NewDecoder(body).Decode(&data)
   191  	return data, errors.New(err)
   192  }
   193  
   194  // WriteMultipartFile reads a file from the given FS and writes it to the given multipart writer.
   195  func WriteMultipartFile(writer *multipart.Writer, filesystem fs.FS, path, fieldName, fileName string) (err error) {
   196  	var file fs.File
   197  	file, err = filesystem.Open(path)
   198  	if err != nil {
   199  		err = errors.New(err)
   200  		return
   201  	}
   202  	defer func() {
   203  		e := file.Close()
   204  		if err == nil && e != nil {
   205  			err = errors.New(e)
   206  		}
   207  	}()
   208  	part, err := writer.CreateFormFile(fieldName, fileName)
   209  	if err != nil {
   210  		err = errors.New(err)
   211  		return
   212  	}
   213  	_, err = io.Copy(part, file)
   214  	if err != nil {
   215  		err = errors.New(err)
   216  	}
   217  	return
   218  }
   219  
   220  // CreateTestFiles create a slice of "fsutil.File" from the given FS.
   221  // To reproduce the way the files are obtained in real scenarios,
   222  // files are first encoded in a multipart form, then decoded with
   223  // a multipart form reader.
   224  //
   225  // Paths are relative to the caller, not relative to the project's root directory.
   226  func CreateTestFiles(fs fs.FS, paths ...string) ([]fsutil.File, error) {
   227  	fieldName := "file"
   228  	body := &bytes.Buffer{}
   229  	writer := multipart.NewWriter(body)
   230  	for _, p := range paths {
   231  		if err := WriteMultipartFile(writer, fs, p, fieldName, filepath.Base(p)); err != nil {
   232  			return nil, errors.New(err)
   233  		}
   234  	}
   235  	err := writer.Close()
   236  	if err != nil {
   237  		return nil, errors.New(err)
   238  	}
   239  
   240  	reader := multipart.NewReader(body, writer.Boundary())
   241  	form, err := reader.ReadForm(math.MaxInt64 - 1)
   242  	if err != nil {
   243  		return nil, errors.New(err)
   244  	}
   245  	return fsutil.ParseMultipartFiles(form.File[fieldName])
   246  }
   247  
   248  // ToJSON marshals the given data and creates a bytes reader from the result.
   249  // Panics on error.
   250  func ToJSON(data any) *bytes.Reader {
   251  	res, err := json.Marshal(data)
   252  	if err != nil {
   253  		panic(errors.New(err))
   254  	}
   255  	return bytes.NewReader(res)
   256  }