package main

import (
	"fmt"
	htemplate "html/template"
	"io"
	"io/fs"
	"os"
	pathpkg "path"
	"strings"
	"text/template"
	"text/template/parse"
)

// Template represents a template.
type Template interface {
	AddParseTree(*parse.Tree) error
	Execute(io.Writer, interface{}) error
	Tree() *parse.Tree
}

type textTemplate struct {
	tmpl *template.Template
}

func (t textTemplate) AddParseTree(tree *parse.Tree) error {
	_, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree)
	return err
}

func (t textTemplate) Execute(w io.Writer, data interface{}) error {
	return t.tmpl.Execute(w, data)
}

func (t textTemplate) Tree() *parse.Tree {
	return t.tmpl.Tree
}

type htmlTemplate struct {
	tmpl *htemplate.Template
}

func (t htmlTemplate) AddParseTree(tree *parse.Tree) error {
	_, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree)
	return err
}

func (t htmlTemplate) Execute(w io.Writer, data interface{}) error {
	return t.tmpl.Execute(w, data)
}

func (t htmlTemplate) Tree() *parse.Tree {
	return t.tmpl.Tree
}

// Templates contains site templates.
type Templates struct {
	tmpls map[string]Template
	funcs map[string]interface{}
}

// Funcs sets the functions available to newly created templates.
func (t *Templates) Funcs(funcs map[string]interface{}) {
	t.funcs = funcs
}

// LoadTemplate loads a template from the provided filenames.
func (t *Templates) LoadTemplate(fsys fs.FS, path string) error {
	if t.tmpls == nil {
		t.tmpls = map[string]Template{}
	}
	if ext := pathpkg.Ext(path); ext == ".html" || ext == ".xml" {
		return t.loadHTMLTemplate(fsys, path)
	}
	return t.loadTextTemplate(fsys, path)
}

func (t *Templates) loadTextTemplate(fsys fs.FS, path string) error {
	tmpl := template.New(path).Funcs(t.funcs)
	b, err := fs.ReadFile(fsys, path)
	if err != nil {
		return err
	}
	if _, err := tmpl.Parse(string(b)); err != nil {
		return err
	}
	t.tmpls[path] = textTemplate{tmpl}
	return nil
}

func (t *Templates) loadHTMLTemplate(fsys fs.FS, path string) error {
	tmpl := htemplate.New(path).Funcs(t.funcs)
	b, err := fs.ReadFile(fsys, path)
	if err != nil {
		return err
	}
	if _, err := tmpl.Parse(string(b)); err != nil {
		return err
	}
	t.tmpls[path] = htmlTemplate{tmpl}
	return nil
}

// Load loads templates from the provided directory.
func (t *Templates) Load(dir string, exts []string) error {
	fsys := os.DirFS(dir)
	err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
		if err != nil {
			return err
		}
		if d.Type().IsRegular() {
			if err := t.LoadTemplate(fsys, path); err != nil {
				return err
			}
		}
		return nil
	})
	if err != nil && !os.IsNotExist(err) {
		return err
	}

	// Add base templates
	var extsMap = map[string]struct{}{}
	for _, ext := range exts {
		extsMap[ext] = struct{}{}
	}
	for path := range t.tmpls {
		ext := pathpkg.Ext(path)
		if _, ok := extsMap[ext]; !ok {
			continue
		}
		base := pathpkg.Join(pathpkg.Dir(path), "base"+ext)
		if tmpl, ok := t.tmpls[base]; ok {
			err := t.tmpls[path].AddParseTree(tmpl.Tree())
			if err != nil {
				return err
			}
		}
	}
	return nil
}

// FindTemplate returns the template for the given path.
func (t *Templates) FindTemplate(path string, tmpl string) (Template, bool) {
	tmplPath := pathpkg.Join(path, tmpl)
	if t, ok := t.tmpls[tmplPath]; ok {
		return t, true
	}
	if t, ok := t.tmpls[pathpkg.Join("_default", tmpl)]; ok {
		return t, true
	}
	// Failed to find template
	return nil, false
}

// FindPartial returns the partial template of the given name.
func (t *Templates) FindPartial(name string) (Template, bool) {
	if t, ok := t.tmpls[pathpkg.Join("_partials", name)]; ok {
		return t, true
	}
	return nil, false
}

// ExecutePartial executes the partial with the given name.
func (t *Templates) ExecutePartial(name string, data interface{}) (string, error) {
	tmpl, ok := t.FindPartial(name)
	if !ok {
		return "", fmt.Errorf("Error: partial %q not found", name)
	}
	var b strings.Builder
	if err := tmpl.Execute(&b, data); err != nil {
		return "", err
	}
	return b.String(), nil
}
