Modularise and improve Luau transformers

This commit is contained in:
Lewin Kelly 2024-04-07 16:26:30 +01:00
parent f7c6899f2d
commit 64309fbb6f
4 changed files with 509 additions and 378 deletions

167
Luau/compatibility.go Normal file
View File

@ -0,0 +1,167 @@
package main
import (
luau "Luau/binding"
"context"
"fmt"
"os"
"strings"
c "github.com/TwiN/go-color"
sitter "github.com/smacker/go-tree-sitter"
)
func compatifyCode(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
// watch out for passing by reference
newSource := make([]byte, len(sourceCode))
copy(newSource, sourceCode)
type sub struct {
node *sitter.Node
toReplace string
}
var toSub []sub
var findExprs func(node sitter.Node)
findExprs = func(node sitter.Node) {
start := node.StartByte()
end := node.EndByte()
rand := randomString(int(end - start))
ntype := node.Type()
if ntype == "ifexp" || (ntype == "binexp" && node.Child(1).Type() == "//") || (ntype == "var_stmt" && node.Child(1).Type() == "//=") {
// replace unsupported statements/expressions with a random string
newSource = append(newSource[:start], append([]byte(rand), newSource[end:]...)...)
toSub = append(toSub, sub{&node, rand})
} else {
for i := range int(node.ChildCount()) {
findExprs(*(node.Child(i)))
}
}
}
findExprs(*node)
sourceString := string(newSource)
for _, s := range toSub {
var replacement string
node := s.node
ntype := node.Type()
switch ntype {
case "ifexp":
fmt.Println(c.InBlue("replacing"))
secondCond := node.Child(3).Type()
hasElseIf := node.Child(4).Type() == "elseif"
if !hasElseIf && map[string]bool{
"number": true,
"string": true,
"string_interp": true,
"true": true, // lelel
}[secondCond] {
// SPECIAL CASE: the second condition is guaranteed to be truthy
// this means it can be simplified to Lua's sorta-ternary operator, a and b or c
// doesn't simplify nested if expressions. GOOD ENOUGH
for i := range int(node.ChildCount()) {
child := node.Child(i)
fmt.Println(child.Type())
switch child.Type() {
case "then":
replacement += "and "
case "else":
replacement += "or "
case "elseif":
panic("elseif in special case")
case "if":
// nothing
default:
replacement += child.Content(sourceCode) + " "
}
}
} else {
replacement += "(function()"
for i := range int(node.ChildCount()) {
child := node.Child(i)
fmt.Println(child.Type())
switch child.Type() {
case "if", "elseif", "then":
replacement += child.Content(sourceCode) + " "
case "else":
replacement += "end "
default:
prev := node.Child(i - 1).Type()
if prev == "then" || prev == "else" {
replacement += "return "
}
replacement += child.Content(sourceCode) + " "
}
}
replacement += "end)()"
}
case "binexp":
// child 1 is the operator (//)
left := node.Child(0).Content(sourceCode)
right := node.Child(2).Content(sourceCode)
replacement = fmt.Sprintf("math.floor(%s/%s)", left, right)
case "var_stmt":
// child 1 is the operator (//=)
left := node.Child(0).Content(sourceCode)
right := node.Child(2).Content(sourceCode)
// todo: make it so if one of the exprs is a function it don't get evaluated twice
// nah jk im not doing that
replacement = fmt.Sprintf("%s=math.floor(%s/%s)", left, left, right)
default:
panic("unhandled node type " + ntype)
}
sourceString = strings.Replace(sourceString, s.toReplace, replacement, 1)
}
// replace all ending newlines with a single newline
sourceString = strings.Trim(sourceString, "\n")
return sourceString
}
func compatify(filename string) {
binding := luau.GetLuau()
sourceCode, err := os.ReadFile(filename)
if err != nil {
fmt.Println(err)
return
}
code := string(sourceCode)
// keep parsing until no changes are made lmao
for {
parser := sitter.NewParser()
parser.SetLanguage(binding)
newSource := []byte(code)
tree, err := parser.ParseCtx(context.Background(), nil, newSource)
if err != nil {
fmt.Println(err)
return
}
compatible := compatifyCode(newSource, tree)
// replace all ending newlines with a single newline
compatible = strings.Trim(compatible, "\n")
if compatible == string(code) {
break
}
code = compatible
}
fmt.Println(code)
}

292
Luau/format.go Normal file
View File

