Skip to content

Commit b2f185d

Browse files
Merge pull request #2442 from captainsafia/security-schemes-selector
Add support for SecuritySchemesSelector and default implementation
2 parents 23fe15d + 5b501e3 commit b2f185d

7 files changed

Lines changed: 216 additions & 14 deletions

File tree

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System.Threading.Tasks;
2+
using Microsoft.OpenApi.Models;
3+
4+
namespace Swashbuckle.AspNetCore.Swagger
5+
{
6+
public interface IAsyncSwaggerProvider
7+
{
8+
Task<OpenApiDocument> GetSwaggerAsync(
9+
string documentName,
10+
string host = null,
11+
string basePath = null);
12+
}
13+
}

src/Swashbuckle.AspNetCore.Swagger/SwaggerMiddleware.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,17 @@ public async Task Invoke(HttpContext httpContext, ISwaggerProvider swaggerProvid
3939
? httpContext.Request.PathBase.Value
4040
: null;
4141

42-
var swagger = swaggerProvider.GetSwagger(
43-
documentName: documentName,
44-
host: null,
45-
basePath: basePath);
42+
var swagger = swaggerProvider switch
43+
{
44+
IAsyncSwaggerProvider asyncSwaggerProvider => await asyncSwaggerProvider.GetSwaggerAsync(
45+
documentName: documentName,
46+
host: null,
47+
basePath: basePath),
48+
_ => swaggerProvider.GetSwagger(
49+
documentName: documentName,
50+
host: null,
51+
basePath: basePath)
52+
};
4653

4754
// One last opportunity to modify the Swagger Document - this time with request context
4855
foreach (var filter in _options.PreSerializeFilters)

src/Swashbuckle.AspNetCore.SwaggerGen/DependencyInjection/DocumentProvider.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ internal class DocumentProvider : IDocumentProvider
2323
{
2424
private readonly SwaggerGeneratorOptions _generatorOptions;
2525
private readonly SwaggerOptions _options;
26-
private readonly ISwaggerProvider _swaggerProvider;
26+
private readonly IAsyncSwaggerProvider _swaggerProvider;
2727

2828
public DocumentProvider(
2929
IOptions<SwaggerGeneratorOptions> generatorOptions,
3030
IOptions<SwaggerOptions> options,
31-
ISwaggerProvider swaggerProvider)
31+
IAsyncSwaggerProvider swaggerProvider)
3232
{
3333
_generatorOptions = generatorOptions.Value;
3434
_options = options.Value;
@@ -40,10 +40,10 @@ public IEnumerable<string> GetDocumentNames()
4040
return _generatorOptions.SwaggerDocs.Keys;
4141
}
4242

43-
public Task GenerateAsync(string documentName, TextWriter writer)
43+
public async Task GenerateAsync(string documentName, TextWriter writer)
4444
{
4545
// Let UnknownSwaggerDocument or other exception bubble up to caller.
46-
var swagger = _swaggerProvider.GetSwagger(documentName, host: null, basePath: null);
46+
var swagger = await _swaggerProvider.GetSwaggerAsync(documentName, host: null, basePath: null);
4747
var jsonWriter = new OpenApiJsonWriter(writer);
4848
if (_options.SerializeAsV2)
4949
{
@@ -53,8 +53,6 @@ public Task GenerateAsync(string documentName, TextWriter writer)
5353
{
5454
swagger.SerializeAsV3(jsonWriter);
5555
}
56-
57-
return Task.CompletedTask;
5856
}
5957
}
6058
}

