367 lines
8.8 KiB
Go
367 lines
8.8 KiB
Go
package main
|
|
|
|
import (
|
|
luau "Luau/binding"
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
c "github.com/TwiN/go-color"
|
|
sitter "github.com/smacker/go-tree-sitter"
|
|
)
|
|
|
|
func formatCode(sourceCode []byte, main *sitter.Node) string {
|
|
if main.HasError() {
|
|
fmt.Println("Error parsing code")
|
|
return string(sourceCode)
|
|
}
|
|
|
|
getContent := func(node sitter.Node) string {
|
|
return node.Content(sourceCode)
|
|
}
|
|
|
|
indent := 0
|
|
|
|
var formatExpr func(node sitter.Node) string
|
|
formatExpr = func(node sitter.Node) string {
|
|
var formatted string
|
|
|
|
writeIndent := func() {
|
|
formatted += strings.Repeat("\t", indent)
|
|
}
|
|
|
|
ntype := node.Type()
|
|
|
|
switch ntype {
|
|
case "number":
|
|
formatted += getContent(node)
|
|
case "string":
|
|
content := getContent(node)
|
|
// strings can be surrounded by '', "", [[]], [=[]=], [==[]==], etc.
|
|
|
|
var prefix, suffix string
|
|
for _, c := range content {
|
|
if strings.Contains(`"[='`, string(c)) {
|
|
prefix += string(c)
|
|
if c == '"' || c == '\'' {
|
|
break
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
for i := len(content) - 1; i >= 0; i-- {
|
|
c := content[i]
|
|
if strings.Contains(`"]'=`, string(c)) {
|
|
suffix = string(c) + suffix
|
|
if c == '"' || c == '\'' {
|
|
break
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
main := content[len(prefix) : len(content)-len(suffix)]
|
|
|
|
hasEscapedStr := func(str string, escd string) bool {
|
|
// ensure the string has an escd with an odd number of \ before it
|
|
pos := strings.Index(str, escd)
|
|
|
|
// count the number of \ before the escd
|
|
count := 0
|
|
for i := pos - 1; i >= 0; i-- {
|
|
if string(str[i]) != `\` {
|
|
break
|
|
}
|
|
count++
|
|
}
|
|
return count%2 == 1
|
|
}
|
|
|
|
// Dued, I had no idea string literals could be so complicated
|
|
if !(hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`)) {
|
|
if prefix == `"` {
|
|
if hasEscapedStr(main, `'`) {
|
|
main = strings.ReplaceAll(main, `\'`, `'`)
|
|
} else if hasEscapedStr(main, `"`) && !strings.Contains(main, `'`) {
|
|
main = strings.ReplaceAll(main, `\"`, `"`)
|
|
prefix, suffix = `'`, `'`
|
|
}
|
|
} else if prefix == `'` {
|
|
if hasEscapedStr(main, `"`) {
|
|
main = strings.ReplaceAll(main, `\"`, `"`)
|
|
} else if hasEscapedStr(main, `'`) && !strings.Contains(main, `"`) {
|
|
main = strings.ReplaceAll(main, `\'`, `'`)
|
|
prefix, suffix = `"`, `"`
|
|
}
|
|
}
|
|
} else if hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`) {
|
|
// if both are escaped, unescape the one with the least escapes
|
|
// (default to unescaping single quotes if equal)
|
|
if strings.Count(main, `\"`) > strings.Count(main, `\'`) {
|
|
main = strings.ReplaceAll(main, `\"`, `"`)
|
|
prefix, suffix = `'`, `'`
|
|
} else {
|
|
main = strings.ReplaceAll(main, `\'`, `'`)
|
|
prefix, suffix = `"`, `"`
|
|
}
|
|
}
|
|
|
|
// "christ all mighty" - ezio4322, 1 August 2022
|
|
if prefix == `'` && hasEscapedStr(main, `'`) && !hasEscapedStr(main, `"`) &&
|
|
strings.Contains(main, `"`) && strings.Count(main, `\'`) > strings.Count(main, `"`) {
|
|
// There are escaped single quotes, unescaped double quotes, and more single quotes than double quotes
|
|
// Swap the quotes
|
|
main = strings.ReplaceAll(main, `\'`, `'`)
|
|
main = strings.ReplaceAll(main, `"`, `\"`)
|
|
prefix, suffix = `"`, `"`
|
|
} else if prefix == `"` && hasEscapedStr(main, `"`) && !hasEscapedStr(main, `'`) &&
|
|
strings.Contains(main, `'`) && strings.Count(main, `\"`) > strings.Count(main, `'`) {
|
|
// that but the other way around
|
|
main = strings.ReplaceAll(main, `\"`, `"`)
|
|
main = strings.ReplaceAll(main, `'`, `\'`)
|
|
prefix, suffix = `'`, `'`
|
|
}
|
|
|
|
formatted += prefix + main + suffix
|
|
case "var":
|
|
formatted += getContent(node)
|
|
|
|
case "string_interp":
|
|
for i := range int(node.ChildCount()) {
|
|
child := node.Child(i)
|
|
switch child.Type() {
|
|
case "interp_start", "interp_end":
|
|
formatted += "`"
|
|
case "interp_content":
|
|
formatted += getContent(*child)
|
|
case "interp_exp":
|
|
for j := range int(child.ChildCount()) {
|
|
switch child.Child(j).Type() {
|
|
case "interp_brace_open":
|
|
formatted += "{"
|
|
case "interp_brace_close":
|
|
formatted += "}"
|
|
default:
|
|
formatted += formatExpr(*child.Child(j))
|
|
}
|
|
}
|
|
|
|
default:
|
|
panic(c.InRed("Unknown string interpolation child type ") + c.InYellow(child.Type()))
|
|
}
|
|
}
|
|
|
|
case "table":
|
|
childCount := node.ChildCount()
|
|
for i := range int(childCount) {
|
|
child := node.Child(i)
|
|
switch child.Type() {
|
|
case "{":
|
|
formatted += "{"
|
|
if childCount > 2 {
|
|
formatted += "\n"
|
|
indent++
|
|
}
|
|
case "}":
|
|
// I love me some trailing commas
|
|
lastChild := node.Child(i - 1)
|
|
if lastChild.Type() == "fieldlist" &&
|
|
lastChild.Child(int(lastChild.ChildCount()-1)).Type() != "," {
|
|
formatted += ",\n"
|
|
}
|
|
|
|
if childCount > 2 {
|
|
indent--
|
|
writeIndent()
|
|
}
|
|
formatted += "}"
|
|
case "fieldlist":
|
|
for j := range int(child.ChildCount()) {
|
|
child := child.Child(j)
|
|
switch child.Type() {
|
|
case ",":
|
|
formatted += ",\n"
|
|
case "field":
|
|
for k := range int(child.ChildCount()) {
|
|
field := child.Child(k)
|
|
switch field.Type() {
|
|
case "name":
|
|
writeIndent()
|
|
formatted += getContent(*field)
|
|
case "=":
|
|
formatted += " = "
|
|
case "[":
|
|
writeIndent()
|
|
formatted += "["
|
|
case "]":
|
|
formatted += "]"
|
|
default:
|
|
formatted += formatExpr(*field)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
default:
|
|
panic(c.InRed("Unknown expression type ") + c.InYellow(ntype))
|
|
}
|
|
|
|
return formatted
|
|
}
|
|
|
|
formatStmt := func(node sitter.Node) string {
|
|
var formatted string
|
|
|
|
writeIndent := func() {
|
|
formatted += strings.Repeat("\t", indent)
|
|
}
|
|
|
|
ntype := node.Type()
|
|
|
|
// parent := node.Parent()
|
|
|
|
switch ntype {
|
|
case "local_var_stmt":
|
|
writeIndent()
|
|
|
|
for i := range int(node.ChildCount()) {
|
|
child := node.Child(i)
|
|
switch child.Type() {
|
|
case "local":
|
|
formatted += "local "
|
|
case "=":
|
|
formatted += " = "
|
|
case ",":
|
|
formatted += ", "
|
|
case "binding":
|
|
formatted += getContent(*child)
|
|
default:
|
|
formatted += formatExpr(*child)
|
|
}
|
|
}
|
|
|
|
formatted += "\n"
|
|
case "assign_stmt":
|
|
writeIndent()
|
|
|
|
for i := range int(node.ChildCount()) {
|
|
child := node.Child(i)
|
|
switch child.Type() {
|
|
case "=":
|
|
formatted += " = "
|
|
case "varlist":
|
|
for j := range int(child.ChildCount()) {
|
|
varchild := child.Child(j)
|
|
switch varchild.Type() {
|
|
case "var":
|
|
formatted += getContent(*varchild)
|
|
case ",":
|
|
formatted += ", "
|
|
default:
|
|
panic(c.InRed("unknown varlist child type ") + c.InYellow(varchild.Type()))
|
|
}
|
|
}
|
|
case "explist":
|
|
for j := range int(child.ChildCount()) {
|
|
child := child.Child(j)
|
|
switch child.Type() {
|
|
case ",":
|
|
formatted += ", "
|
|
default:
|
|
formatted += formatExpr(*child)
|
|
}
|
|
}
|
|
default:
|
|
formatted += formatExpr(*child)
|
|
}
|
|
}
|
|
|
|
formatted += "\n"
|
|
case "comment":
|
|
writeIndent()
|
|
|
|
formatted += getContent(node) + "\n"
|
|
case "call_stmt":
|
|
writeIndent()
|
|
|
|
for i := range int(node.ChildCount()) {
|
|
child := node.Child(i)
|
|
switch child.Type() {
|
|
case "var":
|
|
formatted += getContent(*child)
|
|
case "arglist":
|
|
for j := range int(child.ChildCount()) {
|
|
argchild := child.Child(j)
|
|
switch argchild.Type() {
|
|
case "(":
|
|
nextType := child.Child(j + 1).Type()
|
|
if nextType == "string" && child.ChildCount() <= 3 { // function call with single string argument
|
|
formatted += " "
|
|
} else {
|
|
formatted += "("
|
|
}
|
|
case ")":
|
|
prevType := child.Child(j - 1).Type()
|
|
if prevType != "string" || child.ChildCount() > 3 {
|
|
formatted += ")"
|
|
}
|
|
case ",":
|
|
formatted += ", "
|
|
default:
|
|
if j == 0 {
|
|
formatted += " "
|
|
}
|
|
formatted += formatExpr(*argchild)
|
|
}
|
|
}
|
|
default:
|
|
panic(c.InRed("Unknown call statement child type ") + c.InYellow(child.Type()))
|
|
}
|
|
}
|
|
|
|
formatted += "\n"
|
|
|
|
default:
|
|
panic(c.InRed("Unknown statement type ") + c.InYellow(ntype))
|
|
}
|
|
|
|
return formatted
|
|
}
|
|
|
|
var formatted string
|
|
|
|
for i := range int(main.ChildCount()) {
|
|
formatted += formatStmt(*main.Child(i))
|
|
}
|
|
|
|
return formatted
|
|
}
|
|
|
|
func formatFile(filename string) {
|
|
parser := sitter.NewParser()
|
|
parser.SetLanguage(luau.GetLuau())
|
|
|
|
sourceCode, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
|
|
formatted := formatCode(sourceCode, tree.RootNode())
|
|
|
|
// replace all ending newlines with a single newline
|
|
formatted = strings.Trim(formatted, "\n") + "\n"
|
|
|
|
// write back to file
|
|
err = os.WriteFile(filename, []byte(formatted), 0o644)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
}
|