Skip to content

Commit

Permalink
Fixes + cleaning - Cmd and profiles (#19)
Browse files Browse the repository at this point in the history
Fixed multiple minor issues surrounding cmd mode which were bugging me.

Specifically:
* Upgraded default cmd prompt
* Added compatibility for cmd mode with profile, prompted _hard_ to make llm only suggest a cmd
* Output of cmd is now streamed to stdout and stderr, instead of printed afterward
* Single quotes are no longer removed as these doesn't seem to be missinterpreted as shell escaped as often
* Removed full printout of tools call json, as this was verbose, ugly and unreadable
  • Loading branch information
baalimago authored Jul 23, 2024
1 parent e7910e2 commit 1bb3aa6
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 40 deletions.
2 changes: 1 addition & 1 deletion internal/text/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Profile struct {
var DEFAULT = Configurations{
Model: "gpt-4o",
SystemPrompt: "You are an assistant for a CLI tool. Answer concisely and informatively. Prefer markdown if possible.",
CmdModePrompt: "You are an assistant for a CLI tool aiding with cli tool suggestions. Write ONLY the command and nothing else.",
CmdModePrompt: "You are an assistant for a CLI tool aiding with cli tool suggestions. Write ONLY the command and nothing else. Disregard any queries asking for anything except a bash command. Do not shell escape single or double quotes.",
Raw: false,
UseTools: false,
// Aproximately $1 for the worst input rates as of 2024-05
Expand Down
9 changes: 7 additions & 2 deletions internal/text/conf_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ func (c *Configurations) ProfileOverrides() error {
return fmt.Errorf("failed to find profile: %w", err)
}
c.Model = profile.Model
c.SystemPrompt = profile.Prompt
c.UseTools = profile.UseTools
newPrompt := profile.Prompt
if c.CmdMode {
// SystmePrompt here is CmdPrompt, keep it and remoind llm to only suggest cmd
newPrompt = fmt.Sprintf("You will get this pattern: || <cmd-prompt> | <custom guided profile> ||. It is VERY vital that you DO NOT disobey the <cmd-prompt> with whatever is posted in <custom guided profile. || %v| %v ||", c.CmdModePrompt, profile.Prompt)
}
c.SystemPrompt = newPrompt
c.UseTools = profile.UseTools && !c.CmdMode
c.Tools = profile.Tools
c.SaveReplyAsConv = profile.SaveReplyAsConv
return nil
Expand Down
9 changes: 8 additions & 1 deletion internal/text/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Querier[C models.StreamCompleter] struct {
Model C
tokenWarnLimit int
cmdMode bool
execErr error
}

// Query using the underlying model to stream completions and then print the output
Expand Down Expand Up @@ -75,6 +76,8 @@ func (q *Querier[C]) Query(ctx context.Context) error {
}
return nil
}
// Only add error if its not EOF or context.Canceled
q.execErr = err

if q.debug {
ancli.PrintOK("exiting querier due to EOF error\n")
Expand Down Expand Up @@ -181,6 +184,7 @@ func (q *Querier[C]) postProcessOutput(newSysMsg models.Message) {
}

func (q *Querier[C]) reset() {
q.execErr = nil
q.fullMsg = ""
q.line = ""
q.lineCount = 0
Expand Down Expand Up @@ -225,6 +229,9 @@ func (q *Querier[C]) handleCompletion(ctx context.Context, completion models.Com

// handleFunctionCall by invoking the call, and then resondng to the ai with the output
func (q *Querier[C]) handleFunctionCall(ctx context.Context, call tools.Call) error {
if q.cmdMode {
return errors.New("cant call tools in cmd mode")
}
// Whatever is in q.fullMessage now is what the AI has streamed before the function call
// which normally is handeled by the supercallee of Query, now we need to handle it here.
// There's room for improvement of this system..
Expand Down Expand Up @@ -261,7 +268,7 @@ func (q *Querier[C]) handleFunctionCall(ctx context.Context, call tools.Call) er
}
assistantToolsCall := models.Message{
Role: "assistant",
Content: fmt.Sprintf("tool_calls:\n%v", call.Json()),
Content: fmt.Sprintf("tool calls for: %v", call.Name),
ToolCalls: []tools.Call{call},
}
q.reset()
Expand Down
38 changes: 16 additions & 22 deletions internal/text/querier_cmd_mode.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
package text

import (
"bytes"
"errors"
"fmt"
"os"
"os/exec"
"strings"

"github.com/baalimago/clai/internal/utils"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
)

var (
errFormat = "code: %v, stderr: '%v', stdout: '%v'\n"
okFormat = "stdout on new line:\n%v\n"
)
var errFormat = "code: %v, stderr: '%v'\n"

func (q *Querier[C]) handleCmdMode() error {
// Tokens stream end without endline
fmt.Println()
var input string

if q.execErr != nil {
return nil
}

for {
fmt.Print("Do you want to [e]xecute cmd, [q]uit?: ")
fmt.Scanln(&input)
switch strings.ToLower(input) {
case "q":
return nil
case "e":
out, err := q.executeLlmCmd()
err := q.executeLlmCmd()
if err == nil {
ancli.PrintOK(fmt.Sprintf("%v\n", out))
return nil
} else {
return fmt.Errorf("failed to execute cmd: %v", err)
Expand All @@ -41,45 +41,39 @@ func (q *Querier[C]) handleCmdMode() error {
}
}

func (q *Querier[C]) executeLlmCmd() (string, error) {
func (q *Querier[C]) executeLlmCmd() error {
fullMsg, err := utils.ReplaceTildeWithHome(q.fullMsg)
if err != nil {
return "", fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err)
return fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err)
}
// Quotes are, in 99% of the time, expanded by the shell in
// different ways and then passed into the shell. So when LLM
// suggests a command, executeAiCmd needs to act the same (meaning)
// remove/expand the quotes
fullMsg = strings.ReplaceAll(fullMsg, "\"", "")
fullMsg = strings.ReplaceAll(fullMsg, "'", "")
split := strings.Split(fullMsg, " ")
if len(split) < 1 {
return "", errors.New("Querier.executeAiCmd: too few tokens in q.fullMsg")
return errors.New("Querier.executeAiCmd: too few tokens in q.fullMsg")
}
cmd := split[0]
args := split[1:]

if len(cmd) == 0 {
return "", errors.New("Querier.executeAiCmd: command is empty")
return errors.New("Querier.executeAiCmd: command is empty")
}

command := exec.Command(cmd, args...)
var stdout bytes.Buffer
var stderr bytes.Buffer
command.Stdout = &stdout
command.Stderr = &stderr
command.Stdout = os.Stdout
command.Stderr = os.Stderr
err = command.Run()
outStr := stdout.String()
errStr := stderr.String()

if err != nil {
cast := &exec.ExitError{}
if errors.As(err, &cast) {
return "", fmt.Errorf(errFormat, cast.ExitCode(), errStr, outStr)
return fmt.Errorf(errFormat, cast.ExitCode())
} else {
return "", fmt.Errorf("Querier.executeAiCmd - run error: %w", err)
return fmt.Errorf("Querier.executeAiCmd - run error: %w", err)
}
}

return fmt.Sprintf(okFormat, outStr), nil
return nil
}
28 changes: 15 additions & 13 deletions internal/text/querier_cmd_mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package text

import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -32,7 +31,7 @@ func Test_executeAiCmd(t *testing.T) {
{
description: "it should run shell cmd",
given: "printf 'test'",
want: fmt.Sprintf(okFormat, "test"),
want: "'test'",
wantErr: nil,
},
{
Expand All @@ -42,7 +41,7 @@ func Test_executeAiCmd(t *testing.T) {
os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name()))
},
given: "find ./ -name \"testfile\"",
want: fmt.Sprintf(okFormat, "./testfile\n"),
want: "./testfile\n",
wantErr: nil,
},
{
Expand All @@ -52,22 +51,25 @@ func Test_executeAiCmd(t *testing.T) {
os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name()))
},
given: "find ./ -name testfile",
want: fmt.Sprintf(okFormat, "./testfile\n"),
want: "./testfile\n",
wantErr: nil,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
q := Querier[mockCompleter]{}
if tc.setup != nil {
tc.setup(t)
}
q.fullMsg = tc.given
gotFormated, gotErr := q.executeLlmCmd()

if gotFormated != tc.want {
t.Fatalf("expected: %v, got: %v", tc.want, gotFormated)
var gotErr error
got := testboil.CaptureStdout(t, func(t *testing.T) {
q := Querier[mockCompleter]{}
if tc.setup != nil {
tc.setup(t)
}
q.fullMsg = tc.given
tmp := q.executeLlmCmd()
gotErr = tmp
})
if got != tc.want {
t.Fatalf("expected: %v, got: %v", tc.want, got)
}

if gotErr != tc.wantErr {
Expand Down
2 changes: 1 addition & 1 deletion internal/text/querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func Test_Querier_eventHandling(t *testing.T) {
},
"CLOSE",
},
want: "tool_calls",
want: "tool calls",
},
}
for _, tC := range testCases {
Expand Down

0 comments on commit 1bb3aa6

Please sign in to comment.