Luau code formatter with tree-sitter

This commit is contained in:
Lewin Kelly 2024-04-06 14:02:49 +01:00
parent 9b2c5b4870
commit 97d14d4a23
7 changed files with 28940 additions and 0 deletions

15
Luau/binding/binding.go Normal file
View File

@ -0,0 +1,15 @@
package luau
//#include "tree_sitter/parser.h"
//TSLanguage *tree_sitter_luau();
import "C"
import (
"unsafe"
sitter "github.com/smacker/go-tree-sitter"
)
func GetLuau() *sitter.Language {
ptr := unsafe.Pointer(C.tree_sitter_luau())
return sitter.NewLanguage(ptr)
}

27977
Luau/binding/parser.c Normal file

File diff suppressed because it is too large Load Diff

442
Luau/binding/scanner.c Normal file
View File

@ -0,0 +1,442 @@
#include <tree_sitter/parser.h>
#include <wctype.h>
enum TokenType
{
COMMENT_START,
COMMENT_CONTENT,
COMMENT_END,
STRING_START,
STRING_CONTENT,
STRING_END,
INTERP_START,
INTERP_CONTENT,
INTERP_BRACE_OPEN,
INTERP_BRACE_CLOSE,
INTERP_END
};
static void consume(TSLexer *lexer)
{
lexer->advance(lexer, false);
}
static void skip(TSLexer *lexer)
{
lexer->advance(lexer, true);
}
static bool consume_if(TSLexer *lexer, const int32_t character)
{
if (lexer->lookahead == character)
{
consume(lexer);
return true;
}
return false;
}
static bool skipwspace(TSLexer *lexer)
{
if (iswspace(lexer->lookahead) || lexer->lookahead == '\r')
{
skip(lexer);
return true;
}
return false;
}
const char SQ_STRING_DELIMITER = '\'';
const char DQ_STRING_DELIMITER = '"';
const char TICK_DELIMITER = '`';
enum StartedToken
{
SHORT_COMMENT = 1,
SHORT_SQ_STRING,
SHORT_DQ_STRING,
LONG_COMMENT,
LONG_STRING,
TICK_STRING,
INTERP_EXPRESSION,
};
struct ScannerState
{
enum StartedToken started;
unsigned int depth;
unsigned int idepth;
};
void *tree_sitter_luau_external_scanner_create()
{
// this used to be allocated without instantiation
struct ScannerState *state = malloc(sizeof(struct ScannerState));
state->started = 0;
state->depth = 0;
state->idepth = 0;
return state;
}
void tree_sitter_luau_external_scanner_destroy(void *payload)
{
free(payload);
}
unsigned int tree_sitter_luau_external_scanner_serialize(void *payload, char *buffer)
{
struct ScannerState *state = payload;
buffer[0] = state->started;
buffer[1] = state->depth;
buffer[2] = state->idepth;
return 3;
}
void tree_sitter_luau_external_scanner_deserialize(void *payload, const char *buffer, unsigned int length)
{
if (length == 3)
{
struct ScannerState *state = payload;
state->started = buffer[0];
state->depth = buffer[1];
state->idepth = buffer[2];
}
}
static unsigned int get_depth(TSLexer *lexer)
{
unsigned int current_depth = 0;
while (consume_if(lexer, '='))
{
current_depth += 1;
}
return current_depth;
}
static bool scan_depth(TSLexer *lexer, unsigned int remaining_depth)
{
while (remaining_depth > 0 && consume_if(lexer, '='))
{
remaining_depth -= 1;
}
return remaining_depth == 0;
}
static bool escape_handler(TSLexer *lexer)
{
if (consume_if(lexer, '\\') && !lexer->eof(lexer))
{
if (lexer->lookahead == '\r')
{
skip(lexer);
if (!lexer->eof(lexer) && lexer->lookahead == '\n')
{
skip(lexer);
}
}
else if (lexer->lookahead == '\n')
{
skip(lexer);
}
else if (consume_if(lexer, 'z') && !lexer->eof(lexer))
{
while (skipwspace(lexer) && !lexer->eof(lexer))
;
return true;
}
}
return false;
}
bool tree_sitter_luau_external_scanner_scan(void *payload, TSLexer *lexer, const bool *valid_symbols)
{
struct ScannerState *state = payload;
switch (state->started)
{
case SHORT_COMMENT:
{
// try to match the short comment's end (new line or eof)
if (lexer->lookahead == '\n' || lexer->eof(lexer))
{
if (valid_symbols[COMMENT_END])
{
state->started = state->idepth > 0 ? INTERP_EXPRESSION : 0;
lexer->result_symbol = COMMENT_END;
return true;
}
}
else if (valid_symbols[COMMENT_CONTENT])
{
// consume all characters till a short comment's end
do
{
consume(lexer);
} while (lexer->lookahead != '\n' && !lexer->eof(lexer));
lexer->result_symbol = COMMENT_CONTENT;
return true;
}
break;
}
case SHORT_SQ_STRING:
case SHORT_DQ_STRING:
{
// define the short string's delimiter
const char delimiter = state->started == SHORT_SQ_STRING ? SQ_STRING_DELIMITER : DQ_STRING_DELIMITER;
// try to match the short string's end (" or ')
if (consume_if(lexer, delimiter))
{
if (valid_symbols[STRING_END])
{
state->started = state->idepth > 0 ? INTERP_EXPRESSION : 0;
lexer->result_symbol = STRING_END;
return true;
}
}
else if (lexer->lookahead != '\n' && !lexer->eof(lexer))
{
if (valid_symbols[STRING_CONTENT])
{
// consume any character till a short string's end, new line or eof
do
{
escape_handler(lexer);
consume(lexer);
} while (lexer->lookahead != delimiter && lexer->lookahead != '\n' && !lexer->eof(lexer));
lexer->result_symbol = STRING_CONTENT;
return true;
}
}
break;
}
case TICK_STRING:
{
const char delimiter = TICK_DELIMITER;
if (consume_if(lexer, delimiter))
{
if (valid_symbols[INTERP_END])
{
state->started = state->idepth > 0 ? INTERP_EXPRESSION : 0;
lexer->result_symbol = INTERP_END;
return true;
}
}
else if (consume_if(lexer, '{'))
{
state->idepth++;
state->started = INTERP_EXPRESSION;
lexer->result_symbol = INTERP_BRACE_OPEN;
return true;
}
else if (lexer->lookahead != '\n' && !lexer->eof(lexer))
{
if (valid_symbols[INTERP_CONTENT])
{
do
{
if (lexer->lookahead == '{')
{
break;
}
else
{
if (escape_handler(lexer))
continue;
}
consume(lexer);
} while (lexer->lookahead != delimiter && lexer->lookahead != '\n' && !lexer->eof(lexer));
lexer->result_symbol = INTERP_CONTENT;
return true;
}
}
break;
}
case LONG_COMMENT:
case LONG_STRING:
{
const bool is_inside_a_comment = state->started == LONG_COMMENT;
bool some_characters_were_consumed = false;
if (is_inside_a_comment ? valid_symbols[COMMENT_END] : valid_symbols[STRING_END])
{
// try to match the long comment's/string's end (]=*])
if (consume_if(lexer, ']'))
{
if (scan_depth(lexer, state->depth) && consume_if(lexer, ']'))
{
state->started = state->idepth > 0 ? INTERP_EXPRESSION : 0;
state->depth = 0;
lexer->result_symbol = is_inside_a_comment ? COMMENT_END : STRING_END;
return true;
}
some_characters_were_consumed = true;
}
}
if (is_inside_a_comment ? valid_symbols[COMMENT_CONTENT] : valid_symbols[STRING_CONTENT])
{
if (!some_characters_were_consumed)
{
if (lexer->eof(lexer))
{
break;
}
// consume the next character as it can't start a long comment's/string's end ([)
consume(lexer);
}
// consume any character till a long comment's/string's end or eof
while (true)
{
lexer->mark_end(lexer);
if (consume_if(lexer, ']'))
{
if (scan_depth(lexer, state->depth))
{
if (consume_if(lexer, ']'))
{
break;
}
}
else
{
continue;
}
}
if (lexer->eof(lexer))
{
break;
}
consume(lexer);
}
lexer->result_symbol = is_inside_a_comment ? COMMENT_CONTENT : STRING_CONTENT;
return true;
}
break;
}
case INTERP_EXPRESSION:
{
while (skipwspace(lexer))
;
if (valid_symbols[INTERP_BRACE_CLOSE])
{
if (consume_if(lexer, '}'))
{
state->idepth--;
state->started = TICK_STRING;
lexer->result_symbol = INTERP_BRACE_CLOSE;
return true;
}
}
}
default:
{
// ignore all whitespace
while (skipwspace(lexer))
;
state->started = 0;
if (valid_symbols[COMMENT_START])
{
// try to match a short comment's start (--)
if (consume_if(lexer, '-'))
{
if (consume_if(lexer, '-'))
{
state->started = SHORT_COMMENT;
// try to match a long comment's start (--[=*[)
lexer->mark_end(lexer);
if (consume_if(lexer, '['))
{
unsigned int possible_depth = get_depth(lexer);
if (consume_if(lexer, '['))
{
state->started = LONG_COMMENT;
state->depth = possible_depth;
lexer->mark_end(lexer);
}
}
lexer->result_symbol = COMMENT_START;
return true;
}
break;
}
}
if (valid_symbols[STRING_START])
{
// try to match a short single-quoted string's start (")
if (consume_if(lexer, SQ_STRING_DELIMITER))
{
state->started = SHORT_SQ_STRING;
}
// try to match a short double-quoted string's start (')
else if (consume_if(lexer, DQ_STRING_DELIMITER))
{
state->started = SHORT_DQ_STRING;
}
// try to match a long string's start ([=*[)
else if (consume_if(lexer, '['))
{
unsigned int possible_depth = get_depth(lexer);
if (consume_if(lexer, '['))
{
state->started = LONG_STRING;
state->depth = possible_depth;
}
}
if (state->started)
{
lexer->result_symbol = STRING_START;
return true;
}
}
if (valid_symbols[INTERP_START])
{
if (consume_if(lexer, TICK_DELIMITER))
{
state->started = TICK_STRING;
lexer->result_symbol = INTERP_START;
return true;
}
}
state->started = state->idepth > 0 ? INTERP_EXPRESSION : 0;
break;
}
}
return false;
}

