Improvements to Luau formatting
This commit is contained in:
parent
da41ac09d6
commit
d0e64b35bd
|
|
@ -132,7 +132,7 @@ func compatifyCode(sourceCode []byte, tree *sitter.Tree) string {
|
|||
return sourceString
|
||||
}
|
||||
|
||||
func compatify(filename string) {
|
||||
func compatifyFile(filename string) {
|
||||
binding := luau.GetLuau()
|
||||
sourceCode, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
|
|
|
|||
532
Luau/format.go
532
Luau/format.go
|
|
@ -11,263 +11,337 @@ import (
|
|||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
func isInside(ntype string, node *sitter.Node) bool {
|
||||
for node != nil {
|
||||
if node.Type() == ntype {
|
||||
return true
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func formatCode(sourceCode []byte, tree *sitter.Tree) string {
|
||||
node := tree.RootNode()
|
||||
if node.HasError() {
|
||||
func formatCode(sourceCode []byte, main *sitter.Node) string {
|
||||
if main.HasError() {
|
||||
fmt.Println("Error parsing code")
|
||||
return string(sourceCode)
|
||||
}
|
||||
|
||||
var formatted strings.Builder
|
||||
indent := 0
|
||||
// -1 means at the start of the file
|
||||
// 0 means outside of a local var statement
|
||||
// 1 means inside a local var statement
|
||||
var insideLocal []bool
|
||||
var insideComment []bool
|
||||
|
||||
writeIndent := func() {
|
||||
formatted.WriteString(strings.Repeat("\t", indent))
|
||||
getContent := func(node sitter.Node) string {
|
||||
return node.Content(sourceCode)
|
||||
}
|
||||
|
||||
writeFormatted := func(node sitter.Node) {
|
||||
// Write the formatted content to the string builder
|
||||
content := node.Content(sourceCode)
|
||||
indent := 0
|
||||
|
||||
var formatExpr func(node sitter.Node) string
|
||||
formatExpr = func(node sitter.Node) string {
|
||||
var formatted string
|
||||
|
||||
writeIndent := func() {
|
||||
formatted += strings.Repeat("\t", indent)
|
||||
}
|
||||
|
||||
ntype := node.Type()
|
||||
|
||||
insideLocal = append(insideLocal, isInside("local_var_stmt", &node))
|
||||
insideComment = append(insideComment, isInside("comment", &node))
|
||||
switch ntype {
|
||||
case "number":
|
||||
formatted += getContent(node)
|
||||
case "string":
|
||||
content := getContent(node)
|
||||
// strings can be surrounded by '', "", [[]], [=[]=], [==[]==], etc.
|
||||
|
||||
parent := node.Parent()
|
||||
var prefix, suffix string
|
||||
for _, c := range content {
|
||||
if strings.Contains(`"[='`, string(c)) {
|
||||
prefix += string(c)
|
||||
if c == '"' || c == '\'' {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := len(content) - 1; i >= 0; i-- {
|
||||
c := content[i]
|
||||
if strings.Contains(`"]'=`, string(c)) {
|
||||
suffix = string(c) + suffix
|
||||
if c == '"' || c == '\'' {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
main := content[len(prefix) : len(content)-len(suffix)]
|
||||
|
||||
hasEscapedStr := func(str string, escd string) bool {
|
||||
// ensure the string has an escd with an odd number of \ before it
|
||||
pos := strings.Index(str, escd)
|
||||
|
||||
// count the number of \ before the escd
|
||||
count := 0
|
||||
for i := pos - 1; i >= 0; i-- {
|
||||
if string(str[i]) != `\` {
|
||||
break
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count%2 == 1
|
||||
}
|
||||
|
||||
// Dued, I had no idea string literals could be so complicated
|
||||
if !(hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`)) {
|
||||
if prefix == `"` {
|
||||
if hasEscapedStr(main, `'`) {
|
||||
main = strings.ReplaceAll(main, `\'`, `'`)
|
||||
} else if hasEscapedStr(main, `"`) && !strings.Contains(main, `'`) {
|
||||
main = strings.ReplaceAll(main, `\"`, `"`)
|
||||
prefix, suffix = `'`, `'`
|
||||
}
|
||||
} else if prefix == `'` {
|
||||
if hasEscapedStr(main, `"`) {
|
||||
main = strings.ReplaceAll(main, `\"`, `"`)
|
||||
} else if hasEscapedStr(main, `'`) && !strings.Contains(main, `"`) {
|
||||
main = strings.ReplaceAll(main, `\'`, `'`)
|
||||
prefix, suffix = `"`, `"`
|
||||
}
|
||||
}
|
||||
} else if hasEscapedStr(main, `'`) && hasEscapedStr(main, `"`) {
|
||||
// if both are escaped, unescape the one with the least escapes
|
||||
// (default to unescaping single quotes if equal)
|
||||
if strings.Count(main, `\"`) > strings.Count(main, `\'`) {
|
||||
main = strings.ReplaceAll(main, `\"`, `"`)
|
||||
prefix, suffix = `'`, `'`
|
||||
} else {
|
||||
main = strings.ReplaceAll(main, `\'`, `'`)
|
||||
prefix, suffix = `"`, `"`
|
||||
}
|
||||
}
|
||||
|
||||
// "christ all mighty" - ezio4322, 1 August 2022
|
||||
if prefix == `'` && hasEscapedStr(main, `'`) && !hasEscapedStr(main, `"`) &&
|
||||
strings.Contains(main, `"`) && strings.Count(main, `\'`) > strings.Count(main, `"`) {
|
||||
// There are escaped single quotes, unescaped double quotes, and more single quotes than double quotes
|
||||
// Swap the quotes
|
||||
main = strings.ReplaceAll(main, `\'`, `'`)
|
||||
main = strings.ReplaceAll(main, `"`, `\"`)
|
||||
prefix, suffix = `"`, `"`
|
||||
} else if prefix == `"` && hasEscapedStr(main, `"`) && !hasEscapedStr(main, `'`) &&
|
||||
strings.Contains(main, `'`) && strings.Count(main, `\"`) > strings.Count(main, `'`) {
|
||||
// that but the other way around
|
||||
main = strings.ReplaceAll(main, `\"`, `"`)
|
||||
main = strings.ReplaceAll(main, `'`, `\'`)
|
||||
prefix, suffix = `'`, `'`
|
||||
}
|
||||
|
||||
formatted += prefix + main + suffix
|
||||
case "var":
|
||||
formatted += getContent(node)
|
||||
|
||||
case "string_interp":
|
||||
for i := range int(node.ChildCount()) {
|
||||
child := node.Child(i)
|
||||
switch child.Type() {
|
||||
case "interp_start", "interp_end":
|
||||
formatted += "`"
|
||||
case "interp_content":
|
||||
formatted += getContent(*child)
|
||||
case "interp_exp":
|
||||
for j := range int(child.ChildCount()) {
|
||||
switch child.Child(j).Type() {
|
||||
case "interp_brace_open":
|
||||
formatted += "{"
|
||||
case "interp_brace_close":
|
||||
formatted += "}"
|
||||
default:
|
||||
formatted += formatExpr(*child.Child(j))
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
panic(c.InRed("Unknown string interpolation child type ") + c.InYellow(child.Type()))
|
||||
}
|
||||
}
|
||||
|
||||
case "table":
|
||||
childCount := node.ChildCount()
|
||||
for i := range int(childCount) {
|
||||
child := node.Child(i)
|
||||
switch child.Type() {
|
||||
case "{":
|
||||
formatted += "{"
|
||||
if childCount > 2 {
|
||||
formatted += "\n"
|
||||
indent++
|
||||
}
|
||||
case "}":
|
||||
// I love me some trailing commas
|
||||
lastChild := node.Child(i - 1)
|
||||
if lastChild.Type() == "fieldlist" &&
|
||||
lastChild.Child(int(lastChild.ChildCount()-1)).Type() != "," {
|
||||
formatted += ",\n"
|
||||
}
|
||||
|
||||
if childCount > 2 {
|
||||
indent--
|
||||
writeIndent()
|
||||
}
|
||||
formatted += "}"
|
||||
case "fieldlist":
|
||||
for j := range int(child.ChildCount()) {
|
||||
child := child.Child(j)
|
||||
switch child.Type() {
|
||||
case ",":
|
||||
formatted += ",\n"
|
||||
case "field":
|
||||
for k := range int(child.ChildCount()) {
|
||||
field := child.Child(k)
|
||||
switch field.Type() {
|
||||
case "name":
|
||||
writeIndent()
|
||||
formatted += getContent(*field)
|
||||
case "=":
|
||||
formatted += " = "
|
||||
case "[":
|
||||
writeIndent()
|
||||
formatted += "["
|
||||
case "]":
|
||||
formatted += "]"
|
||||
default:
|
||||
formatted += formatExpr(*field)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
panic(c.InRed("Unknown expression type ") + c.InYellow(ntype))
|
||||
}
|
||||
|
||||
return formatted
|
||||
}
|
||||
|
||||
formatStmt := func(node sitter.Node) string {
|
||||
var formatted string
|
||||
|
||||
writeIndent := func() {
|
||||
formatted += strings.Repeat("\t", indent)
|
||||
}
|
||||
|
||||
ntype := node.Type()
|
||||
|
||||
// parent := node.Parent()
|
||||
|
||||
switch ntype {
|
||||
case "comment":
|
||||
if len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
|
||||
formatted.WriteString("\n")
|
||||
}
|
||||
formatted.WriteString("\n")
|
||||
case "local_var_stmt":
|
||||
writeIndent()
|
||||
formatted.WriteString(content)
|
||||
|
||||
case "local":
|
||||
if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] &&
|
||||
len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
|
||||
formatted.WriteString("\n")
|
||||
}
|
||||
formatted.WriteString("\n")
|
||||
writeIndent()
|
||||
formatted.WriteString("local ")
|
||||
|
||||
case "name":
|
||||
fmt.Println(parent.Parent().Type(), node.Content(sourceCode))
|
||||
if ((parent.Parent().Type() == "call_stmt" && *parent.Child(0) == node) ||
|
||||
parent.Parent().Type() == "var_stmt" ||
|
||||
parent.Parent().Parent().Type() == "assign_stmt") &&
|
||||
!isInside("local_var_stmt", &node) &&
|
||||
!isInside("ifexp", &node) &&
|
||||
!isInside("binexp", &node) {
|
||||
if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] {
|
||||
formatted.WriteString("\n")
|
||||
for i := range int(node.ChildCount()) {
|
||||
child := node.Child(i)
|
||||
switch child.Type() {
|
||||
case "local":
|
||||
formatted += "local "
|
||||
case "=":
|
||||
formatted += " = "
|
||||
case ",":
|
||||
formatted += ", "
|
||||
case "binding":
|
||||
formatted += getContent(*child)
|
||||
default:
|
||||
formatted += formatExpr(*child)
|
||||
}
|
||||
formatted.WriteString("\n")
|
||||
writeIndent()
|
||||
}
|
||||
formatted.WriteString(content)
|
||||
|
||||
case "=":
|
||||
formatted.WriteString(" = ")
|
||||
case "==":
|
||||
formatted.WriteString(" == ")
|
||||
case "number":
|
||||
formatted.WriteString(content)
|
||||
case "return":
|
||||
formatted.WriteString("\n")
|
||||
formatted += "\n"
|
||||
case "assign_stmt":
|
||||
writeIndent()
|
||||
formatted.WriteString("return ")
|
||||
case "if":
|
||||
ifType := parent.Type()
|
||||
switch ifType {
|
||||
case "if_stmt":
|
||||
formatted.WriteString("\n")
|
||||
writeIndent()
|
||||
fallthrough
|
||||
case "ifexp":
|
||||
formatted.WriteString("if ")
|
||||
default:
|
||||
// damn better be unreachable
|
||||
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
|
||||
}
|
||||
case "then":
|
||||
ifType := parent.Type()
|
||||
switch ifType {
|
||||
case "if_stmt":
|
||||
formatted.WriteString(" then")
|
||||
indent++
|
||||
case "ifexp":
|
||||
formatted.WriteString(" then ")
|
||||
case "elseif_clause":
|
||||
formatted.WriteString(" then")
|
||||
default:
|
||||
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
|
||||
}
|
||||
case "elseif":
|
||||
pType := parent.Type()
|
||||
// if it's in a statement, the parent will be the elseif clause
|
||||
// if it's in an expression, the parent will be the if expression
|
||||
var ifType string
|
||||
switch pType {
|
||||
case "elseif_clause":
|
||||
ifType = parent.Parent().Type()
|
||||
case "ifexp":
|
||||
ifType = pType
|
||||
default:
|
||||
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
|
||||
|
||||
for i := range int(node.ChildCount()) {
|
||||
child := node.Child(i)
|
||||
switch child.Type() {
|
||||
case "=":
|
||||
formatted += " = "
|
||||
case "varlist":
|
||||
for j := range int(child.ChildCount()) {
|
||||
varchild := child.Child(j)
|
||||
switch varchild.Type() {
|
||||
case "var":
|
||||
formatted += getContent(*varchild)
|
||||
case ",":
|
||||
formatted += ", "
|
||||
default:
|
||||
panic(c.InRed("unknown varlist child type ") + c.InYellow(varchild.Type()))
|
||||
}
|
||||
}
|
||||
case "explist":
|
||||
for j := range int(child.ChildCount()) {
|
||||
child := child.Child(j)
|
||||
switch child.Type() {
|
||||
case ",":
|
||||
formatted += ", "
|
||||
default:
|
||||
formatted += formatExpr(*child)
|
||||
}
|
||||
}
|
||||
default:
|
||||
formatted += formatExpr(*child)
|
||||
}
|
||||
}
|
||||
|
||||
switch ifType {
|
||||
case "if_stmt":
|
||||
formatted.WriteString("\n")
|
||||
indent--
|
||||
writeIndent()
|
||||
formatted.WriteString("elseif ")
|
||||
indent++
|
||||
case "ifexp":
|
||||
formatted.WriteString(" elseif ")
|
||||
default:
|
||||
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
|
||||
}
|
||||
case "else":
|
||||
pType := parent.Type()
|
||||
// if it's in a statement, the parent will be the else clause
|
||||
// if it's in an expression, the parent will be the if expression
|
||||
var ifType string
|
||||
switch pType {
|
||||
case "else_clause":
|
||||
ifType = parent.Parent().Type()
|
||||
case "ifexp":
|
||||
ifType = pType
|
||||
default:
|
||||
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
|
||||
}
|
||||
|
||||
switch ifType {
|
||||
case "if_stmt":
|
||||
formatted.WriteString("\n")
|
||||
indent--
|
||||
writeIndent()
|
||||
formatted.WriteString("else")
|
||||
indent++
|
||||
case "ifexp":
|
||||
formatted.WriteString(" else ")
|
||||
default:
|
||||
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
|
||||
}
|
||||
case "end":
|
||||
formatted.WriteString("\n")
|
||||
indent--
|
||||
formatted += "\n"
|
||||
case "comment":
|
||||
writeIndent()
|
||||
formatted.WriteString("end")
|
||||
case "string":
|
||||
if parent.Type() == "arglist" && parent.ChildCount() == 1 {
|
||||
// `print"whatever"` -> `print "whatever"`
|
||||
formatted.WriteString(" ")
|
||||
|
||||
formatted += getContent(node) + "\n"
|
||||
case "call_stmt":
|
||||
writeIndent()
|
||||
|
||||
for i := range int(node.ChildCount()) {
|
||||
child := node.Child(i)
|
||||
switch child.Type() {
|
||||
case "var":
|
||||
formatted += getContent(*child)
|
||||
case "arglist":
|
||||
for j := range int(child.ChildCount()) {
|
||||
argchild := child.Child(j)
|
||||
switch argchild.Type() {
|
||||
case "(":
|
||||
nextType := child.Child(j + 1).Type()
|
||||
if nextType == "string" && child.ChildCount() <= 3 { // function call with single string argument
|
||||
formatted += " "
|
||||
} else {
|
||||
formatted += "("
|
||||
}
|
||||
case ")":
|
||||
prevType := child.Child(j - 1).Type()
|
||||
if prevType != "string" || child.ChildCount() > 3 {
|
||||
formatted += ")"
|
||||
}
|
||||
case ",":
|
||||
formatted += ", "
|
||||
default:
|
||||
if j == 0 {
|
||||
formatted += " "
|
||||
}
|
||||
formatted += formatExpr(*argchild)
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic(c.InRed("Unknown call statement child type ") + c.InYellow(child.Type()))
|
||||
}
|
||||
}
|
||||
formatted.WriteString(content)
|
||||
case "interp_start", "interp_end":
|
||||
formatted.WriteString("`")
|
||||
case "interp_content":
|
||||
formatted.WriteString(content)
|
||||
case "interp_brace_open":
|
||||
formatted.WriteString("{")
|
||||
case "interp_brace_close":
|
||||
formatted.WriteString("}")
|
||||
case ":":
|
||||
formatted.WriteString(":")
|
||||
case ".":
|
||||
formatted.WriteString(".")
|
||||
case "(":
|
||||
argType := parent.Child(1).Type()
|
||||
// `print("whatever")` -> `print "whatever"`
|
||||
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
|
||||
formatted.WriteString("(")
|
||||
} else {
|
||||
formatted.WriteString(" ")
|
||||
}
|
||||
case ")":
|
||||
argType := parent.Child(1).Type()
|
||||
// `print("whatever")` -> `print "whatever"`
|
||||
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
|
||||
formatted.WriteString(")")
|
||||
}
|
||||
case ",":
|
||||
formatted.WriteString(", ")
|
||||
case "true":
|
||||
formatted.WriteString("true")
|
||||
case "false":
|
||||
formatted.WriteString("false")
|
||||
case "+":
|
||||
formatted.WriteString(" + ")
|
||||
case "-":
|
||||
formatted.WriteString(" - ")
|
||||
case "*":
|
||||
formatted.WriteString(" * ")
|
||||
case "/":
|
||||
formatted.WriteString(" / ")
|
||||
case "%":
|
||||
formatted.WriteString(" % ")
|
||||
case "^":
|
||||
formatted.WriteString(" ^ ")
|
||||
case "//":
|
||||
formatted.WriteString(" // ")
|
||||
case "+=":
|
||||
formatted.WriteString(" += ")
|
||||
case "-=":
|
||||
formatted.WriteString(" -= ")
|
||||
case "*=":
|
||||
formatted.WriteString(" *= ")
|
||||
case "/=":
|
||||
formatted.WriteString(" /= ")
|
||||
case "%=":
|
||||
formatted.WriteString(" %= ")
|
||||
case "^=":
|
||||
formatted.WriteString(" ^= ")
|
||||
case "//=":
|
||||
formatted.WriteString(" //= ")
|
||||
case ";":
|
||||
// nothing
|
||||
|
||||
formatted += "\n"
|
||||
|
||||
default:
|
||||
panic(c.InRed("Unknown node type ") + c.InYellow(ntype))
|
||||
panic(c.InRed("Unknown statement type ") + c.InYellow(ntype))
|
||||
}
|
||||
|
||||
return formatted
|
||||
}
|
||||
|
||||
var appendLeaf func(node sitter.Node)
|
||||
appendLeaf = func(node sitter.Node) {
|
||||
// Print only the leaf nodes
|
||||
if node.ChildCount() == 0 {
|
||||
writeFormatted(node)
|
||||
} else {
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
appendLeaf(*(node.Child(i)))
|
||||
}
|
||||
}
|
||||
var formatted string
|
||||
|
||||
for i := range int(main.ChildCount()) {
|
||||
formatted += formatStmt(*main.Child(i))
|
||||
}
|
||||
|
||||
appendLeaf(*node)
|
||||
|
||||
return formatted.String()
|
||||
return formatted
|
||||
}
|
||||
|
||||
func format(filename string) {
|
||||
func formatFile(filename string) {
|
||||
parser := sitter.NewParser()
|
||||
parser.SetLanguage(luau.GetLuau())
|
||||
|
||||
|
|
@ -278,7 +352,7 @@ func format(filename string) {
|
|||
}
|
||||
|
||||
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
|
||||
formatted := formatCode(sourceCode, tree)
|
||||
formatted := formatCode(sourceCode, tree.RootNode())
|
||||
|
||||
// replace all ending newlines with a single newline
|
||||
formatted = strings.Trim(formatted, "\n") + "\n"
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ func main() {
|
|||
if len(args) < 3 {
|
||||
Error("No file specified.")
|
||||
}
|
||||
format(args[2])
|
||||
formatFile(args[2])
|
||||
case "c", "compatibility":
|
||||
if len(args) < 3 {
|
||||
Error("No file specified.")
|
||||
}
|
||||
compatify(args[2])
|
||||
compatifyFile(args[2])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue