You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
auth2/cmd/microauth2sqld/handler/auth2.go

380 lines
10 KiB
Go

package handler
import (
"context"
"crypto/ed25519"
"crypto/rsa"
"fmt"
"net/http"
"time"
"github.com/golang-jwt/jwt/v4"
"go-micro.dev/v4/errors"
"go-micro.dev/v4/util/log"
"google.golang.org/protobuf/types/known/emptypb"
"jochum.dev/jo-micro/auth2"
"jochum.dev/jo-micro/auth2/cmd/microauth2sqld/config"
"jochum.dev/jo-micro/auth2/cmd/microauth2sqld/db"
"jochum.dev/jo-micro/auth2/internal/argon2"
"jochum.dev/jo-micro/auth2/internal/proto/authpb"
"jochum.dev/jo-micro/auth2/plugins/verifier/endpointroles"
"jochum.dev/jo-micro/auth2/shared/sjwt"
"jochum.dev/jo-micro/components"
"jochum.dev/jo-micro/logruscomponent"
"jochum.dev/jo-micro/router"
)
type InitConfig struct {
Audiences []string
RefreshTokenExpiry int64
AccessTokenExpiry int64
AccessTokenPubKey string
AccessTokenPrivKey string
RefreshTokenPubKey string
RefreshTokenPrivKey string
}
type Handler struct {
cReg *components.Registry
audiences []string
refreshTokenExpiry int64
accessTokenExpiry int64
accessTokenPubKey any
accessTokenPrivKey any
refreshTokenPubKey any
refreshTokenPrivKey any
}
func NewHandler() *Handler {
return &Handler{}
}
func (h *Handler) Init(cReg *components.Registry, c InitConfig) error {
h.cReg = cReg
h.audiences = c.Audiences
h.accessTokenExpiry = c.AccessTokenExpiry
h.refreshTokenExpiry = c.RefreshTokenExpiry
pub, priv, err := sjwt.DecodeKeyPair(c.AccessTokenPubKey, c.AccessTokenPrivKey)
if err != nil {
return err
}
h.accessTokenPubKey = pub
h.accessTokenPrivKey = priv
pub, priv, err = sjwt.DecodeKeyPair(c.RefreshTokenPubKey, c.RefreshTokenPrivKey)
if err != nil {
return err
}
h.refreshTokenPubKey = pub
h.refreshTokenPrivKey = priv
r := router.MustReg(h.cReg)
r.Add(
router.NewRoute(
router.Method(router.MethodGet),
router.Path("/"),
router.Endpoint(authpb.AuthService.List),
router.Params("limit", "offset"),
router.AuthRequired(),
router.RatelimitUser("1-S", "10-M"),
),
router.NewRoute(
router.Method(router.MethodPost),
router.Path("/login"),
router.Endpoint(authpb.AuthService.Login),
router.RatelimitClientIP("1-S", "10-M", "30-H", "100-D"),
),
router.NewRoute(
router.Method(router.MethodPost),
router.Path("/register"),
router.Endpoint(authpb.AuthService.Register),
router.RatelimitClientIP("1-M", "10-H", "50-D"),
),
router.NewRoute(
router.Method(router.MethodPost),
router.Path("/refresh"),
router.Endpoint(authpb.AuthService.Refresh),
router.RatelimitClientIP("1-M", "10-H", "50-D"),
),
router.NewRoute(
router.Method(router.MethodDelete),
router.Path("/:userId"),
router.Endpoint(authpb.AuthService.Delete),
router.Params("userId"),
router.AuthRequired(),
router.RatelimitUser("1-S", "10-M"),
),
router.NewRoute(
router.Method(router.MethodGet),
router.Path("/:userId"),
router.Endpoint(authpb.AuthService.Detail),
router.Params("userId"),
router.AuthRequired(),
router.RatelimitUser("100-M"),
),
router.NewRoute(
router.Method(router.MethodPut),
router.Path("/:userId/roles"),
router.Endpoint(authpb.AuthService.UpdateRoles),
router.Params("userId"),
router.AuthRequired(),
router.RatelimitUser("1-M"),
),
)
authVerifier := endpointroles.NewVerifier(
endpointroles.WithLogrus(logruscomponent.MustReg(h.cReg).Logger()),
)
authVerifier.AddRules(
endpointroles.RouterRule,
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Delete),
endpointroles.RolesAllow(auth2.RolesServiceAndAdmin),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Detail),
endpointroles.RolesAllow(auth2.RolesServiceAndUsersAndAdmin),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Inspect),
endpointroles.RolesAllow(auth2.RolesServiceAndUsersAndAdmin),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.List),
endpointroles.RolesAllow(auth2.RolesServiceAndAdmin),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Login),
endpointroles.RolesAllow(auth2.RolesAllAndAnon),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Refresh),
endpointroles.RolesAllow(auth2.RolesAllAndAnon),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.Register),
endpointroles.RolesAllow(auth2.RolesAllAndAnon),
),
endpointroles.NewRule(
endpointroles.Endpoint(authpb.AuthService.UpdateRoles),
endpointroles.RolesAllow(auth2.RolesAdmin),
),
)
auth2.ClientAuthMustReg(h.cReg).Plugin().AddVerifier(authVerifier)
return nil
}
func (h *Handler) Stop() error {
return nil
}
func (h *Handler) List(ctx context.Context, in *authpb.ListRequest, out *authpb.UserListReply) error {
results, err := db.UserList(h.cReg, ctx, in.Limit, in.Offset)
if err != nil {
return err
}
// Copy the data to the result
for _, result := range results {
out.Data = append(out.Data, &authpb.User{
Id: result.ID.String(),
Username: result.Username,
Email: result.Email,
})
}
return nil
}
func (h *Handler) Detail(ctx context.Context, in *authpb.UserIDRequest, out *authpb.User) error {
result, err := db.UserDetail(h.cReg, ctx, in.UserId)
if err != nil {
return err
}
out.Id = result.ID.String()
out.Email = result.Email
out.Username = result.Username
out.Roles = result.Roles
return nil
}
func (h *Handler) Delete(ctx context.Context, in *authpb.UserIDRequest, out *emptypb.Empty) error {
err := db.UserDelete(h.cReg, ctx, in.UserId)
if err != nil {
return err
}
return nil
}
func (h *Handler) UpdateRoles(ctx context.Context, in *authpb.UpdateRolesRequest, out *authpb.User) error {
result, err := db.UserUpdateRoles(h.cReg, ctx, in.UserId, in.Roles)
if err != nil {
return err
}
out.Id = result.ID.String()
out.Email = result.Email
out.Username = result.Username
out.Roles = result.Roles
return nil
}
func (h *Handler) Register(ctx context.Context, in *authpb.RegisterRequest, out *authpb.User) error {
if in.Username == auth2.ROLE_SERVICE {
return errors.New(config.Name, "User already exists", http.StatusConflict)
}
hash, err := argon2.Hash(in.Password, argon2.DefaultParams)
if err != nil {
return err
}
result, err := db.UserCreate(h.cReg, ctx, in.Username, hash, in.Email, []string{auth2.ROLE_USER})
if err != nil {
return errors.New(config.Name, "User already exists", http.StatusConflict)
}
out.Id = result.ID.String()
out.Email = result.Email
out.Username = result.Username
out.Roles = result.Roles
return nil
}
func (h *Handler) genTokens(ctx context.Context, user *db.User, out *authpb.Token) error {
// Create the Claims
refreshClaims := sjwt.JWTClaims{
RegisteredClaims: &jwt.RegisteredClaims{
Issuer: config.Name,
Subject: user.Username,
Audience: h.audiences,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(h.accessTokenExpiry) * time.Second)),
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ID: user.ID.String(),
},
}
if err := refreshClaims.Valid(); err != nil {
return err
}
var (
accessToken *jwt.Token
refreshToken *jwt.Token
)
switch h.refreshTokenPrivKey.(type) {
case *rsa.PrivateKey:
refreshToken = jwt.NewWithClaims(jwt.SigningMethodRS512, refreshClaims)
case ed25519.PrivateKey:
refreshToken = jwt.NewWithClaims(jwt.SigningMethodEdDSA, refreshClaims)
}
refreshSignedToken, err := refreshToken.SignedString(h.refreshTokenPrivKey)
if err != nil {
return err
}
// Create the AccessToken
accessClaims := sjwt.JWTClaims{
RegisteredClaims: &jwt.RegisteredClaims{
Issuer: config.Name,
Subject: user.Username,
Audience: h.audiences,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(h.accessTokenExpiry) * time.Second)),
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ID: user.ID.String(),
},
Roles: user.Roles,
}
if err := accessClaims.Valid(); err != nil {
return err
}
switch h.accessTokenPrivKey.(type) {
case *rsa.PrivateKey:
accessToken = jwt.NewWithClaims(jwt.SigningMethodRS512, accessClaims)
case ed25519.PrivateKey:
accessToken = jwt.NewWithClaims(jwt.SigningMethodEdDSA, accessClaims)
}
accessSignedToken, err := accessToken.SignedString(h.accessTokenPrivKey)
if err != nil {
return err
}
out.Id = user.ID.String()
out.RefreshToken = refreshSignedToken
out.RefreshTokenExpiresAt = refreshClaims.ExpiresAt.Unix()
out.AccessToken = accessSignedToken
out.AccessTokenExpiresAt = accessClaims.ExpiresAt.Unix()
return nil
}
func (h *Handler) Login(ctx context.Context, in *authpb.LoginRequest, out *authpb.Token) error {
user, err := db.UserFindByUsername(h.cReg, ctx, in.Username)
if err != nil {
log.Error(err)
return errors.New(config.Name, "Wrong username or password", http.StatusUnauthorized)
}
ok, err := argon2.Verify(in.Password, user.Password)
if err != nil {
return err
}
if !ok {
return errors.New(config.Name, "Wrong username or password", http.StatusUnauthorized)
}
return h.genTokens(ctx, user, out)
}
func (h *Handler) Refresh(ctx context.Context, in *authpb.RefreshTokenRequest, out *authpb.Token) error {
claims := sjwt.JWTClaims{}
_, err := jwt.ParseWithClaims(in.RefreshToken, &claims, func(token *jwt.Token) (interface{}, error) {
return h.refreshTokenPubKey, nil
})
if err != nil {
return errors.New(config.Name, fmt.Sprintf("checking the RefreshToken: %s", err), http.StatusBadRequest)
}
// Check claims (expiration)
if err = claims.Valid(); err != nil {
return fmt.Errorf("claims invalid: %s", err)
}
user, err := db.UserFindById(h.cReg, ctx, claims.ID)
if err != nil {
return errors.New(config.Name, fmt.Sprintf("error fetching the user: %s", err), http.StatusUnauthorized)
}
return h.genTokens(ctx, user, out)
}
func (s *Handler) Inspect(ctx context.Context, in *emptypb.Empty, out *authpb.JWTClaims) error {
u := ctx.Value("user")
if u == nil {
return errors.BadRequest("auth2/handler.Inspect|no user", "no user found in context")
}
u2 := u.(auth2.User)
out.Id = u2.Id
out.Type = u2.Type
out.Issuer = u2.Issuer
out.Metadata = u2.Metadata
out.Roles = u2.Roles
out.Scopes = u2.Scopes
return nil
}