View File

@ -0,0 +1,224 @@
#ifndef TREE_SITTER_PARSER_H_
#define TREE_SITTER_PARSER_H_
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#define ts_builtin_sym_error ((TSSymbol)-1)
#define ts_builtin_sym_end 0
#define TREE_SITTER_SERIALIZATION_BUFFER_SIZE 1024
typedef uint16_t TSStateId;
#ifndef TREE_SITTER_API_H_
typedef uint16_t TSSymbol;
typedef uint16_t TSFieldId;
typedef struct TSLanguage TSLanguage;
#endif
typedef struct {
TSFieldId field_id;
uint8_t child_index;
bool inherited;
} TSFieldMapEntry;
typedef struct {
uint16_t index;
uint16_t length;
} TSFieldMapSlice;
typedef struct {
bool visible;
bool named;
bool supertype;
} TSSymbolMetadata;
typedef struct TSLexer TSLexer;
struct TSLexer {
int32_t lookahead;
TSSymbol result_symbol;
void (*advance)(TSLexer *, bool);
void (*mark_end)(TSLexer *);
uint32_t (*get_column)(TSLexer *);
bool (*is_at_included_range_start)(const TSLexer *);
bool (*eof)(const TSLexer *);
};
typedef enum {
TSParseActionTypeShift,
TSParseActionTypeReduce,
TSParseActionTypeAccept,
TSParseActionTypeRecover,
} TSParseActionType;
typedef union {
struct {
uint8_t type;
TSStateId state;
bool extra;
bool repetition;
} shift;
struct {
uint8_t type;
uint8_t child_count;
TSSymbol symbol;
int16_t dynamic_precedence;
uint16_t production_id;
} reduce;
uint8_t type;
} TSParseAction;
typedef struct {
uint16_t lex_state;
uint16_t external_lex_state;
} TSLexMode;
typedef union {
TSParseAction action;
struct {
uint8_t count;
bool reusable;
} entry;
} TSParseActionEntry;
struct TSLanguage {
uint32_t version;
uint32_t symbol_count;
uint32_t alias_count;
uint32_t token_count;
uint32_t external_token_count;
uint32_t state_count;
uint32_t large_state_count;
uint32_t production_id_count;
uint32_t field_count;
uint16_t max_alias_sequence_length;
const uint16_t *parse_table;
const uint16_t *small_parse_table;
const uint32_t *small_parse_table_map;
const TSParseActionEntry *parse_actions;
const char * const *symbol_names;
const char * const *field_names;
const TSFieldMapSlice *field_map_slices;
const TSFieldMapEntry *field_map_entries;
const TSSymbolMetadata *symbol_metadata;
const TSSymbol *public_symbol_map;
const uint16_t *alias_map;
const TSSymbol *alias_sequences;
const TSLexMode *lex_modes;
bool (*lex_fn)(TSLexer *, TSStateId);
bool (*keyword_lex_fn)(TSLexer *, TSStateId);
TSSymbol keyword_capture_token;
struct {
const bool *states;
const TSSymbol *symbol_map;
void *(*create)(void);
void (*destroy)(void *);
bool (*scan)(void *, TSLexer *, const bool *symbol_whitelist);
unsigned (*serialize)(void *, char *);
void (*deserialize)(void *, const char *, unsigned);
} external_scanner;
const TSStateId *primary_state_ids;
};
/*
* Lexer Macros
*/
#define START_LEXER() \
bool result = false; \
bool skip = false; \
bool eof = false; \
int32_t lookahead; \
goto start; \
next_state: \
lexer->advance(lexer, skip); \
start: \
skip = false; \
lookahead = lexer->lookahead;
#define ADVANCE(state_value) \
{ \
state = state_value; \
goto next_state; \
}
#define SKIP(state_value) \
{ \
skip = true; \
state = state_value; \
goto next_state; \
}
#define ACCEPT_TOKEN(symbol_value) \
result = true; \
lexer->result_symbol = symbol_value; \
lexer->mark_end(lexer);
#define END_STATE() return result;
/*
* Parse Table Macros
*/
#define SMALL_STATE(id) id - LARGE_STATE_COUNT
#define STATE(id) id
#define ACTIONS(id) id
#define SHIFT(state_value) \
{{ \
.shift = { \
.type = TSParseActionTypeShift, \
.state = state_value \
} \
}}
#define SHIFT_REPEAT(state_value) \
{{ \
.shift = { \
.type = TSParseActionTypeShift, \
.state = state_value, \
.repetition = true \
} \
}}
#define SHIFT_EXTRA() \
{{ \
.shift = { \
.type = TSParseActionTypeShift, \
.extra = true \
} \
}}
#define REDUCE(symbol_val, child_count_val, ...) \
{{ \
.reduce = { \
.type = TSParseActionTypeReduce, \
.symbol = symbol_val, \
.child_count = child_count_val, \
__VA_ARGS__ \
}, \
}}
#define RECOVER() \
{{ \
.type = TSParseActionTypeRecover \
}}
#define ACCEPT_INPUT() \
{{ \
.type = TSParseActionTypeAccept \
}}
#ifdef __cplusplus
}
#endif
#endif // TREE_SITTER_PARSER_H_

