//go:build ignore

package main

import (
	"bytes"
	"fmt"
	"go/format"
	"log"
	"os"
	"strings"

	"github.com/brimdata/zed/vector"
)

var opToAlpha = map[string]string{
	"==": "EQ",
	"!=": "NE",
	"<":  "LT",
	"<=": "LE",
	">":  "GT",
	">=": "GE",
}

func main() {
	var buf bytes.Buffer
	fmt.Fprintln(&buf, "// Code generated by gencomparefuncs.go. DO NOT EDIT.")
	fmt.Fprintln(&buf)
	fmt.Fprintln(&buf, "package expr")
	fmt.Fprintln(&buf, "import (")
	fmt.Fprintln(&buf, `"github.com/brimdata/zed"`)
	fmt.Fprintln(&buf, `"github.com/brimdata/zed/vector"`)
	fmt.Fprintln(&buf, ")")

	var ents strings.Builder
	for _, op := range []string{"==", "!=", "<", "<=", ">", ">="} {
		for _, typ := range []string{"Int", "Uint", "Float", "String", "Bytes"} {
			for lform := vector.Form(0); lform < 4; lform++ {
				for rform := vector.Form(0); rform < 4; rform++ {
					name := "compare" + opToAlpha[op] + typ + lform.String() + rform.String()
					fmt.Fprintln(&buf, genFunc(name, op, typ, lform, rform))
					funcCode := vector.FuncCode(vector.CompareOpFromString(op), vector.KindFromString(typ), lform, rform)
					fmt.Fprintf(&ents, "%d: %s,\n", funcCode, name)
				}
			}
		}
	}

	fmt.Fprintln(&buf, "var compareFuncs = map[int]func(vector.Any, vector.Any) vector.Any{")
	fmt.Fprintln(&buf, ents.String())
	fmt.Fprintln(&buf, "}")

	out, formatErr := format.Source(buf.Bytes())
	if formatErr != nil {
		// Write unformatted source so we can find the error.
		out = buf.Bytes()
	}
	const fileName = "comparefuncs.go"
	if err := os.WriteFile(fileName, out, 0644); err != nil {
		log.Fatal(err)
	}
	if formatErr != nil {
		log.Fatal(fileName, ":", formatErr)
	}
}

func genFunc(name, op, typ string, lhs, rhs vector.Form) string {
	s := fmt.Sprintf("func %s(lhs, rhs vector.Any) vector.Any {\n", name)
	s += genVarInit("l", typ, lhs)
	s += genVarInit("r", typ, rhs)
	lexpr := genExpr("l", lhs)
	rexpr := genExpr("r", rhs)
	if typ == "Bytes" {
		lexpr = "string(" + lexpr + ")"
		rexpr = "string(" + rexpr + ")"
	}
	if lhs == vector.FormConst && rhs == vector.FormConst {
		s += fmt.Sprintf("return vector.NewConst(zed.NewBool(%s %s %s), lhs.Len(), nil)\n", lexpr, op, rexpr)
	} else {
		s += "n := lhs.Len()\n"
		s += "out := vector.NewBoolEmpty(n, nil)\n"
		s += fmt.Sprintf("for k := uint32(0); k < n; k++ { if %s %s %s { out.Set(k) }}\n", lexpr, op, rexpr)
		s += "return out\n"
	}
	s += "}\n"
	return s
}

func genVarInit(which, typ string, form vector.Form) string {
	switch form {
	case vector.FormFlat:
		return fmt.Sprintf("%s := %shs.(*vector.%s)\n", which, which, typ)
	case vector.FormDict, vector.FormView:
		s := fmt.Sprintf("%sd := %shs.(*vector.%s)\n", which, which, form)
		s += fmt.Sprintf("%s := %sd.Any.(*vector.%s)\n", which, which, typ)
		s += fmt.Sprintf("%sx := %sd.Index\n", which, which)
		return s
	case vector.FormConst:
		return fmt.Sprintf("%sconst, _ := %shs.(*vector.Const).As%s()\n", which, which, typ)
	}
	panic(form)
}

func genExpr(which string, form vector.Form) string {
	switch form {
	case vector.FormFlat:
		return which + ".Value(k)"
	case vector.FormDict, vector.FormView:
		return fmt.Sprintf("%s.Value(uint32(%sx[k]))", which, which)
	case vector.FormConst:
		return which + "const"
	}
	panic(form)
}
