go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/quota/internal/datatool/main.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Datatool is a program which allows you to encode/decode quotapb protobuf
    16  // messages to/from a variety of codecs.
    17  //
    18  // In particular, this allows you to decode lua's crazy decimal escape codes
    19  // and also visualize the msgpackpb codec data (which is not JSON compatible,
    20  // so most msgpack viewers will show something inaccurate).
    21  //
    22  // This is meant to be an internal debugging tool for server/quota developers.
    23  package main
    24  
    25  import (
    26  	"bufio"
    27  	"bytes"
    28  	"context"
    29  	"flag"
    30  	"fmt"
    31  	"io"
    32  	"os"
    33  	"regexp"
    34  	"sort"
    35  	"strconv"
    36  	"strings"
    37  
    38  	"github.com/vmihailenco/msgpack/v5"
    39  	"google.golang.org/protobuf/encoding/protojson"
    40  	"google.golang.org/protobuf/proto"
    41  	"google.golang.org/protobuf/reflect/protoreflect"
    42  	"google.golang.org/protobuf/reflect/protoregistry"
    43  
    44  	"go.chromium.org/luci/common/errors"
    45  	"go.chromium.org/luci/common/iotools"
    46  	"go.chromium.org/luci/common/logging"
    47  	"go.chromium.org/luci/common/logging/gologger"
    48  	"go.chromium.org/luci/common/proto/msgpackpb"
    49  
    50  	_ "go.chromium.org/luci/server/quota/quotapb"
    51  )
    52  
    53  var typeName = flag.String("type", "",
    54  	"The quotapb message name to encode/decode.\nThis can either be a full proto message name (like 'google.protobuf.Duration'),\nor it will be looked up relative to the quotapb package (like 'Account').")
    55  
    56  var inCodec = flag.String("in", "jsonpb", "The format for reading data from stdin.")
    57  var outCodec = flag.String("out", "jsonpb", "The format for writing data to stdout.")
    58  
    59  var forceRoundTrip = flag.Bool("force", false, "Force round-trip through proto.")
    60  
    61  type codecImpl interface {
    62  	encode(proto.Message, io.Writer) error
    63  	decode(proto.Message, io.Reader) error
    64  }
    65  
    66  type msgpackDecoder interface {
    67  	decodeToMsgpack(io.Reader) (msgpack.RawMessage, error)
    68  }
    69  type msgpackEncoder interface {
    70  	encodeFromMsgpack(msgpack.RawMessage, io.Writer) error
    71  }
    72  
    73  type codec struct {
    74  	blurb string
    75  	impl  codecImpl
    76  }
    77  
    78  var codecs = map[string]codec{}
    79  
    80  type jsonpbCodec struct{}
    81  
    82  func (jsonpbCodec) encode(msg proto.Message, w io.Writer) error {
    83  	raw, err := protojson.MarshalOptions{
    84  		Multiline:     true,
    85  		UseProtoNames: true,
    86  	}.Marshal(msg)
    87  	if err == nil {
    88  		_, err = w.Write(raw)
    89  	}
    90  	return err
    91  }
    92  func (jsonpbCodec) decode(msg proto.Message, r io.Reader) error {
    93  	raw, err := io.ReadAll(r)
    94  	if err == nil {
    95  		err = protojson.Unmarshal(raw, msg)
    96  	}
    97  	return err
    98  }
    99  func init() {
   100  	codecs["jsonpb"] = codec{"JSONPB encoding", jsonpbCodec{}}
   101  }
   102  
   103  type pbCodec struct{}
   104  
   105  func (pbCodec) encode(msg proto.Message, w io.Writer) error {
   106  	dat, err := proto.Marshal(msg)
   107  	if err != nil {
   108  		return err
   109  	}
   110  	_, err = w.Write(dat)
   111  	return err
   112  }
   113  func (pbCodec) decode(msg proto.Message, r io.Reader) error {
   114  	dat, err := io.ReadAll(r)
   115  	if err != nil {
   116  		return err
   117  	}
   118  	return proto.Unmarshal(dat, msg)
   119  }
   120  func init() {
   121  	codecs["pb"] = codec{"protobuf encoding (binary)", pbCodec{}}
   122  }
   123  
   124  type msgpackpbCodec struct{}
   125  
   126  func (msgpackpbCodec) encode(msg proto.Message, w io.Writer) error {
   127  	return msgpackpb.MarshalStream(w, msg, msgpackpb.Deterministic)
   128  }
   129  func (msgpackpbCodec) decode(msg proto.Message, r io.Reader) error {
   130  	return msgpackpb.UnmarshalStream(r, msg)
   131  }
   132  func (msgpackpbCodec) decodeToMsgpack(r io.Reader) (msgpack.RawMessage, error) {
   133  	ret, err := io.ReadAll(r)
   134  	if err == nil {
   135  		return msgpack.RawMessage(ret), nil
   136  	}
   137  	return nil, err
   138  }
   139  func init() {
   140  	codecs["msgpackpb"] = codec{"msgpackpb encoding (binary).", msgpackpbCodec{}}
   141  }
   142  
   143  type msgpackPrettyCodec struct{}
   144  
   145  func prettyPrintMsgpack(w io.Writer, indent string, obj any) {
   146  	ws := func(fmtStr string, args ...any) {
   147  		fmt.Fprintf(w, fmtStr, args...)
   148  	}
   149  
   150  	switch x := obj.(type) {
   151  	case []any:
   152  		ws("[\n")
   153  		newIndent := indent + "  "
   154  		for _, itm := range x {
   155  			ws(newIndent)
   156  			prettyPrintMsgpack(w, newIndent, itm)
   157  			ws(",\n")
   158  		}
   159  		ws("%s]", indent)
   160  	case map[any]any:
   161  		ws("{\n")
   162  		newIndent := indent + "  "
   163  		for key, itm := range x {
   164  			ws(newIndent)
   165  			prettyPrintMsgpack(w, newIndent, key)
   166  			ws(": ")
   167  			prettyPrintMsgpack(w, newIndent, itm)
   168  			ws(",\n")
   169  		}
   170  		ws("%s}", indent)
   171  	case uint8:
   172  		ws("8u%d", x)
   173  	case uint16:
   174  		ws("16u%d", x)
   175  	case uint32:
   176  		ws("32u%d", x)
   177  	case uint64:
   178  		ws("64u%d", x)
   179  	case int8:
   180  		ws("8i%d", x)
   181  	case int16:
   182  		ws("16i%d", x)
   183  	case int32:
   184  		ws("32i%d", x)
   185  	case int64:
   186  		ws("64i%d", x)
   187  	case bool:
   188  		ws("%t", x)
   189  	case float32:
   190  		ws("32f%f", x)
   191  	case float64:
   192  		ws("64f%f", x)
   193  	case []byte:
   194  		ws("b%q", x)
   195  	case string:
   196  		ws("%q", x)
   197  	default:
   198  		panic(fmt.Sprintf("unknown msgback primitive: %T", x))
   199  	}
   200  }
   201  
   202  func (m msgpackPrettyCodec) encode(msg proto.Message, w io.Writer) error {
   203  	raw, err := msgpackpb.Marshal(msg, msgpackpb.Deterministic)
   204  	if err == nil {
   205  		err = m.encodeFromMsgpack(raw, w)
   206  	}
   207  	return err
   208  }
   209  func (msgpackPrettyCodec) encodeFromMsgpack(raw msgpack.RawMessage, w io.Writer) error {
   210  	dec := msgpack.NewDecoder(bytes.NewReader([]byte(raw)))
   211  	dec.SetMapDecoder(func(d *msgpack.Decoder) (any, error) {
   212  		return d.DecodeUntypedMap()
   213  	})
   214  
   215  	var iface any
   216  	iface, err := dec.DecodeInterface()
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	prettyPrintMsgpack(w, "", iface)
   222  	_, _ = w.Write([]byte("\n"))
   223  	return nil
   224  }
   225  func (msgpackPrettyCodec) decode(msg proto.Message, r io.Reader) error {
   226  	return errors.New("msgpackpb+pretty does not support input")
   227  }
   228  func init() {
   229  	codecs["msgpackpb+pretty"] = codec{"Output only; For debugging msgpack structure detail.", msgpackPrettyCodec{}}
   230  }
   231  
   232  type msgpackpbLuaCodec struct{}
   233  
   234  type luaWriter struct{ w io.Writer }
   235  
   236  var _ io.Writer = &luaWriter{}
   237  
   238  func (l *luaWriter) Write(buf []byte) (n int, err error) {
   239  	for _, b := range buf {
   240  		if b == '\\' {
   241  			_, _ = l.w.Write([]byte(`\\`))
   242  		} else if b >= ' ' && b <= '~' {
   243  			_, _ = l.w.Write([]byte{b})
   244  		} else {
   245  			fmt.Fprintf(l.w, "\\%03d", b)
   246  		}
   247  		n++
   248  	}
   249  	return
   250  }
   251  
   252  func (msgpackpbLuaCodec) encode(msg proto.Message, w io.Writer) error {
   253  	return msgpackpb.MarshalStream(&luaWriter{w}, msg, msgpackpb.Deterministic)
   254  }
   255  
   256  var escapes = regexp.MustCompile(`(\\\\)|(\\\d\d\d)`)
   257  var literalSlash = []byte(`\\`)
   258  
   259  func (m msgpackpbLuaCodec) decode(msg proto.Message, r io.Reader) error {
   260  	raw, err := m.decodeToMsgpack(r)
   261  	if err == nil {
   262  		err = msgpackpb.Unmarshal(raw, msg)
   263  	}
   264  	return err
   265  }
   266  
   267  func (msgpackpbLuaCodec) decodeToMsgpack(r io.Reader) (msgpack.RawMessage, error) {
   268  	raw, err := io.ReadAll(r)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  	raw = bytes.Trim(raw, `"'`)
   273  	raw = escapes.ReplaceAllFunc(raw, func(b []byte) []byte {
   274  		if err != nil {
   275  			return nil
   276  		}
   277  		if bytes.Equal(b, literalSlash) {
   278  			return []byte(`\`)
   279  		}
   280  		var byt uint64
   281  		byt, err = strconv.ParseUint(string(b[1:]), 10, 8)
   282  		return []byte{byte(byt)}
   283  	})
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	return msgpack.RawMessage(raw), nil
   288  }
   289  func init() {
   290  	codecs["msgpackpb+lua"] = codec{"msgpackpb encoding (decimal lua string).", msgpackpbLuaCodec{}}
   291  }
   292  
   293  type msgpackpbGoBytesCodec struct{}
   294  
   295  func (msgpackpbGoBytesCodec) encode(msg proto.Message, w io.Writer) error {
   296  	raw, err := msgpackpb.Marshal(msg, msgpackpb.Deterministic)
   297  	if err != nil {
   298  		return err
   299  	}
   300  	fmt.Fprint(w, []byte(raw))
   301  	return nil
   302  }
   303  func (m msgpackpbGoBytesCodec) decode(msg proto.Message, r io.Reader) error {
   304  	raw, err := m.decodeToMsgpack(r)
   305  	if err != nil {
   306  		return err
   307  	}
   308  	return msgpackpb.Unmarshal(raw, msg)
   309  }
   310  func (msgpackpbGoBytesCodec) decodeToMsgpack(r io.Reader) (msgpack.RawMessage, error) {
   311  	scn := bufio.NewScanner(r)
   312  	scn.Split(bufio.ScanWords)
   313  
   314  	buf := []byte{}
   315  
   316  	for scn.Scan() {
   317  		tok := strings.Trim(scn.Text(), "[]")
   318  		val, err := strconv.ParseUint(tok, 10, 8)
   319  		if err != nil {
   320  			return nil, err
   321  		}
   322  		buf = append(buf, byte(val))
   323  	}
   324  
   325  	return msgpack.RawMessage(buf), nil
   326  }
   327  func init() {
   328  	codecs["msgpackpb+gobytes"] = codec{"msgpackpb encoding (Go `[]byte` decimal format).", msgpackpbGoBytesCodec{}}
   329  }
   330  
   331  type msgpackpbHexStringCodec struct{}
   332  
   333  func (msgpackpbHexStringCodec) encode(msg proto.Message, w io.Writer) error {
   334  	raw, err := msgpackpb.Marshal(msg, msgpackpb.Deterministic)
   335  	if err != nil {
   336  		return err
   337  	}
   338  	fmt.Fprintf(w, "%q", string(raw))
   339  	return nil
   340  }
   341  func (m msgpackpbHexStringCodec) decode(msg proto.Message, r io.Reader) error {
   342  	raw, err := m.decodeToMsgpack(r)
   343  	if err != nil {
   344  		return err
   345  	}
   346  	return msgpackpb.Unmarshal(raw, msg)
   347  }
   348  func (msgpackpbHexStringCodec) decodeToMsgpack(r io.Reader) (msgpack.RawMessage, error) {
   349  	raw, err := io.ReadAll(r)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	rawS, err := strconv.Unquote(string(raw))
   354  	if err != nil {
   355  		return nil, err
   356  	}
   357  	return msgpack.RawMessage(rawS), nil
   358  }
   359  func init() {
   360  	codecs["msgpackpb+hex"] = codec{"msgpackpb encoding (raw redis string).", msgpackpbHexStringCodec{}}
   361  }
   362  
   363  func init() {
   364  	codecNames := []string{}
   365  	for k := range codecs {
   366  		codecNames = append(codecNames, k)
   367  	}
   368  	sort.Strings(codecNames)
   369  
   370  	flag.Usage = func() {
   371  		fmt.Fprintf(flag.CommandLine.Output(), "Usage of %s:\n", os.Args[0])
   372  		fmt.Fprintf(flag.CommandLine.Output(), "\n")
   373  		fmt.Fprintf(flag.CommandLine.Output(),
   374  			"This program transforms quotapb proto messages from one form to another.\n")
   375  		fmt.Fprintf(flag.CommandLine.Output(), "Valid formats are:\n")
   376  		for _, name := range codecNames {
   377  			fmt.Fprintf(flag.CommandLine.Output(), "  %s - %s\n", name, codecs[name].blurb)
   378  		}
   379  		fmt.Fprintf(flag.CommandLine.Output(), "\n")
   380  		fmt.Fprintf(flag.CommandLine.Output(), "Flags:\n")
   381  		flag.PrintDefaults()
   382  	}
   383  }
   384  
   385  func main() {
   386  	flag.Parse()
   387  
   388  	ctx := gologger.StdConfig.Use(context.Background())
   389  
   390  	if *typeName == "" {
   391  		flag.Usage()
   392  		logging.Errorf(ctx, "-type is required")
   393  		os.Exit(1)
   394  	}
   395  	fullName := protoreflect.FullName("go.chromium.org.luci.server.quota.quotapb." + *typeName)
   396  	mt, err := protoregistry.GlobalTypes.FindMessageByName(fullName)
   397  	if err == protoregistry.NotFound {
   398  		mt, err = protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(*typeName))
   399  		if err != nil {
   400  			logging.Errorf(ctx, "could not load type %q: %s", *typeName, err)
   401  			os.Exit(1)
   402  		}
   403  	} else if err != nil {
   404  		logging.Errorf(ctx, "could not load type %q: %s", fullName, err)
   405  		os.Exit(1)
   406  	}
   407  
   408  	cIn, ok := codecs[*inCodec]
   409  	if !ok {
   410  		flag.Usage()
   411  		logging.Errorf(ctx, "invalid -in codec: %q", *inCodec)
   412  		os.Exit(1)
   413  	}
   414  	cOut, ok := codecs[*outCodec]
   415  	if !ok {
   416  		flag.Usage()
   417  		logging.Errorf(ctx, "invalid -out codec: %q", *outCodec)
   418  		os.Exit(1)
   419  	}
   420  
   421  	stdout := bufio.NewWriter(os.Stdout)
   422  	defer func() {
   423  		if err := stdout.Flush(); err != nil {
   424  			panic(err)
   425  		}
   426  	}()
   427  
   428  	mpd, _ := cIn.impl.(msgpackDecoder)
   429  	mpe, _ := cOut.impl.(msgpackEncoder)
   430  	if !*forceRoundTrip && mpd != nil && mpe != nil {
   431  		// In this case we can skip the proto message.
   432  		// going on.
   433  		mpk, err := mpd.decodeToMsgpack(bufio.NewReader(os.Stdin))
   434  		if err != nil {
   435  			logging.Errorf(ctx, "parsing stdin as encoded messagepack: %s", err)
   436  			os.Exit(1)
   437  		}
   438  
   439  		_, err = iotools.WriteTracker(stdout, func(w io.Writer) error {
   440  			return mpe.encodeFromMsgpack(mpk, w)
   441  		})
   442  		if err != nil {
   443  			logging.Errorf(ctx, "encoding messagepack: %s", err)
   444  			os.Exit(1)
   445  		}
   446  	} else {
   447  		msg := mt.New().Interface()
   448  		if err := cIn.impl.decode(msg, bufio.NewReader(os.Stdin)); err != nil {
   449  			logging.Errorf(ctx, "failed to decode stdin: %s", err)
   450  			os.Exit(1)
   451  		}
   452  
   453  		_, err = iotools.WriteTracker(stdout, func(w io.Writer) error {
   454  			return cOut.impl.encode(msg, w)
   455  		})
   456  		if err != nil {
   457  			logging.Errorf(ctx, "failed to encode to stdout: %s", err)
   458  			os.Exit(1)
   459  		}
   460  	}
   461  }