464 lines
13 KiB
Go
464 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
const version = "0.1.0"
|
|
|
|
func buildPrompt(config *Config, query string) string {
|
|
return fmt.Sprintf(`You are a Linux shell helper. The user is asking for a shell command (in the %s dialect) to do the following task: "%s". Reply with no more than three options for a command, in Markdown codeblocks. The commands do not have to be one-liners; split lines where necessary to avoid too-long lines. Prefer splitting on pipes, start of new commands, and such places.
|
|
|
|
You can include a short (no more than four sentences; be terse) text before each codeblock. In the text, focus on possible limitations and error modes of the command. No formatting, no introductory sentences.
|
|
|
|
Only include options if they cover additional cases. Prefer native shell features to external commands where practical. Follow best practices. No useless cat. find should usually use the -print0 or -exec flag. Avoid commands that fail for filenames with spaces or special characters. Prefer long form of options to commands.`, config.ShellType, query)
|
|
}
|
|
|
|
var (
|
|
configFileFlag = flag.String("config-file", "", "Path to config file")
|
|
verboseFlag = flag.Bool("verbose", false, "Enable debug logging")
|
|
outputFileFlag = flag.String("output", "", "Path to output file (default: stdout)")
|
|
providerURLFlag = flag.String("provider-url", "", "Override LLM provider URL")
|
|
apiKeyFlag = flag.String("api-key", "", "Override API key")
|
|
shellTypeFlag = flag.String("shell-type", "", "Override shell type")
|
|
modelNameFlag = flag.String("model", "", "Override model name")
|
|
colorSchemeFlag = flag.String("color-scheme", "", "Override color scheme (light/dark)")
|
|
)
|
|
|
|
func main() {
|
|
flag.Usage = func() {
|
|
fmt.Printf("Usage: %s <command> [options]\n", os.Args[0])
|
|
fmt.Println("Commands:")
|
|
fmt.Println(" query Get shell commands for a task")
|
|
fmt.Println(" init Generate shell initialization script")
|
|
fmt.Println("\nGlobal options:")
|
|
flag.PrintDefaults()
|
|
os.Exit(0)
|
|
}
|
|
|
|
if len(os.Args) < 2 {
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// First check for version/help commands before flag parsing
|
|
if len(os.Args) >= 2 {
|
|
switch os.Args[1] {
|
|
case "-v", "--version", "version":
|
|
fmt.Printf("llm-shell-hint version %s\n", version)
|
|
os.Exit(0)
|
|
case "-h", "--help", "help":
|
|
flag.Usage()
|
|
os.Exit(0)
|
|
}
|
|
}
|
|
|
|
// Parse flags, but stop at first non-flag argument
|
|
flag.CommandLine.Parse(os.Args[1:])
|
|
|
|
// Get the command (first non-flag argument)
|
|
args := flag.Args()
|
|
if len(args) < 1 {
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
switch args[0] {
|
|
case "init":
|
|
if len(args) < 2 {
|
|
fmt.Println("Please specify shell name (e.g. fish)")
|
|
os.Exit(1)
|
|
}
|
|
|
|
shell := args[1]
|
|
switch shell {
|
|
case "fish":
|
|
fmt.Printf(`function _llm_shell_hint
|
|
set -l LLM_SUGGESTION_FILE (mktemp)
|
|
llm-shell-hint %s --output $LLM_SUGGESTION_FILE (commandline -b)
|
|
set -l LLM_SUGGESTION (cat $LLM_SUGGESTION_FILE)
|
|
rm $LLM_SUGGESTION_FILE
|
|
|
|
if test -n "$LLM_SUGGESTION"
|
|
commandline -r $LLM_SUGGESTION
|
|
end
|
|
|
|
commandline -f repaint
|
|
end
|
|
|
|
bind ctrl-q _llm_shell_hint
|
|
`, buildConfigFlags())
|
|
|
|
case "bash":
|
|
fmt.Printf(`_llm_shell_hint() {
|
|
local suggestion_file=$(mktemp)
|
|
llm-shell-hint %s --output "$suggestion_file" -- "$READLINE_LINE"
|
|
local suggestion=$(<"$suggestion_file")
|
|
rm "$suggestion_file"
|
|
|
|
if [[ -n "$suggestion" ]]; then
|
|
READLINE_LINE="$suggestion"
|
|
READLINE_POINT=${#suggestion}
|
|
fi
|
|
}
|
|
|
|
bind -x '"\C-q":_llm_shell_hint'
|
|
`, buildConfigFlags())
|
|
|
|
case "zsh":
|
|
fmt.Printf(`_llm_shell_hint() {
|
|
local suggestion_file=$(mktemp)
|
|
llm-shell-hint %s --output "$suggestion_file" -- "$BUFFER"
|
|
local suggestion=$(<"$suggestion_file")
|
|
rm "$suggestion_file"
|
|
|
|
if [[ -n "$suggestion" ]]; then
|
|
BUFFER="$suggestion"
|
|
CURSOR=${#BUFFER}
|
|
fi
|
|
}
|
|
|
|
zle -N _llm_shell_hint
|
|
bindkey '^q' _llm_shell_hint
|
|
`, buildConfigFlags())
|
|
|
|
default:
|
|
fmt.Printf("Unsupported shell: %s\nSupported shells: fish, bash, zsh\n", shell)
|
|
os.Exit(1)
|
|
}
|
|
os.Exit(0)
|
|
|
|
case "query":
|
|
// Get remaining args after the command
|
|
if len(args) < 2 {
|
|
fmt.Println("Please provide a query")
|
|
os.Exit(1)
|
|
}
|
|
|
|
query := strings.Join(args[1:], " ")
|
|
|
|
// Determine config file path
|
|
configPath := *configFileFlag
|
|
if configPath == "" {
|
|
// Get user's home directory
|
|
usr, err := user.Current()
|
|
if err != nil {
|
|
log.Fatalf("Failed to get user home directory: %v", err)
|
|
}
|
|
configPath = filepath.Join(usr.HomeDir, ".config", "llm-shell-hint", "config.toml")
|
|
}
|
|
|
|
// Load or create config
|
|
config, err := loadOrCreateConfig(configPath)
|
|
if err != nil {
|
|
log.Fatalf("Config error: %v", err)
|
|
}
|
|
|
|
// Override config values with command line flags
|
|
if *providerURLFlag != "" {
|
|
config.LLMProviderURL = *providerURLFlag
|
|
}
|
|
if *apiKeyFlag != "" {
|
|
config.APIKey = *apiKeyFlag
|
|
}
|
|
if *shellTypeFlag != "" {
|
|
config.ShellType = *shellTypeFlag
|
|
}
|
|
if *modelNameFlag != "" {
|
|
config.ModelName = *modelNameFlag
|
|
}
|
|
if *colorSchemeFlag != "" {
|
|
if *colorSchemeFlag == "light" || *colorSchemeFlag == "dark" {
|
|
config.ColorScheme = *colorSchemeFlag
|
|
} else {
|
|
log.Printf("Invalid color scheme '%s', using '%s'", *colorSchemeFlag, config.ColorScheme)
|
|
}
|
|
}
|
|
|
|
// Debug logging
|
|
if *verboseFlag {
|
|
log.Printf("[DEBUG] Config: ProviderURL=%s, ShellType=%s, ModelName=%s, APIKey=****, ColorScheme=%s",
|
|
config.LLMProviderURL, config.ShellType, config.ModelName, config.ColorScheme)
|
|
}
|
|
|
|
// Run the TUI and get the selected command
|
|
selectedCommand, err := runTUI(config, query, *verboseFlag)
|
|
if err != nil {
|
|
log.Fatalf("Error running TUI: %v", err)
|
|
}
|
|
if selectedCommand != "" {
|
|
if *outputFileFlag != "" {
|
|
err := os.WriteFile(*outputFileFlag, []byte(selectedCommand+"\n"), 0644)
|
|
if err != nil {
|
|
log.Fatalf("Error writing to output file: %v", err)
|
|
}
|
|
} else {
|
|
fmt.Println(selectedCommand)
|
|
}
|
|
}
|
|
|
|
case "-v", "--version", "version":
|
|
fmt.Printf("llm-shell-hint version %s\n", version)
|
|
os.Exit(0)
|
|
|
|
case "-h", "--help", "help":
|
|
fmt.Printf("Usage: %s <command> [options]\n", os.Args[0])
|
|
fmt.Println("Commands:")
|
|
fmt.Println(" query Get shell commands for a task")
|
|
fmt.Println("\nUse -help with any command to see its options")
|
|
os.Exit(0)
|
|
|
|
default:
|
|
fmt.Printf("Unknown command: %s\n", os.Args[1])
|
|
fmt.Println("Available commands: query")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func queryLLM(config *Config, prompt string, verbose bool) ([]CommandWithComment, error) {
|
|
// Create HTTP request - append /chat/completions to the URL for OpenAI-compatible API
|
|
apiURL := config.LLMProviderURL
|
|
|
|
// Debug logging
|
|
if verbose {
|
|
log.Printf("[DEBUG] Sending request to: %s", apiURL)
|
|
log.Printf("[DEBUG] Using model: %s", config.ModelName)
|
|
log.Printf("[DEBUG] API key length: %d", len(config.APIKey))
|
|
}
|
|
|
|
// Prepare the request payload for OpenAI-compatible API
|
|
payload := map[string]interface{}{
|
|
"model": config.ModelName,
|
|
"messages": []map[string]interface{}{
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
},
|
|
},
|
|
"max_tokens": 500,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal JSON: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %v", err)
|
|
}
|
|
|
|
// Set headers
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
|
// Add OpenRouter specific headers if using OpenRouter
|
|
if strings.Contains(config.LLMProviderURL, "openrouter.ai") {
|
|
req.Header.Set("HTTP-Referer", "https://github.com/your-username/llm-shell-hint")
|
|
req.Header.Set("X-Title", "LLM Shell Hint")
|
|
}
|
|
|
|
// Send request
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read response
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %v", err)
|
|
}
|
|
|
|
// Debug logging
|
|
if verbose {
|
|
log.Printf("[DEBUG] Response status: %d", resp.StatusCode)
|
|
log.Printf("[DEBUG] Response body: %s", string(body))
|
|
}
|
|
|
|
// Check for non-200 status
|
|
if resp.StatusCode != http.StatusOK {
|
|
// Try to parse as JSON error first
|
|
var errorResult map[string]interface{}
|
|
if json.Unmarshal(body, &errorResult) == nil {
|
|
if errorMsg, ok := errorResult["error"].(map[string]interface{}); ok {
|
|
if message, ok := errorMsg["message"].(string); ok {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, message)
|
|
}
|
|
}
|
|
}
|
|
// If not JSON, show the raw response (truncated if too long)
|
|
responseStr := string(body)
|
|
if len(responseStr) > 200 {
|
|
responseStr = responseStr[:200] + "..."
|
|
}
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, responseStr)
|
|
}
|
|
|
|
// Parse response
|
|
var result map[string]interface{}
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
// If parsing fails, show what we received (truncated)
|
|
responseStr := string(body)
|
|
if len(responseStr) > 200 {
|
|
responseStr = responseStr[:200] + "..."
|
|
}
|
|
return nil, fmt.Errorf("failed to parse response (received: %s): %v", responseStr, err)
|
|
}
|
|
|
|
// Extract the content from the response
|
|
choices, ok := result["choices"].([]interface{})
|
|
if !ok || len(choices) == 0 {
|
|
return nil, fmt.Errorf("invalid response format: no choices found")
|
|
}
|
|
|
|
firstChoice, ok := choices[0].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid response format: choice is not an object")
|
|
}
|
|
|
|
message, ok := firstChoice["message"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid response format: no message found")
|
|
}
|
|
|
|
content, ok := message["content"].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid response format: content is not a string")
|
|
}
|
|
|
|
// Parse the content to extract commands and comments from Markdown codeblocks
|
|
commands, err := parseCommandsFromMarkdown(content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse commands from markdown: %v", err)
|
|
}
|
|
|
|
return commands, nil
|
|
}
|
|
|
|
type CommandWithComment struct {
|
|
Command string
|
|
Comment string
|
|
}
|
|
|
|
func buildConfigFlags() string {
|
|
var flags []string
|
|
if *configFileFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--config-file '%s'", *configFileFlag))
|
|
}
|
|
if *verboseFlag {
|
|
flags = append(flags, "--verbose")
|
|
}
|
|
if *providerURLFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--provider-url '%s'", *providerURLFlag))
|
|
}
|
|
if *apiKeyFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--api-key '%s'", *apiKeyFlag))
|
|
}
|
|
if *shellTypeFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--shell-type '%s'", *shellTypeFlag))
|
|
}
|
|
if *modelNameFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--model '%s'", *modelNameFlag))
|
|
}
|
|
if *colorSchemeFlag != "" {
|
|
flags = append(flags, fmt.Sprintf("--color-scheme '%s'", *colorSchemeFlag))
|
|
}
|
|
return strings.Join(flags, " ")
|
|
}
|
|
|
|
func parseCommandsFromMarkdown(content string) ([]CommandWithComment, error) {
|
|
var result []CommandWithComment
|
|
|
|
// Split content by codeblock delimiters
|
|
lines := strings.Split(content, "\n")
|
|
|
|
var currentCommand strings.Builder
|
|
var currentComment strings.Builder
|
|
inCodeblock := false
|
|
|
|
for _, line := range lines {
|
|
trimmedLine := strings.TrimSpace(line)
|
|
|
|
// Check if we're entering or exiting a codeblock
|
|
if strings.HasPrefix(trimmedLine, "```") {
|
|
if inCodeblock {
|
|
// Exiting codeblock - add the command
|
|
if currentCommand.Len() > 0 {
|
|
result = append(result, CommandWithComment{
|
|
Command: currentCommand.String(),
|
|
Comment: strings.TrimSpace(currentComment.String()),
|
|
})
|
|
currentCommand.Reset()
|
|
currentComment.Reset()
|
|
}
|
|
inCodeblock = false
|
|
} else {
|
|
// Entering codeblock
|
|
inCodeblock = true
|
|
}
|
|
continue
|
|
}
|
|
|
|
if inCodeblock {
|
|
currentCommand.WriteString(line)
|
|
currentCommand.WriteString("\n")
|
|
} else if trimmedLine != "" {
|
|
// Collect comment text outside codeblocks
|
|
currentComment.WriteString(line)
|
|
currentComment.WriteString(" ")
|
|
}
|
|
}
|
|
|
|
// Handle any remaining command
|
|
if currentCommand.Len() > 0 {
|
|
result = append(result, CommandWithComment{
|
|
Command: currentCommand.String(),
|
|
Comment: strings.TrimSpace(currentComment.String()),
|
|
})
|
|
}
|
|
|
|
// Clean up all commands and comments
|
|
for i := range result {
|
|
// Trim empty lines from beginning and end of command
|
|
cmdLines := strings.Split(result[i].Command, "\n")
|
|
// Trim from start
|
|
start := 0
|
|
for start < len(cmdLines) && strings.TrimSpace(cmdLines[start]) == "" {
|
|
start++
|
|
}
|
|
// Trim from end
|
|
end := len(cmdLines) - 1
|
|
for end >= start && strings.TrimSpace(cmdLines[end]) == "" {
|
|
end--
|
|
}
|
|
result[i].Command = strings.Join(cmdLines[start:end+1], "\n")
|
|
|
|
// Trim empty lines from beginning and end of comment
|
|
commentLines := strings.Split(result[i].Comment, "\n")
|
|
// Trim from start
|
|
start = 0
|
|
for start < len(commentLines) && strings.TrimSpace(commentLines[start]) == "" {
|
|
start++
|
|
}
|
|
// Trim from end
|
|
end = len(commentLines) - 1
|
|
for end >= start && strings.TrimSpace(commentLines[end]) == "" {
|
|
end--
|
|
}
|
|
result[i].Comment = strings.Join(commentLines[start:end+1], "\n")
|
|
}
|
|
|
|
return result, nil
|
|
}
|