github.com/kaptinlin/jsonschema@v0.4.6/compiler.go (about)

     1  package jsonschema
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"encoding/xml"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/goccy/go-yaml"
    15  )
    16  
    17  // FormatDef defines a custom format validation rule
    18  type FormatDef struct {
    19  	// Type specifies which JSON Schema type this format applies to (optional)
    20  	// Supported values: "string", "number", "integer", "boolean", "array", "object"
    21  	// Empty string means applies to all types
    22  	Type string
    23  
    24  	// Validate is the validation function
    25  	Validate func(interface{}) bool
    26  }
    27  
    28  // Compiler represents a JSON Schema compiler that manages schema compilation and caching.
    29  type Compiler struct {
    30  	mu             sync.RWMutex                                       // Protects concurrent access to schemas map
    31  	schemas        map[string]*Schema                                 // Cache of compiled schemas.
    32  	allSchemas     []*Schema                                          // All compiled schemas, including those without IDs
    33  	unresolvedRefs map[string][]*Schema                               // Track schemas that have unresolved references by URI
    34  	Decoders       map[string]func(string) ([]byte, error)            // Decoders for various encoding formats.
    35  	MediaTypes     map[string]func([]byte) (interface{}, error)       // Media type handlers for unmarshalling data.
    36  	Loaders        map[string]func(url string) (io.ReadCloser, error) // Functions to load schemas from URLs.
    37  	DefaultBaseURI string                                             // Base URI used to resolve relative references.
    38  	AssertFormat   bool                                               // Flag to enforce format validation.
    39  
    40  	// JSON encoder/decoder configuration
    41  	jsonEncoder func(v interface{}) ([]byte, error)
    42  	jsonDecoder func(data []byte, v interface{}) error
    43  
    44  	// Default function registry
    45  	defaultFuncs map[string]DefaultFunc // Registry for dynamic default value functions
    46  
    47  	// Custom format registry
    48  	customFormats   map[string]*FormatDef // Registry for custom format definitions
    49  	customFormatsRW sync.RWMutex          // Protects concurrent access to custom formats
    50  }
    51  
    52  // DefaultFunc represents a function that can generate dynamic default values
    53  type DefaultFunc func(args ...any) (any, error)
    54  
    55  // NewCompiler creates a new Compiler instance and initializes it with default settings.
    56  func NewCompiler() *Compiler {
    57  	compiler := &Compiler{
    58  		schemas:        make(map[string]*Schema),
    59  		allSchemas:     make([]*Schema, 0),
    60  		unresolvedRefs: make(map[string][]*Schema),
    61  		Decoders:       make(map[string]func(string) ([]byte, error)),
    62  		MediaTypes:     make(map[string]func([]byte) (interface{}, error)),
    63  		Loaders:        make(map[string]func(url string) (io.ReadCloser, error)),
    64  		DefaultBaseURI: "",
    65  		AssertFormat:   false,
    66  		defaultFuncs:   make(map[string]DefaultFunc),
    67  		customFormats:  make(map[string]*FormatDef),
    68  
    69  		// Default to standard library JSON implementation
    70  		jsonEncoder: json.Marshal,
    71  		jsonDecoder: json.Unmarshal,
    72  	}
    73  	compiler.initDefaults()
    74  	return compiler
    75  }
    76  
    77  // WithEncoderJSON configures custom JSON encoder implementation
    78  func (c *Compiler) WithEncoderJSON(encoder func(v interface{}) ([]byte, error)) *Compiler {
    79  	c.jsonEncoder = encoder
    80  	return c
    81  }
    82  
    83  // WithDecoderJSON configures custom JSON decoder implementation
    84  func (c *Compiler) WithDecoderJSON(decoder func(data []byte, v interface{}) error) *Compiler {
    85  	c.jsonDecoder = decoder
    86  	return c
    87  }
    88  
    89  // Compile compiles a JSON schema and caches it. If an URI is provided, it uses that as the key; otherwise, it generates a hash.
    90  func (c *Compiler) Compile(jsonSchema []byte, uris ...string) (*Schema, error) {
    91  	schema, err := newSchema(jsonSchema)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	uri := schema.ID
    97  	if uri == "" && len(uris) > 0 {
    98  		uri = uris[0]
    99  	}
   100  
   101  	if uri != "" && isValidURI(uri) {
   102  		schema.uri = uri
   103  
   104  		c.mu.RLock()
   105  		existingSchema, exists := c.schemas[uri]
   106  		c.mu.RUnlock()
   107  
   108  		if exists {
   109  			return existingSchema, nil
   110  		}
   111  	}
   112  
   113  	schema.initializeSchema(c, nil)
   114  
   115  	// Track all schemas, whether they have an ID or not
   116  	c.mu.Lock()
   117  	c.allSchemas = append(c.allSchemas, schema)
   118  
   119  	if schema.uri != "" && isValidURI(schema.uri) {
   120  		c.schemas[schema.uri] = schema
   121  	}
   122  
   123  	// Track unresolved references from this schema
   124  	c.trackUnresolvedReferences(schema)
   125  
   126  	// If this schema has a URI, check if any previously compiled schemas were waiting for it
   127  	var schemasToResolve []*Schema
   128  	if schema.uri != "" {
   129  		if waitingSchemas, exists := c.unresolvedRefs[schema.uri]; exists {
   130  			schemasToResolve = make([]*Schema, len(waitingSchemas))
   131  			copy(schemasToResolve, waitingSchemas)
   132  			delete(c.unresolvedRefs, schema.uri) // Clear the waiting list
   133  		}
   134  	}
   135  	c.mu.Unlock()
   136  
   137  	// Only re-resolve schemas that were actually waiting for this URI
   138  	for _, waitingSchema := range schemasToResolve {
   139  		waitingSchema.ResolveUnresolvedReferences()
   140  		// Re-track any still unresolved references
   141  		c.mu.Lock()
   142  		c.trackUnresolvedReferences(waitingSchema)
   143  		c.mu.Unlock()
   144  	}
   145  
   146  	return schema, nil
   147  }
   148  
   149  // trackUnresolvedReferences tracks which schemas have unresolved references to which URIs
   150  // This method should be called with mutex locked
   151  func (c *Compiler) trackUnresolvedReferences(schema *Schema) {
   152  	unresolvedURIs := schema.GetUnresolvedReferenceURIs()
   153  	for _, uri := range unresolvedURIs {
   154  		if c.unresolvedRefs[uri] == nil {
   155  			c.unresolvedRefs[uri] = make([]*Schema, 0)
   156  		}
   157  		// Check if schema is already in the list to avoid duplicates
   158  		found := false
   159  		for _, existing := range c.unresolvedRefs[uri] {
   160  			if existing == schema {
   161  				found = true
   162  				break
   163  			}
   164  		}
   165  		if !found {
   166  			c.unresolvedRefs[uri] = append(c.unresolvedRefs[uri], schema)
   167  		}
   168  	}
   169  }
   170  
   171  // resolveSchemaURL attempts to fetch and compile a schema from a URL.
   172  func (c *Compiler) resolveSchemaURL(url string) (*Schema, error) {
   173  	id, anchor := splitRef(url)
   174  
   175  	c.mu.RLock()
   176  	schema, exists := c.schemas[id]
   177  	c.mu.RUnlock()
   178  
   179  	if exists {
   180  		return schema, nil // Return cached schema if available
   181  	}
   182  
   183  	loader, ok := c.Loaders[getURLScheme(url)]
   184  	if !ok {
   185  		return nil, ErrNoLoaderRegistered
   186  	}
   187  
   188  	body, err := loader(url)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  	defer body.Close() //nolint:errcheck
   193  
   194  	data, err := io.ReadAll(body)
   195  	if err != nil {
   196  		return nil, ErrFailedToReadData
   197  	}
   198  
   199  	compiledSchema, err := c.Compile(data, id)
   200  
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	if anchor != "" {
   206  		return compiledSchema.resolveAnchor(anchor)
   207  	}
   208  
   209  	return compiledSchema, nil
   210  }
   211  
   212  // SetSchema associates a specific schema with a URI.
   213  func (c *Compiler) SetSchema(uri string, schema *Schema) *Compiler {
   214  	c.mu.Lock()
   215  	c.schemas[uri] = schema
   216  	c.mu.Unlock()
   217  	return c
   218  }
   219  
   220  // GetSchema retrieves a schema by reference. If the schema is not found in the cache and the ref is a URL, it tries to resolve it.
   221  func (c *Compiler) GetSchema(ref string) (*Schema, error) {
   222  	baseURI, anchor := splitRef(ref)
   223  
   224  	c.mu.RLock()
   225  	schema, exists := c.schemas[baseURI]
   226  	c.mu.RUnlock()
   227  
   228  	if exists {
   229  		if baseURI == ref {
   230  			return schema, nil
   231  		}
   232  		return schema.resolveAnchor(anchor)
   233  	}
   234  
   235  	return c.resolveSchemaURL(ref)
   236  }
   237  
   238  // SetDefaultBaseURI sets the default base URL for resolving relative references.
   239  func (c *Compiler) SetDefaultBaseURI(baseURI string) *Compiler {
   240  	c.DefaultBaseURI = baseURI
   241  	return c
   242  }
   243  
   244  // SetAssertFormat enables or disables format assertion.
   245  func (c *Compiler) SetAssertFormat(assert bool) *Compiler {
   246  	c.AssertFormat = assert
   247  	return c
   248  }
   249  
   250  // RegisterDecoder adds a new decoder function for a specific encoding.
   251  func (c *Compiler) RegisterDecoder(encodingName string, decoderFunc func(string) ([]byte, error)) *Compiler {
   252  	c.Decoders[encodingName] = decoderFunc
   253  	return c
   254  }
   255  
   256  // RegisterMediaType adds a new unmarshal function for a specific media type.
   257  func (c *Compiler) RegisterMediaType(mediaTypeName string, unmarshalFunc func([]byte) (interface{}, error)) *Compiler {
   258  	c.MediaTypes[mediaTypeName] = unmarshalFunc
   259  	return c
   260  }
   261  
   262  // RegisterLoader adds a new loader function for a specific URI scheme.
   263  func (c *Compiler) RegisterLoader(scheme string, loaderFunc func(url string) (io.ReadCloser, error)) *Compiler {
   264  	c.Loaders[scheme] = loaderFunc
   265  	return c
   266  }
   267  
   268  // RegisterDefaultFunc registers a function for dynamic default value generation
   269  func (c *Compiler) RegisterDefaultFunc(name string, fn DefaultFunc) *Compiler {
   270  	c.mu.Lock()
   271  	defer c.mu.Unlock()
   272  
   273  	if c.defaultFuncs == nil {
   274  		c.defaultFuncs = make(map[string]DefaultFunc)
   275  	}
   276  	c.defaultFuncs[name] = fn
   277  	return c
   278  }
   279  
   280  // getDefaultFunc retrieves a registered default function by name
   281  func (c *Compiler) getDefaultFunc(name string) (DefaultFunc, bool) {
   282  	c.mu.RLock()
   283  	defer c.mu.RUnlock()
   284  
   285  	fn, exists := c.defaultFuncs[name]
   286  	return fn, exists
   287  }
   288  
   289  // initDefaults initializes default values for decoders, media types, and loaders.
   290  func (c *Compiler) initDefaults() {
   291  	c.Decoders["base64"] = base64.StdEncoding.DecodeString
   292  	c.setupMediaTypes()
   293  	c.setupLoaders()
   294  }
   295  
   296  // setupMediaTypes configures default media type handlers.
   297  func (c *Compiler) setupMediaTypes() {
   298  	c.MediaTypes["application/json"] = func(data []byte) (interface{}, error) {
   299  		var temp interface{}
   300  		if err := c.jsonDecoder(data, &temp); err != nil {
   301  			return nil, ErrJSONUnmarshalError
   302  		}
   303  		return temp, nil
   304  	}
   305  
   306  	c.MediaTypes["application/xml"] = func(data []byte) (interface{}, error) {
   307  		var temp interface{}
   308  		if err := xml.Unmarshal(data, &temp); err != nil {
   309  			return nil, ErrXMLUnmarshalError
   310  		}
   311  		return temp, nil
   312  	}
   313  
   314  	c.MediaTypes["application/yaml"] = func(data []byte) (interface{}, error) {
   315  		var temp interface{}
   316  		if err := yaml.Unmarshal(data, &temp); err != nil {
   317  			return nil, ErrYAMLUnmarshalError
   318  		}
   319  		return temp, nil
   320  	}
   321  }
   322  
   323  // setupLoaders configures default loaders for fetching schemas via HTTP/HTTPS.
   324  func (c *Compiler) setupLoaders() {
   325  	client := &http.Client{
   326  		Timeout: 10 * time.Second, // Set a reasonable timeout for network requests.
   327  	}
   328  
   329  	defaultHTTPLoader := func(url string) (io.ReadCloser, error) {
   330  		req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
   331  		if err != nil {
   332  			return nil, err
   333  		}
   334  
   335  		resp, err := client.Do(req)
   336  		if err != nil {
   337  			return nil, ErrFailedToFetch
   338  		}
   339  
   340  		if resp.StatusCode != http.StatusOK {
   341  			err = resp.Body.Close()
   342  			if err != nil {
   343  				return nil, err
   344  			}
   345  			return nil, ErrInvalidHTTPStatusCode
   346  		}
   347  
   348  		return resp.Body, nil
   349  	}
   350  
   351  	c.RegisterLoader("http", defaultHTTPLoader)
   352  	c.RegisterLoader("https", defaultHTTPLoader)
   353  }
   354  
   355  // CompileBatch compiles multiple schemas efficiently by deferring reference resolution
   356  // until all schemas are compiled. This is the most efficient approach when you have
   357  // many schemas with interdependencies.
   358  func (c *Compiler) CompileBatch(schemas map[string][]byte) (map[string]*Schema, error) {
   359  	compiledSchemas := make(map[string]*Schema)
   360  
   361  	// First pass: compile all schemas without resolving references
   362  	for id, schemaBytes := range schemas {
   363  		schema, err := newSchema(schemaBytes)
   364  		if err != nil {
   365  			return nil, fmt.Errorf("failed to compile schema %s: %w", id, err)
   366  		}
   367  
   368  		if schema.ID == "" {
   369  			schema.ID = id
   370  		}
   371  		schema.uri = schema.ID
   372  
   373  		// Initialize schema structure but skip reference resolution
   374  		schema.compiler = c
   375  		// Initialize basic properties without resolving references
   376  		schema.initializeSchemaWithoutReferences(c, nil)
   377  
   378  		compiledSchemas[id] = schema
   379  
   380  		c.mu.Lock()
   381  		c.allSchemas = append(c.allSchemas, schema)
   382  		if schema.uri != "" && isValidURI(schema.uri) {
   383  			c.schemas[schema.uri] = schema
   384  		}
   385  		c.mu.Unlock()
   386  	}
   387  
   388  	// Second pass: resolve all references at once
   389  	for _, schema := range compiledSchemas {
   390  		schema.resolveReferences()
   391  	}
   392  
   393  	return compiledSchemas, nil
   394  }
   395  
   396  // RegisterFormat registers a custom format.
   397  // The optional typeName parameter specifies which JSON Schema type the format applies to
   398  // (e.g., "string", "number"). If omitted, the format applies to all types.
   399  func (c *Compiler) RegisterFormat(name string, validator func(interface{}) bool, typeName ...string) *Compiler {
   400  	c.customFormatsRW.Lock()
   401  	defer c.customFormatsRW.Unlock()
   402  
   403  	var t string
   404  	if len(typeName) > 0 {
   405  		t = typeName[0]
   406  	}
   407  
   408  	c.customFormats[name] = &FormatDef{
   409  		Type:     t,
   410  		Validate: validator,
   411  	}
   412  	return c
   413  }
   414  
   415  // UnregisterFormat removes a custom format.
   416  func (c *Compiler) UnregisterFormat(name string) *Compiler {
   417  	c.customFormatsRW.Lock()
   418  	defer c.customFormatsRW.Unlock()
   419  
   420  	delete(c.customFormats, name)
   421  	return c
   422  }