Files
OliveTin/service/internal/httpservers/restapi_auth_jwt.go
2025-04-08 20:31:55 +00:00

177 lines
4.1 KiB
Go

package httpservers
import (
"context"
"crypto/rsa"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"net/http"
"os"
"strings"
// "github.com/coreos/go-oidc/v3/oidc"
"github.com/MicahParks/keyfunc/v3"
"time"
)
var (
pubKeyBytes []byte = nil
pubKey *rsa.PublicKey
jwksVerifier keyfunc.Keyfunc
)
func initJwks() {
if jwksVerifier == nil {
var err error
if cfg.AuthJwtCertsURL != "" {
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
jwksVerifier, err = keyfunc.NewDefaultCtx(ctx, []string{
cfg.AuthJwtCertsURL,
})
if err != nil {
log.Errorf("Init JWKS Failure: %v", err)
}
defer cancel()
}
}
}
func readLocalPublicKey() error {
if pubKeyBytes != nil {
return nil // Already read.
}
pubKeyBytes, err := os.ReadFile(cfg.AuthJwtPubKeyPath)
if err != nil {
return fmt.Errorf("couldn't read public key from file %s", cfg.AuthJwtPubKeyPath)
}
// Since the token is RSA (which we validated at the start of this function), the return type of this function actually has to be rsa.PublicKey!
pubKey, err = jwt.ParseRSAPublicKeyFromPEM(pubKeyBytes)
if err != nil {
return fmt.Errorf("error parsing public key object (from %s)", cfg.AuthJwtPubKeyPath)
}
return nil
}
func parseJwtTokenWithRemoteKey(jwtToken string) (*jwt.Token, error) {
initJwks()
return jwt.Parse(jwtToken, jwksVerifier.Keyfunc, jwt.WithAudience(cfg.AuthJwtAud))
}
func parseJwtTokenWithLocalKey(jwtString string) (*jwt.Token, error) {
err := readLocalPublicKey()
if err != nil {
return nil, err
}
return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("parseJwt expected token algorithm RSA but got: %v", token.Header["alg"])
}
return pubKey, nil
})
}
// Hash-based Message Authentication Code
func parseJwtTokenWithHMAC(jwtString string) (*jwt.Token, error) {
return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("parseJwt expected token algorithm HMAC but got: %v", token.Header["alg"])
}
return []byte(cfg.AuthJwtHmacSecret), nil
})
}
func parseJwtToken(jwtString string) (*jwt.Token, error) {
if cfg.AuthJwtCertsURL != "" {
return parseJwtTokenWithRemoteKey(jwtString)
}
if cfg.AuthJwtPubKeyPath != "" {
return parseJwtTokenWithLocalKey(jwtString)
}
return parseJwtTokenWithHMAC(jwtString)
}
func getClaimsFromJwtToken(jwtString string) (jwt.MapClaims, error) {
token, err := parseJwtToken(jwtString)
if err != nil {
log.Errorf("jwt parse failure: %v", err)
return nil, errors.New("jwt parse failure")
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
return claims, nil
} else {
return nil, errors.New("jwt token isn't valid")
}
}
func lookupClaimValueOrDefault(claims jwt.MapClaims, key string, def string) string {
if val, ok := claims[key]; ok {
return fmt.Sprintf("%s", val)
} else {
return def
}
}
func parseJwtCookie(request *http.Request) (string, string) {
cookie, err := request.Cookie(cfg.AuthJwtCookieName)
if err != nil {
log.Debugf("jwt cookie check %v name: %v", err, cfg.AuthJwtCookieName)
return "", ""
}
return parseJwt(cookie.Value)
}
func parseJwt(token string) (string, string) {
claims, err := getClaimsFromJwtToken(token)
if err != nil {
log.Warnf("jwt claim error: %+v", err)
return "", ""
}
if cfg.InsecureAllowDumpJwtClaims {
log.Debugf("JWT Claims %+v", claims)
}
username := lookupClaimValueOrDefault(claims, cfg.AuthJwtClaimUsername, "")
usergroup := parseGroupClaim(cfg.AuthJwtClaimUserGroup, claims)
return username, usergroup
}
func parseGroupClaim(groupClaim string, claims jwt.MapClaims) string {
usergroup := ""
if val, ok := claims[groupClaim]; ok {
if array, ok := val.([]interface{}); ok {
groups := make([]string, len(array))
for i, v := range array {
groups[i] = fmt.Sprintf("%s", v)
}
usergroup = strings.Join(groups, " ")
} else {
usergroup = fmt.Sprintf("%s", val)
}
}
return usergroup
}