mirror of
https://github.com/garethgeorge/backrest.git
synced 2026-05-05 04:20:31 +00:00
93becf3e32
Release Please / release-please (push) Has been cancelled
Release Preview / call-reusable-release (push) Has been cancelled
Test / test-nix (push) Has been cancelled
Test / test-win (push) Has been cancelled
Update Restic / update-restic-version (push) Has been cancelled
201 lines
5.2 KiB
Go
201 lines
5.2 KiB
Go
package syncapi
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
v1 "github.com/garethgeorge/backrest/gen/go/v1"
|
|
"github.com/garethgeorge/backrest/gen/go/v1sync"
|
|
"github.com/garethgeorge/backrest/internal/config"
|
|
"github.com/garethgeorge/backrest/internal/config/migrations"
|
|
"github.com/garethgeorge/backrest/internal/cryptoutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func TestAuthMiddleware(t *testing.T) {
|
|
serverPrivKey, err := cryptoutil.GeneratePrivateKey()
|
|
require.NoError(t, err)
|
|
|
|
clientPrivKey, err := cryptoutil.GeneratePrivateKey()
|
|
require.NoError(t, err)
|
|
|
|
// Create a mock config manager
|
|
cfgManager := &config.ConfigManager{
|
|
Store: &config.MemoryStore{
|
|
Config: &v1.Config{
|
|
Version: migrations.CurrentVersion,
|
|
Instance: "test-instance",
|
|
Multihost: &v1.Multihost{
|
|
Identity: serverPrivKey,
|
|
AuthorizedClients: []*v1.Multihost_Peer{
|
|
{
|
|
InstanceId: "client-instance",
|
|
Keyid: clientPrivKey.Keyid,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Create a mock handler
|
|
mockHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
peer := PeerFromContext(r.Context())
|
|
require.NotNil(t, peer)
|
|
assert.Equal(t, "client-instance", peer.InstanceId)
|
|
rw.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Create the auth handler
|
|
authHandler := newAuthHandler(cfgManager, mockHandler)
|
|
|
|
// Create a test server
|
|
server := httptest.NewServer(authHandler)
|
|
defer server.Close()
|
|
|
|
t.Run("valid auth header", func(t *testing.T) {
|
|
// Create a request with a valid auth header
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Create a valid auth header
|
|
clientCfg := &v1.Config{
|
|
Instance: "client-instance",
|
|
Multihost: &v1.Multihost{
|
|
Identity: clientPrivKey,
|
|
},
|
|
}
|
|
authHeader, err := createAuthHeader(clientCfg)
|
|
require.NoError(t, err)
|
|
req.Header.Set(authTokenHeader, authHeader)
|
|
|
|
// Make the request
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
// Check the response
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("missing auth header", func(t *testing.T) {
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("invalid auth header", func(t *testing.T) {
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set(authTokenHeader, "invalid")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("unauthorized peer", func(t *testing.T) {
|
|
unauthorizedPrivKey, err := cryptoutil.GeneratePrivateKey()
|
|
require.NoError(t, err)
|
|
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
clientCfg := &v1.Config{
|
|
Instance: "unauthorized-instance",
|
|
Multihost: &v1.Multihost{
|
|
Identity: unauthorizedPrivKey,
|
|
},
|
|
}
|
|
authHeader, err := createAuthHeader(clientCfg)
|
|
require.NoError(t, err)
|
|
req.Header.Set(authTokenHeader, authHeader)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("instance id mismatch", func(t *testing.T) {
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
clientCfg := &v1.Config{
|
|
Instance: "wrong-instance",
|
|
Multihost: &v1.Multihost{
|
|
Identity: clientPrivKey,
|
|
},
|
|
}
|
|
authHeader, err := createAuthHeader(clientCfg)
|
|
require.NoError(t, err)
|
|
req.Header.Set(authTokenHeader, authHeader)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("signature too old", func(t *testing.T) {
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
clientCfg := &v1.Config{
|
|
Instance: "client-instance",
|
|
Multihost: &v1.Multihost{
|
|
Identity: clientPrivKey,
|
|
},
|
|
}
|
|
|
|
privKey, err := cryptoutil.NewPrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// create a signed message with an old timestamp
|
|
signedMessage, err := createSignedMessage([]byte(clientCfg.Instance), privKey)
|
|
require.NoError(t, err)
|
|
signedMessage.TimestampMillis = time.Now().Add(-2 * maxSignatureAge).UnixMilli()
|
|
|
|
// create the auth token
|
|
authToken, err := createAuthToken(signedMessage, privKey)
|
|
require.NoError(t, err)
|
|
|
|
req.Header.Set(authTokenHeader, authToken)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
}
|
|
|
|
func createAuthToken(signedMessage *v1.SignedMessage, privKey *cryptoutil.PrivateKey) (string, error) {
|
|
authToken := &v1sync.AuthorizationToken{
|
|
InstanceId: signedMessage,
|
|
PublicKey: privKey.PublicKeyProto(),
|
|
}
|
|
|
|
tokenBytes, err := proto.Marshal(authToken)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal auth token: %w", err)
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(tokenBytes), nil
|
|
}
|