package syntax import ( "bytes" "sort" "strings" "git.gophernest.net/azpect/TextEditor/internal/core" "git.gophernest.net/azpect/TextEditor/internal/style" "github.com/charmbracelet/lipgloss" sitter "github.com/tree-sitter/go-tree-sitter" ) type TreeSitterEngine struct { styles style.Styles registry *languageRegistry cache map[*core.Buffer]*bufferCache } type bufferCache struct { built bool lines map[int][]lipgloss.Style count int parser *sitter.Parser tree *sitter.Tree source []byte dirtyAll bool dirty []lineRange langID string language *sitter.Language query *sitter.Query } type lineRange struct { start int end int } type captureRange struct { startRow uint startCol uint endRow uint endCol uint name string } // NewTreeSitterEngine: Creates a new tree sitter engine with the styles // provided attached. // // Currently, this engine only support GoLang. But more languages can be // added with easy. func NewTreeSitterEngine(styles style.Styles) *TreeSitterEngine { return &TreeSitterEngine{ styles: styles, registry: newLanguageRegistry(), cache: map[*core.Buffer]*bufferCache{}, } } func (e *TreeSitterEngine) PrepareBuffer(buf *core.Buffer) { // Cannot prepare a nil buffer if buf == nil { return } // Get the buffers cache and return if we are already "built" (ready to render). bc := e.getCache(buf) if bc.count != buf.LineCount() { bc.dirtyAll = true } if bc.dirtyAll { bc.built = false } if bc.built { return } // If we do no support the buffer, load empty styles into the cache lang, ok, err := e.resolveBufferLanguage(buf, bc) if err != nil || !ok { bc.lines = map[int][]lipgloss.Style{} bc.built = true return } _ = lang e.buildFullBuffer(buf, bc) } func (e *TreeSitterEngine) LineStyleMap(buf *core.Buffer, line int) []lipgloss.Style { if buf == nil { return nil } e.PrepareBuffer(buf) bc := e.getCache(buf) if s, ok := bc.lines[line]; ok { return s } runes := []rune(buf.Line(line)) out := make([]lipgloss.Style, len(runes)) for i := range out { out[i] = e.styles.LineStyle } bc.lines[line] = out return out } func (e *TreeSitterEngine) ApplyEdit(buf *core.Buffer, edit *core.BufferEdit) { if buf == nil || edit == nil { return } bc := e.getCache(buf) lang, ok, err := e.resolveBufferLanguage(buf, bc) if err != nil || !ok { bc.built = false bc.dirtyAll = true return } _ = lang if bc.parser == nil { bc.parser = sitter.NewParser() bc.parser.SetLanguage(bc.language) } if bc.tree == nil || len(bc.source) == 0 { bc.dirtyAll = true return } bc.tree.Edit(&sitter.InputEdit{ StartByte: edit.StartByte, OldEndByte: edit.OldEndByte, NewEndByte: edit.NewEndByte, StartPosition: sitter.NewPoint(edit.StartPoint.Row, edit.StartPoint.Column), OldEndPosition: sitter.NewPoint(edit.OldEndPoint.Row, edit.OldEndPoint.Column), NewEndPosition: sitter.NewPoint(edit.NewEndPoint.Row, edit.NewEndPoint.Column), }) newSource := buildBufferSource(buf) newTree := bc.parser.Parse(newSource, bc.tree) if newTree == nil { bc.dirtyAll = true return } changed := bc.tree.ChangedRanges(newTree) newLineCount := buf.LineCount() if newLineCount != bc.count { bc.dirtyAll = true bc.dirty = nil } else { startRow := int(edit.StartPoint.Row) endRow := int(max(edit.OldEndPoint.Row, edit.NewEndPoint.Row)) addDirtyRange(bc, startRow, endRow) for _, r := range changed { addDirtyRange(bc, int(r.StartPoint.Row), int(r.EndPoint.Row)) } } bc.source = newSource bc.tree.Close() bc.tree = newTree bc.built = false } // TreeSitterEngine.InvalidateBuffer: Deletes the entire buffers cache from the engine. If the // buffer provided is nil, this function does nothing. func (e *TreeSitterEngine) InvalidateBuffer(buf *core.Buffer) { if buf == nil { return } bc := e.getCache(buf) bc.built = false bc.dirtyAll = true bc.dirty = nil } // TreeSitterEngine.InvalidateLines: Deletes lines between start and end (inclusive) from the // buffers cache. Then marks the cache as "unbuilt." If the buffer provided is nil, this function // does nothing. func (e *TreeSitterEngine) InvalidateLines(buf *core.Buffer, startLine, endLine int) { if buf == nil { return } bc := e.getCache(buf) addDirtyRange(bc, startLine, endLine) bc.built = false } // TreeSitterEngine.supportsBuffer: Returns whether the buffer can be parsed and highlighted // by the engine. When false, there should be a fallback. func (e *TreeSitterEngine) resolveBufferLanguage(buf *core.Buffer, bc *bufferCache) (*resolvedLanguage, bool, error) { if e.registry == nil { e.registry = newLanguageRegistry() } resolved, ok, err := e.registry.resolve(buf.Filetype, buf.Filename) if err != nil || !ok { return nil, ok, err } if bc.langID != resolved.id { bc.langID = resolved.id bc.language = resolved.language bc.query = resolved.query if bc.parser != nil { bc.parser.SetLanguage(bc.language) } bc.dirtyAll = true bc.built = false } return resolved, true, nil } // TreeSitterEngine.getCache: Returns the buffers cache. If the cache does not exist, a new one // is created and applied to the engines cache map. func (e *TreeSitterEngine) getCache(buf *core.Buffer) *bufferCache { if bc, ok := e.cache[buf]; ok { return bc } bc := &bufferCache{lines: map[int][]lipgloss.Style{}} e.cache[buf] = bc return bc } func (e *TreeSitterEngine) buildFullBuffer(buf *core.Buffer, bc *bufferCache) { lineCount := buf.LineCount() // Load the lines into memory. There is no method for this due to the buffers // internal implementation using a gap buffer. So the "Lines" property is of // type []*GapBuffer. lines := make([]string, lineCount) for i := range lineCount { lines[i] = buf.Line(i) } fullRebuild := bc.dirtyAll || len(bc.lines) == 0 || len(bc.dirty) == 0 if fullRebuild { bc.lines = map[int][]lipgloss.Style{} for i := range lineCount { bc.lines[i] = defaultLineStyles(lines[i], e.styles.LineStyle) } } else { dirty := normalizedDirtyRanges(bc.dirty, lineCount) for _, r := range dirty { for i := r.start; i <= r.end; i++ { bc.lines[i] = defaultLineStyles(lines[i], e.styles.LineStyle) } } } source := buildBufferSource(buf) useCurrentTree := bc.tree != nil && bytes.Equal(bc.source, source) if bc.parser == nil { bc.parser = sitter.NewParser() bc.parser.SetLanguage(bc.language) } if !useCurrentTree { var baseTree *sitter.Tree if bc.tree != nil { baseTree = bc.tree } tree := bc.parser.Parse(source, baseTree) if tree == nil { bc.built = true return } if bc.tree != nil { bc.tree.Close() } bc.tree = tree bc.source = source } root := bc.tree.RootNode() cursor := sitter.NewQueryCursor() defer cursor.Close() var captures []captureRange if fullRebuild { iter := cursor.Captures(bc.query, root, source) captures = append(captures, collectCaptures(iter, bc.query)...) } else { dirty := normalizedDirtyRanges(bc.dirty, lineCount) for _, r := range dirty { queryStart := max(0, r.start-1) queryEnd := min(lineCount-1, r.end+1) rangeCursor := sitter.NewQueryCursor() rangeCursor.SetPointRange( sitter.NewPoint(uint(queryStart), 0), sitter.NewPoint(uint(queryEnd+1), 0), ) iter := rangeCursor.Captures(bc.query, root, source) captures = append(captures, collectCaptures(iter, bc.query)...) rangeCursor.Close() } } // Sort the captures in order of their character occurrence in the file sort.Slice(captures, func(i, j int) bool { if captures[i].startRow == captures[j].startRow { if captures[i].startCol == captures[j].startCol { if captures[i].endRow == captures[j].endRow { return captures[i].endCol > captures[j].endCol } return captures[i].endRow > captures[j].endRow } return captures[i].startCol < captures[j].startCol } return captures[i].startRow < captures[j].startRow }) // Basically, this code works by rewriting the same range and the last capture wins. // This is a great spot for optimization: No need to draw many times, just pick the best one. // Or maybe when we sort, if we find ones that are the same, remove the first one, and then // we just keep the last one. Then this code can stay the same but will not suffer so many // rewrites. targetDirty := normalizedDirtyRanges(bc.dirty, lineCount) for _, c := range captures { sty := style.CaptureStyle(e.styles.LineStyle, c.name) for row := c.startRow; row <= c.endRow; row++ { if int(row) >= len(lines) { break } if !fullRebuild && !rowInRanges(int(row), targetDirty) { continue } lineBytes := []byte(lines[row]) startByteCol := uint(0) if row == c.startRow { startByteCol = c.startCol } endByteCol := uint(len(lineBytes)) if row == c.endRow { endByteCol = min(c.endCol, uint(len(lineBytes))) } startRune := byteColToRuneIndex(lineBytes, int(startByteCol)) endRune := byteColToRuneIndex(lineBytes, int(endByteCol)) rowStyles := bc.lines[int(row)] if startRune < 0 { startRune = 0 } if endRune > len(rowStyles) { endRune = len(rowStyles) } if startRune >= endRune { continue } for i := startRune; i < endRune; i++ { rowStyles[i] = sty } bc.lines[int(row)] = rowStyles } } bc.dirtyAll = false bc.dirty = nil bc.count = lineCount bc.built = true } func addDirtyRange(bc *bufferCache, start, end int) { if bc == nil { return } if end < start { start, end = end, start } if start < 0 { start = 0 } if end < 0 { end = 0 } bc.dirty = append(bc.dirty, lineRange{start: start, end: end}) bc.dirty = mergeRanges(bc.dirty) } func normalizedDirtyRanges(ranges []lineRange, lineCount int) []lineRange { if lineCount <= 0 || len(ranges) == 0 { return nil } clamped := make([]lineRange, 0, len(ranges)) for _, r := range ranges { start := max(0, r.start) end := min(lineCount-1, r.end) if start > end { continue } clamped = append(clamped, lineRange{start: start, end: end}) } return mergeRanges(clamped) } func mergeRanges(ranges []lineRange) []lineRange { if len(ranges) == 0 { return nil } sort.Slice(ranges, func(i, j int) bool { if ranges[i].start == ranges[j].start { return ranges[i].end < ranges[j].end } return ranges[i].start < ranges[j].start }) merged := make([]lineRange, 0, len(ranges)) cur := ranges[0] for i := 1; i < len(ranges); i++ { n := ranges[i] if n.start <= cur.end+1 { if n.end > cur.end { cur.end = n.end } continue } merged = append(merged, cur) cur = n } merged = append(merged, cur) return merged } func rowInRanges(row int, ranges []lineRange) bool { for _, r := range ranges { if row >= r.start && row <= r.end { return true } } return false } func defaultLineStyles(line string, base lipgloss.Style) []lipgloss.Style { runes := []rune(line) row := make([]lipgloss.Style, len(runes)) for i := range row { row[i] = base } return row } func collectCaptures(iter sitter.QueryCaptures, query *sitter.Query) []captureRange { if query == nil { return nil } names := query.CaptureNames() out := []captureRange{} for match, captureIdx := iter.Next(); match != nil; match, captureIdx = iter.Next() { capture := match.Captures[captureIdx] if int(capture.Index) >= len(names) { continue } name := names[capture.Index] if name == "spell" { continue } node := capture.Node start := node.StartPosition() end := node.EndPosition() out = append(out, captureRange{ startRow: start.Row, startCol: start.Column, endRow: end.Row, endCol: end.Column, name: name, }) } return out } func buildBufferSource(buf *core.Buffer) []byte { lineCount := buf.LineCount() if lineCount == 0 { return []byte{} } lines := make([]string, lineCount) for i := range lineCount { lines[i] = buf.Line(i) } return []byte(strings.Join(lines, "\n")) } func byteColToRuneIndex(line []byte, byteCol int) int { if byteCol <= 0 { return 0 } if byteCol >= len(line) { return len([]rune(string(line))) } prefix := line[:byteCol] return len([]rune(string(prefix))) }