Improvements to Luau formatting

This commit is contained in:
Lewin Kelly 2024-04-27 23:42:40 +01:00
parent da41ac09d6
commit d0e64b35bd
3 changed files with 306 additions and 232 deletions

View File

@ -132,7 +132,7 @@ func compatifyCode(sourceCode []byte, tree *sitter.Tree) string {
return sourceString
}
func compatify(filename string) {
func compatifyFile(filename string) {
binding := luau.GetLuau()
sourceCode, err := os.ReadFile(filename)
if err != nil {

View File

@ -11,263 +11,337 @@ import (
sitter "github.com/smacker/go-tree-sitter"
)
func isInside(ntype string, node *sitter.Node) bool {
for node != nil {
if node.Type() == ntype {
return true
}
node = node.Parent()
}
return false
}
func formatCode(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
func formatCode(sourceCode []byte, main *sitter.Node) string {
if main.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
var formatted strings.Builder
indent := 0
// -1 means at the start of the file
// 0 means outside of a local var statement
// 1 means inside a local var statement
var insideLocal []bool
var insideComment []bool
writeIndent := func() {
formatted.WriteString(strings.Repeat("\t", indent))
getContent := func(node sitter.Node) string {
return node.Content(sourceCode)
}
writeFormatted := func(node sitter.Node) {
// Write the formatted content to the string builder
content := node.Content(sourceCode)
indent := 0
var formatExpr func(node sitter.Node) string
formatExpr = func(node sitter.Node) string {
var formatted string
writeIndent := func() {
formatted += strings.Repeat("\t", indent)
}
ntype := node.Type()
insideLocal = append(insideLocal, isInside("local_var_stmt", &node))
insideComment = append(insideComment, isInside("comment", &node))
switch ntype {
case "number":
formatted += getContent(node)
case "string":
content := getContent(node)
// strings can be surrounded by '', "", [[]], [=[]=], [==[]==], etc.
parent := node.Parent()
var prefix, suffix string
for _, c := range content {
if strings.Contains(`"[='`, string(c)) {
prefix += string(c)
if c == '"' || c == '\'' {
break
}
} else {
break
}
}
for i := len(content) - 1; i >= 0; i-- {
c := content[i]
if strings.Contains(`"]'=`, string(c)) {
suffix = string(c) + suffix
if c == '"' || c == '\'' {
break
}
} else {
break
}
}
main := content[len(prefix) : len(content)-len(suffix)]
hasEscapedStr := func(str string, escd string) bool {
// ensure the string has an escd with an odd number of \ before it
pos := strings.Index(str, escd)
// count the number of \ before the escd
count := 0
for i := pos - 1; i >= 0; i-- {
if string(str[i]) != `\` {
break
}
count++
}
return count%2 == 1
}
// Dued, I had no idea string literals could be so complicated
if !(hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`)) {
if prefix == `"` {
if hasEscapedStr(main, `'`) {
main = strings.ReplaceAll(main, `\'`, `'`)
} else if hasEscapedStr(main, `"`) && !strings.Contains(main, `'`) {
main = strings.ReplaceAll(main, `\"`, `"`)
prefix, suffix = `'`, `'`
}
} else if prefix == `'` {
if hasEscapedStr(main, `"`) {
main = strings.ReplaceAll(main, `\"`, `"`)
} else if hasEscapedStr(main, `'`) && !strings.Contains(main, `"`) {
main = strings.ReplaceAll(main, `\'`, `'`)
prefix, suffix = `"`, `"`
}
}
} else if hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`) {
// if both are escaped, unescape the one with the least escapes
// (default to unescaping single quotes if equal)
if strings.Count(main, `\"`) > strings.Count(main, `\'`) {
main = strings.ReplaceAll(main, `\"`, `"`)
prefix, suffix = `'`, `'`
} else {
main = strings.ReplaceAll(main, `\'`, `'`)
prefix, suffix = `"`, `"`
}
}
// "christ all mighty" - ezio4322, 1 August 2022
if prefix == `'` && hasEscapedStr(main, `'`) && !hasEscapedStr(main, `"`) &&
strings.Contains(main, `"`) && strings.Count(main, `\'`) > strings.Count(main, `"`) {
// There are escaped single quotes, unescaped double quotes, and more single quotes than double quotes
// Swap the quotes
main = strings.ReplaceAll(main, `\'`, `'`)
main = strings.ReplaceAll(main, `"`, `\"`)
prefix, suffix = `"`, `"`
} else if prefix == `"` && hasEscapedStr(main, `"`) && !hasEscapedStr(main, `'`) &&
strings.Contains(main, `'`) && strings.Count(main, `\"`) > strings.Count(main, `'`) {
// that but the other way around
main = strings.ReplaceAll(main, `\"`, `"`)
main = strings.ReplaceAll(main, `'`, `\'`)
prefix, suffix = `'`, `'`
}
formatted += prefix + main + suffix
case "var":
formatted += getContent(node)
case "string_interp":
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "interp_start", "interp_end":
formatted += "`"
case "interp_content":
formatted += getContent(*child)
case "interp_exp":
for j := range int(child.ChildCount()) {
switch child.Child(j).Type() {
case "interp_brace_open":
formatted += "{"
case "interp_brace_close":
formatted += "}"
default:
formatted += formatExpr(*child.Child(j))
}
}
default:
panic(c.InRed("Unknown string interpolation child type ") + c.InYellow(child.Type()))
}
}
case "table":
childCount := node.ChildCount()
for i := range int(childCount) {
child := node.Child(i)
switch child.Type() {
case "{":
formatted += "{"
if childCount > 2 {
formatted += "\n"
indent++
}
case "}":
// I love me some trailing commas
lastChild := node.Child(i - 1)
if lastChild.Type() == "fieldlist" &&
lastChild.Child(int(lastChild.ChildCount()-1)).Type() != "," {
formatted += ",\n"
}
if childCount > 2 {
indent--
writeIndent()
}
formatted += "}"
case "fieldlist":
for j := range int(child.ChildCount()) {
child := child.Child(j)
switch child.Type() {
case ",":
formatted += ",\n"
case "field":
for k := range int(child.ChildCount()) {
field := child.Child(k)
switch field.Type() {
case "name":
writeIndent()
formatted += getContent(*field)
case "=":
formatted += " = "
case "[":
writeIndent()
formatted += "["
case "]":
formatted += "]"
default:
formatted += formatExpr(*field)
}
}
}
}
}
}
default:
panic(c.InRed("Unknown expression type ") + c.InYellow(ntype))
}
return formatted
}
formatStmt := func(node sitter.Node) string {
var formatted string
writeIndent := func() {
formatted += strings.Repeat("\t", indent)
}
ntype := node.Type()
// parent := node.Parent()
switch ntype {
case "comment":
if len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
case "local_var_stmt":
writeIndent()
formatted.WriteString(content)
case "local":
if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] &&
len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("local ")
case "name":
fmt.Println(parent.Parent().Type(), node.Content(sourceCode))
if ((parent.Parent().Type() == "call_stmt" && *parent.Child(0) == node) ||
parent.Parent().Type() == "var_stmt" ||
parent.Parent().Parent().Type() == "assign_stmt") &&
!isInside("local_var_stmt", &node) &&
!isInside("ifexp", &node) &&
!isInside("binexp", &node) {
if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] {
formatted.WriteString("\n")
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "local":
formatted += "local "
case "=":
formatted += " = "
case ",":
formatted += ", "
case "binding":
formatted += getContent(*child)
default:
formatted += formatExpr(*child)
}
formatted.WriteString("\n")
writeIndent()
}
formatted.WriteString(content)
case "=":
formatted.WriteString(" = ")
case "==":
formatted.WriteString(" == ")
case "number":
formatted.WriteString(content)
case "return":
formatted.WriteString("\n")
formatted += "\n"
case "assign_stmt":
writeIndent()
formatted.WriteString("return ")
case "if":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
writeIndent()
fallthrough
case "ifexp":
formatted.WriteString("if ")
default:
// damn better be unreachable
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "then":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString(" then")
indent++
case "ifexp":
formatted.WriteString(" then ")
case "elseif_clause":
formatted.WriteString(" then")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "elseif":
pType := parent.Type()
// if it's in a statement, the parent will be the elseif clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "elseif_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "=":
formatted += " = "
case "varlist":
for j := range int(child.ChildCount()) {
varchild := child.Child(j)
switch varchild.Type() {
case "var":
formatted += getContent(*varchild)
case ",":
formatted += ", "
default:
panic(c.InRed("unknown varlist child type ") + c.InYellow(varchild.Type()))
}
}
case "explist":
for j := range int(child.ChildCount()) {
child := child.Child(j)
switch child.Type() {
case ",":
formatted += ", "
default:
formatted += formatExpr(*child)
}
}
default:
formatted += formatExpr(*child)
}
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("elseif ")
indent++
case "ifexp":
formatted.WriteString(" elseif ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "else":
pType := parent.Type()
// if it's in a statement, the parent will be the else clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "else_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("else")
indent++
case "ifexp":
formatted.WriteString(" else ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "end":
formatted.WriteString("\n")
indent--
formatted += "\n"
case "comment":
writeIndent()
formatted.WriteString("end")
case "string":
if parent.Type() == "arglist" && parent.ChildCount() == 1 {
// `print"whatever"` -> `print "whatever"`
formatted.WriteString(" ")
formatted += getContent(node) + "\n"
case "call_stmt":
writeIndent()
for i := range int(node.ChildCount()) {
child := node.Child(i)
switch child.Type() {
case "var":
formatted += getContent(*child)
case "arglist":
for j := range int(child.ChildCount()) {
argchild := child.Child(j)
switch argchild.Type() {
case "(":
nextType := child.Child(j + 1).Type()
if nextType == "string" && child.ChildCount() <= 3 { // function call with single string argument
formatted += " "
} else {
formatted += "("
}
case ")":
prevType := child.Child(j - 1).Type()
if prevType != "string" || child.ChildCount() > 3 {
formatted += ")"
}
case ",":
formatted += ", "
default:
if j == 0 {
formatted += " "
}
formatted += formatExpr(*argchild)
}
}
default:
panic(c.InRed("Unknown call statement child type ") + c.InYellow(child.Type()))
}
}
formatted.WriteString(content)
case "interp_start", "interp_end":
formatted.WriteString("`")
case "interp_content":
formatted.WriteString(content)
case "interp_brace_open":
formatted.WriteString("{")
case "interp_brace_close":
formatted.WriteString("}")
case ":":
formatted.WriteString(":")
case ".":
formatted.WriteString(".")
case "(":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString("(")
} else {
formatted.WriteString(" ")
}
case ")":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString(")")
}
case ",":
formatted.WriteString(", ")
case "true":
formatted.WriteString("true")
case "false":
formatted.WriteString("false")
case "+":
formatted.WriteString(" + ")
case "-":
formatted.WriteString(" - ")
case "*":
formatted.WriteString(" * ")
case "/":
formatted.WriteString(" / ")
case "%":
formatted.WriteString(" % ")
case "^":
formatted.WriteString(" ^ ")
case "//":
formatted.WriteString(" // ")
case "+=":
formatted.WriteString(" += ")
case "-=":
formatted.WriteString(" -= ")
case "*=":
formatted.WriteString(" *= ")
case "/=":
formatted.WriteString(" /= ")
case "%=":
formatted.WriteString(" %= ")
case "^=":
formatted.WriteString(" ^= ")
case "//=":
formatted.WriteString(" //= ")
case ";":
// nothing
formatted += "\n"
default:
panic(c.InRed("Unknown node type ") + c.InYellow(ntype))
panic(c.InRed("Unknown statement type ") + c.InYellow(ntype))
}
return formatted
}
var appendLeaf func(node sitter.Node)
appendLeaf = func(node sitter.Node) {
// Print only the leaf nodes
if node.ChildCount() == 0 {
writeFormatted(node)
} else {
for i := 0; i < int(node.ChildCount()); i++ {
appendLeaf(*(node.Child(i)))
}
}
var formatted string
for i := range int(main.ChildCount()) {
formatted += formatStmt(*main.Child(i))
}
appendLeaf(*node)
return formatted.String()
return formatted
}
func format(filename string) {
func formatFile(filename string) {
parser := sitter.NewParser()
parser.SetLanguage(luau.GetLuau())
@ -278,7 +352,7 @@ func format(filename string) {
}
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
formatted := formatCode(sourceCode, tree)
formatted := formatCode(sourceCode, tree.RootNode())
// replace all ending newlines with a single newline
formatted = strings.Trim(formatted, "\n") + "\n"

View File

@ -27,11 +27,11 @@ func main() {
if len(args) < 3 {
Error("No file specified.")
}
format(args[2])
formatFile(args[2])
case "c", "compatibility":
if len(args) < 3 {
Error("No file specified.")
}
compatify(args[2])
compatifyFile(args[2])
}
}