Files
backrest/internal/api/syncapi/authmiddleware_test.go
Gareth 93becf3e32
Release Please / release-please (push) Has been cancelled
Release Preview / call-reusable-release (push) Has been cancelled
Test / test-nix (push) Has been cancelled
Test / test-win (push) Has been cancelled
Update Restic / update-restic-version (push) Has been cancelled
fix: multihost sync refinements (#830)
2025-11-01 18:29:35 -07:00

201 lines
5.2 KiB
Go

package syncapi
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
v1 "github.com/garethgeorge/backrest/gen/go/v1"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"github.com/garethgeorge/backrest/internal/config"
"github.com/garethgeorge/backrest/internal/config/migrations"
"github.com/garethgeorge/backrest/internal/cryptoutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestAuthMiddleware(t *testing.T) {
serverPrivKey, err := cryptoutil.GeneratePrivateKey()
require.NoError(t, err)
clientPrivKey, err := cryptoutil.GeneratePrivateKey()
require.NoError(t, err)
// Create a mock config manager
cfgManager := &config.ConfigManager{
Store: &config.MemoryStore{
Config: &v1.Config{
Version: migrations.CurrentVersion,
Instance: "test-instance",
Multihost: &v1.Multihost{
Identity: serverPrivKey,
AuthorizedClients: []*v1.Multihost_Peer{
{
InstanceId: "client-instance",
Keyid: clientPrivKey.Keyid,
},
},
},
},
},
}
// Create a mock handler
mockHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
peer := PeerFromContext(r.Context())
require.NotNil(t, peer)
assert.Equal(t, "client-instance", peer.InstanceId)
rw.WriteHeader(http.StatusOK)
})
// Create the auth handler
authHandler := newAuthHandler(cfgManager, mockHandler)
// Create a test server
server := httptest.NewServer(authHandler)
defer server.Close()
t.Run("valid auth header", func(t *testing.T) {
// Create a request with a valid auth header
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
// Create a valid auth header
clientCfg := &v1.Config{
Instance: "client-instance",
Multihost: &v1.Multihost{
Identity: clientPrivKey,
},
}
authHeader, err := createAuthHeader(clientCfg)
require.NoError(t, err)
req.Header.Set(authTokenHeader, authHeader)
// Make the request
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Check the response
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("missing auth header", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("invalid auth header", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
req.Header.Set(authTokenHeader, "invalid")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("unauthorized peer", func(t *testing.T) {
unauthorizedPrivKey, err := cryptoutil.GeneratePrivateKey()
require.NoError(t, err)
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
clientCfg := &v1.Config{
Instance: "unauthorized-instance",
Multihost: &v1.Multihost{
Identity: unauthorizedPrivKey,
},
}
authHeader, err := createAuthHeader(clientCfg)
require.NoError(t, err)
req.Header.Set(authTokenHeader, authHeader)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("instance id mismatch", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
clientCfg := &v1.Config{
Instance: "wrong-instance",
Multihost: &v1.Multihost{
Identity: clientPrivKey,
},
}
authHeader, err := createAuthHeader(clientCfg)
require.NoError(t, err)
req.Header.Set(authTokenHeader, authHeader)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("signature too old", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
clientCfg := &v1.Config{
Instance: "client-instance",
Multihost: &v1.Multihost{
Identity: clientPrivKey,
},
}
privKey, err := cryptoutil.NewPrivateKey(clientPrivKey)
require.NoError(t, err)
// create a signed message with an old timestamp
signedMessage, err := createSignedMessage([]byte(clientCfg.Instance), privKey)
require.NoError(t, err)
signedMessage.TimestampMillis = time.Now().Add(-2 * maxSignatureAge).UnixMilli()
// create the auth token
authToken, err := createAuthToken(signedMessage, privKey)
require.NoError(t, err)
req.Header.Set(authTokenHeader, authToken)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
}
func createAuthToken(signedMessage *v1.SignedMessage, privKey *cryptoutil.PrivateKey) (string, error) {
authToken := &v1sync.AuthorizationToken{
InstanceId: signedMessage,
PublicKey: privKey.PublicKeyProto(),
}
tokenBytes, err := proto.Marshal(authToken)
if err != nil {
return "", fmt.Errorf("marshal auth token: %w", err)
}
return base64.StdEncoding.EncodeToString(tokenBytes), nil
}