Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Untracked files handled differently (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Feb 21, 2024
1 parent 3c16705 commit d6c88bf
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 41 deletions.
17 changes: 10 additions & 7 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from mentat.errors import PathValidationError
from mentat.feature_filters.default_filter import DefaultFilter
from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter
from mentat.git_handler import get_paths_with_git_diffs
from mentat.include_files import (
PathType,
get_code_features_for_path,
Expand Down Expand Up @@ -45,6 +44,7 @@ class ContextStreamMessage(TypedDict):
features: List[str]
auto_features: List[str]
git_diff_paths: List[str]
git_untracked_paths: List[str]
total_tokens: int
total_cost: float

Expand Down Expand Up @@ -91,10 +91,12 @@ def refresh_context_display(self):
]
)
auto_features = get_consolidated_feature_refs(self.auto_features)
git_diff_paths = (
list(get_paths_with_git_diffs(self.git_root)) if self.git_root else []
)

if self.diff_context:
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]
else:
git_diff_paths = []
git_untracked_paths = []
messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
Expand Down Expand Up @@ -122,7 +124,8 @@ def refresh_context_display(self):
auto_context_tokens=ctx.config.auto_context_tokens,
features=features,
auto_features=auto_features,
git_diff_paths=[str(p) for p in git_diff_paths],
git_diff_paths=git_diff_paths,
git_untracked_paths=git_untracked_paths,
total_tokens=total_tokens,
total_cost=total_cost,
)
Expand Down Expand Up @@ -151,7 +154,7 @@ async def get_code_message(
if self.diff_context:
# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh_diff_files()
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
Expand Down
20 changes: 15 additions & 5 deletions mentat/diff_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_diff_for_file,
get_files_in_diff,
get_treeish_metadata,
get_untracked_files,
)
from mentat.interval import Interval
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -153,21 +154,30 @@ def __init__(
self.name = name

_diff_files: List[Path] | None = None
_untracked_files: List[Path] | None = None

def diff_files(self) -> List[Path]:
if self._diff_files is None:
self.refresh_diff_files()
self.refresh()
return self._diff_files # pyright: ignore

def refresh_diff_files(self):
session_context = SESSION_CONTEXT.get()
def untracked_files(self) -> List[Path]:
if self._untracked_files is None:
self.refresh()
return self._untracked_files # pyright: ignore

def refresh(self):
ctx = SESSION_CONTEXT.get()

if self.target == "HEAD" and not check_head_exists():
self._diff_files = [] # A new repo without any commits
self._untracked_files = []
else:
self._diff_files = [
(session_context.cwd / f).resolve()
for f in get_files_in_diff(self.target)
(ctx.cwd / f).resolve() for f in get_files_in_diff(self.target)
]
self._untracked_files = [
(ctx.cwd / f).resolve() for f in get_untracked_files(ctx.cwd)
]

def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]:
Expand Down
40 changes: 19 additions & 21 deletions mentat/git_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
from mentat.utils import is_file_text_encoded


def get_untracked_files(root: Path, paths: list[Path] = []) -> list[str]:
"""Returns untracked files. --directory flag is used to show only directories
and not files in the output. A significant performance improvement when things
like node_modules are present."""
command = [
"git",
"ls-files",
"--exclude-standard",
"--others",
"--directory",
] + paths
result = subprocess.run(command, cwd=root, stdout=subprocess.PIPE)
untracked = result.stdout.decode("utf-8").strip()
if untracked != "":
return untracked.split("\n")
else:
return []


