From d0e64b35bd25f0ab5892e26ff5e1e97f63d460c9 Mon Sep 17 00:00:00 2001 From: Lewin Kelly Date: Sat, 27 Apr 2024 23:42:40 +0100 Subject: [PATCH] Improvements to Luau formatting --- Luau/compatibility.go | 2 +- Luau/format.go | 532 ++++++++++++++++++++++++------------------ Luau/main.go | 4 +- 3 files changed, 306 insertions(+), 232 deletions(-) diff --git a/Luau/compatibility.go b/Luau/compatibility.go index c2f48f3..4b5c372 100644 --- a/Luau/compatibility.go +++ b/Luau/compatibility.go @@ -132,7 +132,7 @@ func compatifyCode(sourceCode []byte, tree *sitter.Tree) string { return sourceString } -func compatify(filename string) { +func compatifyFile(filename string) { binding := luau.GetLuau() sourceCode, err := os.ReadFile(filename) if err != nil { diff --git a/Luau/format.go b/Luau/format.go index 859a082..d1c3dab 100644 --- a/Luau/format.go +++ b/Luau/format.go @@ -11,263 +11,337 @@ import ( sitter "github.com/smacker/go-tree-sitter" ) -func isInside(ntype string, node *sitter.Node) bool { - for node != nil { - if node.Type() == ntype { - return true - } - node = node.Parent() - } - return false -} - -func formatCode(sourceCode []byte, tree *sitter.Tree) string { - node := tree.RootNode() - if node.HasError() { +func formatCode(sourceCode []byte, main *sitter.Node) string { + if main.HasError() { fmt.Println("Error parsing code") return string(sourceCode) } - var formatted strings.Builder - indent := 0 - // -1 means at the start of the file - // 0 means outside of a local var statement - // 1 means inside a local var statement - var insideLocal []bool - var insideComment []bool - - writeIndent := func() { - formatted.WriteString(strings.Repeat("\t", indent)) + getContent := func(node sitter.Node) string { + return node.Content(sourceCode) } - writeFormatted := func(node sitter.Node) { - // Write the formatted content to the string builder - content := node.Content(sourceCode) + indent := 0 + + var formatExpr func(node sitter.Node) string + formatExpr = func(node sitter.Node) string { + var formatted string + + writeIndent := func() { + formatted += strings.Repeat("\t", indent) + } + ntype := node.Type() - insideLocal = append(insideLocal, isInside("local_var_stmt", &node)) - insideComment = append(insideComment, isInside("comment", &node)) + switch ntype { + case "number": + formatted += getContent(node) + case "string": + content := getContent(node) + // strings can be surrounded by '', "", [[]], [=[]=], [==[]==], etc. - parent := node.Parent() + var prefix, suffix string + for _, c := range content { + if strings.Contains(`"[='`, string(c)) { + prefix += string(c) + if c == '"' || c == '\'' { + break + } + } else { + break + } + } + for i := len(content) - 1; i >= 0; i-- { + c := content[i] + if strings.Contains(`"]'=`, string(c)) { + suffix = string(c) + suffix + if c == '"' || c == '\'' { + break + } + } else { + break + } + } + main := content[len(prefix) : len(content)-len(suffix)] + + hasEscapedStr := func(str string, escd string) bool { + // ensure the string has an escd with an odd number of \ before it + pos := strings.Index(str, escd) + + // count the number of \ before the escd + count := 0 + for i := pos - 1; i >= 0; i-- { + if string(str[i]) != `\` { + break + } + count++ + } + return count%2 == 1 + } + + // Dued, I had no idea string literals could be so complicated + if !(hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`)) { + if prefix == `"` { + if hasEscapedStr(main, `'`) { + main = strings.ReplaceAll(main, `\'`, `'`) + } else if hasEscapedStr(main, `"`) && !strings.Contains(main, `'`) { + main = strings.ReplaceAll(main, `\"`, `"`) + prefix, suffix = `'`, `'` + } + } else if prefix == `'` { + if hasEscapedStr(main, `"`) { + main = strings.ReplaceAll(main, `\"`, `"`) + } else if hasEscapedStr(main, `'`) && !strings.Contains(main, `"`) { + main = strings.ReplaceAll(main, `\'`, `'`) + prefix, suffix = `"`, `"` + } + } + } else if hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`) { + // if both are escaped, unescape the one with the least escapes + // (default to unescaping single quotes if equal) + if strings.Count(main, `\"`) > strings.Count(main, `\'`) { + main = strings.ReplaceAll(main, `\"`, `"`) + prefix, suffix = `'`, `'` + } else { + main = strings.ReplaceAll(main, `\'`, `'`) + prefix, suffix = `"`, `"` + } + } + + // "christ all mighty" - ezio4322, 1 August 2022 + if prefix == `'` && hasEscapedStr(main, `'`) && !hasEscapedStr(main, `"`) && + strings.Contains(main, `"`) && strings.Count(main, `\'`) > strings.Count(main, `"`) { + // There are escaped single quotes, unescaped double quotes, and more single quotes than double quotes + // Swap the quotes + main = strings.ReplaceAll(main, `\'`, `'`) + main = strings.ReplaceAll(main, `"`, `\"`) + prefix, suffix = `"`, `"` + } else if prefix == `"` && hasEscapedStr(main, `"`) && !hasEscapedStr(main, `'`) && + strings.Contains(main, `'`) && strings.Count(main, `\"`) > strings.Count(main, `'`) { + // that but the other way around + main = strings.ReplaceAll(main, `\"`, `"`) + main = strings.ReplaceAll(main, `'`, `\'`) + prefix, suffix = `'`, `'` + } + + formatted += prefix + main + suffix + case "var": + formatted += getContent(node) + + case "string_interp": + for i := range int(node.ChildCount()) { + child := node.Child(i) + switch child.Type() { + case "interp_start", "interp_end": + formatted += "`" + case "interp_content": + formatted += getContent(*child) + case "interp_exp": + for j := range int(child.ChildCount()) { + switch child.Child(j).Type() { + case "interp_brace_open": + formatted += "{" + case "interp_brace_close": + formatted += "}" + default: + formatted += formatExpr(*child.Child(j)) + } + } + + default: + panic(c.InRed("Unknown string interpolation child type ") + c.InYellow(child.Type())) + } + } + + case "table": + childCount := node.ChildCount() + for i := range int(childCount) { + child := node.Child(i) + switch child.Type() { + case "{": + formatted += "{" + if childCount > 2 { + formatted += "\n" + indent++ + } + case "}": + // I love me some trailing commas + lastChild := node.Child(i - 1) + if lastChild.Type() == "fieldlist" && + lastChild.Child(int(lastChild.ChildCount()-1)).Type() != "," { + formatted += ",\n" + } + + if childCount > 2 { + indent-- + writeIndent() + } + formatted += "}" + case "fieldlist": + for j := range int(child.ChildCount()) { + child := child.Child(j) + switch child.Type() { + case ",": + formatted += ",\n" + case "field": + for k := range int(child.ChildCount()) { + field := child.Child(k) + switch field.Type() { + case "name": + writeIndent() + formatted += getContent(*field) + case "=": + formatted += " = " + case "[": + writeIndent() + formatted += "[" + case "]": + formatted += "]" + default: + formatted += formatExpr(*field) + } + } + } + } + } + } + + default: + panic(c.InRed("Unknown expression type ") + c.InYellow(ntype)) + } + + return formatted + } + + formatStmt := func(node sitter.Node) string { + var formatted string + + writeIndent := func() { + formatted += strings.Repeat("\t", indent) + } + + ntype := node.Type() + + // parent := node.Parent() switch ntype { - case "comment": - if len(insideComment) > 1 && !insideComment[len(insideComment)-2] { - formatted.WriteString("\n") - } - formatted.WriteString("\n") + case "local_var_stmt": writeIndent() - formatted.WriteString(content) - case "local": - if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] && - len(insideComment) > 1 && !insideComment[len(insideComment)-2] { - formatted.WriteString("\n") - } - formatted.WriteString("\n") - writeIndent() - formatted.WriteString("local ") - - case "name": - fmt.Println(parent.Parent().Type(), node.Content(sourceCode)) - if ((parent.Parent().Type() == "call_stmt" && *parent.Child(0) == node) || - parent.Parent().Type() == "var_stmt" || - parent.Parent().Parent().Type() == "assign_stmt") && - !isInside("local_var_stmt", &node) && - !isInside("ifexp", &node) && - !isInside("binexp", &node) { - if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] { - formatted.WriteString("\n") + for i := range int(node.ChildCount()) { + child := node.Child(i) + switch child.Type() { + case "local": + formatted += "local " + case "=": + formatted += " = " + case ",": + formatted += ", " + case "binding": + formatted += getContent(*child) + default: + formatted += formatExpr(*child) } - formatted.WriteString("\n") - writeIndent() } - formatted.WriteString(content) - case "=": - formatted.WriteString(" = ") - case "==": - formatted.WriteString(" == ") - case "number": - formatted.WriteString(content) - case "return": - formatted.WriteString("\n") + formatted += "\n" + case "assign_stmt": writeIndent() - formatted.WriteString("return ") - case "if": - ifType := parent.Type() - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - writeIndent() - fallthrough - case "ifexp": - formatted.WriteString("if ") - default: - // damn better be unreachable - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "then": - ifType := parent.Type() - switch ifType { - case "if_stmt": - formatted.WriteString(" then") - indent++ - case "ifexp": - formatted.WriteString(" then ") - case "elseif_clause": - formatted.WriteString(" then") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "elseif": - pType := parent.Type() - // if it's in a statement, the parent will be the elseif clause - // if it's in an expression, the parent will be the if expression - var ifType string - switch pType { - case "elseif_clause": - ifType = parent.Parent().Type() - case "ifexp": - ifType = pType - default: - panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) + + for i := range int(node.ChildCount()) { + child := node.Child(i) + switch child.Type() { + case "=": + formatted += " = " + case "varlist": + for j := range int(child.ChildCount()) { + varchild := child.Child(j) + switch varchild.Type() { + case "var": + formatted += getContent(*varchild) + case ",": + formatted += ", " + default: + panic(c.InRed("unknown varlist child type ") + c.InYellow(varchild.Type())) + } + } + case "explist": + for j := range int(child.ChildCount()) { + child := child.Child(j) + switch child.Type() { + case ",": + formatted += ", " + default: + formatted += formatExpr(*child) + } + } + default: + formatted += formatExpr(*child) + } } - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - indent-- - writeIndent() - formatted.WriteString("elseif ") - indent++ - case "ifexp": - formatted.WriteString(" elseif ") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "else": - pType := parent.Type() - // if it's in a statement, the parent will be the else clause - // if it's in an expression, the parent will be the if expression - var ifType string - switch pType { - case "else_clause": - ifType = parent.Parent().Type() - case "ifexp": - ifType = pType - default: - panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) - } - - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - indent-- - writeIndent() - formatted.WriteString("else") - indent++ - case "ifexp": - formatted.WriteString(" else ") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "end": - formatted.WriteString("\n") - indent-- + formatted += "\n" + case "comment": writeIndent() - formatted.WriteString("end") - case "string": - if parent.Type() == "arglist" && parent.ChildCount() == 1 { - // `print"whatever"` -> `print "whatever"` - formatted.WriteString(" ") + + formatted += getContent(node) + "\n" + case "call_stmt": + writeIndent() + + for i := range int(node.ChildCount()) { + child := node.Child(i) + switch child.Type() { + case "var": + formatted += getContent(*child) + case "arglist": + for j := range int(child.ChildCount()) { + argchild := child.Child(j) + switch argchild.Type() { + case "(": + nextType := child.Child(j + 1).Type() + if nextType == "string" && child.ChildCount() <= 3 { // function call with single string argument + formatted += " " + } else { + formatted += "(" + } + case ")": + prevType := child.Child(j - 1).Type() + if prevType != "string" || child.ChildCount() > 3 { + formatted += ")" + } + case ",": + formatted += ", " + default: + if j == 0 { + formatted += " " + } + formatted += formatExpr(*argchild) + } + } + default: + panic(c.InRed("Unknown call statement child type ") + c.InYellow(child.Type())) + } } - formatted.WriteString(content) - case "interp_start", "interp_end": - formatted.WriteString("`") - case "interp_content": - formatted.WriteString(content) - case "interp_brace_open": - formatted.WriteString("{") - case "interp_brace_close": - formatted.WriteString("}") - case ":": - formatted.WriteString(":") - case ".": - formatted.WriteString(".") - case "(": - argType := parent.Child(1).Type() - // `print("whatever")` -> `print "whatever"` - if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { - formatted.WriteString("(") - } else { - formatted.WriteString(" ") - } - case ")": - argType := parent.Child(1).Type() - // `print("whatever")` -> `print "whatever"` - if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { - formatted.WriteString(")") - } - case ",": - formatted.WriteString(", ") - case "true": - formatted.WriteString("true") - case "false": - formatted.WriteString("false") - case "+": - formatted.WriteString(" + ") - case "-": - formatted.WriteString(" - ") - case "*": - formatted.WriteString(" * ") - case "/": - formatted.WriteString(" / ") - case "%": - formatted.WriteString(" % ") - case "^": - formatted.WriteString(" ^ ") - case "//": - formatted.WriteString(" // ") - case "+=": - formatted.WriteString(" += ") - case "-=": - formatted.WriteString(" -= ") - case "*=": - formatted.WriteString(" *= ") - case "/=": - formatted.WriteString(" /= ") - case "%=": - formatted.WriteString(" %= ") - case "^=": - formatted.WriteString(" ^= ") - case "//=": - formatted.WriteString(" //= ") - case ";": - // nothing + + formatted += "\n" + default: - panic(c.InRed("Unknown node type ") + c.InYellow(ntype)) + panic(c.InRed("Unknown statement type ") + c.InYellow(ntype)) } + + return formatted } - var appendLeaf func(node sitter.Node) - appendLeaf = func(node sitter.Node) { - // Print only the leaf nodes - if node.ChildCount() == 0 { - writeFormatted(node) - } else { - for i := 0; i < int(node.ChildCount()); i++ { - appendLeaf(*(node.Child(i))) - } - } + var formatted string + + for i := range int(main.ChildCount()) { + formatted += formatStmt(*main.Child(i)) } - appendLeaf(*node) - - return formatted.String() + return formatted } -func format(filename string) { +func formatFile(filename string) { parser := sitter.NewParser() parser.SetLanguage(luau.GetLuau()) @@ -278,7 +352,7 @@ func format(filename string) { } tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode) - formatted := formatCode(sourceCode, tree) + formatted := formatCode(sourceCode, tree.RootNode()) // replace all ending newlines with a single newline formatted = strings.Trim(formatted, "\n") + "\n" diff --git a/Luau/main.go b/Luau/main.go index c288f27..2ef6ffd 100644 --- a/Luau/main.go +++ b/Luau/main.go @@ -27,11 +27,11 @@ func main() { if len(args) < 3 { Error("No file specified.") } - format(args[2]) + formatFile(args[2]) case "c", "compatibility": if len(args) < 3 { Error("No file specified.") } - compatify(args[2]) + compatifyFile(args[2]) } }