github.com/qjfoidnh/BaiduPCS-Go@v0.0.0-20231011165705-caa18a3765f3/requester/multipartreader/multipartreader.go (about)

     1  // Package multipartreader helps you encode large files in MIME multipart format
     2  // without reading the entire content into memory.
     3  package multipartreader
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"github.com/qjfoidnh/BaiduPCS-Go/requester/rio"
     9  	"io"
    10  	"mime/multipart"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  type (
    17  	// MultipartReader MIME multipart format
    18  	MultipartReader struct {
    19  		length      int64
    20  		contentType string
    21  		boundary    string
    22  
    23  		formBody  string
    24  		parts     []*part
    25  		part64s   []*part64
    26  		formClose string
    27  
    28  		mu          sync.Mutex
    29  		closed      bool
    30  		multiReader io.Reader
    31  	}
    32  
    33  	part struct {
    34  		form      string
    35  		readerlen rio.ReaderLen
    36  	}
    37  
    38  	part64 struct {
    39  		form        string
    40  		readerlen64 rio.ReaderLen64
    41  	}
    42  )
    43  
    44  // NewMultipartReader 返回初始化的 *MultipartReader
    45  func NewMultipartReader() (mr *MultipartReader) {
    46  	builder := &strings.Builder{}
    47  	writer := multipart.NewWriter(builder)
    48  	mr = &MultipartReader{
    49  		contentType: writer.FormDataContentType(),
    50  		boundary:    writer.Boundary(),
    51  	}
    52  
    53  	mr.length += int64(builder.Len())
    54  	mr.formBody = builder.String()
    55  	return
    56  }
    57  
    58  // AddFormFeild 增加 form 表单
    59  func (mr *MultipartReader) AddFormFeild(fieldname string, readerlen rio.ReaderLen) {
    60  	if readerlen == nil {
    61  		return
    62  	}
    63  
    64  	mpart := &part{
    65  		form:      fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"\r\n\r\n", mr.boundary, fieldname),
    66  		readerlen: readerlen,
    67  	}
    68  	atomic.AddInt64(&mr.length, int64(len(mpart.form)+mpart.readerlen.Len()))
    69  	mr.parts = append(mr.parts, mpart)
    70  }
    71  
    72  // AddFormFile 增加 form 文件表单
    73  func (mr *MultipartReader) AddFormFile(fieldname, filename string, readerlen64 rio.ReaderLen64) {
    74  	if readerlen64 == nil {
    75  		return
    76  	}
    77  
    78  	mpart64 := &part64{
    79  		form:        fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\n\r\n", mr.boundary, fieldname, filename),
    80  		readerlen64: readerlen64,
    81  	}
    82  	atomic.AddInt64(&mr.length, int64(len(mpart64.form))+mpart64.readerlen64.Len())
    83  	mr.part64s = append(mr.part64s, mpart64)
    84  }
    85  
    86  //CloseMultipart 关闭multipartreader
    87  func (mr *MultipartReader) CloseMultipart() error {
    88  	mr.mu.Lock()
    89  	defer mr.mu.Unlock()
    90  	if mr.closed {
    91  		return errors.New("multipartreader already closed")
    92  	}
    93  
    94  	mr.formClose = "\r\n--" + mr.boundary + "--\r\n"
    95  	atomic.AddInt64(&mr.length, int64(len(mr.formClose)))
    96  
    97  	numReaders := 0
    98  	if mr.formBody != "" {
    99  		numReaders++
   100  	}
   101  	numReaders += 2*len(mr.parts) + 2*len(mr.part64s)
   102  	if mr.formClose != "" {
   103  		numReaders++
   104  	}
   105  
   106  	readers := make([]io.Reader, 0, numReaders)
   107  	readers = append(readers, strings.NewReader(mr.formBody))
   108  	for k := range mr.parts {
   109  		readers = append(readers, strings.NewReader(mr.parts[k].form), mr.parts[k].readerlen)
   110  	}
   111  	for k := range mr.part64s {
   112  		readers = append(readers, strings.NewReader(mr.part64s[k].form), mr.part64s[k].readerlen64)
   113  	}
   114  	readers = append(readers, strings.NewReader(mr.formClose))
   115  	mr.multiReader = io.MultiReader(readers...)
   116  
   117  	mr.closed = true
   118  	return nil
   119  }
   120  
   121  //ContentType 返回Content-Type
   122  func (mr *MultipartReader) ContentType() string {
   123  	return mr.contentType
   124  }
   125  
   126  func (mr *MultipartReader) Read(p []byte) (n int, err error) {
   127  	if !mr.closed {
   128  		return 0, errors.New("multipartreader not closed")
   129  	}
   130  	n, err = mr.multiReader.Read(p)
   131  	return n, err
   132  }
   133  
   134  // Len 返回表单内容总长度
   135  func (mr *MultipartReader) Len() int64 {
   136  	return atomic.LoadInt64(&mr.length)
   137  }