github.com/gagliardetto/solana-go@v1.11.0/nativetypes.go (about)

     1  // Copyright 2021 github.com/gagliardetto
     2  // This file has been modified by github.com/gagliardetto
     3  //
     4  // Copyright 2020 dfuse Platform Inc.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package solana
    19  
    20  import (
    21  	"crypto/ed25519"
    22  	"encoding/base64"
    23  	"fmt"
    24  	"io"
    25  
    26  	bin "github.com/gagliardetto/binary"
    27  	"github.com/mostynb/zstdpool-freelist"
    28  	"github.com/mr-tron/base58"
    29  )
    30  
    31  type Padding []byte
    32  
    33  type Hash PublicKey
    34  
    35  // MustHashFromBase58 decodes a base58 string into a Hash.
    36  // Panics on error.
    37  func MustHashFromBase58(in string) Hash {
    38  	return Hash(MustPublicKeyFromBase58(in))
    39  }
    40  
    41  // HashFromBase58 decodes a base58 string into a Hash.
    42  func HashFromBase58(in string) (Hash, error) {
    43  	tmp, err := PublicKeyFromBase58(in)
    44  	if err != nil {
    45  		return Hash{}, err
    46  	}
    47  	return Hash(tmp), nil
    48  }
    49  
    50  // HashFromBytes decodes a byte slice into a Hash.
    51  func HashFromBytes(in []byte) Hash {
    52  	return Hash(PublicKeyFromBytes(in))
    53  }
    54  
    55  // MarshalText implements encoding.TextMarshaler.
    56  func (ha Hash) MarshalText() ([]byte, error) {
    57  	s := base58.Encode(ha[:])
    58  	return []byte(s), nil
    59  }
    60  
    61  // UnmarshalText implements encoding.TextUnmarshaler.
    62  func (ha *Hash) UnmarshalText(data []byte) (err error) {
    63  	tmp, err := HashFromBase58(string(data))
    64  	if err != nil {
    65  		return fmt.Errorf("invalid hash %q: %w", string(data), err)
    66  	}
    67  	*ha = tmp
    68  	return
    69  }
    70  
    71  func (ha Hash) MarshalJSON() ([]byte, error) {
    72  	return json.Marshal(base58.Encode(ha[:]))
    73  }
    74  
    75  func (ha *Hash) UnmarshalJSON(data []byte) (err error) {
    76  	var s string
    77  	if err := json.Unmarshal(data, &s); err != nil {
    78  		return err
    79  	}
    80  
    81  	tmp, err := HashFromBase58(s)
    82  	if err != nil {
    83  		return fmt.Errorf("invalid hash %q: %w", s, err)
    84  	}
    85  	*ha = tmp
    86  	return
    87  }
    88  
    89  func (ha Hash) Equals(pb Hash) bool {
    90  	return ha == pb
    91  }
    92  
    93  var zeroHash = Hash{}
    94  
    95  func (ha Hash) IsZero() bool {
    96  	return ha == zeroHash
    97  }
    98  
    99  func (ha Hash) String() string {
   100  	return base58.Encode(ha[:])
   101  }
   102  
   103  type Signature [64]byte
   104  
   105  var zeroSignature = Signature{}
   106  
   107  func (sig Signature) IsZero() bool {
   108  	return sig == zeroSignature
   109  }
   110  
   111  func (sig Signature) Equals(pb Signature) bool {
   112  	return sig == pb
   113  }
   114  
   115  // SignatureFromBase58 decodes a base58 string into a Signature.
   116  func SignatureFromBase58(in string) (out Signature, err error) {
   117  	val, err := base58.Decode(in)
   118  	if err != nil {
   119  		return
   120  	}
   121  
   122  	if len(val) != SignatureLength {
   123  		err = fmt.Errorf("invalid length, expected 64, got %d", len(val))
   124  		return
   125  	}
   126  	copy(out[:], val)
   127  	return
   128  }
   129  
   130  // MustSignatureFromBase58 decodes a base58 string into a Signature.
   131  // Panics on error.
   132  func MustSignatureFromBase58(in string) Signature {
   133  	out, err := SignatureFromBase58(in)
   134  	if err != nil {
   135  		panic(err)
   136  	}
   137  	return out
   138  }
   139  
   140  // SignatureFromBytes decodes a byte slice into a Signature.
   141  func SignatureFromBytes(in []byte) (out Signature) {
   142  	byteCount := len(in)
   143  	if byteCount == 0 {
   144  		return
   145  	}
   146  
   147  	max := SignatureLength
   148  	if byteCount < max {
   149  		max = byteCount
   150  	}
   151  
   152  	copy(out[:], in[0:max])
   153  	return
   154  }
   155  
   156  func (p Signature) MarshalText() ([]byte, error) {
   157  	s := base58.Encode(p[:])
   158  	return []byte(s), nil
   159  }
   160  
   161  func (p *Signature) UnmarshalText(data []byte) (err error) {
   162  	tmp, err := SignatureFromBase58(string(data))
   163  	if err != nil {
   164  		return fmt.Errorf("invalid signature %q: %w", string(data), err)
   165  	}
   166  	*p = tmp
   167  	return
   168  }
   169  
   170  func (p Signature) MarshalJSON() ([]byte, error) {
   171  	return json.Marshal(base58.Encode(p[:]))
   172  }
   173  
   174  func (p *Signature) UnmarshalJSON(data []byte) (err error) {
   175  	var s string
   176  	err = json.Unmarshal(data, &s)
   177  	if err != nil {
   178  		return
   179  	}
   180  
   181  	dat, err := base58.Decode(s)
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	if len(dat) != SignatureLength {
   187  		return fmt.Errorf("invalid length for Signature, expected 64, got %d", len(dat))
   188  	}
   189  
   190  	target := Signature{}
   191  	copy(target[:], dat)
   192  	*p = target
   193  	return
   194  }
   195  
   196  // Verify checks that the signature is valid for the given public key and message.
   197  func (s Signature) Verify(pubkey PublicKey, msg []byte) bool {
   198  	return ed25519.Verify(pubkey[:], msg, s[:])
   199  }
   200  
   201  func (p Signature) String() string {
   202  	return base58.Encode(p[:])
   203  }
   204  
   205  type Base58 []byte
   206  
   207  func (t Base58) MarshalJSON() ([]byte, error) {
   208  	return json.Marshal(base58.Encode(t))
   209  }
   210  
   211  func (t *Base58) UnmarshalJSON(data []byte) (err error) {
   212  	var s string
   213  	err = json.Unmarshal(data, &s)
   214  	if err != nil {
   215  		return
   216  	}
   217  
   218  	if s == "" {
   219  		*t = []byte{}
   220  		return nil
   221  	}
   222  
   223  	*t, err = base58.Decode(s)
   224  	return
   225  }
   226  
   227  func (t Base58) String() string {
   228  	return base58.Encode(t)
   229  }
   230  
   231  type Data struct {
   232  	Content  []byte
   233  	Encoding EncodingType
   234  }
   235  
   236  func (t Data) MarshalJSON() ([]byte, error) {
   237  	return json.Marshal(
   238  		[]interface{}{
   239  			t.String(),
   240  			t.Encoding,
   241  		})
   242  }
   243  
   244  var zstdDecoderPool = zstdpool.NewDecoderPool()
   245  
   246  func (t *Data) UnmarshalJSON(data []byte) (err error) {
   247  	var in []string
   248  	if err := json.Unmarshal(data, &in); err != nil {
   249  		return err
   250  	}
   251  
   252  	if len(in) != 2 {
   253  		return fmt.Errorf("invalid length for solana.Data, expected 2, found %d", len(in))
   254  	}
   255  
   256  	contentString := in[0]
   257  	encodingString := in[1]
   258  	t.Encoding = EncodingType(encodingString)
   259  
   260  	if contentString == "" {
   261  		t.Content = []byte{}
   262  		return nil
   263  	}
   264  
   265  	switch t.Encoding {
   266  	case EncodingBase58:
   267  		var err error
   268  		t.Content, err = base58.Decode(contentString)
   269  		if err != nil {
   270  			return err
   271  		}
   272  	case EncodingBase64:
   273  		var err error
   274  		t.Content, err = base64.StdEncoding.DecodeString(contentString)
   275  		if err != nil {
   276  			return err
   277  		}
   278  	case EncodingBase64Zstd:
   279  		rawBytes, err := base64.StdEncoding.DecodeString(contentString)
   280  		if err != nil {
   281  			return err
   282  		}
   283  		dec, err := zstdDecoderPool.Get(nil)
   284  		if err != nil {
   285  			return err
   286  		}
   287  		defer zstdDecoderPool.Put(dec)
   288  
   289  		t.Content, err = dec.DecodeAll(rawBytes, nil)
   290  		if err != nil {
   291  			return err
   292  		}
   293  	default:
   294  		return fmt.Errorf("unsupported encoding %s", encodingString)
   295  	}
   296  	return
   297  }
   298  
   299  var zstdEncoderPool = zstdpool.NewEncoderPool()
   300  
   301  func (t Data) String() string {
   302  	switch EncodingType(t.Encoding) {
   303  	case EncodingBase58:
   304  		return base58.Encode(t.Content)
   305  	case EncodingBase64:
   306  		return base64.StdEncoding.EncodeToString(t.Content)
   307  	case EncodingBase64Zstd:
   308  		enc, err := zstdEncoderPool.Get(nil)
   309  		if err != nil {
   310  			// TODO: remove panic?
   311  			panic(err)
   312  		}
   313  		defer zstdEncoderPool.Put(enc)
   314  		return base64.StdEncoding.EncodeToString(enc.EncodeAll(t.Content, nil))
   315  	default:
   316  		// TODO
   317  		return ""
   318  	}
   319  }
   320  
   321  func (obj Data) MarshalWithEncoder(encoder *bin.Encoder) (err error) {
   322  	err = encoder.WriteBytes(obj.Content, true)
   323  	if err != nil {
   324  		return err
   325  	}
   326  	err = encoder.WriteString(string(obj.Encoding))
   327  	if err != nil {
   328  		return err
   329  	}
   330  	return nil
   331  }
   332  
   333  func (obj *Data) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) {
   334  	obj.Content, err = decoder.ReadByteSlice()
   335  	if err != nil {
   336  		return err
   337  	}
   338  	{
   339  		enc, err := decoder.ReadString()
   340  		if err != nil {
   341  			return err
   342  		}
   343  		obj.Encoding = EncodingType(enc)
   344  	}
   345  	return nil
   346  }
   347  
   348  type ByteWrapper struct {
   349  	io.Reader
   350  }
   351  
   352  func (w *ByteWrapper) ReadByte() (byte, error) {
   353  	var b [1]byte
   354  	// NOTE: w.Read() gives no guaranties about the number of bytes actually read.
   355  	// Using io.ReadFull reads exactly len(buf) bytes from r into buf.
   356  	_, err := io.ReadFull(w, b[:])
   357  	return b[0], err
   358  }
   359  
   360  type EncodingType string
   361  
   362  const (
   363  	EncodingBase58     EncodingType = "base58"      // limited to Account data of less than 129 bytes
   364  	EncodingBase64     EncodingType = "base64"      // will return base64 encoded data for Account data of any size
   365  	EncodingBase64Zstd EncodingType = "base64+zstd" // compresses the Account data using Zstandard and base64-encodes the result
   366  
   367  	// attempts to use program-specific state parsers to
   368  	// return more human-readable and explicit account state data.
   369  	// If "jsonParsed" is requested but a parser cannot be found,
   370  	// the field falls back to "base64" encoding, detectable when the data field is type <string>.
   371  	// Cannot be used if specifying dataSlice parameters (offset, length).
   372  	EncodingJSONParsed EncodingType = "jsonParsed"
   373  
   374  	EncodingJSON EncodingType = "json" // NOTE: you're probably looking for EncodingJSONParsed
   375  )
   376  
   377  // IsAnyOfEncodingType checks whether the provided `candidate` is any of the `allowed`.
   378  func IsAnyOfEncodingType(candidate EncodingType, allowed ...EncodingType) bool {
   379  	for _, v := range allowed {
   380  		if candidate == v {
   381  			return true
   382  		}
   383  	}
   384  	return false
   385  }