github.com/LagrangeDev/LagrangeGo@v0.0.0-20240512064304-ad4a85e10cb4/client/internal/highway/bdh.go (about)

     1  package highway
     2  
     3  // from https://github.com/Mrs4s/MiraiGo/tree/master/client/internal/highway/bdh.go
     4  
     5  import (
     6  	"crypto/md5"
     7  	"io"
     8  	"strconv"
     9  	"sync"
    10  	"sync/atomic"
    11  
    12  	ftea "github.com/fumiama/gofastTEA"
    13  	"github.com/pkg/errors"
    14  	"golang.org/x/sync/errgroup"
    15  
    16  	"github.com/LagrangeDev/LagrangeGo/client/packets/pb/service/highway"
    17  	"github.com/LagrangeDev/LagrangeGo/internal/proto"
    18  	"github.com/LagrangeDev/LagrangeGo/utils/binary"
    19  )
    20  
    21  const BlockSize = 256 * 1024
    22  
    23  type Transaction struct {
    24  	CommandID uint32
    25  	Body      io.Reader
    26  	Sum       []byte // md5 sum of body
    27  	Size      uint64 // body size
    28  	Ticket    []byte
    29  	LoginSig  []byte
    30  	Ext       []byte
    31  	Encrypt   bool
    32  }
    33  
    34  func (trans *Transaction) encrypt(key []byte) error {
    35  	if !trans.Encrypt {
    36  		return nil
    37  	}
    38  	if len(key) == 0 {
    39  		return errors.New("session key not found. maybe miss some packet?")
    40  	}
    41  	trans.Ext = ftea.NewTeaCipher(key).Encrypt(trans.Ext)
    42  	return nil
    43  }
    44  
    45  func (trans *Transaction) Build(s *Session, offset uint64, length uint32, md5hash []byte) *highway.ReqDataHighwayHead {
    46  	return &highway.ReqDataHighwayHead{
    47  		MsgBaseHead: &highway.DataHighwayHead{
    48  			Version:    1,
    49  			Uin:        proto.Some(strconv.Itoa(int(*s.Uin))),
    50  			Command:    proto.Some(_REQ_CMD_DATA),
    51  			Seq:        proto.Some(s.NextSeq()),
    52  			RetryTimes: proto.Some(uint32(0)),
    53  			AppId:      s.SubAppID,
    54  			DataFlag:   16,
    55  			CommandId:  trans.CommandID,
    56  			// LocaleId:  2052,
    57  		},
    58  		MsgSegHead: &highway.SegHead{
    59  			ServiceId:     proto.Some(uint32(0)),
    60  			Filesize:      trans.Size,
    61  			DataOffset:    proto.Some(offset),
    62  			DataLength:    length,
    63  			RetCode:       proto.Some(uint32(0)),
    64  			ServiceTicket: trans.Ticket,
    65  			Md5:           md5hash,
    66  			FileMd5:       trans.Sum,
    67  			CacheAddr:     proto.Some(uint32(0)),
    68  			CachePort:     proto.Some(uint32(0)),
    69  		},
    70  		BytesReqExtendInfo: trans.Ext,
    71  		MsgLoginSigHead: &highway.LoginSigHead{
    72  			Uint32LoginSigType: 8,
    73  			BytesLoginSig:      trans.LoginSig,
    74  			AppId:              s.AppID,
    75  		},
    76  	}
    77  }
    78  
    79  func (s *Session) uploadSingle(trans *Transaction) ([]byte, error) {
    80  	pc, err := s.selectConn()
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	defer s.putIdleConn(pc)
    85  
    86  	reader := binary.NewNetworkReader(pc.conn)
    87  	var rspExt []byte
    88  	offset := 0
    89  	chunk := make([]byte, BlockSize)
    90  	for {
    91  		chunk = chunk[:cap(chunk)]
    92  		rl, err := io.ReadFull(trans.Body, chunk)
    93  		if rl == 0 {
    94  			break
    95  		}
    96  		if errors.Is(err, io.ErrUnexpectedEOF) {
    97  			chunk = chunk[:rl]
    98  		}
    99  		ch := md5.Sum(chunk)
   100  		head, _ := proto.Marshal(trans.Build(s, uint64(offset), uint32(rl), ch[:]))
   101  		offset += rl
   102  		buffers := Frame(head, chunk)
   103  		_, err = buffers.WriteTo(pc.conn)
   104  		if err != nil {
   105  			return nil, errors.Wrap(err, "write conn error")
   106  		}
   107  		rspHead, err := readResponse(reader)
   108  		if err != nil {
   109  			return nil, errors.Wrap(err, "highway upload error")
   110  		}
   111  		if rspHead.ErrorCode != 0 {
   112  			return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
   113  		}
   114  		if rspHead.BytesRspExtendInfo != nil {
   115  			rspExt = rspHead.BytesRspExtendInfo
   116  		}
   117  		if rspHead.MsgSegHead != nil && rspHead.MsgSegHead.ServiceTicket != nil {
   118  			trans.Ticket = rspHead.MsgSegHead.ServiceTicket
   119  		}
   120  	}
   121  	return rspExt, nil
   122  }
   123  
   124  func (s *Session) Upload(trans *Transaction) ([]byte, error) {
   125  	// encrypt ext data
   126  	if err := trans.encrypt(s.SessionKey); err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	const maxThreadCount = 4
   131  	threadCount := int(trans.Size) / (6 * BlockSize) // 1 thread upload 1.5 MB
   132  	if threadCount > maxThreadCount {
   133  		threadCount = maxThreadCount
   134  	}
   135  	if threadCount < 2 {
   136  		// single thread upload
   137  		return s.uploadSingle(trans)
   138  	}
   139  
   140  	// pick a address
   141  	// TODO: pick smarter
   142  	pc, err := s.selectConn()
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	addr := pc.addr
   147  	s.putIdleConn(pc)
   148  
   149  	var (
   150  		rspExt          []byte
   151  		completedThread uint32
   152  		cond            = sync.NewCond(&sync.Mutex{})
   153  		offset          = uint64(0)
   154  		count           = (trans.Size + BlockSize - 1) / BlockSize
   155  		id              = 0
   156  	)
   157  	doUpload := func() error {
   158  		// send signal complete uploading
   159  		defer func() {
   160  			atomic.AddUint32(&completedThread, 1)
   161  			cond.Signal()
   162  		}()
   163  
   164  		// todo: get from pool?
   165  		pc, err := s.connect(addr)
   166  		if err != nil {
   167  			return err
   168  		}
   169  		defer s.putIdleConn(pc)
   170  
   171  		reader := binary.NewNetworkReader(pc.conn)
   172  		chunk := make([]byte, BlockSize)
   173  		for {
   174  			cond.L.Lock() // lock protect reading
   175  			off := offset
   176  			offset += BlockSize
   177  			id++
   178  			last := uint64(id) == count
   179  			if last { // last
   180  				for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
   181  					cond.Wait()
   182  				}
   183  			} else if uint64(id) > count {
   184  				cond.L.Unlock()
   185  				break
   186  			}
   187  			chunk = chunk[:BlockSize]
   188  			n, err := io.ReadFull(trans.Body, chunk)
   189  			cond.L.Unlock()
   190  
   191  			if n == 0 {
   192  				break
   193  			}
   194  			if errors.Is(err, io.ErrUnexpectedEOF) {
   195  				chunk = chunk[:n]
   196  			}
   197  			ch := md5.Sum(chunk)
   198  			head, _ := proto.Marshal(trans.Build(s, off, uint32(n), ch[:]))
   199  			buffers := Frame(head, chunk)
   200  			_, err = buffers.WriteTo(pc.conn)
   201  			if err != nil {
   202  				return errors.Wrap(err, "write conn error")
   203  			}
   204  			rspHead, err := readResponse(reader)
   205  			if err != nil {
   206  				return errors.Wrap(err, "highway upload error")
   207  			}
   208  			if rspHead.ErrorCode != 0 {
   209  				return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
   210  			}
   211  			if last && rspHead.BytesRspExtendInfo != nil {
   212  				rspExt = rspHead.BytesRspExtendInfo
   213  			}
   214  		}
   215  		return nil
   216  	}
   217  
   218  	group := errgroup.Group{}
   219  	for i := 0; i < threadCount; i++ {
   220  		group.Go(doUpload)
   221  	}
   222  	return rspExt, group.Wait()
   223  }