mirror of
https://github.com/garethgeorge/backrest.git
synced 2025-12-14 01:35:31 +00:00
Some checks failed
Release Please / release-please (push) Has been cancelled
Build Snapshot Release / build (push) Has been cancelled
Test / test-nix (push) Has been cancelled
Test / test-win (push) Has been cancelled
Update Restic / update-restic-version (push) Has been cancelled
291 lines
7.9 KiB
Go
291 lines
7.9 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdh"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"connectrpc.com/connect"
|
|
v1 "github.com/garethgeorge/backrest/gen/go/v1"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type WrappedStreamOptions func(*WrappedStream)
|
|
|
|
func WithLogger(logger *zap.Logger) WrappedStreamOptions {
|
|
return func(ws *WrappedStream) {
|
|
ws.logger = logger
|
|
}
|
|
}
|
|
|
|
func WithHeartbeatInterval(interval time.Duration) WrappedStreamOptions {
|
|
return func(ws *WrappedStream) {
|
|
ws.heartbeatInterval = interval
|
|
}
|
|
}
|
|
|
|
type WrappedStream struct {
|
|
isClient bool
|
|
|
|
stream stream
|
|
logger *zap.Logger
|
|
|
|
provider *ConnectionProvider // is nil until ProvideConnectionsTo is called
|
|
heartbeatInterval time.Duration // interval for heartbeat messages, if set
|
|
|
|
// pool of connections, every connection is bidirectional so can be initiated from either side.
|
|
connsMu sync.Mutex
|
|
conns map[int64]*connState
|
|
|
|
lastConnID atomic.Int64
|
|
|
|
handlingPackets atomic.Bool
|
|
streamStopped atomic.Bool
|
|
}
|
|
|
|
func newWrappedStreamInternal(stream stream, isClient bool, opts ...WrappedStreamOptions) *WrappedStream {
|
|
ws := &WrappedStream{
|
|
isClient: isClient,
|
|
stream: stream,
|
|
heartbeatInterval: 30 * time.Second,
|
|
conns: make(map[int64]*connState),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(ws)
|
|
}
|
|
if isClient {
|
|
ws.lastConnID.Store(1) // Use odd numbered connection IDs for client-initiated connections.
|
|
} else {
|
|
ws.lastConnID.Store(2) // Use even numbered connection IDs for server-initiated connections.
|
|
}
|
|
return ws
|
|
}
|
|
|
|
func NewWrappedStream(stream *connect.BidiStream[v1.TunnelMessage, v1.TunnelMessage], opts ...WrappedStreamOptions) *WrappedStream {
|
|
return newWrappedStreamInternal(&serverStream{
|
|
stream: stream,
|
|
}, false, opts...)
|
|
}
|
|
|
|
func NewWrappedStreamFromClient(stream *connect.BidiStreamForClient[v1.TunnelMessage, v1.TunnelMessage], opts ...WrappedStreamOptions) *WrappedStream {
|
|
return newWrappedStreamInternal(&clientStream{
|
|
stream: stream,
|
|
}, true, opts...)
|
|
}
|
|
|
|
func (ws *WrappedStream) allocConnID() int64 {
|
|
return ws.lastConnID.Add(2)
|
|
}
|
|
|
|
func (ws *WrappedStream) IsReady() bool {
|
|
return ws.handlingPackets.Load() && !ws.streamStopped.Load()
|
|
}
|
|
|
|
func (ws *WrappedStream) Dial() (net.Conn, error) {
|
|
if !ws.handlingPackets.Load() {
|
|
return nil, fmt.Errorf("cannot dial before handling packets")
|
|
}
|
|
|
|
connID := ws.allocConnID()
|
|
new := newConnState(ws.stream, connID, ws.logger)
|
|
if err := new.sendOpenPacket(); err != nil {
|
|
return nil, fmt.Errorf("send open packet: %w", err)
|
|
}
|
|
ws.connsMu.Lock()
|
|
defer ws.connsMu.Unlock()
|
|
ws.conns[connID] = new
|
|
return new, nil
|
|
}
|
|
|
|
func (ws *WrappedStream) ProvideConnectionsTo(provider *ConnectionProvider) {
|
|
ws.provider = provider
|
|
}
|
|
|
|
func (ws *WrappedStream) sendHeartbeats(ctx context.Context) {
|
|
if ws.heartbeatInterval <= 0 || !ws.isClient {
|
|
return
|
|
}
|
|
|
|
ticker := time.NewTicker(ws.heartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if ws.logger != nil {
|
|
ws.logger.Debug("sending heartbeat")
|
|
}
|
|
if err := ws.stream.Send(&v1.TunnelMessage{
|
|
ConnId: -1, // handshake packet
|
|
}); err != nil && ws.logger != nil {
|
|
ws.logger.Error("failed to send heartbeat", zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ws *WrappedStream) HandlePackets(ctx context.Context) error {
|
|
if ws.handlingPackets.Swap(true) {
|
|
return fmt.Errorf("already handling packets")
|
|
}
|
|
defer ws.handlingPackets.Store(false)
|
|
if ws.streamStopped.Load() {
|
|
return fmt.Errorf("stream already stopped")
|
|
}
|
|
|
|
// TODO: optimization, it is generally secure and performant have a singleton key that is generated once and reused for all connections.
|
|
// the only risk w/this approach is if the key were somehow leaked all related connections would be compromised. The risk is low in this case since
|
|
// the key is only kept in memory for the lifetime of the process, and ideally is a second line of defense after TLS.
|
|
key, err := ecdh.X25519().GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return fmt.Errorf("generate key for handshake packet: %w", err)
|
|
}
|
|
|
|
if err := ws.stream.Send(&v1.TunnelMessage{
|
|
ConnId: -100, // hadnshake packet
|
|
PubkeyEcdhX25519: key.PublicKey().Bytes(),
|
|
}); err != nil {
|
|
return fmt.Errorf("send handshake packet: %w", err)
|
|
}
|
|
|
|
// receive handshake packet
|
|
handshake, err := ws.stream.Receive()
|
|
if err != nil {
|
|
return fmt.Errorf("receive handshake packet: %w", err)
|
|
}
|
|
if handshake.GetConnId() != -100 {
|
|
return fmt.Errorf("expected handshake packet with connId -100, got %d", handshake.GetConnId())
|
|
}
|
|
if len(handshake.PubkeyEcdhX25519) == 0 {
|
|
return fmt.Errorf("handshake packet does not contain public key")
|
|
}
|
|
peerKey, err := ecdh.X25519().NewPublicKey(handshake.PubkeyEcdhX25519)
|
|
if err != nil {
|
|
return fmt.Errorf("parse peer public key: %w", err)
|
|
}
|
|
|
|
_, err = key.ECDH(peerKey)
|
|
if err != nil {
|
|
return fmt.Errorf("compute shared key: %w", err)
|
|
}
|
|
// TODO: use the key for encryption and decryption of messages.
|
|
|
|
newConn := func(connId int64) *connState {
|
|
if ws.logger != nil {
|
|
ws.logger.Info("new tunnel connection", zap.Int64("connId", connId))
|
|
}
|
|
new := newConnState(ws.stream, connId, ws.logger)
|
|
ws.conns[connId] = new
|
|
ws.provider.ProvideConn(new)
|
|
return new
|
|
}
|
|
|
|
headOfLineBlockingTimer := time.NewTimer(0)
|
|
defer headOfLineBlockingTimer.Stop()
|
|
|
|
// send heartbeats in a separate goroutine if heartbeat interval is set
|
|
go ws.sendHeartbeats(ctx)
|
|
|
|
for {
|
|
msg, err := ws.stream.Receive()
|
|
if err != nil {
|
|
if ws.handlingPackets.Load() {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("receive message: %w", err)
|
|
}
|
|
|
|
connId := msg.GetConnId()
|
|
if connId <= 0 {
|
|
// negative IDs are reserved for healthchecks and control messages, ignored.
|
|
continue
|
|
}
|
|
|
|
if msg.Close {
|
|
if ws.logger != nil {
|
|
ws.logger.Info("closing connection", zap.Int64("connId", connId))
|
|
}
|
|
ws.connsMu.Lock()
|
|
if conn, exists := ws.conns[connId]; exists {
|
|
if err := conn.Close(); err != nil && ws.logger != nil {
|
|
ws.logger.Error("failed to close connection", zap.Int64("connId", connId), zap.Error(err))
|
|
}
|
|
delete(ws.conns, connId)
|
|
}
|
|
ws.connsMu.Unlock()
|
|
continue
|
|
}
|
|
|
|
ws.connsMu.Lock()
|
|
conn, exists := ws.conns[connId]
|
|
if !exists {
|
|
if msg.Seqno != 0 {
|
|
if ws.logger != nil {
|
|
ws.logger.Warn("received message for unknown connection", zap.Int64("connId", connId), zap.Int64("seqno", msg.Seqno))
|
|
}
|
|
continue
|
|
}
|
|
if ws.provider == nil {
|
|
if ws.logger != nil {
|
|
ws.logger.Warn("received open packet for unknown connection, but no provider is set, ignoring the connection", zap.Int64("connId", connId))
|
|
}
|
|
continue
|
|
}
|
|
|
|
conn = newConn(connId)
|
|
}
|
|
ws.connsMu.Unlock()
|
|
|
|
if len(msg.Data) == 0 {
|
|
continue
|
|
}
|
|
|
|
headOfLineBlockingTimer.Reset(100 * time.Millisecond)
|
|
select {
|
|
case <-conn.closedCh:
|
|
if ws.logger != nil {
|
|
ws.logger.Warn("received message for closed connection", zap.Int64("connId", connId), zap.Int64("seqno", msg.Seqno))
|
|
}
|
|
case conn.reads <- msg.Data:
|
|
if ws.logger != nil {
|
|
ws.logger.Debug("received data on tunnel connection", zap.Int64("connId", connId), zap.Int64("seqno", msg.Seqno), zap.Int("dataLength", len(msg.Data)))
|
|
}
|
|
case <-headOfLineBlockingTimer.C:
|
|
if ws.logger != nil {
|
|
ws.logger.Warn("head-of-line blocking detected, no reads available for connection", zap.Int64("connId", connId), zap.Int64("seqno", msg.Seqno))
|
|
}
|
|
conn.Close()
|
|
case <-ctx.Done():
|
|
// Close all open connections when the context is done.
|
|
ws.connsMu.Lock()
|
|
if ws.logger != nil {
|
|
ws.logger.Info("context done, closing all connections")
|
|
}
|
|
for _, c := range ws.conns {
|
|
if err := c.Close(); err != nil && ws.logger != nil {
|
|
ws.logger.Error("failed to close connection on context done", zap.Int64("connId", c.connId), zap.Error(err))
|
|
}
|
|
}
|
|
ws.connsMu.Unlock()
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ws *WrappedStream) Shutdown() error {
|
|
if ws.streamStopped.Swap(true) {
|
|
if ws.logger != nil {
|
|
ws.logger.Warn("wrapped stream already shutdown")
|
|
}
|
|
return nil
|
|
}
|
|
return ws.stream.Close()
|
|
}
|