提交 5db605ca 编写于 作者: F fatedier

frps: vhost_http_port and vhost_https_port can be same with frps bind

port
上级 f45283db
......@@ -16,6 +16,7 @@ kcp_bind_port = 7000
# proxy_bind_addr = 127.0.0.1
# if you want to support virtual host, you must set the http port for listening (optional)
# Note: http port and https port can be same with bind_port
vhost_http_port = 80
vhost_https_port = 443
......
......@@ -26,6 +26,7 @@ import (
"github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/log"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/fatedier/frp/utils/net/mux"
"github.com/fatedier/frp/utils/util"
"github.com/fatedier/frp/utils/version"
"github.com/fatedier/frp/utils/vhost"
......@@ -41,6 +42,9 @@ var ServerService *Service
// Server service.
type Service struct {
// Dispatch connections to different handlers listen on same port.
muxer *mux.Mux
// Accept connections from client.
listener frpNet.Listener
......@@ -88,12 +92,33 @@ func NewService() (svr *Service, err error) {
return
}
var (
httpMuxOn bool
httpsMuxOn bool
)
if cfg.BindAddr == cfg.ProxyBindAddr {
if cfg.BindPort == cfg.VhostHttpPort {
httpMuxOn = true
}
if cfg.BindPort == cfg.VhostHttpsPort {
httpsMuxOn = true
}
if httpMuxOn || httpsMuxOn {
svr.muxer = mux.NewMux()
}
}
// Listen for accepting connections from client.
svr.listener, err = frpNet.ListenTcp(cfg.BindAddr, cfg.BindPort)
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.BindPort))
if err != nil {
err = fmt.Errorf("Create server listener error, %v", err)
return
}
if svr.muxer != nil {
go svr.muxer.Serve(ln)
ln = svr.muxer.DefaultListener()
}
svr.listener = frpNet.WrapLogListener(ln)
log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort)
// Listen for accepting connections from client using kcp protocol.
......@@ -117,10 +142,14 @@ func NewService() (svr *Service, err error) {
Handler: rp,
}
var l net.Listener
l, err = net.Listen("tcp", address)
if err != nil {
err = fmt.Errorf("Create vhost http listener error, %v", err)
return
if httpMuxOn {
l = svr.muxer.ListenHttp(0)
} else {
l, err = net.Listen("tcp", address)
if err != nil {
err = fmt.Errorf("Create vhost http listener error, %v", err)
return
}
}
go server.Serve(l)
log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
......@@ -128,13 +157,18 @@ func NewService() (svr *Service, err error) {
// Create https vhost muxer.
if cfg.VhostHttpsPort > 0 {
var l frpNet.Listener
l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpsPort)
if err != nil {
err = fmt.Errorf("Create vhost https listener error, %v", err)
return
var l net.Listener
if httpsMuxOn {
l = svr.muxer.ListenHttps(0)
} else {
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
if err != nil {
err = fmt.Errorf("Create server listener error, %v", err)
return
}
}
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second)
svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(frpNet.WrapLogListener(l), 30*time.Second)
if err != nil {
err = fmt.Errorf("Create vhost httpsMuxer error, %v", err)
return
......
......@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
......@@ -136,7 +135,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
type SharedConn struct {
Conn
sync.Mutex
buf *bytes.Buffer
}
......@@ -149,22 +147,24 @@ func NewShareConn(conn Conn) (*SharedConn, io.Reader) {
return sc, io.TeeReader(conn, sc.buf)
}
func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) {
sc := &SharedConn{
Conn: conn,
buf: bytes.NewBuffer(make([]byte, 0, bufSize)),
}
return sc, io.TeeReader(conn, sc.buf)
}
// Not thread safety.
func (sc *SharedConn) Read(p []byte) (n int, err error) {
sc.Lock()
if sc.buf == nil {
sc.Unlock()
return sc.Conn.Read(p)
}
sc.Unlock()
n, err = sc.buf.Read(p)
if err == io.EOF {
sc.Lock()
sc.buf = nil
sc.Unlock()
var n2 int
n2, err = sc.Conn.Read(p[n:])
n += n2
}
return
......
package mux
import (
"fmt"
"io"
"net"
"sort"
"sync"
"time"
"github.com/fatedier/frp/utils/errors"
frpNet "github.com/fatedier/frp/utils/net"
)
const (
// DefaultTimeout is the default length of time to wait for bytes we need.
DefaultTimeout = 10 * time.Second
)
type Mux struct {
ln net.Listener
defaultLn *listener
lns []*listener
maxNeedBytesNum uint32
mu sync.RWMutex
}
func NewMux() (mux *Mux) {
mux = &Mux{
lns: make([]*listener, 0),
}
return
}
func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
ln := &listener{
c: make(chan net.Conn),
mux: mux,
needBytesNum: needBytesNum,
matchFn: fn,
}
mux.mu.Lock()
defer mux.mu.Unlock()
if needBytesNum > mux.maxNeedBytesNum {
mux.maxNeedBytesNum = needBytesNum
}
newlns := append(mux.copyLns(), ln)
sort.Slice(newlns, func(i, j int) bool {
return newlns[i].needBytesNum < newlns[j].needBytesNum
})
mux.lns = newlns
return ln
}
func (mux *Mux) ListenHttp(priority int) net.Listener {
return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
}
func (mux *Mux) ListenHttps(priority int) net.Listener {
return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
}
func (mux *Mux) DefaultListener() net.Listener {
mux.mu.Lock()
defer mux.mu.Unlock()
if mux.defaultLn == nil {
mux.defaultLn = &listener{
c: make(chan net.Conn),
mux: mux,
}
}
return mux.defaultLn
}
func (mux *Mux) release(ln *listener) bool {
result := false
mux.mu.Lock()
defer mux.mu.Unlock()
lns := mux.copyLns()
for i, l := range lns {
if l == ln {
lns = append(lns[:i], lns[i+1:]...)
result = true
}
}
mux.lns = lns
return result
}
func (mux *Mux) copyLns() []*listener {
lns := make([]*listener, 0, len(mux.lns))
for _, l := range mux.lns {
lns = append(lns, l)
}
return lns
}
// Serve handles connections from ln and multiplexes then across registered listeners.
func (mux *Mux) Serve(ln net.Listener) error {
mux.mu.Lock()
mux.ln = ln
mux.mu.Unlock()
for {
// Wait for the next connection.
// If it returns a temporary error then simply retry.
// If it returns any other error then exit immediately.
conn, err := ln.Accept()
if err, ok := err.(interface {
Temporary() bool
}); ok && err.Temporary() {
continue
}
if err != nil {
return err
}
go mux.handleConn(conn)
}
}
func (mux *Mux) handleConn(conn net.Conn) {
mux.mu.RLock()
maxNeedBytesNum := mux.maxNeedBytesNum
lns := mux.lns
defaultLn := mux.defaultLn
mux.mu.RUnlock()
shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
data := make([]byte, maxNeedBytesNum)
conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
_, err := io.ReadFull(rd, data)
if err != nil {
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
for _, ln := range lns {
if match := ln.matchFn(data); match {
err = errors.PanicToError(func() {
ln.c <- shareConn
})
if err != nil {
conn.Close()
}
return
}
}
// No match listeners
if defaultLn != nil {
err = errors.PanicToError(func() {
defaultLn.c <- shareConn
})
if err != nil {
conn.Close()
}
return
}
// No listeners for this connection, close it.
conn.Close()
return
}
type listener struct {
mux *Mux
needBytesNum uint32
matchFn MatchFunc
c chan net.Conn
mu sync.RWMutex
}
// Accept waits for and returns the next connection to the listener.
func (ln *listener) Accept() (net.Conn, error) {
conn, ok := <-ln.c
if !ok {
return nil, fmt.Errorf("network connection closed")
}
return conn, nil
}
// Close removes this listener from the parent mux and closes the channel.
func (ln *listener) Close() error {
if ok := ln.mux.release(ln); ok {
// Close done to signal to any RLock holders to release their lock.
close(ln.c)
}
return nil
}
func (ln *listener) Addr() net.Addr {
if ln.mux == nil {
return nil
}
ln.mux.mu.RLock()
defer ln.mux.mu.RUnlock()
if ln.mux.ln == nil {
return nil
}
return ln.mux.ln.Addr()
}
package mux
import (
"bufio"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func runHttpSvr(ln net.Listener) *httptest.Server {
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("http service"))
}))
svr.Listener = ln
svr.Start()
return svr
}
func runHttpsSvr(ln net.Listener) *httptest.Server {
svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("https service"))
}))
svr.Listener = ln
svr.StartTLS()
return svr
}
func runEchoSvr(ln net.Listener) {
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
rd := bufio.NewReader(conn)
data, err := rd.ReadString('\n')
if err != nil {
return
}
conn.Write([]byte(data))
conn.Close()
}
}()
}
func TestMux(t *testing.T) {
assert := assert.New(t)
ln, err := net.Listen("tcp", "127.0.0.1:")
assert.NoError(err)
mux := NewMux()
httpLn := mux.ListenHttp(0)
httpsLn := mux.ListenHttps(0)
defaultLn := mux.DefaultListener()
go mux.Serve(ln)
time.Sleep(100 * time.Millisecond)
httpSvr := runHttpSvr(httpLn)
defer httpSvr.Close()
httpsSvr := runHttpsSvr(httpsLn)
defer httpsSvr.Close()
runEchoSvr(defaultLn)
defer ln.Close()
// test http service
resp, err := http.Get(httpSvr.URL)
assert.NoError(err)
data, err := ioutil.ReadAll(resp.Body)
assert.NoError(err)
assert.Equal("http service", string(data))
// test https service
client := httpsSvr.Client()
resp, err = client.Get(httpsSvr.URL)
assert.NoError(err)
data, err = ioutil.ReadAll(resp.Body)
assert.NoError(err)
assert.Equal("https service", string(data))
// test echo service
conn, err := net.Dial("tcp", ln.Addr().String())
assert.NoError(err)
_, err = conn.Write([]byte("test echo\n"))
assert.NoError(err)
data = make([]byte, 1024)
n, err := conn.Read(data)
assert.NoError(err)
assert.Equal("test echo\n", string(data[:n]))
}
package mux
type MatchFunc func(data []byte) (match bool)
var (
HttpsNeedBytesNum uint32 = 1
HttpNeedBytesNum uint32 = 3
YamuxNeedBytesNum uint32 = 2
)
var HttpsMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(HttpsNeedBytesNum) {
return false
}
if data[0] == 0x16 {
return true
} else {
return false
}
}
// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
var httpHeadBytes = map[string]struct{}{
"GET": struct{}{},
"HEA": struct{}{},
"POS": struct{}{},
"PUT": struct{}{},
"DEL": struct{}{},
"CON": struct{}{},
"OPT": struct{}{},
"TRA": struct{}{},
"PAT": struct{}{},
}
var HttpMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(HttpNeedBytesNum) {
return false
}
_, ok := httpHeadBytes[string(data[:3])]
return ok
}
// From https://github.com/hashicorp/yamux/blob/master/spec.md
var YamuxMatchFunc MatchFunc = func(data []byte) bool {
if len(data) < int(YamuxNeedBytesNum) {
return false
}
if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
return true
}
return false
}
......@@ -55,14 +55,17 @@ func readHandshake(rd io.Reader) (host string, err error) {
data := pool.GetBuf(1024)
origin := data
defer pool.PutBuf(origin)
length, err := rd.Read(data)
_, err = io.ReadFull(rd, data[:47])
if err != nil {
return
}
length, err := rd.Read(data[47:])
if err != nil {
return
} else {
if length < 47 {
err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
return
}
length += 47
}
data = data[:length]
if uint8(data[5]) != typeClientHello {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册