Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Untracked files handled differently #528

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from mentat.errors import PathValidationError
from mentat.feature_filters.default_filter import DefaultFilter
from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter
from mentat.git_handler import get_paths_with_git_diffs
from mentat.include_files import (
PathType,
get_code_features_for_path,
Expand Down Expand Up @@ -45,6 +44,7 @@ class ContextStreamMessage(TypedDict):
features: List[str]
auto_features: List[str]
git_diff_paths: List[str]
git_untracked_paths: List[str]
total_tokens: int
total_cost: float

Expand Down Expand Up @@ -91,10 +91,12 @@ def refresh_context_display(self):
]
)
auto_features = get_consolidated_feature_refs(self.auto_features)
git_diff_paths = (
list(get_paths_with_git_diffs(self.git_root)) if self.git_root else []
)

if self.diff_context:
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]
else:
git_diff_paths = []
git_untracked_paths = []
messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
Expand Down Expand Up @@ -122,7 +124,8 @@ def refresh_context_display(self):
auto_context_tokens=ctx.config.auto_context_tokens,
features=features,
auto_features=auto_features,
git_diff_paths=[str(p) for p in git_diff_paths],
git_diff_paths=git_diff_paths,
git_untracked_paths=git_untracked_paths,
total_tokens=total_tokens,
total_cost=total_cost,
)
Expand Down Expand Up @@ -151,7 +154,7 @@ async def get_code_message(
if self.diff_context:
# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh_diff_files()
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
Expand Down
20 changes: 15 additions & 5 deletions mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_diff_for_file,
get_files_in_diff,
get_treeish_metadata,
get_untracked_files,
)
from mentat.interval import Interval
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -153,21 +154,30 @@ def __init__(
self.name = name

_diff_files: List[Path] | None = None
_untracked_files: List[Path] | None = None

def diff_files(self) -> List[Path]:
if self._diff_files is None:
self.refresh_diff_files()
self.refresh()
return self._diff_files # pyright: ignore

def refresh_diff_files(self):
session_context = SESSION_CONTEXT.get()
def untracked_files(self) -> List[Path]:
if self._untracked_files is None:
self.refresh()
return self._untracked_files # pyright: ignore

def refresh(self):
ctx = SESSION_CONTEXT.get()

if self.target == "HEAD" and not check_head_exists():
self._diff_files = [] # A new repo without any commits
self._untracked_files = []
else:
self._diff_files = [
(session_context.cwd / f).resolve()
for f in get_files_in_diff(self.target)
(ctx.cwd / f).resolve() for f in get_files_in_diff(self.target)
]
self._untracked_files = [
(ctx.cwd / f).resolve() for f in get_untracked_files(ctx.cwd)
]

def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]:
Expand Down
40 changes: 19 additions & 21 deletions mentat/git_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
from mentat.utils import is_file_text_encoded


def get_untracked_files(root: Path, paths: list[Path] = []) -> list[str]:
"""Returns untracked files. --directory flag is used to show only directories
and not files in the output. A significant performance improvement when things
like node_modules are present."""
command = [
"git",
"ls-files",
"--exclude-standard",
"--others",
"--directory",
] + paths
result = subprocess.run(command, cwd=root, stdout=subprocess.PIPE)
untracked = result.stdout.decode("utf-8").strip()
if untracked != "":
return untracked.split("\n")
else:
return []


def get_non_gitignored_files(root: Path, visited: set[Path] = set()) -> Set[Path]:
paths = set(
# git returns / separated paths even on windows, convert so we can remove
Expand Down Expand Up @@ -49,27 +68,6 @@ def get_non_gitignored_files(root: Path, visited: set[Path] = set()) -> Set[Path
return file_paths


def get_paths_with_git_diffs(git_root: Path) -> set[Path]:
changed = subprocess.check_output(
["git", "diff", "--name-only"],
cwd=git_root,
text=True,
stderr=subprocess.DEVNULL,
).split("\n")
new = subprocess.check_output(
["git", "ls-files", "-o", "--exclude-standard"],
cwd=git_root,
text=True,
stderr=subprocess.DEVNULL,
).split("\n")
return set(
map(
lambda path: Path(os.path.realpath(os.path.join(git_root, Path(path)))),
changed + new,
)
)


def get_git_root_for_path(path: Path, raise_error: bool = True) -> Optional[Path]:
if os.path.isdir(path):
dir_path = path
Expand Down
3 changes: 3 additions & 0 deletions mentat/terminal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ async def _listen_for_context_updates(self):
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
) = (
Expand All @@ -97,6 +98,7 @@ async def _listen_for_context_updates(self):
data["features"],
data["auto_features"],
set(Path(path) for path in data["git_diff_paths"]),
set(Path(path) for path in data["git_untracked_paths"]),
data["total_tokens"],
data["total_cost"],
)
Expand All @@ -107,6 +109,7 @@ async def _listen_for_context_updates(self):
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
)
Expand Down
37 changes: 29 additions & 8 deletions mentat/terminal/terminal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,42 @@ def _build_sub_tree(
root: TreeNode[Any],
children: Dict[str, Any],
git_diff_paths: Set[Path],
untracked_paths: Set[Path],
untracked: bool = False,
):
for child, grandchildren in children.items():
new_path = cur_path / child
path_untracked = new_path in untracked_paths or untracked
if path_untracked:
label = f"[red]! {child}[/red]"
else:
label = child
if not grandchildren:
if new_path in git_diff_paths:
label = f"[green]* {child}[/green]"
else:
label = child
root.add_leaf(label)
else:
child_node = root.add(child, expand=True)
child_node = root.add(label, expand=True)
self._build_sub_tree(
new_path, child_node, grandchildren, git_diff_paths
new_path,
child_node,
grandchildren,
git_diff_paths,
untracked_paths,
path_untracked,
)

def _build_tree_widget(
self, files: list[str], cwd: Path, git_diff_paths: Set[Path]
self,
files: list[str],
cwd: Path,
git_diff_paths: Set[Path],
untracked_paths: Set[Path],
) -> Tree[Any]:
path_tree = self._build_path_tree(files, cwd)
tree: Tree[Any] = Tree(f"[blue]{cwd.name}[/blue]")
tree.root.expand()
self._build_sub_tree(cwd, tree.root, path_tree, git_diff_paths)
self._build_sub_tree(cwd, tree.root, path_tree, git_diff_paths, untracked_paths)
return tree

def update_context(
Expand All @@ -178,11 +192,16 @@ def update_context(
features: List[str],
auto_features: List[str],
git_diff_paths: Set[Path],
git_untracked_paths: Set[Path],
total_tokens: int,
total_cost: float,
):
feature_tree = self._build_tree_widget(features, cwd, git_diff_paths)
auto_feature_tree = self._build_tree_widget(auto_features, cwd, git_diff_paths)
feature_tree = self._build_tree_widget(
features, cwd, git_diff_paths, git_untracked_paths
)
auto_feature_tree = self._build_tree_widget(
auto_features, cwd, git_diff_paths, git_untracked_paths
)

context_header = ""
context_header += "[blue bold]Code Context:[/blue bold]"
Expand Down Expand Up @@ -266,6 +285,7 @@ def update_context(
features: List[str],
auto_features: List[str],
git_diff_paths: Set[Path],
git_untracked_paths: Set[Path],
total_tokens: int,
total_cost: float,
):
Expand All @@ -277,6 +297,7 @@ def update_context(
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
)
Expand Down
Loading