Skip to end of metadata
Go to start of metadata

You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 2 Current »

Summary

In the current implementation we use an attribute to inject authentication into the controller pipeline. ASP.NET Core suggests we use middle ware and inject further up the pipeline, and thus we can implement the usage of the Authorize attribute for controllers that require authentication and/or authorization. Authorization though remains the same, and is addressed with the NHibernate process as before.

Challenges

The challenge we have in our solutions are two fold. First is how do we authorize using OAuth, and how do we secure the Identities endpoint that uses a policy. ASP.NET has built in support for OAuth however, we are not using an external provider though the application could be modified to do so. Secondly, we need to pass the claims information down into the system. Be default, ASP.NET Core will place the claims onto the request once authorize, however, we have the unique challenge of needing that information on other threads external to the system. Microsoft has deprecated the usage of IHttpContext, and the recommendation is to pass the claims that has been resolved from the request context to any services. 

Solution

To solve this issue of passing the claims to other threads, Microsoft has added an extension method that allows us to use IHttpContext as before, even though it has been deprecated, and this is declared in Startup class.

            // this allows the solution to resolve the claims principal. this is not best practice defined by the
            // netcore team, as the claims principal is on the controllers.
            // c.f. https://docs.microsoft.com/en-us/aspnet/core/migration/claimsprincipal-current?view=aspnetcore-3.1
            services.AddHttpContextAccessor();

To apply the authentication challenge, we implement a custom authentication handler as follows:

using System.Net.Http.Headers;
using System.Security.Claims;
using System.Text.Encodings.Web;
using System.Threading.Tasks;
using EdFi.Ods.Api.NetCore.Providers;
using EdFi.Ods.Common.Security;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace EdFi.Ods.Api.NetCore.Middleware
{
    public class EdFiOAuthAuthenticationHandler : AuthenticationHandler<AuthenticationSchemeOptions>
    {
        private readonly IApiKeyContextProvider _apiKeyContextProvider;
        private readonly IAuthenticationProvider _authenticationProvider;

        public EdFiOAuthAuthenticationHandler(
            IOptionsMonitor<AuthenticationSchemeOptions> options,
            ILoggerFactory logger,
            UrlEncoder encoder,
            ISystemClock clock,
            IAuthenticationProvider authenticationProvider,
            IApiKeyContextProvider apiKeyContextProvider)
            : base(options, logger, encoder, clock)
        {
            _authenticationProvider = authenticationProvider;
            _apiKeyContextProvider = apiKeyContextProvider;
        }

        protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
        {
            AuthenticationResult authenticationResult;

            try
            {
                var authHeader = AuthenticationHeaderValue.Parse(Request.Headers["Authorization"]);

                authenticationResult = await _authenticationProvider.GetAuthenticationResultAsync(authHeader);

                if (authenticationResult == null)
                {
                    return AuthenticateResult.NoResult();
                }

                if (authenticationResult.ClaimsIdentity == null && authenticationResult.AuthenticateResult == null)
                {
                    return AuthenticateResult.NoResult();
                }

                if (authenticationResult.AuthenticateResult != null)
                {
                    return authenticationResult.AuthenticateResult;
                }
            }
            catch
            {
                return AuthenticateResult.Fail("Invalid Authorization Header");
            }

            // Set the api key context
            _apiKeyContextProvider.SetApiKeyContext(authenticationResult.ApiKeyContext);

            var principal = new ClaimsPrincipal(authenticationResult.ClaimsIdentity);
            var ticket = new AuthenticationTicket(principal, Scheme.Name);

            return AuthenticateResult.Success(ticket);
        }
    }
}

The token controller is updated as follows:

using System;
using System.Text;
using System.Threading.Tasks;
using EdFi.Ods.Api.Common.Models.Tokens;
using EdFi.Ods.Api.NetCore.Providers;
using EdFi.Ods.Common.Extensions;
using log4net;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;

namespace EdFi.Ods.Api.NetCore.Controllers
{
    [ApiExplorerSettings(IgnoreApi = true)]
    [ApiController]
    [Route("oauth/token")]
    [Produces("application/json")]
    [AllowAnonymous]
    public class TokenController : ControllerBase
    {
        private readonly ILog _logger = LogManager.GetLogger(typeof(TokenController));
        private readonly ITokenRequestProvider _requestProvider;

        public TokenController(ITokenRequestProvider provider)
        {
            _requestProvider = provider;
        }

        [HttpPost]
        [AllowAnonymous]
        [Consumes("application/json")]
        public async Task<IActionResult> Post([FromBody] TokenRequest tokenRequest)
        {
            // Handle token request
            var authenticationResult = await _requestProvider.HandleAsync(tokenRequest);

            if (authenticationResult.TokenError != null)
            {
                return BadRequest(authenticationResult.TokenError);
            }

            return Ok(authenticationResult.TokenResponse);
        }

        [HttpPost]
        [AllowAnonymous]
        [Consumes("application/x-www-form-urlencoded")]
        public async Task<IActionResult> PostFromForm([FromForm] TokenRequest tokenRequest)
        {
            // Look for the authorization header, since we MUST support this method of authorization
            // https://tools.ietf.org/html/rfc6749#section-2.3.1
            // Decode and parse the client id/secret from the header
            // Authorization is in a form of Bearer <encoded client and secret>

            if (!Request.Headers.ContainsKey("Authorization"))
            {
                _logger.Debug($"Header is missing authorization credentials");
                return Unauthorized();
            }

            string[] encodedClientAndSecret = Request.Headers["Authorization"]
                .ToString()
                .Split(' ');

            if (encodedClientAndSecret.Length != 2)
            {
                _logger.Debug("Header is not in the form of Basic <encoded credentials>");
                return Unauthorized();
            }

            if (!encodedClientAndSecret[0]
                .EqualsIgnoreCase("Basic"))
            {
                _logger.Debug("Authorization scheme is not Basic");
                return Unauthorized();
            }

            string[] clientIdAndSecret = Base64Decode(encodedClientAndSecret[1])
                .Split(':');

            // Correct format will include 2 entries
            // format of the string is <client_id>:<client_secret>
            if (clientIdAndSecret.Length == 2)
            {
                tokenRequest.Client_id = clientIdAndSecret[0];
                tokenRequest.Client_secret = clientIdAndSecret[1];
            }

            var authenticationResult = await _requestProvider.HandleAsync(tokenRequest);

            if (authenticationResult.TokenError != null)
            {
                return BadRequest(authenticationResult.TokenError);
            }

            return Ok(authenticationResult.TokenResponse);

            string Base64Decode(string encoded) => Encoding.UTF8.GetString(Convert.FromBase64String(encoded));
        }
    }
}

The request provider was refactored to use async methods, along with the removal of the cross cutting concerns with passing back an IActionResult.

using System.Linq;
using System.Threading.Tasks;
using EdFi.Ods.Api.Common.Models.Tokens;
using EdFi.Ods.Api.NetCore.Models.ClientCredentials;
using EdFi.Ods.Common.Extensions;
using EdFi.Ods.Common.Security;
using EdFi.Ods.Sandbox.Repositories;
using Microsoft.Extensions.Hosting.Internal;

