github.com/shoshinnikita/budget-manager@v0.7.1-0.20220131195411-8c46ff1c6778/internal/web/pages/template_executor.go (about)

     1  package pages
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"html/template"
     7  	"io"
     8  	"io/fs"
     9  	"path/filepath"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/ShoshinNikita/budget-manager/internal/logger"
    14  	"github.com/ShoshinNikita/budget-manager/internal/pkg/errors"
    15  	"github.com/ShoshinNikita/budget-manager/internal/pkg/reqid"
    16  	"github.com/ShoshinNikita/budget-manager/templates"
    17  )
    18  
    19  type templateExecutor struct {
    20  	cacheTemplates bool
    21  	fs             fs.ReadDirFS
    22  	log            logger.Logger
    23  	commonFuncs    template.FuncMap
    24  
    25  	mu  sync.Mutex
    26  	tpl *template.Template
    27  }
    28  
    29  func newTemplateExecutor(log logger.Logger, cacheTemplates bool, commonFuncs template.FuncMap) *templateExecutor {
    30  	return &templateExecutor{
    31  		fs:             templates.New(cacheTemplates),
    32  		log:            log,
    33  		cacheTemplates: cacheTemplates,
    34  		commonFuncs:    commonFuncs,
    35  	}
    36  }
    37  
    38  func (e *templateExecutor) Execute(ctx context.Context, w io.Writer, name string, data interface{}) error {
    39  	log := reqid.FromContextToLogger(ctx, e.log)
    40  
    41  	tpl, err := e.loadTemplates()
    42  	if err != nil {
    43  		return errors.Wrap(err, "couldn't load templates")
    44  	}
    45  
    46  	tpl = tpl.Lookup(name)
    47  	if tpl == nil {
    48  		return errors.Errorf("no template with name '%s'", name)
    49  	}
    50  
    51  	if err := executeTemplate(log, tpl, w, data); err != nil {
    52  		return errors.Wrap(err, "couldn't execute template")
    53  	}
    54  
    55  	return nil
    56  }
    57  
    58  // loadTemplates loads all templates from file or returns them from cache according to 'cacheTemplates'
    59  func (e *templateExecutor) loadTemplates() (_ *template.Template, err error) {
    60  	e.mu.Lock()
    61  	defer e.mu.Unlock()
    62  
    63  	if e.cacheTemplates && e.tpl != nil {
    64  		return e.tpl, nil
    65  	}
    66  
    67  	patterns, err := extractAllTemplatePaths(e.fs)
    68  	if err != nil {
    69  		return nil, errors.Wrap(err, "couldn't get template filenames")
    70  	}
    71  
    72  	e.tpl, err = template.New("base").Funcs(e.getCommonFuncs()).ParseFS(e.fs, patterns...)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	return e.tpl, nil
    78  }
    79  
    80  func (e *templateExecutor) getCommonFuncs() template.FuncMap {
    81  	res := make(template.FuncMap, len(e.commonFuncs))
    82  	for k, v := range e.commonFuncs {
    83  		res[k] = v
    84  	}
    85  	return res
    86  }
    87  
    88  // executeTemplate executes passed template. It checks for errors before writing into w: it executes
    89  // template into temporary buffer and copies data if everything is fine
    90  func executeTemplate(log logger.Logger, tpl *template.Template, w io.Writer, data interface{}) error {
    91  	buff := bytes.NewBuffer(nil)
    92  
    93  	now := time.Now()
    94  	if err := tpl.Execute(buff, data); err != nil {
    95  		return err
    96  	}
    97  	log.WithField("time", time.Since(now)).Debug("template was successfully executed")
    98  
    99  	_, err := io.Copy(w, buff)
   100  	return err
   101  }
   102  
   103  func extractAllTemplatePaths(fs fs.ReadDirFS) ([]string, error) {
   104  	const maxDepth = 25
   105  
   106  	var walk func(root string, depth int) ([]string, error)
   107  	walk = func(root string, depth int) (paths []string, err error) {
   108  		if depth >= maxDepth {
   109  			return nil, errors.Errorf("max dir depth is reached: %d", maxDepth)
   110  		}
   111  
   112  		entries, err := fs.ReadDir(root)
   113  		if err != nil {
   114  			return nil, errors.Wrap(err, "couldn't read dir")
   115  		}
   116  
   117  		for _, entry := range entries {
   118  			if !entry.IsDir() {
   119  				if isTemplate(entry.Name()) {
   120  					paths = append(paths, filepath.Join(root, entry.Name()))
   121  				}
   122  				continue
   123  			}
   124  
   125  			nestedPaths, err := walk(filepath.Join(root, entry.Name()), depth+1)
   126  			if err != nil {
   127  				return nil, err
   128  			}
   129  			paths = append(paths, nestedPaths...)
   130  		}
   131  		return paths, nil
   132  	}
   133  
   134  	return walk(".", 0)
   135  }
   136  
   137  func isTemplate(name string) bool {
   138  	ext := filepath.Ext(name)
   139  
   140  	return ext == ".html"
   141  }