mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
1105 lines
30 KiB
Go
1105 lines
30 KiB
Go
// Package agent provides agent loop orchestration and tool approval.
|
|
package agent
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
// ApprovalDecision represents the user's decision for a tool execution.
|
|
type ApprovalDecision int
|
|
|
|
const (
|
|
// ApprovalDeny means the user denied execution.
|
|
ApprovalDeny ApprovalDecision = iota
|
|
// ApprovalOnce means execute this one time only.
|
|
ApprovalOnce
|
|
// ApprovalAlways means add to session allowlist.
|
|
ApprovalAlways
|
|
)
|
|
|
|
// ApprovalResult contains the decision and optional deny reason.
|
|
type ApprovalResult struct {
|
|
Decision ApprovalDecision
|
|
DenyReason string
|
|
}
|
|
|
|
// Option labels for the selector (numbered for quick selection)
|
|
var optionLabels = []string{
|
|
"1. Execute once",
|
|
"2. Allow for this session",
|
|
"3. Deny",
|
|
}
|
|
|
|
// toolDisplayNames maps internal tool names to human-readable display names.
|
|
var toolDisplayNames = map[string]string{
|
|
"bash": "Bash",
|
|
"web_search": "Web Search",
|
|
}
|
|
|
|
// ToolDisplayName returns the human-readable display name for a tool.
|
|
func ToolDisplayName(toolName string) string {
|
|
if displayName, ok := toolDisplayNames[toolName]; ok {
|
|
return displayName
|
|
}
|
|
// Default: capitalize first letter and replace underscores with spaces
|
|
name := strings.ReplaceAll(toolName, "_", " ")
|
|
if len(name) > 0 {
|
|
return strings.ToUpper(name[:1]) + name[1:]
|
|
}
|
|
return toolName
|
|
}
|
|
|
|
// autoAllowCommands are commands that are always allowed without prompting.
|
|
// These are zero-risk, read-only commands.
|
|
var autoAllowCommands = map[string]bool{
|
|
"pwd": true,
|
|
"echo": true,
|
|
"date": true,
|
|
"whoami": true,
|
|
"hostname": true,
|
|
"uname": true,
|
|
}
|
|
|
|
// autoAllowPrefixes are command prefixes that are always allowed.
|
|
// These are read-only or commonly-needed development commands.
|
|
var autoAllowPrefixes = []string{
|
|
// Git read-only
|
|
"git status", "git log", "git diff", "git branch", "git show",
|
|
"git remote -v", "git tag", "git stash list",
|
|
// Package managers - run scripts
|
|
"npm run", "npm test", "npm start",
|
|
"bun run", "bun test",
|
|
"uv run",
|
|
"yarn run", "yarn test",
|
|
"pnpm run", "pnpm test",
|
|
// Package info
|
|
"go list", "go version", "go env",
|
|
"npm list", "npm ls", "npm version",
|
|
"pip list", "pip show",
|
|
"cargo tree", "cargo version",
|
|
// Build commands
|
|
"go build", "go test", "go fmt", "go vet",
|
|
"make", "cmake",
|
|
"cargo build", "cargo test", "cargo check",
|
|
}
|
|
|
|
// denyPatterns are dangerous command patterns that are always blocked.
|
|
var denyPatterns = []string{
|
|
// Destructive commands
|
|
"rm -rf", "rm -fr",
|
|
"mkfs", "dd if=", "dd of=",
|
|
"shred",
|
|
"> /dev/", ">/dev/",
|
|
// Privilege escalation
|
|
"sudo ", "su ", "doas ",
|
|
"chmod 777", "chmod -R 777",
|
|
"chown ", "chgrp ",
|
|
// Network exfiltration
|
|
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
|
"wget --post",
|
|
"nc ", "netcat ",
|
|
"scp ", "rsync ",
|
|
// History and credentials
|
|
"history",
|
|
".bash_history", ".zsh_history",
|
|
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
|
".ssh/config",
|
|
".aws/credentials", ".aws/config",
|
|
".gnupg/",
|
|
"/etc/shadow", "/etc/passwd",
|
|
// Dangerous patterns
|
|
":(){ :|:& };:", // fork bomb
|
|
"chmod +s", // setuid
|
|
"mkfifo",
|
|
}
|
|
|
|
// denyPathPatterns are file patterns that should never be accessed.
|
|
// These are checked as exact filename matches or path suffixes.
|
|
var denyPathPatterns = []string{
|
|
".env",
|
|
".env.local",
|
|
".env.production",
|
|
"credentials.json",
|
|
"secrets.json",
|
|
"secrets.yaml",
|
|
"secrets.yml",
|
|
".pem",
|
|
".key",
|
|
}
|
|
|
|
// ApprovalManager manages tool execution approvals.
|
|
type ApprovalManager struct {
|
|
allowlist map[string]bool // exact matches
|
|
prefixes map[string]bool // prefix matches for bash commands (e.g., "cat:tools/")
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewApprovalManager creates a new approval manager.
|
|
func NewApprovalManager() *ApprovalManager {
|
|
return &ApprovalManager{
|
|
allowlist: make(map[string]bool),
|
|
prefixes: make(map[string]bool),
|
|
}
|
|
}
|
|
|
|
// IsAutoAllowed checks if a bash command is auto-allowed (no prompt needed).
|
|
func IsAutoAllowed(command string) bool {
|
|
command = strings.TrimSpace(command)
|
|
|
|
// Check exact command match (first word)
|
|
fields := strings.Fields(command)
|
|
if len(fields) > 0 && autoAllowCommands[fields[0]] {
|
|
return true
|
|
}
|
|
|
|
// Check prefix match
|
|
for _, prefix := range autoAllowPrefixes {
|
|
if strings.HasPrefix(command, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// IsDenied checks if a bash command matches deny patterns.
|
|
// Returns true and the matched pattern if denied.
|
|
func IsDenied(command string) (bool, string) {
|
|
commandLower := strings.ToLower(command)
|
|
|
|
// Check deny patterns
|
|
for _, pattern := range denyPatterns {
|
|
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
|
return true, pattern
|
|
}
|
|
}
|
|
|
|
// Check deny path patterns
|
|
for _, pattern := range denyPathPatterns {
|
|
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
|
return true, pattern
|
|
}
|
|
}
|
|
|
|
return false, ""
|
|
}
|
|
|
|
// FormatDeniedResult returns the tool result message when a command is blocked.
|
|
func FormatDeniedResult(command string, pattern string) string {
|
|
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
|
|
}
|
|
|
|
// extractBashPrefix extracts a prefix pattern from a bash command.
|
|
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
|
// For commands without path args, returns empty string.
|
|
// Paths with ".." traversal that escape the base directory return empty string for security.
|
|
func extractBashPrefix(command string) string {
|
|
// Split command by pipes and get the first part
|
|
parts := strings.Split(command, "|")
|
|
firstCmd := strings.TrimSpace(parts[0])
|
|
|
|
// Split into command and args
|
|
fields := strings.Fields(firstCmd)
|
|
if len(fields) < 2 {
|
|
return ""
|
|
}
|
|
|
|
baseCmd := fields[0]
|
|
// Common commands that benefit from prefix allowlisting
|
|
// These are typically safe for read operations on specific directories
|
|
safeCommands := map[string]bool{
|
|
"cat": true, "ls": true, "head": true, "tail": true,
|
|
"less": true, "more": true, "file": true, "wc": true,
|
|
"grep": true, "find": true, "tree": true, "stat": true,
|
|
"sed": true,
|
|
}
|
|
|
|
if !safeCommands[baseCmd] {
|
|
return ""
|
|
}
|
|
|
|
// Find the first path-like argument (must contain / or \ or start with .)
|
|
// First pass: look for clear paths (containing path separators or starting with .)
|
|
for _, arg := range fields[1:] {
|
|
// Skip flags
|
|
if strings.HasPrefix(arg, "-") {
|
|
continue
|
|
}
|
|
// Skip numeric arguments (e.g., "head -n 100")
|
|
if isNumeric(arg) {
|
|
continue
|
|
}
|
|
// Only process if it looks like a path (contains / or \ or starts with .)
|
|
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
|
|
continue
|
|
}
|
|
// Normalize to forward slashes for consistent cross-platform matching
|
|
arg = strings.ReplaceAll(arg, "\\", "/")
|
|
|
|
// Security: reject absolute paths
|
|
if path.IsAbs(arg) {
|
|
return "" // Absolute path - don't create prefix
|
|
}
|
|
|
|
// Normalize the path using stdlib path.Clean (resolves . and ..)
|
|
cleaned := path.Clean(arg)
|
|
|
|
// Security: reject if cleaned path escapes to parent directory
|
|
if strings.HasPrefix(cleaned, "..") {
|
|
return "" // Path escapes - don't create prefix
|
|
}
|
|
|
|
// Security: if original had "..", verify cleaned path didn't escape to sibling
|
|
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
|
|
if strings.Contains(arg, "..") {
|
|
origBase := strings.SplitN(arg, "/", 2)[0]
|
|
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
|
|
if origBase != cleanedBase {
|
|
return "" // Path escaped to sibling directory
|
|
}
|
|
}
|
|
|
|
// Check if arg ends with / (explicit directory)
|
|
isDir := strings.HasSuffix(arg, "/")
|
|
|
|
// Get the directory part
|
|
var dir string
|
|
if isDir {
|
|
dir = cleaned
|
|
} else {
|
|
dir = path.Dir(cleaned)
|
|
}
|
|
|
|
if dir == "." {
|
|
return fmt.Sprintf("%s:./", baseCmd)
|
|
}
|
|
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
|
}
|
|
|
|
// Second pass: if no clear path found, use the first non-flag argument as a filename
|
|
for _, arg := range fields[1:] {
|
|
if strings.HasPrefix(arg, "-") {
|
|
continue
|
|
}
|
|
if isNumeric(arg) {
|
|
continue
|
|
}
|
|
// Treat as filename in current dir
|
|
return fmt.Sprintf("%s:./", baseCmd)
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// isNumeric checks if a string is a numeric value
|
|
func isNumeric(s string) bool {
|
|
for _, c := range s {
|
|
if c < '0' || c > '9' {
|
|
return false
|
|
}
|
|
}
|
|
return len(s) > 0
|
|
}
|
|
|
|
// isCommandOutsideCwd checks if a bash command targets paths outside the current working directory.
|
|
// Returns true if any path argument would access files outside cwd.
|
|
func isCommandOutsideCwd(command string) bool {
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
return false // Can't determine, assume safe
|
|
}
|
|
|
|
// Split command by pipes and semicolons to check all parts
|
|
parts := strings.FieldsFunc(command, func(r rune) bool {
|
|
return r == '|' || r == ';' || r == '&'
|
|
})
|
|
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
fields := strings.Fields(part)
|
|
if len(fields) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Check each argument that looks like a path
|
|
for _, arg := range fields[1:] {
|
|
// Skip flags
|
|
if strings.HasPrefix(arg, "-") {
|
|
continue
|
|
}
|
|
|
|
// Treat POSIX-style absolute paths as outside cwd on all platforms.
|
|
if strings.HasPrefix(arg, "/") || strings.HasPrefix(arg, "\\") {
|
|
return true
|
|
}
|
|
|
|
// Check for absolute paths outside cwd
|
|
if filepath.IsAbs(arg) {
|
|
absPath := filepath.Clean(arg)
|
|
if !strings.HasPrefix(absPath, cwd) {
|
|
return true
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Check for relative paths that escape cwd (e.g., ../foo, /etc/passwd)
|
|
if strings.HasPrefix(arg, "..") {
|
|
// Resolve the path relative to cwd
|
|
absPath := filepath.Join(cwd, arg)
|
|
absPath = filepath.Clean(absPath)
|
|
if !strings.HasPrefix(absPath, cwd) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check for home directory expansion
|
|
if strings.HasPrefix(arg, "~") {
|
|
home, err := os.UserHomeDir()
|
|
if err == nil && !strings.HasPrefix(home, cwd) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// AllowlistKey generates the key for exact allowlist lookup.
|
|
func AllowlistKey(toolName string, args map[string]any) string {
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
return fmt.Sprintf("bash:%s", cmd)
|
|
}
|
|
}
|
|
return toolName
|
|
}
|
|
|
|
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
|
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
|
|
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
|
|
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
|
a.mu.RLock()
|
|
defer a.mu.RUnlock()
|
|
|
|
// Check exact match first
|
|
key := AllowlistKey(toolName, args)
|
|
if a.allowlist[key] {
|
|
return true
|
|
}
|
|
|
|
// For bash commands, check prefix matches with hierarchical path support
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
prefix := extractBashPrefix(cmd)
|
|
if prefix != "" {
|
|
// Check exact prefix match first
|
|
if a.prefixes[prefix] {
|
|
return true
|
|
}
|
|
// Check hierarchical match: if any stored prefix is a parent of current prefix
|
|
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
|
|
if a.matchesHierarchicalPrefix(prefix) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if tool itself is allowed (non-bash)
|
|
if toolName != "bash" && a.allowlist[toolName] {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
|
|
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
|
|
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
|
|
// Split prefix into command and path parts (format: "cmd:path/")
|
|
colonIdx := strings.Index(currentPrefix, ":")
|
|
if colonIdx == -1 {
|
|
return false
|
|
}
|
|
currentCmd := currentPrefix[:colonIdx]
|
|
currentPath := currentPrefix[colonIdx+1:]
|
|
|
|
for storedPrefix := range a.prefixes {
|
|
storedColonIdx := strings.Index(storedPrefix, ":")
|
|
if storedColonIdx == -1 {
|
|
continue
|
|
}
|
|
storedCmd := storedPrefix[:storedColonIdx]
|
|
storedPath := storedPrefix[storedColonIdx+1:]
|
|
|
|
// Commands must match exactly
|
|
if currentCmd != storedCmd {
|
|
continue
|
|
}
|
|
|
|
// Check if current path starts with stored path (hierarchical match)
|
|
// e.g., "tools/subdir/" starts with "tools/"
|
|
if strings.HasPrefix(currentPath, storedPath) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// AddToAllowlist adds a tool/command to the session allowlist.
|
|
// For bash commands, it adds the prefix pattern instead of exact command.
|
|
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
prefix := extractBashPrefix(cmd)
|
|
if prefix != "" {
|
|
a.prefixes[prefix] = true
|
|
return
|
|
}
|
|
// Fall back to exact match if no prefix extracted
|
|
a.allowlist[fmt.Sprintf("bash:%s", cmd)] = true
|
|
return
|
|
}
|
|
}
|
|
a.allowlist[toolName] = true
|
|
}
|
|
|
|
// RequestApproval prompts the user for approval to execute a tool.
|
|
// Returns the decision and optional deny reason.
|
|
func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any) (ApprovalResult, error) {
|
|
// Format tool info for display
|
|
toolDisplay := formatToolDisplay(toolName, args)
|
|
|
|
// Enter raw mode for interactive selection
|
|
fd := int(os.Stdin.Fd())
|
|
oldState, err := term.MakeRaw(fd)
|
|
if err != nil {
|
|
// Fallback to simple input if terminal control fails
|
|
return a.fallbackApproval(toolDisplay)
|
|
}
|
|
|
|
// Flush any pending stdin input before starting selector
|
|
// This prevents buffered input from causing double-press issues
|
|
flushStdin(fd)
|
|
|
|
isWarning := false
|
|
var warningMsg string
|
|
var allowlistInfo string
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
if isCommandOutsideCwd(cmd) {
|
|
isWarning = true
|
|
warningMsg = "command targets paths outside project"
|
|
}
|
|
if prefix := extractBashPrefix(cmd); prefix != "" {
|
|
colonIdx := strings.Index(prefix, ":")
|
|
if colonIdx != -1 {
|
|
cmdName := prefix[:colonIdx]
|
|
dirPath := prefix[colonIdx+1:]
|
|
if dirPath != "./" {
|
|
allowlistInfo = fmt.Sprintf("%s in %s directory (includes subdirs)", cmdName, dirPath)
|
|
} else {
|
|
allowlistInfo = fmt.Sprintf("%s in %s directory", cmdName, dirPath)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run interactive selector
|
|
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
|
|
if err != nil {
|
|
term.Restore(fd, oldState)
|
|
return ApprovalResult{Decision: ApprovalDeny}, err
|
|
}
|
|
|
|
// Restore terminal
|
|
term.Restore(fd, oldState)
|
|
|
|
// Map selection to decision
|
|
switch selected {
|
|
case -1: // Ctrl+C cancelled
|
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: "cancelled"}, nil
|
|
case 0:
|
|
return ApprovalResult{Decision: ApprovalOnce}, nil
|
|
case 1:
|
|
return ApprovalResult{Decision: ApprovalAlways}, nil
|
|
default:
|
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: denyReason}, nil
|
|
}
|
|
}
|
|
|
|
// formatToolDisplay creates the display string for a tool call.
|
|
func formatToolDisplay(toolName string, args map[string]any) string {
|
|
var sb strings.Builder
|
|
displayName := ToolDisplayName(toolName)
|
|
|
|
// For bash, show command directly
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
|
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
|
return sb.String()
|
|
}
|
|
}
|
|
|
|
// For web search, show query and internet notice
|
|
if toolName == "web_search" {
|
|
if query, ok := args["query"].(string); ok {
|
|
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
|
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
|
sb.WriteString("Uses internet via ollama.com")
|
|
return sb.String()
|
|
}
|
|
}
|
|
|
|
// Generic display
|
|
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
|
if len(args) > 0 {
|
|
sb.WriteString("\nArguments: ")
|
|
first := true
|
|
for k, v := range args {
|
|
if !first {
|
|
sb.WriteString(", ")
|
|
}
|
|
sb.WriteString(fmt.Sprintf("%s=%v", k, v))
|
|
first = false
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
// selectorState holds the state for the interactive selector
|
|
type selectorState struct {
|
|
toolDisplay string
|
|
selected int
|
|
totalLines int
|
|
termWidth int
|
|
termHeight int
|
|
boxWidth int
|
|
innerWidth int
|
|
denyReason string // deny reason (always visible in box)
|
|
isWarning bool // true if command has warning
|
|
warningMessage string // dynamic warning message to display
|
|
allowlistInfo string // show what will be allowlisted (for "Allow for this session" option)
|
|
}
|
|
|
|
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
|
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
|
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
|
|
state := &selectorState{
|
|
toolDisplay: toolDisplay,
|
|
selected: 0,
|
|
isWarning: isWarning,
|
|
warningMessage: warningMessage,
|
|
allowlistInfo: allowlistInfo,
|
|
}
|
|
|
|
// Get terminal size
|
|
state.termWidth, state.termHeight, _ = term.GetSize(fd)
|
|
if state.termWidth < 20 {
|
|
state.termWidth = 80 // fallback
|
|
}
|
|
|
|
// Calculate box width: 90% of terminal, min 24, max 60
|
|
state.boxWidth = (state.termWidth * 90) / 100
|
|
if state.boxWidth > 60 {
|
|
state.boxWidth = 60
|
|
}
|
|
if state.boxWidth < 24 {
|
|
state.boxWidth = 24
|
|
}
|
|
// Ensure box fits in terminal
|
|
if state.boxWidth > state.termWidth-1 {
|
|
state.boxWidth = state.termWidth - 1
|
|
}
|
|
state.innerWidth = state.boxWidth - 4 // account for "│ " and " │"
|
|
|
|
// Calculate total lines (will be updated by render)
|
|
state.totalLines = calculateTotalLines(state)
|
|
|
|
// Hide cursor during selection (show when in deny mode)
|
|
fmt.Fprint(os.Stderr, "\033[?25l")
|
|
defer fmt.Fprint(os.Stderr, "\033[?25h") // Show cursor when done
|
|
|
|
// Initial render
|
|
renderSelectorBox(state)
|
|
|
|
numOptions := len(optionLabels)
|
|
|
|
for {
|
|
// Read input
|
|
buf := make([]byte, 8)
|
|
n, err := os.Stdin.Read(buf)
|
|
if err != nil {
|
|
clearSelectorBox(state)
|
|
return 2, "", err
|
|
}
|
|
|
|
// Process input byte by byte
|
|
for i := 0; i < n; i++ {
|
|
ch := buf[i]
|
|
|
|
// Check for escape sequences (arrow keys)
|
|
if ch == 27 && i+2 < n && buf[i+1] == '[' {
|
|
oldSelected := state.selected
|
|
switch buf[i+2] {
|
|
case 'A': // Up arrow
|
|
if state.selected > 0 {
|
|
state.selected--
|
|
}
|
|
case 'B': // Down arrow
|
|
if state.selected < numOptions-1 {
|
|
state.selected++
|
|
}
|
|
}
|
|
if oldSelected != state.selected {
|
|
updateSelectorOptions(state)
|
|
}
|
|
i += 2 // Skip the rest of escape sequence
|
|
continue
|
|
}
|
|
|
|
switch {
|
|
// Ctrl+C - cancel
|
|
case ch == 3:
|
|
clearSelectorBox(state)
|
|
return -1, "", nil // -1 indicates cancelled
|
|
|
|
// Enter key - confirm selection
|
|
case ch == 13:
|
|
clearSelectorBox(state)
|
|
if state.selected == 2 { // Deny
|
|
return 2, state.denyReason, nil
|
|
}
|
|
return state.selected, "", nil
|
|
|
|
// Number keys 1-3 for quick select
|
|
case ch >= '1' && ch <= '3':
|
|
selected := int(ch - '1')
|
|
clearSelectorBox(state)
|
|
if selected == 2 { // Deny
|
|
return 2, state.denyReason, nil
|
|
}
|
|
return selected, "", nil
|
|
|
|
// Backspace - delete from reason (UTF-8 safe)
|
|
case ch == 127 || ch == 8:
|
|
if len(state.denyReason) > 0 {
|
|
runes := []rune(state.denyReason)
|
|
state.denyReason = string(runes[:len(runes)-1])
|
|
updateReasonInput(state)
|
|
}
|
|
|
|
// Escape - clear reason
|
|
case ch == 27:
|
|
if len(state.denyReason) > 0 {
|
|
state.denyReason = ""
|
|
updateReasonInput(state)
|
|
}
|
|
|
|
// Printable ASCII (except 1-3 handled above) - type into reason
|
|
case ch >= 32 && ch < 127:
|
|
maxLen := state.innerWidth - 2
|
|
if maxLen < 10 {
|
|
maxLen = 10
|
|
}
|
|
if len(state.denyReason) < maxLen {
|
|
state.denyReason += string(ch)
|
|
// Auto-select Deny option when user starts typing
|
|
if state.selected != 2 {
|
|
state.selected = 2
|
|
updateSelectorOptions(state)
|
|
} else {
|
|
updateReasonInput(state)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// wrapText wraps text to fit within maxWidth, returning lines
|
|
func wrapText(text string, maxWidth int) []string {
|
|
if maxWidth < 5 {
|
|
maxWidth = 5
|
|
}
|
|
var lines []string
|
|
for _, line := range strings.Split(text, "\n") {
|
|
if len(line) <= maxWidth {
|
|
lines = append(lines, line)
|
|
continue
|
|
}
|
|
// Wrap long lines
|
|
for len(line) > maxWidth {
|
|
// Try to break at space
|
|
breakAt := maxWidth
|
|
for i := maxWidth; i > maxWidth/2; i-- {
|
|
if i < len(line) && line[i] == ' ' {
|
|
breakAt = i
|
|
break
|
|
}
|
|
}
|
|
lines = append(lines, line[:breakAt])
|
|
line = strings.TrimLeft(line[breakAt:], " ")
|
|
}
|
|
if len(line) > 0 {
|
|
lines = append(lines, line)
|
|
}
|
|
}
|
|
return lines
|
|
}
|
|
|
|
// getHintLines returns the hint text wrapped to terminal width
|
|
func getHintLines(state *selectorState) []string {
|
|
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
|
|
if state.termWidth >= len(hint)+1 {
|
|
return []string{hint}
|
|
}
|
|
// Wrap hint to multiple lines
|
|
return wrapText(hint, state.termWidth-1)
|
|
}
|
|
|
|
// calculateTotalLines calculates how many lines the selector will use
|
|
func calculateTotalLines(state *selectorState) int {
|
|
toolLines := strings.Split(state.toolDisplay, "\n")
|
|
hintLines := getHintLines(state)
|
|
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
|
|
warningLines := 0
|
|
if state.isWarning {
|
|
warningLines = 2 // warning line + blank line after
|
|
}
|
|
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
|
}
|
|
|
|
// renderSelectorBox renders the selector (minimal, no box)
|
|
func renderSelectorBox(state *selectorState) {
|
|
toolLines := strings.Split(state.toolDisplay, "\n")
|
|
hintLines := getHintLines(state)
|
|
|
|
// Draw warning line if needed
|
|
if state.isWarning {
|
|
if state.warningMessage != "" {
|
|
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
|
}
|
|
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
|
|
}
|
|
|
|
// Draw tool info (plain white)
|
|
for _, line := range toolLines {
|
|
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
|
|
}
|
|
|
|
// Blank line separator
|
|
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
|
|
|
for i, label := range optionLabels {
|
|
if i == 2 {
|
|
denyLabel := "3. Deny: "
|
|
inputDisplay := state.denyReason
|
|
if inputDisplay == "" {
|
|
inputDisplay = "\033[90m(optional reason)\033[0m"
|
|
}
|
|
if i == state.selected {
|
|
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
}
|
|
} else {
|
|
displayLabel := label
|
|
if i == 1 && state.allowlistInfo != "" {
|
|
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
|
}
|
|
if i == state.selected {
|
|
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Blank line before hint
|
|
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
|
|
|
// Draw hint (dark grey)
|
|
for i, line := range hintLines {
|
|
if i == len(hintLines)-1 {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
|
}
|
|
}
|
|
}
|
|
|
|
// updateSelectorOptions updates just the options portion of the selector
|
|
func updateSelectorOptions(state *selectorState) {
|
|
hintLines := getHintLines(state)
|
|
|
|
// Move up to the first option line
|
|
// Cursor is at end of last hint line, need to go up:
|
|
// (hint lines - 1) + 1 (blank line) + numOptions
|
|
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
|
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
|
|
|
for i, label := range optionLabels {
|
|
if i == 2 {
|
|
denyLabel := "3. Deny: "
|
|
inputDisplay := state.denyReason
|
|
if inputDisplay == "" {
|
|
inputDisplay = "\033[90m(optional reason)\033[0m"
|
|
}
|
|
if i == state.selected {
|
|
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
}
|
|
} else {
|
|
displayLabel := label
|
|
if i == 1 && state.allowlistInfo != "" {
|
|
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
|
|
}
|
|
if i == state.selected {
|
|
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Blank line + hint
|
|
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
|
for i, line := range hintLines {
|
|
if i == len(hintLines)-1 {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
|
}
|
|
}
|
|
}
|
|
|
|
// updateReasonInput updates just the Deny option line (which contains the reason input)
|
|
func updateReasonInput(state *selectorState) {
|
|
hintLines := getHintLines(state)
|
|
|
|
// Move up to the Deny line (3rd option, index 2)
|
|
// Cursor is at end of last hint line, need to go up:
|
|
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
|
|
linesToMove := len(hintLines) - 1 + 1 + 1
|
|
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
|
|
|
// Redraw Deny line with reason
|
|
denyLabel := "3. Deny: "
|
|
inputDisplay := state.denyReason
|
|
if inputDisplay == "" {
|
|
inputDisplay = "\033[90m(optional reason)\033[0m"
|
|
}
|
|
if state.selected == 2 {
|
|
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
|
}
|
|
|
|
// Blank line + hint
|
|
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
|
for i, line := range hintLines {
|
|
if i == len(hintLines)-1 {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
|
}
|
|
}
|
|
}
|
|
|
|
// clearSelectorBox clears the selector from screen
|
|
func clearSelectorBox(state *selectorState) {
|
|
// Clear the current line (hint line) first
|
|
fmt.Fprint(os.Stderr, "\r\033[K")
|
|
// Move up and clear each remaining line
|
|
for range state.totalLines - 1 {
|
|
fmt.Fprint(os.Stderr, "\033[A\033[K")
|
|
}
|
|
fmt.Fprint(os.Stderr, "\r")
|
|
}
|
|
|
|
// fallbackApproval handles approval when terminal control isn't available.
|
|
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
|
fmt.Fprintln(os.Stderr)
|
|
fmt.Fprintln(os.Stderr, toolDisplay)
|
|
fmt.Fprintln(os.Stderr)
|
|
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
|
|
fmt.Fprint(os.Stderr, "choice: ")
|
|
|
|
var input string
|
|
fmt.Scanln(&input)
|
|
|
|
switch input {
|
|
case "1":
|
|
return ApprovalResult{Decision: ApprovalOnce}, nil
|
|
case "2":
|
|
return ApprovalResult{Decision: ApprovalAlways}, nil
|
|
default:
|
|
fmt.Fprint(os.Stderr, "Reason (optional): ")
|
|
var reason string
|
|
fmt.Scanln(&reason)
|
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: reason}, nil
|
|
}
|
|
}
|
|
|
|
// Reset clears the session allowlist.
|
|
func (a *ApprovalManager) Reset() {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
a.allowlist = make(map[string]bool)
|
|
a.prefixes = make(map[string]bool)
|
|
}
|
|
|
|
// AllowedTools returns a list of tools and prefixes in the allowlist.
|
|
func (a *ApprovalManager) AllowedTools() []string {
|
|
a.mu.RLock()
|
|
defer a.mu.RUnlock()
|
|
|
|
tools := make([]string, 0, len(a.allowlist)+len(a.prefixes))
|
|
for tool := range a.allowlist {
|
|
tools = append(tools, tool)
|
|
}
|
|
for prefix := range a.prefixes {
|
|
tools = append(tools, prefix+"*")
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// FormatApprovalResult returns a formatted string showing the approval result.
|
|
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
|
var label string
|
|
displayName := ToolDisplayName(toolName)
|
|
|
|
switch result.Decision {
|
|
case ApprovalOnce:
|
|
label = "Approved"
|
|
case ApprovalAlways:
|
|
label = "Always allowed"
|
|
case ApprovalDeny:
|
|
label = "Denied"
|
|
}
|
|
|
|
// Format based on tool type
|
|
if toolName == "bash" {
|
|
if cmd, ok := args["command"].(string); ok {
|
|
// Truncate long commands
|
|
if len(cmd) > 40 {
|
|
cmd = cmd[:37] + "..."
|
|
}
|
|
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
|
|
}
|
|
}
|
|
|
|
if toolName == "web_search" {
|
|
if query, ok := args["query"].(string); ok {
|
|
// Truncate long queries
|
|
if len(query) > 40 {
|
|
query = query[:37] + "..."
|
|
}
|
|
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
|
}
|
|
|
|
// FormatDenyResult returns the tool result message when a tool is denied.
|
|
func FormatDenyResult(toolName string, reason string) string {
|
|
if reason != "" {
|
|
return fmt.Sprintf("User denied execution of %s. Reason: %s", toolName, reason)
|
|
}
|
|
return fmt.Sprintf("User denied execution of %s.", toolName)
|
|
}
|
|
|
|
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
|
|
// Returns true for Yes, false for No.
|
|
func PromptYesNo(question string) (bool, error) {
|
|
fd := int(os.Stdin.Fd())
|
|
oldState, err := term.MakeRaw(fd)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer term.Restore(fd, oldState)
|
|
|
|
selected := 0 // 0 = Yes, 1 = No
|
|
options := []string{"Yes", "No"}
|
|
|
|
// Hide cursor
|
|
fmt.Fprint(os.Stderr, "\033[?25l")
|
|
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
|
|
|
renderYesNo := func() {
|
|
// Move to start of line and clear
|
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
|
fmt.Fprintf(os.Stderr, "%s ", question)
|
|
for i, opt := range options {
|
|
if i == selected {
|
|
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
|
|
}
|
|
}
|
|
}
|
|
|
|
renderYesNo()
|
|
|
|
buf := make([]byte, 3)
|
|
for {
|
|
n, err := os.Stdin.Read(buf)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if n == 1 {
|
|
switch buf[0] {
|
|
case 'y', 'Y':
|
|
selected = 0
|
|
renderYesNo()
|
|
case 'n', 'N':
|
|
selected = 1
|
|
renderYesNo()
|
|
case '\r', '\n': // Enter
|
|
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
|
|
return selected == 0, nil
|
|
case 3: // Ctrl+C
|
|
fmt.Fprintf(os.Stderr, "\r\033[K")
|
|
return false, nil
|
|
case 27: // Escape - could be arrow key
|
|
// Read more bytes for arrow keys
|
|
continue
|
|
}
|
|
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
|
// Arrow keys
|
|
switch buf[2] {
|
|
case 'D': // Left
|
|
if selected > 0 {
|
|
selected--
|
|
}
|
|
renderYesNo()
|
|
case 'C': // Right
|
|
if selected < len(options)-1 {
|
|
selected++
|
|
}
|
|
renderYesNo()
|
|
}
|
|
}
|
|
}
|
|
}
|