namespace EdFi.Ods.Api.NetCore.Providers
{
    public class ClientCredentialsTokenRequestProvider
        : ITokenRequestProvider
    {
        private readonly IApiClientAuthenticator _apiClientAuthenticator;
        private readonly IClientAppRepo _clientAppRepo;

        public ClientCredentialsTokenRequestProvider(IClientAppRepo clientAppRepo, IApiClientAuthenticator apiClientAuthenticator)
        {
            _clientAppRepo = clientAppRepo;
            _apiClientAuthenticator = apiClientAuthenticator;
        }

        public async Task<AuthenticationResponse> HandleAsync(TokenRequest tokenRequest)
        {
            // Only handle the "client_credentials" grant type
            if (!RequestIsRequiredGrantType())
            {
                return new AuthenticationResponse {TokenError = new TokenError(TokenErrorType.UnsupportedGrantType)};
            }

            // Verify client_id and client_secret are present
            if (!HasIdAndSecret())
            {
                return new AuthenticationResponse {TokenError = new TokenError(TokenErrorType.InvalidClient)};
            }

            // authenticate the client
            var authenticationResult = await _apiClientAuthenticator.TryAuthenticateAsync(
                tokenRequest.Client_id,
                tokenRequest.Client_secret);

            if (!authenticationResult.IsAuthenticated)
            {
                return new AuthenticationResponse {TokenError = new TokenError(TokenErrorType.InvalidClient)};
            }

            // get client information
            var client = await _clientAppRepo.GetClientAsync(authenticationResult.ApiClientIdentity.Key);

            // Convert empty scope to null
            string tokenRequestScope = string.IsNullOrEmpty(tokenRequest.Scope)
                ? null
                : tokenRequest.Scope.Trim();

            // validate client is in scope
            if (tokenRequestScope != null)
            {
                if (!int.TryParse(tokenRequestScope, out int educationOrganizationScope))
                {
                    return new AuthenticationResponse
                    {
                        TokenError = new TokenError(
                            TokenErrorType.InvalidScope,
                            "The supplied 'scope' was not a number (it should be an EducationOrganizationId that is explicitly associated with the client).")
                    };
                }

                if (!client.ApplicationEducationOrganizations
                    .Select(x => x.EducationOrganizationId)
                    .Contains(educationOrganizationScope))
                {
                    return new AuthenticationResponse
                    {
                        TokenError = new TokenError(
                            TokenErrorType.InvalidScope,
                            "The client is not explicitly associated with the EducationOrganizationId specified in the requested 'scope'.")
                    };
                }
            }

            // create a new token
            var token = await _clientAppRepo.AddClientAccessTokenAsync(client.ApiClientId, tokenRequestScope);

            var tokenResponse = new TokenResponse();
            tokenResponse.SetToken(token.Id, (int) token.Duration.TotalSeconds, token.Scope);

            return new AuthenticationResponse {TokenResponse = tokenResponse};

            bool RequestIsRequiredGrantType() => tokenRequest.Grant_type.EqualsIgnoreCase("client_credentials");

            bool HasIdAndSecret()
                => !string.IsNullOrEmpty(tokenRequest.Client_secret) && !string.IsNullOrEmpty(tokenRequest.Client_id);
        }
    }
}

Also the ClientAccessRepo was refactored to use async methods and to use the IConfigurationRoot interface.

using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Data.Entity.Migrations;
using System.Data.SqlClient;
using System.Linq;
using System.Threading.Tasks;
using EdFi.Admin.DataAccess;
using EdFi.Admin.DataAccess.Contexts;
using EdFi.Admin.DataAccess.Extensions;
using EdFi.Ods.Common;
#if NETFRAMEWORK
using EdFi.Ods.Common.Configuration;
#elif NETSTANDARD
using Microsoft.Extensions.Configuration;
#endif
using EdFi.Ods.Common.Extensions;
using EdFi.Admin.DataAccess.Models;
using EdFi.Ods.Sandbox.Provisioners;

namespace EdFi.Ods.Sandbox.Repositories
{
    public class ClientAppRepo : IClientAppRepo
    {
        private const int DefaultDuration = 60;
        private readonly IUsersContextFactory _contextFactory;
        private readonly ISandboxProvisioner _provisioner;
        private readonly Lazy<int> _duration;
        private readonly Lazy<string> _defaultOperationalContextUri;
        private readonly Lazy<string> _defaultAppName;
        private readonly Lazy<string> _defaultClaimSetName;

#if NETFRAMEWORK
        public ClientAppRepo(
            IUsersContextFactory contextFactory,
            ISandboxProvisioner provisioner,
            IConfigValueProvider configValueProvider)
        {
            _contextFactory = Preconditions.ThrowIfNull(contextFactory, nameof(contextFactory));
            _provisioner = Preconditions.ThrowIfNull(provisioner, nameof(provisioner));
            Preconditions.ThrowIfNull(configValueProvider, nameof(configValueProvider));

            _duration = new Lazy<int>(
                () =>
                {
                    // Get the config value, defaulting to 1 hour
                    if (!int.TryParse(configValueProvider.GetValue("BearerTokenTimeoutMinutes"), out int duration))
                    {
                        duration = DefaultDuration;
                    }

                    return duration;
                });

            _defaultOperationalContextUri = new Lazy<string>(() => configValueProvider.GetValue("DefaultOperationalContextUri"));
            _defaultAppName = new Lazy<string>(() => configValueProvider.GetValue("DefaultApplicationName"));
            _defaultClaimSetName = new Lazy<string>(() => configValueProvider.GetValue("DefaultClaimSetName"));
        }
#elif NETSTANDARD
        public ClientAppRepo(
            IUsersContextFactory contextFactory,
            ISandboxProvisioner provisioner,
            IConfigurationRoot config)
        {
            _contextFactory = Preconditions.ThrowIfNull(contextFactory, nameof(contextFactory));
            _provisioner = Preconditions.ThrowIfNull(provisioner, nameof(provisioner));
            Preconditions.ThrowIfNull(config, nameof(config));

            _duration = new Lazy<int>(
                () =>
                {
                    // Get the config value, defaulting to 1 hour
                    if (!int.TryParse(
                        config.GetSection("BearerTokenTimeoutMinutes")
                            .Value,
                        out int duration))
                    {
                        duration = DefaultDuration;
                    }

                    return duration;
                });

            _defaultOperationalContextUri = new Lazy<string>(
                () => config.GetSection("DefaultOperationalContextUri")
                    .Value);

            _defaultAppName = new Lazy<string>(
                () => config.GetSection("DefaultApplicationName")
                    .Value);

            _defaultClaimSetName = new Lazy<string>(
                () => config.GetSection("DefaultClaimSetName")
                    .Value);
        }
#endif

