diff --git a/YABA.API/Extensions/ControllerExtensions.cs b/YABA.API/Extensions/ControllerExtensions.cs index 1ef4f12..5c0622c 100644 --- a/YABA.API/Extensions/ControllerExtensions.cs +++ b/YABA.API/Extensions/ControllerExtensions.cs @@ -1,6 +1,4 @@ using Microsoft.AspNetCore.Mvc; -using System.Security.Claims; -using YABA.Common.Extensions; using YABA.Common.Lookups; namespace YABA.API.Extensions @@ -9,21 +7,15 @@ namespace YABA.API.Extensions { public static string GetAuthProviderId(this ControllerBase controller) { - return GetCustomClaim(controller, ClaimsLookup.AuthProviderId); + return controller.User.Identity.GetCustomClaim(ClaimsLookup.AuthProviderId); } public static int GetUserId(this ControllerBase controller) { - var isValidUserId = int.TryParse(GetCustomClaim(controller, ClaimsLookup.UserId), out int userId); + var isValidUserId = int.TryParse(controller.User.Identity.GetCustomClaim(ClaimsLookup.UserId), out int userId); return isValidUserId ? userId : 0; } - public static string GetCustomClaim(this ControllerBase controller, ClaimsLookup claim) - { - var claimsIdentity = controller.User.Identity as ClaimsIdentity; - return claimsIdentity.FindFirst(claim.GetClaimName())?.Value.ToString(); - } - public static string GetIpAddress(this ControllerBase controller) { if (controller.Request.Headers.ContainsKey("X-Forwarded-For")) diff --git a/YABA.API/Extensions/UserIdentityExtensions.cs b/YABA.API/Extensions/UserIdentityExtensions.cs new file mode 100644 index 0000000..89cec1c --- /dev/null +++ b/YABA.API/Extensions/UserIdentityExtensions.cs @@ -0,0 +1,18 @@ +using System.Security.Claims; +using System.Security.Principal; +using YABA.Common.Extensions; +using YABA.Common.Lookups; + +namespace YABA.API.Extensions +{ + public static class UserIdentityExtensions + { + public static string GetAuthProviderId(this IIdentity identity) => GetCustomClaim(identity, ClaimsLookup.AuthProviderId); + + public static string GetCustomClaim(this IIdentity identity, ClaimsLookup claim) + { + var claimsIdentity = identity as ClaimsIdentity; + return claimsIdentity.FindFirst(claim.GetClaimName())?.Value.ToString(); + } + } +} diff --git a/YABA.API/Middlewares/AddCustomClaimsMiddleware.cs b/YABA.API/Middlewares/AddCustomClaimsMiddleware.cs new file mode 100644 index 0000000..f837fdc --- /dev/null +++ b/YABA.API/Middlewares/AddCustomClaimsMiddleware.cs @@ -0,0 +1,36 @@ +using System.Security.Claims; +using YABA.API.Extensions; +using YABA.Common.Extensions; +using YABA.Common.Lookups; +using YABA.Service.Interfaces; + +namespace YABA.API.Middlewares +{ + public class AddCustomClaimsMiddleware + { + private readonly RequestDelegate _next; + + public AddCustomClaimsMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task InvokeAsync(HttpContext httpContext, IUserService userService) + { + if (httpContext.User != null && httpContext.User.Identity.IsAuthenticated) + { + var claims = new List(); + + var userAuthProviderId = httpContext.User.Identity.GetAuthProviderId(); + + if (!string.IsNullOrEmpty(userAuthProviderId)) + { + var userId = userService.GetUserId(userAuthProviderId); + httpContext.User.Identities.FirstOrDefault().AddClaim(new Claim(ClaimsLookup.UserId.GetClaimName(), userId.ToString())); + } + } + + await _next(httpContext); + } + } +} diff --git a/YABA.Service/Interfaces/IUserService.cs b/YABA.Service/Interfaces/IUserService.cs index dfe1c29..56871a5 100644 --- a/YABA.Service/Interfaces/IUserService.cs +++ b/YABA.Service/Interfaces/IUserService.cs @@ -9,5 +9,6 @@ namespace YABA.Service.Interfaces { public bool IsUserRegistered(string authProviderId); public UserDTO RegisterUser(string authProviderId); + public int GetUserId(string authProviderId); } } diff --git a/YABA.Service/UserService.cs b/YABA.Service/UserService.cs index 0e607a8..337c07f 100644 --- a/YABA.Service/UserService.cs +++ b/YABA.Service/UserService.cs @@ -38,5 +38,11 @@ namespace YABA.Service var registedUser = _context.Users.Add(userToRegister); return _context.SaveChanges() > 0 ? new UserDTO(registedUser.Entity) : null; } + + public int GetUserId(string authProviderId) + { + var user = _roContext.Users.FirstOrDefault(x => x.Auth0Id == authProviderId); + return user != null ? user.Id : 0; + } } }