@ -0,0 +1,292 @@
package main
import (
luau "Luau/binding"
"context"
"fmt"
"os"
"strings"
c "github.com/TwiN/go-color"
sitter "github.com/smacker/go-tree-sitter"
)
func isInside(ntype string, node *sitter.Node) bool {
for node != nil {
if node.Type() == ntype {
return true
}
node = node.Parent()
}
return false
}
func formatCode(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
var formatted strings.Builder
indent := 0
// -1 means at the start of the file
// 0 means outside of a local var statement
// 1 means inside a local var statement
var insideLocal []bool
var insideComment []bool
writeIndent := func() {
formatted.WriteString(strings.Repeat("\t", indent))
}
writeFormatted := func(node sitter.Node) {
// Write the formatted content to the string builder
content := node.Content(sourceCode)
ntype := node.Type()
insideLocal = append(insideLocal, isInside("local_var_stmt", &node))
insideComment = append(insideComment, isInside("comment", &node))
parent := node.Parent()
switch ntype {
case "comment":
if len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString(content)
case "local":
if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] &&
len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("local ")
case "name":
fmt.Println(parent.Parent().Type(), node.Content(sourceCode))
if ((parent.Parent().Type() == "call_stmt" && *parent.Child(0) == node) ||
parent.Parent().Type() == "var_stmt" ||
parent.Parent().Parent().Type() == "assign_stmt") &&
!isInside("local_var_stmt", &node) &&
!isInside("ifexp", &node) &&
!isInside("binexp", &node) {
if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
}
formatted.WriteString(content)
case "=":
formatted.WriteString(" = ")
case "==":
formatted.WriteString(" == ")
case "number":
formatted.WriteString(content)
case "return":
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("return ")
case "if":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
writeIndent()
fallthrough
case "ifexp":
formatted.WriteString("if ")
default:
// damn better be unreachable
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "then":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString(" then")
indent++
case "ifexp":
formatted.WriteString(" then ")
case "elseif_clause":
formatted.WriteString(" then")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "elseif":
pType := parent.Type()
// if it's in a statement, the parent will be the elseif clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "elseif_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("elseif ")
indent++
case "ifexp":
formatted.WriteString(" elseif ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "else":
pType := parent.Type()
// if it's in a statement, the parent will be the else clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "else_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("else")
indent++
case "ifexp":
formatted.WriteString(" else ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "end":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("end")
case "string":
if parent.Type() == "arglist" && parent.ChildCount() == 1 {
// `print"whatever"` -> `print "whatever"`
formatted.WriteString(" ")
}
formatted.WriteString(content)
case "interp_start", "interp_end":
formatted.WriteString("`")
case "interp_content":
formatted.WriteString(content)
case "interp_brace_open":
formatted.WriteString("{")
case "interp_brace_close":
formatted.WriteString("}")
case ":":
formatted.WriteString(":")
case ".":
formatted.WriteString(".")
case "(":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString("(")
} else {
formatted.WriteString(" ")
}
case ")":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString(")")
}
case ",":
formatted.WriteString(", ")
case "true":
formatted.WriteString("true")
case "false":
formatted.WriteString("false")
case "+":
formatted.WriteString(" + ")
case "-":
formatted.WriteString(" - ")
case "*":
formatted.WriteString(" * ")
case "/":
formatted.WriteString(" / ")
case "%":
formatted.WriteString(" % ")
case "^":
formatted.WriteString(" ^ ")
case "//":
formatted.WriteString(" // ")
case "+=":
formatted.WriteString(" += ")
case "-=":
formatted.WriteString(" -= ")
case "*=":
formatted.WriteString(" *= ")
case "/=":
formatted.WriteString(" /= ")
case "%=":
formatted.WriteString(" %= ")
case "^=":
formatted.WriteString(" ^= ")
case "//=":
formatted.WriteString(" //= ")
case ";":
// nothing
default:
panic(c.InRed("Unknown node type ") + c.InYellow(ntype))
}
}
var appendLeaf func(node sitter.Node)
appendLeaf = func(node sitter.Node) {
// Print only the leaf nodes
if node.ChildCount() == 0 {
writeFormatted(node)
} else {
for i := 0; i < int(node.ChildCount()); i++ {
appendLeaf(*(node.Child(i)))
}
}
}
appendLeaf(*node)
return formatted.String()
}
func format(filename string) {
parser := sitter.NewParser()
parser.SetLanguage(luau.GetLuau())
sourceCode, err := os.ReadFile(filename)
if err != nil {
fmt.Println(err)
return
}
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
formatted := formatCode(sourceCode, tree)
// replace all ending newlines with a single newline
formatted = strings.Trim(formatted, "\n") + "\n"
// write back to file
err = os.WriteFile(filename, []byte(formatted), 0o644)
if err != nil {
fmt.Println(err)
return
}
}

