github.com/Mrs4s/MiraiGo@v0.0.0-20240226124653-54bdd873e3fe/client/internal/highway/bdh.go (about)

     1  package highway
     2  
     3  import (
     4  	"crypto/md5"
     5  	"io"
     6  	"sync"
     7  	"sync/atomic"
     8  
     9  	"github.com/pkg/errors"
    10  	"golang.org/x/sync/errgroup"
    11  
    12  	"github.com/Mrs4s/MiraiGo/binary"
    13  	"github.com/Mrs4s/MiraiGo/client/pb"
    14  	"github.com/Mrs4s/MiraiGo/internal/proto"
    15  )
    16  
    17  type Transaction struct {
    18  	CommandID int32
    19  	Body      io.Reader
    20  	Sum       []byte // md5 sum of body
    21  	Size      int64  // body size
    22  	Ticket    []byte
    23  	Ext       []byte
    24  	Encrypt   bool
    25  }
    26  
    27  func (bdh *Transaction) encrypt(key []byte) error {
    28  	if !bdh.Encrypt {
    29  		return nil
    30  	}
    31  	if len(key) == 0 {
    32  		return errors.New("session key not found. maybe miss some packet?")
    33  	}
    34  	bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
    35  	return nil
    36  }
    37  
    38  func (s *Session) uploadSingle(trans Transaction) ([]byte, error) {
    39  	pc, err := s.selectConn()
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	defer s.putIdleConn(pc)
    44  
    45  	reader := binary.NewNetworkReader(pc.conn)
    46  	const chunkSize = 128 * 1024
    47  	var rspExt []byte
    48  	offset := 0
    49  	chunk := make([]byte, chunkSize)
    50  	for {
    51  		chunk = chunk[:cap(chunk)]
    52  		rl, err := io.ReadFull(trans.Body, chunk)
    53  		if rl == 0 {
    54  			break
    55  		}
    56  		if errors.Is(err, io.ErrUnexpectedEOF) {
    57  			chunk = chunk[:rl]
    58  		}
    59  		ch := md5.Sum(chunk)
    60  		head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
    61  			MsgBasehead: &pb.DataHighwayHead{
    62  				Version:   1,
    63  				Uin:       s.Uin,
    64  				Command:   _REQ_CMD_DATA,
    65  				Seq:       s.nextSeq(),
    66  				Appid:     s.AppID,
    67  				Dataflag:  4096,
    68  				CommandId: trans.CommandID,
    69  				LocaleId:  2052,
    70  			},
    71  			MsgSeghead: &pb.SegHead{
    72  				Filesize:      trans.Size,
    73  				Dataoffset:    int64(offset),
    74  				Datalength:    int32(rl),
    75  				Serviceticket: trans.Ticket,
    76  				Md5:           ch[:],
    77  				FileMd5:       trans.Sum,
    78  			},
    79  			ReqExtendinfo: trans.Ext,
    80  		})
    81  		offset += rl
    82  		buffers := frame(head, chunk)
    83  		_, err = buffers.WriteTo(pc.conn)
    84  		if err != nil {
    85  			return nil, errors.Wrap(err, "write conn error")
    86  		}
    87  		rspHead, err := readResponse(reader)
    88  		if err != nil {
    89  			return nil, errors.Wrap(err, "highway upload error")
    90  		}
    91  		if rspHead.ErrorCode != 0 {
    92  			return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
    93  		}
    94  		if rspHead.RspExtendinfo != nil {
    95  			rspExt = rspHead.RspExtendinfo
    96  		}
    97  		if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil {
    98  			trans.Ticket = rspHead.MsgSeghead.Serviceticket
    99  		}
   100  	}
   101  	return rspExt, nil
   102  }
   103  
   104  func (s *Session) Upload(trans Transaction) ([]byte, error) {
   105  	// encrypt ext data
   106  	if err := trans.encrypt(s.SessionKey); err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	const maxThreadCount = 4
   111  	threadCount := int(trans.Size) / (3 * 512 * 1024) // 1 thread upload 1.5 MB
   112  	if threadCount > maxThreadCount {
   113  		threadCount = maxThreadCount
   114  	}
   115  	if threadCount < 2 {
   116  		// single thread upload
   117  		return s.uploadSingle(trans)
   118  	}
   119  
   120  	// pick a address
   121  	// TODO: pick smarter
   122  	pc, err := s.selectConn()
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	addr := pc.addr
   127  	s.putIdleConn(pc)
   128  
   129  	const blockSize int64 = 256 * 1024
   130  	var (
   131  		rspExt          []byte
   132  		completedThread uint32
   133  		cond            = sync.NewCond(&sync.Mutex{})
   134  		offset          = int64(0)
   135  		count           = (trans.Size + blockSize - 1) / blockSize
   136  		id              = 0
   137  	)
   138  	doUpload := func() error {
   139  		// send signal complete uploading
   140  		defer func() {
   141  			atomic.AddUint32(&completedThread, 1)
   142  			cond.Signal()
   143  		}()
   144  
   145  		// todo: get from pool?
   146  		pc, err := s.connect(addr)
   147  		if err != nil {
   148  			return err
   149  		}
   150  		defer s.putIdleConn(pc)
   151  
   152  		reader := binary.NewNetworkReader(pc.conn)
   153  		chunk := make([]byte, blockSize)
   154  		for {
   155  			cond.L.Lock() // lock protect reading
   156  			off := offset
   157  			offset += blockSize
   158  			id++
   159  			last := int64(id) == count
   160  			if last { // last
   161  				for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
   162  					cond.Wait()
   163  				}
   164  			} else if int64(id) > count {
   165  				cond.L.Unlock()
   166  				break
   167  			}
   168  			chunk = chunk[:blockSize]
   169  			n, err := io.ReadFull(trans.Body, chunk)
   170  			cond.L.Unlock()
   171  
   172  			if n == 0 {
   173  				break
   174  			}
   175  			if errors.Is(err, io.ErrUnexpectedEOF) {
   176  				chunk = chunk[:n]
   177  			}
   178  			ch := md5.Sum(chunk)
   179  			head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
   180  				MsgBasehead: &pb.DataHighwayHead{
   181  					Version:   1,
   182  					Uin:       s.Uin,
   183  					Command:   _REQ_CMD_DATA,
   184  					Seq:       s.nextSeq(),
   185  					Appid:     s.AppID,
   186  					Dataflag:  4096,
   187  					CommandId: trans.CommandID,
   188  					LocaleId:  2052,
   189  				},
   190  				MsgSeghead: &pb.SegHead{
   191  					Filesize:      trans.Size,
   192  					Dataoffset:    off,
   193  					Datalength:    int32(n),
   194  					Serviceticket: trans.Ticket,
   195  					Md5:           ch[:],
   196  					FileMd5:       trans.Sum,
   197  				},
   198  				ReqExtendinfo: trans.Ext,
   199  			})
   200  			buffers := frame(head, chunk)
   201  			_, err = buffers.WriteTo(pc.conn)
   202  			if err != nil {
   203  				return errors.Wrap(err, "write conn error")
   204  			}
   205  			rspHead, err := readResponse(reader)
   206  			if err != nil {
   207  				return errors.Wrap(err, "highway upload error")
   208  			}
   209  			if rspHead.ErrorCode != 0 {
   210  				return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
   211  			}
   212  			if last && rspHead.RspExtendinfo != nil {
   213  				rspExt = rspHead.RspExtendinfo
   214  			}
   215  		}
   216  		return nil
   217  	}
   218  
   219  	group := errgroup.Group{}
   220  	for i := 0; i < threadCount; i++ {
   221  		group.Go(doUpload)
   222  	}
   223  	return rspExt, group.Wait()
   224  }