melt/Luau/format.go

367 lines
8.8 KiB
Go

package main
import (
luau "Luau/binding"
"context"
"fmt"
"os"
"strings"
c "github.com/TwiN/go-color"
sitter "github.com/smacker/go-tree-sitter"
)
func formatCode(sourceCode []byte, main *sitter.Node) string {
if main.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
getContent := func(node sitter.Node) string {
return node.Content(sourceCode)
}
indent := 0
var formatExpr func(node sitter.Node) string
formatExpr = func(node sitter.Node) string {
var formatted string
writeIndent := func() {
formatted += strings.Repeat("\t", indent)
}
ntype := node.Type()
switch ntype {
case "number":
formatted += getContent(node)
case "string":
content := getContent(node)
// strings can be surrounded by '', "", [[]], [=[]=], [==[]==], etc.
var prefix, suffix string
for _, c := range content {
if strings.Contains(`"[='`, string(c)) {
prefix += string(c)
if c == '"' || c == '\'' {
break
}
} else {
break
}
}
for i := len(content) - 1; i >= 0; i-- {
c := content[i]
if strings.Contains(`"]'=`, string(c)) {
suffix = string(c) + suffix
if c == '"' || c == '\'' {
break
}
} else {
break
}
}
main := content[len(prefix) : len(content)-len(suffix)]
hasEscapedStr := func(str string, escd string) bool {
// ensure the string has an escd with an odd number of \ before it
pos := strings.Index(str, escd)
// count the number of \ before the escd
count := 0
for i := pos - 1; i >= 0; i-- {
if string(str[i]) != `\` {
break
}
count++
}
return count%2 == 1
}
// Dued, I had no idea string literals could be so complicated
if !(hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`)) {
if prefix == `"` {
if hasEscapedStr(main, `'`) {
main = strings.ReplaceAll(main, `\'`, `'`)
} else if hasEscapedStr(main, `"`) && !strings.Contains(main, `'`) {
main = strings.ReplaceAll(main, `\"`, `"`)
prefix, suffix = `'`, `'`
}
} else if prefix == `'` {
if hasEscapedStr(main, `"`) {
main = strings.ReplaceAll(main, `\"`, `"`)
} else if hasEscapedStr(main, `'`) && !strings.Contains(main, `"`) {
main = strings.ReplaceAll(main, `\'`, `'`)
prefix, suffix = `"`, `"`
}
}
} else if hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`) {
// if both are escaped, unescape the one with the least escapes
// (default to unescaping single quotes if equal)
if strings.Count(main, `\"`) > strings.Count(main, `\'`) {
main = strings.ReplaceAll(main, `\"`, `"`)
prefix, suffix = `'`, `'`
} else {
main = strings.ReplaceAll(main, `\'`, `'`)
prefix, suffix = `"`, `"`
}
}
// "christ all mighty" - ezio4322, 1 August 2022
if prefix == `'` && hasEscapedStr(main, `'`) && !hasEscapedStr(main, `"`) &&
strings.Contains(main, `"`) && strings.Count(main, `\'`) > strings.Count(main, `"`) {
// There are escaped single quotes, unescaped double quotes, and more single quotes than double quotes
// Swap the quotes
main = strings.ReplaceAll(main, `\'`, `'`)
main = strings.ReplaceAll(main, `"`, `\"`)
prefix, suffix = `"`, `"`
} else if prefix == `"` && hasEscapedStr(main, `"`) && !hasEscapedStr(main, `'`) &&
strings.Contains(main, `'`) && strings.Count(main, `\"`) > strings.Count(main, `'`) {
// that but the other way around
main = strings.ReplaceAll(main, `\"`, `"`)
main = strings.ReplaceAll(main, `'`, `\'`)
prefix, suffix = `'`, `'`
}
formatted += prefix + main + suffix
case "var":
formatted += getContent(node)
case "string_interp":
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "interp_start", "interp_end":
formatted += "`"
case "interp_content":
formatted += getContent(*child)
case "interp_exp":
for j := range int(child.ChildCount()) {
switch child.Child(j).Type() {
case "interp_brace_open":
formatted += "{"
case "interp_brace_close":
formatted += "}"
default:
formatted += formatExpr(*child.Child(j))
}
}
default:
panic(c.InRed("Unknown string interpolation child type ") + c.InYellow(child.Type()))
}
}
case "table":
childCount := node.ChildCount()
for i := range int(childCount) {
child := node.Child(i)
switch child.Type() {
case "{":
formatted += "{"
if childCount > 2 {
formatted += "\n"
indent++
}
case "}":
// I love me some trailing commas
lastChild := node.Child(i - 1)
if lastChild.Type() == "fieldlist" &&
lastChild.Child(int(lastChild.ChildCount()-1)).Type() != "," {
formatted += ",\n"
}
if childCount > 2 {
indent--
writeIndent()
}
formatted += "}"
case "fieldlist":
for j := range int(child.ChildCount()) {
child := child.Child(j)
switch child.Type() {
case ",":
formatted += ",\n"
case "field":
for k := range int(child.ChildCount()) {
field := child.Child(k)
switch field.Type() {
case "name":
writeIndent()
formatted += getContent(*field)
case "=":
formatted += " = "
case "[":
writeIndent()
formatted += "["
case "]":
formatted += "]"
default:
formatted += formatExpr(*field)
}
}
}
}
}
}
default:
panic(c.InRed("Unknown expression type ") + c.InYellow(ntype))
}
return formatted
}
formatStmt := func(node sitter.Node) string {
var formatted string
writeIndent := func() {
formatted += strings.Repeat("\t", indent)
}
ntype := node.Type()
// parent := node.Parent()
switch ntype {
case "local_var_stmt":
writeIndent()
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "local":
formatted += "local "
case "=":
formatted += " = "
case ",":
formatted += ", "
case "binding":
formatted += getContent(*child)
default:
formatted += formatExpr(*child)
}
}
formatted += "\n"
case "assign_stmt":
writeIndent()
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "=":
formatted += " = "
case "varlist":
for j := range int(child.ChildCount()) {
varchild := child.Child(j)
switch varchild.Type() {
case "var":
formatted += getContent(*varchild)
case ",":
formatted += ", "
default:
panic(c.InRed("unknown varlist child type ") + c.InYellow(varchild.Type()))
}
}
case "explist":
for j := range int(child.ChildCount()) {
child := child.Child(j)
switch child.Type() {
case ",":
formatted += ", "
default:
formatted += formatExpr(*child)
}
}
default:
formatted += formatExpr(*child)
}
}
formatted += "\n"
case "comment":
writeIndent()
formatted += getContent(node) + "\n"
case "call_stmt":
writeIndent()
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "var":
formatted += getContent(*child)
case "arglist":
for j := range int(child.ChildCount()) {
argchild := child.Child(j)
switch argchild.Type() {
case "(":
nextType := child.Child(j + 1).Type()
if nextType == "string" && child.ChildCount() <= 3 { // function call with single string argument
formatted += " "
} else {
formatted += "("
}
case ")":
prevType := child.Child(j - 1).Type()
if prevType != "string" || child.ChildCount() > 3 {
formatted += ")"
}
case ",":
formatted += ", "
default:
if j == 0 {
formatted += " "
}
formatted += formatExpr(*argchild)
}
}
default:
panic(c.InRed("Unknown call statement child type ") + c.InYellow(child.Type()))
}
}
formatted += "\n"
default:
panic(c.InRed("Unknown statement type ") + c.InYellow(ntype))
}
return formatted
}
var formatted string
for i := range int(main.ChildCount()) {
formatted += formatStmt(*main.Child(i))
}
return formatted
}
func formatFile(filename string) {
parser := sitter.NewParser()
parser.SetLanguage(luau.GetLuau())
sourceCode, err := os.ReadFile(filename)
if err != nil {
fmt.Println(err)
return
}
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
formatted := formatCode(sourceCode, tree.RootNode())
// replace all ending newlines with a single newline
formatted = strings.Trim(formatted, "\n") + "\n"
// write back to file
err = os.WriteFile(filename, []byte(formatted), 0o644)
if err != nil {
fmt.Println(err)
return
}
}