8
Luau/go.mod Normal file
View File

@ -0,0 +1,8 @@
module Luau
go 1.22.1
require (
github.com/TwiN/go-color v1.4.1
github.com/smacker/go-tree-sitter v0.0.0-20240402012804-99ab967cf9b9
)

14
Luau/go.sum Normal file
View File

@ -0,0 +1,14 @@
github.com/TwiN/go-color v1.4.1 h1:mqG0P/KBgHKVqmtL5ye7K0/Gr4l6hTksPgTgMk3mUzc=
github.com/TwiN/go-color v1.4.1/go.mod h1:WcPf/jtiW95WBIsEeY1Lc/b8aaWoiqQpu5cf8WFxu+s=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/smacker/go-tree-sitter v0.0.0-20240402012804-99ab967cf9b9 h1:5HSGLeLdHwoLEEr794DtHfFD67aF4rPLLQFfbVvEF2w=
github.com/smacker/go-tree-sitter v0.0.0-20240402012804-99ab967cf9b9/go.mod h1:q99oHDsbP0xRwmn7Vmob8gbSMNyvJ83OauXPSuHQuKE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.4/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

260
Luau/main.go Normal file
View File

@ -0,0 +1,260 @@
package main
import (
luau "Luau/binding"
"context"
"fmt"
"os"
"strings"
c "github.com/TwiN/go-color"
sitter "github.com/smacker/go-tree-sitter"
)
func isInside(ntype string, node *sitter.Node) bool {
for node != nil {
if node.Type() == ntype {
return true
}
node = node.Parent()
}
return false
}
func formatCode(sourceCode []byte, tree *sitter.Tree) string {
node := tree.RootNode()
if node.HasError() {
fmt.Println("Error parsing code")
return string(sourceCode)
}
var formatted strings.Builder
indent := 0
// -1 means at the start of the file
// 0 means outside of a local var statement
// 1 means inside a local var statement
var insideLocal []bool
var insideComment []bool
writeIndent := func() {
formatted.WriteString(strings.Repeat("\t", indent))
}
writeFormatted := func(node sitter.Node) {
// Write the formatted content to the string builder
content := node.Content(sourceCode)
ntype := node.Type()
insideLocal = append(insideLocal, isInside("local_var_stmt", &node))
insideComment = append(insideComment, isInside("comment", &node))
parent := node.Parent()
switch ntype {
case "comment":
if len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString(content)
case "local":
if len(insideLocal) > 1 && !insideLocal[len(insideLocal)-2] &&
len(insideComment) > 1 && !insideComment[len(insideComment)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("local ")
case "name":
if parent.Parent().Type() == "call_stmt" && !isInside("local_var_stmt", &node) {
if len(insideLocal) > 1 && insideLocal[len(insideLocal)-2] {
formatted.WriteString("\n")
}
formatted.WriteString("\n")
writeIndent()
}
formatted.WriteString(content)
case "=":
formatted.WriteString(" = ")
case "==":
formatted.WriteString(" == ")
case "number":
formatted.WriteString(content)
case "return":
formatted.WriteString("\n")
writeIndent()
formatted.WriteString("return ")
case "if":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
writeIndent()
fallthrough
case "ifexp":
formatted.WriteString("if ")
default:
// damn better be unreachable
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "then":
ifType := parent.Type()
switch ifType {
case "if_stmt":
formatted.WriteString(" then")
indent++
case "ifexp":
formatted.WriteString(" then ")
case "elseif_clause":
formatted.WriteString(" then")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "elseif":
pType := parent.Type()
// if it's in a statement, the parent will be the elseif clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "elseif_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("elseif ")
indent++
case "ifexp":
formatted.WriteString(" elseif ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "else":
pType := parent.Type()
// if it's in a statement, the parent will be the else clause
// if it's in an expression, the parent will be the if expression
var ifType string
switch pType {
case "else_clause":
ifType = parent.Parent().Type()
case "ifexp":
ifType = pType
default:
panic(c.InRed("Unknown parent type ") + c.InYellow(pType))
}
switch ifType {
case "if_stmt":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("else")
indent++
case "ifexp":
formatted.WriteString(" else ")
default:
panic(c.InRed("Unknown if type ") + c.InYellow(ifType))
}
case "end":
formatted.WriteString("\n")
indent--
writeIndent()
formatted.WriteString("end")
case "string":
if parent.Type() == "arglist" && parent.ChildCount() == 1 {
// `print"whatever"` -> `print "whatever"`
formatted.WriteString(" ")
}
formatted.WriteString(content)
case "interp_start":
fallthrough
case "interp_end":
formatted.WriteString("`")
case "interp_content":
formatted.WriteString(content)
case "interp_brace_open":
formatted.WriteString("{")
case "interp_brace_close":
formatted.WriteString("}")
case ":":
formatted.WriteString(":")
case ".":
formatted.WriteString(".")
case "(":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString("(")
} else {
formatted.WriteString(" ")
}
case ")":
argType := parent.Child(1).Type()
// `print("whatever")` -> `print "whatever"`
if parent.Type() == "arglist" && parent.ChildCount() > 3 || argType != "string" && argType != "table" {
formatted.WriteString(")")
}
case ",":
formatted.WriteString(", ")
case "true":
formatted.WriteString("true")
case "false":
formatted.WriteString("false")
default:
panic(c.InRed("Unknown node type ") + c.InYellow(ntype))
}
}
var appendLeaf func(node sitter.Node)
appendLeaf = func(node sitter.Node) {
// Print only the leaf nodes
if node.ChildCount() == 0 {
writeFormatted(node)
} else {
for i := 0; i < int(node.ChildCount()); i++ {
appendLeaf(*(node.Child(i)))
}
}
}
appendLeaf(*node)
return formatted.String()
}
func main() {
parser := sitter.NewParser()
parser.SetLanguage(luau.GetLuau())
filename := "test.luau"
sourceCode, err := os.ReadFile(filename)
if err != nil {
fmt.Println(err)
return
}
tree, _ := parser.ParseCtx(context.Background(), nil, sourceCode)
formatted := formatCode(sourceCode, tree)
// replace all ending newlines with a single newline
formatted = strings.Trim(formatted, "\n") + "\n"
// write back to file
err = os.WriteFile(filename, []byte(formatted), 0o644)
if err != nil {
fmt.Println(err)
return
}
}