View File

@ -1,396 +1,37 @@
package main
import (
luau "Luau/binding"
"context"
"fmt"
"math/rand"
"os"
"strings"
c "github.com/TwiN/go-color"
sitter "github.com/smacker/go-tree-sitter"
)
func randomString(length int) string {
// generate random unicode string
var str string
for i := 0; i < length; i++ {
str += string(rune(rand.Intn(0x7E-0x21) + 0x21))
}
return str
}
func isInside(ntype string, node *sitter.Node) bool {
for node != nil {
if node.Type() == ntype {
return true
}
node = node.Parent()
}
return false
}
func formatCode(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
var formatted strings.Builder
indent := 0
// -1 means at the start of the file
// 0 means outside of a local var statement
// 1 means inside a local var statement
var insideLocal []bool
var insideComment []bool
writeIndent := func() {
formatted.WriteString(strings.Repeat("\t", indent))
}
writeFormatted := func(node sitter.Node) {
// Write the formatted content to the string builder
content := node.Content(sourceCode)
ntype := node.Type()
insideLocal = append(insideLocal, isInside("local_var_stmt", &node))
insideComment = append(insideComment, isInside("comment", &node))
parent := node.Parent()
switch ntype {
case "comment":
if len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString(content)
case "local":
if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] &&
len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("local ")
case "name":
if parent.Parent().Type() == "call_stmt" && !isInside("local_var_stmt", &node) {
if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
}
formatted.WriteString(content)
case "=":
formatted.WriteString(" = ")
case "==":
formatted.WriteString(" == ")
case "number":
formatted.WriteString(content)
case "return":
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("return ")
case "if":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
writeIndent()
fallthrough
case "ifexp":
formatted.WriteString("if ")
default:
// damn better be unreachable
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "then":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString(" then")
indent++
case "ifexp":
formatted.WriteString(" then ")
case "elseif_clause":
formatted.WriteString(" then")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "elseif":
pType := parent.Type()
// if it's in a statement, the parent will be the elseif clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "elseif_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("elseif ")
indent++
case "ifexp":
formatted.WriteString(" elseif ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "else":
pType := parent.Type()
// if it's in a statement, the parent will be the else clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "else_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("else")
indent++
case "ifexp":
formatted.WriteString(" else ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "end":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("end")
case "string":
if parent.Type() == "arglist" && parent.ChildCount() == 1 {
// `print"whatever"` -> `print "whatever"`
formatted.WriteString(" ")
}
formatted.WriteString(content)
case "interp_start":
fallthrough
case "interp_end":
formatted.WriteString("`")
case "interp_content":
formatted.WriteString(content)
case "interp_brace_open":
formatted.WriteString("{")
case "interp_brace_close":
formatted.WriteString("}")
case ":":
formatted.WriteString(":")
case ".":
formatted.WriteString(".")
case "(":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString("(")
} else {
formatted.WriteString(" ")
}
case ")":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString(")")
}
case ",":
formatted.WriteString(", ")
case "true":
formatted.WriteString("true")
case "false":
formatted.WriteString("false")
default:
panic(c.InRed("Unknown node type ") + c.InYellow(ntype))
}
}
var appendLeaf func(node sitter.Node)
appendLeaf = func(node sitter.Node) {
// Print only the leaf nodes
if node.ChildCount() == 0 {
writeFormatted(node)
} else {
for i := 0; i < int(node.ChildCount()); i++ {
appendLeaf(*(node.Child(i)))
}
}
}
appendLeaf(*node)
return formatted.String()
}
func compatibility(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
// watch out for passing by reference
newSource := make([]byte, len(sourceCode))
copy(newSource, sourceCode)
type sub struct {
node *sitter.Node
toReplace string
}
var toSub []sub
var findExprs func(node sitter.Node)
findExprs = func(node sitter.Node) {
start := node.StartByte()
end := node.EndByte()
rand := randomString(int(end - start))
ntype := node.Type()
if ntype == "ifexp" || (ntype == "binexp" && node.Child(1).Type() == "//") || (ntype == "var_stmt" && node.Child(1).Type() == "//=") {
// replace unsupported statements/expressions with a random string
newSource = append(newSource[:start], append([]byte(rand), newSource[end:]...)...)
toSub = append(toSub, sub{&node, rand})
} else {
for i := range int(node.ChildCount()) {
findExprs(*(node.Child(i)))
}
}
}
findExprs(*node)
sourceString := string(newSource)
for _, s := range toSub {
var replacement string
node := s.node
ntype := node.Type()
switch ntype {
case "ifexp":
fmt.Println(c.InBlue("replacing"))
secondCond := node.Child(3).Type()
hasElseIf := node.Child(4).Type() == "elseif"
if !hasElseIf && map[string]bool{
"number": true,
"string": true,
"string_interp": true,
"true": true, // lelel
}[secondCond] {
// SPECIAL CASE: the second condition is guaranteed to be truthy
// this means it can be simplified to Lua's sorta-ternary operator, a and b or c
// doesn't simplify nested if expressions. GOOD ENOUGH
for i := range int(node.ChildCount()) {
child := node.Child(i)
fmt.Println(child.Type())
switch child.Type() {
case "then":
replacement += "and "
case "else":
replacement += "or "
case "elseif":
panic("elseif in special case")
case "if":
// nothing
default:
replacement += child.Content(sourceCode) + " "
}
}
} else {
replacement += "(function()"
for i := range int(node.ChildCount()) {
child := node.Child(i)
fmt.Println(child.Type())
switch child.Type() {
case "if", "elseif", "then":
replacement += child.Content(sourceCode) + " "
case "else":
replacement += "end "
default:
prev := node.Child(i - 1).Type()
if prev == "then" || prev == "else" {
replacement += "return "
}
replacement += child.Content(sourceCode) + " "
}
}
replacement += "end)()"
}
case "binexp":
// child 1 is the operator (//)
left := node.Child(0).Content(sourceCode)
right := node.Child(2).Content(sourceCode)
replacement = fmt.Sprintf("math.floor(%s/%s)", left, right)
case "var_stmt":
// child 1 is the operator (//=)
left := node.Child(0).Content(sourceCode)
right := node.Child(2).Content(sourceCode)
// todo: make it so if one of the exprs is a function it don't get evaluated twice
// nah jk im not doing that
replacement = fmt.Sprintf("%s=math.floor(%s/%s)", left, left, right)
default:
panic("unhandled node type " + ntype)
}
sourceString = strings.Replace(sourceString, s.toReplace, replacement, 1)
}
return sourceString
}
func main() {
binding := luau.GetLuau()
args := os.Args
filename := "test.luau"
sourceCode, err := os.ReadFile(filename)
if err != nil {
fmt.Println(err)
return
if len(args) < 2 {
Error("No command specified. Run with 'help' to see available commands.")
}
code := string(sourceCode)
// keep parsing until no changes are made lmao
for {
parser := sitter.NewParser()
parser.SetLanguage(binding)
newSource := []byte(code)
tree, err := parser.ParseCtx(context.Background(), nil, newSource)
if err != nil {
fmt.Println(err)
return
switch strings.ToLower(args[1]) {
case "h", "help":
fmt.Println(c.InYellow("Usage"))
fmt.Println(c.InGreen(" [executable] [command] [arguments]\n"))
fmt.Println(c.InYellow("Commands"))
fmt.Println(c.InBlue(" h help") + " Shows this help message")
fmt.Println(c.InBlue(" f format [file]") + " Formats the specified file")
fmt.Println(c.InBlue(" c compatibility [file]") + " Makes the specified file compatible with Lua")
case "f", "format":
if len(args) < 3 {
Error("No file specified.")
}
compatible := compatibility(newSource, tree)
// replace all ending newlines with a single newline
compatible = strings.Trim(compatible, "\n")
if compatible == string(code) {
break
format(args[2])
case "c", "compatibility":
if len(args) < 3 {
Error("No file specified.")
}
code = compatible
compatify(args[2])
}
fmt.Println(code)
}

31
Luau/util.go Normal file
View File

@ -0,0 +1,31 @@
package main
import (
"fmt"
"math/rand"
"os"
c "github.com/TwiN/go-color"
)
func Error(txt string) {
fmt.Println(c.InRed("Error: ") + txt)
os.Exit(1)
}
func Assert(err error, txt string) {
// so that I don't have to write this every time
if err != nil {
fmt.Println(err)
Error(txt)
}
}
func randomString(length int) string {
// generate random unicode string
var str string
for i := 0; i < length; i++ {
str += string(rune(rand.Intn(0x7E-0x21) + 0x21))
}
return str
}