diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini index c3d13ffd58312914cbc6b6f7d4b9afaf9efe7d63..5a2e95d94ccd599f02b71b35de6168508fbea6ff 100644 --- a/conf/frpc_full.ini +++ b/conf/frpc_full.ini @@ -41,7 +41,7 @@ user = your_name login_fail_exit = true # communication protocol used to connect to server -# now it supports tcp and kcp, default is tcp +# now it supports tcp and kcp and websocket, default is tcp protocol = tcp # specify a dns server, so frpc will use this instead of default one diff --git a/models/config/client_common.go b/models/config/client_common.go index 95a383af402a6fefe2476020135c1bcafc6f85d1..c1d61cb4d395726cbc1dc10eaf587df30a47b189 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -187,7 +187,7 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c if tmpStr, ok = conf.Get("common", "protocol"); ok { // Now it only support tcp and kcp. - if tmpStr != "kcp" { + if tmpStr != "kcp" && tmpStr != "websocket" { tmpStr = "tcp" } cfg.Protocol = tmpStr diff --git a/server/service.go b/server/service.go index a9b14a62a993575d06c6c78c8e6a25f474a803f3..dcb7a2ba0447021116687a5c89fb6d9db85fe338 100644 --- a/server/service.go +++ b/server/service.go @@ -19,6 +19,7 @@ import ( "io/ioutil" "net" "net/http" + "strings" "time" "github.com/fatedier/frp/assets" @@ -53,6 +54,9 @@ type Service struct { // Accept connections using kcp kcpListener frpNet.Listener + // Accept connections using websocket + websocketListener frpNet.Listener + // For https proxies, route requests to different clients by hostname and other infomation VhostHttpsMuxer *vhost.HttpsMuxer @@ -109,9 +113,6 @@ func NewService() (svr *Service, err error) { if cfg.BindPort == cfg.VhostHttpsPort { httpsMuxOn = true } - if httpMuxOn || httpsMuxOn { - svr.muxer = mux.NewMux() - } } // Listen for accepting connections from client. @@ -120,10 +121,11 @@ func NewService() (svr *Service, err error) { err = fmt.Errorf("Create server listener error, %v", err) return } - if svr.muxer != nil { - go svr.muxer.Serve(ln) - ln = svr.muxer.DefaultListener() - } + + svr.muxer = mux.NewMux() + go svr.muxer.Serve(ln) + ln = svr.muxer.DefaultListener() + svr.listener = frpNet.WrapLogListener(ln) log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort) @@ -148,16 +150,14 @@ func NewService() (svr *Service, err error) { Handler: rp, } var l net.Listener - if httpMuxOn { - l = svr.muxer.ListenHttp(0) - } else { + if !httpMuxOn { l, err = net.Listen("tcp", address) if err != nil { err = fmt.Errorf("Create vhost http listener error, %v", err) return } + go server.Serve(l) } - go server.Serve(l) log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort) } @@ -204,6 +204,38 @@ func NewService() (svr *Service, err error) { } log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort) } + + if !httpMuxOn { + svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil) + return + } + + // server := &http.Server{} + if httpMuxOn { + rp := svr.httpReverseProxy + svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), + func(w http.ResponseWriter, req *http.Request) bool { + domain := getHostFromAddr(req.Host) + location := req.URL.Path + headers := rp.GetHeaders(domain, location) + if headers == nil { + return true + } + rp.ServeHTTP(w, req) + return false + }) + } + + return +} + +func getHostFromAddr(addr string) (host string) { + strs := strings.Split(addr, ":") + if len(strs) > 1 { + host = strs[0] + } else { + host = addr + } return } @@ -214,8 +246,10 @@ func (svr *Service) Run() { if g.GlbServerCfg.KcpBindPort > 0 { go svr.HandleListener(svr.kcpListener) } + if svr.websocketListener != nil { + go svr.HandleListener(svr.websocketListener) + } svr.HandleListener(svr.listener) - } func (svr *Service) HandleListener(l frpNet.Listener) { @@ -226,7 +260,6 @@ func (svr *Service) HandleListener(l frpNet.Listener) { log.Warn("Listener for incoming connections from client closed") return } - // Start a new goroutine for dealing connections. go func(frpConn frpNet.Conn) { dealFn := func(conn frpNet.Conn) { diff --git a/utils/net/conn.go b/utils/net/conn.go index 81dc82bbb02aa12173b02e82d75e6de9a08b6190..825a9896dd31f233727e63a596def371bccd8a0a 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -132,6 +132,8 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn case "kcp": // http proxy is not supported for kcp return ConnectServer(protocol, addr) + case "websocket": + return ConnectWebsocketServer(addr) default: return nil, fmt.Errorf("unsupport protocol: %s", protocol) } diff --git a/utils/net/websocket.go b/utils/net/websocket.go new file mode 100644 index 0000000000000000000000000000000000000000..0411112999f562436fe8cb56bdc4cb82ef447ee2 --- /dev/null +++ b/utils/net/websocket.go @@ -0,0 +1,127 @@ +package net + +import ( + "fmt" + "net" + "net/http" + "net/url" + "sync/atomic" + "time" + + "github.com/fatedier/frp/utils/log" + "golang.org/x/net/websocket" +) + +type WebsocketListener struct { + log.Logger + server *http.Server + httpMutex *http.ServeMux + connChan chan *WebsocketConn + closeFlag bool +} + +func NewWebsocketListener(ln net.Listener, + filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) { + l = &WebsocketListener{ + httpMutex: http.NewServeMux(), + connChan: make(chan *WebsocketConn), + Logger: log.NewPrefixLogger(""), + } + l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) { + conn := NewWebScoketConn(c) + l.connChan <- conn + conn.waitClose() + })) + l.server = &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if filter != nil && !filter(w, r) { + return + } + l.httpMutex.ServeHTTP(w, r) + }), + } + ch := make(chan struct{}) + go func() { + close(ch) + err = l.server.Serve(ln) + }() + <-ch + <-time.After(time.Millisecond) + return +} + +func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) { + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + if err != nil { + return + } + l, err = NewWebsocketListener(ln, nil) + return +} + +func (p *WebsocketListener) Accept() (Conn, error) { + c := <-p.connChan + return c, nil +} + +func (p *WebsocketListener) Close() error { + if !p.closeFlag { + p.closeFlag = true + p.server.Close() + } + return nil +} + +type WebsocketConn struct { + net.Conn + log.Logger + closed int32 + wait chan struct{} +} + +func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) { + c = &WebsocketConn{ + Conn: conn, + Logger: log.NewPrefixLogger(""), + wait: make(chan struct{}), + } + return +} + +func (p *WebsocketConn) Close() error { + if atomic.SwapInt32(&p.closed, 1) == 1 { + return nil + } + close(p.wait) + return p.Conn.Close() +} + +func (p *WebsocketConn) waitClose() { + <-p.wait +} + +// ConnectWebsocketServer : +// addr: ws://domain:port +func ConnectWebsocketServer(addr string) (c Conn, err error) { + addr = "ws://" + addr + uri, err := url.Parse(addr) + if err != nil { + return + } + + origin := "http://" + uri.Host + cfg, err := websocket.NewConfig(addr, origin) + if err != nil { + return + } + cfg.Dialer = &net.Dialer{ + Timeout: time.Second * 10, + } + + conn, err := websocket.DialConfig(cfg) + if err != nil { + return + } + c = NewWebScoketConn(conn) + return +}