github.com/ipld/go-ipld-prime@v0.21.0/node/bindnode/infer.go (about)

     1  package bindnode
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/ipfs/go-cid"
    10  	"github.com/ipld/go-ipld-prime/datamodel"
    11  	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
    12  	"github.com/ipld/go-ipld-prime/schema"
    13  )
    14  
    15  var (
    16  	goTypeBool    = reflect.TypeOf(false)
    17  	goTypeInt     = reflect.TypeOf(int(0))
    18  	goTypeFloat   = reflect.TypeOf(0.0)
    19  	goTypeString  = reflect.TypeOf("")
    20  	goTypeBytes   = reflect.TypeOf([]byte{})
    21  	goTypeLink    = reflect.TypeOf((*datamodel.Link)(nil)).Elem()
    22  	goTypeNode    = reflect.TypeOf((*datamodel.Node)(nil)).Elem()
    23  	goTypeCidLink = reflect.TypeOf((*cidlink.Link)(nil)).Elem()
    24  	goTypeCid     = reflect.TypeOf((*cid.Cid)(nil)).Elem()
    25  
    26  	schemaTypeBool   = schema.SpawnBool("Bool")
    27  	schemaTypeInt    = schema.SpawnInt("Int")
    28  	schemaTypeFloat  = schema.SpawnFloat("Float")
    29  	schemaTypeString = schema.SpawnString("String")
    30  	schemaTypeBytes  = schema.SpawnBytes("Bytes")
    31  	schemaTypeLink   = schema.SpawnLink("Link")
    32  	schemaTypeAny    = schema.SpawnAny("Any")
    33  )
    34  
    35  // Consider exposing these APIs later, if they might be useful.
    36  
    37  type seenEntry struct {
    38  	goType     reflect.Type
    39  	schemaType schema.Type
    40  }
    41  
    42  // verifyCompatibility is the primary way we check that the schema type(s)
    43  // matches the Go type(s); so we do this before we can proceed operating on it.
    44  // verifyCompatibility doesn't return an error, it panics—the errors here are
    45  // not runtime errors, they're programmer errors because your schema doesn't
    46  // match your Go type
    47  func verifyCompatibility(cfg config, seen map[seenEntry]bool, goType reflect.Type, schemaType schema.Type) {
    48  	// TODO(mvdan): support **T as well?
    49  	if goType.Kind() == reflect.Ptr {
    50  		goType = goType.Elem()
    51  	}
    52  
    53  	// Avoid endless loops.
    54  	//
    55  	// TODO(mvdan): this is easy but fairly allocation-happy.
    56  	// Plus, one map per call means we don't reuse work.
    57  	if seen[seenEntry{goType, schemaType}] {
    58  		return
    59  	}
    60  	seen[seenEntry{goType, schemaType}] = true
    61  
    62  	doPanic := func(format string, args ...interface{}) {
    63  		panicFormat := "bindnode: schema type %s is not compatible with Go type %s"
    64  		panicArgs := []interface{}{schemaType.Name(), goType.String()}
    65  
    66  		if format != "" {
    67  			panicFormat += ": " + format
    68  		}
    69  		panicArgs = append(panicArgs, args...)
    70  		panic(fmt.Sprintf(panicFormat, panicArgs...))
    71  	}
    72  	switch schemaType := schemaType.(type) {
    73  	case *schema.TypeBool:
    74  		if customConverter := cfg.converterForType(goType); customConverter != nil {
    75  			if customConverter.kind != schema.TypeKind_Bool {
    76  				doPanic("kind mismatch; custom converter for type is not for Bool")
    77  			}
    78  		} else if goType.Kind() != reflect.Bool {
    79  			doPanic("kind mismatch; need boolean")
    80  		}
    81  	case *schema.TypeInt:
    82  		if customConverter := cfg.converterForType(goType); customConverter != nil {
    83  			if customConverter.kind != schema.TypeKind_Int {
    84  				doPanic("kind mismatch; custom converter for type is not for Int")
    85  			}
    86  		} else if kind := goType.Kind(); !kindInt[kind] && !kindUint[kind] {
    87  			doPanic("kind mismatch; need integer")
    88  		}
    89  	case *schema.TypeFloat:
    90  		if customConverter := cfg.converterForType(goType); customConverter != nil {
    91  			if customConverter.kind != schema.TypeKind_Float {
    92  				doPanic("kind mismatch; custom converter for type is not for Float")
    93  			}
    94  		} else {
    95  			switch goType.Kind() {
    96  			case reflect.Float32, reflect.Float64:
    97  			default:
    98  				doPanic("kind mismatch; need float")
    99  			}
   100  		}
   101  	case *schema.TypeString:
   102  		// TODO: allow []byte?
   103  		if customConverter := cfg.converterForType(goType); customConverter != nil {
   104  			if customConverter.kind != schema.TypeKind_String {
   105  				doPanic("kind mismatch; custom converter for type is not for String")
   106  			}
   107  		} else if goType.Kind() != reflect.String {
   108  			doPanic("kind mismatch; need string")
   109  		}
   110  	case *schema.TypeBytes:
   111  		// TODO: allow string?
   112  		if customConverter := cfg.converterForType(goType); customConverter != nil {
   113  			if customConverter.kind != schema.TypeKind_Bytes {
   114  				doPanic("kind mismatch; custom converter for type is not for Bytes")
   115  			}
   116  		} else if goType.Kind() != reflect.Slice {
   117  			doPanic("kind mismatch; need slice of bytes")
   118  		} else if goType.Elem().Kind() != reflect.Uint8 {
   119  			doPanic("kind mismatch; need slice of bytes")
   120  		}
   121  	case *schema.TypeEnum:
   122  		if _, ok := schemaType.RepresentationStrategy().(schema.EnumRepresentation_Int); ok {
   123  			if kind := goType.Kind(); kind != reflect.String && !kindInt[kind] && !kindUint[kind] {
   124  				doPanic("kind mismatch; need string or integer")
   125  			}
   126  		} else {
   127  			if goType.Kind() != reflect.String {
   128  				doPanic("kind mismatch; need string")
   129  			}
   130  		}
   131  	case *schema.TypeList:
   132  		if goType.Kind() != reflect.Slice {
   133  			doPanic("kind mismatch; need slice")
   134  		}
   135  		goType = goType.Elem()
   136  		if schemaType.ValueIsNullable() {
   137  			if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
   138  				doPanic("nullable types must be nilable")
   139  			} else if ptr {
   140  				goType = goType.Elem()
   141  			}
   142  		}
   143  		verifyCompatibility(cfg, seen, goType, schemaType.ValueType())
   144  	case *schema.TypeMap:
   145  		//	struct {
   146  		//		Keys   []K
   147  		//		Values map[K]V
   148  		//	}
   149  		if goType.Kind() != reflect.Struct {
   150  			doPanic("kind mismatch; need struct{Keys []K; Values map[K]V}")
   151  		}
   152  		if goType.NumField() != 2 {
   153  			doPanic("%d vs 2 fields", goType.NumField())
   154  		}
   155  
   156  		fieldKeys := goType.Field(0)
   157  		if fieldKeys.Type.Kind() != reflect.Slice {
   158  			doPanic("kind mismatch; need struct{Keys []K; Values map[K]V}")
   159  		}
   160  		verifyCompatibility(cfg, seen, fieldKeys.Type.Elem(), schemaType.KeyType())
   161  
   162  		fieldValues := goType.Field(1)
   163  		if fieldValues.Type.Kind() != reflect.Map {
   164  			doPanic("kind mismatch; need struct{Keys []K; Values map[K]V}")
   165  		}
   166  		keyType := fieldValues.Type.Key()
   167  		verifyCompatibility(cfg, seen, keyType, schemaType.KeyType())
   168  
   169  		elemType := fieldValues.Type.Elem()
   170  		if schemaType.ValueIsNullable() {
   171  			if ptr, nilable := ptrOrNilable(elemType.Kind()); !nilable {
   172  				doPanic("nullable types must be nilable")
   173  			} else if ptr {
   174  				elemType = elemType.Elem()
   175  			}
   176  		}
   177  		verifyCompatibility(cfg, seen, elemType, schemaType.ValueType())
   178  	case *schema.TypeStruct:
   179  		if goType.Kind() != reflect.Struct {
   180  			doPanic("kind mismatch; need struct")
   181  		}
   182  
   183  		schemaFields := schemaType.Fields()
   184  		if goType.NumField() != len(schemaFields) {
   185  			doPanic("%d vs %d fields", goType.NumField(), len(schemaFields))
   186  		}
   187  		for i, schemaField := range schemaFields {
   188  			schemaType := schemaField.Type()
   189  			goType := goType.Field(i).Type
   190  			switch {
   191  			case schemaField.IsOptional() && schemaField.IsNullable():
   192  				// TODO: https://github.com/ipld/go-ipld-prime/issues/340 will
   193  				// help here, to avoid the double pointer. We can't use nilable
   194  				// but non-pointer types because that's just one "nil" state.
   195  				// TODO: deal with custom converters in this case
   196  				if goType.Kind() != reflect.Ptr {
   197  					doPanic("optional and nullable fields must use double pointers (**)")
   198  				}
   199  				goType = goType.Elem()
   200  				if goType.Kind() != reflect.Ptr {
   201  					doPanic("optional and nullable fields must use double pointers (**)")
   202  				}
   203  				goType = goType.Elem()
   204  			case schemaField.IsOptional():
   205  				if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
   206  					doPanic("optional fields must be nilable")
   207  				} else if ptr {
   208  					goType = goType.Elem()
   209  				}
   210  			case schemaField.IsNullable():
   211  				if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
   212  					if customConverter := cfg.converterForType(goType); customConverter == nil {
   213  						doPanic("nullable fields must be nilable")
   214  					}
   215  				} else if ptr {
   216  					goType = goType.Elem()
   217  				}
   218  			}
   219  			verifyCompatibility(cfg, seen, goType, schemaType)
   220  		}
   221  	case *schema.TypeUnion:
   222  		if goType.Kind() != reflect.Struct {
   223  			doPanic("kind mismatch; need struct for an union")
   224  		}
   225  
   226  		schemaMembers := schemaType.Members()
   227  		if goType.NumField() != len(schemaMembers) {
   228  			doPanic("%d vs %d members", goType.NumField(), len(schemaMembers))
   229  		}
   230  
   231  		for i, schemaType := range schemaMembers {
   232  			goType := goType.Field(i).Type
   233  			if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
   234  				doPanic("union members must be nilable")
   235  			} else if ptr {
   236  				goType = goType.Elem()
   237  			}
   238  			verifyCompatibility(cfg, seen, goType, schemaType)
   239  		}
   240  	case *schema.TypeLink:
   241  		if customConverter := cfg.converterForType(goType); customConverter != nil {
   242  			if customConverter.kind != schema.TypeKind_Link {
   243  				doPanic("kind mismatch; custom converter for type is not for Link")
   244  			}
   245  		} else if goType != goTypeLink && goType != goTypeCidLink && goType != goTypeCid {
   246  			doPanic("links in Go must be datamodel.Link, cidlink.Link, or cid.Cid")
   247  		}
   248  	case *schema.TypeAny:
   249  		if customConverter := cfg.converterForType(goType); customConverter != nil {
   250  			if customConverter.kind != schema.TypeKind_Any {
   251  				doPanic("kind mismatch; custom converter for type is not for Any")
   252  			}
   253  		} else if goType != goTypeNode {
   254  			doPanic("Any in Go must be datamodel.Node")
   255  		}
   256  	default:
   257  		panic(fmt.Sprintf("%T", schemaType))
   258  	}
   259  }
   260  
   261  func ptrOrNilable(kind reflect.Kind) (ptr, nilable bool) {
   262  	switch kind {
   263  	case reflect.Ptr:
   264  		return true, true
   265  	case reflect.Interface, reflect.Map, reflect.Slice:
   266  		return false, true
   267  	default:
   268  		return false, false
   269  	}
   270  }
   271  
   272  // If we recurse past a large number of levels, we're mostly stuck in a loop.
   273  // Prevent burning CPU or causing OOM crashes.
   274  // If a user really wrote an IPLD schema or Go type with such deep nesting,
   275  // it's likely they are trying to abuse the system as well.
   276  const maxRecursionLevel = 1 << 10
   277  
   278  type inferredStatus int
   279  
   280  const (
   281  	_ inferredStatus = iota
   282  	inferringInProcess
   283  	inferringDone
   284  )
   285  
   286  // inferGoType can build a Go type given a schema
   287  func inferGoType(typ schema.Type, status map[schema.TypeName]inferredStatus, level int) reflect.Type {
   288  	if level > maxRecursionLevel {
   289  		panic(fmt.Sprintf("inferGoType: refusing to recurse past %d levels", maxRecursionLevel))
   290  	}
   291  	name := typ.Name()
   292  	if status[name] == inferringInProcess {
   293  		panic("bindnode: inferring Go types from cyclic schemas is not supported since Go reflection does not support creating named types")
   294  	}
   295  	status[name] = inferringInProcess
   296  	defer func() { status[name] = inferringDone }()
   297  	switch typ := typ.(type) {
   298  	case *schema.TypeBool:
   299  		return goTypeBool
   300  	case *schema.TypeInt:
   301  		return goTypeInt
   302  	case *schema.TypeFloat:
   303  		return goTypeFloat
   304  	case *schema.TypeString:
   305  		return goTypeString
   306  	case *schema.TypeBytes:
   307  		return goTypeBytes
   308  	case *schema.TypeStruct:
   309  		fields := typ.Fields()
   310  		fieldsGo := make([]reflect.StructField, len(fields))
   311  		for i, field := range fields {
   312  			ftypGo := inferGoType(field.Type(), status, level+1)
   313  			if field.IsNullable() {
   314  				ftypGo = reflect.PtrTo(ftypGo)
   315  			}
   316  			if field.IsOptional() {
   317  				ftypGo = reflect.PtrTo(ftypGo)
   318  			}
   319  			fieldsGo[i] = reflect.StructField{
   320  				Name: fieldNameFromSchema(field.Name()),
   321  				Type: ftypGo,
   322  			}
   323  		}
   324  		return reflect.StructOf(fieldsGo)
   325  	case *schema.TypeMap:
   326  		ktyp := inferGoType(typ.KeyType(), status, level+1)
   327  		vtyp := inferGoType(typ.ValueType(), status, level+1)
   328  		if typ.ValueIsNullable() {
   329  			vtyp = reflect.PtrTo(vtyp)
   330  		}
   331  		// We need an extra field to keep the map ordered,
   332  		// since IPLD maps must have stable iteration order.
   333  		// We could sort when iterating, but that's expensive.
   334  		// Keeping the insertion order is easy and intuitive.
   335  		//
   336  		//	struct {
   337  		//		Keys   []K
   338  		//		Values map[K]V
   339  		//	}
   340  		fieldsGo := []reflect.StructField{
   341  			{
   342  				Name: "Keys",
   343  				Type: reflect.SliceOf(ktyp),
   344  			},
   345  			{
   346  				Name: "Values",
   347  				Type: reflect.MapOf(ktyp, vtyp),
   348  			},
   349  		}
   350  		return reflect.StructOf(fieldsGo)
   351  	case *schema.TypeList:
   352  		etyp := inferGoType(typ.ValueType(), status, level+1)
   353  		if typ.ValueIsNullable() {
   354  			etyp = reflect.PtrTo(etyp)
   355  		}
   356  		return reflect.SliceOf(etyp)
   357  	case *schema.TypeUnion:
   358  		// type goUnion struct {
   359  		// 	Type1 *Type1
   360  		// 	Type2 *Type2
   361  		// 	...
   362  		// }
   363  		members := typ.Members()
   364  		fieldsGo := make([]reflect.StructField, len(members))
   365  		for i, ftyp := range members {
   366  			ftypGo := inferGoType(ftyp, status, level+1)
   367  			fieldsGo[i] = reflect.StructField{
   368  				Name: fieldNameFromSchema(ftyp.Name()),
   369  				Type: reflect.PtrTo(ftypGo),
   370  			}
   371  		}
   372  		return reflect.StructOf(fieldsGo)
   373  	case *schema.TypeLink:
   374  		return goTypeLink
   375  	case *schema.TypeEnum:
   376  		// TODO: generate int for int reprs by default?
   377  		return goTypeString
   378  	case *schema.TypeAny:
   379  		return goTypeNode
   380  	case nil:
   381  		panic("bindnode: unexpected nil schema.Type")
   382  	}
   383  	panic(fmt.Sprintf("%T", typ))
   384  }
   385  
   386  // from IPLD Schema field names like "foo" to Go field names like "Foo".
   387  func fieldNameFromSchema(name string) string {
   388  	fieldName := strings.Title(name) //lint:ignore SA1019 cases.Title doesn't work for this
   389  	if !token.IsIdentifier(fieldName) {
   390  		panic(fmt.Sprintf("bindnode: inferred field name %q is not a valid Go identifier", fieldName))
   391  	}
   392  	return fieldName
   393  }
   394  
   395  var defaultTypeSystem schema.TypeSystem
   396  
   397  func init() {
   398  	defaultTypeSystem.Init()
   399  
   400  	defaultTypeSystem.Accumulate(schemaTypeBool)
   401  	defaultTypeSystem.Accumulate(schemaTypeInt)
   402  	defaultTypeSystem.Accumulate(schemaTypeFloat)
   403  	defaultTypeSystem.Accumulate(schemaTypeString)
   404  	defaultTypeSystem.Accumulate(schemaTypeBytes)
   405  	defaultTypeSystem.Accumulate(schemaTypeLink)
   406  	defaultTypeSystem.Accumulate(schemaTypeAny)
   407  }
   408  
   409  // TODO: support IPLD maps and unions in inferSchema
   410  
   411  // TODO: support bringing your own TypeSystem?
   412  
   413  // TODO: we should probably avoid re-spawning the same types if the TypeSystem
   414  // has them, and test that that works as expected
   415  
   416  // inferSchema can build a schema from a Go type
   417  func inferSchema(typ reflect.Type, level int) schema.Type {
   418  	if level > maxRecursionLevel {
   419  		panic(fmt.Sprintf("inferSchema: refusing to recurse past %d levels", maxRecursionLevel))
   420  	}
   421  	switch typ.Kind() {
   422  	case reflect.Bool:
   423  		return schemaTypeBool
   424  	case reflect.Int64:
   425  		return schemaTypeInt
   426  	case reflect.Float64:
   427  		return schemaTypeFloat
   428  	case reflect.String:
   429  		return schemaTypeString
   430  	case reflect.Struct:
   431  		// these types must match exactly since we need symmetry of being able to
   432  		// get the values an also assign values to them
   433  		if typ == goTypeCid || typ == goTypeCidLink {
   434  			return schemaTypeLink
   435  		}
   436  
   437  		fieldsSchema := make([]schema.StructField, typ.NumField())
   438  		for i := range fieldsSchema {
   439  			field := typ.Field(i)
   440  			ftyp := field.Type
   441  			ftypSchema := inferSchema(ftyp, level+1)
   442  			fieldsSchema[i] = schema.SpawnStructField(
   443  				field.Name, // TODO: allow configuring the name with tags
   444  				ftypSchema.Name(),
   445  
   446  				// TODO: support nullable/optional with tags
   447  				false,
   448  				false,
   449  			)
   450  		}
   451  		name := typ.Name()
   452  		if name == "" {
   453  			panic("TODO: anonymous composite types")
   454  		}
   455  		typSchema := schema.SpawnStruct(name, fieldsSchema, nil)
   456  		defaultTypeSystem.Accumulate(typSchema)
   457  		return typSchema
   458  	case reflect.Slice:
   459  		if typ.Elem().Kind() == reflect.Uint8 {
   460  			// Special case for []byte.
   461  			return schemaTypeBytes
   462  		}
   463  
   464  		nullable := false
   465  		if typ.Elem().Kind() == reflect.Ptr {
   466  			nullable = true
   467  		}
   468  		etypSchema := inferSchema(typ.Elem(), level+1)
   469  		name := typ.Name()
   470  		if name == "" {
   471  			name = "List_" + etypSchema.Name()
   472  		}
   473  		typSchema := schema.SpawnList(name, etypSchema.Name(), nullable)
   474  		defaultTypeSystem.Accumulate(typSchema)
   475  		return typSchema
   476  	case reflect.Interface:
   477  		// these types must match exactly since we need symmetry of being able to
   478  		// get the values an also assign values to them
   479  		if typ == goTypeLink {
   480  			return schemaTypeLink
   481  		}
   482  		if typ == goTypeNode {
   483  			return schemaTypeAny
   484  		}
   485  		panic("bindnode: unable to infer from interface")
   486  	}
   487  	panic(fmt.Sprintf("bindnode: unable to infer from type %s", typ.Kind().String()))
   488  }
   489  
   490  // There are currently 27 reflect.Kind iota values,
   491  // so 32 should be plenty to ensure we don't panic in practice.
   492  
   493  var kindInt = [32]bool{
   494  	reflect.Int:   true,
   495  	reflect.Int8:  true,
   496  	reflect.Int16: true,
   497  	reflect.Int32: true,
   498  	reflect.Int64: true,
   499  }
   500  
   501  var kindUint = [32]bool{
   502  	reflect.Uint:   true,
   503  	reflect.Uint8:  true,
   504  	reflect.Uint16: true,
   505  	reflect.Uint32: true,
   506  	reflect.Uint64: true,
   507  }