提交 df0b268a 编写于 作者: C Christian

Add new API for extended authentication

上级 addf812d
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;
namespace MQTTnet.Tests.Clients.MqttClient
{
[TestClass]
public sealed class Extended_Authentication_Tests : BaseTestClass
{
[TestMethod]
public async Task ReAuthenticate()
{
using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500))
{
var server = await testEnvironment.StartServer();
server.ValidatingConnectionAsync += async args =>
{
using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(10)))
{
// Just do a simple custom authentication.
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, cancellationToken: timeout.Token);
var clientResponse = await args.ReceiveAuthenticationDataAsync(timeout.Token);
CollectionAssert.AreEqual(clientResponse.AuthenticationData, Encoding.ASCII.GetBytes("TOKEN"));
var userProperties = new MqttUserPropertiesBuilder().With("x", "y").Build();
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, cancellationToken: timeout.Token, userProperties: userProperties);
args.ReasonCode = MqttConnectReasonCode.Success;
}
};
server.ClientReAuthenticatingAsync += args =>
{
return CompletedTask.Instance;
};
var client = testEnvironment.CreateClient();
client.ExtendedAuthenticationExchangeAsync += async args =>
{
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, Encoding.ASCII.GetBytes("TOKEN"));
var serverResponse = await args.ReceiveAuthenticationDataAsync(CancellationToken.None);
Assert.IsNotNull(serverResponse.UserProperties);
Assert.AreEqual("x", serverResponse.UserProperties[0].Name);
Assert.AreEqual("y", serverResponse.UserProperties[0].Value);
};
var clientOptions = testEnvironment.CreateDefaultClientOptionsBuilder().WithTimeout(TimeSpan.FromSeconds(10)).WithAuthentication("CUSTOM").Build();
await client.ConnectAsync(clientOptions);
await LongTestDelay();
await client.ReAuthenticateAsync();
await LongTestDelay();
var pingResult = await client.TryPingAsync();
Assert.IsTrue(pingResult);
}
}
[TestMethod]
public async Task Use_Extended_Authentication()
{
var initialContextToken = Encoding.ASCII.GetBytes("initial context token");
var replyContextToken = Encoding.ASCII.GetBytes("reply context token");
var outcomeOfAuthentication = Encoding.ASCII.GetBytes("outcome of authentication");
using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500))
{
var server = await testEnvironment.StartServer();
/*
Sample flow from RFC:
1. Client to Server CONNECT Authentication Method="GS2-KRB5"
2. Server to Client AUTH rc=0x18 Authentication Method="GS2-KRB5"
3. Client to Server AUTH rc=0x18 Authentication Method="GS2-KRB5" Authentication Data=initial context token
4. Server to Client AUTH rc=0x18 Authentication Method="GS2-KRB5" Authentication Data=reply context token
5. Client to Server AUTH rc=0x18 Authentication Method="GS2-KRB5"
6. Server to Client CONNACK rc=0 Authentication Method="GS2-KRB5" Authentication Data=outcome of authentication
*/
server.ValidatingConnectionAsync += async args =>
{
using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(10)))
{
// 1.
if (args.AuthenticationMethod != "GS2-KRB5")
{
args.ReasonCode = MqttConnectReasonCode.BadAuthenticationMethod;
return;
}
// 2.
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, cancellationToken: timeout.Token);
// 3.
var clientResponse = await args.ReceiveAuthenticationDataAsync(timeout.Token);
CollectionAssert.AreEqual(clientResponse.AuthenticationData, initialContextToken);
Assert.AreEqual(MqttAuthenticateReasonCode.ContinueAuthentication, clientResponse.ReasonCode);
// 4.
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, replyContextToken, cancellationToken: timeout.Token);
// 5.
clientResponse = await args.ReceiveAuthenticationDataAsync(timeout.Token);
Assert.AreEqual(clientResponse.AuthenticationData, null);
Assert.AreEqual(MqttAuthenticateReasonCode.ContinueAuthentication, clientResponse.ReasonCode);
// 6.
args.ResponseAuthenticationData = outcomeOfAuthentication;
args.ReasonCode = MqttConnectReasonCode.Success;
}
};
var client = testEnvironment.CreateClient();
client.ExtendedAuthenticationExchangeAsync += async args =>
{
if (args.AuthenticationMethod != "GS2-KRB5")
{
Assert.Fail("Authentication method is wrong.");
}
// 2. THE FACT THAT THE SERVER SENDS THE AUTH PACKET WILL TRIGGER THIS EVENT SO THE INITIAL DATA IS IN THE EVENT ARGS!
Assert.AreEqual(MqttAuthenticateReasonCode.ContinueAuthentication, args.ReasonCode);
// 3.
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication, initialContextToken);
// 4.
var serverResponse = await args.ReceiveAuthenticationDataAsync(CancellationToken.None);
CollectionAssert.AreEqual(serverResponse.AuthenticationData, replyContextToken);
Assert.AreEqual(MqttAuthenticateReasonCode.ContinueAuthentication, serverResponse.ReasonCode);
// 5.
await args.SendAuthenticationDataAsync(MqttAuthenticateReasonCode.ContinueAuthentication);
};
var clientOptions = testEnvironment.CreateDefaultClientOptionsBuilder().WithTimeout(TimeSpan.FromSeconds(10)).WithAuthentication("GS2-KRB5").Build();
await client.ConnectAsync(clientOptions);
}
}
}
}
\ No newline at end of file
......@@ -277,8 +277,7 @@ namespace MQTTnet.Tests.Mockups
// Null is used when the client id is assigned from the server!
if (!string.IsNullOrEmpty(e.ClientId) && !e.ClientId.StartsWith(TestContext.TestName))
{
TrackException(new InvalidOperationException($"Invalid client ID used ({e.ClientId}). It must start with UnitTest name."));
e.ReasonCode = MqttConnectReasonCode.ClientIdentifierNotValid;
Assert.Fail($"Client ID does not start with test name ({TestContext.TestName}).");
}
}
......@@ -292,23 +291,12 @@ namespace MQTTnet.Tests.Mockups
{
foreach (var mqttClient in _clients)
{
try
{
//mqttClient.DisconnectAsync().GetAwaiter().GetResult();
}
catch
{
// This can happen when the test already disconnected the client.
}
finally
{
mqttClient?.Dispose();
}
mqttClient?.Dispose();
}
foreach (var lowLevelMqttClient in _lowLevelClients)
{
lowLevelMqttClient.Dispose();
lowLevelMqttClient?.Dispose();
}
try
......
......@@ -2,12 +2,18 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Threading.Tasks;
using System.Collections.Generic;
using MQTTnet.Packets;
using MQTTnet.Protocol;
namespace MQTTnet.Client
{
public interface IMqttExtendedAuthenticationExchangeHandler
public sealed class MqttClientPartialAuthenticationResponse
{
Task HandleRequestAsync(MqttExtendedAuthenticationExchangeEventArgs eventArgs);
public byte[] AuthenticationData { get; set; }
public MqttAuthenticateReasonCode ReasonCode { get; set; }
public List<MqttUserProperty> UserProperties { get; set; }
}
}
}
\ No newline at end of file
......@@ -4,6 +4,10 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Client.Internal;
using MQTTnet.Exceptions;
using MQTTnet.Packets;
using MQTTnet.Protocol;
......@@ -13,13 +17,17 @@ namespace MQTTnet.Client
{
static readonly IReadOnlyCollection<MqttUserProperty> EmptyUserProperties = new List<MqttUserProperty>();
public MqttExtendedAuthenticationExchangeEventArgs(MqttAuthPacket authPacket)
readonly IMqttAuthenticationTransportStrategy _transportStrategy;
public MqttExtendedAuthenticationExchangeEventArgs(MqttAuthPacket authPacket, IMqttAuthenticationTransportStrategy transportStrategy)
{
if (authPacket == null)
{
throw new ArgumentNullException(nameof(authPacket));
}
_transportStrategy = transportStrategy ?? throw new ArgumentNullException(nameof(transportStrategy));
ReasonCode = authPacket.ReasonCode;
ReasonString = authPacket.ReasonString;
AuthenticationMethod = authPacket.AuthenticationMethod;
......@@ -51,15 +59,53 @@ namespace MQTTnet.Client
/// </summary>
public string ReasonString { get; }
// /// <summary>
// /// Gets the response which will be sent to the server.
// /// </summary>
// public MqttExtendedAuthenticationExchangeResponse Response { get; } = new MqttExtendedAuthenticationExchangeResponse();
/// <summary>
/// Gets or sets the user properties.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public IReadOnlyCollection<MqttUserProperty> UserProperties { get; }
/// <summary>
/// Gets the response which will be sent to the server.
/// </summary>
public MqttExtendedAuthenticationExchangeResponse Response { get; } = new MqttExtendedAuthenticationExchangeResponse();
public async Task<MqttClientPartialAuthenticationResponse> ReceiveAuthenticationDataAsync(CancellationToken cancellationToken)
{
var receivePacket = await _transportStrategy.ReceivePacketAsync(AuthenticationMethod, cancellationToken).ConfigureAwait(false);
if (receivePacket == null)
{
throw new MqttCommunicationException("The client closed the connection.");
}
return new MqttClientPartialAuthenticationResponse
{
ReasonCode = receivePacket.ReasonCode,
AuthenticationData = receivePacket.AuthenticationData,
UserProperties = receivePacket.UserProperties
};
}
public Task SendAuthenticationDataAsync(
MqttAuthenticateReasonCode reasonCode,
byte[] authenticationData = null,
string reasonString = null,
List<MqttUserProperty> userProperties = null,
CancellationToken cancellationToken = default)
{
// The authentication method will never change so we must use the already known one [MQTT-4.12.0-5].
//var authPacket = MqttPacketFactories.Auth.Create(AuthenticationMethod, reasonCode);
var authPacket = new MqttAuthPacket
{
AuthenticationMethod = AuthenticationMethod,
AuthenticationData = authenticationData,
UserProperties = userProperties,
ReasonCode = reasonCode,
ReasonString = reasonString
};
return _transportStrategy.SendPacketAsync(authPacket, cancellationToken);
}
}
}
\ No newline at end of file
......@@ -7,18 +7,22 @@ using MQTTnet.Packets;
namespace MQTTnet.Client
{
public sealed class MqttExtendedAuthenticationExchangeResponse
public sealed class MqttReAuthenticationOptions
{
/// <summary>
/// Gets or sets the authentication data.
/// Authentication data is binary information used to transmit multiple iterations of cryptographic secrets of protocol
/// steps.
/// The content of the authentication data is highly dependent on the specific implementation of the authentication
/// method.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public byte[] AuthenticationData { get; set; }
/// <summary>
/// Gets or sets the user properties which will be sent to the server.
/// Gets or sets the user properties.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public List<MqttUserProperty> UserProperties { get; set; } = new List<MqttUserProperty>();
public List<MqttUserProperty> UserProperties { get; set; }
}
}
\ No newline at end of file
......@@ -11,110 +11,112 @@ namespace MQTTnet.Client
public sealed class MqttClientConnectResult
{
/// <summary>
/// Gets the result code.
/// MQTTv5 only.
/// Gets the client identifier which was chosen by the server.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public MqttClientConnectResultCode ResultCode { get; internal set; }
public string AssignedClientIdentifier { get; internal set; }
/// <summary>
/// Gets a value indicating whether a session was already available or not.
/// MQTTv5 only.
/// Gets the authentication data.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool IsSessionPresent { get; internal set; }
public byte[] AuthenticationData { get; internal set; }
/// <summary>
/// Gets a value indicating whether wildcards can be used in subscriptions at the current server.
/// MQTTv5 only.
/// Gets the authentication method.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool WildcardSubscriptionAvailable { get; internal set; }
public string AuthenticationMethod { get; internal set; }
/// <summary>
/// Gets whether the server supports retained messages.
/// MQTTv5 only.
/// Gets a value indicating whether a session was already available or not.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool RetainAvailable { get; internal set; }
public bool IsSessionPresent { get; internal set; }
/// <summary>
/// Gets the client identifier which was chosen by the server.
/// MQTTv5 only.
/// </summary>
public string AssignedClientIdentifier { get; internal set; }
public uint? MaximumPacketSize { get; internal set; }
/// <summary>
/// Gets the authentication method.
/// MQTTv5 only.
/// Gets the maximum QoS which is supported by the server.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public string AuthenticationMethod { get; internal set; }
public MqttQualityOfServiceLevel MaximumQoS { get; internal set; }
/// <summary>
/// Gets the authentication data.
/// MQTTv5 only.
/// Gets the reason string.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public byte[] AuthenticationData { get; internal set; }
public string ReasonString { get; internal set; }
public uint? MaximumPacketSize { get; internal set; }
public ushort? ReceiveMaximum { get; internal set; }
/// <summary>
/// Gets the reason string.
/// MQTTv5 only.
/// Gets the response information.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public string ReasonString { get; internal set; }
public string ResponseInformation { get; internal set; }
public ushort? ReceiveMaximum { get; internal set; }
/// <summary>
/// Gets the maximum QoS which is supported by the server.
/// MQTTv5 only.
/// Gets the result code.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public MqttQualityOfServiceLevel MaximumQoS { get; internal set; }
public MqttClientConnectResultCode ResultCode { get; internal set; }
/// <summary>
/// Gets the response information.
/// MQTTv5 only.
/// Gets whether the server supports retained messages.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public string ResponseInformation { get; internal set; }
public bool RetainAvailable { get; internal set; }
/// <summary>
/// Gets the maximum value for a topic alias. 0 means not supported.
/// MQTTv5 only.
/// MQTTv5 only.
/// Gets the keep alive interval which was chosen by the server instead of the
/// keep alive interval from the client CONNECT packet.
/// A value of 0 indicates that the feature is not used.
/// </summary>
public ushort TopicAliasMaximum { get; internal set; }
public ushort ServerKeepAlive { get; internal set; }
/// <summary>
/// Gets an alternate server which should be used instead of the current one.
/// MQTTv5 only.
/// Gets an alternate server which should be used instead of the current one.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public string ServerReference { get; internal set; }
public uint? SessionExpiryInterval { get; internal set; }
/// <summary>
/// MQTTv5 only.
/// Gets the keep alive interval which was chosen by the server instead of the
/// keep alive interval from the client CONNECT packet.
/// A value of 0 indicates that the feature is not used.
/// Gets a value indicating whether the shared subscriptions are available or not.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public ushort ServerKeepAlive { get; internal set; }
public uint? SessionExpiryInterval { get; internal set; }
public bool SharedSubscriptionAvailable { get; internal set; }
/// <summary>
/// Gets a value indicating whether the subscription identifiers are available or not.
/// MQTTv5 only.
/// Gets a value indicating whether the subscription identifiers are available or not.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool SubscriptionIdentifiersAvailable { get; internal set; }
/// <summary>
/// Gets a value indicating whether the shared subscriptions are available or not.
/// MQTTv5 only.
/// Gets the maximum value for a topic alias. 0 means not supported.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool SharedSubscriptionAvailable { get; internal set; }
public ushort TopicAliasMaximum { get; internal set; }
/// <summary>
/// Gets the user properties.
/// In MQTT 5, user properties are basic UTF-8 string key-value pairs that you can append to almost every type of MQTT packet.
/// As long as you don’t exceed the maximum message size, you can use an unlimited number of user properties to add metadata to MQTT messages and pass information between publisher, broker, and subscriber.
/// The feature is very similar to the HTTP header concept.
/// MQTTv5 only.
/// Gets the user properties.
/// In MQTT 5, user properties are basic UTF-8 string key-value pairs that you can append to almost every type of MQTT
/// packet.
/// As long as you don’t exceed the maximum message size, you can use an unlimited number of user properties to add
/// metadata to MQTT messages and pass information between publisher, broker, and subscriber.
/// The feature is very similar to the HTTP header concept.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public List<MqttUserProperty> UserProperties { get; internal set; }
/// <summary>
/// Gets a value indicating whether wildcards can be used in subscriptions at the current server.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public bool WildcardSubscriptionAvailable { get; internal set; }
}
}
}
\ No newline at end of file
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using MQTTnet.Packets;
using MQTTnet.Protocol;
namespace MQTTnet.Client
{
public class MqttExtendedAuthenticationExchangeData
{
/// <summary>
/// Gets or sets the reason code.
/// Hint: MQTT 5 feature only.
/// </summary>
public MqttAuthenticateReasonCode ReasonCode { get; set; }
/// <summary>
/// Gets or sets the reason string.
/// Hint: MQTT 5 feature only.
/// </summary>
public string ReasonString { get; set; }
/// <summary>
/// Gets or sets the authentication data.
/// Authentication data is binary information used to transmit multiple iterations of cryptographic secrets of protocol steps.
/// The content of the authentication data is highly dependent on the specific implementation of the authentication method.
/// Hint: MQTT 5 feature only.
/// </summary>
public byte[] AuthenticationData { get; set; }
/// <summary>
/// Gets or sets the user properties.
/// In MQTT 5, user properties are basic UTF-8 string key-value pairs that you can append to almost every type of MQTT packet.
/// As long as you don’t exceed the maximum message size, you can use an unlimited number of user properties to add metadata to MQTT messages and pass information between publisher, broker, and subscriber.
/// The feature is very similar to the HTTP header concept.
/// Hint: MQTT 5 feature only.
/// </summary>
public List<MqttUserProperty> UserProperties { get; }
}
}
\ No newline at end of file
......@@ -15,6 +15,8 @@ namespace MQTTnet.Client
event Func<MqttClientDisconnectedEventArgs, Task> DisconnectedAsync;
event Func<MqttExtendedAuthenticationExchangeEventArgs, Task> ExtendedAuthenticationExchangeAsync;
event Func<InspectMqttPacketEventArgs, Task> InspectPacketAsync;
bool IsConnected { get; }
......@@ -27,10 +29,10 @@ namespace MQTTnet.Client
Task PingAsync(CancellationToken cancellationToken = default);
Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticationExchangeData data, CancellationToken cancellationToken = default);
Task<MqttClientPublishResult> PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken = default);
Task ReAuthenticateAsync(MqttReAuthenticationOptions options, CancellationToken cancellationToken = default);
Task<MqttClientSubscribeResult> SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken = default);
Task<MqttClientUnsubscribeResult> UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken = default);
......
......@@ -15,9 +15,65 @@ using MQTTnet.Protocol;
namespace MQTTnet.Client.Internal
{
public sealed class MqttInitialAuthenticationStrategy : IMqttAuthenticationTransportStrategy
{
readonly IMqttChannelAdapter _channelAdapter;
public MqttInitialAuthenticationStrategy(IMqttChannelAdapter channelAdapter)
{
_channelAdapter = channelAdapter ?? throw new ArgumentNullException(nameof(channelAdapter));
}
public Task SendPacketAsync(MqttAuthPacket authPacket, CancellationToken cancellationToken)
{
if (authPacket == null)
{
throw new ArgumentNullException(nameof(authPacket));
}
return _channelAdapter.SendPacketAsync(authPacket, cancellationToken);
}
public async Task<MqttAuthPacket> ReceivePacketAsync(string authenticationMethod, CancellationToken cancellationToken)
{
if (authenticationMethod == null)
{
throw new ArgumentNullException(nameof(authenticationMethod));
}
var receivePacket = await _channelAdapter.ReceivePacketAsync(cancellationToken).ConfigureAwait(false);
if (receivePacket is MqttAuthPacket authPacket)
{
if (!string.Equals(authPacket.AuthenticationMethod, authenticationMethod, StringComparison.Ordinal))
{
throw new MqttProtocolViolationException("The authentication method is not allowed to change while authenticating.");
}
return authPacket;
}
if (receivePacket == null)
{
throw new MqttCommunicationException("The server closed the connection.");
}
throw new MqttProtocolViolationException("Expected an AUTH packet from the client.");
}
}
public interface IMqttAuthenticationTransportStrategy
{
Task SendPacketAsync(MqttAuthPacket authPacket, CancellationToken cancellationToken);
Task<MqttAuthPacket> ReceivePacketAsync(string authenticationMethod, CancellationToken cancellationToken);
}
public sealed class MqttClientAuthenticationHandler
{
readonly MqttClient _client;
readonly AsyncEvent<MqttExtendedAuthenticationExchangeEventArgs> _extendedAuthenticationExchangeEvent = new AsyncEvent<MqttExtendedAuthenticationExchangeEventArgs>();
readonly MqttNetSourceLogger _logger;
public MqttClientAuthenticationHandler(MqttClient client, IMqttNetLogger logger)
......@@ -31,7 +87,10 @@ namespace MQTTnet.Client.Internal
_logger = logger.WithSource(nameof(MqttClientAuthenticationHandler));
}
public AsyncEvent<MqttExtendedAuthenticationExchangeEventArgs> ExtendedAuthenticationExchangeEvent { get; } = new AsyncEvent<MqttExtendedAuthenticationExchangeEventArgs>();
public void AddHandler(Func<MqttExtendedAuthenticationExchangeEventArgs, Task> handler)
{
_extendedAuthenticationExchangeEvent.AddHandler(handler);
}
public async Task<MqttClientConnectResult> Authenticate(IMqttChannelAdapter channelAdapter, MqttClientOptions options, CancellationToken cancellationToken)
{
......@@ -64,25 +123,43 @@ namespace MQTTnet.Client.Internal
{
var receivedPacket = await channelAdapter.ReceivePacketAsync(cancellationToken).ConfigureAwait(false);
if (receivedPacket is MqttAuthPacket authPacket)
if (receivedPacket is MqttAuthPacket initialAuthPacket)
{
// MQTT v3.1.1 cannot send an AUTH packet.
// If the Server requires additional information to complete the authentication, it can send an AUTH packet to the Client.
// This packet MUST contain a Reason Code of 0x18 (Continue authentication)
if (authPacket.ReasonCode != MqttAuthenticateReasonCode.ContinueAuthentication)
if (initialAuthPacket.ReasonCode != MqttAuthenticateReasonCode.ContinueAuthentication)
{
throw new MqttProtocolViolationException("Wrong reason code received [MQTT-4.12.0-2].");
}
var response = await OnExtendedAuthentication(authPacket).ConfigureAwait(false);
authPacket = MqttPacketFactories.Auth.Create(authPacket, response);
await _client.Send(authPacket, cancellationToken).ConfigureAwait(false);
//var response = await HandleExtendedAuthentication(authPacket, channelAdapter).ConfigureAwait(false);
await HandleExtendedAuthentication(initialAuthPacket, new MqttInitialAuthenticationStrategy(channelAdapter)).ConfigureAwait(false);
}
else if (receivedPacket is MqttConnAckPacket connAckPacketBuffer)
{
connAckPacket = connAckPacketBuffer;
// The CONNACK packet is the last packet when authenticating so there is no further AUTH packet allowed!
break;
}
else
{
throw new MqttProtocolViolationException($"Received {receivedPacket.GetRfcName()} while authenticating.");
}
}
cancellationToken.ThrowIfCancellationRequested();
// If the initial CONNECT packet included an Authentication Method property then all AUTH packets,
// and any successful CONNACK packet MUST include an Authentication Method Property with the same
// value as in the CONNECT packet [MQTT-4.12.0-5].
if (!string.IsNullOrEmpty(connectPacket.AuthenticationMethod))
{
if (!string.Equals(connectPacket.AuthenticationMethod, connAckPacket?.AuthenticationMethod))
{
throw new MqttProtocolViolationException("The CONNACK packet does not have the same authentication method as the CONNECT packet.");
}
}
result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion);
......@@ -97,8 +174,9 @@ namespace MQTTnet.Client.Internal
// did send a proper ACK packet.
if (options.ThrowOnNonSuccessfulResponseFromServer)
{
_logger.Warning("Client will now throw an _MqttConnectingFailedException_. This is obsolete and will be removed in the future. Consider setting _ThrowOnNonSuccessfulResponseFromServer=False_ in client options.");
_logger.Warning(
"Client will now throw an _MqttConnectingFailedException_. This is obsolete and will be removed in the future. Consider setting _ThrowOnNonSuccessfulResponseFromServer=False_ in client options.");
if (result.ResultCode != MqttClientConnectResultCode.Success)
{
throw new MqttConnectingFailedException($"Connecting with MQTT server failed ({result.ResultCode}).", null, result);
......@@ -110,16 +188,38 @@ namespace MQTTnet.Client.Internal
return result;
}
async Task<MqttExtendedAuthenticationExchangeResponse> OnExtendedAuthentication(MqttAuthPacket authPacket)
public Task HandleExtendedAuthentication(MqttAuthPacket initialAuthPacket, IMqttAuthenticationTransportStrategy transportStrategy)
{
ValidateEventHandler();
var eventArgs = new MqttExtendedAuthenticationExchangeEventArgs(initialAuthPacket, transportStrategy);
return _extendedAuthenticationExchangeEvent.InvokeAsync(eventArgs);
}
void ValidateEventHandler()
{
if (!ExtendedAuthenticationExchangeEvent.HasHandlers)
if (!_extendedAuthenticationExchangeEvent.HasHandlers)
{
throw new InvalidOperationException("Cannot handle extended authentication without attached event handler.");
}
}
public void RemoveHandler(Func<MqttExtendedAuthenticationExchangeEventArgs, Task> handler)
{
_extendedAuthenticationExchangeEvent.RemoveHandler(handler);
}
var eventArgs = new MqttExtendedAuthenticationExchangeEventArgs(authPacket);
await ExtendedAuthenticationExchangeEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
return eventArgs.Response;
public async Task ReAuthenticate(MqttReAuthenticationOptions options, CancellationToken cancellationToken)
{
ValidateEventHandler();
var authPacket = MqttPacketFactories.Auth.CreateReAuthenticationPacket(_client.Options, options);
// Sending the AUTH packet will trigger the re authentication and the event handler will
// handle the details.
await _client.Send(authPacket, cancellationToken).ConfigureAwait(false);
}
}
}
\ No newline at end of file
......@@ -29,7 +29,7 @@ namespace MQTTnet.Client
readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider();
readonly IMqttNetLogger _rootLogger;
IMqttChannelAdapter _adapter;
IMqttChannelAdapter _channelAdapter;
internal bool _cleanDisconnectInitiated;
......@@ -82,8 +82,8 @@ namespace MQTTnet.Client
public event Func<MqttExtendedAuthenticationExchangeEventArgs, Task> ExtendedAuthenticationExchangeAsync
{
add => _authenticationHandler.ExtendedAuthenticationExchangeEvent.AddHandler(value);
remove => _authenticationHandler.ExtendedAuthenticationExchangeEvent.RemoveHandler(value);
add => _authenticationHandler.AddHandler(value);
remove => _authenticationHandler.RemoveHandler(value);
}
public event Func<InspectMqttPacketEventArgs, Task> InspectPacketAsync
......@@ -125,7 +125,7 @@ namespace MQTTnet.Client
_clientAlive = new CancellationTokenSource();
var adapter = _adapterFactory.CreateClientAdapter(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger);
_adapter = adapter;
_channelAdapter = adapter;
if (cancellationToken.CanBeCanceled)
{
......@@ -194,7 +194,7 @@ namespace MQTTnet.Client
if (Options.ValidateFeatures)
{
MqttClientDisconnectOptionsValidator.ThrowIfNotSupported(options, _adapter.PacketFormatterAdapter.ProtocolVersion);
MqttClientDisconnectOptionsValidator.ThrowIfNotSupported(options, _channelAdapter.PacketFormatterAdapter.ProtocolVersion);
}
// Sending the DISCONNECT may fail due to connection issues. The resulting exception
......@@ -245,7 +245,7 @@ namespace MQTTnet.Client
if (Options.ValidateFeatures)
{
MqttApplicationMessageValidator.ThrowIfNotSupported(applicationMessage, _adapter.PacketFormatterAdapter.ProtocolVersion);
MqttApplicationMessageValidator.ThrowIfNotSupported(applicationMessage, _channelAdapter.PacketFormatterAdapter.ProtocolVersion);
}
var publishPacket = MqttPacketFactories.Publish.Create(applicationMessage);
......@@ -271,26 +271,25 @@ namespace MQTTnet.Client
}
}
public Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticationExchangeData data, CancellationToken cancellationToken = default)
public Task ReAuthenticateAsync(MqttReAuthenticationOptions options, CancellationToken cancellationToken = default)
{
if (data == null)
if (options == null)
{
throw new ArgumentNullException(nameof(data));
throw new ArgumentNullException(nameof(options));
}
ThrowIfDisposed();
ThrowIfNotConnected();
var authPacket = new MqttAuthPacket
if (cancellationToken.CanBeCanceled)
{
// This must always be equal to the value from the CONNECT packet. So we use it here to ensure that.
AuthenticationMethod = Options.AuthenticationMethod,
AuthenticationData = data.AuthenticationData,
ReasonString = data.ReasonString,
UserProperties = data.UserProperties
};
return _authenticationHandler.ReAuthenticate(options, cancellationToken);
}
return Send(authPacket, cancellationToken);
using (var timeout = new CancellationTokenSource(Options.Timeout))
{
return _authenticationHandler.ReAuthenticate(options, timeout.Token);
}
}
public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken = default)
......@@ -307,7 +306,7 @@ namespace MQTTnet.Client
if (Options.ValidateFeatures)
{
MqttClientSubscribeOptionsValidator.ThrowIfNotSupported(options, _adapter.PacketFormatterAdapter.ProtocolVersion);
MqttClientSubscribeOptionsValidator.ThrowIfNotSupported(options, _channelAdapter.PacketFormatterAdapter.ProtocolVersion);
}
ThrowIfDisposed();
......@@ -349,7 +348,7 @@ namespace MQTTnet.Client
if (Options.ValidateFeatures)
{
MqttClientUnsubscribeOptionsValidator.ThrowIfNotSupported(options, _adapter.PacketFormatterAdapter.ProtocolVersion);
MqttClientUnsubscribeOptionsValidator.ThrowIfNotSupported(options, _channelAdapter.PacketFormatterAdapter.ProtocolVersion);
}
var unsubscribePacket = MqttPacketFactories.Unsubscribe.Create(options);
......@@ -389,7 +388,7 @@ namespace MQTTnet.Client
_keepAliveHandler.TrackSentPacket();
return _adapter.SendPacketAsync(packet, cancellationToken);
return _channelAdapter.SendPacketAsync(packet, cancellationToken);
}
protected override void Dispose(bool disposing)
......@@ -446,8 +445,8 @@ namespace MQTTnet.Client
_publishPacketReceiverQueue?.Dispose();
_publishPacketReceiverQueue = null;
_adapter?.Dispose();
_adapter = null;
_channelAdapter?.Dispose();
_channelAdapter = null;
_packetDispatcher?.Dispose();
_packetDispatcher = null;
......@@ -467,11 +466,11 @@ namespace MQTTnet.Client
{
_logger.Verbose("Trying to connect with server '{0}'", Options.ChannelOptions);
await _adapter.ConnectAsync(effectiveCancellationToken.Token).ConfigureAwait(false);
await _channelAdapter.ConnectAsync(effectiveCancellationToken.Token).ConfigureAwait(false);
_logger.Verbose("Connection with server established");
var connectResult = await _authenticationHandler.Authenticate(_adapter, Options, effectiveCancellationToken.Token).ConfigureAwait(false);
var connectResult = await _authenticationHandler.Authenticate(_channelAdapter, Options, effectiveCancellationToken.Token).ConfigureAwait(false);
if (connectResult.ResultCode == MqttClientConnectResultCode.Success)
{
......@@ -494,13 +493,13 @@ namespace MQTTnet.Client
try
{
if (_adapter != null)
if (_channelAdapter != null)
{
_logger.Verbose("Disconnecting [Timeout={0}]", Options.Timeout);
using (var timeout = new CancellationTokenSource(Options.Timeout))
{
await _adapter.DisconnectAsync(timeout.Token).ConfigureAwait(false);
await _channelAdapter.DisconnectAsync(timeout.Token).ConfigureAwait(false);
}
}
......@@ -619,17 +618,6 @@ namespace MQTTnet.Client
return CompletedTask.Instance;
}
Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket)
{
var extendedAuthenticationExchangeHandler = Options.ExtendedAuthenticationExchangeHandler;
if (extendedAuthenticationExchangeHandler != null)
{
return extendedAuthenticationExchangeHandler.HandleRequestAsync(new MqttExtendedAuthenticationExchangeEventArgs(authPacket));
}
return CompletedTask.Instance;
}
Task ProcessReceivedDisconnectPacket(MqttDisconnectPacket disconnectPacket)
{
_disconnectReason = (int)disconnectPacket.ReasonCode;
......@@ -722,7 +710,7 @@ namespace MQTTnet.Client
async Task<MqttPacket> Receive(CancellationToken cancellationToken)
{
var packetTask = _adapter.ReceivePacketAsync(cancellationToken);
var packetTask = _channelAdapter.ReceivePacketAsync(cancellationToken);
MqttPacket packet;
if (packetTask.IsCompleted)
......@@ -796,15 +784,36 @@ namespace MQTTnet.Client
}
}
async Task<TResponsePacket> Request<TResponsePacket>(MqttPacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttPacket
internal async Task<TPacket> Receive<TPacket>(int packetIdentifier, CancellationToken cancellationToken) where TPacket : MqttPacket
{
using (var packetAwaitable = _packetDispatcher.AddAwaitable<TPacket>(packetIdentifier))
{
return await packetAwaitable.WaitOneAsync(cancellationToken);
}
}
internal async Task<TResponsePacket> Request<TResponsePacket>(MqttPacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttPacket
{
cancellationToken.ThrowIfCancellationRequested();
ushort packetIdentifier = 0;
// Some packets are a direct response to another packet but have no packet identifier for correlation!
int packetIdentifier;
if (requestPacket is MqttPacketWithIdentifier packetWithIdentifier)
{
packetIdentifier = packetWithIdentifier.PacketIdentifier;
}
else if (requestPacket is MqttPingReqPacket)
{
packetIdentifier = MqttPacketDispatcher.PingRespPacketFakeIdentifier;
}
else if (requestPacket is MqttAuthPacket)
{
packetIdentifier = MqttPacketDispatcher.AuthPacketFakeIdentifier;
}
else
{
throw new MqttProtocolViolationException($"Expecting a response for packet {requestPacket.GetRfcName()} is not supported.");
}
using (var packetAwaitable = _packetDispatcher.AddAwaitable<TResponsePacket>(packetIdentifier))
{
......@@ -908,23 +917,31 @@ namespace MQTTnet.Client
{
await ProcessReceivedDisconnectPacket(disconnectPacket).ConfigureAwait(false);
}
else if (packet is MqttAuthPacket authPacket)
{
await ProcessReceivedAuthPacket(authPacket).ConfigureAwait(false);
}
// else if (packet is MqttAuthPacket authPacket)
// {
// await _authenticationHandler.HandleExtendedAuthentication(authPacket, new MqttInitialAuthenticationStrategy(_channelAdapter)).ConfigureAwait(false);
// }
else if (packet is MqttPingRespPacket)
{
_packetDispatcher.TryDispatch(packet);
}
else if (packet is MqttPingReqPacket)
{
throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a Client to the Server only.");
throw new MqttProtocolViolationException("The PINGREQ packet can only be sent from client to server.");
}
else if (packet is MqttConnAckPacket)
{
throw new MqttProtocolViolationException("The CONNACK packet can only be send while connecting.");
}
else if (packet is MqttConnectPacket)
{
throw new MqttProtocolViolationException("The CONNECT packet can only be sent from client to server.");
}
else
{
if (!_packetDispatcher.TryDispatch(packet))
{
throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time.");
throw new MqttProtocolViolationException($"Received unexpected {packet} packet.");
}
}
}
......
......@@ -14,6 +14,8 @@ namespace MQTTnet.Client
{
public static class MqttClientExtensions
{
static readonly MqttReAuthenticationOptions EmptyReAuthenticationOptions = new MqttReAuthenticationOptions();
public static Task DisconnectAsync(
this IMqttClient client,
MqttClientDisconnectOptionsReason reason = MqttClientDisconnectOptionsReason.NormalDisconnection,
......@@ -77,6 +79,11 @@ namespace MQTTnet.Client
return mqttClient.PublishBinaryAsync(topic, payloadBuffer, qualityOfServiceLevel, retain, cancellationToken);
}
public static Task ReAuthenticateAsync(this IMqttClient client, CancellationToken cancellationToken = default)
{
return client.ReAuthenticateAsync(EmptyReAuthenticationOptions, cancellationToken);
}
public static Task ReconnectAsync(this IMqttClient client, CancellationToken cancellationToken = default)
{
if (client.Options == null)
......@@ -87,7 +94,7 @@ namespace MQTTnet.Client
return client.ConnectAsync(client.Options, cancellationToken);
}
public static Task<MqttClientSubscribeResult> SubscribeAsync(this IMqttClient mqttClient, MqttTopicFilter topicFilter, CancellationToken cancellationToken = default)
{
if (mqttClient == null)
......
......@@ -53,9 +53,7 @@ namespace MQTTnet.Client
public string ClientId { get; set; } = Guid.NewGuid().ToString("N");
public IMqttClientCredentialsProvider Credentials { get; set; }
public IMqttExtendedAuthenticationExchangeHandler ExtendedAuthenticationExchangeHandler { get; set; }
/// <summary>
/// Gets or sets the keep alive period.
/// The connection is normally left open by the client so that is can send and receive data at any time.
......
......@@ -79,7 +79,7 @@ namespace MQTTnet.Client
return _options;
}
public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data)
public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data = null)
{
_options.AuthenticationMethod = method;
_options.AuthenticationData = data;
......@@ -177,12 +177,6 @@ namespace MQTTnet.Client
return this;
}
public MqttClientOptionsBuilder WithExtendedAuthenticationExchangeHandler(IMqttExtendedAuthenticationExchangeHandler handler)
{
_options.ExtendedAuthenticationExchangeHandler = handler;
return this;
}
public MqttClientOptionsBuilder WithKeepAlivePeriod(TimeSpan value)
{
_options.KeepAlivePeriod = value;
......
......@@ -11,25 +11,26 @@ namespace MQTTnet.Formatter
{
public sealed class MqttAuthPacketFactory
{
public MqttAuthPacket Create(MqttAuthPacket serverAuthPacket, MqttExtendedAuthenticationExchangeResponse response)
public MqttAuthPacket CreateReAuthenticationPacket(MqttClientOptions clientOptions, MqttReAuthenticationOptions reAuthenticationOptions)
{
if (serverAuthPacket == null)
if (clientOptions == null)
{
throw new ArgumentNullException(nameof(serverAuthPacket));
throw new ArgumentNullException(nameof(clientOptions));
}
if (response == null)
if (reAuthenticationOptions == null)
{
throw new ArgumentNullException(nameof(response));
throw new ArgumentNullException(nameof(reAuthenticationOptions));
}
return new MqttAuthPacket
{
ReasonCode = MqttAuthenticateReasonCode.ContinueAuthentication,
ReasonString = null,
AuthenticationMethod = serverAuthPacket.AuthenticationMethod,
AuthenticationData = response.AuthenticationData,
UserProperties = response.UserProperties
// The authentication method cannot change and must be always the same as long as the client is connected.
AuthenticationMethod = clientOptions.AuthenticationMethod,
ReasonCode = MqttAuthenticateReasonCode.ReAuthenticate,
AuthenticationData = reAuthenticationOptions.AuthenticationData,
UserProperties = reAuthenticationOptions.UserProperties,
ReasonString = null
};
}
}
......
......@@ -59,6 +59,8 @@ namespace MQTTnet.Formatter
return MqttConnectReturnCode.ConnectionAccepted;
}
case MqttConnectReasonCode.BadAuthenticationMethod:
case MqttConnectReasonCode.Banned:
case MqttConnectReasonCode.NotAuthorized:
{
return MqttConnectReturnCode.ConnectionRefusedNotAuthorized;
......@@ -73,12 +75,13 @@ namespace MQTTnet.Formatter
{
return MqttConnectReturnCode.ConnectionRefusedIdentifierRejected;
}
case MqttConnectReasonCode.UnsupportedProtocolVersion:
{
return MqttConnectReturnCode.ConnectionRefusedUnacceptableProtocolVersion;
}
case MqttConnectReasonCode.UseAnotherServer:
case MqttConnectReasonCode.ServerUnavailable:
case MqttConnectReasonCode.ServerBusy:
case MqttConnectReasonCode.ServerMoved:
......@@ -88,7 +91,8 @@ namespace MQTTnet.Formatter
default:
{
throw new MqttProtocolViolationException("Unable to convert connect reason code (MQTTv5) to return code (MQTTv3).");
// This is the most best matching value.
return MqttConnectReturnCode.ConnectionRefusedUnacceptableProtocolVersion;
}
}
}
......
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=client_005Cauthentication/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=client_005Cconnecting/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=client_005Cdiagnostics/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=client_005Cdisconnecting/@EntryIndexedValue">True</s:Boolean>
......
......@@ -9,7 +9,7 @@ namespace MQTTnet.PacketDispatcher
{
public interface IMqttPacketAwaitable : IDisposable
{
MqttPacketAwaitableFilter Filter { get; }
int PacketIdentifier { get; }
void Complete(MqttPacket packet);
......
......@@ -16,18 +16,13 @@ namespace MQTTnet.PacketDispatcher
readonly AsyncTaskCompletionSource<MqttPacket> _promise = new AsyncTaskCompletionSource<MqttPacket>();
readonly MqttPacketDispatcher _owningPacketDispatcher;
public MqttPacketAwaitable(ushort packetIdentifier, MqttPacketDispatcher owningPacketDispatcher)
public MqttPacketAwaitable(int packetPacketIdentifier, MqttPacketDispatcher owningPacketDispatcher)
{
Filter = new MqttPacketAwaitableFilter
{
Type = typeof(TPacket),
Identifier = packetIdentifier
};
PacketIdentifier = packetPacketIdentifier;
_owningPacketDispatcher = owningPacketDispatcher ?? throw new ArgumentNullException(nameof(owningPacketDispatcher));
}
public MqttPacketAwaitableFilter Filter { get; }
public int PacketIdentifier { get; }
public async Task<TPacket> WaitOneAsync(CancellationToken cancellationToken)
{
......
......@@ -10,11 +10,15 @@ namespace MQTTnet.PacketDispatcher
{
public sealed class MqttPacketDispatcher : IDisposable
{
public const int PingRespPacketFakeIdentifier = 2147483647;
public const int AuthPacketFakeIdentifier = 2147483646;
public const int SuccessAuthPacketFakeIdentifier = 2147483647;
readonly List<IMqttPacketAwaitable> _waiters = new List<IMqttPacketAwaitable>();
bool _isDisposed;
public MqttPacketAwaitable<TResponsePacket> AddAwaitable<TResponsePacket>(ushort packetIdentifier) where TResponsePacket : MqttPacket
public MqttPacketAwaitable<TResponsePacket> AddAwaitable<TResponsePacket>(int packetIdentifier) where TResponsePacket : MqttPacket
{
var awaitable = new MqttPacketAwaitable<TResponsePacket>(packetIdentifier, this);
......@@ -99,15 +103,21 @@ namespace MQTTnet.PacketDispatcher
throw new ArgumentNullException(nameof(packet));
}
ushort identifier = 0;
var identifier = 0;
if (packet is MqttPacketWithIdentifier packetWithIdentifier)
{
identifier = packetWithIdentifier.PacketIdentifier;
}
var packetType = packet.GetType();
else if (packet is MqttPingRespPacket)
{
identifier = PingRespPacketFakeIdentifier;
}
else
{
throw new InvalidOperationException($"Cannot dispatch {packet.GetRfcName()} packet.");
}
var waiters = new List<IMqttPacketAwaitable>();
lock (_waiters)
{
ThrowIfDisposed();
......@@ -115,11 +125,8 @@ namespace MQTTnet.PacketDispatcher
for (var i = _waiters.Count - 1; i >= 0; i--)
{
var entry = _waiters[i];
// Note: The PingRespPacket will also arrive here and has NO identifier but there
// is code which waits for it. So the code must be able to deal with filters which
// are referring to the type only (identifier is 0)!
if (entry.Filter.Type != packetType || entry.Filter.Identifier != identifier)
if (entry.PacketIdentifier != identifier)
{
continue;
}
......
......@@ -28,7 +28,7 @@ namespace MQTTnet.Packets
public MqttQualityOfServiceLevel MaximumQoS { get; set; }
/// <summary>
/// Added in MQTTv5.
/// The return code is only used in MQTTv 5.0.0 and higher.
/// </summary>
public MqttConnectReasonCode ReasonCode { get; set; }
......@@ -39,13 +39,16 @@ namespace MQTTnet.Packets
public string ResponseInformation { get; set; }
public bool RetainAvailable { get; set; }
/// <summary>
/// The return code is only used in MQTTv 3.1.1 and lower.
/// </summary>
public MqttConnectReturnCode ReturnCode { get; set; }
public ushort ServerKeepAlive { get; set; }
public string ServerReference { get; set; }
public uint SessionExpiryInterval { get; set; }
public bool SharedSubscriptionAvailable { get; set; }
......
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
namespace MQTTnet.Packets
{
public sealed class MqttUserPropertiesBuilder
{
readonly List<MqttUserProperty> _properties = new List<MqttUserProperty>();
public List<MqttUserProperty> Build()
{
return _properties;
}
public MqttUserPropertiesBuilder With(string name, string value = "")
{
_properties.Add(new MqttUserProperty(name, value));
return this;
}
}
}
\ No newline at end of file
......@@ -4,12 +4,10 @@
using System;
namespace MQTTnet.PacketDispatcher
namespace MQTTnet.Server
{
public sealed class MqttPacketAwaitableFilter
public sealed class ClientReAuthenticatingEventArgs : EventArgs
{
public Type Type { get; set; }
public ushort Identifier { get; set; }
}
}
\ No newline at end of file
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using MQTTnet.Packets;
using MQTTnet.Protocol;
namespace MQTTnet.Server
{
public sealed class PartialAuthenticationResponse
{
public byte[] AuthenticationData { get; set; }
public MqttAuthenticateReasonCode ReasonCode { get; set; }
public List<MqttUserProperty> UserProperties { get; set; }
}
}
\ No newline at end of file
......@@ -7,7 +7,10 @@ using System.Collections;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Adapter;
using MQTTnet.Exceptions;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Packets;
......@@ -19,11 +22,12 @@ namespace MQTTnet.Server
{
readonly MqttConnectPacket _connectPacket;
public ValidatingConnectionEventArgs(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary sessionItems)
public ValidatingConnectionEventArgs(MqttConnectPacket connectPacket, IMqttChannelAdapter clientAdapter, IDictionary sessionItems, CancellationToken cancellationToken)
{
_connectPacket = connectPacket ?? throw new ArgumentNullException(nameof(connectPacket));
ChannelAdapter = clientAdapter ?? throw new ArgumentNullException(nameof(clientAdapter));
SessionItems = sessionItems ?? throw new ArgumentNullException(nameof(sessionItems));
CancellationToken = cancellationToken;
}
/// <summary>
......@@ -44,6 +48,11 @@ namespace MQTTnet.Server
/// </summary>
public string AuthenticationMethod => _connectPacket.AuthenticationMethod;
/// <summary>
/// Gets a cancellation token which is being canceled as soon as the transport connection is closed.
/// </summary>
public CancellationToken CancellationToken { get; }
/// <summary>
/// Gets the channel adapter. This can be a _MqttConnectionContext_ (used in ASP.NET), a _MqttChannelAdapter_ (used for
/// TCP or WebSockets) or a custom implementation.
......@@ -189,5 +198,51 @@ namespace MQTTnet.Server
/// A value of 0 indicates that the value is not used.
/// </summary>
public uint WillDelayInterval => _connectPacket.WillDelayInterval;
public async Task<PartialAuthenticationResponse> ReceiveAuthenticationDataAsync(CancellationToken cancellationToken = default)
{
var receivePacket = await ChannelAdapter.ReceivePacketAsync(cancellationToken).ConfigureAwait(false);
if (receivePacket is MqttAuthPacket authPacket)
{
if (!string.Equals(authPacket.AuthenticationMethod, AuthenticationMethod, StringComparison.Ordinal))
{
throw new MqttProtocolViolationException("The authentication method is not allowed to change while authenticating.");
}
return new PartialAuthenticationResponse
{
ReasonCode = authPacket.ReasonCode,
AuthenticationData = authPacket.AuthenticationData,
UserProperties = authPacket.UserProperties
};
}
if (receivePacket == null)
{
throw new MqttCommunicationException("The client closed the connection.");
}
throw new MqttProtocolViolationException("Expected an AUTH packet from the client.");
}
public Task SendAuthenticationDataAsync(
MqttAuthenticateReasonCode reasonCode,
byte[] authenticationData = null,
string reasonString = null,
List<MqttUserProperty> userProperties = null,
CancellationToken cancellationToken = default)
{
// The authentication method will never change so we must use the already known one [MQTT-4.12.0-5].
var authPacket = new MqttAuthPacket
{
AuthenticationMethod = AuthenticationMethod,
AuthenticationData = authenticationData,
UserProperties = userProperties,
ReasonCode = reasonCode,
ReasonString = reasonString
};
return ChannelAdapter.SendPacketAsync(authPacket, cancellationToken);
}
}
}
\ No newline at end of file
......@@ -95,8 +95,9 @@ namespace MQTTnet.Server
{
var cancellationToken = _cancellationToken.Token;
IsRunning = true;
_ = Task.Factory.StartNew(() => SendPacketsLoop(cancellationToken), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default).ConfigureAwait(false);
_ = Task.Factory.StartNew(() => SendPacketsLoop(cancellationToken), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default)
.ConfigureAwait(false);
await ReceivePackagesLoop(cancellationToken).ConfigureAwait(false);
}
......@@ -176,6 +177,22 @@ namespace MQTTnet.Server
return CompletedTask.Instance;
}
async Task HandleIncomingAuthPacket(MqttAuthPacket authPacket)
{
if (authPacket.ReasonCode != MqttAuthenticateReasonCode.ReAuthenticate)
{
throw new MqttProtocolViolationException("Clients are only allowed to send AUTH packets after connecting for re authentication.");
}
// We have to handle the re authentication in another task because we rely on the package retrieval.
_ = Task.Run(
async () =>
{
var eventArgs = new ClientReAuthenticatingEventArgs();
await _eventContainer.ClientReAuthenticatingEvent.InvokeAsync(eventArgs);
});
}
void HandleIncomingPingReqPacket()
{
// See: The Server MUST send a PINGRESP packet in response to a PINGREQ packet [MQTT-3.12.4-1].
......@@ -436,9 +453,13 @@ namespace MQTTnet.Server
DisconnectPacket = disconnectPacket;
return;
}
else if (currentPacket is MqttAuthPacket authPacket)
{
await HandleIncomingAuthPacket(authPacket);
}
else
{
throw new MqttProtocolViolationException("Packet not allowed");
throw new MqttProtocolViolationException("Packet not allowed.");
}
}
}
......
......@@ -360,7 +360,7 @@ namespace MQTTnet.Server
return;
}
var validatingConnectionEventArgs = await ValidateConnection(connectPacket, channelAdapter).ConfigureAwait(false);
var validatingConnectionEventArgs = await ValidateConnection(connectPacket, channelAdapter, cancellationToken).ConfigureAwait(false);
var connAckPacket = MqttPacketFactories.ConnAck.Create(validatingConnectionEventArgs);
if (validatingConnectionEventArgs.ReasonCode != MqttConnectReasonCode.Success)
......@@ -397,7 +397,7 @@ namespace MQTTnet.Server
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
_logger.Error(exception, $"Error while validating client {channelAdapter.Endpoint}. {exception.Message}");
}
finally
{
......@@ -719,11 +719,11 @@ namespace MQTTnet.Server
return null;
}
async Task<ValidatingConnectionEventArgs> ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
async Task<ValidatingConnectionEventArgs> ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
// TODO: Load session items from persisted sessions in the future.
var sessionItems = new ConcurrentDictionary<object, object>();
var eventArgs = new ValidatingConnectionEventArgs(connectPacket, channelAdapter, sessionItems);
var eventArgs = new ValidatingConnectionEventArgs(connectPacket, channelAdapter, sessionItems, cancellationToken);
await _eventContainer.ValidatingConnectionEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
// Check the client ID and set a random one if supported.
......
......@@ -17,11 +17,14 @@ namespace MQTTnet.Server
public AsyncEvent<ClientDisconnectedEventArgs> ClientDisconnectedEvent { get; } = new AsyncEvent<ClientDisconnectedEventArgs>();
public AsyncEvent<ClientReAuthenticatingEventArgs> ClientReAuthenticatingEvent { get; } = new AsyncEvent<ClientReAuthenticatingEventArgs>();
public AsyncEvent<ClientSubscribedTopicEventArgs> ClientSubscribedTopicEvent { get; } = new AsyncEvent<ClientSubscribedTopicEventArgs>();
public AsyncEvent<ClientUnsubscribedTopicEventArgs> ClientUnsubscribedTopicEvent { get; } = new AsyncEvent<ClientUnsubscribedTopicEventArgs>();
public AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs> InterceptingClientEnqueueEvent { get; } = new AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs>();
public AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs> InterceptingClientEnqueueEvent { get; } =
new AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs>();
public AsyncEvent<InterceptingPacketEventArgs> InterceptingInboundPacketEvent { get; } = new AsyncEvent<InterceptingPacketEventArgs>();
......
......@@ -73,6 +73,12 @@ namespace MQTTnet.Server
remove => _eventContainer.ClientDisconnectedEvent.RemoveHandler(value);
}
public event Func<ClientReAuthenticatingEventArgs, Task> ClientReAuthenticatingAsync
{
add => _eventContainer.ClientReAuthenticatingEvent.AddHandler(value);
remove => _eventContainer.ClientReAuthenticatingEvent.RemoveHandler(value);
}
public event Func<ClientSubscribedTopicEventArgs, Task> ClientSubscribedTopicAsync
{
add => _eventContainer.ClientSubscribedTopicEvent.AddHandler(value);
......@@ -203,13 +209,6 @@ namespace MQTTnet.Server
return _clientSessionsManager.GetClientStatusesAsync();
}
public Task<IList<MqttApplicationMessage>> GetRetainedMessagesAsync()
{
ThrowIfNotStarted();
return _retainedMessagesManager.GetMessages();
}
public Task<MqttApplicationMessage> GetRetainedMessageAsync(string topic)
{
if (topic == null)
......@@ -222,6 +221,13 @@ namespace MQTTnet.Server
return _retainedMessagesManager.GetMessage(topic);
}
public Task<IList<MqttApplicationMessage>> GetRetainedMessagesAsync()
{
ThrowIfNotStarted();
return _retainedMessagesManager.GetMessages();
}
public Task<IList<MqttSessionStatus>> GetSessionsAsync()
{
ThrowIfNotStarted();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册