github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/libkb/stream.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"bufio"
     8  	"io"
     9  	"sync"
    10  
    11  	"golang.org/x/net/context"
    12  
    13  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    14  )
    15  
    16  type ReadCloser struct{}
    17  
    18  type ReadCloseSeeker interface {
    19  	io.ReadCloser
    20  	io.Seeker
    21  }
    22  
    23  type ExportedStream struct {
    24  	// r io.ReadCloser
    25  	r ReadCloseSeeker
    26  	w io.WriteCloser
    27  	i int
    28  }
    29  
    30  type ExportedStreams struct {
    31  	m map[int]*ExportedStream
    32  	i int
    33  	sync.Mutex
    34  }
    35  
    36  func NewExportedStreams() *ExportedStreams {
    37  	return &ExportedStreams{
    38  		m: make(map[int]*ExportedStream),
    39  		i: 0,
    40  	}
    41  }
    42  
    43  func (s *ExportedStreams) ExportWriter(w io.WriteCloser) keybase1.Stream {
    44  	es := s.alloc()
    45  	es.w = w
    46  	return es.Export()
    47  }
    48  
    49  func (s *ExportedStreams) ExportReader(r ReadCloseSeeker) keybase1.Stream {
    50  	es := s.alloc()
    51  	es.r = r
    52  	return es.Export()
    53  }
    54  
    55  func (s *ExportedStreams) alloc() (ret *ExportedStream) {
    56  	s.Lock()
    57  	defer s.Unlock()
    58  	s.i++
    59  	i := s.i
    60  	ret = &ExportedStream{i: i}
    61  	s.m[i] = ret
    62  	return ret
    63  }
    64  
    65  func (s *ExportedStream) Export() keybase1.Stream {
    66  	return keybase1.Stream{Fd: s.i}
    67  }
    68  
    69  func (s *ExportedStreams) GetWriter(st keybase1.Stream) (ret io.WriteCloser, err error) {
    70  	s.Lock()
    71  	defer s.Unlock()
    72  	if obj, found := s.m[st.Fd]; !found {
    73  		err = StreamNotFoundError{}
    74  	} else if obj.w == nil {
    75  		err = StreamWrongKindError{}
    76  	} else {
    77  		ret = obj.w
    78  	}
    79  	return
    80  }
    81  
    82  func (s *ExportedStreams) GetReader(st keybase1.Stream) (ret io.ReadCloser, err error) {
    83  	s.Lock()
    84  	defer s.Unlock()
    85  	if obj, found := s.m[st.Fd]; !found {
    86  		err = StreamNotFoundError{}
    87  	} else if obj.r == nil {
    88  		err = StreamWrongKindError{}
    89  	} else {
    90  		ret = obj.r
    91  	}
    92  	return
    93  }
    94  
    95  func (s *ExportedStreams) Close(_ context.Context, a keybase1.CloseArg) (err error) {
    96  	s.Lock()
    97  	defer s.Unlock()
    98  	if obj, found := s.m[a.S.Fd]; !found {
    99  		err = StreamNotFoundError{}
   100  	} else {
   101  		if obj.w != nil {
   102  			err = obj.w.Close()
   103  		}
   104  		if obj.r != nil {
   105  			tmp := obj.r.Close()
   106  			if tmp != nil && err == nil {
   107  				err = tmp
   108  			}
   109  		}
   110  		delete(s.m, a.S.Fd)
   111  	}
   112  	return err
   113  }
   114  
   115  func (s *ExportedStreams) Read(_ context.Context, a keybase1.ReadArg) (buf []byte, err error) {
   116  	var r io.ReadCloser
   117  	if r, err = s.GetReader(a.S); err != nil {
   118  		return
   119  	}
   120  	var n int
   121  	buf = make([]byte, a.Sz)
   122  	n, err = r.Read(buf)
   123  	buf = buf[0:n]
   124  	return
   125  }
   126  
   127  func (s *ExportedStreams) Write(_ context.Context, a keybase1.WriteArg) (n int, err error) {
   128  	var w io.WriteCloser
   129  	if w, err = s.GetWriter(a.S); err != nil {
   130  		return
   131  	}
   132  	n, err = w.Write(a.Buf)
   133  	return
   134  }
   135  
   136  func (s *ExportedStreams) Reset(_ context.Context, a keybase1.ResetArg) error {
   137  	s.Lock()
   138  	defer s.Unlock()
   139  
   140  	obj, found := s.m[a.S.Fd]
   141  	if !found || obj.r == nil {
   142  		return StreamNotFoundError{}
   143  	}
   144  
   145  	_, err := obj.r.Seek(0, io.SeekStart)
   146  	return err
   147  }
   148  
   149  type RemoteStream struct {
   150  	Stream    keybase1.Stream
   151  	Cli       *keybase1.StreamUiClient
   152  	SessionID int
   153  }
   154  
   155  func (ewc RemoteStream) Write(buf []byte) (n int, err error) {
   156  	return ewc.Cli.Write(context.TODO(), keybase1.WriteArg{S: ewc.Stream, Buf: buf, SessionID: ewc.SessionID})
   157  }
   158  
   159  func (ewc RemoteStream) Close() (err error) {
   160  	return ewc.Cli.Close(context.TODO(), keybase1.CloseArg{S: ewc.Stream, SessionID: ewc.SessionID})
   161  }
   162  
   163  func (ewc RemoteStream) Read(buf []byte) (n int, err error) {
   164  	var tmp []byte
   165  	if tmp, err = ewc.Cli.Read(context.TODO(), keybase1.ReadArg{S: ewc.Stream, Sz: len(buf), SessionID: ewc.SessionID}); err == nil {
   166  		n = len(tmp)
   167  		copy(buf, tmp)
   168  	}
   169  	return
   170  }
   171  
   172  func (ewc RemoteStream) Reset() (err error) {
   173  	return ewc.Cli.Reset(context.TODO(), keybase1.ResetArg{S: ewc.Stream, SessionID: ewc.SessionID})
   174  }
   175  
   176  type RemoteStreamBuffered struct {
   177  	rs *RemoteStream
   178  	r  *bufio.Reader
   179  	w  *bufio.Writer
   180  }
   181  
   182  func NewRemoteStreamBuffered(s keybase1.Stream, c *keybase1.StreamUiClient, sessionID int) *RemoteStreamBuffered {
   183  	x := &RemoteStreamBuffered{
   184  		rs: &RemoteStream{Stream: s, Cli: c, SessionID: sessionID},
   185  	}
   186  	x.createBufs()
   187  	return x
   188  }
   189  
   190  func (x *RemoteStreamBuffered) Write(p []byte) (int, error) {
   191  	return x.w.Write(p)
   192  }
   193  
   194  func (x *RemoteStreamBuffered) Read(p []byte) (int, error) {
   195  	return x.r.Read(p)
   196  }
   197  
   198  func (x *RemoteStreamBuffered) Close() error {
   199  	x.w.Flush()
   200  	return x.rs.Close()
   201  }
   202  
   203  func (x *RemoteStreamBuffered) Reset() (err error) {
   204  	x.w.Flush()
   205  	if err := x.rs.Reset(); err != nil {
   206  		return err
   207  	}
   208  	x.createBufs()
   209  	return nil
   210  }
   211  
   212  func (x *RemoteStreamBuffered) createBufs() {
   213  	x.r = bufio.NewReader(x.rs)
   214  	x.w = bufio.NewWriter(x.rs)
   215  }