github.com/Mrs4s/MiraiGo@v0.0.0-20240226124653-54bdd873e3fe/client/internal/highway/bdh.go (about) 1 package highway 2 3 import ( 4 "crypto/md5" 5 "io" 6 "sync" 7 "sync/atomic" 8 9 "github.com/pkg/errors" 10 "golang.org/x/sync/errgroup" 11 12 "github.com/Mrs4s/MiraiGo/binary" 13 "github.com/Mrs4s/MiraiGo/client/pb" 14 "github.com/Mrs4s/MiraiGo/internal/proto" 15 ) 16 17 type Transaction struct { 18 CommandID int32 19 Body io.Reader 20 Sum []byte // md5 sum of body 21 Size int64 // body size 22 Ticket []byte 23 Ext []byte 24 Encrypt bool 25 } 26 27 func (bdh *Transaction) encrypt(key []byte) error { 28 if !bdh.Encrypt { 29 return nil 30 } 31 if len(key) == 0 { 32 return errors.New("session key not found. maybe miss some packet?") 33 } 34 bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext) 35 return nil 36 } 37 38 func (s *Session) uploadSingle(trans Transaction) ([]byte, error) { 39 pc, err := s.selectConn() 40 if err != nil { 41 return nil, err 42 } 43 defer s.putIdleConn(pc) 44 45 reader := binary.NewNetworkReader(pc.conn) 46 const chunkSize = 128 * 1024 47 var rspExt []byte 48 offset := 0 49 chunk := make([]byte, chunkSize) 50 for { 51 chunk = chunk[:cap(chunk)] 52 rl, err := io.ReadFull(trans.Body, chunk) 53 if rl == 0 { 54 break 55 } 56 if errors.Is(err, io.ErrUnexpectedEOF) { 57 chunk = chunk[:rl] 58 } 59 ch := md5.Sum(chunk) 60 head, _ := proto.Marshal(&pb.ReqDataHighwayHead{ 61 MsgBasehead: &pb.DataHighwayHead{ 62 Version: 1, 63 Uin: s.Uin, 64 Command: _REQ_CMD_DATA, 65 Seq: s.nextSeq(), 66 Appid: s.AppID, 67 Dataflag: 4096, 68 CommandId: trans.CommandID, 69 LocaleId: 2052, 70 }, 71 MsgSeghead: &pb.SegHead{ 72 Filesize: trans.Size, 73 Dataoffset: int64(offset), 74 Datalength: int32(rl), 75 Serviceticket: trans.Ticket, 76 Md5: ch[:], 77 FileMd5: trans.Sum, 78 }, 79 ReqExtendinfo: trans.Ext, 80 }) 81 offset += rl 82 buffers := frame(head, chunk) 83 _, err = buffers.WriteTo(pc.conn) 84 if err != nil { 85 return nil, errors.Wrap(err, "write conn error") 86 } 87 rspHead, err := readResponse(reader) 88 if err != nil { 89 return nil, errors.Wrap(err, "highway upload error") 90 } 91 if rspHead.ErrorCode != 0 { 92 return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode) 93 } 94 if rspHead.RspExtendinfo != nil { 95 rspExt = rspHead.RspExtendinfo 96 } 97 if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil { 98 trans.Ticket = rspHead.MsgSeghead.Serviceticket 99 } 100 } 101 return rspExt, nil 102 } 103 104 func (s *Session) Upload(trans Transaction) ([]byte, error) { 105 // encrypt ext data 106 if err := trans.encrypt(s.SessionKey); err != nil { 107 return nil, err 108 } 109 110 const maxThreadCount = 4 111 threadCount := int(trans.Size) / (3 * 512 * 1024) // 1 thread upload 1.5 MB 112 if threadCount > maxThreadCount { 113 threadCount = maxThreadCount 114 } 115 if threadCount < 2 { 116 // single thread upload 117 return s.uploadSingle(trans) 118 } 119 120 // pick a address 121 // TODO: pick smarter 122 pc, err := s.selectConn() 123 if err != nil { 124 return nil, err 125 } 126 addr := pc.addr 127 s.putIdleConn(pc) 128 129 const blockSize int64 = 256 * 1024 130 var ( 131 rspExt []byte 132 completedThread uint32 133 cond = sync.NewCond(&sync.Mutex{}) 134 offset = int64(0) 135 count = (trans.Size + blockSize - 1) / blockSize 136 id = 0 137 ) 138 doUpload := func() error { 139 // send signal complete uploading 140 defer func() { 141 atomic.AddUint32(&completedThread, 1) 142 cond.Signal() 143 }() 144 145 // todo: get from pool? 146 pc, err := s.connect(addr) 147 if err != nil { 148 return err 149 } 150 defer s.putIdleConn(pc) 151 152 reader := binary.NewNetworkReader(pc.conn) 153 chunk := make([]byte, blockSize) 154 for { 155 cond.L.Lock() // lock protect reading 156 off := offset 157 offset += blockSize 158 id++ 159 last := int64(id) == count 160 if last { // last 161 for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) { 162 cond.Wait() 163 } 164 } else if int64(id) > count { 165 cond.L.Unlock() 166 break 167 } 168 chunk = chunk[:blockSize] 169 n, err := io.ReadFull(trans.Body, chunk) 170 cond.L.Unlock() 171 172 if n == 0 { 173 break 174 } 175 if errors.Is(err, io.ErrUnexpectedEOF) { 176 chunk = chunk[:n] 177 } 178 ch := md5.Sum(chunk) 179 head, _ := proto.Marshal(&pb.ReqDataHighwayHead{ 180 MsgBasehead: &pb.DataHighwayHead{ 181 Version: 1, 182 Uin: s.Uin, 183 Command: _REQ_CMD_DATA, 184 Seq: s.nextSeq(), 185 Appid: s.AppID, 186 Dataflag: 4096, 187 CommandId: trans.CommandID, 188 LocaleId: 2052, 189 }, 190 MsgSeghead: &pb.SegHead{ 191 Filesize: trans.Size, 192 Dataoffset: off, 193 Datalength: int32(n), 194 Serviceticket: trans.Ticket, 195 Md5: ch[:], 196 FileMd5: trans.Sum, 197 }, 198 ReqExtendinfo: trans.Ext, 199 }) 200 buffers := frame(head, chunk) 201 _, err = buffers.WriteTo(pc.conn) 202 if err != nil { 203 return errors.Wrap(err, "write conn error") 204 } 205 rspHead, err := readResponse(reader) 206 if err != nil { 207 return errors.Wrap(err, "highway upload error") 208 } 209 if rspHead.ErrorCode != 0 { 210 return errors.Errorf("upload failed: %d", rspHead.ErrorCode) 211 } 212 if last && rspHead.RspExtendinfo != nil { 213 rspExt = rspHead.RspExtendinfo 214 } 215 } 216 return nil 217 } 218 219 group := errgroup.Group{} 220 for i := 0; i < threadCount; i++ { 221 group.Go(doUpload) 222 } 223 return rspExt, group.Wait() 224 }