github.com/swishcloud/filesync@v0.0.0-20231002120458-6ade2feed6f9/session/session.go (about)

     1  package session
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"strconv"
    13  	"time"
    14  
    15  	"github.com/swishcloud/filesync/message"
    16  	"github.com/swishcloud/filesync/x"
    17  	"github.com/swishcloud/gostudy/common"
    18  )
    19  
    20  type Session struct {
    21  	_c                     net.Conn
    22  	written                int64
    23  	read                   int64
    24  	last_msg_rest          int64
    25  	presentWriteProgress   PresentWriteProgress
    26  	write_speed_counter    int64
    27  	write_speed            int64
    28  	write_speed_clear_time time.Time
    29  	closed                 bool
    30  }
    31  type PresentWriteProgress func(n int)
    32  
    33  func NewSession(c net.Conn) *Session {
    34  	s := new(Session)
    35  	s._c = c
    36  	go s.speedTimer()
    37  	return s
    38  }
    39  func (s *Session) speedTimer() {
    40  	for !s.closed {
    41  		time.Sleep(time.Second * 1)
    42  		s.write_speed = s.write_speed_counter
    43  		s.write_speed_counter = 0
    44  		s.write_speed_clear_time = time.Now()
    45  	}
    46  }
    47  func (s *Session) Close() {
    48  	err := s._c.Close()
    49  	if err != nil {
    50  		log.Fatal(err)
    51  	}
    52  	s.closed = true
    53  }
    54  func (s *Session) ReadMessage() (*message.Message, error) {
    55  	if s.last_msg_rest != 0 {
    56  		return nil, errors.New("last read not completed")
    57  	}
    58  	var size_b []byte
    59  	size_buf := new(bytes.Buffer)
    60  	for {
    61  		_, err := io.CopyN(size_buf, s, 1)
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		size_b = size_buf.Bytes()
    66  		if size_b[len(size_b)-1] == 0 {
    67  			size_b = size_b[:len(size_b)-1]
    68  			break
    69  		}
    70  	}
    71  	size, err := strconv.ParseInt(string(size_b), 16, 64)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	msg_buf := new(bytes.Buffer)
    77  	_, err = io.CopyN(msg_buf, s, int64(size))
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	msg_b := msg_buf.Bytes()
    82  	msg, err := message.ReadMessage(bytes.NewReader(msg_b))
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	s.last_msg_rest = msg.BodySize
    87  	return msg, nil
    88  }
    89  func (s *Session) Write(p []byte) (n int, err error) {
    90  	t := time.Now()
    91  	n, err = s._c.Write(p)
    92  	if s.presentWriteProgress != nil {
    93  		s.presentWriteProgress(n)
    94  	}
    95  	s.written += int64(n)
    96  	if !t.Before(s.write_speed_clear_time) {
    97  		s.write_speed_counter += int64(n)
    98  	}
    99  	return n, err
   100  }
   101  func (s *Session) Read(p []byte) (n int, err error) {
   102  	n, err = s._c.Read(p)
   103  	s.read += int64(n)
   104  	s.last_msg_rest -= int64(n)
   105  	return n, err
   106  }
   107  func (s *Session) SendMessage(msg *message.Message, payload io.Reader, payload_size int64) error {
   108  	if msg.MsgType == 0 {
   109  		return errors.New("message type value must be non zero")
   110  	}
   111  	if payload == nil && payload_size != 0 {
   112  		return errors.New("parameter error")
   113  	}
   114  
   115  	msg.BodySize = payload_size
   116  
   117  	msg_b, err := x.Encode(msg)
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	size_b := []byte(strconv.FormatInt((int64)(len(msg_b)), 16))
   123  
   124  	_, err = io.CopyN(s, bytes.NewReader(size_b), int64(len(size_b)))
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	_, err = io.CopyN(s, bytes.NewReader([]byte{0}), 1)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	_, err = io.CopyN(s, bytes.NewReader(msg_b), int64(len(msg_b)))
   134  	if err != nil {
   135  		return err
   136  	}
   137  	if payload != nil && payload_size > 0 {
   138  		written := int64(0)
   139  		total := msg.BodySize
   140  		s.presentWriteProgress = func(n int) {
   141  			written += int64(n)
   142  			percent := int(float64(written) / float64(total) * 100)
   143  			s, u := common.FormatByteSize(s.write_speed)
   144  			fmt.Print("\r")
   145  			info := fmt.Sprintf("sent %d/%d bytes %d%%,%s %s/s               ", written, total, percent, s, u)
   146  			info = common.StringLimitLen(info, 50)
   147  			fmt.Print(info)
   148  		}
   149  		n, err := io.CopyN(s, payload, msg.BodySize)
   150  		fmt.Println()
   151  		s.presentWriteProgress = nil
   152  		if err != nil {
   153  			return err
   154  		}
   155  		if n != payload_size {
   156  			return errors.New(fmt.Sprintf("unexpected error:payload size is %d bytes,but written %d bytes", payload_size, n))
   157  		}
   158  	}
   159  	return nil
   160  }
   161  func (s *Session) Ack() error {
   162  	msg := message.NewMessage(message.MT_ACK)
   163  	return s.Send(msg, nil)
   164  }
   165  func (s *Session) SendFile(file_path string, pre_send func(filename string, md5 string, size int64) (offset int64, send bool)) error {
   166  	msg := message.NewMessage(message.MT_FILE)
   167  	md5, err := x.Hash_file_md5(file_path)
   168  	if err != nil {
   169  		return err
   170  	}
   171  	msg.Header["md5"] = md5
   172  	f, err := os.Open(file_path)
   173  	defer f.Close()
   174  	if err != nil {
   175  		return err
   176  	}
   177  	file_info, err := f.Stat()
   178  	if err != nil {
   179  		return err
   180  	}
   181  	msg.Header["file_name"] = file_info.Name()
   182  	payload_size := file_info.Size()
   183  	if pre_send != nil {
   184  		offset, ok := pre_send(file_info.Name(), md5, file_info.Size())
   185  		if !ok {
   186  			return nil
   187  		}
   188  		_, err := f.Seek(offset, 1)
   189  		if err != nil {
   190  			return err
   191  		}
   192  		payload_size -= offset
   193  	}
   194  	return s.SendMessage(msg, f, payload_size)
   195  }
   196  func (s *Session) Send(msg *message.Message, data interface{}) error {
   197  	var payload io.Reader = nil
   198  	var payload_size = 0
   199  	if data != nil {
   200  		b, err := json.Marshal(data)
   201  		if err != nil {
   202  			return err
   203  		}
   204  		payload = bytes.NewReader(b)
   205  		payload_size = len(b)
   206  	}
   207  	return s.SendMessage(msg, payload, int64(payload_size))
   208  }
   209  func (s *Session) Fetch(msg *message.Message, data interface{}) (*message.Message, error) {
   210  	err := s.Send(msg, data)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	return s.ReadMessage()
   215  }
   216  func (s *Session) ReadJson(size int, v interface{}) error {
   217  	buf := new(bytes.Buffer)
   218  	_, err := io.CopyN(buf, s, int64(size))
   219  	if err != nil {
   220  		return err
   221  	}
   222  	return json.Unmarshal(buf.Bytes(), v)
   223  }
   224  func (s *Session) ReadFile(filepath string, md5 string, size int64) (written int64, err error) {
   225  	f, err := os.Create(filepath)
   226  	if err != nil {
   227  		return 0, err
   228  	}
   229  	written, err = io.CopyN(f, s, size)
   230  	f.Close()
   231  	hash, err := x.Hash_file_md5(filepath)
   232  	if hash != md5 {
   233  		return written, errors.New("md5 is inconsistent")
   234  	}
   235  	return written, err
   236  }
   237  func (s *Session) String() string {
   238  	return fmt.Sprintf("remote_addr:%s", s._c.RemoteAddr())
   239  }