        private Profile GetOrCreateProfile(string profileName)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var profiles = context.Profiles.FirstOrDefault(s => s.ProfileName == profileName);

                if (profiles == null)
                {
                    context.Profiles.Add(new Profile {ProfileName = profileName});
                    context.SaveChanges();
                }

                return context.Profiles.FirstOrDefault(s => s.ProfileName == profileName);
            }
        }

        public void AddProfilesToApplication(List<string> profileNames, int applicationId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                foreach (var profileName in profileNames)
                {
                    var profile = GetOrCreateProfile(profileName);

                    var currentProfile = context.Profiles
                        .Include(u => u.Applications)
                        .FirstOrDefault(u => u.ProfileId == profile.ProfileId);

                    if (!currentProfile.Applications.Any(a => a.ApplicationId == applicationId))
                    {
                        var application = context.Applications.FirstOrDefault(a => a.ApplicationId == applicationId);
                        currentProfile.Applications.Add(application);
                    }
                }

                context.SaveChanges();
            }
        }

        public async Task<string> GetUserNameFromTokenAsync(string token)
        {
            using (var context = _contextFactory.CreateContext())
            {
                // Used by Sandbox Admin only, therefore PostgreSQL support is not needed
                var result = await context
                    .ExecuteQueryAsync<EmailResult>(
                        $"select top 1 U.Email from webpages_Membership M join Users U on M.UserId = U.UserId and M.ConfirmationToken = {token}");

                return result.FirstOrDefault() == null
                    ? null
                    : result.FirstOrDefault()
                        .Email;
            }
        }

        public async Task<string> GetTokenFromUserNameAsync(string userName)
        {
            using (var context = _contextFactory.CreateContext())
            {
                // Used by Sandbox Admin only, therefore PostgreSQL support is not needed
                var result = await context
                    .ExecuteQueryAsync<ConfirmationTokenResult>(
                        $"select top 1 M.ConfirmationToken from webpages_Membership M join Users U on M.UserId = U.UserId and U.Email = {userName}");

                return result.FirstOrDefault() == null
                    ? null
                    : result.FirstOrDefault()
                        .ConfirmationToken;
            }
        }

        public User CreateUser(User user)
        {
            using (var context = _contextFactory.CreateContext())
            {
                context.Users.Add(user);
                context.SaveChanges();
            }

            return user;
        }

        public IEnumerable<User> GetUsers()
        {
            using (var context = _contextFactory.CreateContext())
            {
                return context.Users.Include(u => u.ApiClients.Select(ac => ac.Application))
                    .ToList();
            }
        }

        public User GetUser(int userId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return
                    context.Users.Include(u => u.ApiClients.Select(ac => ac.Application))
                        .FirstOrDefault(u => u.UserId == userId);
            }
        }

        public User GetUser(string userName)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return
                    context.Users.Include(u => u.ApiClients.Select(ac => ac.Application))
                        .Include(u => u.Vendor)
                        .FirstOrDefault(x => x.Email == userName);
            }
        }

        public void DeleteUser(User userProfile)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var user =
                    context.Users.Include(u => u.ApiClients.Select(ac => ac.Application))
                        .FirstOrDefault(x => x.UserId == userProfile.UserId);

                if (user == null)
                {
                    return;
                }

                var arraySoThatUnderlyingCollectionCanBeModified = user.ApiClients.ToArray();

                foreach (var client in arraySoThatUnderlyingCollectionCanBeModified)
                {
                    context.Clients.Remove(client);
                }

                context.Users.Remove(user);
                context.SaveChanges();
            }
        }

        public ApiClient GetClient(string key)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return context.Clients.Include(c => c.Application)
                    .Include(c => c.Application.Vendor)
                    .Include(c => c.Application.Vendor.VendorNamespacePrefixes)
                    .Include(c => c.Application.Profiles)
                    .Include(c => c.ApplicationEducationOrganizations)
                    .Include(c => c.CreatorOwnershipTokenId)
                    .FirstOrDefault(c => c.Key == key);
            }
        }

        public async Task<ApiClient> GetClientAsync(string key)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return await context.Clients.Include(c => c.Application)
                    .Include(c => c.Application.Vendor)
                    .Include(c => c.Application.Vendor.VendorNamespacePrefixes)
                    .Include(c => c.Application.Profiles)
                    .Include(c => c.ApplicationEducationOrganizations)
                    .Include(c => c.CreatorOwnershipTokenId)
                    .FirstOrDefaultAsync(c => c.Key == key);
            }
        }

        public ApiClient GetClient(string key, string secret)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return context.Clients.FirstOrDefault(c => c.Key == key && c.Secret == secret);
            }
        }

        public ApiClient UpdateClient(ApiClient client)
        {
            using (var context = _contextFactory.CreateContext())
            {
                context.Clients.AddOrUpdate(client);
                context.SaveChanges();
                return client;
            }
        }

        public void DeleteClient(string key)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var client = context.Clients.First(x => x.Key == key);

                // TODO SF: AA-518
                // Assuming that this is used by Admin App, although that will not actually be clear
                // until we are able to start testing Admin App thoroughly.
                // Convert this to ANSI SQL for PostgreSql support and don't use a SqlParameter.
                // Be sure to write integration tests in project EdFi.Ods.Admin.Models.IntegrationTests.
                context.ExecuteSqlCommandAsync(
                        @"delete ClientAccessTokens where ApiClient_ApiClientId = @clientId; 
delete ApiClients where ApiClientId = @clientId",
                        new SqlParameter("@clientId", client.ApiClientId))
                    .Wait();

                if (client.UseSandbox)
                {
                    _provisioner.DeleteSandboxes(key);
                }
            }
        }

        public async Task<ClientAccessToken> AddClientAccessTokenAsync(int apiClientId, string tokenRequestScope = null)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var client = await context.Clients.FirstOrDefaultAsync(c => c.ApiClientId == apiClientId);

                if (client == null)
                {
                    throw new InvalidOperationException("Cannot add client access token when the client does not exist.");
                }

                var token = new ClientAccessToken(TimeSpan.FromMinutes(_duration.Value))
                {
                    Scope = string.IsNullOrEmpty(tokenRequestScope)
                        ? null
                        : tokenRequestScope.Trim()
                };

                client.ClientAccessTokens.Add(token);
                await context.SaveChangesAsync();
                return token;
            }
        }

        public ClientAccessToken AddClientAccessToken(int apiClientId, string tokenRequestScope = null)
        {
            return AddClientAccessTokenAsync(apiClientId, tokenRequestScope)
                .GetResultSafely();
        }

        public Application[] GetVendorApplications(int vendorId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                return context.Applications.Where(a => a.Vendor.VendorId == vendorId)
                    .ToArray();
            }
        }

        public void AddApiClientToUserWithVendorApplication(int userId, ApiClient client)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var user = context.Users
                    .Include(u => u.Vendor)
                    .Include(v => v.Vendor.Applications)
                    .SingleOrDefault(u => u.UserId == userId);

                if (user == null)
                {
                    return;
                }

                if (user.Vendor != null)
                {
                    client.Application = user.Vendor.Applications.FirstOrDefault();
                }

                context.Clients.Add(client);
                context.SaveChanges();
            }
        }

        public ApiClient CreateApiClient(int userId, string name, string key, string secret)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var client = CreateApiClient(context, userId, name, SandboxType.Sample, key, secret);

                context.SaveChanges();

                return client;
            }
        }

        public void SetupKeySecret(
            string name,
            SandboxType sandboxType,
            string key,
            string secret,
            int userId,
            int applicationId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var client = CreateApiClient(context, userId, name, sandboxType, key, secret);

                AddApplicationEducationOrganizations(context, applicationId, client);

                context.SaveChanges();
            }
        }

        private ApiClient CreateApiClient(
            IUsersContext context,
            int userId,
            string name,
            SandboxType sandboxType,
            string key,
            string secret)
        {
            var attachedUser = context.Users.Find(userId);

            return attachedUser.AddSandboxClient(name, sandboxType, key, secret);
        }

        public void AddLeaIdsToApiClient(int userId, int apiClientId, IList<int> leaIds, int applicationId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var application = context.Applications
                    .Include(a => a.ApplicationEducationOrganizations)
                    .Single(a => a.ApplicationId == applicationId);

                var user = context.Users.FirstOrDefault(u => u.UserId == userId);

                var client = user?.ApiClients.FirstOrDefault(c => c.ApiClientId == apiClientId);

                if (client == null)
                {
                    return;
                }

                client.Application = application;

                foreach (var applicationEducationOrganization in application.ApplicationEducationOrganizations.Where(
                    s => leaIds.Contains(s.EducationOrganizationId)))
                {
                    client.ApplicationEducationOrganizations.Add(applicationEducationOrganization);
                }

                context.SaveChanges();
            }
        }

        private void AddApplicationEducationOrganizations(IUsersContext context, int applicationId, ApiClient client)
        {
            var defaultApplication = context.Applications
                .Include(a => a.ApplicationEducationOrganizations)
                .First(a => a.ApplicationId == applicationId);

            client.Application = defaultApplication;

            foreach (var applicationEducationOrganization in defaultApplication.ApplicationEducationOrganizations)
            {
                client.ApplicationEducationOrganizations.Add(applicationEducationOrganization);
            }
        }

        public ApiClient SetupDefaultSandboxClient(
            string name,
            SandboxType sandboxType,
            string key,
            string secret,
            int userId,
            int applicationId)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var client = CreateApiClient(context, userId, name, sandboxType, key, secret);

                AddApplicationEducationOrganizations(context, applicationId, client);

                _provisioner.AddSandbox(client.Key, sandboxType);

                context.SaveChanges();

                return client;
            }
        }

        public void Reset()
        {
            try
            {
                using (var context = _contextFactory.CreateContext())
                {
                    var dbContext = context as DbContext;

                    try
                    {
                        //Admin.Web Creates table webpages_UsersInRoles.
                        //If exists remove rows, if not swallow exception.
                        dbContext.DeleteAll<WebPagesUsersInRoles>();
                        context.SaveChanges();
                    }
                    catch (Exception) { }

                    dbContext.DeleteAll<ClientAccessToken>();
                    dbContext.DeleteAll<ApiClient>();
                    dbContext.DeleteAll<User>();
                    dbContext.DeleteAll<ApplicationEducationOrganization>();
                    dbContext.DeleteAll<Application>();
                    dbContext.DeleteAll<Vendor>();
                    context.SaveChanges();
                }
            }
            catch (Exception ex)
            {
                throw new Exception("Error occurred while attempting to reset Admin database.", ex);
            }
        }

        public void SetDefaultVendorOnUserFromEmailAndName(string userEmail, string userName)
        {
            var namePrefix = "uri://" + userEmail.Split('@')[1]
                .ToLower();

            var vendorName = userName.Split(',')[0]
                .Trim();

            using (var context = _contextFactory.CreateContext())
            {
                var vendor = FindOrCreateVendorByDomainName(context, vendorName, namePrefix);
                var usr = context.Users.Single(u => u.Email == userEmail);
                usr.Vendor = vendor;
                context.SaveChanges();
            }
        }

        public Vendor CreateOrGetVendor(string userEmail, string userName)
        {
            var vendorName = userName.Split(',')[0]
                .Trim();

            var namePrefix = "uri://" + userEmail.Split('@')[1]
                .ToLower();

            using (var context = _contextFactory.CreateContext())
            {
                var vendor = context.Vendors.SingleOrDefault(v => v.VendorName == vendorName);

                if (vendor == null)
                {
                    vendor = new Vendor {VendorName = vendorName};

                    vendor.VendorNamespacePrefixes.Add(
                        new VendorNamespacePrefix
                        {
                            Vendor = vendor,
                            NamespacePrefix = namePrefix
                        });
                }

                return vendor;
            }
        }

        private Vendor FindOrCreateVendorByDomainName(IUsersContext context, string vendorName, string namePrefix)
        {
            var vendor = context.Vendors.SingleOrDefault(v => v.VendorName == vendorName);

            if (vendor == null)
            {
                vendor = new Vendor {VendorName = vendorName};

                vendor.VendorNamespacePrefixes.Add(
                    new VendorNamespacePrefix
                    {
                        Vendor = vendor,
                        NamespacePrefix = namePrefix
                    });

                context.Vendors.AddOrUpdate(vendor);

                //TODO: DEA - Move this behavior to happen during client creation.  No need to do this in two places.  At a minimum, remove the duplicated code.
                CreateDefaultApplicationForVendor(context, vendor);
            }

            return vendor;
        }

        public Application CreateApplicationForVendor(int vendorId, string applicationName, string claimSetName)
        {
            using (var context = _contextFactory.CreateContext())
            {
                var app =
                    context.Applications.SingleOrDefault(
                        a => a.ApplicationName == applicationName && a.Vendor.VendorId == vendorId);

                if (app != null)
                {
                    return app;
                }

                var vendor = context.Vendors.FirstOrDefault(v => v.VendorId == vendorId);

                app = new Application
                {
                    ApplicationName = applicationName,
                    Vendor = vendor,
                    ClaimSetName = claimSetName,
                    OperationalContextUri = _defaultOperationalContextUri.Value
                };

                context.Applications.AddOrUpdate(app);

                context.SaveChanges();

                return app;
            }
        }

        private void CreateDefaultApplicationForVendor(IUsersContext context, Vendor vendor)
        {
            var app =
                context.Applications.SingleOrDefault(
                    a => a.ApplicationName == _defaultAppName.Value && a.Vendor.VendorId == vendor.VendorId);

            if (app != null)
            {
                return;
            }

            context.Applications.AddOrUpdate(
                new Application
                {
                    ApplicationName = _defaultAppName.Value,
                    Vendor = vendor,
                    ClaimSetName = _defaultClaimSetName.Value,
                    OperationalContextUri = _defaultOperationalContextUri.Value
                });
        }

        internal class EmailResult
        {
            public string Email { get; set; }
        }

        internal class ConfirmationTokenResult
        {
            public string ConfirmationToken { get; set; }
        }
    }

Finally OAuth is registered into the container with:

using Autofac;
using EdFi.Ods.Api.Common.Authentication;
using EdFi.Ods.Api.Common.Configuration;
using EdFi.Ods.Api.Common.Container;
using EdFi.Ods.Api.NetCore.Middleware;
using EdFi.Ods.Api.NetCore.Providers;

namespace EdFi.Ods.Api.NetCore.Container.Modules
{
    public class OAuthModule : ConditionalModule
    {
        public OAuthModule(ApiSettings apiSettings)
            : base(apiSettings, nameof(OAuthModule)) { }

        public override bool IsSelected() => !ApiSettings.DisableSecurity;

        public override void ApplyConfigurationSpecificRegistrations(ContainerBuilder builder)
        {
            builder.RegisterType<ClientCredentialsTokenRequestProvider>().As<ITokenRequestProvider>();
            builder.RegisterType<OAuthTokenValidator>().As<IOAuthTokenValidator>();
            builder.RegisterDecorator<CachingOAuthTokenValidatorDecorator, IOAuthTokenValidator>();
            builder.RegisterType<AuthenticationProvider>().As<IAuthenticationProvider>();
        }
    }
}
  • No labels