github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/genauth/cppgen.go (about)

     1  package genauth
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  )
     7  
     8  const (
     9  	implHeader = `#include "auth.h"
    10  #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
    11  
    12  namespace tao {
    13  namespace {
    14  // This is the canonical implementation of make_unique for C++11. It is wrapped
    15  // in an anonymous namespace to keep it from conflicting with the real thing if
    16  // it exists.
    17  template<typename T, typename ...Args>
    18  std::unique_ptr<T> make_unique( Args&& ...args )
    19  {
    20      return std::unique_ptr<T>( new T( std::forward<Args>(args)... ) );
    21  }
    22  }  // namespace
    23  
    24  using google::protobuf::uint64;
    25  using google::protobuf::uint32;
    26  using google::protobuf::io::ArrayInputStream;
    27  using google::protobuf::io::CodedInputStream;
    28  using google::protobuf::io::CodedOutputStream;
    29  using std::string;
    30  
    31  namespace {
    32  `
    33  
    34  	encodeString = `void EncodeString(const string& str, CodedOutputStream* output) {
    35    output->WriteVarint32(str.size());
    36    output->WriteString(str);
    37  }
    38  
    39  `
    40  
    41  	decodeString = `bool DecodeString(CodedInputStream* input, string* str) {
    42    uint32 size = 0;
    43    if (!input->ReadVarint32(&size)) return false;
    44    return input->ReadString(str, size);
    45  }
    46  
    47  `
    48  
    49  	peekTag = `bool PeekTag(CodedInputStream* input, uint32* tag) {
    50    const void* ptr = nullptr;
    51    int size = 0;
    52    if (!input->GetDirectBufferPointer(&ptr, &size)) return false;
    53  
    54    ArrayInputStream array_stream(ptr, size);
    55    CodedInputStream temp_input(&array_stream);
    56    return temp_input.ReadVarint32(tag);
    57  }
    58  
    59  `
    60  
    61  	unmarshalTemplate = "bool %s::Unmarshal(CodedInputStream* input) {"
    62  	marshalTemplate   = "void %s::Marshal(CodedOutputStream* output) {"
    63  
    64  	headerPrefix = `#ifndef CLOUDPROXY_GO_APPS_GENAUTH_H_
    65  #define CLOUDPROXY_GO_APPS_GENAUTH_H_
    66  #include <memory>
    67  #include <string>
    68  #include <vector>
    69  
    70  #include <google/protobuf/io/coded_stream.h>
    71  #include <google/protobuf/stubs/common.h>
    72  
    73  namespace tao {
    74  
    75  class LogicElement {
    76   public:
    77    virtual void Marshal(google::protobuf::io::CodedOutputStream* output) = 0;
    78    virtual bool Unmarshal(google::protobuf::io::CodedInputStream* input) = 0;
    79  };
    80  
    81  class Form: public LogicElement {
    82   public:
    83    virtual ~Form() = default;
    84    virtual void Marshal(google::protobuf::io::CodedOutputStream* output) = 0;
    85    virtual bool Unmarshal(google::protobuf::io::CodedInputStream* input) = 0;
    86  };
    87  
    88  class Term: public LogicElement {
    89   public:
    90    virtual ~Term() = default;
    91    virtual void Marshal(google::protobuf::io::CodedOutputStream* output) = 0;
    92    virtual bool Unmarshal(google::protobuf::io::CodedInputStream* input) = 0;
    93  };
    94  `
    95  )
    96  
    97  // CppGenerator generates C++ code from the Go auth types.
    98  type CppGenerator struct {
    99  	Constants  []Constant
   100  	Types      map[string][]Field
   101  	Interfaces map[string]bool
   102  	FormTypes  map[string]bool
   103  	TermTypes  map[string]bool
   104  }
   105  
   106  // Constants creates lines that define the constants for auth serialization and
   107  // deserialization.
   108  func (cg *CppGenerator) BinaryConstants() []string {
   109  	header := []string{"enum class BinaryTags {"}
   110  	for i, constant := range cg.Constants {
   111  		value := fmt.Sprintf("  %s = %d", constant.Name, constant.Value)
   112  		if i < len(cg.Constants)-1 {
   113  			value += ","
   114  		}
   115  		header = append(header, value)
   116  	}
   117  	return append(header, "};", "")
   118  }
   119  
   120  // FieldDeclType returns a string that represents the C++ type for this field.
   121  func FieldDeclType(info Field) string {
   122  	typeName := info.TypeName
   123  
   124  	if primitives[typeName] {
   125  		if info.Type == StarType {
   126  			typeName = typeName + "*"
   127  		}
   128  
   129  		if typeName == "string" {
   130  			typeName = "std::string"
   131  		}
   132  
   133  		if typeName == "int64*" {
   134  			// This is not a pointer in the C++ version. The type change is needed to get the CodedInputStream unmarshalling to work.
   135  			typeName = "google::protobuf::uint64"
   136  		}
   137  
   138  		if typeName == "int" {
   139  			// The type change is needed to get the CodedInputStream unmarshalling to work.
   140  			typeName = "google::protobuf::uint32"
   141  		}
   142  
   143  		return typeName
   144  	}
   145  
   146  	switch info.Type {
   147  	case IdentType, StarType:
   148  		return fmt.Sprintf("std::unique_ptr<%s>", typeName)
   149  	case ArrayType:
   150  		if info.TypeName == "byte" {
   151  			return "std::string"
   152  		}
   153  
   154  		return "std::vector<std::unique_ptr<" + typeName + ">>"
   155  	}
   156  
   157  	return ""
   158  }
   159  
   160  // FieldName returns the generated name for a field in an auth class.
   161  func FieldName(info Field) string {
   162  	varName := info.Name + "_"
   163  	if info.Type == ArrayType && info.TypeName != "byte" {
   164  		varName = info.Name + "s_";
   165  	}
   166  
   167  	return varName
   168  }
   169  
   170  // FieldDecl generates the code for a field in a C++ header.
   171  func FieldDecl(info Field) []string {
   172  	field := make([]string, 0)
   173  	fieldType := FieldDeclType(info)
   174  	if info.Type == StarType {
   175  		field = append(field, fmt.Sprintf("  bool %s_present_;", info.Name))
   176  	}
   177  
   178  	return append(field, "  " + fieldType + " " + FieldName(info) + ";")
   179  }
   180  
   181  func (cg *CppGenerator) Class(name string, fields []Field) []string {
   182  	class := fmt.Sprintf("class %s", name)
   183  	isSubclass := false
   184  	if _, isForm := cg.FormTypes[name]; isForm {
   185  		class += ": public Form"
   186  		isSubclass = true
   187  	} else if _, isTerm := cg.TermTypes[name]; isTerm {
   188  		class += ": public Term"
   189  		isSubclass = true
   190  	}
   191  
   192  	class += " {"
   193  	header := []string{class, " public:"}
   194  
   195  	header = append(header, fmt.Sprintf("  %s() = default;", name))
   196  	var override string
   197  	if isSubclass {
   198  		override = " override"
   199  	}
   200  
   201  	rvalue := fmt.Sprintf("  %s(", name)
   202  
   203  	marshal := "  void Marshal(google::protobuf::io::CodedOutputStream* output)"
   204  	unmarshal := "  bool Unmarshal(google::protobuf::io::CodedInputStream* input)"
   205  
   206  	header = append(header, fmt.Sprintf("  ~%s()%s = default;", name, override))
   207  	header = append(header, fmt.Sprintf("%s%s;", marshal, override))
   208  	header = append(header, fmt.Sprintf("%s%s;", unmarshal, override))
   209  
   210  	variables := make([]string, 0)
   211  	for i, info := range fields {
   212  		variables = append(variables, FieldDecl(info)...)
   213  
   214  		rvalue += FieldDeclType(info) + "&& " + FieldName(info)
   215  		if i < len(fields) - 1 {
   216  			rvalue += ", "
   217  		} else {
   218  			rvalue += ");"
   219  		}
   220  	}
   221  
   222  	header = append(header, rvalue)
   223  	header = append(header, variables...)
   224  
   225  	return append(header, "};", "")
   226  }
   227  
   228  // Header creates the header file for the C++ code.
   229  func (cg *CppGenerator) Header() []string {
   230  	header := strings.Split(headerPrefix, "\n")
   231  	header = append(header, cg.BinaryConstants()...)
   232  
   233  	// Append declarations of all the class to avoid ordering problems.
   234  	for name, _ := range cg.Types {
   235  		header = append(header, "class "+name+";")
   236  	}
   237  	header = append(header, "")
   238  
   239  	for name, fields := range cg.Types {
   240  		header = append(header, cg.Class(name, fields)...)
   241  	}
   242  
   243  	return append(header, "}  // namespace tao", "#endif  // CLOUDPROXY_GO_APPS_GENAUTH_H_")
   244  }
   245  
   246  // Decoder generates code that deserializes bytes that might be any subclass of
   247  // a given interface class.
   248  func (cg *CppGenerator) Decoder(interfaceName string, types map[string]bool) []string {
   249  	impl := []string{
   250  		"bool Decode" + interfaceName + "(uint32 tag, CodedInputStream* input, std::unique_ptr<" + interfaceName + ">* value) {",
   251  		"  switch(tag) {",
   252  	}
   253  	for _, constant := range cg.Constants {
   254  		typeName := strings.TrimPrefix(constant.Name, "kTag")
   255  		if _, ok := types[typeName]; !ok {
   256  			continue
   257  		}
   258  
   259  		impl = append(impl, []string{
   260  			"  case static_cast<uint32>(BinaryTags::" + constant.Name + "):",
   261  			"    *value = make_unique<" + typeName + ">();",
   262  			"    break;",
   263  		}...)
   264  	}
   265  
   266  	impl = append(impl, "  default:", "    return false;")
   267  	impl = append(impl, "  }", "  return (*value)->Unmarshal(input);", "}", "")
   268  	return impl
   269  }
   270  
   271  // PrimitiveUnmarshaller generates code that unmarshals primitive types like int.
   272  func PrimitiveUnmarshaller(typeName string, field Field) []string {
   273  	if typeName == "string" {
   274  		return []string{fmt.Sprintf("  if (!DecodeString(input, &%s_)) return false;", field.Name)}
   275  	}
   276  
   277  	if typeName == "int" {
   278  		return []string{fmt.Sprintf("  if (!input->ReadVarint32(&%s_)) return false;", field.Name)}
   279  	}
   280  
   281  	if typeName == "bool" {
   282  		return []string{
   283  			fmt.Sprintf("  uint32 %s_value = 0;", field.Name),
   284  			fmt.Sprintf("  if (!input->ReadVarint32(&%s_value)) return false;", field.Name),
   285  			fmt.Sprintf("  %[1]s_ = !!%[1]s_value;", field.Name),
   286  		}
   287  	}
   288  
   289  	if typeName == "int64" && field.Type == StarType {
   290  		// This has a boolean value that says whether or
   291  		// not to expect an int64 field next.
   292  		return []string{
   293  			fmt.Sprintf("  uint32 %s_present_value = 0;", field.Name),
   294  			fmt.Sprintf("  if (!input->ReadVarint32(&%s_present_value)) return false;", field.Name),
   295  			fmt.Sprintf("  %[1]s_present_ = !!%[1]s_present_value;", field.Name),
   296  			fmt.Sprintf("  if (%s_present_) {", field.Name),
   297  			fmt.Sprintf("    if (!input->ReadVarint64(&%s_)) return false;", field.Name),
   298  			"  }",
   299  		}
   300  	}
   301  
   302  	return nil
   303  }
   304  
   305  // IdentUnmarshaller generates code that unmarshals an ast.Ident.
   306  func (cg *CppGenerator) IdentUnmarshaller(typeName string, field Field) []string {
   307  	if cg.Interfaces[typeName] {
   308  		return []string{
   309  			// Peek at the next tag.
   310  			fmt.Sprintf("  uint32 %s_tag = 0;", field.Name),
   311  			fmt.Sprintf("  if (!PeekTag(input, &%s_tag)) return false;", field.Name),
   312  			fmt.Sprintf("  if (!Decode%s(%[2]s_tag, input, &%[2]s_)) return false;", typeName, field.Name),
   313  		}
   314  	}
   315  	return []string{
   316  		fmt.Sprintf("  %s_ = make_unique<%s>();", field.Name, typeName),
   317  		fmt.Sprintf("  if (!%s_->Unmarshal(input)) return false;", field.Name),
   318  	}
   319  }
   320  
   321  // ArrayUnmarshaller generates code that unmarshals an Array.
   322  func (cg *CppGenerator) ArrayUnmarshaller(typeName string, field Field) []string {
   323  	if typeName == "byte" {
   324  		return []string{fmt.Sprintf("  if (!DecodeString(input, &%s_)) return false;", field.Name)}
   325  	}
   326  
   327  	impl := []string{
   328  		fmt.Sprintf("  uint32 %ss_count = 0;", field.Name),
   329  		fmt.Sprintf("  if (!input->ReadVarint32(&%ss_count)) return false;", field.Name),
   330  		fmt.Sprintf("  for(uint32 i = 0; i < %ss_count; i++) {", field.Name),
   331  	}
   332  
   333  	if cg.Interfaces[typeName] {
   334  		return append(impl, []string{
   335  			// Peek at the next tag.
   336  			fmt.Sprintf("    uint32 %ss_tag = 0;", field.Name),
   337  			fmt.Sprintf("    if (!PeekTag(input, &%ss_tag)) return false;", field.Name),
   338  			fmt.Sprintf("    std::unique_ptr<%s> %ss_obj;", typeName, field.Name),
   339  			fmt.Sprintf("    if (!Decode%s(%[2]ss_tag, input, &%[2]ss_obj)) return false;", typeName, field.Name),
   340  			fmt.Sprintf("    %[1]ss_.emplace_back(std::move(%[1]ss_obj));", field.Name),
   341  			"  }",
   342  		}...)
   343  	}
   344  
   345  	return append(impl, []string{
   346  		fmt.Sprintf("    auto %ss_obj = make_unique<%s>();", field.Name, typeName),
   347  		fmt.Sprintf("    %ss_obj->Unmarshal(input);", field.Name),
   348  		fmt.Sprintf("    %[1]ss_.emplace_back(std::move(%[1]ss_obj));", field.Name),
   349  		"  }",
   350  	}...)
   351  }
   352  
   353  // Unmarshaller generates code that unmarshals bytes to a given class.
   354  func (cg *CppGenerator) Unmarshaller(name string, fields []Field) []string {
   355  	impl := []string{fmt.Sprintf(unmarshalTemplate, name)}
   356  	tag := "BinaryTags::kTag" + name
   357  	impl = append(impl, []string{
   358  		"  uint32 type_tag = 0;",
   359  		"  if (!input->ReadVarint32(&type_tag)) return false;",
   360  		fmt.Sprintf("  if (type_tag != static_cast<uint32>(%s)) return false;", tag),
   361  	}...)
   362  	for _, field := range fields {
   363  		typeName := field.TypeName
   364  		if primitives[typeName] {
   365  			m := PrimitiveUnmarshaller(typeName, field)
   366  			if m != nil {
   367  				impl = append(impl, m...)
   368  			}
   369  			continue
   370  		}
   371  
   372  		switch field.Type {
   373  		case IdentType:
   374  			impl = append(impl, cg.IdentUnmarshaller(typeName, field)...)
   375  		case ArrayType:
   376  			impl = append(impl, cg.ArrayUnmarshaller(typeName, field)...)
   377  		}
   378  	}
   379  
   380  	return append(impl, "  return true;", "}", "")
   381  }
   382  
   383  // PrimitiveMarshaller generates serialization code for primitive types like int.
   384  func (cg *CppGenerator) PrimitiveMarshaller(typeName string, field Field) []string {
   385  	if typeName == "string" {
   386  		return []string{fmt.Sprintf("  EncodeString(%s_, output);", field.Name)}
   387  	}
   388  
   389  	if typeName == "int" {
   390  		return []string{fmt.Sprintf("  output->WriteVarint32(%s_);", field.Name)}
   391  	}
   392  
   393  	if typeName == "bool" {
   394  		return []string{fmt.Sprintf("  output->WriteVarint32(static_cast<uint32>(%s_));", field.Name)}
   395  	}
   396  
   397  	if typeName == "int64" && field.Type == StarType {
   398  		// This has a boolean value that says whether or
   399  		// not to expect an int64 field next.
   400  		return []string{
   401  			fmt.Sprintf("  uint32 %[1]s_value = %[1]s_present_ ? 1 : 0;", field.Name),
   402  			fmt.Sprintf("  output->WriteVarint32(%s_value);", field.Name),
   403  			fmt.Sprintf("  if (%s_present_) {", field.Name),
   404  			fmt.Sprintf("    output->WriteVarint64(%s_);", field.Name),
   405  			"  }",
   406  		}
   407  	}
   408  
   409  	return nil
   410  }
   411  
   412  // MoveConstructor generates a constructor that moves all of the member
   413  // variables through rvalue parameters.
   414  func (cg *CppGenerator) MoveConstructor(name string, fields []Field) []string {
   415  	sig := fmt.Sprintf("%[1]s::%[1]s(", name)
   416  	body := make([]string, 0)
   417  	for i, info := range fields {
   418  		sig += FieldDeclType(info) + "&& " + info.Name
   419  		end := ""
   420  		if i < len(fields) - 1 {
   421  			sig += ", "
   422  			end = ","
   423  		} else {
   424  			sig += ")"
   425  			end = " {}\n"
   426  		}
   427  
   428  		leader := "    "
   429  		if i == 0 {
   430  			leader += ": "
   431  		}
   432  		body = append(body, fmt.Sprintf("%s%s(std::move(%s))%s", leader, FieldName(info), info.Name, end))
   433  	}
   434  
   435  	header := []string{sig}
   436  	return append(header, body...)
   437  }
   438  
   439  // Marshaller generates serialization code for the given auth type.
   440  func (cg *CppGenerator) Marshaller(name string, fields []Field) []string {
   441  	impl := []string{fmt.Sprintf(marshalTemplate, name)}
   442  	tag := "BinaryTags::kTag" + name
   443  	impl = append(impl, fmt.Sprintf("  output->WriteVarint32(static_cast<uint32>(%s));", tag))
   444  	for _, field := range fields {
   445  		typeName := field.TypeName
   446  
   447  		if primitives[typeName] {
   448  			m := cg.PrimitiveMarshaller(typeName, field)
   449  			if m != nil {
   450  				impl = append(impl, m...)
   451  			}
   452  			continue
   453  		}
   454  
   455  		switch field.Type {
   456  		case IdentType:
   457  			impl = append(impl, fmt.Sprintf("  %s_->Marshal(output);", field.Name))
   458  		case ArrayType:
   459  			if typeName == "byte" {
   460  				impl = append(impl, fmt.Sprintf("  EncodeString(%s_, output);", field.Name))
   461  				continue
   462  			}
   463  
   464  			impl = append(impl, []string{
   465  				fmt.Sprintf("  output->WriteVarint32(%ss_.size());", field.Name),
   466  				fmt.Sprintf("  for(auto& elt : %ss_) {", field.Name),
   467  				"    elt->Marshal(output);",
   468  				"  }",
   469  			}...)
   470  		}
   471  	}
   472  
   473  	return append(impl, "}", "")
   474  }
   475  
   476  // Implementation generates the C++ implementation file for the auth classes.
   477  func (cg *CppGenerator) Implementation() []string {
   478  	impl := strings.Split(implHeader, "\n")
   479  	impl = append(impl, strings.Split(encodeString, "\n")...)
   480  	impl = append(impl, strings.Split(decodeString, "\n")...)
   481  	impl = append(impl, strings.Split(peekTag, "\n")...)
   482  
   483  	impl = append(impl, cg.Decoder("Form", cg.FormTypes)...)
   484  	impl = append(impl, cg.Decoder("Term", cg.TermTypes)...)
   485  	impl = append(impl, "}  // namespace", "")
   486  
   487  	for name, fields := range cg.Types {
   488  		impl = append(impl, cg.MoveConstructor(name, fields)...)
   489  		impl = append(impl, cg.Marshaller(name, fields)...)
   490  		impl = append(impl, cg.Unmarshaller(name, fields)...)
   491  	}
   492  
   493  	return append(impl, "}  // namespace tao")
   494  }