261 lines
5.9 KiB
Go
261 lines
5.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/bubbles/textinput"
|
|
tea "github.com/charmbracelet/bubbletea"
|
|
"github.com/charmbracelet/lipgloss"
|
|
"github.com/charmbracelet/log"
|
|
"github.com/charmbracelet/ssh"
|
|
"github.com/charmbracelet/wish"
|
|
"github.com/charmbracelet/wish/bubbletea"
|
|
"github.com/charmbracelet/wish/logging"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const (
|
|
host = "localhost"
|
|
port = "23234"
|
|
dbFile = "users.db"
|
|
)
|
|
|
|
var (
|
|
errorStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9"))
|
|
successStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10"))
|
|
db *sql.DB
|
|
dbMutex sync.Mutex
|
|
)
|
|
|
|
func main() {
|
|
// Initialize database
|
|
var err error
|
|
db, err = sql.Open("sqlite3", dbFile)
|
|
if err != nil {
|
|
log.Fatal("Could not open database", "error", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Create users table if it doesn't exist
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
fingerprint TEXT PRIMARY KEY,
|
|
public_key BLOB NOT NULL,
|
|
nickname TEXT NOT NULL,
|
|
created_at DATETIME NOT NULL,
|
|
CHECK(length(nickname) >= 3 AND length(nickname) <= 12)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
log.Fatal("Could not create users table", "error", err)
|
|
}
|
|
|
|
s, err := wish.NewServer(
|
|
wish.WithAddress(net.JoinHostPort(host, port)),
|
|
wish.WithHostKeyPath(".ssh/id_ed25519"),
|
|
wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
|
// Accept any public key authentication
|
|
return true
|
|
}),
|
|
wish.WithMiddleware(
|
|
bubbletea.Middleware(teaHandler),
|
|
logging.Middleware(),
|
|
),
|
|
)
|
|
if err != nil {
|
|
log.Error("Could not start server", "error", err)
|
|
}
|
|
|
|
done := make(chan os.Signal, 1)
|
|
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
|
log.Info("Starting SSH server", "host", host, "port", port)
|
|
go func() {
|
|
if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
|
log.Error("Could not start server", "error", err)
|
|
done <- nil
|
|
}
|
|
}()
|
|
|
|
<-done
|
|
log.Info("Stopping SSH server")
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer func() { cancel() }()
|
|
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
|
log.Error("Could not stop server", "error", err)
|
|
}
|
|
}
|
|
|
|
// getFingerprint generates a fingerprint from the public key
|
|
func getFingerprint(pubKey ssh.PublicKey) string {
|
|
hash := md5.Sum(pubKey.Marshal())
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// getUserNickname retrieves a user's nickname from the database
|
|
func getUserNickname(fingerprint string) (string, error) {
|
|
dbMutex.Lock()
|
|
defer dbMutex.Unlock()
|
|
|
|
var nickname string
|
|
err := db.QueryRow("SELECT nickname FROM users WHERE fingerprint = ?", fingerprint).Scan(&nickname)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
return nickname, nil
|
|
}
|
|
|
|
// addUser adds a new user to the database
|
|
func addUser(fingerprint string, pubKey ssh.PublicKey, nickname string) error {
|
|
dbMutex.Lock()
|
|
defer dbMutex.Unlock()
|
|
|
|
_, err := db.Exec(
|
|
"INSERT INTO users (fingerprint, public_key, nickname, created_at) VALUES (?, ?, ?, ?)",
|
|
fingerprint,
|
|
pubKey.Marshal(),
|
|
nickname,
|
|
time.Now().UTC(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// teaHandler creates a Bubble Tea program for each SSH session
|
|
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
|
|
_, _, active := s.Pty()
|
|
if !active {
|
|
wish.Fatalln(s, "no active terminal, skipping")
|
|
return nil, nil
|
|
}
|
|
|
|
// Get the public key
|
|
pubKey := s.PublicKey()
|
|
if pubKey == nil {
|
|
wish.Fatalln(s, "no public key found")
|
|
return nil, nil
|
|
}
|
|
fingerprint := getFingerprint(pubKey)
|
|
|
|
// Check if we know this user
|
|
nickname, err := getUserNickname(fingerprint)
|
|
if err != nil {
|
|
wish.Fatalln(s, fmt.Sprintf("database error: %v", err))
|
|
return nil, nil
|
|
}
|
|
|
|
if nickname != "" {
|
|
// Known user - send greeting and close connection
|
|
wish.Println(s, fmt.Sprintf("Hello, %s!", nickname))
|
|
return nil, nil
|
|
}
|
|
|
|
// New user - prompt for nickname
|
|
ti := textinput.New()
|
|
ti.Placeholder = "3-12 characters"
|
|
ti.Focus()
|
|
ti.CharLimit = 12
|
|
ti.Width = 20
|
|
|
|
m := model{
|
|
textInput: ti,
|
|
fingerprint: fingerprint,
|
|
pubKey: pubKey,
|
|
session: s,
|
|
}
|
|
return m, []tea.ProgramOption{tea.WithAltScreen()}
|
|
}
|
|
|
|
// model manages the UI state
|
|
type model struct {
|
|
textInput textinput.Model
|
|
fingerprint string
|
|
pubKey ssh.PublicKey
|
|
nickname string
|
|
known bool
|
|
err string
|
|
session ssh.Session
|
|
}
|
|
|
|
// Init initializes the model
|
|
func (m model) Init() tea.Cmd {
|
|
if m.known {
|
|
return nil
|
|
}
|
|
return textinput.Blink
|
|
}
|
|
|
|
// Update handles input and updates the model
|
|
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|
if m.known {
|
|
return m, nil
|
|
}
|
|
|
|
switch msg := msg.(type) {
|
|
case tea.KeyMsg:
|
|
switch msg.Type {
|
|
case tea.KeyEnter:
|
|
// Validate nickname
|
|
nick := m.textInput.Value()
|
|
if len(nick) < 3 {
|
|
m.err = "Nickname must be at least 3 characters"
|
|
return m, nil
|
|
}
|
|
if len(nick) > 12 {
|
|
m.err = "Nickname must be no more than 12 characters"
|
|
return m, nil
|
|
}
|
|
|
|
// Store the user in the database
|
|
err := addUser(m.fingerprint, m.pubKey, nick)
|
|
if err != nil {
|
|
m.err = fmt.Sprintf("Could not save nickname: %v", err)
|
|
return m, nil
|
|
}
|
|
|
|
// Send greeting and close connection
|
|
wish.Println(m.session, fmt.Sprintf("Hello, %s! Your account has been created.", nick))
|
|
return m, tea.Quit
|
|
|
|
case tea.KeyCtrlC, tea.KeyEsc:
|
|
return m, tea.Quit
|
|
}
|
|
}
|
|
|
|
var cmd tea.Cmd
|
|
m.textInput, cmd = m.textInput.Update(msg)
|
|
return m, cmd
|
|
}
|
|
|
|
// View renders the UI
|
|
func (m model) View() string {
|
|
if m.known {
|
|
return ""
|
|
}
|
|
|
|
// New user - show input prompt
|
|
var errMsg string
|
|
if m.err != "" {
|
|
errMsg = errorStyle.Render("Error: " + m.err + "\n\n")
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"Welcome! Please choose a nickname (3-12 characters)\n\n%s%s\n\n%s",
|
|
errMsg,
|
|
m.textInput.View(),
|
|
"(esc to quit)",
|
|
) + "\n"
|
|
}
|