llm-shell-hint/main.go
2025-08-26 21:20:50 +03:00

464 lines
13 KiB
Go

package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
)
const version = "0.1.0"
func buildPrompt(config *Config, query string) string {
return fmt.Sprintf(`You are a Linux shell helper. The user is asking for a shell command (in the %s dialect) to do the following task: "%s". Reply with no more than three options for a command, in Markdown codeblocks. The commands do not have to be one-liners; split lines where necessary to avoid too-long lines. Prefer splitting on pipes, start of new commands, and such places.
You can include a short (no more than four sentences; be terse) text before each codeblock. In the text, focus on possible limitations and error modes of the command. No formatting, no introductory sentences.
Only include options if they cover additional cases. Prefer native shell features to external commands where practical. Follow best practices. No useless cat. find should usually use the -print0 or -exec flag. Avoid commands that fail for filenames with spaces or special characters. Prefer long form of options to commands.`, config.ShellType, query)
}
var (
configFileFlag = flag.String("config-file", "", "Path to config file")
verboseFlag = flag.Bool("verbose", false, "Enable debug logging")
outputFileFlag = flag.String("output", "", "Path to output file (default: stdout)")
providerURLFlag = flag.String("provider-url", "", "Override LLM provider URL")
apiKeyFlag = flag.String("api-key", "", "Override API key")
shellTypeFlag = flag.String("shell-type", "", "Override shell type")
modelNameFlag = flag.String("model", "", "Override model name")
colorSchemeFlag = flag.String("color-scheme", "", "Override color scheme (light/dark)")
)
func main() {
flag.Usage = func() {
fmt.Printf("Usage: %s <command> [options]\n", os.Args[0])
fmt.Println("Commands:")
fmt.Println(" query Get shell commands for a task")
fmt.Println(" init Generate shell initialization script")
fmt.Println("\nGlobal options:")
flag.PrintDefaults()
os.Exit(0)
}
if len(os.Args) < 2 {
flag.Usage()
os.Exit(1)
}
// First check for version/help commands before flag parsing
if len(os.Args) >= 2 {
switch os.Args[1] {
case "-v", "--version", "version":
fmt.Printf("llm-shell-hint version %s\n", version)
os.Exit(0)
case "-h", "--help", "help":
flag.Usage()
os.Exit(0)
}
}
// Parse flags, but stop at first non-flag argument
flag.CommandLine.Parse(os.Args[1:])
// Get the command (first non-flag argument)
args := flag.Args()
if len(args) < 1 {
flag.Usage()
os.Exit(1)
}
switch args[0] {
case "init":
if len(args) < 2 {
fmt.Println("Please specify shell name (e.g. fish)")
os.Exit(1)
}
shell := args[1]
switch shell {
case "fish":
fmt.Printf(`function _llm_shell_hint
set -l LLM_SUGGESTION_FILE (mktemp)
llm-shell-hint %s --output $LLM_SUGGESTION_FILE (commandline -b)
set -l LLM_SUGGESTION (cat $LLM_SUGGESTION_FILE)
rm $LLM_SUGGESTION_FILE
if test -n "$LLM_SUGGESTION"
commandline -r $LLM_SUGGESTION
end
commandline -f repaint
end
bind ctrl-q _llm_shell_hint
`, buildConfigFlags())
case "bash":
fmt.Printf(`_llm_shell_hint() {
local suggestion_file=$(mktemp)
llm-shell-hint %s --output "$suggestion_file" -- "$READLINE_LINE"
local suggestion=$(<"$suggestion_file")
rm "$suggestion_file"
if [[ -n "$suggestion" ]]; then
READLINE_LINE="$suggestion"
READLINE_POINT=${#suggestion}
fi
}
bind -x '"\C-q":_llm_shell_hint'
`, buildConfigFlags())
case "zsh":
fmt.Printf(`_llm_shell_hint() {
local suggestion_file=$(mktemp)
llm-shell-hint %s --output "$suggestion_file" -- "$BUFFER"
local suggestion=$(<"$suggestion_file")
rm "$suggestion_file"
if [[ -n "$suggestion" ]]; then
BUFFER="$suggestion"
CURSOR=${#BUFFER}
fi
}
zle -N _llm_shell_hint
bindkey '^q' _llm_shell_hint
`, buildConfigFlags())
default:
fmt.Printf("Unsupported shell: %s\nSupported shells: fish, bash, zsh\n", shell)
os.Exit(1)
}
os.Exit(0)
case "query":
// Get remaining args after the command
if len(args) < 2 {
fmt.Println("Please provide a query")
os.Exit(1)
}
query := strings.Join(args[1:], " ")
// Determine config file path
configPath := *configFileFlag
if configPath == "" {
// Get user's home directory
usr, err := user.Current()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
}
configPath = filepath.Join(usr.HomeDir, ".config", "llm-shell-hint", "config.toml")
}
// Load or create config
config, err := loadOrCreateConfig(configPath)
if err != nil {
log.Fatalf("Config error: %v", err)
}
// Override config values with command line flags
if *providerURLFlag != "" {
config.LLMProviderURL = *providerURLFlag
}
if *apiKeyFlag != "" {
config.APIKey = *apiKeyFlag
}
if *shellTypeFlag != "" {
config.ShellType = *shellTypeFlag
}
if *modelNameFlag != "" {
config.ModelName = *modelNameFlag
}
if *colorSchemeFlag != "" {
if *colorSchemeFlag == "light" || *colorSchemeFlag == "dark" {
config.ColorScheme = *colorSchemeFlag
} else {
log.Printf("Invalid color scheme '%s', using '%s'", *colorSchemeFlag, config.ColorScheme)
}
}
// Debug logging
if *verboseFlag {
log.Printf("[DEBUG] Config: ProviderURL=%s, ShellType=%s, ModelName=%s, APIKey=****, ColorScheme=%s",
config.LLMProviderURL, config.ShellType, config.ModelName, config.ColorScheme)
}
// Run the TUI and get the selected command
selectedCommand, err := runTUI(config, query, *verboseFlag)
if err != nil {
log.Fatalf("Error running TUI: %v", err)
}
if selectedCommand != "" {
if *outputFileFlag != "" {
err := os.WriteFile(*outputFileFlag, []byte(selectedCommand+"\n"), 0644)
if err != nil {
log.Fatalf("Error writing to output file: %v", err)
}
} else {
fmt.Println(selectedCommand)
}
}
case "-v", "--version", "version":
fmt.Printf("llm-shell-hint version %s\n", version)
os.Exit(0)
case "-h", "--help", "help":
fmt.Printf("Usage: %s <command> [options]\n", os.Args[0])
fmt.Println("Commands:")
fmt.Println(" query Get shell commands for a task")
fmt.Println("\nUse -help with any command to see its options")
os.Exit(0)
default:
fmt.Printf("Unknown command: %s\n", os.Args[1])
fmt.Println("Available commands: query")
os.Exit(1)
}
}
func queryLLM(config *Config, prompt string, verbose bool) ([]CommandWithComment, error) {
// Create HTTP request - append /chat/completions to the URL for OpenAI-compatible API
apiURL := config.LLMProviderURL
// Debug logging
if verbose {
log.Printf("[DEBUG] Sending request to: %s", apiURL)
log.Printf("[DEBUG] Using model: %s", config.ModelName)
log.Printf("[DEBUG] API key length: %d", len(config.APIKey))
}
// Prepare the request payload for OpenAI-compatible API
payload := map[string]interface{}{
"model": config.ModelName,
"messages": []map[string]interface{}{
{
"role": "user",
"content": prompt,
},
},
"max_tokens": 500,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal JSON: %v", err)
}
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+config.APIKey)
// Add OpenRouter specific headers if using OpenRouter
if strings.Contains(config.LLMProviderURL, "openrouter.ai") {
req.Header.Set("HTTP-Referer", "https://github.com/your-username/llm-shell-hint")
req.Header.Set("X-Title", "LLM Shell Hint")
}
// Send request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
// Debug logging
if verbose {
log.Printf("[DEBUG] Response status: %d", resp.StatusCode)
log.Printf("[DEBUG] Response body: %s", string(body))
}
// Check for non-200 status
if resp.StatusCode != http.StatusOK {
// Try to parse as JSON error first
var errorResult map[string]interface{}
if json.Unmarshal(body, &errorResult) == nil {
if errorMsg, ok := errorResult["error"].(map[string]interface{}); ok {
if message, ok := errorMsg["message"].(string); ok {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, message)
}
}
}
// If not JSON, show the raw response (truncated if too long)
responseStr := string(body)
if len(responseStr) > 200 {
responseStr = responseStr[:200] + "..."
}
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, responseStr)
}
// Parse response
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
// If parsing fails, show what we received (truncated)
responseStr := string(body)
if len(responseStr) > 200 {
responseStr = responseStr[:200] + "..."
}
return nil, fmt.Errorf("failed to parse response (received: %s): %v", responseStr, err)
}
// Extract the content from the response
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return nil, fmt.Errorf("invalid response format: no choices found")
}
firstChoice, ok := choices[0].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid response format: choice is not an object")
}
message, ok := firstChoice["message"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid response format: no message found")
}
content, ok := message["content"].(string)
if !ok {
return nil, fmt.Errorf("invalid response format: content is not a string")
}
// Parse the content to extract commands and comments from Markdown codeblocks
commands, err := parseCommandsFromMarkdown(content)
if err != nil {
return nil, fmt.Errorf("failed to parse commands from markdown: %v", err)
}
return commands, nil
}
type CommandWithComment struct {
Command string
Comment string
}
func buildConfigFlags() string {
var flags []string
if *configFileFlag != "" {
flags = append(flags, fmt.Sprintf("--config-file '%s'", *configFileFlag))
}
if *verboseFlag {
flags = append(flags, "--verbose")
}
if *providerURLFlag != "" {
flags = append(flags, fmt.Sprintf("--provider-url '%s'", *providerURLFlag))
}
if *apiKeyFlag != "" {
flags = append(flags, fmt.Sprintf("--api-key '%s'", *apiKeyFlag))
}
if *shellTypeFlag != "" {
flags = append(flags, fmt.Sprintf("--shell-type '%s'", *shellTypeFlag))
}
if *modelNameFlag != "" {
flags = append(flags, fmt.Sprintf("--model '%s'", *modelNameFlag))
}
if *colorSchemeFlag != "" {
flags = append(flags, fmt.Sprintf("--color-scheme '%s'", *colorSchemeFlag))
}
return strings.Join(flags, " ")
}
func parseCommandsFromMarkdown(content string) ([]CommandWithComment, error) {
var result []CommandWithComment
// Split content by codeblock delimiters
lines := strings.Split(content, "\n")
var currentCommand strings.Builder
var currentComment strings.Builder
inCodeblock := false
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
// Check if we're entering or exiting a codeblock
if strings.HasPrefix(trimmedLine, "```") {
if inCodeblock {
// Exiting codeblock - add the command
if currentCommand.Len() > 0 {
result = append(result, CommandWithComment{
Command: currentCommand.String(),
Comment: strings.TrimSpace(currentComment.String()),
})
currentCommand.Reset()
currentComment.Reset()
}
inCodeblock = false
} else {
// Entering codeblock
inCodeblock = true
}
continue
}
if inCodeblock {
currentCommand.WriteString(line)
currentCommand.WriteString("\n")
} else if trimmedLine != "" {
// Collect comment text outside codeblocks
currentComment.WriteString(line)
currentComment.WriteString(" ")
}
}
// Handle any remaining command
if currentCommand.Len() > 0 {
result = append(result, CommandWithComment{
Command: currentCommand.String(),
Comment: strings.TrimSpace(currentComment.String()),
})
}
// Clean up all commands and comments
for i := range result {
// Trim empty lines from beginning and end of command
cmdLines := strings.Split(result[i].Command, "\n")
// Trim from start
start := 0
for start < len(cmdLines) && strings.TrimSpace(cmdLines[start]) == "" {
start++
}
// Trim from end
end := len(cmdLines) - 1
for end >= start && strings.TrimSpace(cmdLines[end]) == "" {
end--
}
result[i].Command = strings.Join(cmdLines[start:end+1], "\n")
// Trim empty lines from beginning and end of comment
commentLines := strings.Split(result[i].Comment, "\n")
// Trim from start
start = 0
for start < len(commentLines) && strings.TrimSpace(commentLines[start]) == "" {
start++
}
// Trim from end
end = len(commentLines) - 1
for end >= start && strings.TrimSpace(commentLines[end]) == "" {
end--
}
result[i].Comment = strings.Join(commentLines[start:end+1], "\n")
}
return result, nil
}