sshidentity/main.go

261 lines
5.9 KiB
Go

package main
import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/log"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/charmbracelet/wish/bubbletea"
"github.com/charmbracelet/wish/logging"
_ "github.com/mattn/go-sqlite3"
)
const (
host = "localhost"
port = "23234"
dbFile = "users.db"
)
var (
errorStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9"))
successStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10"))
db *sql.DB
dbMutex sync.Mutex
)
func main() {
// Initialize database
var err error
db, err = sql.Open("sqlite3", dbFile)
if err != nil {
log.Fatal("Could not open database", "error", err)
}
defer db.Close()
// Create users table if it doesn't exist
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS users (
fingerprint TEXT PRIMARY KEY,
public_key BLOB NOT NULL,
nickname TEXT NOT NULL,
created_at DATETIME NOT NULL,
CHECK(length(nickname) >= 3 AND length(nickname) <= 12)
)
`)
if err != nil {
log.Fatal("Could not create users table", "error", err)
}
s, err := wish.NewServer(
wish.WithAddress(net.JoinHostPort(host, port)),
wish.WithHostKeyPath(".ssh/id_ed25519"),
wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
// Accept any public key authentication
return true
}),
wish.WithMiddleware(
bubbletea.Middleware(teaHandler),
logging.Middleware(),
),
)
if err != nil {
log.Error("Could not start server", "error", err)
}
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
log.Info("Starting SSH server", "host", host, "port", port)
go func() {
if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Error("Could not start server", "error", err)
done <- nil
}
}()
<-done
log.Info("Stopping SSH server")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer func() { cancel() }()
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Error("Could not stop server", "error", err)
}
}
// getFingerprint generates a fingerprint from the public key
func getFingerprint(pubKey ssh.PublicKey) string {
hash := md5.Sum(pubKey.Marshal())
return hex.EncodeToString(hash[:])
}
// getUserNickname retrieves a user's nickname from the database
func getUserNickname(fingerprint string) (string, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
var nickname string
err := db.QueryRow("SELECT nickname FROM users WHERE fingerprint = ?", fingerprint).Scan(&nickname)
if err != nil {
if err == sql.ErrNoRows {
return "", nil
}
return "", err
}
return nickname, nil
}
// addUser adds a new user to the database
func addUser(fingerprint string, pubKey ssh.PublicKey, nickname string) error {
dbMutex.Lock()
defer dbMutex.Unlock()
_, err := db.Exec(
"INSERT INTO users (fingerprint, public_key, nickname, created_at) VALUES (?, ?, ?, ?)",
fingerprint,
pubKey.Marshal(),
nickname,
time.Now().UTC(),
)
return err
}
// teaHandler creates a Bubble Tea program for each SSH session
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
_, _, active := s.Pty()
if !active {
wish.Fatalln(s, "no active terminal, skipping")
return nil, nil
}
// Get the public key
pubKey := s.PublicKey()
if pubKey == nil {
wish.Fatalln(s, "no public key found")
return nil, nil
}
fingerprint := getFingerprint(pubKey)
// Check if we know this user
nickname, err := getUserNickname(fingerprint)
if err != nil {
wish.Fatalln(s, fmt.Sprintf("database error: %v", err))
return nil, nil
}
if nickname != "" {
// Known user - send greeting and close connection
wish.Println(s, fmt.Sprintf("Hello, %s!", nickname))
return nil, nil
}
// New user - prompt for nickname
ti := textinput.New()
ti.Placeholder = "3-12 characters"
ti.Focus()
ti.CharLimit = 12
ti.Width = 20
m := model{
textInput: ti,
fingerprint: fingerprint,
pubKey: pubKey,
session: s,
}
return m, []tea.ProgramOption{tea.WithAltScreen()}
}
// model manages the UI state
type model struct {
textInput textinput.Model
fingerprint string
pubKey ssh.PublicKey
nickname string
known bool
err string
session ssh.Session
}
// Init initializes the model
func (m model) Init() tea.Cmd {
if m.known {
return nil
}
return textinput.Blink
}
// Update handles input and updates the model
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.known {
return m, nil
}
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEnter:
// Validate nickname
nick := m.textInput.Value()
if len(nick) < 3 {
m.err = "Nickname must be at least 3 characters"
return m, nil
}
if len(nick) > 12 {
m.err = "Nickname must be no more than 12 characters"
return m, nil
}
// Store the user in the database
err := addUser(m.fingerprint, m.pubKey, nick)
if err != nil {
m.err = fmt.Sprintf("Could not save nickname: %v", err)
return m, nil
}
// Send greeting and close connection
wish.Println(m.session, fmt.Sprintf("Hello, %s! Your account has been created.", nick))
return m, tea.Quit
case tea.KeyCtrlC, tea.KeyEsc:
return m, tea.Quit
}
}
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
return m, cmd
}
// View renders the UI
func (m model) View() string {
if m.known {
return ""
}
// New user - show input prompt
var errMsg string
if m.err != "" {
errMsg = errorStyle.Render("Error: " + m.err + "\n\n")
}
return fmt.Sprintf(
"Welcome! Please choose a nickname (3-12 characters)\n\n%s%s\n\n%s",
errMsg,
m.textInput.View(),
"(esc to quit)",
) + "\n"
}