github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/smartcontract/callflag/call_flags.go (about)

     1  package callflag
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"strings"
     7  )
     8  
     9  // CallFlag represents a call flag.
    10  type CallFlag byte
    11  
    12  // Default flags.
    13  const (
    14  	ReadStates CallFlag = 1 << iota
    15  	WriteStates
    16  	AllowCall
    17  	AllowNotify
    18  
    19  	States            = ReadStates | WriteStates
    20  	ReadOnly          = ReadStates | AllowCall
    21  	All               = States | AllowCall | AllowNotify
    22  	NoneFlag CallFlag = 0
    23  )
    24  
    25  var flagString = map[CallFlag]string{
    26  	ReadStates:  "ReadStates",
    27  	WriteStates: "WriteStates",
    28  	AllowCall:   "AllowCall",
    29  	AllowNotify: "AllowNotify",
    30  	States:      "States",
    31  	ReadOnly:    "ReadOnly",
    32  	All:         "All",
    33  	NoneFlag:    "None",
    34  }
    35  
    36  // basicFlags are all flags except All and None. It's used to stringify CallFlag
    37  // where its bits are matched against these values from the values with sets of bits
    38  // to simple flags, which is important to produce proper string representation
    39  // matching C# Enum handling.
    40  var basicFlags = []CallFlag{ReadOnly, States, ReadStates, WriteStates, AllowCall, AllowNotify}
    41  
    42  // FromString parses an input string and returns a corresponding CallFlag.
    43  func FromString(s string) (CallFlag, error) {
    44  	flags := strings.Split(s, ",")
    45  	if len(flags) == 0 {
    46  		return NoneFlag, errors.New("empty flags")
    47  	}
    48  	if len(flags) == 1 {
    49  		for f, str := range flagString {
    50  			if s == str {
    51  				return f, nil
    52  			}
    53  		}
    54  		return NoneFlag, errors.New("unknown flag")
    55  	}
    56  
    57  	var res CallFlag
    58  
    59  	for _, flag := range flags {
    60  		var knownFlag bool
    61  
    62  		flag = strings.TrimSpace(flag)
    63  		for _, f := range basicFlags {
    64  			if flag == flagString[f] {
    65  				res |= f
    66  				knownFlag = true
    67  				break
    68  			}
    69  		}
    70  		if !knownFlag {
    71  			return NoneFlag, errors.New("unknown/inappropriate flag")
    72  		}
    73  	}
    74  	return res, nil
    75  }
    76  
    77  // Has returns true iff all bits set in cf are also set in f.
    78  func (f CallFlag) Has(cf CallFlag) bool {
    79  	return f&cf == cf
    80  }
    81  
    82  // String implements Stringer interface.
    83  func (f CallFlag) String() string {
    84  	if flagString[f] != "" {
    85  		return flagString[f]
    86  	}
    87  
    88  	var res string
    89  
    90  	for _, flag := range basicFlags {
    91  		if f.Has(flag) {
    92  			if len(res) != 0 {
    93  				res += ", "
    94  			}
    95  			res += flagString[flag]
    96  			f &= ^flag // Some "States" shouldn't be combined with "ReadStates".
    97  		}
    98  	}
    99  	return res
   100  }
   101  
   102  // MarshalJSON implements the json.Marshaler interface.
   103  func (f CallFlag) MarshalJSON() ([]byte, error) {
   104  	return []byte(`"` + f.String() + `"`), nil
   105  }
   106  
   107  // UnmarshalJSON implements the json.Unmarshaler interface.
   108  func (f *CallFlag) UnmarshalJSON(data []byte) error {
   109  	var js string
   110  	if err := json.Unmarshal(data, &js); err != nil {
   111  		return err
   112  	}
   113  	flag, err := FromString(js)
   114  	if err != nil {
   115  		return err
   116  	}
   117  	*f = flag
   118  	return nil
   119  }
   120  
   121  // MarshalYAML implements the YAML marshaler interface.
   122  func (f CallFlag) MarshalYAML() (any, error) {
   123  	return f.String(), nil
   124  }
   125  
   126  // UnmarshalYAML implements the YAML unmarshaler interface.
   127  func (f *CallFlag) UnmarshalYAML(unmarshal func(any) error) error {
   128  	var s string
   129  
   130  	err := unmarshal(&s)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	*f, err = FromString(s)
   136  	return err
   137  }