github.com/Files-com/files-sdk-go/v3@v3.1.81/file/mockserver.go (about)

     1  package file
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/xml"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	files_sdk "github.com/Files-com/files-sdk-go/v3"
    21  	"github.com/Files-com/files-sdk-go/v3/lib"
    22  	"github.com/chilts/sid"
    23  	"github.com/gin-gonic/gin"
    24  	"github.com/samber/lo"
    25  )
    26  
    27  type randomReader struct {
    28  	n int
    29  }
    30  
    31  func (r *randomReader) Read(p []byte) (int, error) {
    32  	if r.n <= 0 {
    33  		return 0, io.EOF
    34  	}
    35  	if len(p) > r.n {
    36  		p = p[:r.n]
    37  	}
    38  
    39  	_, err := rand.Read(p)
    40  	if err != nil {
    41  		return 0, err
    42  	}
    43  
    44  	r.n -= len(p)
    45  	return len(p), nil
    46  }
    47  
    48  type CustomResponse struct {
    49  	Status      int
    50  	Body        []byte
    51  	ContentType string
    52  }
    53  
    54  type MockAPIServer struct {
    55  	router *gin.Engine
    56  	Addr   string
    57  	*httptest.Server
    58  	downloads       *lib.Map[download]
    59  	MockFiles       map[string]mockFile
    60  	customResponses map[string]func(ctx *gin.Context, model interface{}) bool
    61  	*testing.T
    62  	TrackRequest map[string][]string
    63  	traceMutex   *sync.Mutex
    64  }
    65  
    66  type download struct {
    67  	Id string
    68  	mockFile
    69  	Requests *lib.Map[files_sdk.ResponseError]
    70  }
    71  
    72  func (d download) init() download {
    73  	d.Requests = &lib.Map[files_sdk.ResponseError]{}
    74  	return d
    75  }
    76  
    77  type mockFile struct {
    78  	files_sdk.File
    79  	RealSize *int64
    80  	SizeTrust
    81  	ForceRequestStatus  string
    82  	ForceRequestMessage string
    83  	ServerBytesSent     *int64
    84  	MaxConnections      int
    85  	MaxConnectionsMutex *sync.Mutex
    86  }
    87  
    88  func (m mockFile) Completed() string {
    89  	if m.ForceRequestStatus != "" {
    90  		return m.ForceRequestStatus
    91  	}
    92  	return "completed"
    93  }
    94  
    95  type TestLogger struct {
    96  	*testing.T
    97  }
    98  
    99  func (t TestLogger) Printf(format string, args ...any) {
   100  	t.T.Logf(format, args...)
   101  }
   102  
   103  func (t TestLogger) Write(p []byte) (n int, err error) {
   104  	t.T.Log(string(p))
   105  	return len(p), nil
   106  }
   107  
   108  func (f *MockAPIServer) Do() *MockAPIServer {
   109  	gin.SetMode(gin.TestMode)
   110  	f.MockFiles = make(map[string]mockFile)
   111  	f.customResponses = make(map[string]func(ctx *gin.Context, model interface{}) bool)
   112  	f.TrackRequest = make(map[string][]string)
   113  	f.traceMutex = &sync.Mutex{}
   114  	f.downloads = &lib.Map[download]{}
   115  	f.router = gin.New()
   116  	f.router.Use(gin.LoggerWithWriter(TestLogger{f.T}))
   117  	f.Routes()
   118  	f.Server = httptest.NewServer(f.router)
   119  
   120  	return f
   121  }
   122  
   123  func (f *MockAPIServer) MockRoute(path string, call func(ctx *gin.Context, model interface{}) bool) {
   124  	f.traceMutex.Lock()
   125  	defer f.traceMutex.Unlock()
   126  	f.customResponses[path] = call
   127  }
   128  
   129  func (f *MockAPIServer) Client() *Client {
   130  	client := &Client{Config: files_sdk.Config{}.Init()}
   131  	httpClient := http.Client{}
   132  	if u, err := url.Parse(f.Server.URL); err != nil {
   133  		f.T.Fatal(err.Error())
   134  	} else {
   135  		httpClient.Transport = &CustomTransport{URL: u}
   136  	}
   137  	client.Config.Logger = TestLogger{f.T}
   138  	client.Config = client.Config.SetCustomClient(&httpClient)
   139  	return client
   140  }
   141  
   142  func (f *MockAPIServer) GetFile(file mockFile) (r io.Reader, contentLengthOk bool, contentLength int64, realSize int64, err error) {
   143  	if file.SizeTrust == NullSizeTrust || file.SizeTrust == TrustedSizeValue {
   144  		contentLengthOk = true
   145  	}
   146  
   147  	contentLength = file.File.Size
   148  
   149  	if file.RealSize != nil {
   150  		realSize = *file.RealSize
   151  	} else {
   152  		realSize = contentLength
   153  	}
   154  	r = &randomReader{int(realSize)}
   155  	return
   156  }
   157  
   158  func (f *MockAPIServer) trackRequest(c *gin.Context) {
   159  	f.traceMutex.Lock()
   160  	defer f.traceMutex.Unlock()
   161  	f.TrackRequest[c.FullPath()] = append(f.TrackRequest[c.FullPath()], c.Request.URL.String())
   162  }
   163  
   164  func (f *MockAPIServer) GetRouter() *gin.Engine {
   165  	return f.router
   166  }
   167  
   168  func (f *MockAPIServer) Routes() {
   169  	//Download Context
   170  	f.router.GET("/api/rest/v1/files/*path", func(c *gin.Context) {
   171  		f.trackRequest(c)
   172  		path := strings.TrimPrefix(c.Param("path"), "/")
   173  		if f.customResponse(c, nil) {
   174  			return
   175  		}
   176  		file, ok := f.MockFiles[path]
   177  		if ok {
   178  			if file.Path == "" {
   179  				file.Path = path
   180  				file.DisplayName = filepath.Base(path)
   181  			}
   182  			downloadId := sid.IdHex()
   183  			f.downloads.Store(downloadId, download{Id: downloadId, mockFile: file}.init())
   184  			file.DownloadUri = lib.UrlJoinNoEscape("http://localhost:8080/download", downloadId)
   185  
   186  			if file.MaxConnections != 0 {
   187  				file.MaxConnectionsMutex = &sync.Mutex{}
   188  			}
   189  
   190  			c.JSON(http.StatusOK, file.File)
   191  		} else {
   192  			c.JSON(http.StatusNotFound, nil)
   193  		}
   194  	})
   195  	f.router.GET("/api/rest/v1/folders/*path", func(c *gin.Context) {
   196  		f.trackRequest(c)
   197  		path := strings.TrimPrefix(c.Param("path"), "/")
   198  
   199  		if f.customResponse(c, nil) {
   200  			return
   201  		}
   202  
   203  		var files []files_sdk.File
   204  		for k, v := range f.MockFiles {
   205  			dir, _ := filepath.Split(k)
   206  			if lib.NormalizeForComparison(filepath.Clean(path)) == lib.NormalizeForComparison(filepath.Clean(dir)) {
   207  				if v.Path == "" {
   208  					v.Path = k
   209  					v.DisplayName = filepath.Base(k)
   210  				}
   211  				files = append(files, v.File)
   212  			}
   213  		}
   214  
   215  		if len(files) > 0 {
   216  			c.JSON(http.StatusOK, files)
   217  		} else {
   218  			c.JSON(http.StatusNotFound, nil)
   219  		}
   220  	})
   221  	f.router.GET("/api/rest/v1/file_actions/metadata/*path", func(c *gin.Context) {
   222  		f.trackRequest(c)
   223  		path := strings.TrimPrefix(c.Param("path"), "/")
   224  
   225  		if f.customResponse(c, nil) {
   226  			return
   227  		}
   228  
   229  		file, ok := f.MockFiles[path]
   230  		if ok {
   231  			if file.Path == "" {
   232  				file.Path = path
   233  				file.DisplayName = filepath.Base(path)
   234  			}
   235  			c.JSON(http.StatusOK, file.File)
   236  		} else {
   237  			c.JSON(http.StatusNotFound, nil)
   238  		}
   239  	})
   240  	f.router.GET("/download/:download_id/:download_request_id", func(c *gin.Context) {
   241  		f.trackRequest(c)
   242  		downloadId := c.Param("download_id")
   243  		downloadJob, downloadOk := f.downloads.Load(downloadId)
   244  		if !downloadOk {
   245  			c.JSON(http.StatusNotFound, nil)
   246  			return
   247  		}
   248  		downloadRequestJob, requestOk := downloadJob.Requests.Load(c.Param("download_request_id"))
   249  		if requestOk {
   250  			c.JSON(http.StatusOK, downloadRequestJob)
   251  		} else {
   252  			c.JSON(http.StatusNotFound, nil)
   253  		}
   254  
   255  	})
   256  	f.router.GET("/download/:download_id", func(c *gin.Context) {
   257  		f.trackRequest(c)
   258  		downloadJob, ok := f.downloads.Load(c.Param("download_id"))
   259  		if !ok {
   260  			c.JSON(http.StatusNotFound, nil)
   261  			return
   262  		}
   263  
   264  		if downloadJob.mockFile.MaxConnectionsMutex != nil {
   265  			downloadJob.mockFile.MaxConnectionsMutex.Lock()
   266  		}
   267  
   268  		start, end, okRange := rangeValue(c.Request.Header)
   269  
   270  		reader, contentLengthOk, contentLength, realSize, err := f.GetFile(downloadJob.mockFile)
   271  		if err != nil {
   272  			panic(err)
   273  		}
   274  		status := http.StatusOK
   275  		if okRange {
   276  			if realSize < int64(start) {
   277  				reader = &randomReader{0}
   278  			} else {
   279  				reader = &randomReader{(lo.Min[int]([]int{int(realSize - 1), end}) - start) + 1}
   280  			}
   281  			status = http.StatusPartialContent
   282  		}
   283  		downloadRequestId := sid.IdHex()
   284  		if downloadJob.mockFile.MaxConnections == 0 {
   285  			c.Header("X-Files-Max-Connections", "*")
   286  		} else {
   287  			c.Header("X-Files-Max-Connections", fmt.Sprintf("%v", downloadJob.mockFile.MaxConnections))
   288  		}
   289  
   290  		c.Header("X-Files-Download-Request-Id", downloadRequestId)
   291  		responseError := files_sdk.ResponseError{ErrorMessage: downloadJob.ForceRequestMessage}
   292  		extraHeaders := map[string]string{}
   293  		if contentLengthOk {
   294  			if okRange && contentLength < int64(end) {
   295  				c.Status(http.StatusBadRequest)
   296  			}
   297  
   298  			if okRange {
   299  				extraHeaders["Content-Range"] = fmt.Sprintf("%v-%v/%v", start, end, contentLength)
   300  				contentLength = int64(end-start) + 1
   301  			}
   302  
   303  			c.DataFromReader(status, contentLength, "application/zip, application/octet-stream", reader, extraHeaders)
   304  			downloadJob.Requests.Store(downloadRequestId, responseError)
   305  			if downloadJob.mockFile.MaxConnectionsMutex != nil {
   306  				downloadJob.mockFile.MaxConnectionsMutex.Unlock()
   307  			}
   308  		} else {
   309  			finish := func() {
   310  				if downloadJob.ServerBytesSent != nil {
   311  					responseError.Data.BytesTransferred = *downloadJob.ServerBytesSent
   312  				}
   313  				downloadJob.Requests.Store(downloadRequestId, responseError)
   314  				if downloadJob.mockFile.MaxConnectionsMutex != nil {
   315  					downloadJob.mockFile.MaxConnectionsMutex.Unlock()
   316  				}
   317  			}
   318  			if okRange {
   319  				c.Header("Content-Range", fmt.Sprintf("%v-%v/*", start, end))
   320  			}
   321  			c.Status(status)
   322  			c.Stream(func(w io.Writer) bool {
   323  				buf := make([]byte, 1024*1024)
   324  
   325  				n, err := reader.Read(buf)
   326  				if err == io.EOF {
   327  					responseError.Data.Status = downloadJob.Completed()
   328  					finish()
   329  					return false
   330  				}
   331  				if err != nil && err != io.EOF {
   332  					responseError.Data.Status = "errored"
   333  					finish()
   334  					return false
   335  				}
   336  
   337  				wn, err := w.Write(buf[:n])
   338  				if err != nil {
   339  					responseError.Data.Status = "errored"
   340  					finish()
   341  					return false
   342  				}
   343  
   344  				responseError.Data.BytesTransferred += int64(wn)
   345  
   346  				if err == io.EOF {
   347  					responseError.Data.Status = "errored"
   348  					finish()
   349  					return false
   350  				}
   351  
   352  				return true
   353  			})
   354  		}
   355  	})
   356  	f.router.HEAD("/download/:download_id", func(c *gin.Context) {
   357  		f.trackRequest(c)
   358  		downloadJob, ok := f.downloads.Load(c.Param("download_id"))
   359  		if !ok {
   360  			c.JSON(http.StatusNotFound, nil)
   361  			return
   362  		}
   363  		_, contentLengthOk, contentLength, _, err := f.GetFile(downloadJob.mockFile)
   364  		if err != nil {
   365  			panic(err)
   366  		}
   367  		if contentLengthOk {
   368  			c.Header("Content-Length", fmt.Sprintf("%v", contentLength))
   369  		}
   370  		if downloadJob.mockFile.MaxConnections == 0 {
   371  			c.Header("X-Files-Max-Connections", "*")
   372  		} else {
   373  			c.Header("X-Files-Max-Connections", fmt.Sprintf("%v", downloadJob.mockFile.MaxConnections))
   374  		}
   375  		c.Status(http.StatusOK)
   376  	})
   377  	//	Upload Context
   378  	f.router.POST("/api/rest/v1/files/*path", func(c *gin.Context) {
   379  		f.trackRequest(c)
   380  		path := strings.TrimPrefix(c.Param("path"), "/")
   381  
   382  		var fileCreate files_sdk.FileCreateParams
   383  
   384  		if err := c.BindJSON(&fileCreate); err != nil {
   385  			c.JSON(http.StatusBadRequest, map[string]interface{}{"message": err.Error()})
   386  		}
   387  
   388  		if f.customResponse(c, fileCreate) {
   389  			return
   390  		}
   391  
   392  		file, ok := f.MockFiles[path]
   393  		if ok {
   394  			if file.Path == "" {
   395  				file.Path = path
   396  				file.DisplayName = filepath.Base(path)
   397  			}
   398  			c.JSON(http.StatusOK, file)
   399  		} else {
   400  			c.JSON(http.StatusNotFound, nil)
   401  		}
   402  	})
   403  	f.router.POST("/api/rest/v1/file_actions/begin_upload/*path", func(c *gin.Context) {
   404  		f.trackRequest(c)
   405  		path := strings.TrimPrefix(c.Param("path"), "/")
   406  
   407  		var beginUpload files_sdk.FileBeginUploadParams
   408  
   409  		if err := c.BindJSON(&beginUpload); err != nil {
   410  			c.JSON(http.StatusBadRequest, map[string]interface{}{"message": err.Error()})
   411  		}
   412  
   413  		beginUpload.Path = path
   414  
   415  		if f.customResponse(c, beginUpload) {
   416  			return
   417  		}
   418  
   419  		_, ok := f.MockFiles[path]
   420  		_, parentOk := f.MockFiles[filepath.Dir(path)]
   421  
   422  		if !ok && (filepath.Dir(path) == "." || parentOk || *beginUpload.MkdirParents == true) {
   423  			f.MockFiles[path] = mockFile{File: files_sdk.File{Path: path, DisplayName: filepath.Base(path), Size: beginUpload.Size}}
   424  			ok = true
   425  		}
   426  
   427  		if beginUpload.Part == 0 {
   428  			beginUpload.Part = 1
   429  		}
   430  
   431  		if ok {
   432  			c.JSON(http.StatusOK, files_sdk.FileUploadPartCollection{
   433  				files_sdk.FileUploadPart{
   434  					HttpMethod:    "POST",
   435  					Path:          path,
   436  					UploadUri:     fmt.Sprintf("%v?part_number=%v", lib.UrlJoinNoEscape(f.Server.URL, "upload", path), beginUpload.Part),
   437  					ParallelParts: lib.Bool(true),
   438  					Expires:       time.Now().Add(time.Hour).Format(time.RFC3339),
   439  					PartNumber:    beginUpload.Part,
   440  				},
   441  			})
   442  		} else {
   443  			c.JSON(http.StatusNotFound, nil)
   444  		}
   445  	})
   446  	f.router.POST("upload/*path", func(c *gin.Context) {
   447  		f.trackRequest(c)
   448  		path := strings.TrimPrefix(c.Param("path"), "/")
   449  
   450  		if f.customResponse(c, nil) {
   451  			return
   452  		}
   453  
   454  		_, ok := f.MockFiles[path]
   455  		if ok {
   456  			ctx, cancel := context.WithTimeout(c, time.Millisecond*100)
   457  			defer cancel()
   458  
   459  			b, err := io.Copy(io.Discard, &readerCtx{r: c.Request.Body, ctx: ctx})
   460  			if err != nil {
   461  				if errors.Is(err, context.DeadlineExceeded) {
   462  					c.XML(http.StatusBadRequest, lib.S3Error{
   463  						XMLName: xml.Name{Local: "Error"},
   464  						Message: "Your socket connection to the server was not read from or written to within the timeout period. Idle connections will be closed.",
   465  						Code:    "RequestTimeout",
   466  					},
   467  					)
   468  				} else {
   469  					c.Data(http.StatusBadRequest, "text", []byte(err.Error()))
   470  				}
   471  			}
   472  			if c.Request.ContentLength != b {
   473  				c.JSON(http.StatusBadRequest, map[string]interface{}{"message": "Content-Length did not match body"})
   474  			}
   475  			c.Header("Etag", sid.IdBase64())
   476  			c.Status(http.StatusOK)
   477  		} else {
   478  			c.JSON(http.StatusNotFound, nil)
   479  		}
   480  	})
   481  	f.router.GET("/ping", func(c *gin.Context) {
   482  		c.Status(http.StatusOK)
   483  	})
   484  }
   485  
   486  type readerCtx struct {
   487  	ctx context.Context
   488  	r   io.Reader
   489  }
   490  
   491  func (r *readerCtx) Read(p []byte) (n int, err error) {
   492  	if err := r.ctx.Err(); err != nil {
   493  		return 0, err
   494  	}
   495  	return r.r.Read(p)
   496  }
   497  
   498  func (f *MockAPIServer) customResponse(c *gin.Context, model interface{}) bool {
   499  	f.traceMutex.Lock()
   500  	defer f.traceMutex.Unlock()
   501  	if mock, ok := f.customResponses[c.Request.URL.Path]; ok {
   502  		return mock(c, model)
   503  	}
   504  	return false
   505  }
   506  
   507  func (f *MockAPIServer) Shutdown() {
   508  	f.Server.Close()
   509  }
   510  
   511  type CustomTransport struct {
   512  	http.Transport
   513  	*url.URL
   514  }
   515  
   516  func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   517  	req.URL.Host = t.URL.Host
   518  	req.URL.Scheme = "http"
   519  
   520  	return t.Transport.RoundTrip(req)
   521  }
   522  
   523  func rangeValue(header http.Header) (start, end int, ok bool) {
   524  	r := header.Get("Range")
   525  	if r == "" {
   526  		return
   527  	}
   528  
   529  	r = strings.SplitN(r, "=", 2)[1]
   530  	ranges := strings.Split(r, "-")
   531  	var err error
   532  	start, err = strconv.Atoi(ranges[0])
   533  	if err != nil {
   534  		return
   535  	}
   536  	end, err = strconv.Atoi(ranges[1])
   537  	if err != nil {
   538  		return
   539  	}
   540  
   541  	ok = true
   542  
   543  	return
   544  }