github.com/LagrangeDev/LagrangeGo@v0.0.0-20240512064304-ad4a85e10cb4/client/internal/highway/bdh.go (about) 1 package highway 2 3 // from https://github.com/Mrs4s/MiraiGo/tree/master/client/internal/highway/bdh.go 4 5 import ( 6 "crypto/md5" 7 "io" 8 "strconv" 9 "sync" 10 "sync/atomic" 11 12 ftea "github.com/fumiama/gofastTEA" 13 "github.com/pkg/errors" 14 "golang.org/x/sync/errgroup" 15 16 "github.com/LagrangeDev/LagrangeGo/client/packets/pb/service/highway" 17 "github.com/LagrangeDev/LagrangeGo/internal/proto" 18 "github.com/LagrangeDev/LagrangeGo/utils/binary" 19 ) 20 21 const BlockSize = 256 * 1024 22 23 type Transaction struct { 24 CommandID uint32 25 Body io.Reader 26 Sum []byte // md5 sum of body 27 Size uint64 // body size 28 Ticket []byte 29 LoginSig []byte 30 Ext []byte 31 Encrypt bool 32 } 33 34 func (trans *Transaction) encrypt(key []byte) error { 35 if !trans.Encrypt { 36 return nil 37 } 38 if len(key) == 0 { 39 return errors.New("session key not found. maybe miss some packet?") 40 } 41 trans.Ext = ftea.NewTeaCipher(key).Encrypt(trans.Ext) 42 return nil 43 } 44 45 func (trans *Transaction) Build(s *Session, offset uint64, length uint32, md5hash []byte) *highway.ReqDataHighwayHead { 46 return &highway.ReqDataHighwayHead{ 47 MsgBaseHead: &highway.DataHighwayHead{ 48 Version: 1, 49 Uin: proto.Some(strconv.Itoa(int(*s.Uin))), 50 Command: proto.Some(_REQ_CMD_DATA), 51 Seq: proto.Some(s.NextSeq()), 52 RetryTimes: proto.Some(uint32(0)), 53 AppId: s.SubAppID, 54 DataFlag: 16, 55 CommandId: trans.CommandID, 56 // LocaleId: 2052, 57 }, 58 MsgSegHead: &highway.SegHead{ 59 ServiceId: proto.Some(uint32(0)), 60 Filesize: trans.Size, 61 DataOffset: proto.Some(offset), 62 DataLength: length, 63 RetCode: proto.Some(uint32(0)), 64 ServiceTicket: trans.Ticket, 65 Md5: md5hash, 66 FileMd5: trans.Sum, 67 CacheAddr: proto.Some(uint32(0)), 68 CachePort: proto.Some(uint32(0)), 69 }, 70 BytesReqExtendInfo: trans.Ext, 71 MsgLoginSigHead: &highway.LoginSigHead{ 72 Uint32LoginSigType: 8, 73 BytesLoginSig: trans.LoginSig, 74 AppId: s.AppID, 75 }, 76 } 77 } 78 79 func (s *Session) uploadSingle(trans *Transaction) ([]byte, error) { 80 pc, err := s.selectConn() 81 if err != nil { 82 return nil, err 83 } 84 defer s.putIdleConn(pc) 85 86 reader := binary.NewNetworkReader(pc.conn) 87 var rspExt []byte 88 offset := 0 89 chunk := make([]byte, BlockSize) 90 for { 91 chunk = chunk[:cap(chunk)] 92 rl, err := io.ReadFull(trans.Body, chunk) 93 if rl == 0 { 94 break 95 } 96 if errors.Is(err, io.ErrUnexpectedEOF) { 97 chunk = chunk[:rl] 98 } 99 ch := md5.Sum(chunk) 100 head, _ := proto.Marshal(trans.Build(s, uint64(offset), uint32(rl), ch[:])) 101 offset += rl 102 buffers := Frame(head, chunk) 103 _, err = buffers.WriteTo(pc.conn) 104 if err != nil { 105 return nil, errors.Wrap(err, "write conn error") 106 } 107 rspHead, err := readResponse(reader) 108 if err != nil { 109 return nil, errors.Wrap(err, "highway upload error") 110 } 111 if rspHead.ErrorCode != 0 { 112 return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode) 113 } 114 if rspHead.BytesRspExtendInfo != nil { 115 rspExt = rspHead.BytesRspExtendInfo 116 } 117 if rspHead.MsgSegHead != nil && rspHead.MsgSegHead.ServiceTicket != nil { 118 trans.Ticket = rspHead.MsgSegHead.ServiceTicket 119 } 120 } 121 return rspExt, nil 122 } 123 124 func (s *Session) Upload(trans *Transaction) ([]byte, error) { 125 // encrypt ext data 126 if err := trans.encrypt(s.SessionKey); err != nil { 127 return nil, err 128 } 129 130 const maxThreadCount = 4 131 threadCount := int(trans.Size) / (6 * BlockSize) // 1 thread upload 1.5 MB 132 if threadCount > maxThreadCount { 133 threadCount = maxThreadCount 134 } 135 if threadCount < 2 { 136 // single thread upload 137 return s.uploadSingle(trans) 138 } 139 140 // pick a address 141 // TODO: pick smarter 142 pc, err := s.selectConn() 143 if err != nil { 144 return nil, err 145 } 146 addr := pc.addr 147 s.putIdleConn(pc) 148 149 var ( 150 rspExt []byte 151 completedThread uint32 152 cond = sync.NewCond(&sync.Mutex{}) 153 offset = uint64(0) 154 count = (trans.Size + BlockSize - 1) / BlockSize 155 id = 0 156 ) 157 doUpload := func() error { 158 // send signal complete uploading 159 defer func() { 160 atomic.AddUint32(&completedThread, 1) 161 cond.Signal() 162 }() 163 164 // todo: get from pool? 165 pc, err := s.connect(addr) 166 if err != nil { 167 return err 168 } 169 defer s.putIdleConn(pc) 170 171 reader := binary.NewNetworkReader(pc.conn) 172 chunk := make([]byte, BlockSize) 173 for { 174 cond.L.Lock() // lock protect reading 175 off := offset 176 offset += BlockSize 177 id++ 178 last := uint64(id) == count 179 if last { // last 180 for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) { 181 cond.Wait() 182 } 183 } else if uint64(id) > count { 184 cond.L.Unlock() 185 break 186 } 187 chunk = chunk[:BlockSize] 188 n, err := io.ReadFull(trans.Body, chunk) 189 cond.L.Unlock() 190 191 if n == 0 { 192 break 193 } 194 if errors.Is(err, io.ErrUnexpectedEOF) { 195 chunk = chunk[:n] 196 } 197 ch := md5.Sum(chunk) 198 head, _ := proto.Marshal(trans.Build(s, off, uint32(n), ch[:])) 199 buffers := Frame(head, chunk) 200 _, err = buffers.WriteTo(pc.conn) 201 if err != nil { 202 return errors.Wrap(err, "write conn error") 203 } 204 rspHead, err := readResponse(reader) 205 if err != nil { 206 return errors.Wrap(err, "highway upload error") 207 } 208 if rspHead.ErrorCode != 0 { 209 return errors.Errorf("upload failed: %d", rspHead.ErrorCode) 210 } 211 if last && rspHead.BytesRspExtendInfo != nil { 212 rspExt = rspHead.BytesRspExtendInfo 213 } 214 } 215 return nil 216 } 217 218 group := errgroup.Group{} 219 for i := 0; i < threadCount; i++ { 220 group.Go(doUpload) 221 } 222 return rspExt, group.Wait() 223 }