github.com/Files-com/files-sdk-go/v2@v2.1.2/file/mockserver.go (about)

     1  package file
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/xml"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	files_sdk "github.com/Files-com/files-sdk-go/v2"
    21  	"github.com/Files-com/files-sdk-go/v2/lib"
    22  	"github.com/chilts/sid"
    23  	"github.com/gin-gonic/gin"
    24  	"github.com/samber/lo"
    25  )
    26  
    27  type randomReader struct {
    28  	n int
    29  }
    30  
    31  func (r *randomReader) Read(p []byte) (int, error) {
    32  	if r.n <= 0 {
    33  		return 0, io.EOF
    34  	}
    35  	if len(p) > r.n {
    36  		p = p[:r.n]
    37  	}
    38  
    39  	_, err := rand.Read(p)
    40  	if err != nil {
    41  		return 0, err
    42  	}
    43  
    44  	r.n -= len(p)
    45  	return len(p), nil
    46  }
    47  
    48  type CustomResponse struct {
    49  	Status      int
    50  	Body        []byte
    51  	ContentType string
    52  }
    53  
    54  type MockAPIServer struct {
    55  	router *gin.Engine
    56  	Addr   string
    57  	*httptest.Server
    58  	downloads       *lib.Map[download]
    59  	MockFiles       map[string]mockFile
    60  	customResponses map[string]func(ctx *gin.Context, model interface{}) bool
    61  	*testing.T
    62  	TrackRequest map[string][]string
    63  	traceMutex   *sync.Mutex
    64  }
    65  
    66  type download struct {
    67  	Id string
    68  	mockFile
    69  	Requests *lib.Map[files_sdk.ResponseError]
    70  }
    71  
    72  func (d download) init() download {
    73  	d.Requests = &lib.Map[files_sdk.ResponseError]{}
    74  	return d
    75  }
    76  
    77  type mockFile struct {
    78  	files_sdk.File
    79  	RealSize *int64
    80  	SizeTrust
    81  	ForceRequestStatus  string
    82  	ForceRequestMessage string
    83  	ServerBytesSent     *int64
    84  	MaxConnections      int
    85  	MaxConnectionsMutex *sync.Mutex
    86  }
    87  
    88  func (m mockFile) Completed() string {
    89  	if m.ForceRequestStatus != "" {
    90  		return m.ForceRequestStatus
    91  	}
    92  	return "completed"
    93  }
    94  
    95  type TestLogger struct {
    96  	*testing.T
    97  }
    98  
    99  func (t TestLogger) Printf(format string, args ...any) {
   100  	t.T.Logf(format, args...)
   101  }
   102  
   103  func (t TestLogger) Write(p []byte) (n int, err error) {
   104  	t.T.Log(string(p))
   105  	return len(p), nil
   106  }
   107  
   108  func (f *MockAPIServer) Do() *MockAPIServer {
   109  	gin.SetMode(gin.TestMode)
   110  	f.MockFiles = make(map[string]mockFile)
   111  	f.customResponses = make(map[string]func(ctx *gin.Context, model interface{}) bool)
   112  	f.TrackRequest = make(map[string][]string)
   113  	f.traceMutex = &sync.Mutex{}
   114  	f.downloads = &lib.Map[download]{}
   115  	f.router = gin.New()
   116  	f.router.Use(gin.LoggerWithWriter(TestLogger{f.T}))
   117  	f.Routes()
   118  	f.Server = httptest.NewServer(f.router)
   119  
   120  	return f
   121  }
   122  
   123  func (f *MockAPIServer) MockRoute(path string, call func(ctx *gin.Context, model interface{}) bool) {
   124  	f.customResponses[path] = call
   125  }
   126  
   127  func (f *MockAPIServer) Client() *Client {
   128  	client := &Client{}
   129  	httpClient := http.Client{}
   130  	if u, err := url.Parse(f.Server.URL); err != nil {
   131  		f.T.Fatal(err.Error())
   132  	} else {
   133  		httpClient.Transport = &CustomTransport{URL: u}
   134  	}
   135  	client.Config.SetHttpClient(&httpClient)
   136  	client.Config.SetLogger(TestLogger{f.T})
   137  	return client
   138  }
   139  
   140  func (f *MockAPIServer) GetFile(file mockFile) (r io.Reader, contentLengthOk bool, contentLength int64, realSize int64, err error) {
   141  	if file.SizeTrust == NullSizeTrust || file.SizeTrust == TrustedSizeValue {
   142  		contentLengthOk = true
   143  	}
   144  
   145  	contentLength = file.File.Size
   146  
   147  	if file.RealSize != nil {
   148  		realSize = *file.RealSize
   149  	} else {
   150  		realSize = contentLength
   151  	}
   152  	r = &randomReader{int(realSize)}
   153  	return
   154  }
   155  
   156  func (f *MockAPIServer) trackRequest(c *gin.Context) {
   157  	f.traceMutex.Lock()
   158  	defer f.traceMutex.Unlock()
   159  	f.TrackRequest[c.FullPath()] = append(f.TrackRequest[c.FullPath()], c.Request.URL.String())
   160  }
   161  
   162  func (f *MockAPIServer) GetRouter() *gin.Engine {
   163  	return f.router
   164  }
   165  
   166  func (f *MockAPIServer) Routes() {
   167  	//Download Context
   168  	f.router.GET("/api/rest/v1/files/*path", func(c *gin.Context) {
   169  		f.trackRequest(c)
   170  		path := strings.TrimPrefix(c.Param("path"), "/")
   171  		if f.customResponse(c, nil) {
   172  			return
   173  		}
   174  		file, ok := f.MockFiles[path]
   175  		if ok {
   176  			if file.Path == "" {
   177  				file.Path = path
   178  				file.DisplayName = filepath.Base(path)
   179  			}
   180  			downloadId := sid.IdHex()
   181  			f.downloads.Store(downloadId, download{Id: downloadId, mockFile: file}.init())
   182  			file.DownloadUri = lib.UrlJoinNoEscape("http://localhost:8080/download", downloadId)
   183  
   184  			if file.MaxConnections != 0 {
   185  				file.MaxConnectionsMutex = &sync.Mutex{}
   186  			}
   187  
   188  			c.JSON(http.StatusOK, file.File)
   189  		} else {
   190  			c.JSON(http.StatusNotFound, nil)
   191  		}
   192  	})
   193  	f.router.GET("/api/rest/v1/folders/*path", func(c *gin.Context) {
   194  		f.trackRequest(c)
   195  		path := strings.TrimPrefix(c.Param("path"), "/")
   196  
   197  		if f.customResponse(c, nil) {
   198  			return
   199  		}
   200  
   201  		var files []files_sdk.File
   202  		for k, v := range f.MockFiles {
   203  			dir, _ := filepath.Split(k)
   204  			if lib.NormalizeForComparison(filepath.Clean(path)) == lib.NormalizeForComparison(filepath.Clean(dir)) {
   205  				if v.Path == "" {
   206  					v.Path = k
   207  					v.DisplayName = filepath.Base(k)
   208  				}
   209  				files = append(files, v.File)
   210  			}
   211  		}
   212  
   213  		if len(files) > 0 {
   214  			c.JSON(http.StatusOK, files)
   215  		} else {
   216  			c.JSON(http.StatusNotFound, nil)
   217  		}
   218  	})
   219  	f.router.GET("/api/rest/v1/file_actions/metadata/*path", func(c *gin.Context) {
   220  		f.trackRequest(c)
   221  		path := strings.TrimPrefix(c.Param("path"), "/")
   222  
   223  		if f.customResponse(c, nil) {
   224  			return
   225  		}
   226  
   227  		file, ok := f.MockFiles[path]
   228  		if ok {
   229  			if file.Path == "" {
   230  				file.Path = path
   231  				file.DisplayName = filepath.Base(path)
   232  			}
   233  			c.JSON(http.StatusOK, file.File)
   234  		} else {
   235  			c.JSON(http.StatusNotFound, nil)
   236  		}
   237  	})
   238  	f.router.GET("/download/:download_id/:download_request_id", func(c *gin.Context) {
   239  		f.trackRequest(c)
   240  		downloadId := c.Param("download_id")
   241  		downloadJob, downloadOk := f.downloads.Load(downloadId)
   242  		if !downloadOk {
   243  			c.JSON(http.StatusNotFound, nil)
   244  			return
   245  		}
   246  		downloadRequestJob, requestOk := downloadJob.Requests.Load(c.Param("download_request_id"))
   247  		if requestOk {
   248  			c.JSON(http.StatusOK, downloadRequestJob)
   249  		} else {
   250  			c.JSON(http.StatusNotFound, nil)
   251  		}
   252  
   253  	})
   254  	f.router.GET("/download/:download_id", func(c *gin.Context) {
   255  		f.trackRequest(c)
   256  		downloadJob, ok := f.downloads.Load(c.Param("download_id"))
   257  		if !ok {
   258  			c.JSON(http.StatusNotFound, nil)
   259  			return
   260  		}
   261  
   262  		if downloadJob.mockFile.MaxConnectionsMutex != nil {
   263  			downloadJob.mockFile.MaxConnectionsMutex.Lock()
   264  		}
   265  
   266  		start, end, okRange := rangeValue(c.Request.Header)
   267  
   268  		reader, contentLengthOk, contentLength, realSize, err := f.GetFile(downloadJob.mockFile)
   269  		if err != nil {
   270  			panic(err)
   271  		}
   272  		status := http.StatusOK
   273  		if okRange {
   274  			if realSize < int64(start) {
   275  				reader = &randomReader{0}
   276  			} else {
   277  				reader = &randomReader{(lo.Min[int]([]int{int(realSize - 1), end}) - start) + 1}
   278  			}
   279  			status = http.StatusPartialContent
   280  		}
   281  		downloadRequestId := sid.IdHex()
   282  		if downloadJob.mockFile.MaxConnections == 0 {
   283  			c.Header("X-Files-Max-Connections", "*")
   284  		} else {
   285  			c.Header("X-Files-Max-Connections", fmt.Sprintf("%v", downloadJob.mockFile.MaxConnections))
   286  		}
   287  
   288  		c.Header("X-Files-Download-Request-Id", downloadRequestId)
   289  		responseError := files_sdk.ResponseError{ErrorMessage: downloadJob.ForceRequestMessage}
   290  		extraHeaders := map[string]string{}
   291  		if contentLengthOk {
   292  			if okRange && contentLength < int64(end) {
   293  				c.Status(http.StatusBadRequest)
   294  			}
   295  
   296  			if okRange {
   297  				extraHeaders["Content-Range"] = fmt.Sprintf("%v-%v/%v", start, end, contentLength)
   298  				contentLength = int64(end-start) + 1
   299  			}
   300  
   301  			c.DataFromReader(status, contentLength, "application/zip, application/octet-stream", reader, extraHeaders)
   302  			downloadJob.Requests.Store(downloadRequestId, responseError)
   303  			if downloadJob.mockFile.MaxConnectionsMutex != nil {
   304  				downloadJob.mockFile.MaxConnectionsMutex.Unlock()
   305  			}
   306  		} else {
   307  			finish := func() {
   308  				if downloadJob.ServerBytesSent != nil {
   309  					responseError.Data.BytesTransferred = *downloadJob.ServerBytesSent
   310  				}
   311  				downloadJob.Requests.Store(downloadRequestId, responseError)
   312  				if downloadJob.mockFile.MaxConnectionsMutex != nil {
   313  					downloadJob.mockFile.MaxConnectionsMutex.Unlock()
   314  				}
   315  			}
   316  			if okRange {
   317  				c.Header("Content-Range", fmt.Sprintf("%v-%v/*", start, end))
   318  			}
   319  			c.Status(status)
   320  			c.Stream(func(w io.Writer) bool {
   321  				buf := make([]byte, 1024*1024)
   322  
   323  				n, err := reader.Read(buf)
   324  				if err == io.EOF {
   325  					responseError.Data.Status = downloadJob.Completed()
   326  					finish()
   327  					return false
   328  				}
   329  				if err != nil && err != io.EOF {
   330  					responseError.Data.Status = "errored"
   331  					finish()
   332  					return false
   333  				}
   334  
   335  				wn, err := w.Write(buf[:n])
   336  				if err != nil {
   337  					responseError.Data.Status = "errored"
   338  					finish()
   339  					return false
   340  				}
   341  
   342  				responseError.Data.BytesTransferred += int64(wn)
   343  
   344  				if err == io.EOF {
   345  					responseError.Data.Status = "errored"
   346  					finish()
   347  					return false
   348  				}
   349  
   350  				return true
   351  			})
   352  		}
   353  	})
   354  	f.router.HEAD("/download/:download_id", func(c *gin.Context) {
   355  		f.trackRequest(c)
   356  		downloadJob, ok := f.downloads.Load(c.Param("download_id"))
   357  		if !ok {
   358  			c.JSON(http.StatusNotFound, nil)
   359  			return
   360  		}
   361  		_, contentLengthOk, contentLength, _, err := f.GetFile(downloadJob.mockFile)
   362  		if err != nil {
   363  			panic(err)
   364  		}
   365  		if contentLengthOk {
   366  			c.Header("Content-Length", fmt.Sprintf("%v", contentLength))
   367  		}
   368  		if downloadJob.mockFile.MaxConnections == 0 {
   369  			c.Header("X-Files-Max-Connections", "*")
   370  		} else {
   371  			c.Header("X-Files-Max-Connections", fmt.Sprintf("%v", downloadJob.mockFile.MaxConnections))
   372  		}
   373  		c.Status(http.StatusOK)
   374  	})
   375  	//	Upload Context
   376  	f.router.POST("/api/rest/v1/files/*path", func(c *gin.Context) {
   377  		f.trackRequest(c)
   378  		path := strings.TrimPrefix(c.Param("path"), "/")
   379  
   380  		var fileCreate files_sdk.FileCreateParams
   381  
   382  		if err := c.BindJSON(&fileCreate); err != nil {
   383  			c.JSON(http.StatusBadRequest, map[string]interface{}{"message": err.Error()})
   384  		}
   385  
   386  		if f.customResponse(c, fileCreate) {
   387  			return
   388  		}
   389  
   390  		file, ok := f.MockFiles[path]
   391  		if ok {
   392  			if file.Path == "" {
   393  				file.Path = path
   394  				file.DisplayName = filepath.Base(path)
   395  			}
   396  			c.JSON(http.StatusOK, file)
   397  		} else {
   398  			c.JSON(http.StatusNotFound, nil)
   399  		}
   400  	})
   401  	f.router.POST("/api/rest/v1/file_actions/begin_upload/*path", func(c *gin.Context) {
   402  		f.trackRequest(c)
   403  		path := strings.TrimPrefix(c.Param("path"), "/")
   404  
   405  		var beginUpload files_sdk.FileBeginUploadParams
   406  
   407  		if err := c.BindJSON(&beginUpload); err != nil {
   408  			c.JSON(http.StatusBadRequest, map[string]interface{}{"message": err.Error()})
   409  		}
   410  
   411  		beginUpload.Path = path
   412  
   413  		if f.customResponse(c, beginUpload) {
   414  			return
   415  		}
   416  
   417  		_, ok := f.MockFiles[path]
   418  		_, parentOk := f.MockFiles[filepath.Dir(path)]
   419  
   420  		if !ok && (filepath.Dir(path) == "." || parentOk || *beginUpload.MkdirParents == true) {
   421  			f.MockFiles[path] = mockFile{File: files_sdk.File{Path: path, DisplayName: filepath.Base(path), Size: beginUpload.Size}}
   422  			ok = true
   423  		}
   424  
   425  		if ok {
   426  			c.JSON(http.StatusOK, files_sdk.FileUploadPartCollection{
   427  				files_sdk.FileUploadPart{
   428  					HttpMethod:    "POST",
   429  					Path:          path,
   430  					UploadUri:     lib.UrlJoinNoEscape(f.Server.URL, "upload", path),
   431  					ParallelParts: lib.Bool(true),
   432  					Expires:       time.Now().Add(time.Hour).Format(time.RFC3339),
   433  				},
   434  			})
   435  		} else {
   436  			c.JSON(http.StatusNotFound, nil)
   437  		}
   438  	})
   439  	f.router.POST("upload/*path", func(c *gin.Context) {
   440  		f.trackRequest(c)
   441  		path := strings.TrimPrefix(c.Param("path"), "/")
   442  
   443  		if f.customResponse(c, nil) {
   444  			return
   445  		}
   446  
   447  		_, ok := f.MockFiles[path]
   448  		if ok {
   449  			ctx, cancel := context.WithTimeout(c, time.Millisecond*100)
   450  			defer cancel()
   451  
   452  			b, err := io.Copy(io.Discard, &readerCtx{r: c.Request.Body, ctx: ctx})
   453  			if err != nil {
   454  				if errors.Is(err, context.DeadlineExceeded) {
   455  					c.XML(http.StatusBadRequest, lib.S3Error{
   456  						XMLName: xml.Name{Local: "Error"},
   457  						Message: "Your socket connection to the server was not read from or written to within the timeout period. Idle connections will be closed.",
   458  						Code:    "RequestTimeout",
   459  					},
   460  					)
   461  				} else {
   462  					c.Data(http.StatusBadRequest, "text", []byte(err.Error()))
   463  				}
   464  			}
   465  			if c.Request.ContentLength != b {
   466  				c.JSON(http.StatusBadRequest, map[string]interface{}{"message": "Content-Length did not match body"})
   467  			}
   468  			c.Header("Etag", sid.IdBase64())
   469  			c.Status(http.StatusOK)
   470  		} else {
   471  			c.JSON(http.StatusNotFound, nil)
   472  		}
   473  	})
   474  	f.router.GET("/ping", func(c *gin.Context) {
   475  		c.Status(http.StatusOK)
   476  	})
   477  }
   478  
   479  type readerCtx struct {
   480  	ctx context.Context
   481  	r   io.Reader
   482  }
   483  
   484  func (r *readerCtx) Read(p []byte) (n int, err error) {
   485  	if err := r.ctx.Err(); err != nil {
   486  		return 0, err
   487  	}
   488  	return r.r.Read(p)
   489  }
   490  
   491  func (f *MockAPIServer) customResponse(c *gin.Context, model interface{}) bool {
   492  	if mock, ok := f.customResponses[c.Request.URL.Path]; ok {
   493  		return mock(c, model)
   494  	}
   495  	return false
   496  }
   497  
   498  func (f *MockAPIServer) Shutdown() {
   499  	f.Server.Close()
   500  }
   501  
   502  type CustomTransport struct {
   503  	http.Transport
   504  	*url.URL
   505  }
   506  
   507  func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   508  	req.URL.Host = t.URL.Host
   509  	req.URL.Scheme = "http"
   510  
   511  	return t.Transport.RoundTrip(req)
   512  }
   513  
   514  func rangeValue(header http.Header) (start, end int, ok bool) {
   515  	r := header.Get("Range")
   516  	if r == "" {
   517  		return
   518  	}
   519  
   520  	r = strings.SplitN(r, "=", 2)[1]
   521  	ranges := strings.Split(r, "-")
   522  	var err error
   523  	start, err = strconv.Atoi(ranges[0])
   524  	if err != nil {
   525  		return
   526  	}
   527  	end, err = strconv.Atoi(ranges[1])
   528  	if err != nil {
   529  		return
   530  	}
   531  
   532  	ok = true
   533  
   534  	return
   535  }