diff --git a/Luau/compatibility.go b/Luau/compatibility.go new file mode 100644 index 0000000..c2f48f3 --- /dev/null +++ b/Luau/compatibility.go @@ -0,0 +1,167 @@ +package main + +import ( + luau "Luau/binding" + "context" + "fmt" + "os" + "strings" + + c "github.com/TwiN/go-color" + sitter "github.com/smacker/go-tree-sitter" +) + +func compatifyCode(sourceCode []byte, tree *sitter.Tree) string { + node := tree.RootNode() + if node.HasError() { + fmt.Println("Error parsing code") + return string(sourceCode) + } + + // watch out for passing by reference + newSource := make([]byte, len(sourceCode)) + copy(newSource, sourceCode) + + type sub struct { + node *sitter.Node + toReplace string + } + var toSub []sub + + var findExprs func(node sitter.Node) + findExprs = func(node sitter.Node) { + start := node.StartByte() + end := node.EndByte() + rand := randomString(int(end - start)) + ntype := node.Type() + + if ntype == "ifexp" || (ntype == "binexp" && node.Child(1).Type() == "//") || (ntype == "var_stmt" && node.Child(1).Type() == "//=") { + // replace unsupported statements/expressions with a random string + newSource = append(newSource[:start], append([]byte(rand), newSource[end:]...)...) + toSub = append(toSub, sub{&node, rand}) + } else { + for i := range int(node.ChildCount()) { + findExprs(*(node.Child(i))) + } + } + } + + findExprs(*node) + sourceString := string(newSource) + + for _, s := range toSub { + var replacement string + + node := s.node + ntype := node.Type() + + switch ntype { + case "ifexp": + fmt.Println(c.InBlue("replacing")) + secondCond := node.Child(3).Type() + hasElseIf := node.Child(4).Type() == "elseif" + + if !hasElseIf && map[string]bool{ + "number": true, + "string": true, + "string_interp": true, + "true": true, // lelel + }[secondCond] { + // SPECIAL CASE: the second condition is guaranteed to be truthy + // this means it can be simplified to Lua's sorta-ternary operator, a and b or c + + // doesn't simplify nested if expressions. GOOD ENOUGH + for i := range int(node.ChildCount()) { + child := node.Child(i) + fmt.Println(child.Type()) + switch child.Type() { + case "then": + replacement += "and " + case "else": + replacement += "or " + case "elseif": + panic("elseif in special case") + case "if": + // nothing + default: + replacement += child.Content(sourceCode) + " " + } + } + } else { + replacement += "(function()" + for i := range int(node.ChildCount()) { + child := node.Child(i) + fmt.Println(child.Type()) + switch child.Type() { + case "if", "elseif", "then": + replacement += child.Content(sourceCode) + " " + case "else": + replacement += "end " + default: + prev := node.Child(i - 1).Type() + if prev == "then" || prev == "else" { + replacement += "return " + } + replacement += child.Content(sourceCode) + " " + } + } + replacement += "end)()" + } + case "binexp": + // child 1 is the operator (//) + left := node.Child(0).Content(sourceCode) + right := node.Child(2).Content(sourceCode) + replacement = fmt.Sprintf("math.floor(%s/%s)", left, right) + case "var_stmt": + // child 1 is the operator (//=) + left := node.Child(0).Content(sourceCode) + right := node.Child(2).Content(sourceCode) + + // todo: make it so if one of the exprs is a function it don't get evaluated twice + // nah jk im not doing that + replacement = fmt.Sprintf("%s=math.floor(%s/%s)", left, left, right) + default: + panic("unhandled node type " + ntype) + } + + sourceString = strings.Replace(sourceString, s.toReplace, replacement, 1) + } + + // replace all ending newlines with a single newline + sourceString = strings.Trim(sourceString, "\n") + return sourceString +} + +func compatify(filename string) { + binding := luau.GetLuau() + sourceCode, err := os.ReadFile(filename) + if err != nil { + fmt.Println(err) + return + } + + code := string(sourceCode) + + // keep parsing until no changes are made lmao + for { + parser := sitter.NewParser() + parser.SetLanguage(binding) + + newSource := []byte(code) + tree, err := parser.ParseCtx(context.Background(), nil, newSource) + if err != nil { + fmt.Println(err) + return + } + compatible := compatifyCode(newSource, tree) + // replace all ending newlines with a single newline + compatible = strings.Trim(compatible, "\n") + + if compatible == string(code) { + break + } + code = compatible + } + fmt.Println(code) + +} diff --git a/Luau/format.go b/Luau/format.go new file mode 100644 index 0000000..859a082 --- /dev/null +++ b/Luau/format.go @@ -0,0 +1,292 @@ +package main + +import ( + luau "Luau/binding" + "context" + "fmt" + "os" + "strings" + + c "github.com/TwiN/go-color" + sitter "github.com/smacker/go-tree-sitter" +) + +func isInside(ntype string, node *sitter.Node) bool { + for node != nil { + if node.Type() == ntype { + return true + } + node = node.Parent() + } + return false +} + +func formatCode(sourceCode []byte, tree *sitter.Tree) string { + node := tree.RootNode() + if node.HasError() { + fmt.Println("Error parsing code") + return string(sourceCode) + } + + var formatted strings.Builder + indent := 0 + // -1 means at the start of the file + // 0 means outside of a local var statement + // 1 means inside a local var statement + var insideLocal []bool + var insideComment []bool + + writeIndent := func() { + formatted.WriteString(strings.Repeat("\t", indent)) + } + + writeFormatted := func(node sitter.Node) { + // Write the formatted content to the string builder + content := node.Content(sourceCode) + ntype := node.Type() + + insideLocal = append(insideLocal, isInside("local_var_stmt", &node)) + insideComment = append(insideComment, isInside("comment", &node)) + + parent := node.Parent() + + switch ntype { + case "comment": + if len(insideComment) > 1 && !insideComment[len(insideComment)-2] { + formatted.WriteString("\n") + } + formatted.WriteString("\n") + writeIndent() + formatted.WriteString(content) + + case "local": + if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] && + len(insideComment) > 1 && !insideComment[len(insideComment)-2] { + formatted.WriteString("\n") + } + formatted.WriteString("\n") + writeIndent() + formatted.WriteString("local ") + + case "name": + fmt.Println(parent.Parent().Type(), node.Content(sourceCode)) + if ((parent.Parent().Type() == "call_stmt" && *parent.Child(0) == node) || + parent.Parent().Type() == "var_stmt" || + parent.Parent().Parent().Type() == "assign_stmt") && + !isInside("local_var_stmt", &node) && + !isInside("ifexp", &node) && + !isInside("binexp", &node) { + if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] { + formatted.WriteString("\n") + } + formatted.WriteString("\n") + writeIndent() + } + formatted.WriteString(content) + + case "=": + formatted.WriteString(" = ") + case "==": + formatted.WriteString(" == ") + case "number": + formatted.WriteString(content) + case "return": + formatted.WriteString("\n") + writeIndent() + formatted.WriteString("return ") + case "if": + ifType := parent.Type() + switch ifType { + case "if_stmt": + formatted.WriteString("\n") + writeIndent() + fallthrough + case "ifexp": + formatted.WriteString("if ") + default: + // damn better be unreachable + panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) + } + case "then": + ifType := parent.Type() + switch ifType { + case "if_stmt": + formatted.WriteString(" then") + indent++ + case "ifexp": + formatted.WriteString(" then ") + case "elseif_clause": + formatted.WriteString(" then") + default: + panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) + } + case "elseif": + pType := parent.Type() + // if it's in a statement, the parent will be the elseif clause + // if it's in an expression, the parent will be the if expression + var ifType string + switch pType { + case "elseif_clause": + ifType = parent.Parent().Type() + case "ifexp": + ifType = pType + default: + panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) + } + + switch ifType { + case "if_stmt": + formatted.WriteString("\n") + indent-- + writeIndent() + formatted.WriteString("elseif ") + indent++ + case "ifexp": + formatted.WriteString(" elseif ") + default: + panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) + } + case "else": + pType := parent.Type() + // if it's in a statement, the parent will be the else clause + // if it's in an expression, the parent will be the if expression + var ifType string + switch pType { + case "else_clause": + ifType = parent.Parent().Type() + case "ifexp": + ifType = pType + default: + panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) + } + + switch ifType { + case "if_stmt": + formatted.WriteString("\n") + indent-- + writeIndent() + formatted.WriteString("else") + indent++ + case "ifexp": + formatted.WriteString(" else ") + default: + panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) + } + case "end": + formatted.WriteString("\n") + indent-- + writeIndent() + formatted.WriteString("end") + case "string": + if parent.Type() == "arglist" && parent.ChildCount() == 1 { + // `print"whatever"` -> `print "whatever"` + formatted.WriteString(" ") + } + formatted.WriteString(content) + case "interp_start", "interp_end": + formatted.WriteString("`") + case "interp_content": + formatted.WriteString(content) + case "interp_brace_open": + formatted.WriteString("{") + case "interp_brace_close": + formatted.WriteString("}") + case ":": + formatted.WriteString(":") + case ".": + formatted.WriteString(".") + case "(": + argType := parent.Child(1).Type() + // `print("whatever")` -> `print "whatever"` + if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { + formatted.WriteString("(") + } else { + formatted.WriteString(" ") + } + case ")": + argType := parent.Child(1).Type() + // `print("whatever")` -> `print "whatever"` + if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { + formatted.WriteString(")") + } + case ",": + formatted.WriteString(", ") + case "true": + formatted.WriteString("true") + case "false": + formatted.WriteString("false") + case "+": + formatted.WriteString(" + ") + case "-": + formatted.WriteString(" - ") + case "*": + formatted.WriteString(" * ") + case "/": + formatted.WriteString(" / ") + case "%": + formatted.WriteString(" % ") + case "^": + formatted.WriteString(" ^ ") + case "//": + formatted.WriteString(" // ") + case "+=": + formatted.WriteString(" += ") + case "-=": + formatted.WriteString(" -= ") + case "*=": + formatted.WriteString(" *= ") + case "/=": + formatted.WriteString(" /= ") + case "%=": + formatted.WriteString(" %= ") + case "^=": + formatted.WriteString(" ^= ") + case "//=": + formatted.WriteString(" //= ") + case ";": + // nothing + default: + panic(c.InRed("Unknown node type ") + c.InYellow(ntype)) + } + } + + var appendLeaf func(node sitter.Node) + appendLeaf = func(node sitter.Node) { + // Print only the leaf nodes + if node.ChildCount() == 0 { + writeFormatted(node) + } else { + for i := 0; i < int(node.ChildCount()); i++ { + appendLeaf(*(node.Child(i))) + } + } + } + + appendLeaf(*node) + + return formatted.String() +} + +func format(filename string) { + parser := sitter.NewParser() + parser.SetLanguage(luau.GetLuau()) + + sourceCode, err := os.ReadFile(filename) + if err != nil { + fmt.Println(err) + return + } + + tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode) + formatted := formatCode(sourceCode, tree) + + // replace all ending newlines with a single newline + formatted = strings.Trim(formatted, "\n") + "\n" + + // write back to file + err = os.WriteFile(filename, []byte(formatted), 0o644) + if err != nil { + fmt.Println(err) + return + } +} diff --git a/Luau/main.go b/Luau/main.go index 6c8b368..c288f27 100644 --- a/Luau/main.go +++ b/Luau/main.go @@ -1,396 +1,37 @@ package main import ( - luau "Luau/binding" - "context" "fmt" - "math/rand" "os" "strings" c "github.com/TwiN/go-color" - sitter "github.com/smacker/go-tree-sitter" ) -func randomString(length int) string { - // generate random unicode string - var str string - for i := 0; i < length; i++ { - str += string(rune(rand.Intn(0x7E-0x21) + 0x21)) - } - return str -} - -func isInside(ntype string, node *sitter.Node) bool { - for node != nil { - if node.Type() == ntype { - return true - } - node = node.Parent() - } - return false -} - -func formatCode(sourceCode []byte, tree *sitter.Tree) string { - node := tree.RootNode() - if node.HasError() { - fmt.Println("Error parsing code") - return string(sourceCode) - } - - var formatted strings.Builder - indent := 0 - // -1 means at the start of the file - // 0 means outside of a local var statement - // 1 means inside a local var statement - var insideLocal []bool - var insideComment []bool - - writeIndent := func() { - formatted.WriteString(strings.Repeat("\t", indent)) - } - - writeFormatted := func(node sitter.Node) { - // Write the formatted content to the string builder - content := node.Content(sourceCode) - ntype := node.Type() - - insideLocal = append(insideLocal, isInside("local_var_stmt", &node)) - insideComment = append(insideComment, isInside("comment", &node)) - - parent := node.Parent() - - switch ntype { - case "comment": - if len(insideComment) > 1 && !insideComment[len(insideComment)-2] { - formatted.WriteString("\n") - } - formatted.WriteString("\n") - writeIndent() - formatted.WriteString(content) - case "local": - if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] && - len(insideComment) > 1 && !insideComment[len(insideComment)-2] { - formatted.WriteString("\n") - } - formatted.WriteString("\n") - writeIndent() - formatted.WriteString("local ") - case "name": - if parent.Parent().Type() == "call_stmt" && !isInside("local_var_stmt", &node) { - if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] { - formatted.WriteString("\n") - } - formatted.WriteString("\n") - writeIndent() - } - formatted.WriteString(content) - case "=": - formatted.WriteString(" = ") - case "==": - formatted.WriteString(" == ") - case "number": - formatted.WriteString(content) - case "return": - formatted.WriteString("\n") - writeIndent() - formatted.WriteString("return ") - case "if": - ifType := parent.Type() - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - writeIndent() - fallthrough - case "ifexp": - formatted.WriteString("if ") - default: - // damn better be unreachable - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "then": - ifType := parent.Type() - switch ifType { - case "if_stmt": - formatted.WriteString(" then") - indent++ - case "ifexp": - formatted.WriteString(" then ") - case "elseif_clause": - formatted.WriteString(" then") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "elseif": - pType := parent.Type() - // if it's in a statement, the parent will be the elseif clause - // if it's in an expression, the parent will be the if expression - var ifType string - switch pType { - case "elseif_clause": - ifType = parent.Parent().Type() - case "ifexp": - ifType = pType - default: - panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) - } - - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - indent-- - writeIndent() - formatted.WriteString("elseif ") - indent++ - case "ifexp": - formatted.WriteString(" elseif ") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "else": - pType := parent.Type() - // if it's in a statement, the parent will be the else clause - // if it's in an expression, the parent will be the if expression - var ifType string - switch pType { - case "else_clause": - ifType = parent.Parent().Type() - case "ifexp": - ifType = pType - default: - panic(c.InRed("Unknown parent type ") + c.InYellow(pType)) - } - - switch ifType { - case "if_stmt": - formatted.WriteString("\n") - indent-- - writeIndent() - formatted.WriteString("else") - indent++ - case "ifexp": - formatted.WriteString(" else ") - default: - panic(c.InRed("Unknown if type ") + c.InYellow(ifType)) - } - case "end": - formatted.WriteString("\n") - indent-- - writeIndent() - formatted.WriteString("end") - case "string": - if parent.Type() == "arglist" && parent.ChildCount() == 1 { - // `print"whatever"` -> `print "whatever"` - formatted.WriteString(" ") - } - formatted.WriteString(content) - case "interp_start": - fallthrough - case "interp_end": - formatted.WriteString("`") - case "interp_content": - formatted.WriteString(content) - case "interp_brace_open": - formatted.WriteString("{") - case "interp_brace_close": - formatted.WriteString("}") - case ":": - formatted.WriteString(":") - case ".": - formatted.WriteString(".") - case "(": - argType := parent.Child(1).Type() - // `print("whatever")` -> `print "whatever"` - if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { - formatted.WriteString("(") - } else { - formatted.WriteString(" ") - } - case ")": - argType := parent.Child(1).Type() - // `print("whatever")` -> `print "whatever"` - if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" { - formatted.WriteString(")") - } - case ",": - formatted.WriteString(", ") - case "true": - formatted.WriteString("true") - case "false": - formatted.WriteString("false") - default: - panic(c.InRed("Unknown node type ") + c.InYellow(ntype)) - } - } - - var appendLeaf func(node sitter.Node) - appendLeaf = func(node sitter.Node) { - // Print only the leaf nodes - if node.ChildCount() == 0 { - writeFormatted(node) - } else { - for i := 0; i < int(node.ChildCount()); i++ { - appendLeaf(*(node.Child(i))) - } - } - } - - appendLeaf(*node) - - return formatted.String() -} - -func compatibility(sourceCode []byte, tree *sitter.Tree) string { - node := tree.RootNode() - if node.HasError() { - fmt.Println("Error parsing code") - return string(sourceCode) - } - - // watch out for passing by reference - newSource := make([]byte, len(sourceCode)) - copy(newSource, sourceCode) - - type sub struct { - node *sitter.Node - toReplace string - } - var toSub []sub - - var findExprs func(node sitter.Node) - findExprs = func(node sitter.Node) { - start := node.StartByte() - end := node.EndByte() - rand := randomString(int(end - start)) - ntype := node.Type() - - if ntype == "ifexp" || (ntype == "binexp" && node.Child(1).Type() == "//") || (ntype == "var_stmt" && node.Child(1).Type() == "//=") { - // replace unsupported statements/expressions with a random string - newSource = append(newSource[:start], append([]byte(rand), newSource[end:]...)...) - toSub = append(toSub, sub{&node, rand}) - } else { - for i := range int(node.ChildCount()) { - findExprs(*(node.Child(i))) - } - } - } - - findExprs(*node) - sourceString := string(newSource) - - for _, s := range toSub { - var replacement string - - node := s.node - ntype := node.Type() - - switch ntype { - case "ifexp": - fmt.Println(c.InBlue("replacing")) - secondCond := node.Child(3).Type() - hasElseIf := node.Child(4).Type() == "elseif" - - if !hasElseIf && map[string]bool{ - "number": true, - "string": true, - "string_interp": true, - "true": true, // lelel - }[secondCond] { - // SPECIAL CASE: the second condition is guaranteed to be truthy - // this means it can be simplified to Lua's sorta-ternary operator, a and b or c - - // doesn't simplify nested if expressions. GOOD ENOUGH - for i := range int(node.ChildCount()) { - child := node.Child(i) - fmt.Println(child.Type()) - switch child.Type() { - case "then": - replacement += "and " - case "else": - replacement += "or " - case "elseif": - panic("elseif in special case") - case "if": - // nothing - default: - replacement += child.Content(sourceCode) + " " - } - } - } else { - replacement += "(function()" - for i := range int(node.ChildCount()) { - child := node.Child(i) - fmt.Println(child.Type()) - switch child.Type() { - case "if", "elseif", "then": - replacement += child.Content(sourceCode) + " " - case "else": - replacement += "end " - default: - prev := node.Child(i - 1).Type() - if prev == "then" || prev == "else" { - replacement += "return " - } - replacement += child.Content(sourceCode) + " " - } - } - replacement += "end)()" - } - case "binexp": - // child 1 is the operator (//) - left := node.Child(0).Content(sourceCode) - right := node.Child(2).Content(sourceCode) - replacement = fmt.Sprintf("math.floor(%s/%s)", left, right) - case "var_stmt": - // child 1 is the operator (//=) - left := node.Child(0).Content(sourceCode) - right := node.Child(2).Content(sourceCode) - - // todo: make it so if one of the exprs is a function it don't get evaluated twice - // nah jk im not doing that - replacement = fmt.Sprintf("%s=math.floor(%s/%s)", left, left, right) - default: - panic("unhandled node type " + ntype) - } - - sourceString = strings.Replace(sourceString, s.toReplace, replacement, 1) - } - - return sourceString -} - func main() { - binding := luau.GetLuau() + args := os.Args - filename := "test.luau" - - sourceCode, err := os.ReadFile(filename) - if err != nil { - fmt.Println(err) - return + if len(args) < 2 { + Error("No command specified. Run with 'help' to see available commands.") } - code := string(sourceCode) - - // keep parsing until no changes are made lmao - for { - parser := sitter.NewParser() - parser.SetLanguage(binding) - - newSource := []byte(code) - tree, err := parser.ParseCtx(context.Background(), nil, newSource) - if err != nil { - fmt.Println(err) - return + switch strings.ToLower(args[1]) { + case "h", "help": + fmt.Println(c.InYellow("Usage")) + fmt.Println(c.InGreen(" [executable] [command] [arguments]\n")) + fmt.Println(c.InYellow("Commands")) + fmt.Println(c.InBlue(" h help") + " Shows this help message") + fmt.Println(c.InBlue(" f format [file]") + " Formats the specified file") + fmt.Println(c.InBlue(" c compatibility [file]") + " Makes the specified file compatible with Lua") + case "f", "format": + if len(args) < 3 { + Error("No file specified.") } - compatible := compatibility(newSource, tree) - // replace all ending newlines with a single newline - compatible = strings.Trim(compatible, "\n") - - if compatible == string(code) { - break + format(args[2]) + case "c", "compatibility": + if len(args) < 3 { + Error("No file specified.") } - code = compatible + compatify(args[2]) } - fmt.Println(code) } diff --git a/Luau/util.go b/Luau/util.go new file mode 100644 index 0000000..53b9290 --- /dev/null +++ b/Luau/util.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "math/rand" + "os" + + c "github.com/TwiN/go-color" +) + +func Error(txt string) { + fmt.Println(c.InRed("Error: ") + txt) + os.Exit(1) +} + +func Assert(err error, txt string) { + // so that I don't have to write this every time + if err != nil { + fmt.Println(err) + Error(txt) + } +} + +func randomString(length int) string { + // generate random unicode string + var str string + for i := 0; i < length; i++ { + str += string(rune(rand.Intn(0x7E-0x21) + 0x21)) + } + return str +}