src/Swashbuckle.AspNetCore.SwaggerGen/DependencyInjection/SwaggerGenServiceCollectionExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public static IServiceCollection AddSwaggerGen(
2626

2727
// Register generator and it's dependencies
2828
services.TryAddTransient<ISwaggerProvider, SwaggerGenerator>();
29+
services.TryAddTransient<IAsyncSwaggerProvider, SwaggerGenerator>();
2930
services.TryAddTransient(s => s.GetRequiredService<IOptions<SwaggerGeneratorOptions>>().Value);
3031
services.TryAddTransient<ISchemaGenerator, SchemaGenerator>();
3132
services.TryAddTransient(s => s.GetRequiredService<IOptions<SchemaGeneratorOptions>>().Value);

src/Swashbuckle.AspNetCore.SwaggerGen/SwaggerGenerator/SwaggerGenerator.cs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using System.Linq;
44
using System.Reflection;
55
using System.Text.RegularExpressions;
6+
using System.Threading.Tasks;
7+
using Microsoft.AspNetCore.Authentication;
68
using Microsoft.AspNetCore.Mvc;
79
using Microsoft.AspNetCore.Mvc.ApiExplorer;
810
using Microsoft.AspNetCore.Mvc.ModelBinding;
@@ -11,11 +13,12 @@
1113

1214
namespace Swashbuckle.AspNetCore.SwaggerGen
1315
{
14-
public class SwaggerGenerator : ISwaggerProvider
16+
public class SwaggerGenerator : ISwaggerProvider, IAsyncSwaggerProvider
1517
{
1618
private readonly IApiDescriptionGroupCollectionProvider _apiDescriptionsProvider;
1719
private readonly ISchemaGenerator _schemaGenerator;
1820
private readonly SwaggerGeneratorOptions _options;
21+
private readonly IAuthenticationSchemeProvider _authenticationSchemeProvider;
1922

2023
public SwaggerGenerator(
2124
SwaggerGeneratorOptions options,
@@ -27,7 +30,30 @@ public SwaggerGenerator(
2730
_schemaGenerator = schemaGenerator;
2831
}
2932

33+
public SwaggerGenerator(
34+
SwaggerGeneratorOptions options,
35+
IApiDescriptionGroupCollectionProvider apiDescriptionsProvider,
36+
ISchemaGenerator schemaGenerator,
37+
IAuthenticationSchemeProvider authentiationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
38+
{
39+
_authenticationSchemeProvider = authentiationSchemeProvider;
40+
}
41+
42+
public async Task<OpenApiDocument> GetSwaggerAsync(string documentName, string host = null, string basePath = null)
43+
{
44+
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
45+
swaggerDoc.Components.SecuritySchemes = await GetSecuritySchemes();
46+
return swaggerDoc;
47+
}
48+
3049
public OpenApiDocument GetSwagger(string documentName, string host = null, string basePath = null)
50+
{
51+
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
52+
swaggerDoc.Components.SecuritySchemes = GetSecuritySchemes().Result;
53+
return swaggerDoc;
54+
}
55+
56+
private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocument(string documentName, string host = null, string basePath = null)
3157
{
3258
if (!_options.SwaggerDocs.TryGetValue(documentName, out OpenApiInfo info))
3359
throw new UnknownSwaggerDocument(documentName, _options.SwaggerDocs.Select(d => d.Key));
@@ -47,7 +73,6 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
4773
Components = new OpenApiComponents
4874
{
4975
Schemas = schemaRepository.Schemas,
50-
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes)
5176
},
5277
SecurityRequirements = new List<OpenApiSecurityRequirement>(_options.SecurityRequirements)
5378
};
@@ -60,7 +85,30 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
6085

6186
swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);
6287

63-
return swaggerDoc;
88+
return (applicableApiDescriptions, swaggerDoc, schemaRepository);
89+
}
90+
91+
private async Task<Dictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
92+
{
93+
var securitySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
94+
var authenticationSchemes = Enumerable.Empty<AuthenticationScheme>();
95+
if (_authenticationSchemeProvider is not null)
96+
{
97+
authenticationSchemes = await _authenticationSchemeProvider.GetAllSchemesAsync();
98+
}
99+
var securitySchemesFromSelector = _options.SecuritySchemesSelector(authenticationSchemes);
100+
// Favor security schemes set via options over those generated
101+
// from the selector. For the default selector, this effectively
102+
// ends up favoring `Bearer` authentication types explicitly set
103+
// by the user over those derived by the selector.
104+
foreach (var securityScheme in securitySchemesFromSelector)
105+
{
106+
if (!securitySchemes.ContainsKey(securityScheme.Key))
107+
{
108+
securitySchemes.Add(securityScheme.Key, securityScheme.Value);
109+
}
110+
}
111+
return securitySchemes;
64112
}
65113

66114
private IList<OpenApiServer> GenerateServers(string host, string basePath)

src/Swashbuckle.AspNetCore.SwaggerGen/SwaggerGenerator/SwaggerGeneratorOptions.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.AspNetCore.Mvc.ApiExplorer;
88
using Microsoft.OpenApi.Models;
99
using Microsoft.AspNetCore.Routing;
10+
using Microsoft.AspNetCore.Authentication;
1011

1112
namespace Swashbuckle.AspNetCore.SwaggerGen
1213
{
@@ -19,6 +20,7 @@ public SwaggerGeneratorOptions()
1920
OperationIdSelector = DefaultOperationIdSelector;
2021
TagsSelector = DefaultTagsSelector;
2122
SortKeySelector = DefaultSortKeySelector;
23+
SecuritySchemesSelector = DefaultSecuritySchemeSelector;
2224
SchemaComparer = StringComparer.Ordinal;
2325
Servers = new List<OpenApiServer>();
2426
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>();
@@ -61,6 +63,8 @@ public SwaggerGeneratorOptions()
6163

6264
public IList<IDocumentFilter> DocumentFilters { get; set; }
6365

66+
public Func<IEnumerable<AuthenticationScheme>, Dictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}
67+
6468
private bool DefaultDocInclusionPredicate(string documentName, ApiDescription apiDescription)
6569
{
6670
return apiDescription.GroupName == null || apiDescription.GroupName == documentName;
@@ -102,5 +106,26 @@ private string DefaultSortKeySelector(ApiDescription apiDescription)
102106
{
103107
return TagsSelector(apiDescription).First();
104108
}
109+
110+
private Dictionary<string, OpenApiSecurityScheme> DefaultSecuritySchemeSelector(IEnumerable<AuthenticationScheme> schemes)
111+
{
112+
Dictionary<string, OpenApiSecurityScheme> securitySchemes = new();
113+
#if (NET6_0_OR_GREATER)
114+
foreach (var scheme in schemes)
115+
{
116+
if (scheme.Name == "Bearer")
117+
{
118+
securitySchemes[scheme.Name] = new OpenApiSecurityScheme
119+
{
120+
Type = SecuritySchemeType.Http,
121+
Scheme = "bearer", // "bearer" refers to the header name here
122+
In = ParameterLocation.Header,
123+
BearerFormat = "Json Web Token"
124+
};
125+
}
126+
}
127+
#endif
128+
return securitySchemes;
129+
}
105130
}
106131
}

test/Swashbuckle.AspNetCore.SwaggerGen.Test/SwaggerGenerator/SwaggerGeneratorTests.cs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
using Swashbuckle.AspNetCore.Swagger;
1313
using Swashbuckle.AspNetCore.TestSupport;
1414
using Xunit;
15+
using System.Threading.Tasks;
16+
using Microsoft.AspNetCore.Authentication;
17+
using Microsoft.AspNetCore.Server.HttpSys;
1518

1619
namespace Swashbuckle.AspNetCore.SwaggerGen.Test
1720
{
@@ -911,9 +914,78 @@ public void GetSwagger_SupportsOption_SecuritySchemes()
911914

912915
var document = subject.GetSwagger("v1");
913916

917+
Assert.Equal(new[] { "basic", "Bearer" }, document.Components.SecuritySchemes.Keys);
918+
}
919+
920+
[Fact]
921+
public async Task GetSwagger_SupportsSecuritySchemesSelector()
922+
{
923+
var subject = Subject(
924+
apiDescriptions: new ApiDescription[] { },
925+
options: new SwaggerGeneratorOptions
926+
{
927+
SwaggerDocs = new Dictionary<string, OpenApiInfo>
928+
{
929+
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
930+
},
931+
SecuritySchemesSelector = (schemes) => new Dictionary<string, OpenApiSecurityScheme>
932+
{
933+
["basic"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.Http, Scheme = "basic" }
934+
}
935+
}
936+
);
937+
938+
var document = await subject.GetSwaggerAsync("v1");
939+
940+
// Overrides the default set of [basic, bearer] with just [basic]
914941
Assert.Equal(new[] { "basic" }, document.Components.SecuritySchemes.Keys);
915942
}
916943

944+
[Fact]
945+
public async Task GetSwagger_DefaultSecuritySchemeSelectorAddsBearerByDefault()
946+
{
947+
var subject = Subject(
948+
apiDescriptions: new ApiDescription[] { },
949+
options: new SwaggerGeneratorOptions
950+
{
951+
SwaggerDocs = new Dictionary<string, OpenApiInfo>
952+
{
953+
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
954+
},
955+
}
956+
);
957+
958+
var document = await subject.GetSwaggerAsync("v1");
959+
960+
Assert.Equal(new[] { "Bearer" }, document.Components.SecuritySchemes.Keys);
961+
}
962+
963+
[Fact]
964+
public async Task GetSwagger_DefaultSecuritySchemesSelectorDoesNotOverrideBearer()
965+
{
966+
var subject = Subject(
967+
apiDescriptions: new ApiDescription[] { },
968+
options: new SwaggerGeneratorOptions
969+
{
970+
SwaggerDocs = new Dictionary<string, OpenApiInfo>
971+
{
972+
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
973+
},
974+
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>
975+
{
976+
["Bearer"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.ApiKey, Scheme = "someSpecialOne" }
977+
}
978+
}
979+
);
980+
981+
var document = await subject.GetSwaggerAsync("v1");
982+
983+
var securityScheme = Assert.Single(document.Components.SecuritySchemes);
984+
Assert.Equal("Bearer", securityScheme.Key);
985+
Assert.Equal(SecuritySchemeType.ApiKey, securityScheme.Value.Type);
986+
Assert.Equal("someSpecialOne", securityScheme.Value.Scheme);
987+
}
988+
917989
[Fact]
918990
public void GetSwagger_SupportsOption_ParameterFilters()
919991
{
@@ -1049,7 +1121,8 @@ private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, Sw
10491121
return new SwaggerGenerator(
10501122
options ?? DefaultOptions,
10511123
new FakeApiDescriptionGroupCollectionProvider(apiDescriptions),
1052-
new SchemaGenerator(new SchemaGeneratorOptions(), new JsonSerializerDataContractResolver(new JsonSerializerOptions()))
1124+
new SchemaGenerator(new SchemaGeneratorOptions(), new JsonSerializerDataContractResolver(new JsonSerializerOptions())),
1125+
new TestAuthenticationSchemeProvider()
10531126
);
10541127
}
10551128

@@ -1061,4 +1134,41 @@ private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, Sw
10611134
}
10621135
};
10631136
}
1137+
1138+
class TestAuthenticationSchemeProvider : IAuthenticationSchemeProvider
1139+
{
1140+
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes = new AuthenticationScheme[]
1141+
{
1142+
new AuthenticationScheme("Bearer", null, typeof(IAuthenticationHandler))
1143+
};
1144+
1145+
public void AddScheme(AuthenticationScheme scheme)
1146+
=> throw new NotImplementedException();
1147+
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
1148+
=> Task.FromResult(_authenticationSchemes);
1149+
1150+
public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
1151+
=> Task.FromResult(_authenticationSchemes.First());
1152+
1153+
public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
1154+
=> Task.FromResult(_authenticationSchemes.First());
1155+
1156+
public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
1157+
=> Task.FromResult(_authenticationSchemes.First());
1158+
1159+
public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
1160+
=> Task.FromResult(_authenticationSchemes.First());
1161+
1162+
public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
1163+
=> Task.FromResult(_authenticationSchemes.First());
1164+
1165+
public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
1166+
=> throw new NotImplementedException();
1167+
1168+
public Task<AuthenticationScheme> GetSchemeAsync(string name)
1169+
=> Task.FromResult(_authenticationSchemes.First());
1170+
1171+
public void RemoveScheme(string name)
1172+
=> throw new NotImplementedException();
1173+
}
10641174
}

0 commit comments

Comments
 (0)