github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/s3/mem.go (about)

     1  package s3
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"io"
     9  	"sort"
    10  	"sync"
    11  
    12  	"github.com/keybase/client/go/libkb"
    13  
    14  	"golang.org/x/net/context"
    15  )
    16  
    17  type Mem struct {
    18  	mc     *MemConn
    19  	mcMake sync.Once
    20  }
    21  
    22  var _ Root = &Mem{}
    23  
    24  func (m *Mem) New(g *libkb.GlobalContext, signer Signer, region Region) Connection {
    25  	return m.NewMemConn()
    26  }
    27  
    28  type MemConn struct {
    29  	buckets map[string]*MemBucket
    30  	sync.Mutex
    31  }
    32  
    33  func (m *Mem) NewMemConn() *MemConn {
    34  	m.mcMake.Do(func() {
    35  		m.mc = &MemConn{
    36  			buckets: make(map[string]*MemBucket),
    37  		}
    38  	})
    39  	return m.mc
    40  }
    41  
    42  var _ Connection = &MemConn{}
    43  
    44  func (s *MemConn) SetAccessKey(key string)    {}
    45  func (s *MemConn) SetSessionToken(key string) {}
    46  
    47  func (s *MemConn) Bucket(name string) BucketInt {
    48  	s.Lock()
    49  	defer s.Unlock()
    50  	b, ok := s.buckets[name]
    51  	if ok {
    52  		return b
    53  	}
    54  	b = NewMemBucket(s, name)
    55  	s.buckets[name] = b
    56  	return b
    57  }
    58  
    59  func (s *MemConn) AllMultis() []*MemMulti {
    60  	s.Lock()
    61  	defer s.Unlock()
    62  	var all []*MemMulti
    63  	for _, b := range s.buckets {
    64  		for _, m := range b.multis {
    65  			all = append(all, m)
    66  		}
    67  	}
    68  	return all
    69  }
    70  
    71  type MemBucket struct {
    72  	conn    *MemConn
    73  	name    string
    74  	objects map[string][]byte
    75  	multis  map[string]*MemMulti
    76  	sync.Mutex
    77  }
    78  
    79  func NewMemBucket(conn *MemConn, name string) *MemBucket {
    80  	return &MemBucket{
    81  		conn:    conn,
    82  		name:    name,
    83  		objects: make(map[string][]byte),
    84  		multis:  make(map[string]*MemMulti),
    85  	}
    86  }
    87  
    88  var _ BucketInt = &MemBucket{}
    89  
    90  func (b *MemBucket) GetReader(ctx context.Context, path string) (io.ReadCloser, error) {
    91  	b.Lock()
    92  	defer b.Unlock()
    93  	obj, ok := b.objects[path]
    94  	if !ok {
    95  		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
    96  	}
    97  	return io.NopCloser(bytes.NewBuffer(obj)), nil
    98  }
    99  
   100  func (b *MemBucket) GetReaderWithRange(ctx context.Context, path string, begin, end int64) (io.ReadCloser, error) {
   101  	b.Lock()
   102  	defer b.Unlock()
   103  	obj, ok := b.objects[path]
   104  	if !ok {
   105  		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
   106  	}
   107  	if end >= int64(len(obj)) {
   108  		end = int64(len(obj))
   109  	}
   110  	return io.NopCloser(bytes.NewBuffer(obj[begin:end])), nil
   111  }
   112  
   113  func (b *MemBucket) PutReader(ctx context.Context, path string, r io.Reader, length int64, contType string, perm ACL, options Options) error {
   114  	b.Lock()
   115  	defer b.Unlock()
   116  
   117  	var buf bytes.Buffer
   118  	_, err := buf.ReadFrom(r)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	b.objects[path] = buf.Bytes()
   123  
   124  	return nil
   125  }
   126  
   127  func (b *MemBucket) Del(ctx context.Context, path string) error {
   128  	b.Lock()
   129  	defer b.Unlock()
   130  
   131  	delete(b.objects, path)
   132  	return nil
   133  }
   134  
   135  func (b *MemBucket) setObject(path string, data []byte) {
   136  	b.Lock()
   137  	defer b.Unlock()
   138  	b.objects[path] = data
   139  }
   140  
   141  func (b *MemBucket) Multi(ctx context.Context, key, contType string, perm ACL) (MultiInt, error) {
   142  	b.Lock()
   143  	defer b.Unlock()
   144  	m, ok := b.multis[key]
   145  	if ok {
   146  		return m, nil
   147  	}
   148  	m = NewMemMulti(b, key)
   149  	b.multis[key] = m
   150  	return m, nil
   151  }
   152  
   153  type MemMulti struct {
   154  	bucket      *MemBucket
   155  	path        string
   156  	parts       map[int]*part
   157  	numPutParts int
   158  	sync.Mutex
   159  }
   160  
   161  var _ MultiInt = &MemMulti{}
   162  
   163  func NewMemMulti(b *MemBucket, path string) *MemMulti {
   164  	return &MemMulti{
   165  		bucket: b,
   166  		path:   path,
   167  		parts:  make(map[int]*part),
   168  	}
   169  }
   170  
   171  func (m *MemMulti) ListParts(ctx context.Context) ([]Part, error) {
   172  	m.Lock()
   173  	defer m.Unlock()
   174  
   175  	var ps []Part
   176  	for _, p := range m.parts {
   177  		ps = append(ps, p.export())
   178  	}
   179  	return ps, nil
   180  }
   181  
   182  func (m *MemMulti) Complete(ctx context.Context, parts []Part) error {
   183  	m.Lock()
   184  	defer m.Unlock()
   185  
   186  	// match parts coming in with existing parts
   187  	var scratch partList
   188  	for _, p := range parts {
   189  		if pp, ok := m.parts[p.N]; ok {
   190  			scratch = append(scratch, pp)
   191  		}
   192  	}
   193  
   194  	// assemble into one block
   195  	sort.Sort(scratch)
   196  	var buf bytes.Buffer
   197  	for _, p := range scratch {
   198  		buf.Write(p.data)
   199  	}
   200  
   201  	// store in bucket
   202  	m.bucket.setObject(m.path, buf.Bytes())
   203  
   204  	return nil
   205  }
   206  
   207  func (m *MemMulti) PutPart(ctx context.Context, index int, r io.ReadSeeker) (Part, error) {
   208  	m.Lock()
   209  	defer m.Unlock()
   210  
   211  	var buf bytes.Buffer
   212  	_, err := buf.ReadFrom(r)
   213  	if err != nil {
   214  		return Part{}, err
   215  	}
   216  	p := newPart(index, buf)
   217  	m.parts[index] = p
   218  
   219  	m.numPutParts++
   220  
   221  	return p.export(), nil
   222  }
   223  
   224  // NumPutParts returns the number of times PutPart was called.
   225  func (m *MemMulti) NumPutParts() int {
   226  	m.Lock()
   227  	defer m.Unlock()
   228  
   229  	return m.numPutParts
   230  }
   231  
   232  type part struct {
   233  	index int
   234  	hash  string
   235  	data  []byte
   236  }
   237  
   238  func newPart(index int, buf bytes.Buffer) *part {
   239  	p := &part{
   240  		index: index,
   241  		data:  buf.Bytes(),
   242  	}
   243  	h := md5.Sum(p.data)
   244  	p.hash = hex.EncodeToString(h[:])
   245  	return p
   246  }
   247  
   248  func (p *part) export() Part {
   249  	return Part{N: p.index, ETag: `"` + p.hash + `"`, Size: int64(len(p.data))}
   250  }
   251  
   252  type partList []*part
   253  
   254  func (x partList) Len() int           { return len(x) }
   255  func (x partList) Less(a, b int) bool { return x[a].index < x[b].index }
   256  func (x partList) Swap(a, b int)      { x[a], x[b] = x[b], x[a] }