github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/s3/mem.go (about)

     1  package s3
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"io"
     9  	"sort"
    10  	"sync"
    11  
    12  	"github.com/keybase/client/go/libkb"
    13  
    14  	"golang.org/x/net/context"
    15  )
    16  
    17  type Mem struct {
    18  	mc     *MemConn
    19  	mcMake sync.Once
    20  }
    21  
    22  var _ Root = &Mem{}
    23  
    24  func (m *Mem) New(g *libkb.GlobalContext, signer Signer, region Region) Connection {
    25  	return m.NewMemConn()
    26  }
    27  
    28  type MemConn struct {
    29  	buckets map[string]*MemBucket
    30  	sync.Mutex
    31  }
    32  
    33  func (m *Mem) NewMemConn() *MemConn {
    34  	m.mcMake.Do(func() {
    35  		m.mc = &MemConn{
    36  			buckets: make(map[string]*MemBucket),
    37  		}
    38  	})
    39  	return m.mc
    40  }
    41  
    42  var _ Connection = &MemConn{}
    43  
    44  func (s *MemConn) SetAccessKey(key string) {}
    45  
    46  func (s *MemConn) Bucket(name string) BucketInt {
    47  	s.Lock()
    48  	defer s.Unlock()
    49  	b, ok := s.buckets[name]
    50  	if ok {
    51  		return b
    52  	}
    53  	b = NewMemBucket(s, name)
    54  	s.buckets[name] = b
    55  	return b
    56  }
    57  
    58  func (s *MemConn) AllMultis() []*MemMulti {
    59  	s.Lock()
    60  	defer s.Unlock()
    61  	var all []*MemMulti
    62  	for _, b := range s.buckets {
    63  		for _, m := range b.multis {
    64  			all = append(all, m)
    65  		}
    66  	}
    67  	return all
    68  }
    69  
    70  type MemBucket struct {
    71  	conn    *MemConn
    72  	name    string
    73  	objects map[string][]byte
    74  	multis  map[string]*MemMulti
    75  	sync.Mutex
    76  }
    77  
    78  func NewMemBucket(conn *MemConn, name string) *MemBucket {
    79  	return &MemBucket{
    80  		conn:    conn,
    81  		name:    name,
    82  		objects: make(map[string][]byte),
    83  		multis:  make(map[string]*MemMulti),
    84  	}
    85  }
    86  
    87  var _ BucketInt = &MemBucket{}
    88  
    89  func (b *MemBucket) GetReader(ctx context.Context, path string) (io.ReadCloser, error) {
    90  	b.Lock()
    91  	defer b.Unlock()
    92  	obj, ok := b.objects[path]
    93  	if !ok {
    94  		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
    95  	}
    96  	return io.NopCloser(bytes.NewBuffer(obj)), nil
    97  }
    98  
    99  func (b *MemBucket) GetReaderWithRange(ctx context.Context, path string, begin, end int64) (io.ReadCloser, error) {
   100  	b.Lock()
   101  	defer b.Unlock()
   102  	obj, ok := b.objects[path]
   103  	if !ok {
   104  		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
   105  	}
   106  	if end >= int64(len(obj)) {
   107  		end = int64(len(obj))
   108  	}
   109  	return io.NopCloser(bytes.NewBuffer(obj[begin:end])), nil
   110  }
   111  
   112  func (b *MemBucket) PutReader(ctx context.Context, path string, r io.Reader, length int64, contType string, perm ACL, options Options) error {
   113  	b.Lock()
   114  	defer b.Unlock()
   115  
   116  	var buf bytes.Buffer
   117  	_, err := buf.ReadFrom(r)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	b.objects[path] = buf.Bytes()
   122  
   123  	return nil
   124  }
   125  
   126  func (b *MemBucket) Del(ctx context.Context, path string) error {
   127  	b.Lock()
   128  	defer b.Unlock()
   129  
   130  	delete(b.objects, path)
   131  	return nil
   132  }
   133  
   134  func (b *MemBucket) setObject(path string, data []byte) {
   135  	b.Lock()
   136  	defer b.Unlock()
   137  	b.objects[path] = data
   138  }
   139  
   140  func (b *MemBucket) Multi(ctx context.Context, key, contType string, perm ACL) (MultiInt, error) {
   141  	b.Lock()
   142  	defer b.Unlock()
   143  	m, ok := b.multis[key]
   144  	if ok {
   145  		return m, nil
   146  	}
   147  	m = NewMemMulti(b, key)
   148  	b.multis[key] = m
   149  	return m, nil
   150  }
   151  
   152  type MemMulti struct {
   153  	bucket      *MemBucket
   154  	path        string
   155  	parts       map[int]*part
   156  	numPutParts int
   157  	sync.Mutex
   158  }
   159  
   160  var _ MultiInt = &MemMulti{}
   161  
   162  func NewMemMulti(b *MemBucket, path string) *MemMulti {
   163  	return &MemMulti{
   164  		bucket: b,
   165  		path:   path,
   166  		parts:  make(map[int]*part),
   167  	}
   168  }
   169  
   170  func (m *MemMulti) ListParts(ctx context.Context) ([]Part, error) {
   171  	m.Lock()
   172  	defer m.Unlock()
   173  
   174  	var ps []Part
   175  	for _, p := range m.parts {
   176  		ps = append(ps, p.export())
   177  	}
   178  	return ps, nil
   179  }
   180  
   181  func (m *MemMulti) Complete(ctx context.Context, parts []Part) error {
   182  	m.Lock()
   183  	defer m.Unlock()
   184  
   185  	// match parts coming in with existing parts
   186  	var scratch partList
   187  	for _, p := range parts {
   188  		if pp, ok := m.parts[p.N]; ok {
   189  			scratch = append(scratch, pp)
   190  		}
   191  	}
   192  
   193  	// assemble into one block
   194  	sort.Sort(scratch)
   195  	var buf bytes.Buffer
   196  	for _, p := range scratch {
   197  		buf.Write(p.data)
   198  	}
   199  
   200  	// store in bucket
   201  	m.bucket.setObject(m.path, buf.Bytes())
   202  
   203  	return nil
   204  }
   205  
   206  func (m *MemMulti) PutPart(ctx context.Context, index int, r io.ReadSeeker) (Part, error) {
   207  	m.Lock()
   208  	defer m.Unlock()
   209  
   210  	var buf bytes.Buffer
   211  	_, err := buf.ReadFrom(r)
   212  	if err != nil {
   213  		return Part{}, err
   214  	}
   215  	p := newPart(index, buf)
   216  	m.parts[index] = p
   217  
   218  	m.numPutParts++
   219  
   220  	return p.export(), nil
   221  }
   222  
   223  // NumPutParts returns the number of times PutPart was called.
   224  func (m *MemMulti) NumPutParts() int {
   225  	m.Lock()
   226  	defer m.Unlock()
   227  
   228  	return m.numPutParts
   229  }
   230  
   231  type part struct {
   232  	index int
   233  	hash  string
   234  	data  []byte
   235  }
   236  
   237  func newPart(index int, buf bytes.Buffer) *part {
   238  	p := &part{
   239  		index: index,
   240  		data:  buf.Bytes(),
   241  	}
   242  	h := md5.Sum(p.data)
   243  	p.hash = hex.EncodeToString(h[:])
   244  	return p
   245  }
   246  
   247  func (p *part) export() Part {
   248  	return Part{N: p.index, ETag: `"` + p.hash + `"`, Size: int64(len(p.data))}
   249  }
   250  
   251  type partList []*part
   252  
   253  func (x partList) Len() int           { return len(x) }
   254  func (x partList) Less(a, b int) bool { return x[a].index < x[b].index }
   255  func (x partList) Swap(a, b int)      { x[a], x[b] = x[b], x[a] }