From 040cc2d2085fcffbabe9655ea57c3da2c0e8b049 Mon Sep 17 00:00:00 2001 From: Aaron Zuspan <50475791+aazuspan@users.noreply.github.com> Date: Tue, 20 Aug 2024 00:10:20 -0700 Subject: [PATCH] Semantic highlighting (#29) * Evaluate tokens and refactor token lookups * Refactor * WIP working semantic highlighting (no tests) * WIP refactor (token mixins) * Fix semantic encoding bug with empty encodings * WIP Refactor parser into position and diagnostic classes (parsing failing at EOF) * Fix EOF token bug * Get tests passing with refactored parser * Refactor tests to use dataclass for TestCase * Semantic token testing * Fix mutable dataclass default * Fix 3.8 typing incompatibility --- README.md | 5 +- src/spinasm_lsp/parser.py | 353 +++++----------- src/spinasm_lsp/server.py | 149 ++++--- src/spinasm_lsp/tokens.py | 398 +++++++++++++++++++ tests/conftest.py | 11 +- tests/server_tests/test_completion.py | 87 ++-- tests/server_tests/test_definition.py | 68 ++-- tests/server_tests/test_diagnostics.py | 119 +++--- tests/server_tests/test_hover.py | 132 +++--- tests/server_tests/test_prepare_rename.py | 75 ++-- tests/server_tests/test_reference.py | 40 +- tests/server_tests/test_rename.py | 59 +-- tests/server_tests/test_semantics.py | 79 ++++ tests/server_tests/test_signature_help.py | 139 ++++--- tests/server_tests/test_symbol_definition.py | 58 +-- tests/test_parser.py | 102 +---- tests/test_tokens.py | 169 ++++++++ 17 files changed, 1246 insertions(+), 797 deletions(-) create mode 100644 src/spinasm_lsp/tokens.py create mode 100644 tests/server_tests/test_semantics.py create mode 100644 tests/test_tokens.py diff --git a/README.md b/README.md index 4c9bb8a..097a497 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,11 @@ A Language Server Protocol (LSP) server to provide language support for the [SPI - **Diagnostics**: Reports the location of syntax errors and warnings. - **Signature help**: Shows parameter hints as instructions are entered. -- **Hover**: Shows documentation and values on hover. +- **Hover**: Shows documentation and assigned values on hover. - **Completion**: Provides suggestions for opcodes, labels, and variables. -- **Renaming**: Allows renaming labels and variables. +- **Renaming**: Renames matching labels or variables. - **Go to definition**: Jumps to the definition of a label, memory address, or variable. +- **Semantic highlighting**: Color codes variables, constants, instructions, etc. based on program semantics. ------ diff --git a/src/spinasm_lsp/parser.py b/src/spinasm_lsp/parser.py index 7c1e71b..70955e6 100644 --- a/src/spinasm_lsp/parser.py +++ b/src/spinasm_lsp/parser.py @@ -2,231 +2,93 @@ from __future__ import annotations -import bisect -import copy -from typing import Literal, TypedDict - import lsprotocol.types as lsp from asfv1 import fv1parse +from spinasm_lsp.tokens import ASFV1Token, LSPToken, ParsedToken, TokenLookup -class Symbol(TypedDict): - """ - The token specification used by asfv1. - - Note that we exclude EOF tokens, as they are ignored by the LSP. - """ - - type: Literal[ - "ASSEMBLER", - "INTEGER", - "LABEL", - "TARGET", - "MNEMONIC", - "OPERATOR", - "FLOAT", - "ARGSEP", - ] - txt: str - stxt: str - val: int | float | None - - -class Token: - """ - A parsed token. - - Parameters - ---------- - symbol : Symbol - The symbol parsed by asfv1 representing the token. - start : lsp.Position - The start position of the token in the source file. - end : lsp.Position, optional - The end position of the token in the source file. If not provided, the end - position is calculated based on the width of the symbol's stxt. - - Attributes - ---------- - symbol : Symbol - The symbol parsed by asfv1 representing the token. - range : lsp.Range - The location range of the token in the source file. - """ - - def __init__( - self, symbol: Symbol, start: lsp.Position, end: lsp.Position | None = None - ): - if end is None: - width = len(symbol["stxt"]) - end = lsp.Position(line=start.line, character=start.character + width) - - self.symbol: Symbol = symbol - self.range: lsp.Range = lsp.Range(start=start, end=end) - - def __repr__(self) -> str: - return self.symbol["stxt"] - - def concatenate(self, other: Token) -> Token: - """ - Concatenate by merging with another token, in place. - - In practice, this is used for the multi-word opcodes that are parsed as separate - tokens: CHO RDA, CHO RDAL, and CHO SOF. - """ - if any( - symbol_type not in ("MNEMONIC", "LABEL") - for symbol_type in (self.symbol["type"], other.symbol["type"]) - ): - raise TypeError("Only MNEMONIC and LABEL symbols can be concatenated.") - self.symbol["txt"] += f" {other.symbol['txt']}" - self.symbol["stxt"] += f" {other.symbol['stxt']}" - self.range.end = other.range.end - return self +class SPINAsmPositionParser(fv1parse): + """An SPINAsm parser that tracks zero-indexed parsing position.""" - def _clone(self) -> Token: - """Return a clone of the token to avoid mutating the original.""" - return copy.deepcopy(self) + def __init__(self, *args, **kwargs): + # Current position during parsing + self._current_character: int = 0 + self._previous_character: int = 0 - def without_address_modifier(self) -> Token: - """ - Create a clone of the token with the address modifier removed. - """ - if not str(self).endswith("#") and not str(self).endswith("^"): - return self + super().__init__(*args, **kwargs) - token = self._clone() - token.symbol["stxt"] = token.symbol["stxt"][:-1] - token.range.end.character -= 1 + # Store an unmodified version of the source for future reference + self._source: list[str] = self.source.copy() - return token + @property + def sline(self) -> int: + return self._sline + @sline.setter + def sline(self, value): + """Update the current line and reset the column.""" + self._sline = value -class TokenRegistry: - """A registry of tokens and their positions in a source file.""" + # Reset the column to 0 when we move to a new line + self._previous_character = self._current_character + self._current_character = 0 - def __init__(self, tokens: list[Token] | None = None) -> None: - self._prev_token: Token | None = None + @property + def _current_line(self) -> int: + """Get the zero-indexed current line.""" + return self.sline - 1 - """A dictionary mapping program lines to all Tokens on that line.""" - self._tokens_by_line: dict[int, list[Token]] = {} + @property + def position(self) -> lsp.Position: + """The current position of the parser in the source code.""" + return lsp.Position(line=self._current_line, character=self._current_character) - """A dictionary mapping token names to all matching Tokens in the program.""" - self._tokens_by_name: dict[str, list[Token]] = {} + @property + def parsed_symbol(self) -> ASFV1Token: + """Get the last parsed symbol.""" + return ASFV1Token(**self.sym) - for token in tokens or []: - self.register_token(token) + def __next__(self) -> None: + """Parse the next token and update the current character and line.""" + super().__next__() - def register_token(self, token: Token) -> None: - """Add a token to the registry.""" - # Handle multi-word CHO instructions by merging the second token with the first - # and skipping the second token. - if str(self._prev_token) == "CHO" and str(token) in ("RDA", "RDAL", "SOF"): - self._prev_token.concatenate(token) # type: ignore + # Don't advance position on EOF token, since we're done parsing + if self.parsed_symbol.type == "EOF": return - if token.range.start.line not in self._tokens_by_line: - self._tokens_by_line[token.range.start.line] = [] - - # Store the token on its line - self._tokens_by_line[token.range.start.line].append(token) - self._prev_token = token - - # Store user-defined tokens together by name. Other token types could be stored, - # but currently there's no use case for retrieving their positions. - if token.symbol["type"] in ("LABEL", "TARGET"): - # Tokens are stored by name without address modifiers, so that e.g. Delay# - # and Delay can be retrieved with the same query. This allows for renaming - # all instances of a memory token. - token = token.without_address_modifier() - - if str(token) not in self._tokens_by_name: - self._tokens_by_name[str(token)] = [] - - self._tokens_by_name[str(token)].append(token) - - def get_matching_tokens(self, token_name: str) -> list[Token]: - """Retrieve all tokens with a given name in the program.""" - return self._tokens_by_name.get(token_name.upper(), []) - - def get_token_at_position(self, position: lsp.Position) -> Token | None: - """Retrieve the token at the given position.""" - if position.line not in self._tokens_by_line: - return None - - line_tokens = self._tokens_by_line[position.line] - token_starts = [t.range.start.character for t in line_tokens] - token_ends = [t.range.end.character for t in line_tokens] - - idx = bisect.bisect_left(token_starts, position.character) - - # The index returned by bisect_left points to the start value >= character. This - # will either be the first character of the token or the start of the next - # token. First check if we're out of bounds, then shift left unless we're at the - # first character of the token. - if idx == len(line_tokens) or token_starts[idx] != position.character: - idx -= 1 + current_line_txt = self._source[self._current_line] + current_symbol = self.parsed_symbol.txt - # If the col falls after the end of the token, we're not inside a token. - if position.character > token_ends[idx]: - return None - - return line_tokens[idx] - - -class SPINAsmParser(fv1parse): - """A modified version of fv1parse optimized for use with LSP.""" - - sym: Symbol | None - - def __init__(self, source: str): - self.diagnostics: list[lsp.Diagnostic] = [] - """A list of diagnostic messages generated during parsing.""" - - self.definitions: dict[str, lsp.Range] = {} - """A dictionary mapping symbol names to their definition location.""" - - self.current_character: int = 0 - """The current column in the source file.""" + self._previous_character = self._current_character + # Start at the current column to skip previous duplicates of the symbol + self._current_character = current_line_txt.index( + current_symbol, self._current_character + ) - self.previous_character: int = 0 - """The last visitied column in the source file.""" - self.token_registry = TokenRegistry() - """A registry of tokens and their positions in the source file.""" +class SPINAsmDiagnosticParser(SPINAsmPositionParser): + """An SPINAsm parser that logs warnings and errors as LSP diagnostics.""" + def __init__(self, *args, **kwargs): super().__init__( - source=source, - clamp=True, - spinreals=False, + *args, # Ignore the callbacks in favor of overriding their callers wfunc=lambda *args, **kwargs: None, efunc=lambda *args, **kwargs: None, + **kwargs, ) - # Track which symbols were defined at initialization, e.g. registers and LFOs - self.constants: list[str] = list(self.symtbl.keys()) - # Keep an unchanged copy of the original source - self._source: list[str] = self.source.copy() - - def __mkopcodes__(self): - """ - No-op. - - Generating opcodes isn't needed for LSP functionality, so we'll skip it. - """ + self.diagnostics: list[lsp.Diagnostic] = [] + """A list of diagnostic messages generated during parsing.""" def _record_diagnostic( - self, msg: str, line: int, character: int, severity: lsp.DiagnosticSeverity + self, msg: str, *, position: lsp.Position, severity: lsp.DiagnosticSeverity ): """Record a diagnostic message for the LSP.""" self.diagnostics.append( lsp.Diagnostic( - range=lsp.Range( - start=lsp.Position(line, character=character), - end=lsp.Position(line, character=character), - ), + range=lsp.Range(start=position, end=position), message=msg, severity=severity, source="SPINAsm", @@ -241,8 +103,7 @@ def parseerror(self, msg: str, line: int | None = None): # Offset the line from the parser's 1-indexed line to the 0-indexed line self._record_diagnostic( msg, - line=line - 1, - character=self.current_character, + position=lsp.Position(line=line - 1, character=self._current_character), severity=lsp.DiagnosticSeverity.Error, ) @@ -250,8 +111,9 @@ def scanerror(self, msg: str): """Override to record scanning errors as LSP diagnostics.""" self._record_diagnostic( msg, - line=self.current_line, - character=self.current_character, + position=lsp.Position( + line=self._current_line, character=self._current_character + ), severity=lsp.DiagnosticSeverity.Error, ) @@ -263,78 +125,83 @@ def parsewarn(self, msg: str, line: int | None = None): # Offset the line from the parser's 1-indexed line to the 0-indexed line self._record_diagnostic( msg, - line=line - 1, - character=self.current_character, + position=lsp.Position(line=line - 1, character=self._current_character), severity=lsp.DiagnosticSeverity.Warning, ) - @property - def sline(self): - return self._sline - @sline.setter - def sline(self, value): - """Update the current line and reset the column.""" - self._sline = value +class SPINAsmParser(SPINAsmDiagnosticParser): + """An SPINAsm parser with position, diagnostics, and additional LSP features.""" - # Reset the column to 0 when we move to a new line - self.previous_character = self.current_character - self.current_character = 0 + def __init__(self, source: str): + # Intermediate token definitions and lookups set during parsing + self._definitions: dict[str, lsp.Range] = {} + self._parsed_tokens: TokenLookup[ParsedToken] = TokenLookup() - @property - def current_line(self): - """Get the zero-indexed current line.""" - return self.sline - 1 + super().__init__( + source=source, + clamp=True, + spinreals=False, + ) - @property - def previous_line(self): - """Get the zero-indexed previous line.""" - return self.prevline - 1 + # Store built-in constants that were defined at initialization. + self._constants: list[str] = list(self.symtbl.keys()) + + self.evaluated_tokens: TokenLookup[LSPToken] = TokenLookup() + """Tokens with additional metadata after evaluation.""" + + def __mkopcodes__(self): + """ + No-op. + + Generating opcodes isn't needed for LSP functionality, so we'll skip it. + """ def __next__(self): """Parse the next symbol and update the column and definitions.""" super().__next__() - if self.sym["type"] == "EOF": - return - self._update_column() + # Don't store the EOF token + if self.parsed_symbol.type == "EOF": + return - token = Token( - symbol=self.sym, - start=lsp.Position( - line=self.current_line, character=self.current_character - ), + token = self.parsed_symbol.at_position( + start=lsp.Position(self._current_line, character=self._current_character), ) - self.token_registry.register_token(token) + self._parsed_tokens.add_token(token) base_token = token.without_address_modifier() - is_user_definable = base_token.symbol["type"] in ("LABEL", "TARGET") - is_defined = str(base_token) in self.jmptbl or str(base_token) in self.symtbl + is_user_definable = base_token.type in ("LABEL", "TARGET") + is_defined = base_token.stxt in self.jmptbl or base_token.stxt in self.symtbl if ( is_user_definable and not is_defined # Labels appear before their target definition, so override when the target # is defined. - or base_token.symbol["type"] == "TARGET" + or base_token.type == "TARGET" ): - self.definitions[str(base_token)] = base_token.range - - def _update_column(self): - """Set the current column based on the last parsed symbol.""" - current_line_txt = self._source[self.current_line] - current_symbol = self.sym.get("txt", None) or "" - - self.previous_character = self.current_character - try: - # Start at the current column to skip previous duplicates of the symbol - self.current_character = current_line_txt.index( - current_symbol, self.current_character - ) - except ValueError: - self.current_character = 0 + self._definitions[base_token.stxt] = base_token.range + + def _evaluate_token(self, token: ParsedToken) -> LSPToken: + """Evaluate a parsed token to determine its value and metadata.""" + value = self.jmptbl.get(token.stxt, self.symtbl.get(token.stxt, None)) + defined_range = self._definitions.get(token.without_address_modifier().stxt) + + return LSPToken.from_parsed_token( + token=token, + value=value, + defined=defined_range, + is_constant=token.stxt in self._constants, + is_label=token.stxt in self.jmptbl, + ) def parse(self) -> SPINAsmParser: - """Parse and return the parser.""" + """Parse and evaluate all tokens.""" super().parse() + + for token in self._parsed_tokens: + evaluated_token = self._evaluate_token(token) + self.evaluated_tokens.add_token(evaluated_token) + return self diff --git a/src/spinasm_lsp/server.py b/src/spinasm_lsp/server.py index d7510e1..4090fa5 100644 --- a/src/spinasm_lsp/server.py +++ b/src/spinasm_lsp/server.py @@ -11,6 +11,7 @@ from spinasm_lsp import __version__ from spinasm_lsp.docs import MULTI_WORD_INSTRUCTIONS, DocumentationManager from spinasm_lsp.parser import SPINAsmParser +from spinasm_lsp.tokens import SEMANTIC_MODIFIER_LEGEND, SEMANTIC_TYPE_LEGEND @lru_cache(maxsize=1) @@ -92,47 +93,24 @@ def did_close( ls.publish_diagnostics(params.text_document.uri, []) -def _get_defined_hover(stxt: str, parser: SPINAsmParser) -> str: - """Get a hover message with the value of a defined variable or label.""" - # Check jmptbl first since labels are also defined in symtbl - if stxt in parser.jmptbl: - hover_definition = parser.jmptbl[stxt] - return f"(label) {stxt}: Offset[{hover_definition}]" - # Check constants next since they are also defined in symtbl - if stxt in parser.constants: - hover_definition = parser.symtbl[stxt] - return f"(constant) {stxt}: Literal[{hover_definition}]" - if stxt in parser.symtbl: - hover_definition = parser.symtbl[stxt] - return f"(variable) {stxt}: Literal[{hover_definition}]" - - return "" - - @server.feature(lsp.TEXT_DOCUMENT_HOVER) async def hover(ls: SPINAsmLanguageServer, params: lsp.HoverParams) -> lsp.Hover | None: """Retrieve documentation from symbols on hover.""" parser = await ls.get_parser(params.text_document.uri) - if (token := parser.token_registry.get_token_at_position(params.position)) is None: + if (token := parser.evaluated_tokens.get(position=params.position)) is None: return None - if token.symbol["type"] in ("LABEL", "TARGET"): - hover_msg = _get_defined_hover(str(token), parser=parser) - - return ( - None - if not hover_msg - else lsp.Hover( - # Java markdown formatting happens to give the best color-coding for - # hover messages - contents={"language": "java", "value": f"{hover_msg}"}, - range=token.range, - ) + if token.type in ("LABEL", "TARGET"): + return lsp.Hover( + # Java markdown formatting happens to give the best color-coding for + # hover messages + contents={"language": "java", "value": token.completion_detail}, + range=token.range, ) - if token.symbol["type"] in ("ASSEMBLER", "MNEMONIC"): - hover_msg = ls.documentation.get_markdown(str(token)) + if token.type in ("ASSEMBLER", "MNEMONIC"): + hover_msg = ls.documentation.get_markdown(token.stxt) return ( None @@ -155,19 +133,19 @@ async def completions( """Returns completion items.""" parser = await ls.get_parser(params.text_document.uri) - symbol_completions = [ - lsp.CompletionItem( - label=symbol, - kind=lsp.CompletionItemKind.Constant - if symbol in parser.constants - else lsp.CompletionItemKind.Variable - if symbol in parser.symtbl - else lsp.CompletionItemKind.Module, - detail=_get_defined_hover(symbol, parser=parser), - ) - for symbol in {**parser.symtbl, **parser.jmptbl} - ] - + # Get completions for all unique tokens (by their stxt) in the document + seen_tokens = set() + symbol_completions = [] + for token in parser.evaluated_tokens: + # Temporary fix until we can get completions for all tokens at once. + if token.type not in ("LABEL", "TARGET"): + continue + if token.stxt not in seen_tokens: + symbol_completions.append(token.completion_item) + seen_tokens.add(token.stxt) + + # TODO: If possible, get this from the completion item itself. This will require + # tokens to be able to query documentation. opcode_completions = [ lsp.CompletionItem( label=opcode, @@ -209,19 +187,15 @@ async def definition( document = ls.workspace.get_text_document(params.text_document.uri) - if (token := parser.token_registry.get_token_at_position(params.position)) is None: + if (token := parser.evaluated_tokens.get(position=params.position)) is None: return None - # Definitions should be checked against the base token name, ignoring address - # modifiers. - base_token = token.without_address_modifier() - - if str(base_token) not in parser.definitions: + if not token.defined: return None return lsp.Location( uri=document.uri, - range=parser.definitions[str(base_token)], + range=token.defined, ) @@ -229,22 +203,9 @@ async def definition( async def document_symbol_definitions( ls: SPINAsmLanguageServer, params: lsp.DocumentSymbolParams ) -> list[lsp.DocumentSymbol]: - """Returns the definitions of all symbols in the document.""" + """Returns the definition location of all symbols in the document.""" parser = await ls.get_parser(params.text_document.uri) - - return [ - lsp.DocumentSymbol( - name=symbol, - kind=lsp.SymbolKind.Module - if symbol in parser.jmptbl - # There's no need to check for constants here since they aren't included - # in the parser definitions. - else lsp.SymbolKind.Variable, - range=definition, - selection_range=definition, - ) - for symbol, definition in parser.definitions.items() - ] + return [t.document_symbol for t in parser.evaluated_tokens if t.defined] @server.feature(lsp.TEXT_DOCUMENT_PREPARE_RENAME) @@ -253,15 +214,15 @@ async def prepare_rename(ls: SPINAsmLanguageServer, params: lsp.PrepareRenamePar is a valid operation.""" parser = await ls.get_parser(params.text_document.uri) - if (token := parser.token_registry.get_token_at_position(params.position)) is None: + if (token := parser.evaluated_tokens.get(position=params.position)) is None: return None # Renaming is checked against the base token name, ignoring address modifiers. base_token = token.without_address_modifier() # Only user-defined labels should support renaming - if str(base_token) not in parser.definitions: - ls.info(f"Can't rename non-user defined token {base_token}.") + if not base_token.defined: + ls.info(f"Can't rename non-user defined token {base_token.stxt}.") return None return lsp.PrepareRenameResult_Type2(default_behavior=True) @@ -275,12 +236,12 @@ async def rename( ) -> lsp.WorkspaceEdit: parser = await ls.get_parser(params.text_document.uri) - if (token := parser.token_registry.get_token_at_position(params.position)) is None: + if (token := parser.evaluated_tokens.get(position=params.position)) is None: return None # Ignore address modifiers so that e.g. we can rename `Delay` by renaming `Delay#` base_token = token.without_address_modifier() - matching_tokens = parser.token_registry.get_matching_tokens(str(base_token)) + matching_tokens = parser.evaluated_tokens.get(name=base_token.stxt) edits = [lsp.TextEdit(t.range, new_text=params.new_name) for t in matching_tokens] return lsp.WorkspaceEdit(changes={params.text_document.uri: edits}) @@ -292,13 +253,13 @@ async def references( ) -> list[lsp.Location]: parser = await ls.get_parser(params.text_document.uri) - if (token := parser.token_registry.get_token_at_position(params.position)) is None: + if (token := parser.evaluated_tokens.get(position=params.position)) is None: return [] # Ignore address modifiers so that e.g. we can find all variations of addresses, # e.g. `Delay` and `Delay#` base_token = token.without_address_modifier() - matching_tokens = parser.token_registry.get_matching_tokens(str(base_token)) + matching_tokens = parser.evaluated_tokens.get(name=base_token.stxt) return [ lsp.Location(uri=params.text_document.uri, range=t.range) @@ -317,12 +278,11 @@ async def signature_help( # Find all opcodes on the line that could have triggered the signature help. Ignore # opcodes that appear after the cursor, to avoid showing signature help prematurely. - line_tokens = parser.token_registry._tokens_by_line.get(params.position.line, []) + line_tokens = parser.evaluated_tokens.get(line=params.position.line) opcodes = [ t for t in line_tokens - if t.symbol["type"] == "MNEMONIC" - and t.range.end.character < params.position.character + if t.is_opcode and t.range.end.character < params.position.character ] if not opcodes: return None @@ -330,17 +290,17 @@ async def signature_help( # We should never have more than one opcode on a line, but just in case, grab the # last one entered before the cursor. triggered_opcode = opcodes[-1] - opcode = ls.documentation.get_instruction(str(triggered_opcode)) + opcode = ls.documentation.get_instruction(triggered_opcode.stxt) if opcode is None: return None # Get all argument separators after the opcode remaining_tokens = line_tokens[line_tokens.index(triggered_opcode) + 1 :] - argseps = [t for t in remaining_tokens if t.symbol["type"] == "ARGSEP"] + argseps = [t for t in remaining_tokens if t.type == "ARGSEP"] # The first argument of multi-word instructions like CHO RDAL is treated as part of # the opcode, so we should skip the first separator when counting arguments. - if str(triggered_opcode) in MULTI_WORD_INSTRUCTIONS: + if triggered_opcode.stxt in MULTI_WORD_INSTRUCTIONS: argseps = argseps[1:] # Count how many parameters are left of the cursor to see which argument we're @@ -371,6 +331,35 @@ async def signature_help( ) +@server.feature( + lsp.TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL, + lsp.SemanticTokensLegend( + token_types=[x.value for x in SEMANTIC_TYPE_LEGEND], + token_modifiers=[x.value for x in SEMANTIC_MODIFIER_LEGEND], + ), +) +async def semantic_tokens( + ls: SPINAsmLanguageServer, params: lsp.SemanticTokensParams +) -> lsp.SemanticTokens: + parser = await ls.get_parser(params.text_document.uri) + + encoding: list[int] = [] + prev_token_position = lsp.Position(0, 0) + for token in parser.evaluated_tokens: + token_encoding = token.semantic_encoding(prev_token_position) + + # Tokens without semantic encoding (e.g. operators) should be ignored so that + # the next encoding is relative to the last encoded token. Otherwise, character + # offsets would be incorrect. + if not token_encoding: + continue + + encoding += token_encoding + prev_token_position = token.range.start + + return lsp.SemanticTokens(data=encoding) + + def start() -> None: server.start_io() diff --git a/src/spinasm_lsp/tokens.py b/src/spinasm_lsp/tokens.py new file mode 100644 index 0000000..34340c0 --- /dev/null +++ b/src/spinasm_lsp/tokens.py @@ -0,0 +1,398 @@ +"""Data structures for storing and retrieving parsed tokens.""" + +from __future__ import annotations + +import bisect +import copy +from dataclasses import dataclass +from typing import Generator, Generic, Literal, TypeVar, overload + +import lsprotocol.types as lsp + +_ParsedTokenT = TypeVar("_ParsedTokenT", bound="ParsedToken") +_EvaluatedTokenT = TypeVar("_EvaluatedTokenT", bound="EvaluatedToken") + +# Token types assigned by asfv1. Note that we exclude EOF tokens, as they are ignored by +# the LSP. +TokenType = Literal[ + "ASSEMBLER", + "INTEGER", + "LABEL", + "TARGET", + "MNEMONIC", + "OPERATOR", + "FLOAT", + "ARGSEP", +] + +# Map semantic type enums to integer encodings +SEMANTIC_TYPE_LEGEND = {k: i for i, k in enumerate(lsp.SemanticTokenTypes)} +SEMANTIC_MODIFIER_LEGEND = {k: i for i, k in enumerate(lsp.SemanticTokenModifiers)} + + +@dataclass +class ASFV1Token: + """Raw token metadata parsed by asfv1.""" + + type: TokenType + txt: str + stxt: str + val: int | float | None + + def at_position( + self, start: lsp.Position, end: lsp.Position | None = None + ) -> ParsedToken: + """Create a parsed token with this token's metadata at a position.""" + if end is None: + width = len(self.stxt) + end = lsp.Position(line=start.line, character=start.character + width) + + return ParsedToken( + type=self.type, + stxt=self.stxt, + range=lsp.Range(start=start, end=end), + ) + + +class ParsedToken: + """ + Token metadata including its position. + + Parameters + ---------- + type : TokenType + The type of token identified by asfv1. + stxt : str + The name assigned to the token, always uppercase. + range : lsp.Range + The position of the token in the source code. + """ + + def __init__(self, type: TokenType, stxt: str, range: lsp.Range): + self.type = type + self.stxt = stxt + self.range = range + + def _clone(self: _ParsedTokenT) -> _ParsedTokenT: + """Return a clone of the token to avoid mutating the original.""" + return copy.deepcopy(self) + + def without_address_modifier(self: _ParsedTokenT) -> _ParsedTokenT: + """ + Create a clone of the token with the address modifier removed. + """ + if not self.stxt.endswith("#") and not self.stxt.endswith("^"): + return self + + clone = self._clone() + clone.stxt = clone.stxt[:-1] + clone.range.end.character -= 1 + + return clone + + def concatenate(self: _ParsedTokenT, other: _ParsedTokenT) -> _ParsedTokenT: + """ + Concatenate by merging with another token, in place. + + In practice, this is used for the multi-word opcodes that are parsed as separate + tokens: CHO RDA, CHO RDAL, and CHO SOF. + """ + self.stxt += f" {other.stxt}" + self.range.end = other.range.end + return self + + +class EvaluatedToken(ParsedToken): + """ + A parsed token that has been evaluated to determine its value and other metadata. + """ + + def __init__( + self, + type: TokenType, + stxt: str, + range: lsp.Range, + value: float | int | None = None, + defined: lsp.Range | None = None, + is_constant: bool = False, + is_label: bool = False, + ): + super().__init__(type=type, stxt=stxt, range=range) + + self.value = value + """The numeric value of the evaluated token, if applicable.""" + + self.defined = defined + """The range where the token is defined, if applicable.""" + + self.is_constant = is_constant + self.is_label = is_label + self.is_opcode = self.type == "MNEMONIC" + + @classmethod + def from_parsed_token( + cls: type[_EvaluatedTokenT], + token: ParsedToken, + *, + value: float | int | None = None, + defined: lsp.Range | None = None, + is_constant: bool = False, + is_label: bool = False, + ) -> _EvaluatedTokenT: + """Create an evaluated token from a parsed token.""" + return cls( + type=token.type, + stxt=token.stxt, + range=token.range, + value=value, + defined=defined, + is_constant=is_constant, + is_label=is_label, + ) + + +class SemanticTokenMixin(EvaluatedToken): + """A mixin for evaluated tokens with semantic information.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.semantic_type, self.semantic_modifiers = self._infer_semantics() + + def _infer_semantics( + self, + ) -> tuple[lsp.SemanticTokenTypes, list[lsp.SemanticTokenModifiers]]: + """Infer the semantic type and modifiers for the token.""" + # Crosswalk asfv1 token types to LSP semantic token types + type_semantics = { + "MNEMONIC": lsp.SemanticTokenTypes.Function, + "INTEGER": lsp.SemanticTokenTypes.Number, + "FLOAT": lsp.SemanticTokenTypes.Number, + "ASSEMBLER": lsp.SemanticTokenTypes.Operator, + "ARGSEP": lsp.SemanticTokenTypes.Operator, + "LABEL": lsp.SemanticTokenTypes.Variable, + "TARGET": lsp.SemanticTokenTypes.Namespace, + } + + semantic_type = type_semantics.get(self.type) + if self.is_label: + semantic_type = lsp.SemanticTokenTypes.Namespace + + semantic_modifiers = [] + if self.is_constant and self.type != "MNEMONIC": + semantic_modifiers += [ + lsp.SemanticTokenModifiers.Readonly, + lsp.SemanticTokenModifiers.DefaultLibrary, + ] + + if self.stxt.endswith("#") or self.stxt.endswith("^"): + semantic_modifiers.append(lsp.SemanticTokenModifiers.Modification) + + if self.defined == self.range: + semantic_modifiers.append(lsp.SemanticTokenModifiers.Definition) + + return semantic_type, semantic_modifiers + + def semantic_encoding(self, prev_token_start: lsp.Position) -> list[int]: + """ + Encode the token's semantic information for the LSP. + + The output is a list of 5 ints representing: + - The delta line from the previous token + - The delta character from the previous token + - The length of the token + - The semantic type index + - The encoded semantic modifiers + + See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_semanticTokens + """ + # Set the token's position relative to the previous token. If we're on a new + # line, set the character relative to zero. + delta_line = self.range.start.line - prev_token_start.line + delta_start_char = ( + self.range.start.character + if delta_line + else self.range.start.character - prev_token_start.character + ) + + token_type = SEMANTIC_TYPE_LEGEND.get(self.semantic_type) + token_modifiers = [ + SEMANTIC_MODIFIER_LEGEND.get(mod) for mod in self.semantic_modifiers + ] + # Return an empty semantic encoding if type or modifiers are unrecognized + if token_type is None or None in token_modifiers: + return [] + + # The index of each modifier is encoded into a bitmask + modifier_bitmask = sum(1 << i for i in token_modifiers) # type: ignore + + return [ + delta_line, + delta_start_char, + len(self.stxt), + token_type, + modifier_bitmask, + ] + + +class LSPTokenMixin(EvaluatedToken): + """A mixin for evaluated tokens with LSP information.""" + + @property + def completion_detail(self) -> str: + """A description of the token used in completions and hover.""" + type_str = ( + "opcode" + if self.is_opcode + else "label" + if self.is_label + else "constant" + if self.is_constant + else "variable" + ) + value_type = "Offset" if self.is_label else "Literal" + + return ( + f"({type_str})" + f" {self.stxt}: {value_type}[{self.value}]" + if not self.is_opcode + else "" + ) + + @property + def completion_kind(self) -> lsp.CompletionItemKind: + return ( + lsp.CompletionItemKind.Function + if self.is_opcode + else lsp.CompletionItemKind.Constant + if self.is_constant + else lsp.CompletionItemKind.Module + if self.is_label + else lsp.CompletionItemKind.Variable + ) + + @property + def completion_item(self) -> lsp.CompletionItem: + """Create a completion item for the token.""" + + return lsp.CompletionItem( + label=self.stxt, + kind=self.completion_kind, + detail=self.completion_detail, + documentation=None, + ) + + @property + def symbol_kind(self) -> lsp.SymbolKind: + return ( + lsp.SymbolKind.Function + if self.is_opcode + else lsp.SymbolKind.Constant + if self.is_constant + else lsp.SymbolKind.Module + if self.is_label + else lsp.SymbolKind.Variable + ) + + @property + def document_symbol(self) -> lsp.DocumentSymbol: + """Create a document symbol for the token.""" + return lsp.DocumentSymbol( + name=self.stxt, + kind=self.symbol_kind, + range=self.defined, + selection_range=self.defined, + ) + + +class LSPToken(LSPTokenMixin, SemanticTokenMixin): + """An evaluated token with semantic and LSP information.""" + + +class TokenLookup(Generic[_ParsedTokenT]): + """A lookup table for tokens by position and name.""" + + def __init__(self): + self._prev_token: _ParsedTokenT | None = None + self._line_lookup: dict[int, list[_ParsedTokenT]] = {} + self._name_lookup: dict[str, list[_ParsedTokenT]] = {} + + def __iter__(self) -> Generator[_ParsedTokenT, None, None]: + """Yield all tokens in order.""" + for line in self._line_lookup.values(): + yield from line + + @overload + def get(self, *, position: lsp.Position) -> _ParsedTokenT | None: ... + @overload + def get(self, *, name: str) -> list[_ParsedTokenT]: ... + @overload + def get(self, *, line: int) -> list[_ParsedTokenT]: ... + + def get( + self, + *, + position: lsp.Position | None = None, + name: str | None = None, + line: int | None = None, + ) -> _ParsedTokenT | list[_ParsedTokenT] | None: + ... + """Retrieve a token by position, name, or line.""" + # Raise if more than one argument is provided + if sum(arg is not None for arg in (position, name, line)) > 1: + raise ValueError("Only one of position, name, or line may be provided") + + if position is not None: + return self._token_at_position(position) + if line is not None: + return self._line_lookup.get(line, []) + if name is not None: + return self._name_lookup.get(name.upper(), []) + raise ValueError("Either a position, name, or line must be provided.") + + def add_token(self, token: _ParsedTokenT) -> None: + """Store a token for future lookup.""" + # Handle multi-word CHO instructions by merging the second token with the first + # and skipping the second token. + if ( + self._prev_token + and self._prev_token.stxt == "CHO" + and token.stxt in ("RDA", "RDAL", "SOF") + ): + self._prev_token.concatenate(token) # type: ignore + return + + # Store the token on its line + self._line_lookup.setdefault(token.range.start.line, []).append(token) + self._prev_token = token + + # Store user-defined tokens together by name. Other token types could be stored, + # but currently there's no use case for retrieving their positions. + if token.type in ("LABEL", "TARGET"): + # Tokens are stored by name without address modifiers, so that e.g. Delay# + # and Delay can be retrieved with the same query. This allows for renaming + # all instances of a memory token. + base_token = token.without_address_modifier() + self._name_lookup.setdefault(base_token.stxt, []).append(base_token) + + def _token_at_position(self, position: lsp.Position) -> _ParsedTokenT | None: + """Retrieve the token at the given position.""" + if position.line not in self._line_lookup: + return None + + line_tokens = self._line_lookup[position.line] + token_starts = [t.range.start.character for t in line_tokens] + token_ends = [t.range.end.character for t in line_tokens] + + idx = bisect.bisect_left(token_starts, position.character) + + # The index returned by bisect_left points to the start value >= character. This + # will either be the first character of the token or the start of the next + # token. First check if we're out of bounds, then shift left unless we're at the + # first character of the token. + if idx == len(line_tokens) or token_starts[idx] != position.character: + idx -= 1 + + # If the col falls after the end of the token, we're not inside a token. + if position.character > token_ends[idx]: + return None + + return line_tokens[idx] diff --git a/tests/conftest.py b/tests/conftest.py index bf9d2e6..0b0631a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ from __future__ import annotations +from dataclasses import dataclass from pathlib import Path -from typing import TypedDict import lsprotocol.types as lsp +import pytest import pytest_lsp from pytest_lsp import ClientServerConfig, LanguageClient @@ -28,10 +29,16 @@ async def client(request, lsp_client: LanguageClient): await lsp_client.shutdown_session() -class TestCase(TypedDict): +@dataclass +class TestCase: """The inputs and outputs of a test case.""" __test__ = False name: str """The name used to identify the test case.""" + + +def parametrize_cases(test_cases: list[TestCase]): + """A decorator to parametrize a test function with test cases.""" + return pytest.mark.parametrize("test_case", test_cases, ids=lambda x: x.name) diff --git a/tests/server_tests/test_completion.py b/tests/server_tests/test_completion.py index 8b90351..db5aa9e 100644 --- a/tests/server_tests/test_completion.py +++ b/tests/server_tests/test_completion.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class CompletionTestCase(TestCase): """A dictionary to track an expected completion result.""" @@ -17,62 +20,62 @@ class CompletionTestCase(TestCase): uri: str -COMPLETIONS: list[CompletionTestCase] = [ - { - "name": "APOUT", - "label": "APOUT", - "detail": "(variable) APOUT: Literal[33]", - "kind": lsp.CompletionItemKind.Variable, - "doc_contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "REG0", - "label": "REG0", - "detail": "(constant) REG0: Literal[32]", - "kind": lsp.CompletionItemKind.Constant, - "doc_contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "CHO RDA", - "label": "CHO RDA", - "detail": "(opcode)", - "kind": lsp.CompletionItemKind.Function, - "doc_contains": "`CHO RDA, N, C, D`", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "EQU", - "label": "EQU", - "detail": "(assembler)", - "kind": lsp.CompletionItemKind.Operator, - "doc_contains": "**`EQU`** allows one to define symbolic operands", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, +TEST_CASES: list[CompletionTestCase] = [ + CompletionTestCase( + name="variable", + label="APOUT", + detail="(variable) APOUT: Literal[33]", + kind=lsp.CompletionItemKind.Variable, + doc_contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + CompletionTestCase( + name="constant", + label="REG0", + detail="(constant) REG0: Literal[32]", + kind=lsp.CompletionItemKind.Constant, + doc_contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + CompletionTestCase( + name="multi-word opcode", + label="CHO RDA", + detail="(opcode)", + kind=lsp.CompletionItemKind.Function, + doc_contains="`CHO RDA, N, C, D`", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + CompletionTestCase( + name="assembler", + label="EQU", + detail="(assembler)", + kind=lsp.CompletionItemKind.Operator, + doc_contains="**`EQU`** allows one to define symbolic operands", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), ] -@pytest.mark.parametrize("test_case", COMPLETIONS, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_completions(test_case: CompletionTestCase, client: LanguageClient): """Test that expected completions are shown with details and documentation.""" results = await client.text_document_completion_async( params=lsp.CompletionParams( position=lsp.Position(line=0, character=0), - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) assert results is not None, "Expected completions" - matches = [item for item in results.items if item.label == test_case["label"]] + matches = [item for item in results.items if item.label == test_case.label] assert ( len(matches) == 1 - ), f"Expected 1 matching label `{test_case['label']}, got {len(matches)}." + ), f"Expected 1 matching label `{test_case.label}`, got {len(matches)}." match = matches[0] - assert match.detail == test_case["detail"] - assert match.kind == test_case["kind"] - if test_case["doc_contains"] is not None: - assert test_case["doc_contains"] in str(match.documentation) + assert match.detail == test_case.detail + assert match.kind == test_case.kind + if test_case.doc_contains is not None: + assert test_case.doc_contains in str(match.documentation) diff --git a/tests/server_tests/test_definition.py b/tests/server_tests/test_definition.py index 5126ce7..460c65f 100644 --- a/tests/server_tests/test_definition.py +++ b/tests/server_tests/test_definition.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class DefinitionTestCase(TestCase): """A dictionary track where a symbol is referenced and defined.""" @@ -15,72 +18,67 @@ class DefinitionTestCase(TestCase): uri: str -DEFINITIONS: list[DefinitionTestCase] = [ - { - # Variable - "name": "apout", - "referenced": lsp.Position(line=57, character=7), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "defined": lsp.Location( +TEST_CASES: list[DefinitionTestCase] = [ + DefinitionTestCase( + name="apout", + referenced=lsp.Position(line=57, character=7), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + defined=lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( start=lsp.Position(line=23, character=4), end=lsp.Position(line=23, character=9), ), ), - }, - { - # Memory - "name": "lap2a", - "referenced": lsp.Position(line=72, character=7), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "defined": lsp.Location( + ), + DefinitionTestCase( + name="lap2a", + referenced=lsp.Position(line=72, character=7), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + defined=lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( start=lsp.Position(line=16, character=4), end=lsp.Position(line=16, character=9), ), ), - }, - { - # Memory. Note that this has an address modifier, but still points to the - # original definition. - "name": "lap2a#", - "referenced": lsp.Position(line=71, character=7), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "defined": lsp.Location( + ), + DefinitionTestCase( + name="lap2a#", + referenced=lsp.Position(line=71, character=7), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + defined=lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( start=lsp.Position(line=16, character=4), end=lsp.Position(line=16, character=9), ), ), - }, - { - # Label - "name": "endclr", - "referenced": lsp.Position(line=37, character=9), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "defined": lsp.Location( + ), + DefinitionTestCase( + name="endclr", + referenced=lsp.Position(line=37, character=9), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + defined=lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( start=lsp.Position(line=41, character=0), end=lsp.Position(line=41, character=6), ), ), - }, + ), ] @pytest.mark.asyncio() -@pytest.mark.parametrize("test_case", DEFINITIONS, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) async def test_definition(test_case: DefinitionTestCase, client: LanguageClient): """Test that the definition location of different assignments is correct.""" result = await client.text_document_definition_async( params=lsp.DefinitionParams( - position=test_case["referenced"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.referenced, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - assert result == test_case["defined"] + assert result == test_case.defined diff --git a/tests/server_tests/test_diagnostics.py b/tests/server_tests/test_diagnostics.py index 16610e8..af8b3ff 100644 --- a/tests/server_tests/test_diagnostics.py +++ b/tests/server_tests/test_diagnostics.py @@ -1,70 +1,97 @@ +from __future__ import annotations + +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient +from ..conftest import TestCase, parametrize_cases -@pytest.mark.asyncio() -async def test_diagnostic_parsing_errors(client: LanguageClient): - """Test that parsing errors and warnings are correctly reported by the server.""" - source_with_errors = """ -; Undefined symbol a -SOF 0,a -; Label REG0 re-defined -REG0 EQU 4 +@dataclass +class DiagnosticTestCase(TestCase): + """A dictionary to record prepare rename results for a symbol.""" + + source: str + expected: list[lsp.Diagnostic] -; Register out of range -MULX 100 -""" +TEST_CASES: list[DiagnosticTestCase] = [ + DiagnosticTestCase( + name="undefined label", + source="""SOF 0, a\n""", + expected=[ + lsp.Diagnostic( + range=lsp.Range( + start=lsp.Position(line=0, character=7), + end=lsp.Position(line=0, character=7), + ), + message="Undefined label a", + severity=lsp.DiagnosticSeverity.Error, + source="SPINAsm", + ), + ], + ), + DiagnosticTestCase( + name="redefined constant", + source="""REG0 EQU 4\n""", + expected=[ + lsp.Diagnostic( + range=lsp.Range( + start=lsp.Position(line=0, character=9), + end=lsp.Position(line=0, character=9), + ), + message="Label REG0 re-defined", + severity=lsp.DiagnosticSeverity.Warning, + source="SPINAsm", + ), + ], + ), + DiagnosticTestCase( + name="out of range", + source="""MULX 100\n""", + expected=[ + lsp.Diagnostic( + range=lsp.Range( + start=lsp.Position(line=0, character=0), + end=lsp.Position(line=0, character=0), + ), + message="Register 0x64 out of range for MULX", + severity=lsp.DiagnosticSeverity.Error, + source="SPINAsm", + ), + ], + ), +] + + +@parametrize_cases(TEST_CASES) +@pytest.mark.asyncio() +async def test_diagnostic_parsing_errors( + test_case: DiagnosticTestCase, client: LanguageClient +): + """Test that parsing errors and warnings are correctly reported by the server.""" # We need a URI to associate with the source, but it doesn't need to be a real file. test_uri = "dummy_uri" + client.text_document_did_open( lsp.DidOpenTextDocumentParams( text_document=lsp.TextDocumentItem( uri=test_uri, language_id="spinasm", version=1, - text=source_with_errors, + text=test_case.source, ) ) ) await client.wait_for_notification(lsp.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS) - expected = [ - lsp.Diagnostic( - range=lsp.Range( - start=lsp.Position(line=2, character=6), - end=lsp.Position(line=2, character=6), - ), - message="Undefined label a", - severity=lsp.DiagnosticSeverity.Error, - source="SPINAsm", - ), - lsp.Diagnostic( - range=lsp.Range( - start=lsp.Position(line=5, character=9), - end=lsp.Position(line=5, character=9), - ), - message="Label REG0 re-defined", - severity=lsp.DiagnosticSeverity.Warning, - source="SPINAsm", - ), - lsp.Diagnostic( - range=lsp.Range( - start=lsp.Position(line=8, character=0), - end=lsp.Position(line=8, character=0), - ), - message="Register 0x64 out of range for MULX", - severity=lsp.DiagnosticSeverity.Error, - source="SPINAsm", - ), - ] - returned = client.diagnostics[test_uri] - extra = len(returned) - len(expected) - assert extra == 0, f"Expected {len(expected)} diagnostics, got {len(returned)}." + assert len(returned) == len( + test_case.expected + ), "Expected number of diagnostics does not match" - for i, diag in enumerate(expected): - assert diag == returned[i], f"Diagnostic {i} does not match expected" + for expected, actual in zip(test_case.expected, returned): + assert actual == expected, "Diagnostic does not match expected" diff --git a/tests/server_tests/test_hover.py b/tests/server_tests/test_hover.py index 49ce7dc..a0ac7cf 100644 --- a/tests/server_tests/test_hover.py +++ b/tests/server_tests/test_hover.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class HoverTestCase(TestCase): """A dictionary to record hover information for a symbol.""" @@ -15,79 +18,76 @@ class HoverTestCase(TestCase): uri: str -HOVERS: list[HoverTestCase] = [ - { - "name": "mem", - "position": lsp.Position(line=8, character=0), - "contains": "`MEM`", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "skp", - "position": lsp.Position(line=37, character=2), - "contains": "`SKP CMASK, N`", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "endclr", - "position": lsp.Position(line=37, character=13), - "contains": "(label) ENDCLR: Offset[4]", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "mono", - "position": lsp.Position(line=47, character=5), - "contains": "(variable) MONO: Literal[32]", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "reg0", - "position": lsp.Position(line=22, character=9), - "contains": "(constant) REG0: Literal[32]", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "lap2b#", - "position": lsp.Position(line=73, character=4), - "contains": "(variable) LAP2B#: Literal[9802]", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # CHO RDA, hovering over CHO - "name": "CHO_rda", - "position": lsp.Position(line=85, character=0), - "contains": "`CHO RDA, N, C, D`", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # CHO RDA, hovering over RDA - "name": "cho_RDA", - "position": lsp.Position(line=85, character=4), - "contains": "`CHO RDA, N, C, D`", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # Hovering over an int, which should return no hover info - "name": "None", - "position": lsp.Position(line=8, character=8), - "contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, +TEST_CASES: list[HoverTestCase] = [ + HoverTestCase( + name="mem", + position=lsp.Position(line=8, character=0), + contains="`MEM`", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="skp", + position=lsp.Position(line=37, character=2), + contains="`SKP CMASK, N`", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="endclr", + position=lsp.Position(line=37, character=13), + contains="(label) ENDCLR: Offset[4]", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="mono", + position=lsp.Position(line=47, character=5), + contains="(variable) MONO: Literal[32]", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="reg0", + position=lsp.Position(line=22, character=9), + contains="(constant) REG0: Literal[32]", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="lap2b#", + position=lsp.Position(line=73, character=4), + contains="(variable) LAP2B#: Literal[9802]", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="CHO_rda", + position=lsp.Position(line=85, character=0), + contains="`CHO RDA, N, C, D`", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="cho_RDA", + position=lsp.Position(line=85, character=4), + contains="`CHO RDA, N, C, D`", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + HoverTestCase( + name="None", + position=lsp.Position(line=8, character=8), + contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), ] -@pytest.mark.parametrize("test_case", HOVERS, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() -async def test_hover(test_case: dict, client: LanguageClient): +async def test_hover(test_case: HoverTestCase, client: LanguageClient): result = await client.text_document_hover_async( params=lsp.CompletionParams( - position=test_case["position"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.position, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - if test_case["contains"] is None: + if test_case.contains is None: assert result is None, "Expected no hover result" else: - msg = f"Hover does not contain `{test_case['contains']}`" - assert test_case["contains"] in result.contents.value, msg + msg = f"Hover does not contain `{test_case.contains}`" + assert test_case.contains in result.contents.value, msg diff --git a/tests/server_tests/test_prepare_rename.py b/tests/server_tests/test_prepare_rename.py index 9206c7a..ea524fb 100644 --- a/tests/server_tests/test_prepare_rename.py +++ b/tests/server_tests/test_prepare_rename.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class PrepareRenameTestCase(TestCase): """A dictionary to record prepare rename results for a symbol.""" @@ -16,53 +19,53 @@ class PrepareRenameTestCase(TestCase): uri: str -PREPARE_RENAMES: list[PrepareRenameTestCase] = [ - { - "name": "mem", - "position": lsp.Position(line=8, character=0), - "result": None, - "message": "Can't rename non-user defined token MEM.", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "reg0", - "position": lsp.Position(line=22, character=10), - "result": None, - "message": "Can't rename non-user defined token REG0.", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "ap1", - "position": lsp.Position(line=8, character=4), - "result": lsp.PrepareRenameResult_Type2(default_behavior=True), - "message": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "endclr", - "position": lsp.Position(line=37, character=10), - "result": lsp.PrepareRenameResult_Type2(default_behavior=True), - "message": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, +TEST_CASES: list[PrepareRenameTestCase] = [ + PrepareRenameTestCase( + name="mem", + position=lsp.Position(line=8, character=0), + result=None, + message="Can't rename non-user defined token MEM.", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + PrepareRenameTestCase( + name="reg0", + position=lsp.Position(line=22, character=10), + result=None, + message="Can't rename non-user defined token REG0.", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + PrepareRenameTestCase( + name="ap1", + position=lsp.Position(line=8, character=4), + result=lsp.PrepareRenameResult_Type2(default_behavior=True), + message=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + PrepareRenameTestCase( + name="endclr", + position=lsp.Position(line=37, character=10), + result=lsp.PrepareRenameResult_Type2(default_behavior=True), + message=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), ] -@pytest.mark.parametrize("test_case", PREPARE_RENAMES, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_prepare_rename(test_case: PrepareRenameTestCase, client: LanguageClient): """Test that prepare rename prevents renaming non-user defined tokens.""" result = await client.text_document_prepare_rename_async( params=lsp.PrepareRenameParams( - position=test_case["position"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.position, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - assert result == test_case["result"] + assert result == test_case.result - if test_case["message"]: - assert test_case["message"] in client.log_messages[0].message + if test_case.message: + assert test_case.message in client.log_messages[0].message assert client.log_messages[0].type == lsp.MessageType.Info else: assert not client.log_messages diff --git a/tests/server_tests/test_reference.py b/tests/server_tests/test_reference.py index 080028e..12fe6e8 100644 --- a/tests/server_tests/test_reference.py +++ b/tests/server_tests/test_reference.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class ReferenceTestCase(TestCase): """A dictionary to record reference locations for a symbol.""" @@ -15,13 +18,12 @@ class ReferenceTestCase(TestCase): uri: str -REFERENCES: list[ReferenceTestCase] = [ - { - # Variable - "name": "apout", - "position": lsp.Position(line=23, character=4), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "references": [ +TEST_CASES: list[ReferenceTestCase] = [ + ReferenceTestCase( + name="apout", + position=lsp.Position(line=23, character=4), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + references=[ lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( @@ -51,12 +53,12 @@ class ReferenceTestCase(TestCase): ), ), ], - }, - { - "name": "ap1", - "position": lsp.Position(line=8, character=4), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "references": [ + ), + ReferenceTestCase( + name="ap1", + position=lsp.Position(line=8, character=4), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + references=[ lsp.Location( uri=f"file:///{PATCH_DIR / 'Basic.spn'}", range=lsp.Range( @@ -79,20 +81,20 @@ class ReferenceTestCase(TestCase): ), ), ], - }, + ), ] -@pytest.mark.parametrize("test_case", REFERENCES, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_references(test_case: ReferenceTestCase, client: LanguageClient): """Test that references to a symbol are correctly found.""" result = await client.text_document_references_async( params=lsp.ReferenceParams( context=lsp.ReferenceContext(include_declaration=False), - position=test_case["position"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.position, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - assert result == test_case["references"] + assert result == test_case.references diff --git a/tests/server_tests/test_rename.py b/tests/server_tests/test_rename.py index 3c65ffa..c657af5 100644 --- a/tests/server_tests/test_rename.py +++ b/tests/server_tests/test_rename.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class RenameTestCase(TestCase): """A dictionary to record rename results for a symbol.""" @@ -16,13 +19,13 @@ class RenameTestCase(TestCase): uri: str -RENAMES: list[RenameTestCase] = [ - { - "name": "ap1", - "rename_to": "FOO", - "position": lsp.Position(line=8, character=4), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "changes": [ +TEST_CASES: list[RenameTestCase] = [ + RenameTestCase( + name="ap1", + rename_to="FOO", + position=lsp.Position(line=8, character=4), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + changes=[ lsp.TextEdit( range=lsp.Range(start=lsp.Position(8, 4), end=lsp.Position(8, 7)), new_text="FOO", @@ -37,13 +40,13 @@ class RenameTestCase(TestCase): new_text="FOO", ), ], - }, - { - "name": "endclr", - "rename_to": "END", - "position": lsp.Position(line=41, character=0), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "changes": [ + ), + RenameTestCase( + name="endclr", + rename_to="END", + position=lsp.Position(line=41, character=0), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + changes=[ lsp.TextEdit( range=lsp.Range(start=lsp.Position(37, 8), end=lsp.Position(37, 14)), new_text="END", @@ -53,13 +56,13 @@ class RenameTestCase(TestCase): new_text="END", ), ], - }, - { - "name": "lap1a#", - "rename_to": "FOO", - "position": lsp.Position(line=61, character=4), - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "changes": [ + ), + RenameTestCase( + name="lap1a#", + rename_to="FOO", + position=lsp.Position(line=61, character=4), + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + changes=[ # Renaming `lap1a#` should also rename `lap1a` lsp.TextEdit( range=lsp.Range(start=lsp.Position(12, 4), end=lsp.Position(12, 9)), @@ -74,20 +77,20 @@ class RenameTestCase(TestCase): new_text="FOO", ), ], - }, + ), ] -@pytest.mark.parametrize("test_case", RENAMES, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_rename(test_case: RenameTestCase, client: LanguageClient): """Test that renaming a symbol suggests the correct edits.""" result = await client.text_document_rename_async( params=lsp.RenameParams( - position=test_case["position"], - new_name=test_case["rename_to"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.position, + new_name=test_case.rename_to, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - assert result.changes[test_case["uri"]] == test_case["changes"] + assert result.changes[test_case.uri] == test_case.changes diff --git a/tests/server_tests/test_semantics.py b/tests/server_tests/test_semantics.py new file mode 100644 index 0000000..80c27dd --- /dev/null +++ b/tests/server_tests/test_semantics.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import itertools +import tempfile +from dataclasses import dataclass + +import lsprotocol.types as lsp +import pytest +from pytest_lsp import LanguageClient + +from ..conftest import TestCase, parametrize_cases + + +@dataclass +class SemanticTestCase(TestCase): + """A dictionary to record prepare rename results for a symbol.""" + + source: str + encoding: list[int] + + +# fmt: off +TEST_CASES: list[SemanticTestCase] = [ + SemanticTestCase( + name="variable definition", + source="""Delay MEM REG0""", + encoding=[ + 0, 0, 5, 8, 0b10, # variable, definition + 0, 6, 3, 21, 0b0, # operator + 0, 4, 4, 8, 0b1000000100, # variable, constant readonly + ], + ), + SemanticTestCase( + name="label and opcode", + source="""start:\nsof 0,0""", + encoding=[ + 0, 0, 5, 0, 0b10, # namespace, definition + 1, 0, 3, 12, 0b0, # function + 0, 4, 1, 19, 0b0, # number + 0, 1, 1, 21, 0b0, # argsep + 0, 1, 1, 19, 0b0, # number + ], + ), +] +# fmt: on + + +@parametrize_cases(TEST_CASES) +@pytest.mark.asyncio() +async def test_semantic_tokens( + test_case: SemanticTestCase, client: LanguageClient +) -> None: + def batched(iterable, n): + """ + Partial back port of itertools.batched from Python 3.12. + + https://docs.python.org/3/library/itertools.html#itertools.batched + """ + iterator = iter(iterable) + while batch := tuple(itertools.islice(iterator, n)): + yield batch + + tmp = tempfile.NamedTemporaryFile() + with open(tmp.name, "w") as dst: + dst.write(test_case.source) + + response = await client.text_document_semantic_tokens_full_async( + params=lsp.SemanticTokensParams( + text_document=lsp.TextDocumentIdentifier( + uri=f"file:///{tmp.name}", + ), + ) + ) + + assert len(response.data) == len(test_case.encoding), "Unexpected encoding length" + + # Compare encodings 1 token at a time to make it easier to diagnose issues + for got, expected in zip(batched(response.data, 5), batched(test_case.encoding, 5)): + assert got == expected diff --git a/tests/server_tests/test_signature_help.py b/tests/server_tests/test_signature_help.py index b752e46..014a4e2 100644 --- a/tests/server_tests/test_signature_help.py +++ b/tests/server_tests/test_signature_help.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class SignatureHelpTestCase(TestCase): """A dictionary to record signature help information for at a position.""" @@ -17,73 +20,67 @@ class SignatureHelpTestCase(TestCase): uri: str -SIGNATURE_HELPS: list[SignatureHelpTestCase] = [ - { - # No opcode on this line, so the signature help should be None - "name": "no_opcode", - "position": lsp.Position(line=8, character=3), - "active_parameter": None, - "doc_contains": None, - "param_contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "skp_first_arg", - "position": lsp.Position(line=37, character=4), - "active_parameter": 0, - "doc_contains": "**`SKP CMASK, N`** allows conditional program execution", - "param_contains": "CMASK: Binary | Hex ($00-$1F) | Symbolic", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - "name": "skp_second_arg", - "position": lsp.Position(line=37, character=8), - "active_parameter": 1, - "doc_contains": "**`SKP CMASK, N`** allows conditional program execution", - "param_contains": "N: Decimal (1-63) | Label", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # You should still get the last argument even if you're well beyond it - "name": "skp_out_of_bounds", - "position": lsp.Position(line=37, character=45), - "active_parameter": 1, - "doc_contains": "**`SKP CMASK, N`** allows conditional program execution", - "param_contains": "N: Decimal (1-63) | Label", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # The "first" argument of CHO RDA should be N, not RDA - "name": "cho_rda", - "position": lsp.Position(line=85, character=8), - "active_parameter": 0, - "doc_contains": "**`CHO RDA, N, C, D`**, like the `RDA` instruction", - "param_contains": "N: LFO select: SIN0,SIN1,RMP0,RMP1", - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # Triggering signature help before finishing the opcode should return None - "name": "cho_rda_unfinished", - "position": lsp.Position(line=85, character=0), - "active_parameter": None, - "doc_contains": None, - "param_contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, - { - # Triggering signature help after finishing, before the comma in a multi-word - # instruction should return none - "name": "cho_rda_before_comma", - "position": lsp.Position(line=85, character=7), - "active_parameter": None, - "doc_contains": None, - "param_contains": None, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - }, +TEST_CASES: list[SignatureHelpTestCase] = [ + SignatureHelpTestCase( + name="no_opcode", + position=lsp.Position(line=8, character=3), + active_parameter=None, + doc_contains=None, + param_contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="skp_first_arg", + position=lsp.Position(line=37, character=4), + active_parameter=0, + doc_contains="**`SKP CMASK, N`** allows conditional program execution", + param_contains="CMASK: Binary | Hex ($00-$1F) | Symbolic", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="skp_second_arg", + position=lsp.Position(line=37, character=8), + active_parameter=1, + doc_contains="**`SKP CMASK, N`** allows conditional program execution", + param_contains="N: Decimal (1-63) | Label", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="skp_out_of_bounds", + position=lsp.Position(line=37, character=45), + active_parameter=1, + doc_contains="**`SKP CMASK, N`** allows conditional program execution", + param_contains="N: Decimal (1-63) | Label", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="cho_rda", + position=lsp.Position(line=85, character=8), + active_parameter=0, + doc_contains="**`CHO RDA, N, C, D`**, like the `RDA` instruction", + param_contains="N: LFO select: SIN0,SIN1,RMP0,RMP1", + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="cho_rda_unfinished", + position=lsp.Position(line=85, character=0), + active_parameter=None, + doc_contains=None, + param_contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), + SignatureHelpTestCase( + name="cho_rda_before_comma", + position=lsp.Position(line=85, character=7), + active_parameter=None, + doc_contains=None, + param_contains=None, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + ), ] -@pytest.mark.parametrize("test_case", SIGNATURE_HELPS, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_signature_help(test_case: SignatureHelpTestCase, client: LanguageClient): result = await client.text_document_signature_help_async( @@ -92,18 +89,18 @@ async def test_signature_help(test_case: SignatureHelpTestCase, client: Language trigger_kind=lsp.SignatureHelpTriggerKind.TriggerCharacter, is_retrigger=False, ), - position=test_case["position"], - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + position=test_case.position, + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - if test_case["active_parameter"] is None: + if test_case.active_parameter is None: assert not result return sig: lsp.SignatureInformation = result.signatures[result.active_signature] param: lsp.ParameterInformation = sig.parameters[result.active_parameter] - assert test_case["active_parameter"] == result.active_parameter - assert test_case["doc_contains"] in str(sig.documentation) - assert test_case["param_contains"] in param.label + assert test_case.active_parameter == result.active_parameter + assert test_case.doc_contains in str(sig.documentation) + assert test_case.param_contains in param.label diff --git a/tests/server_tests/test_symbol_definition.py b/tests/server_tests/test_symbol_definition.py index 1e469df..0b1270a 100644 --- a/tests/server_tests/test_symbol_definition.py +++ b/tests/server_tests/test_symbol_definition.py @@ -1,12 +1,15 @@ from __future__ import annotations +from dataclasses import dataclass + import lsprotocol.types as lsp import pytest from pytest_lsp import LanguageClient -from ..conftest import PATCH_DIR, TestCase +from ..conftest import PATCH_DIR, TestCase, parametrize_cases +@dataclass class SymbolDefinitionTestCase(TestCase): """A dictionary to record definition locations for a symbol.""" @@ -15,41 +18,38 @@ class SymbolDefinitionTestCase(TestCase): uri: str -SYMBOL_DEFINITIONS: list[SymbolDefinitionTestCase] = [ - { - # Variable - "name": "apout", - "kind": lsp.SymbolKind.Variable, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "range": lsp.Range( +TEST_CASES: list[SymbolDefinitionTestCase] = [ + SymbolDefinitionTestCase( + name="apout", + kind=lsp.SymbolKind.Variable, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( start=lsp.Position(line=23, character=4), end=lsp.Position(line=23, character=9), ), - }, - { - # Memory - "name": "lap2a", - "kind": lsp.SymbolKind.Variable, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "range": lsp.Range( + ), + SymbolDefinitionTestCase( + name="lap2a", + kind=lsp.SymbolKind.Variable, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( start=lsp.Position(line=16, character=4), end=lsp.Position(line=16, character=9), ), - }, - { - # Label - "name": "endclr", - "kind": lsp.SymbolKind.Module, - "uri": f"file:///{PATCH_DIR / 'Basic.spn'}", - "range": lsp.Range( + ), + SymbolDefinitionTestCase( + name="endclr", + kind=lsp.SymbolKind.Module, + uri=f"file:///{PATCH_DIR / 'Basic.spn'}", + range=lsp.Range( start=lsp.Position(line=41, character=0), end=lsp.Position(line=41, character=6), ), - }, + ), ] -@pytest.mark.parametrize("test_case", SYMBOL_DEFINITIONS, ids=lambda x: x["name"]) +@parametrize_cases(TEST_CASES) @pytest.mark.asyncio() async def test_symbol_definitions( test_case: SymbolDefinitionTestCase, client: LanguageClient @@ -57,13 +57,13 @@ async def test_symbol_definitions( """Test that the definitions of all symbols in the document are returned.""" result = await client.text_document_document_symbol_async( params=lsp.DocumentSymbolParams( - text_document=lsp.TextDocumentIdentifier(uri=test_case["uri"]), + text_document=lsp.TextDocumentIdentifier(uri=test_case.uri), ) ) - matching = [item for item in result if item.name == test_case["name"].upper()] - assert matching, f"Symbol {test_case['name'].upper()} not in document symbols" + matching = [item for item in result if item.name == test_case.name.upper()] + assert matching, f"Symbol {test_case.name.upper()} not in document symbols" item = matching[0] - assert item.kind == test_case["kind"] - assert item.range == test_case["range"] + assert item.kind == test_case.kind + assert item.range == test_case.range diff --git a/tests/test_parser.py b/tests/test_parser.py index 5742703..b404fc9 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2,111 +2,17 @@ from __future__ import annotations -import lsprotocol.types as lsp import pytest -from spinasm_lsp.parser import SPINAsmParser, Token, TokenRegistry +from spinasm_lsp.parser import SPINAsmParser -from .conftest import PATCH_DIR, TEST_PATCHES +from .conftest import TEST_PATCHES @pytest.mark.parametrize("patch", TEST_PATCHES, ids=lambda x: x.stem) def test_example_patches(patch): """Test that the example patches from SPINAsm are parsable.""" with open(patch, encoding="utf-8") as f: - assert SPINAsmParser(f.read()).parse() + parser = SPINAsmParser(f.read()).parse() - -@pytest.fixture() -def sentence_token_registry() -> tuple[str, TokenRegistry]: - """A sentence with a token registry for each word.""" - sentence = "This is a line with words." - - # Build a list of word tokens, ignoring whitespace. We'll build the tokens - # consistently with asfv1 parsed tokens. - words = list(filter(lambda x: x, sentence.split(" "))) - token_vals = [{"type": "LABEL", "txt": w, "stxt": w, "val": None} for w in words] - tokens = [] - col = 0 - - for t in token_vals: - start = sentence.index(t["txt"], col) - token = Token(t, start=lsp.Position(line=0, character=start)) - col = token.range.end.character + 1 - - tokens.append(token) - - return sentence, TokenRegistry(tokens) - - -def test_get_token_from_registry(sentence_token_registry): - """Test that tokens are correctly retrieved by position from a registry.""" - sentence, reg = sentence_token_registry - - # Manually build a mapping of column indexes to expected token words. Note that - # each word includes the whitespace immediately after it, which is consistent with - # other LSPs, and that all other whitespace is None. - token_positions = {i: None for i in range(len(sentence))} - for i in range(0, 5): - token_positions[i] = "This" - for i in range(7, 10): - token_positions[i] = "is" - for i in range(10, 12): - token_positions[i] = "a" - for i in range(12, 17): - token_positions[i] = "line" - for i in range(20, 25): - token_positions[i] = "with" - for i in range(25, 32): - token_positions[i] = "words." - - for i, word in token_positions.items(): - found_tok = reg.get_token_at_position(lsp.Position(line=0, character=i)) - found_val = found_tok.symbol["txt"] if found_tok is not None else found_tok - msg = f"Expected token `{word}` at col {i}, found `{found_val}`" - assert found_val == word, msg - - -def test_get_token_at_invalid_position_returns_none(sentence_token_registry): - """Test that retrieving tokens from out of bounds always returns None.""" - _, reg = sentence_token_registry - - assert reg.get_token_at_position(lsp.Position(line=99, character=99)) is None - - -def test_get_token_positions(): - """Test getting all positions of a token from a registry.""" - patch = PATCH_DIR / "Basic.spn" - with open(patch) as fp: - source = fp.read() - - parser = SPINAsmParser(source).parse() - - all_matches = parser.token_registry.get_matching_tokens("apout") - assert len(all_matches) == 4 - assert [t.range.start.line for t in all_matches] == [23, 57, 60, 70] - - -def test_concatenate_cho_rdal_tokens(): - """Test that CHO and RDAL tokens are concatenated correctly into CHO RDAL.""" - cho_rdal = Token( - symbol={"type": "MNEMONIC", "txt": "cho", "stxt": "CHO", "val": None}, - start=lsp.Position(line=0, character=0), - ).concatenate( - Token( - symbol={"type": "LABEL", "txt": "rdal", "stxt": "RDAL", "val": None}, - # Put whitespace between CHO and RDAL to test that range is calculated - start=lsp.Position(line=0, character=10), - ) - ) - - assert cho_rdal.symbol == { - "type": "MNEMONIC", - "txt": "cho rdal", - "stxt": "CHO RDAL", - "val": None, - } - - assert cho_rdal.range == lsp.Range( - start=lsp.Position(line=0, character=0), end=lsp.Position(line=0, character=14) - ) + assert list(parser.evaluated_tokens) diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 0000000..f62b744 --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import lsprotocol.types as lsp +import pytest + +from spinasm_lsp.parser import SPINAsmParser +from spinasm_lsp.tokens import ASFV1Token, LSPToken, TokenLookup + +from .conftest import PATCH_DIR, TestCase, parametrize_cases + + +@dataclass +class TokenSemanticsTestCase(TestCase): + """A dictionary to record prepare rename results for a symbol.""" + + token: LSPToken + encoding: list[int] + type: lsp.SemanticTokenTypes + modifiers: list[lsp.SemanticTokenModifiers] = field(default_factory=list) + prev_token_start: lsp.Position = field( + default_factory=lambda: lsp.Position(line=0, character=0) + ) + + +TOKEN_SEMANTICS: list[TokenSemanticsTestCase] = [ + TokenSemanticsTestCase( + name="skp at start", + token=LSPToken( + type="MNEMONIC", + stxt="SKP", + range=lsp.Range(lsp.Position(0, 0), lsp.Position(0, 2)), + ), + encoding=[0, 0, 3, 12, 0b0], + type=lsp.SemanticTokenTypes.Function, + ), + TokenSemanticsTestCase( + name="variable on newline", + token=LSPToken( + type="LABEL", + stxt="TMP", + range=lsp.Range(lsp.Position(10, 0), lsp.Position(10, 2)), + ), + encoding=[9, 0, 3, 8, 0b0], + type=lsp.SemanticTokenTypes.Variable, + prev_token_start=lsp.Position(line=1, character=8), + ), + TokenSemanticsTestCase( + name="constant after token", + token=LSPToken( + type="LABEL", + stxt="REG0", + range=lsp.Range(lsp.Position(3, 15), lsp.Position(3, 2)), + is_constant=True, + ), + encoding=[0, 5, 4, 8, 0b1000000100], + type=lsp.SemanticTokenTypes.Variable, + modifiers=[ + lsp.SemanticTokenModifiers.Readonly, + lsp.SemanticTokenModifiers.DefaultLibrary, + ], + prev_token_start=lsp.Position(line=3, character=10), + ), +] + + +@parametrize_cases(TOKEN_SEMANTICS) +def test_semantic_tokens(test_case: TokenSemanticsTestCase): + """Test that the semantic tokens are correctly generated.""" + encoding = test_case.token.semantic_encoding(test_case.prev_token_start) + + assert test_case.token.semantic_type == test_case.type + assert test_case.token.semantic_modifiers == test_case.modifiers + assert encoding == test_case.encoding + + +@pytest.fixture() +def sentence_token_lookup() -> tuple[str, TokenLookup]: + """A sentence with a token registry for each word.""" + sentence = "This is a line with words." + + # Build a list of word tokens, ignoring whitespace. We'll build the tokens + # consistently with asfv1 parsed tokens. + words = list(filter(lambda x: x, sentence.split(" "))) + token_vals = [ASFV1Token(type="LABEL", txt=w, stxt=w, val=None) for w in words] + tokens = [] + col = 0 + + lookup = TokenLookup() + for t in token_vals: + start = sentence.index(t.txt, col) + parsed_token = t.at_position(lsp.Position(line=0, character=start)) + eval_token = LSPToken.from_parsed_token(parsed_token) + + col = eval_token.range.end.character + 1 + + tokens.append(eval_token) + lookup.add_token(parsed_token) + + return sentence, lookup + + +def test_get_token_from_registry(sentence_token_lookup: tuple[str, TokenLookup]): + """Test that tokens are correctly retrieved by position from a registry.""" + sentence, lookup = sentence_token_lookup + + # Manually build a mapping of column indexes to expected token words. Note that + # each word includes the whitespace immediately after it, which is consistent with + # other LSPs, and that all other whitespace is None. + token_positions = {i: None for i in range(len(sentence))} + for i in range(0, 5): + token_positions[i] = "This" + for i in range(7, 10): + token_positions[i] = "is" + for i in range(10, 12): + token_positions[i] = "a" + for i in range(12, 17): + token_positions[i] = "line" + for i in range(20, 25): + token_positions[i] = "with" + for i in range(25, 32): + token_positions[i] = "words." + + for i, word in token_positions.items(): + found_tok = lookup.get(position=lsp.Position(line=0, character=i)) + found_val = found_tok.stxt if found_tok is not None else found_tok + msg = f"Expected token `{word}` at col {i}, found `{found_val}`" + assert found_val == word, msg + + +def test_get_token_at_invalid_position_returns_none(sentence_token_lookup): + """Test that retrieving tokens from out of bounds always returns None.""" + _, lookup = sentence_token_lookup + + assert lookup.get(position=lsp.Position(line=99, character=99)) is None + + +def test_get_token_positions(): + """Test getting all positions of a token from a registry.""" + patch = PATCH_DIR / "Basic.spn" + with open(patch) as fp: + source = fp.read() + + parser = SPINAsmParser(source).parse() + + all_matches = parser.evaluated_tokens.get(name="apout") + assert len(all_matches) == 4 + assert [t.range.start.line for t in all_matches] == [23, 57, 60, 70] + + +def test_concatenate_cho_rdal_tokens(): + """Test that CHO and RDAL tokens are concatenated correctly into CHO RDAL.""" + cho = ASFV1Token(type="MNEMONIC", txt="CHO", stxt="CHO", val=None).at_position( + start=lsp.Position(line=0, character=0) + ) + + # Put whitespace between CHO and RDAL to test that range is calculated + rdal = ASFV1Token(type="LABEL", txt="RDAL", stxt="RDAL", val=None).at_position( + start=lsp.Position(line=0, character=10) + ) + + cho_rdal = cho.concatenate(rdal) + + assert cho_rdal.stxt == "CHO RDAL" + assert cho_rdal.type == "MNEMONIC" + assert cho_rdal.range == lsp.Range( + start=lsp.Position(line=0, character=0), end=lsp.Position(line=0, character=14) + )