mirror of
https://github.com/OliveTin/OliveTin
synced 2025-12-15 02:25:36 +00:00
177 lines
4.1 KiB
Go
177 lines
4.1 KiB
Go
package httpservers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
log "github.com/sirupsen/logrus"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
// "github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/MicahParks/keyfunc/v3"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
pubKeyBytes []byte = nil
|
|
pubKey *rsa.PublicKey
|
|
|
|
jwksVerifier keyfunc.Keyfunc
|
|
)
|
|
|
|
func initJwks() {
|
|
if jwksVerifier == nil {
|
|
var err error
|
|
|
|
if cfg.AuthJwtCertsURL != "" {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
|
|
|
|
jwksVerifier, err = keyfunc.NewDefaultCtx(ctx, []string{
|
|
cfg.AuthJwtCertsURL,
|
|
})
|
|
|
|
if err != nil {
|
|
log.Errorf("Init JWKS Failure: %v", err)
|
|
}
|
|
|
|
defer cancel()
|
|
}
|
|
}
|
|
}
|
|
|
|
func readLocalPublicKey() error {
|
|
if pubKeyBytes != nil {
|
|
return nil // Already read.
|
|
}
|
|
|
|
pubKeyBytes, err := os.ReadFile(cfg.AuthJwtPubKeyPath)
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't read public key from file %s", cfg.AuthJwtPubKeyPath)
|
|
}
|
|
|
|
// Since the token is RSA (which we validated at the start of this function), the return type of this function actually has to be rsa.PublicKey!
|
|
pubKey, err = jwt.ParseRSAPublicKeyFromPEM(pubKeyBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing public key object (from %s)", cfg.AuthJwtPubKeyPath)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseJwtTokenWithRemoteKey(jwtToken string) (*jwt.Token, error) {
|
|
initJwks()
|
|
|
|
return jwt.Parse(jwtToken, jwksVerifier.Keyfunc, jwt.WithAudience(cfg.AuthJwtAud))
|
|
}
|
|
|
|
func parseJwtTokenWithLocalKey(jwtString string) (*jwt.Token, error) {
|
|
err := readLocalPublicKey()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("parseJwt expected token algorithm RSA but got: %v", token.Header["alg"])
|
|
}
|
|
|
|
return pubKey, nil
|
|
})
|
|
}
|
|
|
|
// Hash-based Message Authentication Code
|
|
func parseJwtTokenWithHMAC(jwtString string) (*jwt.Token, error) {
|
|
return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("parseJwt expected token algorithm HMAC but got: %v", token.Header["alg"])
|
|
}
|
|
|
|
return []byte(cfg.AuthJwtHmacSecret), nil
|
|
})
|
|
}
|
|
|
|
func parseJwtToken(jwtString string) (*jwt.Token, error) {
|
|
if cfg.AuthJwtCertsURL != "" {
|
|
return parseJwtTokenWithRemoteKey(jwtString)
|
|
}
|
|
|
|
if cfg.AuthJwtPubKeyPath != "" {
|
|
return parseJwtTokenWithLocalKey(jwtString)
|
|
}
|
|
|
|
return parseJwtTokenWithHMAC(jwtString)
|
|
}
|
|
|
|
func getClaimsFromJwtToken(jwtString string) (jwt.MapClaims, error) {
|
|
token, err := parseJwtToken(jwtString)
|
|
|
|
if err != nil {
|
|
log.Errorf("jwt parse failure: %v", err)
|
|
return nil, errors.New("jwt parse failure")
|
|
}
|
|
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
return claims, nil
|
|
} else {
|
|
return nil, errors.New("jwt token isn't valid")
|
|
}
|
|
}
|
|
|
|
func lookupClaimValueOrDefault(claims jwt.MapClaims, key string, def string) string {
|
|
if val, ok := claims[key]; ok {
|
|
return fmt.Sprintf("%s", val)
|
|
} else {
|
|
return def
|
|
}
|
|
}
|
|
|
|
func parseJwtCookie(request *http.Request) (string, string) {
|
|
cookie, err := request.Cookie(cfg.AuthJwtCookieName)
|
|
|
|
if err != nil {
|
|
log.Debugf("jwt cookie check %v name: %v", err, cfg.AuthJwtCookieName)
|
|
return "", ""
|
|
}
|
|
|
|
return parseJwt(cookie.Value)
|
|
}
|
|
|
|
func parseJwt(token string) (string, string) {
|
|
claims, err := getClaimsFromJwtToken(token)
|
|
|
|
if err != nil {
|
|
log.Warnf("jwt claim error: %+v", err)
|
|
return "", ""
|
|
}
|
|
|
|
if cfg.InsecureAllowDumpJwtClaims {
|
|
log.Debugf("JWT Claims %+v", claims)
|
|
}
|
|
|
|
username := lookupClaimValueOrDefault(claims, cfg.AuthJwtClaimUsername, "")
|
|
usergroup := parseGroupClaim(cfg.AuthJwtClaimUserGroup, claims)
|
|
|
|
return username, usergroup
|
|
}
|
|
|
|
func parseGroupClaim(groupClaim string, claims jwt.MapClaims) string {
|
|
usergroup := ""
|
|
if val, ok := claims[groupClaim]; ok {
|
|
if array, ok := val.([]interface{}); ok {
|
|
groups := make([]string, len(array))
|
|
for i, v := range array {
|
|
groups[i] = fmt.Sprintf("%s", v)
|
|
}
|
|
usergroup = strings.Join(groups, " ")
|
|
} else {
|
|
usergroup = fmt.Sprintf("%s", val)
|
|
}
|
|
}
|
|
return usergroup
|
|
}
|