From d8124887671bf4f0ae5dd6f41e5b8f391006002b Mon Sep 17 00:00:00 2001 From: fatedier Date: Mon, 11 Mar 2019 14:14:31 +0800 Subject: [PATCH] support tls connection --- client/control.go | 11 +- client/service.go | 11 +- conf/frpc_full.ini | 3 + models/config/client_common.go | 8 ++ server/service.go | 41 +++++++ tests/ci/tls_test.go | 188 +++++++++++++++++++++++++++++++++ utils/net/conn.go | 11 ++ utils/net/tls.go | 44 ++++++++ utils/version/version.go | 2 +- 9 files changed, 314 insertions(+), 5 deletions(-) create mode 100644 tests/ci/tls_test.go create mode 100644 utils/net/tls.go diff --git a/client/control.go b/client/control.go index 750803f..bbcece6 100644 --- a/client/control.go +++ b/client/control.go @@ -15,6 +15,7 @@ package client import ( + "crypto/tls" "fmt" "io" "runtime/debug" @@ -166,8 +167,14 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) { } conn = frpNet.WrapConn(stream) } else { - conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, - fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) + var tlsConfig *tls.Config + if g.GlbClientCfg.TLSEnable { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, + fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig) if err != nil { ctl.Warn("start new connection to server error: %v", err) return diff --git a/client/service.go b/client/service.go index d439705..32106ca 100644 --- a/client/service.go +++ b/client/service.go @@ -15,6 +15,7 @@ package client import ( + "crypto/tls" "fmt" "io/ioutil" "runtime" @@ -151,8 +152,14 @@ func (svr *Service) keepControllerWorking() { // conn: control connection // session: if it's not nil, using tcp mux func (svr *Service) login() (conn frpNet.Conn, session *fmux.Session, err error) { - conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, - fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort)) + var tlsConfig *tls.Config + if g.GlbClientCfg.TLSEnable { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol, + fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig) if err != nil { return } diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index 29ef0e1..60bf859 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -44,6 +44,9 @@ login_fail_exit = true # now it supports tcp and kcp and websocket, default is tcp protocol = tcp +# if tls_enable is true, frpc will connect frps by tls +tls_enable = true + # specify a dns server, so frpc will use this instead of default one # dns_server = 8.8.8.8 diff --git a/models/config/client_common.go b/models/config/client_common.go index 5dc49aa..1cd9ffc 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -44,6 +44,7 @@ type ClientCommonConf struct { LoginFailExit bool `json:"login_fail_exit"` Start map[string]struct{} `json:"start"` Protocol string `json:"protocol"` + TLSEnable bool `json:"tls_enable"` HeartBeatInterval int64 `json:"heartbeat_interval"` HeartBeatTimeout int64 `json:"heartbeat_timeout"` } @@ -69,6 +70,7 @@ func GetDefaultClientConf() *ClientCommonConf { LoginFailExit: true, Start: make(map[string]struct{}), Protocol: "tcp", + TLSEnable: false, HeartBeatInterval: 30, HeartBeatTimeout: 90, } @@ -194,6 +196,12 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c cfg.Protocol = tmpStr } + if tmpStr, ok = conf.Get("common", "tls_enable"); ok && tmpStr == "true" { + cfg.TLSEnable = true + } else { + cfg.TLSEnable = false + } + if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok { if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") diff --git a/server/service.go b/server/service.go index b40a1a2..ac5602a 100644 --- a/server/service.go +++ b/server/service.go @@ -16,8 +16,14 @@ package server import ( "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" "io/ioutil" + "math/big" "net" "net/http" "time" @@ -61,6 +67,9 @@ type Service struct { // Accept connections using websocket websocketListener frpNet.Listener + // Accept frp tls connections + tlsListener frpNet.Listener + // Manage all controllers ctlManager *ControlManager @@ -72,6 +81,8 @@ type Service struct { // stats collector to store server and proxies stats info statsCollector stats.Collector + + tlsConfig *tls.Config } func NewService() (svr *Service, err error) { @@ -84,6 +95,7 @@ func NewService() (svr *Service, err error) { TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), }, + tlsConfig: generateTLSConfig(), } // Init group controller @@ -187,6 +199,12 @@ func NewService() (svr *Service, err error) { log.Info("https service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort) } + // frp tls listener + tlsListener := svr.muxer.Listen(1, 1, func(data []byte) bool { + return int(data[0]) == frpNet.FRP_TLS_HEAD_BYTE + }) + svr.tlsListener = frpNet.WrapLogListener(tlsListener) + // Create nat hole controller. if cfg.BindUdpPort > 0 { var nc *nathole.NatHoleController @@ -225,6 +243,7 @@ func (svr *Service) Run() { } go svr.HandleListener(svr.websocketListener) + go svr.HandleListener(svr.tlsListener) svr.HandleListener(svr.listener) } @@ -237,6 +256,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) { log.Warn("Listener for incoming connections from client closed") return } + c = frpNet.CheckAndEnableTLSServerConn(c, svr.tlsConfig) // Start a new goroutine for dealing connections. go func(frpConn frpNet.Conn) { @@ -373,3 +393,24 @@ func (svr *Service) RegisterVisitorConn(visitorConn frpNet.Conn, newMsg *msg.New return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, newMsg.UseEncryption, newMsg.UseCompression) } + +// Setup a bare-bones TLS config for the server +func generateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{tlsCert}} +} diff --git a/tests/ci/tls_test.go b/tests/ci/tls_test.go new file mode 100644 index 0000000..2f13eb3 --- /dev/null +++ b/tests/ci/tls_test.go @@ -0,0 +1,188 @@ +package ci + +import ( + "os" + "testing" + "time" + + "github.com/fatedier/frp/tests/config" + "github.com/fatedier/frp/tests/consts" + "github.com/fatedier/frp/tests/util" + + "github.com/stretchr/testify/assert" +) + +const FRPS_TLS_TCP_CONF = ` +[common] +bind_addr = 0.0.0.0 +bind_port = 20000 +log_file = console +log_level = debug +token = 123456 +` + +const FRPC_TLS_TCP_CONF = ` +[common] +server_addr = 127.0.0.1 +server_port = 20000 +log_file = console +log_level = debug +token = 123456 +protocol = tcp +tls_enable = true + +[tcp] +type = tcp +local_port = 10701 +remote_port = 20801 +` + +func TestTlsOverTCP(t *testing.T) { + assert := assert.New(t) + frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_TCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpsCfgPath) + } + + frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_TCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpcCfgPath) + } + + frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath}) + err = frpsProcess.Start() + if assert.NoError(err) { + defer frpsProcess.Stop() + } + + time.Sleep(100 * time.Millisecond) + + frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath}) + err = frpcProcess.Start() + if assert.NoError(err) { + defer frpcProcess.Stop() + } + time.Sleep(250 * time.Millisecond) + + // test tcp + res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(consts.TEST_TCP_ECHO_STR, res) +} + +const FRPS_TLS_KCP_CONF = ` +[common] +bind_addr = 0.0.0.0 +bind_port = 20000 +kcp_bind_port = 20000 +log_file = console +log_level = debug +token = 123456 +` + +const FRPC_TLS_KCP_CONF = ` +[common] +server_addr = 127.0.0.1 +server_port = 20000 +log_file = console +log_level = debug +token = 123456 +protocol = kcp +tls_enable = true + +[tcp] +type = tcp +local_port = 10701 +remote_port = 20801 +` + +func TestTLSOverKCP(t *testing.T) { + assert := assert.New(t) + frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_KCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpsCfgPath) + } + + frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_KCP_CONF) + if assert.NoError(err) { + defer os.Remove(frpcCfgPath) + } + + frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath}) + err = frpsProcess.Start() + if assert.NoError(err) { + defer frpsProcess.Stop() + } + + time.Sleep(200 * time.Millisecond) + + frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath}) + err = frpcProcess.Start() + if assert.NoError(err) { + defer frpcProcess.Stop() + } + time.Sleep(500 * time.Millisecond) + + // test tcp + res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(consts.TEST_TCP_ECHO_STR, res) +} + +const FRPS_TLS_WS_CONF = ` +[common] +bind_addr = 0.0.0.0 +bind_port = 20000 +log_file = console +log_level = debug +token = 123456 +` + +const FRPC_TLS_WS_CONF = ` +[common] +server_addr = 127.0.0.1 +server_port = 20000 +log_file = console +log_level = debug +token = 123456 +protocol = websocket +tls_enable = true + +[tcp] +type = tcp +local_port = 10701 +remote_port = 20801 +` + +func TestTLSOverWebsocket(t *testing.T) { + assert := assert.New(t) + frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_WS_CONF) + if assert.NoError(err) { + defer os.Remove(frpsCfgPath) + } + + frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_WS_CONF) + if assert.NoError(err) { + defer os.Remove(frpcCfgPath) + } + + frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath}) + err = frpsProcess.Start() + if assert.NoError(err) { + defer frpsProcess.Stop() + } + + time.Sleep(200 * time.Millisecond) + + frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath}) + err = frpcProcess.Start() + if assert.NoError(err) { + defer frpcProcess.Stop() + } + time.Sleep(500 * time.Millisecond) + + // test tcp + res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR) + assert.NoError(err) + assert.Equal(consts.TEST_TCP_ECHO_STR, res) +} diff --git a/utils/net/conn.go b/utils/net/conn.go index 6dab2bd..e716457 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -15,6 +15,7 @@ package net import ( + "crypto/tls" "errors" "fmt" "io" @@ -207,3 +208,13 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn return nil, fmt.Errorf("unsupport protocol: %s", protocol) } } + +func ConnectServerByProxyWithTLS(proxyUrl string, protocol string, addr string, tlsConfig *tls.Config) (c Conn, err error) { + c, err = ConnectServerByProxy(proxyUrl, protocol, addr) + if tlsConfig == nil { + return + } + + c = WrapTLSClientConn(c, tlsConfig) + return +} diff --git a/utils/net/tls.go b/utils/net/tls.go new file mode 100644 index 0000000..ae1bfc7 --- /dev/null +++ b/utils/net/tls.go @@ -0,0 +1,44 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "crypto/tls" + "net" + + gnet "github.com/fatedier/golib/net" +) + +var ( + FRP_TLS_HEAD_BYTE = 0x17 +) + +func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out Conn) { + c.Write([]byte{byte(FRP_TLS_HEAD_BYTE)}) + out = WrapConn(tls.Client(c, tlsConfig)) + return +} + +func CheckAndEnableTLSServerConn(c net.Conn, tlsConfig *tls.Config) (out Conn) { + sc, r := gnet.NewSharedConnSize(c, 1) + buf := make([]byte, 1) + n, _ := r.Read(buf) + if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE { + out = WrapConn(tls.Server(c, tlsConfig)) + } else { + out = WrapConn(sc) + } + return +} diff --git a/utils/version/version.go b/utils/version/version.go index 41897cb..6fb34ec 100644 --- a/utils/version/version.go +++ b/utils/version/version.go @@ -19,7 +19,7 @@ import ( "strings" ) -var version string = "0.24.1" +var version string = "0.25.0" func Full() string { return version -- GitLab