From f119a25d5849ff9d2bac947228ccbb6aae515b9f Mon Sep 17 00:00:00 2001 From: Aaron Zuspan Date: Sun, 11 Aug 2024 22:01:48 -0700 Subject: [PATCH] Implement finding all references --- src/spinasm_lsp/parser.py | 18 +++++++--- src/spinasm_lsp/server.py | 26 ++++++++++++-- tests/conftest.py | 74 +++++++++++++++++++++++++++++++++++++++ tests/test_server.py | 19 ++++++++++ 4 files changed, 131 insertions(+), 6 deletions(-) diff --git a/src/spinasm_lsp/parser.py b/src/spinasm_lsp/parser.py index a429247..5173480 100644 --- a/src/spinasm_lsp/parser.py +++ b/src/spinasm_lsp/parser.py @@ -193,10 +193,19 @@ class SPINAsmParser(fv1parse): def __init__(self, source: str): self.diagnostics: list[lsp.Diagnostic] = [] + """A list of diagnostic messages generated during parsing.""" + self.definitions: dict[str, lsp.Range] = {} + """A dictionary mapping symbol names to their definition location.""" + self.current_character: int = 0 + """The current column in the source file.""" + self.previous_character: int = 0 + """The last visitied column in the source file.""" + self.token_registry = TokenRegistry() + """A registry of tokens and their positions in the source file.""" super().__init__( source=source, @@ -301,11 +310,12 @@ def __next__(self): self._update_column() - token_start = lsp.Position( - line=self.current_line, character=self.current_character + token = Token( + symbol=self.sym, + start=lsp.Position( + line=self.current_line, character=self.current_character + ), ) - - token = Token(self.sym, start=token_start) self.token_registry.register_token(token) base_token = token.without_address_modifier() diff --git a/src/spinasm_lsp/server.py b/src/spinasm_lsp/server.py index b43d9c7..09a292d 100644 --- a/src/spinasm_lsp/server.py +++ b/src/spinasm_lsp/server.py @@ -213,7 +213,7 @@ async def definition( @server.feature(lsp.TEXT_DOCUMENT_DOCUMENT_SYMBOL) async def document_symbol_definitions( ls: SPINAsmLanguageServer, params: lsp.DocumentSymbolParams -) -> lsp.DocumentSymbol | None: +) -> list[lsp.DocumentSymbol]: """Returns the definitions of all symbols in the document.""" parser = await ls.get_parser(params.text_document.uri) @@ -255,7 +255,9 @@ async def prepare_rename(ls: SPINAsmLanguageServer, params: lsp.PrepareRenamePar @server.feature( lsp.TEXT_DOCUMENT_RENAME, options=lsp.RenameOptions(prepare_provider=True) ) -async def rename(ls: SPINAsmLanguageServer, params: lsp.RenameParams): +async def rename( + ls: SPINAsmLanguageServer, params: lsp.RenameParams +) -> lsp.WorkspaceEdit: parser = await ls.get_parser(params.text_document.uri) if (token := parser.token_registry.get_token_at_position(params.position)) is None: @@ -269,6 +271,26 @@ async def rename(ls: SPINAsmLanguageServer, params: lsp.RenameParams): return lsp.WorkspaceEdit(changes={params.text_document.uri: edits}) +@server.feature(lsp.TEXT_DOCUMENT_REFERENCES) +async def references( + ls: SPINAsmLanguageServer, params: lsp.ReferenceParams +) -> list[lsp.Location]: + parser = await ls.get_parser(params.text_document.uri) + + if (token := parser.token_registry.get_token_at_position(params.position)) is None: + return [] + + # Ignore address modifiers so that e.g. we can find all variations of addresses, + # e.g. `Delay` and `Delay#` + base_token = token.without_address_modifier() + matching_tokens = parser.token_registry.get_matching_tokens(str(base_token)) + + return [ + lsp.Location(uri=params.text_document.uri, range=t.range) + for t in matching_tokens + ] + + def start() -> None: server.start_io() diff --git a/tests/conftest.py b/tests/conftest.py index 207952b..8c8bd6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,14 @@ class PrepareRenameDict(TypedDict): message: str | None +class ReferenceDict(TypedDict): + """A dictionary to record reference locations for a symbol.""" + + symbol: str + position: lsp.Position + references: list[lsp.Location] + + class RenameDict(TypedDict): """A dictionary to record rename results for a symbol.""" @@ -52,6 +60,72 @@ class RenameDict(TypedDict): changes: list[lsp.TextEdit] +REFERENCES: list[ReferenceDict] = [ + { + # Variable + "symbol": "apout", + "position": lsp.Position(line=23, character=4), + "references": [ + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=23, character=4), + end=lsp.Position(line=23, character=9), + ), + ), + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=57, character=5), + end=lsp.Position(line=57, character=10), + ), + ), + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=60, character=5), + end=lsp.Position(line=60, character=10), + ), + ), + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=70, character=5), + end=lsp.Position(line=70, character=10), + ), + ), + ], + }, + { + "symbol": "ap1", + "position": lsp.Position(line=8, character=4), + "references": [ + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=8, character=4), + end=lsp.Position(line=8, character=7), + ), + ), + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=51, character=4), + end=lsp.Position(line=51, character=7), + ), + ), + lsp.Location( + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( + start=lsp.Position(line=52, character=5), + end=lsp.Position(line=52, character=8), + ), + ), + ], + }, +] + + SYMBOL_DEFINITIONS: list[SymbolDefinitionDict] = [ { # Variable diff --git a/tests/test_server.py b/tests/test_server.py index 9003ee5..608d770 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,10 +8,12 @@ HOVERS, PATCH_DIR, PREPARE_RENAMES, + REFERENCES, RENAMES, SYMBOL_DEFINITIONS, DefinitionDict, PrepareRenameDict, + ReferenceDict, RenameDict, SymbolDefinitionDict, ) @@ -247,3 +249,20 @@ async def test_symbol_definitions(symbol: SymbolDefinitionDict, client: Language item = matching[0] assert item.kind == symbol["kind"] assert item.range == symbol["range"] + + +@pytest.mark.parametrize("reference", REFERENCES, ids=lambda x: x["symbol"]) +@pytest.mark.asyncio() +async def test_references(reference: ReferenceDict, client: LanguageClient): + """Test that references to a symbol are correctly found.""" + patch = PATCH_DIR / "Basic.spn" + + result = await client.text_document_references_async( + params=lsp.ReferenceParams( + context=lsp.ReferenceContext(include_declaration=False), + position=reference["position"], + text_document=lsp.TextDocumentIdentifier(uri=f"file:///{patch.absolute()}"), + ) + ) + + assert result == reference["references"]