auth: Support all JWT algorithms
This change adds support to etcd for all of the JWT algorithms included in the underlying JWT library.
This commit is contained in:
192
auth/jwt.go
192
auth/jwt.go
@ -16,8 +16,9 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"io/ioutil"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
@ -26,10 +27,10 @@ import (
|
||||
|
||||
type tokenJWT struct {
|
||||
lg *zap.Logger
|
||||
signMethod string
|
||||
signKey *rsa.PrivateKey
|
||||
verifyKey *rsa.PublicKey
|
||||
signMethod jwt.SigningMethod
|
||||
key interface{}
|
||||
ttl time.Duration
|
||||
verifyOnly bool
|
||||
}
|
||||
|
||||
func (t *tokenJWT) enable() {}
|
||||
@ -45,25 +46,20 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf
|
||||
)
|
||||
|
||||
parsed, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
||||
return t.verifyKey, nil
|
||||
if token.Method.Alg() != t.signMethod.Alg() {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
switch k := t.key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey, nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey, nil
|
||||
default:
|
||||
return t.key, nil
|
||||
}
|
||||
})
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
if !parsed.Valid {
|
||||
if t.lg != nil {
|
||||
t.lg.Warn("invalid JWT token", zap.String("token", token))
|
||||
} else {
|
||||
plog.Warningf("invalid jwt token: %s", token)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims := parsed.Claims.(jwt.MapClaims)
|
||||
|
||||
username = claims["username"].(string)
|
||||
revision = uint64(claims["revision"].(float64))
|
||||
default:
|
||||
if err != nil {
|
||||
if t.lg != nil {
|
||||
t.lg.Warn(
|
||||
"failed to parse a JWT token",
|
||||
@ -76,20 +72,37 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !parsed.Valid || !ok {
|
||||
if t.lg != nil {
|
||||
t.lg.Warn("invalid JWT token", zap.String("token", token))
|
||||
} else {
|
||||
plog.Warningf("invalid jwt token: %s", token)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
username = claims["username"].(string)
|
||||
revision = uint64(claims["revision"].(float64))
|
||||
|
||||
return &AuthInfo{Username: username, Revision: revision}, true
|
||||
}
|
||||
|
||||
func (t *tokenJWT) assign(ctx context.Context, username string, revision uint64) (string, error) {
|
||||
if t.verifyOnly {
|
||||
return "", ErrVerifyOnly
|
||||
}
|
||||
|
||||
// Future work: let a jwt token include permission information would be useful for
|
||||
// permission checking in proxy side.
|
||||
tk := jwt.NewWithClaims(jwt.GetSigningMethod(t.signMethod),
|
||||
tk := jwt.NewWithClaims(t.signMethod,
|
||||
jwt.MapClaims{
|
||||
"username": username,
|
||||
"revision": revision,
|
||||
"exp": time.Now().Add(t.ttl).Unix(),
|
||||
})
|
||||
|
||||
token, err := tk.SignedString(t.signKey)
|
||||
token, err := tk.SignedString(t.key)
|
||||
if err != nil {
|
||||
if t.lg != nil {
|
||||
t.lg.Warn(
|
||||
@ -117,113 +130,54 @@ func (t *tokenJWT) assign(ctx context.Context, username string, revision uint64)
|
||||
return token, err
|
||||
}
|
||||
|
||||
func prepareOpts(lg *zap.Logger, opts map[string]string) (jwtSignMethod, jwtPubKeyPath, jwtPrivKeyPath string, ttl time.Duration, err error) {
|
||||
for k, v := range opts {
|
||||
switch k {
|
||||
case "sign-method":
|
||||
jwtSignMethod = v
|
||||
case "pub-key":
|
||||
jwtPubKeyPath = v
|
||||
case "priv-key":
|
||||
jwtPrivKeyPath = v
|
||||
case "ttl":
|
||||
ttl, err = time.ParseDuration(v)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn(
|
||||
"failed to parse JWT TTL option",
|
||||
zap.String("ttl-value", v),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
plog.Errorf("failed to parse ttl option (%s)", err)
|
||||
}
|
||||
return "", "", "", 0, ErrInvalidAuthOpts
|
||||
}
|
||||
default:
|
||||
if lg != nil {
|
||||
lg.Warn("unknown JWT token option", zap.String("option", k))
|
||||
} else {
|
||||
plog.Errorf("unknown token specific option: %s", k)
|
||||
}
|
||||
return "", "", "", 0, ErrInvalidAuthOpts
|
||||
}
|
||||
}
|
||||
if len(jwtSignMethod) == 0 {
|
||||
return "", "", "", 0, ErrInvalidAuthOpts
|
||||
}
|
||||
return jwtSignMethod, jwtPubKeyPath, jwtPrivKeyPath, ttl, nil
|
||||
}
|
||||
|
||||
func newTokenProviderJWT(lg *zap.Logger, opts map[string]string) (*tokenJWT, error) {
|
||||
jwtSignMethod, jwtPubKeyPath, jwtPrivKeyPath, ttl, err := prepareOpts(lg, opts)
|
||||
func newTokenProviderJWT(lg *zap.Logger, optMap map[string]string) (*tokenJWT, error) {
|
||||
var err error
|
||||
var opts jwtOptions
|
||||
err = opts.ParseWithDefaults(optMap)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn("problem loading JWT options", zap.Error(err))
|
||||
} else {
|
||||
plog.Errorf("problem loading JWT options: %s", err)
|
||||
}
|
||||
return nil, ErrInvalidAuthOpts
|
||||
}
|
||||
|
||||
if ttl == 0 {
|
||||
ttl = 5 * time.Minute
|
||||
var keys = make([]string, 0, len(optMap))
|
||||
for k := range optMap {
|
||||
if !knownOptions[k] {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if lg != nil {
|
||||
lg.Warn("unknown JWT options", zap.Strings("keys", keys))
|
||||
} else {
|
||||
plog.Warningf("unknown JWT options: %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
key, err := opts.Key()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &tokenJWT{
|
||||
lg: lg,
|
||||
ttl: ttl,
|
||||
lg: lg,
|
||||
ttl: opts.TTL,
|
||||
signMethod: opts.SignMethod,
|
||||
key: key,
|
||||
}
|
||||
|
||||
t.signMethod = jwtSignMethod
|
||||
|
||||
verifyBytes, err := ioutil.ReadFile(jwtPubKeyPath)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn(
|
||||
"failed to read JWT public key",
|
||||
zap.String("public-key-path", jwtPubKeyPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
plog.Errorf("failed to read public key (%s) for jwt: %s", jwtPubKeyPath, err)
|
||||
switch t.signMethod.(type) {
|
||||
case *jwt.SigningMethodECDSA:
|
||||
if _, ok := t.key.(*ecdsa.PublicKey); ok {
|
||||
t.verifyOnly = true
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
t.verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn(
|
||||
"failed to parse JWT public key",
|
||||
zap.String("public-key-path", jwtPubKeyPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
plog.Errorf("failed to parse public key (%s): %s", jwtPubKeyPath, err)
|
||||
case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS:
|
||||
if _, ok := t.key.(*rsa.PublicKey); ok {
|
||||
t.verifyOnly = true
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signBytes, err := ioutil.ReadFile(jwtPrivKeyPath)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn(
|
||||
"failed to read JWT private key",
|
||||
zap.String("private-key-path", jwtPrivKeyPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
plog.Errorf("failed to read private key (%s) for jwt: %s", jwtPrivKeyPath, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
t.signKey, err = jwt.ParseRSAPrivateKeyFromPEM(signBytes)
|
||||
if err != nil {
|
||||
if lg != nil {
|
||||
lg.Warn(
|
||||
"failed to parse JWT private key",
|
||||
zap.String("private-key-path", jwtPrivKeyPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
plog.Errorf("failed to parse private key (%s): %s", jwtPrivKeyPath, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
|
Reference in New Issue
Block a user