Files

229 lines
7.0 KiB
Go

package sftputil
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
// AddHostKey adds the SFTP host key to the known_hosts file.
// It uses ssh-keyscan to fetch the key.
func AddHostKey(host, port string, sshDir string) error {
hostSpec := host
if port != "" && port != "22" {
hostSpec = fmt.Sprintf("[%s]:%s", host, port)
}
knownHostsPath := filepath.Join(sshDir, "known_hosts")
if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil {
return fmt.Errorf("failed to create ssh dir: %w", err)
}
// Check if already known in the specified file
checkCmd := exec.Command("ssh-keygen", "-F", hostSpec, "-f", knownHostsPath)
if checkCmd.Run() == nil {
zap.S().Debugf("SFTP host %s already in %s", hostSpec, knownHostsPath)
return nil
}
// Also check default known_hosts if the user hasn't explicitly pointed us to a custom one that overlaps
home, _ := os.UserHomeDir()
defaultKnownHosts := filepath.Join(home, ".ssh", "known_hosts")
if defaultKnownHosts != knownHostsPath {
if _, err := os.Stat(defaultKnownHosts); err == nil {
checkDefaultCmd := exec.Command("ssh-keygen", "-F", hostSpec, "-f", defaultKnownHosts)
if checkDefaultCmd.Run() == nil {
zap.S().Debugf("SFTP host %s already in %s", hostSpec, defaultKnownHosts)
return nil
}
}
}
keyscanArgs := []string{"-H"}
if port != "" {
keyscanArgs = append(keyscanArgs, "-p", port)
}
keyscanArgs = append(keyscanArgs, host)
keyscanCmd := exec.Command("ssh-keyscan", keyscanArgs...)
var stdout, stderr bytes.Buffer
keyscanCmd.Stdout = &stdout
keyscanCmd.Stderr = &stderr
if err := keyscanCmd.Run(); err != nil {
return fmt.Errorf("ssh-keyscan for host %s failed: %w (stderr: %q)", host, err, strings.TrimSpace(stderr.String()))
}
keyOutput := stdout.Bytes()
// If ssh-keyscan succeeded but returned no output, it's also an error (e.g. host down)
if len(bytes.TrimSpace(keyOutput)) == 0 {
return fmt.Errorf("ssh-keyscan for host %s returned no keys (stderr: %q)", host, strings.TrimSpace(stderr.String()))
}
f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("failed to open known_hosts file: %w", err)
}
defer f.Close()
if _, err := f.Write(keyOutput); err != nil {
return fmt.Errorf("failed to write to known_hosts file: %w", err)
}
zap.S().Infof("Added SFTP host %s to known_hosts file at %s", hostSpec, knownHostsPath)
return nil
}
// GenerateKey generates an Ed25519 key pair and saves it to the specified directory.
// Returns the private key in OpenSSH PEM format, public key in SSH format, and the full path to the private key file.
func GenerateKey(host string, sshDir string) ([]byte, []byte, string, error) {
zap.S().Debugf("Generating ED25519 key for host %s", host)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to generate key: %w", err)
}
// Marshal private key to OpenSSH PEM format (requires "golang.org/x/crypto/ssh")
// Note: ssh.MarshalPrivateKey returns a PEM block since Go 1.16+ for Ed25519?
// Actually ssh.MarshalPrivateKey returns an *pem.Block.
privBlock, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
return nil, nil, "", fmt.Errorf("failed to marshal private key: %w", err)
}
privPEM := pem.EncodeToMemory(privBlock)
sshPub, err := ssh.NewPublicKey(pub)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to create public key: %w", err)
}
pubBytes := ssh.MarshalAuthorizedKey(sshPub)
// Save to file
if err := os.MkdirAll(sshDir, 0700); err != nil {
return nil, nil, "", fmt.Errorf("failed to create ssh dir: %w", err)
}
sanitizedHost := strings.Map(func(r rune) rune {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '-' {
return r
}
return '_'
}, host)
keyPath := filepath.Join(sshDir, "id_ed25519_"+string(sanitizedHost))
// check if file exists
if _, err := os.Stat(keyPath); err == nil {
// read the key from disk instead
privPEM, err = os.ReadFile(keyPath)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to read private key: %w", err)
}
pubBytes, err = os.ReadFile(keyPath + ".pub")
if err != nil {
return nil, nil, "", fmt.Errorf("failed to read public key: %w", err)
}
return privPEM, pubBytes, keyPath, nil
}
if err := os.WriteFile(keyPath, privPEM, 0600); err != nil {
return nil, nil, "", fmt.Errorf("failed to write private key: %w", err)
}
if err := os.WriteFile(keyPath+".pub", pubBytes, 0644); err != nil {
zap.S().Warnf("failed to write public key: %v", err)
}
return privPEM, pubBytes, keyPath, nil
}
// InstallKey connects to the SFTP server using a password and appends the public key to authorized_keys.
func InstallKey(host, port, user, password string, pubBytes []byte) error {
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Verification assumed done via AddHostKey
Timeout: 10 * time.Second,
}
conn, err := ssh.Dial("tcp", net.JoinHostPort(host, port), sshConfig)
if err != nil {
return fmt.Errorf("failed to connect with password: %w", err)
}
defer conn.Close()
sftpClient, err := sftp.NewClient(conn)
if err != nil {
return fmt.Errorf("failed to create sftp client: %w", err)
}
defer sftpClient.Close()
// Ensure .ssh directory exists
if _, err := sftpClient.Stat(".ssh"); errors.Is(err, os.ErrNotExist) {
if err := sftpClient.Mkdir(".ssh"); err != nil {
return fmt.Errorf("failed to create .ssh directory: %w", err)
}
if err := sftpClient.Chmod(".ssh", 0700); err != nil {
zap.S().Warnf("failed to chmod .ssh: %v", err)
}
}
f, err := sftpClient.OpenFile(".ssh/authorized_keys", os.O_APPEND|os.O_CREATE|os.O_WRONLY)
if err != nil {
return fmt.Errorf("failed to open authorized_keys: %w", err)
}
defer f.Close()
if err := f.Chmod(0600); err != nil {
zap.S().Warnf("failed to chmod authorized_keys: %v", err)
}
if _, err := f.Write([]byte("\n")); err != nil {
return fmt.Errorf("write error: %w", err)
}
if _, err := f.Write(pubBytes); err != nil {
return fmt.Errorf("write error: %w", err)
}
if _, err := f.Write([]byte("\n")); err != nil {
return fmt.Errorf("write error: %w", err)
}
return nil
}
// VerifyConnection attempts to connect using the provided private key.
func VerifyConnection(host, port, user string, privPEM []byte) error {
signer, err := ssh.ParsePrivateKey(privPEM)
if err != nil {
return fmt.Errorf("failed to parse private key: %w", err)
}
clientConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
conn, err := ssh.Dial("tcp", net.JoinHostPort(host, port), clientConfig)
if err != nil {
return fmt.Errorf("verification connection failed: %w", err)
}
defer conn.Close()
return nil
}