Gim/internal/syntax/treesitter.go
2026-04-07 11:01:07 -07:00

521 lines
12 KiB
Go

package syntax
import (
"bytes"
"sort"
"strings"
"git.gophernest.net/azpect/TextEditor/internal/core"
"git.gophernest.net/azpect/TextEditor/internal/style"
"github.com/charmbracelet/lipgloss"
sitter "github.com/tree-sitter/go-tree-sitter"
)
type TreeSitterEngine struct {
styles style.Styles
registry *languageRegistry
cache map[*core.Buffer]*bufferCache
}
type bufferCache struct {
built bool
lines map[int][]lipgloss.Style
count int
parser *sitter.Parser
tree *sitter.Tree
source []byte
dirtyAll bool
dirty []lineRange
langID string
language *sitter.Language
query *sitter.Query
}
type lineRange struct {
start int
end int
}
type captureRange struct {
startRow uint
startCol uint
endRow uint
endCol uint
name string
}
// NewTreeSitterEngine: Creates a new tree sitter engine with the styles
// provided attached.
//
// Currently, this engine only support GoLang. But more languages can be
// added with easy.
func NewTreeSitterEngine(styles style.Styles) *TreeSitterEngine {
return &TreeSitterEngine{
styles: styles,
registry: newLanguageRegistry(),
cache: map[*core.Buffer]*bufferCache{},
}
}
func (e *TreeSitterEngine) PrepareBuffer(buf *core.Buffer) {
// Cannot prepare a nil buffer
if buf == nil {
return
}
// Get the buffers cache and return if we are already "built" (ready to render).
bc := e.getCache(buf)
if bc.count != buf.LineCount() {
bc.dirtyAll = true
}
if bc.dirtyAll {
bc.built = false
}
if bc.built {
return
}
// If we do no support the buffer, load empty styles into the cache
lang, ok, err := e.resolveBufferLanguage(buf, bc)
if err != nil || !ok {
bc.lines = map[int][]lipgloss.Style{}
bc.built = true
return
}
_ = lang
e.buildFullBuffer(buf, bc)
}
func (e *TreeSitterEngine) LineStyleMap(buf *core.Buffer, line int) []lipgloss.Style {
if buf == nil {
return nil
}
e.PrepareBuffer(buf)
bc := e.getCache(buf)
if s, ok := bc.lines[line]; ok {
return s
}
runes := []rune(buf.Line(line))
out := make([]lipgloss.Style, len(runes))
for i := range out {
out[i] = e.styles.LineStyle
}
bc.lines[line] = out
return out
}
func (e *TreeSitterEngine) ApplyEdit(buf *core.Buffer, edit *core.BufferEdit) {
if buf == nil || edit == nil {
return
}
bc := e.getCache(buf)
lang, ok, err := e.resolveBufferLanguage(buf, bc)
if err != nil || !ok {
bc.built = false
bc.dirtyAll = true
return
}
_ = lang
if bc.parser == nil {
bc.parser = sitter.NewParser()
bc.parser.SetLanguage(bc.language)
}
if bc.tree == nil || len(bc.source) == 0 {
bc.dirtyAll = true
return
}
bc.tree.Edit(&sitter.InputEdit{
StartByte: edit.StartByte,
OldEndByte: edit.OldEndByte,
NewEndByte: edit.NewEndByte,
StartPosition: sitter.NewPoint(edit.StartPoint.Row, edit.StartPoint.Column),
OldEndPosition: sitter.NewPoint(edit.OldEndPoint.Row, edit.OldEndPoint.Column),
NewEndPosition: sitter.NewPoint(edit.NewEndPoint.Row, edit.NewEndPoint.Column),
})
newSource := buildBufferSource(buf)
newTree := bc.parser.Parse(newSource, bc.tree)
if newTree == nil {
bc.dirtyAll = true
return
}
changed := bc.tree.ChangedRanges(newTree)
newLineCount := buf.LineCount()
if newLineCount != bc.count {
bc.dirtyAll = true
bc.dirty = nil
} else {
startRow := int(edit.StartPoint.Row)
endRow := int(max(edit.OldEndPoint.Row, edit.NewEndPoint.Row))
addDirtyRange(bc, startRow, endRow)
for _, r := range changed {
addDirtyRange(bc, int(r.StartPoint.Row), int(r.EndPoint.Row))
}
}
bc.source = newSource
bc.tree.Close()
bc.tree = newTree
bc.built = false
}
// TreeSitterEngine.InvalidateBuffer: Deletes the entire buffers cache from the engine. If the
// buffer provided is nil, this function does nothing.
func (e *TreeSitterEngine) InvalidateBuffer(buf *core.Buffer) {
if buf == nil {
return
}
bc := e.getCache(buf)
bc.built = false
bc.dirtyAll = true
bc.dirty = nil
}
// TreeSitterEngine.InvalidateLines: Deletes lines between start and end (inclusive) from the
// buffers cache. Then marks the cache as "unbuilt." If the buffer provided is nil, this function
// does nothing.
func (e *TreeSitterEngine) InvalidateLines(buf *core.Buffer, startLine, endLine int) {
if buf == nil {
return
}
bc := e.getCache(buf)
addDirtyRange(bc, startLine, endLine)
bc.built = false
}
// TreeSitterEngine.supportsBuffer: Returns whether the buffer can be parsed and highlighted
// by the engine. When false, there should be a fallback.
func (e *TreeSitterEngine) resolveBufferLanguage(buf *core.Buffer, bc *bufferCache) (*resolvedLanguage, bool, error) {
if e.registry == nil {
e.registry = newLanguageRegistry()
}
resolved, ok, err := e.registry.resolve(buf.Filetype, buf.Filename)
if err != nil || !ok {
return nil, ok, err
}
if bc.langID != resolved.id {
bc.langID = resolved.id
bc.language = resolved.language
bc.query = resolved.query
if bc.parser != nil {
bc.parser.SetLanguage(bc.language)
}
bc.dirtyAll = true
bc.built = false
}
return resolved, true, nil
}
// TreeSitterEngine.getCache: Returns the buffers cache. If the cache does not exist, a new one
// is created and applied to the engines cache map.
func (e *TreeSitterEngine) getCache(buf *core.Buffer) *bufferCache {
if bc, ok := e.cache[buf]; ok {
return bc
}
bc := &bufferCache{lines: map[int][]lipgloss.Style{}}
e.cache[buf] = bc
return bc
}
func (e *TreeSitterEngine) buildFullBuffer(buf *core.Buffer, bc *bufferCache) {
lineCount := buf.LineCount()
// Load the lines into memory. There is no method for this due to the buffers
// internal implementation using a gap buffer. So the "Lines" property is of
// type []*GapBuffer.
lines := make([]string, lineCount)
for i := range lineCount {
lines[i] = buf.Line(i)
}
fullRebuild := bc.dirtyAll || len(bc.lines) == 0 || len(bc.dirty) == 0
if fullRebuild {
bc.lines = map[int][]lipgloss.Style{}
for i := range lineCount {
bc.lines[i] = defaultLineStyles(lines[i], e.styles.LineStyle)
}
} else {
dirty := normalizedDirtyRanges(bc.dirty, lineCount)
for _, r := range dirty {
for i := r.start; i <= r.end; i++ {
bc.lines[i] = defaultLineStyles(lines[i], e.styles.LineStyle)
}
}
}
source := buildBufferSource(buf)
useCurrentTree := bc.tree != nil && bytes.Equal(bc.source, source)
if bc.parser == nil {
bc.parser = sitter.NewParser()
bc.parser.SetLanguage(bc.language)
}
if !useCurrentTree {
var baseTree *sitter.Tree
if bc.tree != nil {
baseTree = bc.tree
}
tree := bc.parser.Parse(source, baseTree)
if tree == nil {
bc.built = true
return
}
if bc.tree != nil {
bc.tree.Close()
}
bc.tree = tree
bc.source = source
}
root := bc.tree.RootNode()
cursor := sitter.NewQueryCursor()
defer cursor.Close()
var captures []captureRange
if fullRebuild {
iter := cursor.Captures(bc.query, root, source)
captures = append(captures, collectCaptures(iter, bc.query)...)
} else {
dirty := normalizedDirtyRanges(bc.dirty, lineCount)
for _, r := range dirty {
queryStart := max(0, r.start-1)
queryEnd := min(lineCount-1, r.end+1)
rangeCursor := sitter.NewQueryCursor()
rangeCursor.SetPointRange(
sitter.NewPoint(uint(queryStart), 0),
sitter.NewPoint(uint(queryEnd+1), 0),
)
iter := rangeCursor.Captures(bc.query, root, source)
captures = append(captures, collectCaptures(iter, bc.query)...)
rangeCursor.Close()
}
}
// Sort the captures in order of their character occurrence in the file
sort.Slice(captures, func(i, j int) bool {
if captures[i].startRow == captures[j].startRow {
if captures[i].startCol == captures[j].startCol {
if captures[i].endRow == captures[j].endRow {
return captures[i].endCol > captures[j].endCol
}
return captures[i].endRow > captures[j].endRow
}
return captures[i].startCol < captures[j].startCol
}
return captures[i].startRow < captures[j].startRow
})
// Basically, this code works by rewriting the same range and the last capture wins.
// This is a great spot for optimization: No need to draw many times, just pick the best one.
// Or maybe when we sort, if we find ones that are the same, remove the first one, and then
// we just keep the last one. Then this code can stay the same but will not suffer so many
// rewrites.
targetDirty := normalizedDirtyRanges(bc.dirty, lineCount)
for _, c := range captures {
sty := style.CaptureStyle(e.styles.LineStyle, c.name)
for row := c.startRow; row <= c.endRow; row++ {
if int(row) >= len(lines) {
break
}
if !fullRebuild && !rowInRanges(int(row), targetDirty) {
continue
}
lineBytes := []byte(lines[row])
startByteCol := uint(0)
if row == c.startRow {
startByteCol = c.startCol
}
endByteCol := uint(len(lineBytes))
if row == c.endRow {
endByteCol = min(c.endCol, uint(len(lineBytes)))
}
startRune := byteColToRuneIndex(lineBytes, int(startByteCol))
endRune := byteColToRuneIndex(lineBytes, int(endByteCol))
rowStyles := bc.lines[int(row)]
if startRune < 0 {
startRune = 0
}
if endRune > len(rowStyles) {
endRune = len(rowStyles)
}
if startRune >= endRune {
continue
}
for i := startRune; i < endRune; i++ {
rowStyles[i] = sty
}
bc.lines[int(row)] = rowStyles
}
}
bc.dirtyAll = false
bc.dirty = nil
bc.count = lineCount
bc.built = true
}
func addDirtyRange(bc *bufferCache, start, end int) {
if bc == nil {
return
}
if end < start {
start, end = end, start
}
if start < 0 {
start = 0
}
if end < 0 {
end = 0
}
bc.dirty = append(bc.dirty, lineRange{start: start, end: end})
bc.dirty = mergeRanges(bc.dirty)
}
func normalizedDirtyRanges(ranges []lineRange, lineCount int) []lineRange {
if lineCount <= 0 || len(ranges) == 0 {
return nil
}
clamped := make([]lineRange, 0, len(ranges))
for _, r := range ranges {
start := max(0, r.start)
end := min(lineCount-1, r.end)
if start > end {
continue
}
clamped = append(clamped, lineRange{start: start, end: end})
}
return mergeRanges(clamped)
}
func mergeRanges(ranges []lineRange) []lineRange {
if len(ranges) == 0 {
return nil
}
sort.Slice(ranges, func(i, j int) bool {
if ranges[i].start == ranges[j].start {
return ranges[i].end < ranges[j].end
}
return ranges[i].start < ranges[j].start
})
merged := make([]lineRange, 0, len(ranges))
cur := ranges[0]
for i := 1; i < len(ranges); i++ {
n := ranges[i]
if n.start <= cur.end+1 {
if n.end > cur.end {
cur.end = n.end
}
continue
}
merged = append(merged, cur)
cur = n
}
merged = append(merged, cur)
return merged
}
func rowInRanges(row int, ranges []lineRange) bool {
for _, r := range ranges {
if row >= r.start && row <= r.end {
return true
}
}
return false
}
func defaultLineStyles(line string, base lipgloss.Style) []lipgloss.Style {
runes := []rune(line)
row := make([]lipgloss.Style, len(runes))
for i := range row {
row[i] = base
}
return row
}
func collectCaptures(iter sitter.QueryCaptures, query *sitter.Query) []captureRange {
if query == nil {
return nil
}
names := query.CaptureNames()
out := []captureRange{}
for match, captureIdx := iter.Next(); match != nil; match, captureIdx = iter.Next() {
capture := match.Captures[captureIdx]
if int(capture.Index) >= len(names) {
continue
}
name := names[capture.Index]
if name == "spell" {
continue
}
node := capture.Node
start := node.StartPosition()
end := node.EndPosition()
out = append(out, captureRange{
startRow: start.Row,
startCol: start.Column,
endRow: end.Row,
endCol: end.Column,
name: name,
})
}
return out
}
func buildBufferSource(buf *core.Buffer) []byte {
lineCount := buf.LineCount()
if lineCount == 0 {
return []byte{}
}
lines := make([]string, lineCount)
for i := range lineCount {
lines[i] = buf.Line(i)
}
return []byte(strings.Join(lines, "\n"))
}
func byteColToRuneIndex(line []byte, byteCol int) int {
if byteCol <= 0 {
return 0
}
if byteCol >= len(line) {
return len([]rune(string(line)))
}
prefix := line[:byteCol]
return len([]rune(string(prefix)))
}