def get_non_gitignored_files(root: Path, visited: set[Path] = set()) -> Set[Path]:
paths = set(
# git returns / separated paths even on windows, convert so we can remove
Expand Down Expand Up @@ -49,27 +68,6 @@ def get_non_gitignored_files(root: Path, visited: set[Path] = set()) -> Set[Path
return file_paths


def get_paths_with_git_diffs(git_root: Path) -> set[Path]:
changed = subprocess.check_output(
["git", "diff", "--name-only"],
cwd=git_root,
text=True,
stderr=subprocess.DEVNULL,
).split("\n")
new = subprocess.check_output(
["git", "ls-files", "-o", "--exclude-standard"],
cwd=git_root,
text=True,
stderr=subprocess.DEVNULL,
).split("\n")
return set(
map(
lambda path: Path(os.path.realpath(os.path.join(git_root, Path(path)))),
changed + new,
)
)


def get_git_root_for_path(path: Path, raise_error: bool = True) -> Optional[Path]:
if os.path.isdir(path):
dir_path = path
Expand Down
3 changes: 3 additions & 0 deletions mentat/terminal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ async def _listen_for_context_updates(self):
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
) = (
Expand All @@ -97,6 +98,7 @@ async def _listen_for_context_updates(self):
data["features"],
data["auto_features"],
set(Path(path) for path in data["git_diff_paths"]),
set(Path(path) for path in data["git_untracked_paths"]),
data["total_tokens"],
data["total_cost"],
)
Expand All @@ -107,6 +109,7 @@ async def _listen_for_context_updates(self):
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
)
Expand Down
37 changes: 29 additions & 8 deletions mentat/terminal/terminal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,42 @@ def _build_sub_tree(
root: TreeNode[Any],
children: Dict[str, Any],
git_diff_paths: Set[Path],
untracked_paths: Set[Path],
untracked: bool = False,
):
for child, grandchildren in children.items():
new_path = cur_path / child
path_untracked = new_path in untracked_paths or untracked
if path_untracked:
label = f"[red]! {child}[/red]"
else:
label = child
if not grandchildren:
if new_path in git_diff_paths:
label = f"[green]* {child}[/green]"
else:
label = child
root.add_leaf(label)
else:
child_node = root.add(child, expand=True)
child_node = root.add(label, expand=True)
self._build_sub_tree(
new_path, child_node, grandchildren, git_diff_paths
new_path,
child_node,
grandchildren,
git_diff_paths,
untracked_paths,
path_untracked,
)

def _build_tree_widget(
self, files: list[str], cwd: Path, git_diff_paths: Set[Path]
self,
files: list[str],
cwd: Path,
git_diff_paths: Set[Path],
untracked_paths: Set[Path],
) -> Tree[Any]:
path_tree = self._build_path_tree(files, cwd)
tree: Tree[Any] = Tree(f"[blue]{cwd.name}[/blue]")
tree.root.expand()
self._build_sub_tree(cwd, tree.root, path_tree, git_diff_paths)
self._build_sub_tree(cwd, tree.root, path_tree, git_diff_paths, untracked_paths)
return tree

def update_context(
Expand All @@ -178,11 +192,16 @@ def update_context(
features: List[str],
auto_features: List[str],
git_diff_paths: Set[Path],
git_untracked_paths: Set[Path],
total_tokens: int,
total_cost: float,
):
feature_tree = self._build_tree_widget(features, cwd, git_diff_paths)
auto_feature_tree = self._build_tree_widget(auto_features, cwd, git_diff_paths)
feature_tree = self._build_tree_widget(
features, cwd, git_diff_paths, git_untracked_paths
)
auto_feature_tree = self._build_tree_widget(
auto_features, cwd, git_diff_paths, git_untracked_paths
)

context_header = ""
context_header += "[blue bold]Code Context:[/blue bold]"
Expand Down Expand Up @@ -266,6 +285,7 @@ def update_context(
features: List[str],
auto_features: List[str],
git_diff_paths: Set[Path],
git_untracked_paths: Set[Path],
total_tokens: int,
total_cost: float,
):
Expand All @@ -277,6 +297,7 @@ def update_context(
features,
auto_features,
git_diff_paths,
git_untracked_paths,
total_tokens,
total_cost,
)
Expand Down

0 comments on commit d6c88bf

Please sign in to comment.