SSLNetworkModule.java 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
/*******************************************************************************
 * Copyright (c) 2009, 2019 IBM Corp.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * and Eclipse Distribution License v1.0 which accompany this distribution. 
 *
 * The Eclipse Public License is available at 
 *    http://www.eclipse.org/legal/epl-v10.html
 * and the Eclipse Distribution License is available at 
 *   http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * Contributors:
 *    Dave Locke - initial API and implementation and/or initial documentation
 */
package org.eclipse.paho.client.mqttv3.internal;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.logging.Logger;
import org.eclipse.paho.client.mqttv3.logging.LoggerFactory;

/**
 * A network module for connecting over SSL.
 */
public class SSLNetworkModule extends TCPNetworkModule {
	private static final String CLASS_NAME = SSLNetworkModule.class.getName();
	private Logger log = LoggerFactory.getLogger(LoggerFactory.MQTT_CLIENT_MSG_CAT, CLASS_NAME);

	private String[] enabledCiphers;
	private int handshakeTimeoutSecs;
	private HostnameVerifier hostnameVerifier;
	private boolean httpsHostnameVerificationEnabled = false;

	

	private String host;
	private int port;

	/**
	 * Constructs a new SSLNetworkModule using the specified host and port. The
	 * supplied SSLSocketFactory is used to supply the network socket.
	 * 
	 * @param factory
	 *            the {@link SSLSocketFactory} to be used in this SSLNetworkModule
	 * @param host
	 *            the Hostname of the Server
	 * @param port
	 *            the Port of the Server
	 * @param resourceContext
	 *            Resource Context
	 */
	public SSLNetworkModule(SSLSocketFactory factory, String host, int port, String resourceContext) {
		super(factory, host, port, resourceContext);
		this.host = host;
		this.port = port;
		log.setResourceName(resourceContext);
	}

	/**
	 * Returns the enabled cipher suites.
	 * 
	 * @return a string array of enabled Cipher suites
	 */
	public String[] getEnabledCiphers() {
		return enabledCiphers;
	}

	/**
	 * Sets the enabled cipher suites on the underlying network socket.
	 * 
	 * @param enabledCiphers
	 *            a String array of cipher suites to enable
	 */
	public void setEnabledCiphers(String[] enabledCiphers) {
		final String methodName = "setEnabledCiphers";
		if (enabledCiphers != null) {
			this.enabledCiphers = enabledCiphers.clone();
		}
		if ((socket != null) && (this.enabledCiphers != null)) {
			if (log.isLoggable(Logger.FINE)) {
				String ciphers = "";
				for (int i = 0; i < this.enabledCiphers.length; i++) {
					if (i > 0) {
						ciphers += ",";
					}
					ciphers += this.enabledCiphers[i];
				}
				// @TRACE 260=setEnabledCiphers ciphers={0}
				log.fine(CLASS_NAME, methodName, "260", new Object[] { ciphers });
			}
			((SSLSocket) socket).setEnabledCipherSuites(this.enabledCiphers);
		}
	}

	public void setSSLhandshakeTimeout(int timeout) {
		super.setConnectTimeout(timeout);
		this.handshakeTimeoutSecs = timeout;
	}

	public HostnameVerifier getSSLHostnameVerifier() {
		return hostnameVerifier;
	}

	public void setSSLHostnameVerifier(HostnameVerifier hostnameVerifier) {
		this.hostnameVerifier = hostnameVerifier;
	}
	
	public boolean isHttpsHostnameVerificationEnabled() {
		return httpsHostnameVerificationEnabled;
	}

	public void setHttpsHostnameVerificationEnabled(boolean httpsHostnameVerificationEnabled) {
		this.httpsHostnameVerificationEnabled = httpsHostnameVerificationEnabled;
	}

	public void start() throws IOException, MqttException {
		super.start();
		setEnabledCiphers(enabledCiphers);
		int soTimeout = socket.getSoTimeout();
		// RTC 765: Set a timeout to avoid the SSL handshake being blocked indefinitely
		socket.setSoTimeout(this.handshakeTimeoutSecs * 1000);
		
		// SNI support.  Should be automatic under some circumstances - not all, apparently
		SSLParameters sslParameters = new SSLParameters();
		List<SNIServerName> sniHostNames = new ArrayList<SNIServerName>(1);
		sniHostNames.add(new SNIHostName(host));
		sslParameters.setServerNames(sniHostNames);
		((SSLSocket)socket).setSSLParameters(sslParameters);

		// If default Hostname verification is enabled, use the same method that is used with HTTPS
		if(this.httpsHostnameVerificationEnabled) {
			SSLParameters sslParams = new SSLParameters();
			sslParams.setEndpointIdentificationAlgorithm("HTTPS");
			((SSLSocket) socket).setSSLParameters(sslParams);
		}
		((SSLSocket) socket).startHandshake();
		if (hostnameVerifier != null && !this.httpsHostnameVerificationEnabled) {
			SSLSession session = ((SSLSocket) socket).getSession();
			if(!hostnameVerifier.verify(host, session)) {
				session.invalidate();
				socket.close();
				throw new SSLPeerUnverifiedException("Host: " + host + ", Peer Host: " + session.getPeerHost());
			}
		}
		// reset timeout to default value
		socket.setSoTimeout(soTimeout);
	}

	public String getServerURI() {
		return "ssl://" + host + ":" + port;
	}
}