github.com/swishcloud/filesync@v0.0.0-20231002120458-6ade2feed6f9/session/session.go (about) 1 package session 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log" 10 "net" 11 "os" 12 "strconv" 13 "time" 14 15 "github.com/swishcloud/filesync/message" 16 "github.com/swishcloud/filesync/x" 17 "github.com/swishcloud/gostudy/common" 18 ) 19 20 type Session struct { 21 _c net.Conn 22 written int64 23 read int64 24 last_msg_rest int64 25 presentWriteProgress PresentWriteProgress 26 write_speed_counter int64 27 write_speed int64 28 write_speed_clear_time time.Time 29 closed bool 30 } 31 type PresentWriteProgress func(n int) 32 33 func NewSession(c net.Conn) *Session { 34 s := new(Session) 35 s._c = c 36 go s.speedTimer() 37 return s 38 } 39 func (s *Session) speedTimer() { 40 for !s.closed { 41 time.Sleep(time.Second * 1) 42 s.write_speed = s.write_speed_counter 43 s.write_speed_counter = 0 44 s.write_speed_clear_time = time.Now() 45 } 46 } 47 func (s *Session) Close() { 48 err := s._c.Close() 49 if err != nil { 50 log.Fatal(err) 51 } 52 s.closed = true 53 } 54 func (s *Session) ReadMessage() (*message.Message, error) { 55 if s.last_msg_rest != 0 { 56 return nil, errors.New("last read not completed") 57 } 58 var size_b []byte 59 size_buf := new(bytes.Buffer) 60 for { 61 _, err := io.CopyN(size_buf, s, 1) 62 if err != nil { 63 return nil, err 64 } 65 size_b = size_buf.Bytes() 66 if size_b[len(size_b)-1] == 0 { 67 size_b = size_b[:len(size_b)-1] 68 break 69 } 70 } 71 size, err := strconv.ParseInt(string(size_b), 16, 64) 72 if err != nil { 73 return nil, err 74 } 75 76 msg_buf := new(bytes.Buffer) 77 _, err = io.CopyN(msg_buf, s, int64(size)) 78 if err != nil { 79 return nil, err 80 } 81 msg_b := msg_buf.Bytes() 82 msg, err := message.ReadMessage(bytes.NewReader(msg_b)) 83 if err != nil { 84 return nil, err 85 } 86 s.last_msg_rest = msg.BodySize 87 return msg, nil 88 } 89 func (s *Session) Write(p []byte) (n int, err error) { 90 t := time.Now() 91 n, err = s._c.Write(p) 92 if s.presentWriteProgress != nil { 93 s.presentWriteProgress(n) 94 } 95 s.written += int64(n) 96 if !t.Before(s.write_speed_clear_time) { 97 s.write_speed_counter += int64(n) 98 } 99 return n, err 100 } 101 func (s *Session) Read(p []byte) (n int, err error) { 102 n, err = s._c.Read(p) 103 s.read += int64(n) 104 s.last_msg_rest -= int64(n) 105 return n, err 106 } 107 func (s *Session) SendMessage(msg *message.Message, payload io.Reader, payload_size int64) error { 108 if msg.MsgType == 0 { 109 return errors.New("message type value must be non zero") 110 } 111 if payload == nil && payload_size != 0 { 112 return errors.New("parameter error") 113 } 114 115 msg.BodySize = payload_size 116 117 msg_b, err := x.Encode(msg) 118 if err != nil { 119 return err 120 } 121 122 size_b := []byte(strconv.FormatInt((int64)(len(msg_b)), 16)) 123 124 _, err = io.CopyN(s, bytes.NewReader(size_b), int64(len(size_b))) 125 if err != nil { 126 return err 127 } 128 129 _, err = io.CopyN(s, bytes.NewReader([]byte{0}), 1) 130 if err != nil { 131 return err 132 } 133 _, err = io.CopyN(s, bytes.NewReader(msg_b), int64(len(msg_b))) 134 if err != nil { 135 return err 136 } 137 if payload != nil && payload_size > 0 { 138 written := int64(0) 139 total := msg.BodySize 140 s.presentWriteProgress = func(n int) { 141 written += int64(n) 142 percent := int(float64(written) / float64(total) * 100) 143 s, u := common.FormatByteSize(s.write_speed) 144 fmt.Print("\r") 145 info := fmt.Sprintf("sent %d/%d bytes %d%%,%s %s/s ", written, total, percent, s, u) 146 info = common.StringLimitLen(info, 50) 147 fmt.Print(info) 148 } 149 n, err := io.CopyN(s, payload, msg.BodySize) 150 fmt.Println() 151 s.presentWriteProgress = nil 152 if err != nil { 153 return err 154 } 155 if n != payload_size { 156 return errors.New(fmt.Sprintf("unexpected error:payload size is %d bytes,but written %d bytes", payload_size, n)) 157 } 158 } 159 return nil 160 } 161 func (s *Session) Ack() error { 162 msg := message.NewMessage(message.MT_ACK) 163 return s.Send(msg, nil) 164 } 165 func (s *Session) SendFile(file_path string, pre_send func(filename string, md5 string, size int64) (offset int64, send bool)) error { 166 msg := message.NewMessage(message.MT_FILE) 167 md5, err := x.Hash_file_md5(file_path) 168 if err != nil { 169 return err 170 } 171 msg.Header["md5"] = md5 172 f, err := os.Open(file_path) 173 defer f.Close() 174 if err != nil { 175 return err 176 } 177 file_info, err := f.Stat() 178 if err != nil { 179 return err 180 } 181 msg.Header["file_name"] = file_info.Name() 182 payload_size := file_info.Size() 183 if pre_send != nil { 184 offset, ok := pre_send(file_info.Name(), md5, file_info.Size()) 185 if !ok { 186 return nil 187 } 188 _, err := f.Seek(offset, 1) 189 if err != nil { 190 return err 191 } 192 payload_size -= offset 193 } 194 return s.SendMessage(msg, f, payload_size) 195 } 196 func (s *Session) Send(msg *message.Message, data interface{}) error { 197 var payload io.Reader = nil 198 var payload_size = 0 199 if data != nil { 200 b, err := json.Marshal(data) 201 if err != nil { 202 return err 203 } 204 payload = bytes.NewReader(b) 205 payload_size = len(b) 206 } 207 return s.SendMessage(msg, payload, int64(payload_size)) 208 } 209 func (s *Session) Fetch(msg *message.Message, data interface{}) (*message.Message, error) { 210 err := s.Send(msg, data) 211 if err != nil { 212 return nil, err 213 } 214 return s.ReadMessage() 215 } 216 func (s *Session) ReadJson(size int, v interface{}) error { 217 buf := new(bytes.Buffer) 218 _, err := io.CopyN(buf, s, int64(size)) 219 if err != nil { 220 return err 221 } 222 return json.Unmarshal(buf.Bytes(), v) 223 } 224 func (s *Session) ReadFile(filepath string, md5 string, size int64) (written int64, err error) { 225 f, err := os.Create(filepath) 226 if err != nil { 227 return 0, err 228 } 229 written, err = io.CopyN(f, s, size) 230 f.Close() 231 hash, err := x.Hash_file_md5(filepath) 232 if hash != md5 { 233 return written, errors.New("md5 is inconsistent") 234 } 235 return written, err 236 } 237 func (s *Session) String() string { 238 return fmt.Sprintf("remote_addr:%s", s._c.RemoteAddr()) 239 }