提交 ae4bdd62 编写于 作者: 如梦技术's avatar 如梦技术 🐛

代码优化。

上级 96b89cda
......@@ -17,6 +17,7 @@
package net.dreamlu.iot.mqtt.codec;
import org.tio.core.ChannelContext;
import static net.dreamlu.iot.mqtt.codec.MqttConstant.MIN_CLIENT_ID_LENGTH;
/**
* 编解码工具
......@@ -25,8 +26,6 @@ import org.tio.core.ChannelContext;
*/
final class MqttCodecUtil {
private static final char[] TOPIC_WILDCARDS = {'#', '+'};
private static final int MIN_CLIENT_ID_LENGTH = 1;
private static final int MAX_CLIENT_ID_LENGTH = 23;
private static final String MQTT_VERSION_KEY = "TIO_CODEC_MQTT_VERSION";
protected static MqttVersion getMqttVersion(ChannelContext ctx) {
......@@ -55,13 +54,13 @@ final class MqttCodecUtil {
return messageId != 0;
}
protected static boolean isValidClientId(MqttVersion mqttVersion, String clientId) {
protected static boolean isValidClientId(MqttVersion mqttVersion, int maxClientIdLength, String clientId) {
if (clientId == null) {
return false;
}
switch (mqttVersion) {
case MQTT_3_1:
return clientId.length() >= MIN_CLIENT_ID_LENGTH && clientId.length() <= MAX_CLIENT_ID_LENGTH;
return clientId.length() >= MIN_CLIENT_ID_LENGTH && clientId.length() <= maxClientIdLength;
case MQTT_3_1_1:
case MQTT_5:
// In 3.1.3.1 Client Identifier of MQTT 3.1.1 and 5.0 specifications, The Server MAY allow ClientId’s
......
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package net.dreamlu.iot.mqtt.codec;
/**
* mqtt 常量
*
* @author netty
*/
public final class MqttConstant {
private MqttConstant() {
}
/**
* mqtt protocol length
*/
public static final int MQTT_PROTOCOL_LENGTH = 2;
/**
* Default max bytes in message
*/
public static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 8092;
/**
* min client id length
*/
public static final int MIN_CLIENT_ID_LENGTH = 1;
/**
* Default max client id length,In the mqtt3.1 protocol,
* the default maximum Client Identifier length is 23
*/
public static final int DEFAULT_MAX_CLIENT_ID_LENGTH = 23;
}
......@@ -23,6 +23,9 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import static net.dreamlu.iot.mqtt.codec.MqttConstant.*;
/**
* Decodes Mqtt messages from bytes, following
* the MQTT protocol specification
......@@ -35,18 +38,20 @@ import java.util.List;
* @author L.cm
*/
public final class MqttDecoder {
public static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 8092;
public static final int MQTT_PROTOCOL_LENGTH = 2;
public static final MqttDecoder INSTANCE = new MqttDecoder();
private final int maxBytesInMessage;
private final int maxClientIdLength;
public MqttDecoder() {
this(DEFAULT_MAX_BYTES_IN_MESSAGE);
}
public MqttDecoder(int maxBytesInMessage) {
this(maxBytesInMessage, DEFAULT_MAX_CLIENT_ID_LENGTH);
}
public MqttDecoder(int maxBytesInMessage, int maxClientIdLength) {
this.maxBytesInMessage = maxBytesInMessage;
this.maxClientIdLength = maxClientIdLength;
}
public MqttMessage decode(ChannelContext ctx, ByteBuffer buffer, int limit, int position, int readableLength) {
......@@ -84,7 +89,7 @@ public final class MqttDecoder {
// 5. 解析消息体
final Result<?> decodedPayload;
try {
decodedPayload = decodePayload(buffer, mqttFixedHeader.messageType(),
decodedPayload = decodePayload(buffer, maxClientIdLength, mqttFixedHeader.messageType(),
bytesRemainingInVariablePart, variableHeader);
bytesRemainingInVariablePart -= decodedPayload.numberOfBytesConsumed;
if (bytesRemainingInVariablePart != 0) {
......@@ -367,12 +372,12 @@ public final class MqttDecoder {
* @param variableHeader variable header of the same message
* @return the payload
*/
private static Result<?> decodePayload(
ByteBuffer buffer, MqttMessageType messageType,
int bytesRemainingInVariablePart, Object variableHeader) {
private static Result<?> decodePayload(ByteBuffer buffer, int maxClientIdLength,
MqttMessageType messageType, int bytesRemainingInVariablePart,
Object variableHeader) {
switch (messageType) {
case CONNECT:
return decodeConnectionPayload(buffer, (MqttConnectVariableHeader) variableHeader);
return decodeConnectionPayload(buffer, maxClientIdLength, (MqttConnectVariableHeader) variableHeader);
case SUBSCRIBE:
return decodeSubscribePayload(buffer, bytesRemainingInVariablePart);
case SUBACK:
......@@ -389,13 +394,13 @@ public final class MqttDecoder {
}
}
private static Result<MqttConnectPayload> decodeConnectionPayload(
ByteBuffer buffer, MqttConnectVariableHeader mqttConnectVariableHeader) {
private static Result<MqttConnectPayload> decodeConnectionPayload(ByteBuffer buffer, int maxClientIdLength,
MqttConnectVariableHeader mqttConnectVariableHeader) {
final Result<String> decodedClientId = decodeString(buffer);
final String decodedClientIdValue = decodedClientId.value;
final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(),
(byte) mqttConnectVariableHeader.version());
if (!MqttCodecUtil.isValidClientId(mqttVersion, decodedClientIdValue)) {
if (!MqttCodecUtil.isValidClientId(mqttVersion, maxClientIdLength, decodedClientIdValue)) {
throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue);
}
int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed;
......@@ -441,9 +446,7 @@ public final class MqttDecoder {
return new Result<>(mqttConnectPayload, numberOfBytesConsumed);
}
private static Result<MqttSubscribePayload> decodeSubscribePayload(
ByteBuffer buffer,
int bytesRemainingInVariablePart) {
private static Result<MqttSubscribePayload> decodeSubscribePayload(ByteBuffer buffer, int bytesRemainingInVariablePart) {
final List<MqttTopicSubscription> subscribeTopics = new ArrayList<>();
int numberOfBytesConsumed = 0;
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
......@@ -467,9 +470,7 @@ public final class MqttDecoder {
return new Result<>(new MqttSubscribePayload(subscribeTopics), numberOfBytesConsumed);
}
private static Result<MqttSubAckPayload> decodeSubAckPayload(
ByteBuffer buffer,
int bytesRemainingInVariablePart) {
private static Result<MqttSubAckPayload> decodeSubAckPayload(ByteBuffer buffer, int bytesRemainingInVariablePart) {
final List<Integer> grantedQos = new ArrayList<>(bytesRemainingInVariablePart);
int numberOfBytesConsumed = 0;
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
......@@ -492,9 +493,7 @@ public final class MqttDecoder {
return new Result<>(new MqttUnsubAckPayload(reasonCodes), numberOfBytesConsumed);
}
private static Result<MqttUnsubscribePayload> decodeUnsubscribePayload(
ByteBuffer buffer,
int bytesRemainingInVariablePart) {
private static Result<MqttUnsubscribePayload> decodeUnsubscribePayload(ByteBuffer buffer, int bytesRemainingInVariablePart) {
final List<String> unsubscribeTopics = new ArrayList<>();
int numberOfBytesConsumed = 0;
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
......
......@@ -21,6 +21,7 @@ import org.tio.core.ChannelContext;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import static net.dreamlu.iot.mqtt.codec.MqttConstant.DEFAULT_MAX_CLIENT_ID_LENGTH;
/**
* Encodes Mqtt messages into bytes following the protocol specification v3.1
......@@ -99,7 +100,7 @@ public final class MqttEncoder {
// Client id
String clientIdentifier = payload.clientIdentifier();
if (!MqttCodecUtil.isValidClientId(mqttVersion, clientIdentifier)) {
if (!MqttCodecUtil.isValidClientId(mqttVersion, DEFAULT_MAX_CLIENT_ID_LENGTH, clientIdentifier)) {
throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + clientIdentifier);
}
byte[] clientIdentifierBytes = encodeStringUtf8(clientIdentifier);
......
......@@ -37,7 +37,7 @@ public class MqttClientAioHandler implements ClientAioHandler {
private final IMqttClientProcessor processor;
public MqttClientAioHandler(ByteBufferAllocator bufferAllocator, IMqttClientProcessor processor) {
this.mqttDecoder = MqttDecoder.INSTANCE;
this.mqttDecoder = new MqttDecoder();
this.mqttEncoder = MqttEncoder.INSTANCE;
this.allocator = bufferAllocator;
this.processor = processor;
......@@ -49,7 +49,7 @@ public class MqttClientAioHandler implements ClientAioHandler {
}
@Override
public Packet decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext) throws TioDecodeException {
public Packet decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext) {
return mqttDecoder.decode(channelContext, buffer, limit, position, readableLength);
}
......
......@@ -16,10 +16,7 @@
package net.dreamlu.iot.mqtt.core.client;
import net.dreamlu.iot.mqtt.codec.ByteBufferAllocator;
import net.dreamlu.iot.mqtt.codec.MqttDecoder;
import net.dreamlu.iot.mqtt.codec.MqttProperties;
import net.dreamlu.iot.mqtt.codec.MqttVersion;
import net.dreamlu.iot.mqtt.codec.*;
import org.tio.client.ClientChannelContext;
import org.tio.client.ClientTioConfig;
import org.tio.client.ReconnConf;
......@@ -60,9 +57,9 @@ public final class MqttClientCreator {
*/
private Integer timeout;
/**
* t-io 每次消息读取长度
* t-io 每次消息读取长度,跟 maxBytesInMessage 相关
*/
private int readBufferSize = MqttDecoder.DEFAULT_MAX_BYTES_IN_MESSAGE;
private int readBufferSize = MqttConstant.DEFAULT_MAX_BYTES_IN_MESSAGE;
/**
* Keep Alive (s)
*/
......@@ -70,7 +67,7 @@ public final class MqttClientCreator {
/**
* SSL配置
*/
protected SslConfig sslConfig;
private SslConfig sslConfig;
/**
* 自动重连
*/
......@@ -288,7 +285,7 @@ public final class MqttClientCreator {
ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1, DefaultThreadFactory.getInstance("MqttClient"));
IMqttClientProcessor processor = new DefaultMqttClientProcessor(clientStore, connLatch, executor);
// 2. 初始化 mqtt 处理器
ClientAioHandler clientAioHandler = new MqttClientAioHandler(this.bufferAllocator, Objects.requireNonNull(processor));
ClientAioHandler clientAioHandler = new MqttClientAioHandler(this.bufferAllocator, processor);
ClientAioListener clientAioListener = new MqttClientAioListener(this, clientStore, executor);
// 3. 重连配置
ReconnConf reconnConf = null;
......
......@@ -40,8 +40,11 @@ public class MqttServerAioHandler implements ServerAioHandler {
private final ByteBufferAllocator allocator;
private final MqttServerProcessor processor;
public MqttServerAioHandler(ByteBufferAllocator bufferAllocator, MqttServerProcessor processor) {
this.mqttDecoder = MqttDecoder.INSTANCE;
public MqttServerAioHandler(int maxBytesInMessage,
int maxClientIdLength,
ByteBufferAllocator bufferAllocator,
MqttServerProcessor processor) {
this.mqttDecoder = new MqttDecoder(maxBytesInMessage, maxClientIdLength);
this.mqttEncoder = MqttEncoder.INSTANCE;
this.allocator = bufferAllocator;
this.processor = processor;
......@@ -57,10 +60,9 @@ public class MqttServerAioHandler implements ServerAioHandler {
* @param readableLength ByteBuffer参与本次解码的有效数据(= limit - position)
* @param context ChannelContext
* @return Packet
* @throws TioDecodeException TioDecodeException
*/
@Override
public Packet decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext context) throws TioDecodeException {
public Packet decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext context) {
return mqttDecoder.decode(context, buffer, limit, position, readableLength);
}
......@@ -82,10 +84,9 @@ public class MqttServerAioHandler implements ServerAioHandler {
*
* @param packet Packet
* @param context ChannelContext
* @throws Exception Exception
*/
@Override
public void handler(Packet packet, ChannelContext context) throws Exception {
public void handler(Packet packet, ChannelContext context) {
MqttMessage mqttMessage = (MqttMessage) packet;
// 1. 先判断 mqtt 消息解析是否正常
DecoderResult decoderResult = mqttMessage.decoderResult();
......
......@@ -17,7 +17,7 @@
package net.dreamlu.iot.mqtt.core.server;
import net.dreamlu.iot.mqtt.codec.ByteBufferAllocator;
import net.dreamlu.iot.mqtt.codec.MqttDecoder;
import net.dreamlu.iot.mqtt.codec.MqttConstant;
import net.dreamlu.iot.mqtt.core.server.dispatcher.IMqttMessageDispatcher;
import net.dreamlu.iot.mqtt.core.server.event.IMqttConnectStatusListener;
import net.dreamlu.iot.mqtt.core.server.event.IMqttMessageListener;
......@@ -63,9 +63,17 @@ public class MqttServerCreator {
*/
private Long heartbeatTimeout;
/**
* 接收数据的 buffer size
* 接收数据的 buffer size,默认:8092
*/
private int readBufferSize = MqttDecoder.DEFAULT_MAX_BYTES_IN_MESSAGE;
private int readBufferSize = MqttConstant.DEFAULT_MAX_BYTES_IN_MESSAGE;
/**
* 消息解析最大 bytes 长度,默认:8092
*/
private int maxBytesInMessage = MqttConstant.DEFAULT_MAX_BYTES_IN_MESSAGE;
/**
* 最大 clientId 长度,默认:23
*/
private int maxClientIdLength = MqttConstant.DEFAULT_MAX_CLIENT_ID_LENGTH;
/**
* 堆内存和堆外内存
*/
......@@ -156,6 +164,30 @@ public class MqttServerCreator {
return this;
}
public int getMaxBytesInMessage() {
return maxBytesInMessage;
}
public MqttServerCreator maxBytesInMessage(int maxBytesInMessage) {
if (maxBytesInMessage < 1) {
throw new IllegalArgumentException("maxBytesInMessage must be greater than 0.");
}
this.maxBytesInMessage = maxBytesInMessage;
return this;
}
public int getMaxClientIdLength() {
return maxClientIdLength;
}
public MqttServerCreator maxClientIdLength(int maxClientIdLength) {
if (maxClientIdLength < 1) {
throw new IllegalArgumentException("maxClientIdLength must be greater than 0.");
}
this.maxClientIdLength = maxClientIdLength;
return this;
}
public ByteBufferAllocator getBufferAllocator() {
return bufferAllocator;
}
......@@ -289,14 +321,13 @@ public class MqttServerCreator {
this.connectStatusListener = new DefaultMqttConnectStatusListener();
}
ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(2, DefaultThreadFactory.getInstance("MqttServer"));
DefaultMqttServerProcessor serverProcessor = new DefaultMqttServerProcessor(
this.messageStore, this.sessionManager, this.authHandler, this.subscribeManager,
this.messageDispatcher, this.connectStatusListener, this.messageListener, executor);
DefaultMqttServerProcessor serverProcessor = new DefaultMqttServerProcessor(this.messageStore, this.sessionManager,
this.authHandler, this.subscribeManager, this.messageDispatcher, this.connectStatusListener, this.messageListener, executor);
// 1. 处理消息
ServerAioHandler handler = new MqttServerAioHandler(this.bufferAllocator, serverProcessor);
ServerAioHandler handler = new MqttServerAioHandler(this.maxBytesInMessage, this.maxClientIdLength, this.bufferAllocator, serverProcessor);
// 2. t-io 监听
ServerAioListener listener = new MqttServerAioListener(
this.messageStore, this.sessionManager, this.subscribeManager, this.messageDispatcher, this.connectStatusListener);
ServerAioListener listener = new MqttServerAioListener(this.messageStore, this.sessionManager, this.subscribeManager,
this.messageDispatcher, this.connectStatusListener);
// 2. t-io 配置
ServerTioConfig tioConfig = new ServerTioConfig(this.name, handler, listener);
// 4. 设置 t-io 心跳 timeout
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册