未验证 提交 32ee7280 编写于 作者: A astaxie 提交者: GitHub

Merge pull request #3586 from astaxie/develop

V1.12.0
github.com/astaxie/beego/*/*:S1012
github.com/astaxie/beego/*:S1012
github.com/astaxie/beego/*/*:S1007
github.com/astaxie/beego/*:S1007
\ No newline at end of file
language: go language: go
go: go:
- "1.10.x"
- "1.11.x" - "1.11.x"
services: services:
- redis-server - redis-server
...@@ -9,9 +8,19 @@ services: ...@@ -9,9 +8,19 @@ services:
- postgresql - postgresql
- memcached - memcached
env: env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db global:
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - GO_REPO_FULLNAME="github.com/astaxie/beego"
matrix:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
before_install: before_install:
# link the local repo with ${GOPATH}/src/<namespace>/<repo>
- GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*}
# relies on GOPATH to contain only one directory...
- mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE}
- ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME}
- cd ${GOPATH}/src/${GO_REPO_FULLNAME}
# get and build ssdb
- git clone git://github.com/ideawu/ssdb.git - git clone git://github.com/ideawu/ssdb.git
- cd ssdb - cd ssdb
- make - make
...@@ -35,7 +44,9 @@ install: ...@@ -35,7 +44,9 @@ install:
- go get github.com/Knetic/govaluate - go get github.com/Knetic/govaluate
- go get github.com/casbin/casbin - go get github.com/casbin/casbin
- go get github.com/elazarl/go-bindata-assetfs - go get github.com/elazarl/go-bindata-assetfs
- go get -u honnef.co/go/tools/cmd/gosimple - go get github.com/OwnLocal/goes
- go get github.com/shiena/ansicolor
- go get -u honnef.co/go/tools/cmd/staticcheck
- go get -u github.com/mdempsky/unconvert - go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign - go get -u github.com/gordonklaus/ineffassign
- go get -u github.com/golang/lint/golint - go get -u github.com/golang/lint/golint
...@@ -54,7 +65,7 @@ after_script: ...@@ -54,7 +65,7 @@ after_script:
- rm -rf ./res/var/* - rm -rf ./res/var/*
script: script:
- go test -v ./... - go test -v ./...
- gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/) - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024"
- unconvert $(go list ./... | grep -v /vendor/) - unconvert $(go list ./... | grep -v /vendor/)
- ineffassign . - ineffassign .
- find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s
......
...@@ -176,7 +176,7 @@ func (app *App) Run(mws ...MiddleWare) { ...@@ -176,7 +176,7 @@ func (app *App) Run(mws ...MiddleWare) {
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
} else if BConfig.Listen.EnableHTTP { } else if BConfig.Listen.EnableHTTP {
BeeLogger.Info("Start https server error, conflict with http. Please reset https port") logs.Info("Start https server error, conflict with http. Please reset https port")
return return
} }
logs.Info("https server Running on https://%s", app.Server.Addr) logs.Info("https server Running on https://%s", app.Server.Addr)
...@@ -192,7 +192,7 @@ func (app *App) Run(mws ...MiddleWare) { ...@@ -192,7 +192,7 @@ func (app *App) Run(mws ...MiddleWare) {
pool := x509.NewCertPool() pool := x509.NewCertPool()
data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile)
if err != nil { if err != nil {
BeeLogger.Info("MutualHTTPS should provide TrustCaFile") logs.Info("MutualHTTPS should provide TrustCaFile")
return return
} }
pool.AppendCertsFromPEM(data) pool.AppendCertsFromPEM(data)
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.11.1" VERSION = "1.12.0"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
......
...@@ -16,10 +16,33 @@ package cache ...@@ -16,10 +16,33 @@ package cache
import ( import (
"os" "os"
"sync"
"testing" "testing"
"time" "time"
) )
func TestCacheIncr(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`)
if err != nil {
t.Error("init err")
}
//timeoutDuration := 10 * time.Second
bm.Put("edwardhey", 0, time.Second*20)
wg := sync.WaitGroup{}
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
bm.Incr("edwardhey")
}()
}
wg.Wait()
if bm.Get("edwardhey").(int) != 10 {
t.Error("Incr err")
}
}
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`) bm, err := NewCache("memory", `{"interval":20}`)
if err != nil { if err != nil {
...@@ -98,7 +121,7 @@ func TestCache(t *testing.T) { ...@@ -98,7 +121,7 @@ func TestCache(t *testing.T) {
} }
func TestFileCache(t *testing.T) { func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`) bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
......
...@@ -62,11 +62,14 @@ func NewFileCache() Cache { ...@@ -62,11 +62,14 @@ func NewFileCache() Cache {
} }
// StartAndGC will start and begin gc for file cache. // StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} // the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}
func (fc *FileCache) StartAndGC(config string) error { func (fc *FileCache) StartAndGC(config string) error {
var cfg map[string]string cfg := make(map[string]string)
json.Unmarshal([]byte(config), &cfg) err := json.Unmarshal([]byte(config), &cfg)
if err != nil {
return err
}
if _, ok := cfg["CachePath"]; !ok { if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath cfg["CachePath"] = FileCachePath
} }
...@@ -142,12 +145,12 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} { ...@@ -142,12 +145,12 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} {
// Put value into file cache. // Put value into file cache.
// timeout means how long to keep this file, unit of ms. // timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. // if timeout equals fc.EmbedExpiry(default is 0), cache this item forever.
func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val) gob.Register(val)
item := FileCacheItem{Data: val} item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry { if timeout == time.Duration(fc.EmbedExpiry) {
item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else { } else {
item.Expired = time.Now().Add(timeout) item.Expired = time.Now().Add(timeout)
...@@ -179,7 +182,7 @@ func (fc *FileCache) Incr(key string) error { ...@@ -179,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
} else { } else {
incr = data.(int) + 1 incr = data.(int) + 1
} }
fc.Put(key, incr, FileCacheEmbedExpiry) fc.Put(key, incr, time.Duration(fc.EmbedExpiry))
return nil return nil
} }
...@@ -192,7 +195,7 @@ func (fc *FileCache) Decr(key string) error { ...@@ -192,7 +195,7 @@ func (fc *FileCache) Decr(key string) error {
} else { } else {
decr = data.(int) - 1 decr = data.(int) - 1
} }
fc.Put(key, decr, FileCacheEmbedExpiry) fc.Put(key, decr, time.Duration(fc.EmbedExpiry))
return nil return nil
} }
......
...@@ -146,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool { ...@@ -146,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
} }
} }
_, err := rc.conn.Get(key) _, err := rc.conn.Get(key)
return !(err != nil) return err == nil
} }
// ClearAll clear all cached in memcache. // ClearAll clear all cached in memcache.
......
...@@ -110,25 +110,25 @@ func (bc *MemoryCache) Delete(name string) error { ...@@ -110,25 +110,25 @@ func (bc *MemoryCache) Delete(name string) error {
// Incr increase cache counter in memory. // Incr increase cache counter in memory.
// it supports int,int32,int64,uint,uint32,uint64. // it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error { func (bc *MemoryCache) Incr(key string) error {
bc.RLock() bc.Lock()
defer bc.RUnlock() defer bc.Unlock()
itm, ok := bc.items[key] itm, ok := bc.items[key]
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch itm.val.(type) { switch val := itm.val.(type) {
case int: case int:
itm.val = itm.val.(int) + 1 itm.val = val + 1
case int32: case int32:
itm.val = itm.val.(int32) + 1 itm.val = val + 1
case int64: case int64:
itm.val = itm.val.(int64) + 1 itm.val = val + 1
case uint: case uint:
itm.val = itm.val.(uint) + 1 itm.val = val + 1
case uint32: case uint32:
itm.val = itm.val.(uint32) + 1 itm.val = val + 1
case uint64: case uint64:
itm.val = itm.val.(uint64) + 1 itm.val = val + 1
default: default:
return errors.New("item val is not (u)int (u)int32 (u)int64") return errors.New("item val is not (u)int (u)int32 (u)int64")
} }
...@@ -137,34 +137,34 @@ func (bc *MemoryCache) Incr(key string) error { ...@@ -137,34 +137,34 @@ func (bc *MemoryCache) Incr(key string) error {
// Decr decrease counter in memory. // Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error { func (bc *MemoryCache) Decr(key string) error {
bc.RLock() bc.Lock()
defer bc.RUnlock() defer bc.Unlock()
itm, ok := bc.items[key] itm, ok := bc.items[key]
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch itm.val.(type) { switch val := itm.val.(type) {
case int: case int:
itm.val = itm.val.(int) - 1 itm.val = val - 1
case int64: case int64:
itm.val = itm.val.(int64) - 1 itm.val = val - 1
case int32: case int32:
itm.val = itm.val.(int32) - 1 itm.val = val - 1
case uint: case uint:
if itm.val.(uint) > 0 { if val > 0 {
itm.val = itm.val.(uint) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
case uint32: case uint32:
if itm.val.(uint32) > 0 { if val > 0 {
itm.val = itm.val.(uint32) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
case uint64: case uint64:
if itm.val.(uint64) > 0 { if val > 0 {
itm.val = itm.val.(uint64) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
......
...@@ -97,7 +97,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { ...@@ -97,7 +97,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
} }
} }
data, err := goyaml2.Read(bytes.NewBuffer(buf)) data, err := goyaml2.Read(bytes.NewReader(buf))
if err != nil { if err != nil {
log.Println("Goyaml2 ERR>", string(buf), err) log.Println("Goyaml2 ERR>", string(buf), err)
return return
......
...@@ -27,6 +27,7 @@ import ( ...@@ -27,6 +27,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
...@@ -49,6 +50,7 @@ type BeegoInput struct { ...@@ -49,6 +50,7 @@ type BeegoInput struct {
pnames []string pnames []string
pvalues []string pvalues []string
data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
dataLock sync.RWMutex
RequestBody []byte RequestBody []byte
RunMethod string RunMethod string
RunController reflect.Type RunController reflect.Type
...@@ -204,6 +206,7 @@ func (input *BeegoInput) AcceptsXML() bool { ...@@ -204,6 +206,7 @@ func (input *BeegoInput) AcceptsXML() bool {
func (input *BeegoInput) AcceptsJSON() bool { func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJSONRegex.MatchString(input.Header("Accept")) return acceptsJSONRegex.MatchString(input.Header("Accept"))
} }
// AcceptsYAML Checks if request accepts json response // AcceptsYAML Checks if request accepts json response
func (input *BeegoInput) AcceptsYAML() bool { func (input *BeegoInput) AcceptsYAML() bool {
return acceptsYAMLRegex.MatchString(input.Header("Accept")) return acceptsYAMLRegex.MatchString(input.Header("Accept"))
...@@ -377,6 +380,8 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { ...@@ -377,6 +380,8 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
// Data return the implicit data in the input // Data return the implicit data in the input
func (input *BeegoInput) Data() map[interface{}]interface{} { func (input *BeegoInput) Data() map[interface{}]interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil { if input.data == nil {
input.data = make(map[interface{}]interface{}) input.data = make(map[interface{}]interface{})
} }
...@@ -385,6 +390,8 @@ func (input *BeegoInput) Data() map[interface{}]interface{} { ...@@ -385,6 +390,8 @@ func (input *BeegoInput) Data() map[interface{}]interface{} {
// GetData returns the stored data in this context. // GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} { func (input *BeegoInput) GetData(key interface{}) interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if v, ok := input.data[key]; ok { if v, ok := input.data[key]; ok {
return v return v
} }
...@@ -394,6 +401,8 @@ func (input *BeegoInput) GetData(key interface{}) interface{} { ...@@ -394,6 +401,8 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context. // SetData stores data with given key in this context.
// This data are only available in this context. // This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) { func (input *BeegoInput) SetData(key, val interface{}) {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil { if input.data == nil {
input.data = make(map[interface{}]interface{}) input.data = make(map[interface{}]interface{})
} }
......
...@@ -30,7 +30,8 @@ import ( ...@@ -30,7 +30,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v2"
) )
// BeegoOutput does work for sending response header. // BeegoOutput does work for sending response header.
...@@ -203,7 +204,6 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) ...@@ -203,7 +204,6 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool)
return output.Body(content) return output.Body(content)
} }
// YAML writes yaml to response body. // YAML writes yaml to response body.
func (output *BeegoOutput) YAML(data interface{}) error { func (output *BeegoOutput) YAML(data interface{}) error {
output.Header("Content-Type", "application/x-yaml; charset=utf-8") output.Header("Content-Type", "application/x-yaml; charset=utf-8")
...@@ -288,7 +288,20 @@ func (output *BeegoOutput) Download(file string, filename ...string) { ...@@ -288,7 +288,20 @@ func (output *BeegoOutput) Download(file string, filename ...string) {
} else { } else {
fName = filepath.Base(file) fName = filepath.Base(file)
} }
output.Header("Content-Disposition", "attachment; filename="+url.PathEscape(fName)) //https://tools.ietf.org/html/rfc6266#section-4.3
fn := url.PathEscape(fName)
if fName == fn {
fn = "filename=" + fn
} else {
/**
The parameters "filename" and "filename*" differ only in that
"filename*" uses the encoding defined in [RFC5987], allowing the use
of characters not present in the ISO-8859-1 character set
([ISO-8859-1]).
*/
fn = "filename=" + fName + "; filename*=utf-8''" + fn
}
output.Header("Content-Disposition", "attachment; "+fn)
output.Header("Content-Description", "File Transfer") output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream") output.Header("Content-Type", "application/octet-stream")
output.Header("Content-Transfer-Encoding", "binary") output.Header("Content-Transfer-Encoding", "binary")
......
...@@ -17,6 +17,7 @@ package beego ...@@ -17,6 +17,7 @@ package beego
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"html/template" "html/template"
"io" "io"
"mime/multipart" "mime/multipart"
...@@ -34,7 +35,7 @@ import ( ...@@ -34,7 +35,7 @@ import (
var ( var (
// ErrAbort custom error when user stop request handler manually. // ErrAbort custom error when user stop request handler manually.
ErrAbort = errors.New("User stop run") ErrAbort = errors.New("user stop run")
// GlobalControllerRouter store comments with controller. pkgpath+controller:comments // GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments) GlobalControllerRouter = make(map[string][]ControllerComments)
) )
...@@ -93,7 +94,6 @@ type Controller struct { ...@@ -93,7 +94,6 @@ type Controller struct {
controllerName string controllerName string
actionName string actionName string
methodMapping map[string]func() //method:routertree methodMapping map[string]func() //method:routertree
gotofunc string
AppController interface{} AppController interface{}
// template data // template data
...@@ -125,6 +125,7 @@ type ControllerInterface interface { ...@@ -125,6 +125,7 @@ type ControllerInterface interface {
Head() Head()
Patch() Patch()
Options() Options()
Trace()
Finish() Finish()
Render() error Render() error
XSRFToken() string XSRFToken() string
...@@ -156,37 +157,59 @@ func (c *Controller) Finish() {} ...@@ -156,37 +157,59 @@ func (c *Controller) Finish() {}
// Get adds a request function to handle GET request. // Get adds a request function to handle GET request.
func (c *Controller) Get() { func (c *Controller) Get() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Post adds a request function to handle POST request. // Post adds a request function to handle POST request.
func (c *Controller) Post() { func (c *Controller) Post() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Delete adds a request function to handle DELETE request. // Delete adds a request function to handle DELETE request.
func (c *Controller) Delete() { func (c *Controller) Delete() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Put adds a request function to handle PUT request. // Put adds a request function to handle PUT request.
func (c *Controller) Put() { func (c *Controller) Put() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Head adds a request function to handle HEAD request. // Head adds a request function to handle HEAD request.
func (c *Controller) Head() { func (c *Controller) Head() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Patch adds a request function to handle PATCH request. // Patch adds a request function to handle PATCH request.
func (c *Controller) Patch() { func (c *Controller) Patch() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Options adds a request function to handle OPTIONS request. // Options adds a request function to handle OPTIONS request.
func (c *Controller) Options() { func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Trace adds a request function to handle Trace request.
// this method SHOULD NOT be overridden.
// https://tools.ietf.org/html/rfc7231#section-4.3.8
// The TRACE method requests a remote, application-level loop-back of
// the request message. The final recipient of the request SHOULD
// reflect the message received, excluding some fields described below,
// back to the client as the message body of a 200 (OK) response with a
// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]).
func (c *Controller) Trace() {
ts := func(h http.Header) (hs string) {
for k, v := range h {
hs += fmt.Sprintf("\r\n%s: %s", k, v)
}
return
}
hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header))
c.Ctx.Output.Header("Content-Type", "message/http")
c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs)))
c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Ctx.WriteString(hs)
} }
// HandlerFunc call function with the name // HandlerFunc call function with the name
...@@ -292,7 +315,7 @@ func (c *Controller) viewPath() string { ...@@ -292,7 +315,7 @@ func (c *Controller) viewPath() string {
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
func (c *Controller) Redirect(url string, code int) { func (c *Controller) Redirect(url string, code int) {
logAccess(c.Ctx, nil, code) LogAccess(c.Ctx, nil, code)
c.Ctx.Redirect(code, url) c.Ctx.Redirect(code, url)
} }
......
...@@ -435,7 +435,7 @@ func exception(errCode string, ctx *context.Context) { ...@@ -435,7 +435,7 @@ func exception(errCode string, ctx *context.Context) {
func executeError(err *errorInfo, ctx *context.Context, code int) { func executeError(err *errorInfo, ctx *context.Context, code int) {
//make sure to log the error in the access log //make sure to log the error in the access log
logAccess(ctx, nil, code) LogAccess(ctx, nil, code)
if err.errorType == errorTypeHandler { if err.errorType == errorTypeHandler {
ctx.ResponseWriter.WriteHeader(code) ctx.ResponseWriter.WriteHeader(code)
......
...@@ -42,13 +42,13 @@ func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.Wal ...@@ -42,13 +42,13 @@ func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.Wal
} }
dir, err := fs.Open(path) dir, err := fs.Open(path)
defer dir.Close()
if err != nil { if err != nil {
if err1 := walkFn(path, info, err); err1 != nil { if err1 := walkFn(path, info, err); err1 != nil {
return err1 return err1
} }
return err return err
} }
defer dir.Close()
dirs, err := dir.Readdir(-1) dirs, err := dir.Readdir(-1)
err1 := walkFn(path, info, err) err1 := walkFn(path, info, err)
// If err != nil, walk can't walk into this directory. // If err != nil, walk can't walk into this directory.
......
...@@ -2,9 +2,9 @@ module github.com/astaxie/beego ...@@ -2,9 +2,9 @@ module github.com/astaxie/beego
require ( require (
github.com/Knetic/govaluate v3.0.0+incompatible // indirect github.com/Knetic/govaluate v3.0.0+incompatible // indirect
github.com/OwnLocal/goes v1.0.0
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542
github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737
github.com/casbin/casbin v1.7.0 github.com/casbin/casbin v1.7.0
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58
......
package grace
import (
"errors"
"net"
"sync"
)
type graceConn struct {
net.Conn
server *Server
m sync.Mutex
closed bool
}
func (c *graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("Unknown panic")
}
}
}()
c.m.Lock()
if c.closed {
c.m.Unlock()
return
}
c.server.wg.Done()
c.closed = true
c.m.Unlock()
return c.Conn.Close()
}
...@@ -78,7 +78,7 @@ var ( ...@@ -78,7 +78,7 @@ var (
DefaultReadTimeOut time.Duration DefaultReadTimeOut time.Duration
// DefaultWriteTimeOut is the HTTP Write timeout // DefaultWriteTimeOut is the HTTP Write timeout
DefaultWriteTimeOut time.Duration DefaultWriteTimeOut time.Duration
// DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit
DefaultMaxHeaderBytes int DefaultMaxHeaderBytes int
// DefaultTimeout is the shutdown server's timeout. default is 60s // DefaultTimeout is the shutdown server's timeout. default is 60s
DefaultTimeout = 60 * time.Second DefaultTimeout = 60 * time.Second
...@@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { ...@@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
} }
srv = &Server{ srv = &Server{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal), sigChan: make(chan os.Signal),
isChild: isChild, isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){ SignalHooks: map[int]map[os.Signal][]func(){
...@@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { ...@@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
syscall.SIGTERM: {}, syscall.SIGTERM: {},
}, },
}, },
state: StateInit, state: StateInit,
Network: "tcp", Network: "tcp",
terminalChan: make(chan error), //no cache channel
}
srv.Server = &http.Server{
Addr: addr,
ReadTimeout: DefaultReadTimeOut,
WriteTimeout: DefaultWriteTimeOut,
MaxHeaderBytes: DefaultMaxHeaderBytes,
Handler: handler,
} }
srv.Server = &http.Server{}
srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler
runningServersOrder = append(runningServersOrder, addr) runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv runningServers[addr] = srv
return srv
return
} }
// ListenAndServe refer http.ListenAndServe // ListenAndServe refer http.ListenAndServe
......
package grace
import (
"net"
"os"
"syscall"
"time"
)
type graceListener struct {
net.Listener
stop chan error
stopped bool
server *Server
}
func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
server: srv,
}
go func() {
<-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
return
}
func (gl *graceListener) Accept() (c net.Conn, err error) {
tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
c = &graceConn{
Conn: tc,
server: gl.server,
}
gl.server.wg.Add(1)
return
}
func (gl *graceListener) Close() error {
if gl.stopped {
return syscall.EINVAL
}
gl.stop <- nil
return <-gl.stop
}
func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}
package grace package grace
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
...@@ -12,7 +13,6 @@ import ( ...@@ -12,7 +13,6 @@ import (
"os/exec" "os/exec"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
) )
...@@ -20,14 +20,13 @@ import ( ...@@ -20,14 +20,13 @@ import (
// Server embedded http.Server // Server embedded http.Server
type Server struct { type Server struct {
*http.Server *http.Server
GraceListener net.Listener ln net.Listener
SignalHooks map[int]map[os.Signal][]func() SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *graceListener sigChan chan os.Signal
wg sync.WaitGroup isChild bool
sigChan chan os.Signal state uint8
isChild bool Network string
state uint8 terminalChan chan error
Network string
} }
// Serve accepts incoming connections on the Listener l, // Serve accepts incoming connections on the Listener l,
...@@ -35,11 +34,19 @@ type Server struct { ...@@ -35,11 +34,19 @@ type Server struct {
// The service goroutines read requests and then call srv.Handler to reply to them. // The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) { func (srv *Server) Serve() (err error) {
srv.state = StateRunning srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener) defer func() { srv.state = StateTerminate }()
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait() // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
srv.state = StateTerminate // immediately return ErrServerClosed. Make sure the program doesn't exit
return // and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err
}
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
// wait for Shutdown to return
return <-srv.terminalChan
} }
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
...@@ -53,14 +60,12 @@ func (srv *Server) ListenAndServe() (err error) { ...@@ -53,14 +60,12 @@ func (srv *Server) ListenAndServe() (err error) {
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) srv.ln, err = srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.GraceListener = newGraceListener(l, srv)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
if err != nil { if err != nil {
...@@ -107,14 +112,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { ...@@ -107,14 +112,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
...@@ -127,6 +130,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { ...@@ -127,6 +130,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
return err return err
} }
} }
log.Println(os.Getpid(), srv.Addr) log.Println(os.Getpid(), srv.Addr)
return srv.Serve() return srv.Serve()
} }
...@@ -163,14 +167,12 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) ...@@ -163,14 +167,12 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
log.Println("Mutual HTTPS") log.Println("Mutual HTTPS")
go srv.handleSignals() go srv.handleSignals()
l, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
...@@ -183,6 +185,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) ...@@ -183,6 +185,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
return err return err
} }
} }
log.Println(os.Getpid(), srv.Addr) log.Println(os.Getpid(), srv.Addr)
return srv.Serve() return srv.Serve()
} }
...@@ -213,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) { ...@@ -213,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
return return
} }
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// handleSignals listens for os Signals and calls any hooked in function that the // handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal. // user had registered with the signal.
func (srv *Server) handleSignals() { func (srv *Server) handleSignals() {
...@@ -265,37 +282,14 @@ func (srv *Server) shutdown() { ...@@ -265,37 +282,14 @@ func (srv *Server) shutdown() {
} }
srv.state = StateShuttingDown srv.state = StateShuttingDown
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
ctx := context.Background()
if DefaultTimeout >= 0 { if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout) var cancel context.CancelFunc
} ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
err := srv.GraceListener.Close() defer cancel()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
}
}
// serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang
func (srv *Server) serverTimeout(d time.Duration) {
defer func() {
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
if srv.state != StateShuttingDown {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
if srv.state == StateTerminate {
break
}
srv.wg.Done()
} }
srv.terminalChan <- srv.Server.Shutdown(ctx)
} }
func (srv *Server) fork() (err error) { func (srv *Server) fork() (err error) {
...@@ -309,12 +303,8 @@ func (srv *Server) fork() (err error) { ...@@ -309,12 +303,8 @@ func (srv *Server) fork() (err error) {
var files = make([]*os.File, len(runningServers)) var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers)) var orderArgs = make([]string, len(runningServers))
for _, srvPtr := range runningServers { for _, srvPtr := range runningServers {
switch srvPtr.GraceListener.(type) { f, _ := srvPtr.ln.(*net.TCPListener).File()
case *graceListener: files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
default:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
}
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
} }
......
...@@ -206,10 +206,16 @@ func TestToJson(t *testing.T) { ...@@ -206,10 +206,16 @@ func TestToJson(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
t.Log(ip.Origin) t.Log(ip.Origin)
ips := strings.Split(ip.Origin, ",")
if n := strings.Count(ip.Origin, "."); n != 3 { if len(ips) == 0 {
t.Fatal("response is not valid ip") t.Fatal("response is not valid ip")
} }
for i := range ips {
if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil {
t.Fatal("response is not valid ip")
}
}
} }
func TestToFile(t *testing.T) { func TestToFile(t *testing.T) {
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
) )
// Log levels to control the logging output. // Log levels to control the logging output.
// Deprecated: use github.com/astaxie/beego/logs instead.
const ( const (
LevelEmergency = iota LevelEmergency = iota
LevelAlert LevelAlert
...@@ -33,75 +34,90 @@ const ( ...@@ -33,75 +34,90 @@ const (
) )
// BeeLogger references the used application logger. // BeeLogger references the used application logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
var BeeLogger = logs.GetBeeLogger() var BeeLogger = logs.GetBeeLogger()
// SetLevel sets the global log level used by the simple logger. // SetLevel sets the global log level used by the simple logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLevel(l int) { func SetLevel(l int) {
logs.SetLevel(l) logs.SetLevel(l)
} }
// SetLogFuncCall set the CallDepth, default is 3 // SetLogFuncCall set the CallDepth, default is 3
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLogFuncCall(b bool) { func SetLogFuncCall(b bool) {
logs.SetLogFuncCall(b) logs.SetLogFuncCall(b)
} }
// SetLogger sets a new logger. // SetLogger sets a new logger.
// Deprecated: use github.com/astaxie/beego/logs instead.
func SetLogger(adaptername string, config string) error { func SetLogger(adaptername string, config string) error {
return logs.SetLogger(adaptername, config) return logs.SetLogger(adaptername, config)
} }
// Emergency logs a message at emergency level. // Emergency logs a message at emergency level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Emergency(v ...interface{}) { func Emergency(v ...interface{}) {
logs.Emergency(generateFmtStr(len(v)), v...) logs.Emergency(generateFmtStr(len(v)), v...)
} }
// Alert logs a message at alert level. // Alert logs a message at alert level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Alert(v ...interface{}) { func Alert(v ...interface{}) {
logs.Alert(generateFmtStr(len(v)), v...) logs.Alert(generateFmtStr(len(v)), v...)
} }
// Critical logs a message at critical level. // Critical logs a message at critical level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Critical(v ...interface{}) { func Critical(v ...interface{}) {
logs.Critical(generateFmtStr(len(v)), v...) logs.Critical(generateFmtStr(len(v)), v...)
} }
// Error logs a message at error level. // Error logs a message at error level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Error(v ...interface{}) { func Error(v ...interface{}) {
logs.Error(generateFmtStr(len(v)), v...) logs.Error(generateFmtStr(len(v)), v...)
} }
// Warning logs a message at warning level. // Warning logs a message at warning level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Warning(v ...interface{}) { func Warning(v ...interface{}) {
logs.Warning(generateFmtStr(len(v)), v...) logs.Warning(generateFmtStr(len(v)), v...)
} }
// Warn compatibility alias for Warning() // Warn compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
logs.Warn(generateFmtStr(len(v)), v...) logs.Warn(generateFmtStr(len(v)), v...)
} }
// Notice logs a message at notice level. // Notice logs a message at notice level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Notice(v ...interface{}) { func Notice(v ...interface{}) {
logs.Notice(generateFmtStr(len(v)), v...) logs.Notice(generateFmtStr(len(v)), v...)
} }
// Informational logs a message at info level. // Informational logs a message at info level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Informational(v ...interface{}) { func Informational(v ...interface{}) {
logs.Informational(generateFmtStr(len(v)), v...) logs.Informational(generateFmtStr(len(v)), v...)
} }
// Info compatibility alias for Warning() // Info compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Info(v ...interface{}) { func Info(v ...interface{}) {
logs.Info(generateFmtStr(len(v)), v...) logs.Info(generateFmtStr(len(v)), v...)
} }
// Debug logs a message at debug level. // Debug logs a message at debug level.
// Deprecated: use github.com/astaxie/beego/logs instead.
func Debug(v ...interface{}) { func Debug(v ...interface{}) {
logs.Debug(generateFmtStr(len(v)), v...) logs.Debug(generateFmtStr(len(v)), v...)
} }
// Trace logs a message at trace level. // Trace logs a message at trace level.
// compatibility alias for Warning() // compatibility alias for Warning()
// Deprecated: use github.com/astaxie/beego/logs instead.
func Trace(v ...interface{}) { func Trace(v ...interface{}) {
logs.Trace(generateFmtStr(len(v)), v...) logs.Trace(generateFmtStr(len(v)), v...)
} }
......
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !windows
package logs
import "io"
type ansiColorWriter struct {
w io.Writer
mode outputMode
}
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
return cw.w.Write(p)
}
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build windows
package logs
import (
"bytes"
"io"
"strings"
"syscall"
"unsafe"
)
type (
csiState int
parseResult int
)
const (
outsideCsiCode csiState = iota
firstCsiCode
secondCsiCode
)
const (
noConsole parseResult = iota
changedColor
unknown
)
type ansiColorWriter struct {
w io.Writer
mode outputMode
state csiState
paramStartBuf bytes.Buffer
paramBuf bytes.Buffer
}
const (
firstCsiChar byte = '\x1b'
secondeCsiChar byte = '['
separatorChar byte = ';'
sgrCode byte = 'm'
)
const (
foregroundBlue = uint16(0x0001)
foregroundGreen = uint16(0x0002)
foregroundRed = uint16(0x0004)
foregroundIntensity = uint16(0x0008)
backgroundBlue = uint16(0x0010)
backgroundGreen = uint16(0x0020)
backgroundRed = uint16(0x0040)
backgroundIntensity = uint16(0x0080)
underscore = uint16(0x8000)
foregroundMask = foregroundBlue | foregroundGreen | foregroundRed | foregroundIntensity
backgroundMask = backgroundBlue | backgroundGreen | backgroundRed | backgroundIntensity
)
const (
ansiReset = "0"
ansiIntensityOn = "1"
ansiIntensityOff = "21"
ansiUnderlineOn = "4"
ansiUnderlineOff = "24"
ansiBlinkOn = "5"
ansiBlinkOff = "25"
ansiForegroundBlack = "30"
ansiForegroundRed = "31"
ansiForegroundGreen = "32"
ansiForegroundYellow = "33"
ansiForegroundBlue = "34"
ansiForegroundMagenta = "35"
ansiForegroundCyan = "36"
ansiForegroundWhite = "37"
ansiForegroundDefault = "39"
ansiBackgroundBlack = "40"
ansiBackgroundRed = "41"
ansiBackgroundGreen = "42"
ansiBackgroundYellow = "43"
ansiBackgroundBlue = "44"
ansiBackgroundMagenta = "45"
ansiBackgroundCyan = "46"
ansiBackgroundWhite = "47"
ansiBackgroundDefault = "49"
ansiLightForegroundGray = "90"
ansiLightForegroundRed = "91"
ansiLightForegroundGreen = "92"
ansiLightForegroundYellow = "93"
ansiLightForegroundBlue = "94"
ansiLightForegroundMagenta = "95"
ansiLightForegroundCyan = "96"
ansiLightForegroundWhite = "97"
ansiLightBackgroundGray = "100"
ansiLightBackgroundRed = "101"
ansiLightBackgroundGreen = "102"
ansiLightBackgroundYellow = "103"
ansiLightBackgroundBlue = "104"
ansiLightBackgroundMagenta = "105"
ansiLightBackgroundCyan = "106"
ansiLightBackgroundWhite = "107"
)
type drawType int
const (
foreground drawType = iota
background
)
type winColor struct {
code uint16
drawType drawType
}
var colorMap = map[string]winColor{
ansiForegroundBlack: {0, foreground},
ansiForegroundRed: {foregroundRed, foreground},
ansiForegroundGreen: {foregroundGreen, foreground},
ansiForegroundYellow: {foregroundRed | foregroundGreen, foreground},
ansiForegroundBlue: {foregroundBlue, foreground},
ansiForegroundMagenta: {foregroundRed | foregroundBlue, foreground},
ansiForegroundCyan: {foregroundGreen | foregroundBlue, foreground},
ansiForegroundWhite: {foregroundRed | foregroundGreen | foregroundBlue, foreground},
ansiForegroundDefault: {foregroundRed | foregroundGreen | foregroundBlue, foreground},
ansiBackgroundBlack: {0, background},
ansiBackgroundRed: {backgroundRed, background},
ansiBackgroundGreen: {backgroundGreen, background},
ansiBackgroundYellow: {backgroundRed | backgroundGreen, background},
ansiBackgroundBlue: {backgroundBlue, background},
ansiBackgroundMagenta: {backgroundRed | backgroundBlue, background},
ansiBackgroundCyan: {backgroundGreen | backgroundBlue, background},
ansiBackgroundWhite: {backgroundRed | backgroundGreen | backgroundBlue, background},
ansiBackgroundDefault: {0, background},
ansiLightForegroundGray: {foregroundIntensity, foreground},
ansiLightForegroundRed: {foregroundIntensity | foregroundRed, foreground},
ansiLightForegroundGreen: {foregroundIntensity | foregroundGreen, foreground},
ansiLightForegroundYellow: {foregroundIntensity | foregroundRed | foregroundGreen, foreground},
ansiLightForegroundBlue: {foregroundIntensity | foregroundBlue, foreground},
ansiLightForegroundMagenta: {foregroundIntensity | foregroundRed | foregroundBlue, foreground},
ansiLightForegroundCyan: {foregroundIntensity | foregroundGreen | foregroundBlue, foreground},
ansiLightForegroundWhite: {foregroundIntensity | foregroundRed | foregroundGreen | foregroundBlue, foreground},
ansiLightBackgroundGray: {backgroundIntensity, background},
ansiLightBackgroundRed: {backgroundIntensity | backgroundRed, background},
ansiLightBackgroundGreen: {backgroundIntensity | backgroundGreen, background},
ansiLightBackgroundYellow: {backgroundIntensity | backgroundRed | backgroundGreen, background},
ansiLightBackgroundBlue: {backgroundIntensity | backgroundBlue, background},
ansiLightBackgroundMagenta: {backgroundIntensity | backgroundRed | backgroundBlue, background},
ansiLightBackgroundCyan: {backgroundIntensity | backgroundGreen | backgroundBlue, background},
ansiLightBackgroundWhite: {backgroundIntensity | backgroundRed | backgroundGreen | backgroundBlue, background},
}
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
defaultAttr *textAttributes
)
func init() {
screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout))
if screenInfo != nil {
colorMap[ansiForegroundDefault] = winColor{
screenInfo.WAttributes & (foregroundRed | foregroundGreen | foregroundBlue),
foreground,
}
colorMap[ansiBackgroundDefault] = winColor{
screenInfo.WAttributes & (backgroundRed | backgroundGreen | backgroundBlue),
background,
}
defaultAttr = convertTextAttr(screenInfo.WAttributes)
}
}
type coord struct {
X, Y int16
}
type smallRect struct {
Left, Top, Right, Bottom int16
}
type consoleScreenBufferInfo struct {
DwSize coord
DwCursorPosition coord
WAttributes uint16
SrWindow smallRect
DwMaximumWindowSize coord
}
func getConsoleScreenBufferInfo(hConsoleOutput uintptr) *consoleScreenBufferInfo {
var csbi consoleScreenBufferInfo
ret, _, _ := procGetConsoleScreenBufferInfo.Call(
hConsoleOutput,
uintptr(unsafe.Pointer(&csbi)))
if ret == 0 {
return nil
}
return &csbi
}
func setConsoleTextAttribute(hConsoleOutput uintptr, wAttributes uint16) bool {
ret, _, _ := procSetConsoleTextAttribute.Call(
hConsoleOutput,
uintptr(wAttributes))
return ret != 0
}
type textAttributes struct {
foregroundColor uint16
backgroundColor uint16
foregroundIntensity uint16
backgroundIntensity uint16
underscore uint16
otherAttributes uint16
}
func convertTextAttr(winAttr uint16) *textAttributes {
fgColor := winAttr & (foregroundRed | foregroundGreen | foregroundBlue)
bgColor := winAttr & (backgroundRed | backgroundGreen | backgroundBlue)
fgIntensity := winAttr & foregroundIntensity
bgIntensity := winAttr & backgroundIntensity
underline := winAttr & underscore
otherAttributes := winAttr &^ (foregroundMask | backgroundMask | underscore)
return &textAttributes{fgColor, bgColor, fgIntensity, bgIntensity, underline, otherAttributes}
}
func convertWinAttr(textAttr *textAttributes) uint16 {
var winAttr uint16
winAttr |= textAttr.foregroundColor
winAttr |= textAttr.backgroundColor
winAttr |= textAttr.foregroundIntensity
winAttr |= textAttr.backgroundIntensity
winAttr |= textAttr.underscore
winAttr |= textAttr.otherAttributes
return winAttr
}
func changeColor(param []byte) parseResult {
screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout))
if screenInfo == nil {
return noConsole
}
winAttr := convertTextAttr(screenInfo.WAttributes)
strParam := string(param)
if len(strParam) <= 0 {
strParam = "0"
}
csiParam := strings.Split(strParam, string(separatorChar))
for _, p := range csiParam {
c, ok := colorMap[p]
switch {
case !ok:
switch p {
case ansiReset:
winAttr.foregroundColor = defaultAttr.foregroundColor
winAttr.backgroundColor = defaultAttr.backgroundColor
winAttr.foregroundIntensity = defaultAttr.foregroundIntensity
winAttr.backgroundIntensity = defaultAttr.backgroundIntensity
winAttr.underscore = 0
winAttr.otherAttributes = 0
case ansiIntensityOn:
winAttr.foregroundIntensity = foregroundIntensity
case ansiIntensityOff:
winAttr.foregroundIntensity = 0
case ansiUnderlineOn:
winAttr.underscore = underscore
case ansiUnderlineOff:
winAttr.underscore = 0
case ansiBlinkOn:
winAttr.backgroundIntensity = backgroundIntensity
case ansiBlinkOff:
winAttr.backgroundIntensity = 0
default:
// unknown code
}
case c.drawType == foreground:
winAttr.foregroundColor = c.code
case c.drawType == background:
winAttr.backgroundColor = c.code
}
}
winTextAttribute := convertWinAttr(winAttr)
setConsoleTextAttribute(uintptr(syscall.Stdout), winTextAttribute)
return changedColor
}
func parseEscapeSequence(command byte, param []byte) parseResult {
if defaultAttr == nil {
return noConsole
}
switch command {
case sgrCode:
return changeColor(param)
default:
return unknown
}
}
func (cw *ansiColorWriter) flushBuffer() (int, error) {
return cw.flushTo(cw.w)
}
func (cw *ansiColorWriter) resetBuffer() (int, error) {
return cw.flushTo(nil)
}
func (cw *ansiColorWriter) flushTo(w io.Writer) (int, error) {
var n1, n2 int
var err error
startBytes := cw.paramStartBuf.Bytes()
cw.paramStartBuf.Reset()
if w != nil {
n1, err = cw.w.Write(startBytes)
if err != nil {
return n1, err
}
} else {
n1 = len(startBytes)
}
paramBytes := cw.paramBuf.Bytes()
cw.paramBuf.Reset()
if w != nil {
n2, err = cw.w.Write(paramBytes)
if err != nil {
return n1 + n2, err
}
} else {
n2 = len(paramBytes)
}
return n1 + n2, nil
}
func isParameterChar(b byte) bool {
return ('0' <= b && b <= '9') || b == separatorChar
}
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
var r, nw, first, last int
if cw.mode != DiscardNonColorEscSeq {
cw.state = outsideCsiCode
cw.resetBuffer()
}
var err error
for i, ch := range p {
switch cw.state {
case outsideCsiCode:
if ch == firstCsiChar {
cw.paramStartBuf.WriteByte(ch)
cw.state = firstCsiCode
}
case firstCsiCode:
switch ch {
case firstCsiChar:
cw.paramStartBuf.WriteByte(ch)
break
case secondeCsiChar:
cw.paramStartBuf.WriteByte(ch)
cw.state = secondCsiCode
last = i - 1
default:
cw.resetBuffer()
cw.state = outsideCsiCode
}
case secondCsiCode:
if isParameterChar(ch) {
cw.paramBuf.WriteByte(ch)
} else {
nw, err = cw.w.Write(p[first:last])
r += nw
if err != nil {
return r, err
}
first = i + 1
result := parseEscapeSequence(ch, cw.paramBuf.Bytes())
if result == noConsole || (cw.mode == OutputNonColorEscSeq && result == unknown) {
cw.paramBuf.WriteByte(ch)
nw, err := cw.flushBuffer()
if err != nil {
return r, err
}
r += nw
} else {
n, _ := cw.resetBuffer()
// Add one more to the size of the buffer for the last ch
r += n + 1
}
cw.state = outsideCsiCode
}
default:
cw.state = outsideCsiCode
}
}
if cw.mode != DiscardNonColorEscSeq || cw.state == outsideCsiCode {
nw, err = cw.w.Write(p[first:])
r += nw
}
return r, err
}
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build windows
package logs
import (
"bytes"
"fmt"
"syscall"
"testing"
)
var GetConsoleScreenBufferInfo = getConsoleScreenBufferInfo
func ChangeColor(color uint16) {
setConsoleTextAttribute(uintptr(syscall.Stdout), color)
}
func ResetColor() {
ChangeColor(uint16(0x0007))
}
func TestWritePlanText(t *testing.T) {
inner := bytes.NewBufferString("")
w := NewAnsiColorWriter(inner)
expected := "plain text"
fmt.Fprintf(w, expected)
actual := inner.String()
if actual != expected {
t.Errorf("Get %q, want %q", actual, expected)
}
}
func TestWriteParseText(t *testing.T) {
inner := bytes.NewBufferString("")
w := NewAnsiColorWriter(inner)
inputTail := "\x1b[0mtail text"
expectedTail := "tail text"
fmt.Fprintf(w, inputTail)
actualTail := inner.String()
inner.Reset()
if actualTail != expectedTail {
t.Errorf("Get %q, want %q", actualTail, expectedTail)
}
inputHead := "head text\x1b[0m"
expectedHead := "head text"
fmt.Fprintf(w, inputHead)
actualHead := inner.String()
inner.Reset()
if actualHead != expectedHead {
t.Errorf("Get %q, want %q", actualHead, expectedHead)
}
inputBothEnds := "both ends \x1b[0m text"
expectedBothEnds := "both ends text"
fmt.Fprintf(w, inputBothEnds)
actualBothEnds := inner.String()
inner.Reset()
if actualBothEnds != expectedBothEnds {
t.Errorf("Get %q, want %q", actualBothEnds, expectedBothEnds)
}
inputManyEsc := "\x1b\x1b\x1b\x1b[0m many esc"
expectedManyEsc := "\x1b\x1b\x1b many esc"
fmt.Fprintf(w, inputManyEsc)
actualManyEsc := inner.String()
inner.Reset()
if actualManyEsc != expectedManyEsc {
t.Errorf("Get %q, want %q", actualManyEsc, expectedManyEsc)
}
expectedSplit := "split text"
for _, ch := range "split \x1b[0m text" {
fmt.Fprintf(w, string(ch))
}
actualSplit := inner.String()
inner.Reset()
if actualSplit != expectedSplit {
t.Errorf("Get %q, want %q", actualSplit, expectedSplit)
}
}
type screenNotFoundError struct {
error
}
func writeAnsiColor(expectedText, colorCode string) (actualText string, actualAttributes uint16, err error) {
inner := bytes.NewBufferString("")
w := NewAnsiColorWriter(inner)
fmt.Fprintf(w, "\x1b[%sm%s", colorCode, expectedText)
actualText = inner.String()
screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout))
if screenInfo != nil {
actualAttributes = screenInfo.WAttributes
} else {
err = &screenNotFoundError{}
}
return
}
type testParam struct {
text string
attributes uint16
ansiColor string
}
func TestWriteAnsiColorText(t *testing.T) {
screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout))
if screenInfo == nil {
t.Fatal("Could not get ConsoleScreenBufferInfo")
}
defer ChangeColor(screenInfo.WAttributes)
defaultFgColor := screenInfo.WAttributes & uint16(0x0007)
defaultBgColor := screenInfo.WAttributes & uint16(0x0070)
defaultFgIntensity := screenInfo.WAttributes & uint16(0x0008)
defaultBgIntensity := screenInfo.WAttributes & uint16(0x0080)
fgParam := []testParam{
{"foreground black ", uint16(0x0000 | 0x0000), "30"},
{"foreground red ", uint16(0x0004 | 0x0000), "31"},
{"foreground green ", uint16(0x0002 | 0x0000), "32"},
{"foreground yellow ", uint16(0x0006 | 0x0000), "33"},
{"foreground blue ", uint16(0x0001 | 0x0000), "34"},
{"foreground magenta", uint16(0x0005 | 0x0000), "35"},
{"foreground cyan ", uint16(0x0003 | 0x0000), "36"},
{"foreground white ", uint16(0x0007 | 0x0000), "37"},
{"foreground default", defaultFgColor | 0x0000, "39"},
{"foreground light gray ", uint16(0x0000 | 0x0008 | 0x0000), "90"},
{"foreground light red ", uint16(0x0004 | 0x0008 | 0x0000), "91"},
{"foreground light green ", uint16(0x0002 | 0x0008 | 0x0000), "92"},
{"foreground light yellow ", uint16(0x0006 | 0x0008 | 0x0000), "93"},
{"foreground light blue ", uint16(0x0001 | 0x0008 | 0x0000), "94"},
{"foreground light magenta", uint16(0x0005 | 0x0008 | 0x0000), "95"},
{"foreground light cyan ", uint16(0x0003 | 0x0008 | 0x0000), "96"},
{"foreground light white ", uint16(0x0007 | 0x0008 | 0x0000), "97"},
}
bgParam := []testParam{
{"background black ", uint16(0x0007 | 0x0000), "40"},
{"background red ", uint16(0x0007 | 0x0040), "41"},
{"background green ", uint16(0x0007 | 0x0020), "42"},
{"background yellow ", uint16(0x0007 | 0x0060), "43"},
{"background blue ", uint16(0x0007 | 0x0010), "44"},
{"background magenta", uint16(0x0007 | 0x0050), "45"},
{"background cyan ", uint16(0x0007 | 0x0030), "46"},
{"background white ", uint16(0x0007 | 0x0070), "47"},
{"background default", uint16(0x0007) | defaultBgColor, "49"},
{"background light gray ", uint16(0x0007 | 0x0000 | 0x0080), "100"},
{"background light red ", uint16(0x0007 | 0x0040 | 0x0080), "101"},
{"background light green ", uint16(0x0007 | 0x0020 | 0x0080), "102"},
{"background light yellow ", uint16(0x0007 | 0x0060 | 0x0080), "103"},
{"background light blue ", uint16(0x0007 | 0x0010 | 0x0080), "104"},
{"background light magenta", uint16(0x0007 | 0x0050 | 0x0080), "105"},
{"background light cyan ", uint16(0x0007 | 0x0030 | 0x0080), "106"},
{"background light white ", uint16(0x0007 | 0x0070 | 0x0080), "107"},
}
resetParam := []testParam{
{"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, "0"},
{"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, ""},
}
boldParam := []testParam{
{"bold on", uint16(0x0007 | 0x0008), "1"},
{"bold off", uint16(0x0007), "21"},
}
underscoreParam := []testParam{
{"underscore on", uint16(0x0007 | 0x8000), "4"},
{"underscore off", uint16(0x0007), "24"},
}
blinkParam := []testParam{
{"blink on", uint16(0x0007 | 0x0080), "5"},
{"blink off", uint16(0x0007), "25"},
}
mixedParam := []testParam{
{"both black, bold, underline, blink", uint16(0x0000 | 0x0000 | 0x0008 | 0x8000 | 0x0080), "30;40;1;4;5"},
{"both red, bold, underline, blink", uint16(0x0004 | 0x0040 | 0x0008 | 0x8000 | 0x0080), "31;41;1;4;5"},
{"both green, bold, underline, blink", uint16(0x0002 | 0x0020 | 0x0008 | 0x8000 | 0x0080), "32;42;1;4;5"},
{"both yellow, bold, underline, blink", uint16(0x0006 | 0x0060 | 0x0008 | 0x8000 | 0x0080), "33;43;1;4;5"},
{"both blue, bold, underline, blink", uint16(0x0001 | 0x0010 | 0x0008 | 0x8000 | 0x0080), "34;44;1;4;5"},
{"both magenta, bold, underline, blink", uint16(0x0005 | 0x0050 | 0x0008 | 0x8000 | 0x0080), "35;45;1;4;5"},
{"both cyan, bold, underline, blink", uint16(0x0003 | 0x0030 | 0x0008 | 0x8000 | 0x0080), "36;46;1;4;5"},
{"both white, bold, underline, blink", uint16(0x0007 | 0x0070 | 0x0008 | 0x8000 | 0x0080), "37;47;1;4;5"},
{"both default, bold, underline, blink", uint16(defaultFgColor | defaultBgColor | 0x0008 | 0x8000 | 0x0080), "39;49;1;4;5"},
}
assertTextAttribute := func(expectedText string, expectedAttributes uint16, ansiColor string) {
actualText, actualAttributes, err := writeAnsiColor(expectedText, ansiColor)
if actualText != expectedText {
t.Errorf("Get %q, want %q", actualText, expectedText)
}
if err != nil {
t.Fatal("Could not get ConsoleScreenBufferInfo")
}
if actualAttributes != expectedAttributes {
t.Errorf("Text: %q, Get 0x%04x, want 0x%04x", expectedText, actualAttributes, expectedAttributes)
}
}
for _, v := range fgParam {
ResetColor()
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
for _, v := range bgParam {
ChangeColor(uint16(0x0070 | 0x0007))
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
for _, v := range resetParam {
ChangeColor(uint16(0x0000 | 0x0070 | 0x0008))
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
ResetColor()
for _, v := range boldParam {
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
ResetColor()
for _, v := range underscoreParam {
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
ResetColor()
for _, v := range blinkParam {
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
for _, v := range mixedParam {
ResetColor()
assertTextAttribute(v.text, v.attributes, v.ansiColor)
}
}
func TestIgnoreUnknownSequences(t *testing.T) {
inner := bytes.NewBufferString("")
w := NewModeAnsiColorWriter(inner, OutputNonColorEscSeq)
inputText := "\x1b[=decpath mode"
expectedTail := inputText
fmt.Fprintf(w, inputText)
actualTail := inner.String()
inner.Reset()
if actualTail != expectedTail {
t.Errorf("Get %q, want %q", actualTail, expectedTail)
}
inputText = "\x1b[=tailing esc and bracket\x1b["
expectedTail = inputText
fmt.Fprintf(w, inputText)
actualTail = inner.String()
inner.Reset()
if actualTail != expectedTail {
t.Errorf("Get %q, want %q", actualTail, expectedTail)
}
inputText = "\x1b[?tailing esc\x1b"
expectedTail = inputText
fmt.Fprintf(w, inputText)
actualTail = inner.String()
inner.Reset()
if actualTail != expectedTail {
t.Errorf("Get %q, want %q", actualTail, expectedTail)
}
inputText = "\x1b[1h;3punended color code invalid\x1b3"
expectedTail = inputText
fmt.Fprintf(w, inputText)
actualTail = inner.String()
inner.Reset()
if actualTail != expectedTail {
t.Errorf("Get %q, want %q", actualTail, expectedTail)
}
}
...@@ -63,7 +63,7 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { ...@@ -63,7 +63,7 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error {
defer c.innerWriter.Close() defer c.innerWriter.Close()
} }
c.lg.println(when, msg) c.lg.writeln(when, msg)
return nil return nil
} }
......
...@@ -17,8 +17,10 @@ package logs ...@@ -17,8 +17,10 @@ package logs
import ( import (
"encoding/json" "encoding/json"
"os" "os"
"runtime" "strings"
"time" "time"
"github.com/shiena/ansicolor"
) )
// brush is a color join function // brush is a color join function
...@@ -54,9 +56,9 @@ type consoleWriter struct { ...@@ -54,9 +56,9 @@ type consoleWriter struct {
// NewConsole create ConsoleWriter returning as LoggerInterface. // NewConsole create ConsoleWriter returning as LoggerInterface.
func NewConsole() Logger { func NewConsole() Logger {
cw := &consoleWriter{ cw := &consoleWriter{
lg: newLogWriter(os.Stdout), lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)),
Level: LevelDebug, Level: LevelDebug,
Colorful: runtime.GOOS != "windows", Colorful: true,
} }
return cw return cw
} }
...@@ -67,11 +69,7 @@ func (c *consoleWriter) Init(jsonConfig string) error { ...@@ -67,11 +69,7 @@ func (c *consoleWriter) Init(jsonConfig string) error {
if len(jsonConfig) == 0 { if len(jsonConfig) == 0 {
return nil return nil
} }
err := json.Unmarshal([]byte(jsonConfig), c) return json.Unmarshal([]byte(jsonConfig), c)
if runtime.GOOS == "windows" {
c.Colorful = false
}
return err
} }
// WriteMsg write message in console. // WriteMsg write message in console.
...@@ -80,9 +78,9 @@ func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { ...@@ -80,9 +78,9 @@ func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error {
return nil return nil
} }
if c.Colorful { if c.Colorful {
msg = colors[level](msg) msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1)
} }
c.lg.println(when, msg) c.lg.writeln(when, msg)
return nil return nil
} }
......
...@@ -8,8 +8,8 @@ import ( ...@@ -8,8 +8,8 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/OwnLocal/goes"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/belogik/goes"
) )
// NewES return a LoggerInterface // NewES return a LoggerInterface
...@@ -21,7 +21,7 @@ func NewES() logs.Logger { ...@@ -21,7 +21,7 @@ func NewES() logs.Logger {
} }
type esLogger struct { type esLogger struct {
*goes.Connection *goes.Client
DSN string `json:"dsn"` DSN string `json:"dsn"`
Level int `json:"level"` Level int `json:"level"`
} }
...@@ -41,8 +41,8 @@ func (el *esLogger) Init(jsonconfig string) error { ...@@ -41,8 +41,8 @@ func (el *esLogger) Init(jsonconfig string) error {
} else if host, port, err := net.SplitHostPort(u.Host); err != nil { } else if host, port, err := net.SplitHostPort(u.Host); err != nil {
return err return err
} else { } else {
conn := goes.NewConnection(host, port) conn := goes.NewClient(host, port)
el.Connection = conn el.Client = conn
} }
return nil return nil
} }
...@@ -78,3 +78,4 @@ func (el *esLogger) Flush() { ...@@ -78,3 +78,4 @@ func (el *esLogger) Flush() {
func init() { func init() {
logs.Register(logs.AdapterEs, NewES) logs.Register(logs.AdapterEs, NewES)
} }
...@@ -47,7 +47,7 @@ import ( ...@@ -47,7 +47,7 @@ import (
// RFC5424 log message levels. // RFC5424 log message levels.
const ( const (
LevelEmergency = iota LevelEmergency = iota
LevelAlert LevelAlert
LevelCritical LevelCritical
LevelError LevelError
...@@ -92,7 +92,7 @@ type Logger interface { ...@@ -92,7 +92,7 @@ type Logger interface {
} }
var adapters = make(map[string]newLoggerFunc) var adapters = make(map[string]newLoggerFunc)
var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "} var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"}
// Register makes a log provide available by the provided name. // Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
...@@ -187,12 +187,12 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { ...@@ -187,12 +187,12 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
} }
} }
log, ok := adapters[adapterName] logAdapter, ok := adapters[adapterName]
if !ok { if !ok {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
} }
lg := log() lg := logAdapter()
err := lg.Init(config) err := lg.Init(config)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
...@@ -248,7 +248,7 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { ...@@ -248,7 +248,7 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) {
} }
// writeMsg will always add a '\n' character // writeMsg will always add a '\n' character
if p[len(p)-1] == '\n' { if p[len(p)-1] == '\n' {
p = p[0: len(p)-1] p = p[0 : len(p)-1]
} }
// set levelLoggerImpl to ensure all log message will be write out // set levelLoggerImpl to ensure all log message will be write out
err = bl.writeMsg(levelLoggerImpl, string(p)) err = bl.writeMsg(levelLoggerImpl, string(p))
...@@ -287,7 +287,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error ...@@ -287,7 +287,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
// set to emergency to ensure all log will be print out correctly // set to emergency to ensure all log will be print out correctly
logLevel = LevelEmergency logLevel = LevelEmergency
} else { } else {
msg = levelPrefix[logLevel] + msg msg = levelPrefix[logLevel] + " " + msg
} }
if bl.asynchronous { if bl.asynchronous {
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
package logs package logs
import ( import (
"fmt"
"io" "io"
"os" "runtime"
"sync" "sync"
"time" "time"
) )
...@@ -31,47 +30,13 @@ func newLogWriter(wr io.Writer) *logWriter { ...@@ -31,47 +30,13 @@ func newLogWriter(wr io.Writer) *logWriter {
return &logWriter{writer: wr} return &logWriter{writer: wr}
} }
func (lg *logWriter) println(when time.Time, msg string) { func (lg *logWriter) writeln(when time.Time, msg string) {
lg.Lock() lg.Lock()
h, _, _:= formatTimeHeader(when) h, _, _ := formatTimeHeader(when)
lg.writer.Write(append(append(h, msg...), '\n')) lg.writer.Write(append(append(h, msg...), '\n'))
lg.Unlock() lg.Unlock()
} }
type outputMode int
// DiscardNonColorEscSeq supports the divided color escape sequence.
// But non-color escape sequence is not output.
// Please use the OutputNonColorEscSeq If you want to output a non-color
// escape sequences such as ncurses. However, it does not support the divided
// color escape sequence.
const (
_ outputMode = iota
DiscardNonColorEscSeq
OutputNonColorEscSeq
)
// NewAnsiColorWriter creates and initializes a new ansiColorWriter
// using io.Writer w as its initial contents.
// In the console of Windows, which change the foreground and background
// colors of the text by the escape sequence.
// In the console of other systems, which writes to w all text.
func NewAnsiColorWriter(w io.Writer) io.Writer {
return NewModeAnsiColorWriter(w, DiscardNonColorEscSeq)
}
// NewModeAnsiColorWriter create and initializes a new ansiColorWriter
// by specifying the outputMode.
func NewModeAnsiColorWriter(w io.Writer, mode outputMode) io.Writer {
if _, ok := w.(*ansiColorWriter); !ok {
return &ansiColorWriter{
w: w,
mode: mode,
}
}
return w
}
const ( const (
y1 = `0123456789` y1 = `0123456789`
y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
...@@ -146,63 +111,65 @@ var ( ...@@ -146,63 +111,65 @@ var (
reset = string([]byte{27, 91, 48, 109}) reset = string([]byte{27, 91, 48, 109})
) )
var once sync.Once
var colorMap map[string]string
func initColor() {
if runtime.GOOS == "windows" {
green = w32Green
white = w32White
yellow = w32Yellow
red = w32Red
blue = w32Blue
magenta = w32Magenta
cyan = w32Cyan
}
colorMap = map[string]string{
//by color
"green": green,
"white": white,
"yellow": yellow,
"red": red,
//by method
"GET": blue,
"POST": cyan,
"PUT": yellow,
"DELETE": red,
"PATCH": green,
"HEAD": magenta,
"OPTIONS": white,
}
}
// ColorByStatus return color by http code // ColorByStatus return color by http code
// 2xx return Green // 2xx return Green
// 3xx return White // 3xx return White
// 4xx return Yellow // 4xx return Yellow
// 5xx return Red // 5xx return Red
func ColorByStatus(cond bool, code int) string { func ColorByStatus(code int) string {
once.Do(initColor)
switch { switch {
case code >= 200 && code < 300: case code >= 200 && code < 300:
return map[bool]string{true: green, false: w32Green}[cond] return colorMap["green"]
case code >= 300 && code < 400: case code >= 300 && code < 400:
return map[bool]string{true: white, false: w32White}[cond] return colorMap["white"]
case code >= 400 && code < 500: case code >= 400 && code < 500:
return map[bool]string{true: yellow, false: w32Yellow}[cond] return colorMap["yellow"]
default: default:
return map[bool]string{true: red, false: w32Red}[cond] return colorMap["red"]
} }
} }
// ColorByMethod return color by http code // ColorByMethod return color by http code
// GET return Blue func ColorByMethod(method string) string {
// POST return Cyan once.Do(initColor)
// PUT return Yellow if c := colorMap[method]; c != "" {
// DELETE return Red return c
// PATCH return Green
// HEAD return Magenta
// OPTIONS return WHITE
func ColorByMethod(cond bool, method string) string {
switch method {
case "GET":
return map[bool]string{true: blue, false: w32Blue}[cond]
case "POST":
return map[bool]string{true: cyan, false: w32Cyan}[cond]
case "PUT":
return map[bool]string{true: yellow, false: w32Yellow}[cond]
case "DELETE":
return map[bool]string{true: red, false: w32Red}[cond]
case "PATCH":
return map[bool]string{true: green, false: w32Green}[cond]
case "HEAD":
return map[bool]string{true: magenta, false: w32Magenta}[cond]
case "OPTIONS":
return map[bool]string{true: white, false: w32White}[cond]
default:
return reset
} }
return reset
} }
// Guard Mutex to guarantee atomic of W32Debug(string) function // ResetColor return reset color
var mu sync.Mutex func ResetColor() string {
return reset
// W32Debug Helper method to output colored logs in Windows terminals
func W32Debug(msg string) {
mu.Lock()
defer mu.Unlock()
current := time.Now()
w := NewAnsiColorWriter(os.Stdout)
fmt.Fprintf(w, "[beego] %v %s\n", current.Format("2006/01/02 - 15:04:05"), msg)
} }
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
package logs package logs
import ( import (
"bytes"
"testing" "testing"
"time" "time"
) )
...@@ -56,20 +55,3 @@ func TestFormatHeader_1(t *testing.T) { ...@@ -56,20 +55,3 @@ func TestFormatHeader_1(t *testing.T) {
tm = tm.Add(dur) tm = tm.Add(dur)
} }
} }
func TestNewAnsiColor1(t *testing.T) {
inner := bytes.NewBufferString("")
w := NewAnsiColorWriter(inner)
if w == inner {
t.Errorf("Get %#v, want %#v", w, inner)
}
}
func TestNewAnsiColor2(t *testing.T) {
inner := bytes.NewBufferString("")
w1 := NewAnsiColorWriter(inner)
w2 := NewAnsiColorWriter(w1)
if w1 != w2 {
t.Errorf("Get %#v, want %#v", w1, w2)
}
}
...@@ -17,7 +17,7 @@ package migration ...@@ -17,7 +17,7 @@ package migration
import ( import (
"fmt" "fmt"
"github.com/astaxie/beego" "github.com/astaxie/beego/logs"
) )
// Index struct defines the structure of Index Columns // Index struct defines the structure of Index Columns
...@@ -316,7 +316,7 @@ func (m *Migration) GetSQL() (sql string) { ...@@ -316,7 +316,7 @@ func (m *Migration) GetSQL() (sql string) {
sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
for index, column := range m.Columns { for index, column := range m.Columns {
if !column.remove { if !column.remove {
beego.BeeLogger.Info("col") logs.Info("col")
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
} else { } else {
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
......
...@@ -176,8 +176,9 @@ func Register(name string, m Migrationer) error { ...@@ -176,8 +176,9 @@ func Register(name string, m Migrationer) error {
func Upgrade(lasttime int64) error { func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap) sm := sortMap(migrationMap)
i := 0 i := 0
migs, _ := getAllMigrations()
for _, v := range sm { for _, v := range sm {
if v.created > lasttime { if _, ok := migs[v.name]; !ok {
logs.Info("start upgrade", v.name) logs.Info("start upgrade", v.name)
v.m.Reset() v.m.Reset()
v.m.Up() v.m.Up()
...@@ -310,3 +311,20 @@ func isRollBack(name string) bool { ...@@ -310,3 +311,20 @@ func isRollBack(name string) bool {
} }
return false return false
} }
func getAllMigrations() (map[string]string, error) {
o := orm.NewOrm()
var maps []orm.Params
migs := make(map[string]string)
num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps)
if err != nil {
logs.Info("get name has error", err)
return migs, err
}
if num > 0 {
for _, v := range maps {
name := v["name"].(string)
migs[name] = v["status"].(string)
}
}
return migs, nil
}
...@@ -207,11 +207,11 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { ...@@ -207,11 +207,11 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns { for _, ni := range ns {
for k, v := range ni.handlers.routers { for k, v := range ni.handlers.routers {
if t, ok := n.handlers.routers[k]; ok { if _, ok := n.handlers.routers[k]; ok {
addPrefix(v, ni.prefix) addPrefix(v, ni.prefix)
n.handlers.routers[k].AddTree(ni.prefix, v) n.handlers.routers[k].AddTree(ni.prefix, v)
} else { } else {
t = NewTree() t := NewTree()
t.AddTree(ni.prefix, v) t.AddTree(ni.prefix, v)
addPrefix(t, ni.prefix) addPrefix(t, ni.prefix)
n.handlers.routers[k] = t n.handlers.routers[k] = t
...@@ -236,11 +236,11 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { ...@@ -236,11 +236,11 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
func AddNamespace(nl ...*Namespace) { func AddNamespace(nl ...*Namespace) {
for _, n := range nl { for _, n := range nl {
for k, v := range n.handlers.routers { for k, v := range n.handlers.routers {
if t, ok := BeeApp.Handlers.routers[k]; ok { if _, ok := BeeApp.Handlers.routers[k]; ok {
addPrefix(v, n.prefix) addPrefix(v, n.prefix)
BeeApp.Handlers.routers[k].AddTree(n.prefix, v) BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else { } else {
t = NewTree() t := NewTree()
t.AddTree(n.prefix, v) t.AddTree(n.prefix, v)
addPrefix(t, n.prefix) addPrefix(t, n.prefix)
BeeApp.Handlers.routers[k] = t BeeApp.Handlers.routers[k] = t
......
...@@ -621,6 +621,31 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ...@@ -621,6 +621,31 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, err return 0, err
} }
var findAutoNowAdd, findAutoNow bool
var index int
for i, col := range setNames {
if mi.fields.GetByColumn(col).autoNowAdd {
index = i
findAutoNowAdd = true
}
if mi.fields.GetByColumn(col).autoNow {
findAutoNow = true
}
}
if findAutoNowAdd {
setNames = append(setNames[0:index], setNames[index+1:]...)
setValues = append(setValues[0:index], setValues[index+1:]...)
}
if !findAutoNow {
for col, info := range mi.fields.columns {
if info.autoNow {
setNames = append(setNames, col)
setValues = append(setValues, time.Now())
}
}
}
setValues = append(setValues, pkValue) setValues = append(setValues, pkValue)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
...@@ -103,6 +104,96 @@ func (ac *_dbCache) getDefault() (al *alias) { ...@@ -103,6 +104,96 @@ func (ac *_dbCache) getDefault() (al *alias) {
return return
} }
type DB struct {
*sync.RWMutex
DB *sql.DB
stmts map[string]*sql.Stmt
}
func (d *DB) Begin() (*sql.Tx, error) {
return d.DB.Begin()
}
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return d.DB.BeginTx(ctx, opts)
}
func (d *DB) getStmt(query string) (*sql.Stmt, error) {
d.RLock()
if stmt, ok := d.stmts[query]; ok {
d.RUnlock()
return stmt, nil
}
d.RUnlock()
stmt, err := d.Prepare(query)
if err != nil {
return nil, err
}
d.Lock()
d.stmts[query] = stmt
d.Unlock()
return stmt, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Exec(args...)
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.ExecContext(ctx, args...)
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.Query(args...)
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := d.getStmt(query)
if err != nil {
return nil, err
}
return stmt.QueryContext(ctx, args...)
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
return stmt.QueryRow(args...)
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := d.getStmt(query)
if err != nil {
panic(err)
}
return stmt.QueryRowContext(ctx, args)
}
type alias struct { type alias struct {
Name string Name string
Driver DriverType Driver DriverType
...@@ -110,7 +201,7 @@ type alias struct { ...@@ -110,7 +201,7 @@ type alias struct {
DataSource string DataSource string
MaxIdleConns int MaxIdleConns int
MaxOpenConns int MaxOpenConns int
DB *sql.DB DB *DB
DbBaser dbBaser DbBaser dbBaser
TZ *time.Location TZ *time.Location
Engine string Engine string
...@@ -176,7 +267,11 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { ...@@ -176,7 +267,11 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias) al := new(alias)
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = db al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
}
if dr, ok := drivers[driverName]; ok { if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr] al.DbBaser = dbBasers[dr]
...@@ -272,7 +367,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { ...@@ -272,7 +367,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
func SetMaxIdleConns(aliasName string, maxIdleConns int) { func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns) al.DB.DB.SetMaxIdleConns(maxIdleConns)
} }
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
...@@ -296,7 +391,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { ...@@ -296,7 +391,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} }
al, ok := dataBaseCache.get(name) al, ok := dataBaseCache.get(name)
if ok { if ok {
return al.DB, nil return al.DB.DB, nil
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
} }
...@@ -335,11 +335,11 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) { ...@@ -335,11 +335,11 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) {
// BootStrap bootstrap models. // BootStrap bootstrap models.
// make all model parsed and can not add more models // make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
modelCache.Lock()
defer modelCache.Unlock()
if modelCache.done { if modelCache.done {
return return
} }
modelCache.Lock()
defer modelCache.Unlock()
bootStrap() bootStrap()
modelCache.done = true modelCache.done = true
} }
...@@ -301,7 +301,7 @@ checkType: ...@@ -301,7 +301,7 @@ checkType:
fi.sf = sf fi.sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name fi.fullName = mi.fullName + mName + "." + sf.Name
fi.description = sf.Tag.Get("description") fi.description = tags["description"]
fi.null = attrs["null"] fi.null = attrs["null"]
fi.index = attrs["index"] fi.index = attrs["index"]
fi.auto = attrs["auto"] fi.auto = attrs["auto"]
......
...@@ -44,6 +44,7 @@ var supportTag = map[string]int{ ...@@ -44,6 +44,7 @@ var supportTag = map[string]int{
"decimals": 2, "decimals": 2,
"on_delete": 2, "on_delete": 2,
"type": 2, "type": 2,
"description": 2,
} }
// get reflect.Type name with package path. // get reflect.Type name with package path.
...@@ -65,7 +66,7 @@ func getTableName(val reflect.Value) string { ...@@ -65,7 +66,7 @@ func getTableName(val reflect.Value) string {
return snakeString(reflect.Indirect(val).Type().Name()) return snakeString(reflect.Indirect(val).Type().Name())
} }
// get table engine, mysiam or innodb. // get table engine, myisam or innodb.
func getTableEngine(val reflect.Value) string { func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine") fun := val.MethodByName("TableEngine")
if fun.IsValid() { if fun.IsValid() {
......
...@@ -60,6 +60,7 @@ import ( ...@@ -60,6 +60,7 @@ import (
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"sync"
"time" "time"
) )
...@@ -72,7 +73,7 @@ const ( ...@@ -72,7 +73,7 @@ const (
var ( var (
Debug = false Debug = false
DebugLog = NewLog(os.Stdout) DebugLog = NewLog(os.Stdout)
DefaultRowsLimit = 1000 DefaultRowsLimit = -1
DefaultRelsDepth = 2 DefaultRelsDepth = 2
DefaultTimeLoc = time.Local DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin") ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
...@@ -522,6 +523,15 @@ func (o *orm) Driver() Driver { ...@@ -522,6 +523,15 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
// return sql.DBStats for current database
func (o *orm) DBStats() *sql.DBStats {
if o.alias != nil && o.alias.DB != nil {
stats := o.alias.DB.DB.Stats()
return &stats
}
return nil
}
// NewOrm create new orm // NewOrm create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
...@@ -548,7 +558,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { ...@@ -548,7 +558,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = db al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmts: make(map[string]*sql.Stmt),
}
detectTZ(al) detectTZ(al)
......
...@@ -29,6 +29,9 @@ type Log struct { ...@@ -29,6 +29,9 @@ type Log struct {
*log.Logger *log.Logger
} }
//costomer log func
var LogFunc func(query map[string]interface{})
// NewLog set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
...@@ -37,12 +40,15 @@ func NewLog(out io.Writer) *Log { ...@@ -37,12 +40,15 @@ func NewLog(out io.Writer) *Log {
} }
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
var logMap = make(map[string]interface{})
sub := time.Now().Sub(t) / 1e5 sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0 elsp := float64(int(sub)) / 10.0
logMap["cost_time"] = elsp
flag := " OK" flag := " OK"
if err != nil { if err != nil {
flag = "FAIL" flag = "FAIL"
} }
logMap["flag"] = flag
con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args)) cons := make([]string, 0, len(args))
for _, arg := range args { for _, arg := range args {
...@@ -54,6 +60,10 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error ...@@ -54,6 +60,10 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil { if err != nil {
con += " - " + err.Error() con += " - " + err.Error()
} }
logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `"))
if LogFunc != nil{
LogFunc(logMap)
}
DebugLog.Println(con) DebugLog.Println(con)
} }
......
...@@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { ...@@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Struct: case reflect.Struct:
if value == nil { if value == nil {
ind.Set(reflect.Zero(ind.Type())) ind.Set(reflect.Zero(ind.Type()))
return
} else if _, ok := ind.Interface().(time.Time); ok { }
switch ind.Interface().(type) {
case time.Time:
var str string var str string
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
...@@ -178,7 +180,25 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { ...@@ -178,7 +180,25 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} }
} }
} }
case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool:
indi := reflect.New(ind.Type()).Interface()
sc, ok := indi.(sql.Scanner)
if !ok {
return
}
err := sc.Scan(value)
if err == nil {
ind.Set(reflect.Indirect(reflect.ValueOf(sc)))
}
}
case reflect.Ptr:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
break
} }
ind.Set(reflect.New(ind.Type().Elem()))
o.setFieldValue(reflect.Indirect(ind), value)
} }
} }
......
...@@ -458,6 +458,15 @@ func TestNullDataTypes(t *testing.T) { ...@@ -458,6 +458,15 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime)))
throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate)))
throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime)))
// test support for pointer fields using RawSeter.QueryRows()
var dnList []*DataNull
Q := dDbBaser.TableQuote()
num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
equal := reflect.DeepEqual(*dnList[0], d)
throwFailNow(t, AssertIs(equal, true))
} }
func TestDataCustomTypes(t *testing.T) { func TestDataCustomTypes(t *testing.T) {
...@@ -1679,6 +1688,31 @@ func TestRawQueryRow(t *testing.T) { ...@@ -1679,6 +1688,31 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(uid, 4)) throwFail(t, AssertIs(uid, 4))
throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(*status, 3))
throwFail(t, AssertIs(pid, nil)) throwFail(t, AssertIs(pid, nil))
// test for sql.Null* fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
}
newId, err := dORM.Insert(nData)
throwFailNow(t, err)
var nd *DataNull
query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q)
err = dORM.Raw(query, newId).QueryRow(&nd)
throwFailNow(t, err)
throwFailNow(t, AssertNot(nd, nil))
throwFail(t, AssertIs(nd.NullBool.Valid, true))
throwFail(t, AssertIs(nd.NullBool.Bool, true))
throwFail(t, AssertIs(nd.NullString.Valid, true))
throwFail(t, AssertIs(nd.NullString.String, "test sql.null"))
throwFail(t, AssertIs(nd.NullInt64.Valid, true))
throwFail(t, AssertIs(nd.NullInt64.Int64, 42))
throwFail(t, AssertIs(nd.NullFloat64.Valid, true))
throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42))
} }
// user_profile table // user_profile table
...@@ -1771,6 +1805,32 @@ func TestQueryRows(t *testing.T) { ...@@ -1771,6 +1805,32 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30)) throwFailNow(t, AssertIs(l[1].Age, 30))
// test for sql.Null* fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
}
newId, err := dORM.Insert(nData)
throwFailNow(t, err)
var nDataList []*DataNull
query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q)
num, err = dORM.Raw(query, newId).QueryRows(&nDataList)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
nd := nDataList[0]
throwFailNow(t, AssertNot(nd, nil))
throwFail(t, AssertIs(nd.NullBool.Valid, true))
throwFail(t, AssertIs(nd.NullBool.Bool, true))
throwFail(t, AssertIs(nd.NullString.Valid, true))
throwFail(t, AssertIs(nd.NullString.String, "test sql.null"))
throwFail(t, AssertIs(nd.NullInt64.Valid, true))
throwFail(t, AssertIs(nd.NullInt64.Int64, 42))
throwFail(t, AssertIs(nd.NullFloat64.Valid, true))
throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
......
...@@ -55,7 +55,7 @@ type Ormer interface { ...@@ -55,7 +55,7 @@ type Ormer interface {
// for example: // for example:
// user := new(User) // user := new(User)
// id, err = Ormer.Insert(user) // id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field // user must be a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// if colu type is integer : can use(+-*/), string : convert(colu,"value") // if colu type is integer : can use(+-*/), string : convert(colu,"value")
...@@ -128,6 +128,7 @@ type Ormer interface { ...@@ -128,6 +128,7 @@ type Ormer interface {
// // update user testing's name to slene // // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter Raw(query string, args ...interface{}) RawSeter
Driver() Driver Driver() Driver
DBStats() *sql.DBStats
} }
// Inserter insert prepared statement // Inserter insert prepared statement
......
...@@ -35,7 +35,7 @@ import ( ...@@ -35,7 +35,7 @@ import (
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
var globalRouterTemplate = `package routers var globalRouterTemplate = `package {{.routersDir}}
import ( import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
...@@ -459,13 +459,17 @@ func genRouterCode(pkgRealpath string) { ...@@ -459,13 +459,17 @@ func genRouterCode(pkgRealpath string) {
imports := "" imports := ""
if len(c.ImportComments) > 0 { if len(c.ImportComments) > 0 {
for _, i := range c.ImportComments { for _, i := range c.ImportComments {
var s string
if i.ImportAlias != "" { if i.ImportAlias != "" {
imports += fmt.Sprintf(` s = fmt.Sprintf(`
%s "%s"`, i.ImportAlias, i.ImportPath) %s "%s"`, i.ImportAlias, i.ImportPath)
} else { } else {
imports += fmt.Sprintf(` s = fmt.Sprintf(`
"%s"`, i.ImportPath) "%s"`, i.ImportPath)
} }
if !strings.Contains(globalimport, s) {
imports += s
}
} }
} }
...@@ -490,7 +494,7 @@ func genRouterCode(pkgRealpath string) { ...@@ -490,7 +494,7 @@ func genRouterCode(pkgRealpath string) {
}`, filters) }`, filters)
} }
globalimport = imports globalimport += imports
globalinfo = globalinfo + ` globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
...@@ -512,7 +516,9 @@ func genRouterCode(pkgRealpath string) { ...@@ -512,7 +516,9 @@ func genRouterCode(pkgRealpath string) {
} }
defer f.Close() defer f.Close()
routersDir := AppConfig.DefaultString("routersdir", "routers")
content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
f.WriteString(content) f.WriteString(content)
} }
...@@ -570,7 +576,8 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { ...@@ -570,7 +576,8 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
func getRouterDir(pkgRealpath string) string { func getRouterDir(pkgRealpath string) string {
dir := filepath.Dir(pkgRealpath) dir := filepath.Dir(pkgRealpath)
for { for {
d := filepath.Join(dir, "routers") routersDir := AppConfig.DefaultString("routersdir", "routers")
d := filepath.Join(dir, routersDir)
if utils.FileExists(d) { if utils.FileExists(d) {
return d return d
} }
......
...@@ -72,8 +72,8 @@ import ( ...@@ -72,8 +72,8 @@ import (
// AppIDToAppSecret is used to get appsecret throw appid // AppIDToAppSecret is used to get appsecret throw appid
type AppIDToAppSecret func(string) string type AppIDToAppSecret func(string) string
// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret // APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret
func APIBaiscAuth(appid, appkey string) beego.FilterFunc { func APIBasicAuth(appid, appkey string) beego.FilterFunc {
ft := func(aid string) string { ft := func(aid string) string {
if aid == appid { if aid == appid {
return appkey return appkey
...@@ -83,6 +83,11 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc { ...@@ -83,6 +83,11 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
return APISecretAuth(ft, 300) return APISecretAuth(ft, 300)
} }
// APIBaiscAuth calls APIBasicAuth for previous callers
func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
return APIBasicAuth(appid, appkey)
}
// APISecretAuth use AppIdToAppSecret verify and // APISecretAuth use AppIdToAppSecret verify and
func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
return func(ctx *context.Context) { return func(ctx *context.Context) {
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
package beego package beego
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
...@@ -479,8 +479,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter ...@@ -479,8 +479,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter { if pos < BeforeStatic || pos > FinishRouter {
err = fmt.Errorf("can not find your filter position") return errors.New("can not find your filter position")
return
} }
p.enableFilter = true p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr) p.filters[pos] = append(p.filters[pos], mr)
...@@ -510,10 +509,10 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri ...@@ -510,10 +509,10 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
} }
} }
} }
controllName := strings.Join(paths[:len(paths)-1], "/") controllerName := strings.Join(paths[:len(paths)-1], "/")
methodName := paths[len(paths)-1] methodName := paths[len(paths)-1]
for m, t := range p.routers { for m, t := range p.routers {
ok, url := p.geturl(t, "/", controllName, methodName, params, m) ok, url := p.getURL(t, "/", controllerName, methodName, params, m)
if ok { if ok {
return url return url
} }
...@@ -521,17 +520,17 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri ...@@ -521,17 +520,17 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
return "" return ""
} }
func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) {
for _, subtree := range t.fixrouters { for _, subtree := range t.fixrouters {
u := path.Join(url, subtree.prefix) u := path.Join(url, subtree.prefix)
ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod) ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod)
if ok { if ok {
return ok, u return ok, u
} }
} }
if t.wildcard != nil { if t.wildcard != nil {
u := path.Join(url, urlPlaceholder) u := path.Join(url, urlPlaceholder)
ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod) ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod)
if ok { if ok {
return ok, u return ok, u
} }
...@@ -539,7 +538,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin ...@@ -539,7 +538,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
for _, l := range t.leaves { for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok { if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego && if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
find := false find := false
if HTTPMETHOD[strings.ToUpper(methodName)] { if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 { if len(c.methods) == 0 {
...@@ -578,18 +577,18 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin ...@@ -578,18 +577,18 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
} }
} }
} }
canskip := false canSkip := false
for _, v := range l.wildcards { for _, v := range l.wildcards {
if v == ":" { if v == ":" {
canskip = true canSkip = true
continue continue
} }
if u, ok := params[v]; ok { if u, ok := params[v]; ok {
delete(params, v) delete(params, v)
url = strings.Replace(url, urlPlaceholder, u, 1) url = strings.Replace(url, urlPlaceholder, u, 1)
} else { } else {
if canskip { if canSkip {
canskip = false canSkip = false
continue continue
} }
return false, "" return false, ""
...@@ -598,27 +597,27 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin ...@@ -598,27 +597,27 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return true, url + toURL(params) return true, url + toURL(params)
} }
var i int var i int
var startreg bool var startReg bool
regurl := "" regURL := ""
for _, v := range strings.Trim(l.regexps.String(), "^$") { for _, v := range strings.Trim(l.regexps.String(), "^$") {
if v == '(' { if v == '(' {
startreg = true startReg = true
continue continue
} else if v == ')' { } else if v == ')' {
startreg = false startReg = false
if v, ok := params[l.wildcards[i]]; ok { if v, ok := params[l.wildcards[i]]; ok {
delete(params, l.wildcards[i]) delete(params, l.wildcards[i])
regurl = regurl + v regURL = regURL + v
i++ i++
} else { } else {
break break
} }
} else if !startreg { } else if !startReg {
regurl = string(append([]rune(regurl), v)) regURL = string(append([]rune(regURL), v))
} }
} }
if l.regexps.MatchString(regurl) { if l.regexps.MatchString(regURL) {
ps := strings.Split(regurl, "/") ps := strings.Split(regURL, "/")
for _, p := range ps { for _, p := range ps {
url = strings.Replace(url, urlPlaceholder, p, 1) url = strings.Replace(url, urlPlaceholder, p, 1)
} }
...@@ -690,7 +689,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -690,7 +689,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// filter wrong http method // filter wrong http method
if !HTTPMETHOD[r.Method] { if !HTTPMETHOD[r.Method] {
http.Error(rw, "Method Not Allowed", 405) exception("405", context)
goto Admin goto Admin
} }
...@@ -779,7 +778,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -779,7 +778,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
runRouter = routerInfo.controllerType runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams methodParams = routerInfo.methodParams
method := r.Method method := r.Method
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut {
method = http.MethodPut method = http.MethodPut
} }
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
...@@ -844,6 +843,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -844,6 +843,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Patch() execController.Patch()
case http.MethodOptions: case http.MethodOptions:
execController.Options() execController.Options()
case http.MethodTrace:
execController.Trace()
default: default:
if !execController.HandlerFunc(runMethod) { if !execController.HandlerFunc(runMethod) {
vc := reflect.ValueOf(execController) vc := reflect.ValueOf(execController)
...@@ -889,7 +890,7 @@ Admin: ...@@ -889,7 +890,7 @@ Admin:
statusCode = 200 statusCode = 200
} }
logAccess(context, &startTime, statusCode) LogAccess(context, &startTime, statusCode)
timeDur := time.Since(startTime) timeDur := time.Since(startTime)
context.ResponseWriter.Elapsed = timeDur context.ResponseWriter.Elapsed = timeDur
...@@ -900,38 +901,28 @@ Admin: ...@@ -900,38 +901,28 @@ Admin:
} }
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
routerName := ""
if runRouter != nil { if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) routerName = runRouter.Name()
} else {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur)
} }
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur)
} }
} }
if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
var devInfo string match := map[bool]string{true: "match", false: "nomatch"}
iswin := (runtime.GOOS == "windows") devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
statusColor := logs.ColorByStatus(iswin, statusCode) context.Input.IP(),
methodColor := logs.ColorByMethod(iswin, r.Method) logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
resetColor := logs.ColorByMethod(iswin, "") timeDur.String(),
if findRouter { match[findRouter],
if routerInfo != nil { logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(),
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode, r.URL.Path)
resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path, if routerInfo != nil {
routerInfo.pattern) devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern)
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
} }
logs.Debug(devInfo)
} }
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
if context.Output.Status != 0 { if context.Output.Status != 0 {
...@@ -980,7 +971,8 @@ func toURL(params map[string]string) string { ...@@ -980,7 +971,8 @@ func toURL(params map[string]string) string {
return strings.TrimRight(u, "&") return strings.TrimRight(u, "&")
} }
func logAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { // LogAccess logging info HTTP Access
func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
//Skip logging if AccessLogs config is false //Skip logging if AccessLogs config is false
if !BConfig.Log.AccessLogs { if !BConfig.Log.AccessLogs {
return return
......
...@@ -71,10 +71,6 @@ func (tc *TestController) GetEmptyBody() { ...@@ -71,10 +71,6 @@ func (tc *TestController) GetEmptyBody() {
tc.Ctx.Output.Body(res) tc.Ctx.Output.Body(res)
} }
type ResStatus struct {
Code int
Msg string
}
type JSONController struct { type JSONController struct {
Controller Controller
...@@ -475,7 +471,7 @@ func TestParamResetFilter(t *testing.T) { ...@@ -475,7 +471,7 @@ func TestParamResetFilter(t *testing.T) {
// a response header of `Splat`. The expectation here is that that Header // a response header of `Splat`. The expectation here is that that Header
// value should match what the _request's_ router set, not the filter's. // value should match what the _request's_ router set, not the filter's.
headers := rw.HeaderMap headers := rw.Result().Header
if len(headers["Splat"]) != 1 { if len(headers["Splat"]) != 1 {
t.Errorf( t.Errorf(
"%s: There was an error in the test. Splat param not set in Header", "%s: There was an error in the test. Splat param not set in Header",
...@@ -660,25 +656,16 @@ func beegoBeforeRouter1(ctx *context.Context) { ...@@ -660,25 +656,16 @@ func beegoBeforeRouter1(ctx *context.Context) {
ctx.WriteString("|BeforeRouter1") ctx.WriteString("|BeforeRouter1")
} }
func beegoBeforeRouter2(ctx *context.Context) {
ctx.WriteString("|BeforeRouter2")
}
func beegoBeforeExec1(ctx *context.Context) { func beegoBeforeExec1(ctx *context.Context) {
ctx.WriteString("|BeforeExec1") ctx.WriteString("|BeforeExec1")
} }
func beegoBeforeExec2(ctx *context.Context) {
ctx.WriteString("|BeforeExec2")
}
func beegoAfterExec1(ctx *context.Context) { func beegoAfterExec1(ctx *context.Context) {
ctx.WriteString("|AfterExec1") ctx.WriteString("|AfterExec1")
} }
func beegoAfterExec2(ctx *context.Context) {
ctx.WriteString("|AfterExec2")
}
func beegoFinishRouter1(ctx *context.Context) { func beegoFinishRouter1(ctx *context.Context) {
ctx.WriteString("|FinishRouter1") ctx.WriteString("|FinishRouter1")
......
...@@ -133,7 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { ...@@ -133,7 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) {
// SessionExist check ledis session exist by sid // SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(sid string) bool { func (lp *Provider) SessionExist(sid string) bool {
count, _ := c.Exists([]byte(sid)) count, _ := c.Exists([]byte(sid))
return !(count == 0) return count != 0
} }
// SessionRegenerate generate new sid for ledis session // SessionRegenerate generate new sid for ledis session
......
...@@ -128,9 +128,12 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { ...@@ -128,9 +128,12 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
} }
} }
item, err := client.Get(sid) item, err := client.Get(sid)
if err != nil && err == memcache.ErrCacheMiss { if err != nil {
rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} if err == memcache.ErrCacheMiss {
return rs, nil rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
return nil, err
} }
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(item.Value) == 0 { if len(item.Value) == 0 {
......
...@@ -170,7 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool { ...@@ -170,7 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
return !(err == sql.ErrNoRows) return err != sql.ErrNoRows
} }
// SessionRegenerate generate new sid for mysql session // SessionRegenerate generate new sid for mysql session
......
...@@ -184,7 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool { ...@@ -184,7 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from session where session_key=$1", sid) row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
return !(err == sql.ErrNoRows) return err != sql.ErrNoRows
} }
// SessionRegenerate generate new sid for postgresql session // SessionRegenerate generate new sid for postgresql session
......
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_sentinel"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
// go globalSessions.GC()
// }
//
// more detail about params: please check the notes on the function SessionInit in this package
package redis_sentinel
import (
"github.com/astaxie/beego/session"
"github.com/go-redis/redis"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
var redispder = &Provider{}
// DefaultPoolSize redis_sentinel default pool size
var DefaultPoolSize = 100
// SessionStore redis_sentinel session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID() string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
poollist *redis.Client
masterName string
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
} else {
rp.masterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(sid string) bool {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false
}
return true
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
return 0
}
func init() {
session.Register("redis_sentinel", redispder)
}
package redis_sentinel
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/session"
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return
}
//todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(w)
// SET AND GET
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get("username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete("username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get("username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set("password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get("username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get("password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush()
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get("username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get("password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(w)
}
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"errors"
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
...@@ -131,6 +132,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { ...@@ -131,6 +132,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
if strings.ContainsAny(sid, "./") { if strings.ContainsAny(sid, "./") {
return nil, nil return nil, nil
} }
if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2")
}
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
......
...@@ -81,6 +81,15 @@ func Register(name string, provide Provider) { ...@@ -81,6 +81,15 @@ func Register(name string, provide Provider) {
provides[name] = provide provides[name] = provide
} }
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config // ManagerConfig define the session config
type ManagerConfig struct { type ManagerConfig struct {
CookieName string `json:"cookieName"` CookieName string `json:"cookieName"`
......
...@@ -38,7 +38,7 @@ var ( ...@@ -38,7 +38,7 @@ var (
beeViewPathTemplates = make(map[string]map[string]*template.Template) beeViewPathTemplates = make(map[string]map[string]*template.Template)
templatesLock sync.RWMutex templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build // beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"} beeTemplateExt = []string{"tpl", "html", "gohtml"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler // beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templatePreProcessor{} beeTemplateEngines = map[string]templatePreProcessor{}
beeTemplateFS = defaultFSFunc beeTemplateFS = defaultFSFunc
...@@ -240,7 +240,7 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t * ...@@ -240,7 +240,7 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *
var fileAbsPath string var fileAbsPath string
var rParent string var rParent string
var err error var err error
if filepath.HasPrefix(file, "../") { if strings.HasPrefix(file, "../") {
rParent = filepath.Join(filepath.Dir(parent), file) rParent = filepath.Join(filepath.Dir(parent), file)
fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
} else { } else {
...@@ -248,10 +248,10 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t * ...@@ -248,10 +248,10 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *
fileAbsPath = filepath.Join(root, file) fileAbsPath = filepath.Join(root, file)
} }
f, err := fs.Open(fileAbsPath) f, err := fs.Open(fileAbsPath)
defer f.Close()
if err != nil { if err != nil {
panic("can't find template file:" + file) panic("can't find template file:" + file)
} }
defer f.Close()
data, err := ioutil.ReadAll(f) data, err := ioutil.ReadAll(f)
if err != nil { if err != nil {
return nil, [][]string{}, err return nil, [][]string{}, err
......
...@@ -55,21 +55,21 @@ func Substr(s string, start, length int) string { ...@@ -55,21 +55,21 @@ func Substr(s string, start, length int) string {
// HTML2str returns escaping text convert from html. // HTML2str returns escaping text convert from html.
func HTML2str(html string) string { func HTML2str(html string) string {
re, _ := regexp.Compile(`\<[\S\s]+?\>`) re := regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower) html = re.ReplaceAllStringFunc(html, strings.ToLower)
//remove STYLE //remove STYLE
re, _ = regexp.Compile(`\<style[\S\s]+?\</style\>`) re = regexp.MustCompile(`\<style[\S\s]+?\</style\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
//remove SCRIPT //remove SCRIPT
re, _ = regexp.Compile(`\<script[\S\s]+?\</script\>`) re = regexp.MustCompile(`\<script[\S\s]+?\</script\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
re, _ = regexp.Compile(`\<[\S\s]+?\>`) re = regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
re, _ = regexp.Compile(`\s{2,}`) re = regexp.MustCompile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html) return strings.TrimSpace(html)
...@@ -85,24 +85,24 @@ func DateFormat(t time.Time, layout string) (datestring string) { ...@@ -85,24 +85,24 @@ func DateFormat(t time.Time, layout string) (datestring string) {
var datePatterns = []string{ var datePatterns = []string{
// year // year
"Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003
"y", "06", //A two digit representation of a year Examples: 99 or 03 "y", "06", //A two digit representation of a year Examples: 99 or 03
// month // month
"m", "01", // Numeric representation of a month, with leading zeros 01 through 12 "m", "01", // Numeric representation of a month, with leading zeros 01 through 12
"n", "1", // Numeric representation of a month, without leading zeros 1 through 12 "n", "1", // Numeric representation of a month, without leading zeros 1 through 12
"M", "Jan", // A short textual representation of a month, three letters Jan through Dec "M", "Jan", // A short textual representation of a month, three letters Jan through Dec
"F", "January", // A full textual representation of a month, such as January or March January through December "F", "January", // A full textual representation of a month, such as January or March January through December
// day // day
"d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31
"j", "2", // Day of the month without leading zeros 1 to 31 "j", "2", // Day of the month without leading zeros 1 to 31
// week // week
"D", "Mon", // A textual representation of a day, three letters Mon through Sun "D", "Mon", // A textual representation of a day, three letters Mon through Sun
"l", "Monday", // A full textual representation of the day of the week Sunday through Saturday "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday
// time // time
"g", "3", // 12-hour format of an hour without leading zeros 1 through 12 "g", "3", // 12-hour format of an hour without leading zeros 1 through 12
"G", "15", // 24-hour format of an hour without leading zeros 0 through 23 "G", "15", // 24-hour format of an hour without leading zeros 0 through 23
"h", "03", // 12-hour format of an hour with leading zeros 01 through 12 "h", "03", // 12-hour format of an hour with leading zeros 01 through 12
"H", "15", // 24-hour format of an hour with leading zeros 00 through 23 "H", "15", // 24-hour format of an hour with leading zeros 00 through 23
...@@ -172,7 +172,7 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{ ...@@ -172,7 +172,7 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{
case "DIY": case "DIY":
value, err = AppConfig.DIY(key) value, err = AppConfig.DIY(key)
default: default:
err = errors.New("Config keys must be of type String, Bool, Int, Int64, Float, or DIY") err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY")
} }
if err != nil { if err != nil {
...@@ -297,9 +297,21 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e ...@@ -297,9 +297,21 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
tag = tags[0] tag = tags[0]
} }
value := form.Get(tag) formValues := form[tag]
if len(value) == 0 { var value string
continue if len(formValues) == 0 {
defaultValue := fieldT.Tag.Get("default")
if defaultValue != "" {
value = defaultValue
} else {
continue
}
}
if len(formValues) == 1 {
value = formValues[0]
if value == "" {
continue
}
} }
switch fieldT.Type.Kind() { switch fieldT.Type.Kind() {
...@@ -349,6 +361,8 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e ...@@ -349,6 +361,8 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
if len(value) >= 25 { if len(value) >= 25 {
value = value[:25] value = value[:25]
t, err = time.ParseInLocation(time.RFC3339, value, time.Local) t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if strings.HasSuffix(strings.ToUpper(value), "Z") {
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if len(value) >= 19 { } else if len(value) >= 19 {
if strings.Contains(value, "T") { if strings.Contains(value, "T") {
value = value[:19] value = value[:19]
......
...@@ -111,7 +111,7 @@ func TestHtmlunquote(t *testing.T) { ...@@ -111,7 +111,7 @@ func TestHtmlunquote(t *testing.T) {
func TestParseForm(t *testing.T) { func TestParseForm(t *testing.T) {
type ExtendInfo struct { type ExtendInfo struct {
Hobby string `form:"hobby"` Hobby []string `form:"hobby"`
Memo string Memo string
} }
...@@ -146,7 +146,7 @@ func TestParseForm(t *testing.T) { ...@@ -146,7 +146,7 @@ func TestParseForm(t *testing.T) {
"date": []string{"2014-11-12"}, "date": []string{"2014-11-12"},
"organization": []string{"beego"}, "organization": []string{"beego"},
"title": []string{"CXO"}, "title": []string{"CXO"},
"hobby": []string{"Basketball"}, "hobby": []string{"", "Basketball", "Football"},
"memo": []string{"nothing"}, "memo": []string{"nothing"},
} }
if err := ParseForm(form, u); err == nil { if err := ParseForm(form, u); err == nil {
...@@ -186,8 +186,14 @@ func TestParseForm(t *testing.T) { ...@@ -186,8 +186,14 @@ func TestParseForm(t *testing.T) {
if u.Title != "CXO" { if u.Title != "CXO" {
t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) t.Errorf("Title should equal `CXO`, but got `%v`", u.Title)
} }
if u.Hobby != "Basketball" { if u.Hobby[0] != "" {
t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby) t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0])
}
if u.Hobby[1] != "Basketball" {
t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1])
}
if u.Hobby[2] != "Football" {
t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2])
} }
if len(u.Memo) != 0 { if len(u.Memo) != 0 {
t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo))
...@@ -197,7 +203,6 @@ func TestParseForm(t *testing.T) { ...@@ -197,7 +203,6 @@ func TestParseForm(t *testing.T) {
func TestRenderForm(t *testing.T) { func TestRenderForm(t *testing.T) {
type user struct { type user struct {
ID int `form:"-"` ID int `form:"-"`
tag string `form:"tag"`
Name interface{} `form:"username"` Name interface{} `form:"username"`
Age int `form:"age,text,年龄:"` Age int `form:"age,text,年龄:"`
Sex string Sex string
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
...@@ -32,6 +33,7 @@ type bounds struct { ...@@ -32,6 +33,7 @@ type bounds struct {
// The bounds for each field. // The bounds for each field.
var ( var (
AdminTaskList map[string]Tasker AdminTaskList map[string]Tasker
taskLock sync.Mutex
stop chan bool stop chan bool
changed chan bool changed chan bool
isstart bool isstart bool
...@@ -389,6 +391,8 @@ func dayMatches(s *Schedule, t time.Time) bool { ...@@ -389,6 +391,8 @@ func dayMatches(s *Schedule, t time.Time) bool {
// StartTask start all tasks // StartTask start all tasks
func StartTask() { func StartTask() {
taskLock.Lock()
defer taskLock.Unlock()
if isstart { if isstart {
//If already started, no need to start another goroutine. //If already started, no need to start another goroutine.
return return
...@@ -440,6 +444,8 @@ func run() { ...@@ -440,6 +444,8 @@ func run() {
// StopTask stop all tasks // StopTask stop all tasks
func StopTask() { func StopTask() {
taskLock.Lock()
defer taskLock.Unlock()
if isstart { if isstart {
isstart = false isstart = false
stop <- true stop <- true
...@@ -449,6 +455,8 @@ func StopTask() { ...@@ -449,6 +455,8 @@ func StopTask() {
// AddTask add task with name // AddTask add task with name
func AddTask(taskname string, t Tasker) { func AddTask(taskname string, t Tasker) {
taskLock.Lock()
defer taskLock.Unlock()
t.SetNext(time.Now().Local()) t.SetNext(time.Now().Local())
AdminTaskList[taskname] = t AdminTaskList[taskname] = t
if isstart { if isstart {
...@@ -458,6 +466,8 @@ func AddTask(taskname string, t Tasker) { ...@@ -458,6 +466,8 @@ func AddTask(taskname string, t Tasker) {
// DeleteTask delete task with name // DeleteTask delete task with name
func DeleteTask(taskname string) { func DeleteTask(taskname string) {
taskLock.Lock()
defer taskLock.Unlock()
delete(AdminTaskList, taskname) delete(AdminTaskList, taskname)
if isstart { if isstart {
changed <- true changed <- true
......
...@@ -162,7 +162,7 @@ func (e *Email) Bytes() ([]byte, error) { ...@@ -162,7 +162,7 @@ func (e *Email) Bytes() ([]byte, error) {
// AttachFile Add attach file to the send mail // AttachFile Add attach file to the send mail
func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { func (e *Email) AttachFile(args ...string) (a *Attachment, err error) {
if len(args) < 1 && len(args) > 2 { if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify a file name and number of parameters can not exceed at least two") err = errors.New("Must specify a file name and number of parameters can not exceed at least two")
return return
} }
...@@ -183,7 +183,7 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { ...@@ -183,7 +183,7 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) {
// Attach is used to attach content from an io.Reader to the email. // Attach is used to attach content from an io.Reader to the email.
// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. // Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type.
func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) {
if len(args) < 1 && len(args) > 2 { if len(args) < 1 || len(args) > 2 { // change && to ||
err = errors.New("Must specify the file type and number of parameters can not exceed at least two") err = errors.New("Must specify the file type and number of parameters can not exceed at least two")
return return
} }
......
...@@ -3,19 +3,78 @@ package utils ...@@ -3,19 +3,78 @@ package utils
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strconv"
"strings" "strings"
) )
// GetGOPATHs returns all paths in GOPATH variable. // GetGOPATHs returns all paths in GOPATH variable.
func GetGOPATHs() []string { func GetGOPATHs() []string {
gopath := os.Getenv("GOPATH") gopath := os.Getenv("GOPATH")
if gopath == "" && strings.Compare(runtime.Version(), "go1.8") >= 0 { if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 {
gopath = defaultGOPATH() gopath = defaultGOPATH()
} }
return filepath.SplitList(gopath) return filepath.SplitList(gopath)
} }
func compareGoVersion(a, b string) int {
reg := regexp.MustCompile("^\\d*")
a = strings.TrimPrefix(a, "go")
b = strings.TrimPrefix(b, "go")
versionsA := strings.Split(a, ".")
versionsB := strings.Split(b, ".")
for i := 0; i < len(versionsA) && i < len(versionsB); i++ {
versionA := versionsA[i]
versionB := versionsB[i]
vA, err := strconv.Atoi(versionA)
if err != nil {
str := reg.FindString(versionA)
if str != "" {
vA, _ = strconv.Atoi(str)
} else {
vA = -1
}
}
vB, err := strconv.Atoi(versionB)
if err != nil {
str := reg.FindString(versionB)
if str != "" {
vB, _ = strconv.Atoi(str)
} else {
vB = -1
}
}
if vA > vB {
// vA = 12, vB = 8
return 1
} else if vA < vB {
// vA = 6, vB = 8
return -1
} else if vA == -1 {
// vA = rc1, vB = rc3
return strings.Compare(versionA, versionB)
}
// vA = vB = 8
continue
}
if len(versionsA) > len(versionsB) {
return 1
} else if len(versionsA) == len(versionsB) {
return 0
}
return -1
}
func defaultGOPATH() string { func defaultGOPATH() string {
env := "HOME" env := "HOME"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
......
package utils
import (
"testing"
)
func TestCompareGoVersion(t *testing.T) {
targetVersion := "go1.8"
if compareGoVersion("go1.12.4", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8.7", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8", targetVersion) != 0 {
t.Error("should be 0")
}
if compareGoVersion("go1.7.6", targetVersion) != -1 {
t.Error("should be -1")
}
if compareGoVersion("go1.12.1rc1", targetVersion) != 1 {
t.Error("should be 1")
}
if compareGoVersion("go1.8rc1", targetVersion) != 0 {
t.Error("should be 0")
}
if compareGoVersion("go1.7rc1", targetVersion) != -1 {
t.Error("should be -1")
}
}
...@@ -268,6 +268,18 @@ func TestMobile(t *testing.T) { ...@@ -268,6 +268,18 @@ func TestMobile(t *testing.T) {
if !valid.Mobile("+8614700008888", "mobile").Ok { if !valid.Mobile("+8614700008888", "mobile").Ok {
t.Error("\"+8614700008888\" is a valid mobile phone number should be true") t.Error("\"+8614700008888\" is a valid mobile phone number should be true")
} }
if !valid.Mobile("17300008888", "mobile").Ok {
t.Error("\"17300008888\" is a valid mobile phone number should be true")
}
if !valid.Mobile("+8617100008888", "mobile").Ok {
t.Error("\"+8617100008888\" is a valid mobile phone number should be true")
}
if !valid.Mobile("8617500008888", "mobile").Ok {
t.Error("\"8617500008888\" is a valid mobile phone number should be true")
}
if valid.Mobile("8617400008888", "mobile").Ok {
t.Error("\"8617400008888\" is a valid mobile phone number should be false")
}
} }
func TestTel(t *testing.T) { func TestTel(t *testing.T) {
...@@ -453,7 +465,7 @@ func TestPointer(t *testing.T) { ...@@ -453,7 +465,7 @@ func TestPointer(t *testing.T) {
u := User{ u := User{
ReqEmail: nil, ReqEmail: nil,
Email: nil, Email: nil,
} }
valid := Validation{} valid := Validation{}
...@@ -468,7 +480,7 @@ func TestPointer(t *testing.T) { ...@@ -468,7 +480,7 @@ func TestPointer(t *testing.T) {
validEmail := "a@a.com" validEmail := "a@a.com"
u = User{ u = User{
ReqEmail: &validEmail, ReqEmail: &validEmail,
Email: nil, Email: nil,
} }
valid = Validation{RequiredFirst: true} valid = Validation{RequiredFirst: true}
...@@ -482,7 +494,7 @@ func TestPointer(t *testing.T) { ...@@ -482,7 +494,7 @@ func TestPointer(t *testing.T) {
u = User{ u = User{
ReqEmail: &validEmail, ReqEmail: &validEmail,
Email: nil, Email: nil,
} }
valid = Validation{} valid = Validation{}
...@@ -497,7 +509,7 @@ func TestPointer(t *testing.T) { ...@@ -497,7 +509,7 @@ func TestPointer(t *testing.T) {
invalidEmail := "a@a" invalidEmail := "a@a"
u = User{ u = User{
ReqEmail: &validEmail, ReqEmail: &validEmail,
Email: &invalidEmail, Email: &invalidEmail,
} }
valid = Validation{RequiredFirst: true} valid = Validation{RequiredFirst: true}
...@@ -511,7 +523,7 @@ func TestPointer(t *testing.T) { ...@@ -511,7 +523,7 @@ func TestPointer(t *testing.T) {
u = User{ u = User{
ReqEmail: &validEmail, ReqEmail: &validEmail,
Email: &invalidEmail, Email: &invalidEmail,
} }
valid = Validation{} valid = Validation{}
...@@ -524,19 +536,18 @@ func TestPointer(t *testing.T) { ...@@ -524,19 +536,18 @@ func TestPointer(t *testing.T) {
} }
} }
func TestCanSkipAlso(t *testing.T) { func TestCanSkipAlso(t *testing.T) {
type User struct { type User struct {
ID int ID int
Email string `valid:"Email"` Email string `valid:"Email"`
ReqEmail string `valid:"Required;Email"` ReqEmail string `valid:"Required;Email"`
MatchRange int `valid:"Range(10, 20)"` MatchRange int `valid:"Range(10, 20)"`
} }
u := User{ u := User{
ReqEmail: "a@a.com", ReqEmail: "a@a.com",
Email: "", Email: "",
MatchRange: 0, MatchRange: 0,
} }
...@@ -560,4 +571,3 @@ func TestCanSkipAlso(t *testing.T) { ...@@ -560,4 +571,3 @@ func TestCanSkipAlso(t *testing.T) {
} }
} }
...@@ -632,7 +632,7 @@ func (b Base64) GetLimitValue() interface{} { ...@@ -632,7 +632,7 @@ func (b Base64) GetLimitValue() interface{} {
} }
// just for chinese mobile phone number // just for chinese mobile phone number
var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`) var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][01356789]|[4][579]))\d{8}$`)
// Mobile check struct // Mobile check struct
type Mobile struct { type Mobile struct {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册