未验证 提交 1a5e5989 编写于 作者: B bigsheeper 提交者: GitHub

Add segment unittests for query node (#7632)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 0caf1016
...@@ -12,63 +12,12 @@ ...@@ -12,63 +12,12 @@
package querynode package querynode
import ( import (
"strconv"
"testing" "testing"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
)
func genIndexBinarySet() ([][]byte, error) {
const (
msgLength = 1000
DIM = 16
)
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
indexParams["dim"] = "16"
indexParams["k"] = "10"
indexParams["nlist"] = "100"
indexParams["nprobe"] = "10"
indexParams["m"] = "4"
indexParams["nbits"] = "8"
indexParams["metric_type"] = "L2"
indexParams["SLICE_SIZE"] = "4"
typeParams := make(map[string]string)
typeParams["dim"] = strconv.Itoa(DIM)
var indexRowData []float32
for n := 0; n < msgLength; n++ {
for i := 0; i < DIM; i++ {
indexRowData = append(indexRowData, float32(n*i))
}
}
index, err := indexnode.NewCIndex(typeParams, indexParams)
if err != nil {
return nil, err
}
err = index.BuildFloatVecIndexWithoutIds(indexRowData) "github.com/milvus-io/milvus/internal/proto/commonpb"
if err != nil { )
return nil, err
}
// save index to minio
binarySet, err := index.Serialize()
if err != nil {
return nil, err
}
bytesSet := make([][]byte, 0)
for i := range binarySet {
bytesSet = append(bytesSet, binarySet[i].Value)
}
return bytesSet, nil
}
func TestLoadIndexInfo(t *testing.T) { func TestLoadIndexInfo(t *testing.T) {
indexParams := make([]*commonpb.KeyValuePair, 0) indexParams := make([]*commonpb.KeyValuePair, 0)
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/indexnode"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
minioKV "github.com/milvus-io/milvus/internal/kv/minio" minioKV "github.com/milvus-io/milvus/internal/kv/minio"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
...@@ -139,6 +140,56 @@ func genFloatVectorField(param vecFieldParam) *schemapb.FieldSchema { ...@@ -139,6 +140,56 @@ func genFloatVectorField(param vecFieldParam) *schemapb.FieldSchema {
return fieldVec return fieldVec
} }
func genSimpleIndexParams() indexParam {
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
indexParams["dim"] = strconv.FormatInt(defaultDim, 10)
indexParams["k"] = "10"
indexParams["nlist"] = "100"
indexParams["nprobe"] = "10"
indexParams["m"] = "4"
indexParams["nbits"] = "8"
indexParams["metric_type"] = "L2"
indexParams["SLICE_SIZE"] = "400"
return indexParams
}
func genIndexBinarySet() ([][]byte, error) {
indexParams := genSimpleIndexParams()
typeParams := make(map[string]string)
typeParams["dim"] = strconv.Itoa(defaultDim)
var indexRowData []float32
for n := 0; n < defaultMsgLength; n++ {
for i := 0; i < defaultDim; i++ {
indexRowData = append(indexRowData, float32(n*i))
}
}
index, err := indexnode.NewCIndex(typeParams, indexParams)
if err != nil {
return nil, err
}
err = index.BuildFloatVecIndexWithoutIds(indexRowData)
if err != nil {
return nil, err
}
// save index to minio
binarySet, err := index.Serialize()
if err != nil {
return nil, err
}
bytesSet := make([][]byte, 0)
for i := range binarySet {
bytesSet = append(bytesSet, binarySet[i].Value)
}
return bytesSet, nil
}
func genSimpleSchema() (*schemapb.CollectionSchema, *schemapb.CollectionSchema) { func genSimpleSchema() (*schemapb.CollectionSchema, *schemapb.CollectionSchema) {
fieldUID := genConstantField(uidField) fieldUID := genConstantField(uidField)
fieldTimestamp := genConstantField(timestampField) fieldTimestamp := genConstantField(timestampField)
......
...@@ -234,7 +234,6 @@ func (s *Segment) getRowCount() int64 { ...@@ -234,7 +234,6 @@ func (s *Segment) getRowCount() int64 {
return -1 return -1
} }
var rowCount = C.GetRowCount(s.segmentPtr) var rowCount = C.GetRowCount(s.segmentPtr)
//log.Debug("QueryNode::Segment::getRowCount", zap.Any("rowCount", rowCount))
return int64(rowCount) return int64(rowCount)
} }
...@@ -502,7 +501,7 @@ func (s *Segment) matchIndexParam(fieldID int64, indexParams indexParam) bool { ...@@ -502,7 +501,7 @@ func (s *Segment) matchIndexParam(fieldID int64, indexParams indexParam) bool {
if fieldIndexParam == nil { if fieldIndexParam == nil {
return false return false
} }
paramSize := len(s.indexInfos) paramSize := len(s.indexInfos[fieldID].indexParams)
matchCount := 0 matchCount := 0
for k, v := range indexParams { for k, v := range indexParams {
value, ok := fieldIndexParam[k] value, ok := fieldIndexParam[k]
...@@ -591,7 +590,6 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps ...@@ -591,7 +590,6 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps
if s.segmentType != segmentTypeGrowing { if s.segmentType != segmentTypeGrowing {
return nil return nil
} }
log.Debug("QueryNode::Segment::segmentInsert:", zap.Any("s.segmentPtr", s.segmentPtr))
if s.segmentPtr == nil { if s.segmentPtr == nil {
return errors.New("null seg core pointer") return errors.New("null seg core pointer")
......
...@@ -23,9 +23,11 @@ import ( ...@@ -23,9 +23,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
) )
//-------------------------------------------------------------------------------------- constructor and destructor //-------------------------------------------------------------------------------------- constructor and destructor
...@@ -41,6 +43,19 @@ func TestSegment_newSegment(t *testing.T) { ...@@ -41,6 +43,19 @@ func TestSegment_newSegment(t *testing.T) {
assert.Equal(t, segmentID, segment.segmentID) assert.Equal(t, segmentID, segment.segmentID)
deleteSegment(segment) deleteSegment(segment)
deleteCollection(collection) deleteCollection(collection)
t.Run("test invalid type", func(t *testing.T) {
s := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
collectionID, "", segmentTypeInvalid, true)
assert.Nil(t, s)
s = newSegment(collection,
defaultSegmentID,
defaultPartitionID,
collectionID, "", 100, true)
assert.Nil(t, s)
})
} }
func TestSegment_deleteSegment(t *testing.T) { func TestSegment_deleteSegment(t *testing.T) {
...@@ -56,6 +71,13 @@ func TestSegment_deleteSegment(t *testing.T) { ...@@ -56,6 +71,13 @@ func TestSegment_deleteSegment(t *testing.T) {
deleteSegment(segment) deleteSegment(segment)
deleteCollection(collection) deleteCollection(collection)
t.Run("test delete nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment()
assert.NoError(t, err)
s.segmentPtr = nil
deleteSegment(s)
})
} }
//-------------------------------------------------------------------------------------- stats functions //-------------------------------------------------------------------------------------- stats functions
...@@ -105,6 +127,14 @@ func TestSegment_getRowCount(t *testing.T) { ...@@ -105,6 +127,14 @@ func TestSegment_getRowCount(t *testing.T) {
deleteSegment(segment) deleteSegment(segment)
deleteCollection(collection) deleteCollection(collection)
t.Run("test getRowCount nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment()
assert.NoError(t, err)
s.segmentPtr = nil
res := s.getRowCount()
assert.Equal(t, int64(-1), res)
})
} }
func TestSegment_retrieve(t *testing.T) { func TestSegment_retrieve(t *testing.T) {
...@@ -253,6 +283,14 @@ func TestSegment_getDeletedCount(t *testing.T) { ...@@ -253,6 +283,14 @@ func TestSegment_getDeletedCount(t *testing.T) {
assert.Equal(t, deletedCount, int64(0)) assert.Equal(t, deletedCount, int64(0))
deleteCollection(collection) deleteCollection(collection)
t.Run("test getDeletedCount nil ptr", func(t *testing.T) {
s, err := genSimpleSealedSegment()
assert.NoError(t, err)
s.segmentPtr = nil
res := s.getDeletedCount()
assert.Equal(t, int64(-1), res)
})
} }
func TestSegment_getMemSize(t *testing.T) { func TestSegment_getMemSize(t *testing.T) {
...@@ -345,6 +383,22 @@ func TestSegment_segmentInsert(t *testing.T) { ...@@ -345,6 +383,22 @@ func TestSegment_segmentInsert(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
deleteSegment(segment) deleteSegment(segment)
deleteCollection(collection) deleteCollection(collection)
t.Run("test nil segment", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
segment.setType(segmentTypeGrowing)
segment.segmentPtr = nil
err = segment.segmentInsert(0, nil, nil, nil)
assert.Error(t, err)
})
t.Run("test invalid segment type", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
err = segment.segmentInsert(0, nil, nil, nil)
assert.NoError(t, err)
})
} }
func TestSegment_segmentDelete(t *testing.T) { func TestSegment_segmentDelete(t *testing.T) {
...@@ -746,8 +800,281 @@ func TestSegment_indexInfoTest(t *testing.T) { ...@@ -746,8 +800,281 @@ func TestSegment_indexInfoTest(t *testing.T) {
id = seg.getBuildID(fieldID) id = seg.getBuildID(fieldID)
assert.Equal(t, int64(-1), id) assert.Equal(t, int64(-1), id)
err = seg.setIndexInfo(fieldID, &indexInfo{
readyLoad: true,
})
assert.NoError(t, err)
ready := seg.checkIndexReady(fieldID)
assert.True(t, ready)
ready = seg.checkIndexReady(FieldID(1000))
assert.False(t, ready)
seg.indexInfos = nil seg.indexInfos = nil
err = seg.setIndexInfo(fieldID, &indexInfo{}) err = seg.setIndexInfo(fieldID, &indexInfo{
readyLoad: true,
})
assert.Error(t, err)
})
}
func TestSegment_BasicMetrics(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
t.Run("test enable index", func(t *testing.T) {
segment.setEnableIndex(true)
enable := segment.getEnableIndex()
assert.True(t, enable)
})
t.Run("test id binlog row size", func(t *testing.T) {
size := int64(1024)
segment.setIDBinlogRowSizes([]int64{size})
sizes := segment.getIDBinlogRowSizes()
assert.Len(t, sizes, 1)
assert.Equal(t, size, sizes[0])
})
t.Run("test type", func(t *testing.T) {
sType := segmentTypeGrowing
segment.setType(sType)
resType := segment.getType()
assert.Equal(t, sType, resType)
})
t.Run("test onService", func(t *testing.T) {
segment.setOnService(false)
resOnService := segment.getOnService()
assert.Equal(t, false, resOnService)
})
t.Run("test VectorFieldInfo", func(t *testing.T) {
fieldID := rowIDFieldID
info := &VectorFieldInfo{
fieldBinlog: &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []string{},
},
}
segment.setVectorFieldInfo(fieldID, info)
resInfo, err := segment.getVectorFieldInfo(fieldID)
assert.NoError(t, err)
assert.Equal(t, info, resInfo)
_, err = segment.getVectorFieldInfo(FieldID(1000))
assert.Error(t, err)
})
}
func TestSegment_fillVectorFieldsData(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
vecCM, err := genVectorChunkManager(ctx)
assert.NoError(t, err)
t.Run("test fillVectorFieldsData float-vector invalid vectorChunkManager", func(t *testing.T) {
fieldID := FieldID(100)
fieldName := "float-vector-field-0"
err = segment.setIndexInfo(fieldID, &indexInfo{})
assert.NoError(t, err)
info := &VectorFieldInfo{
fieldBinlog: &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []string{},
},
}
segment.setVectorFieldInfo(fieldID, info)
fieldData := []*schemapb.FieldData{
{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
FieldId: fieldID,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: defaultDim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2, 3.3, 4.4},
},
},
},
},
},
}
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{},
Offset: []int64{0},
FieldsData: fieldData,
}
err = segment.fillVectorFieldsData(defaultCollectionID, vecCM, result)
assert.Error(t, err)
})
}
func TestSegment_indexParam(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
t.Run("test indexParam", func(t *testing.T) {
fieldID := rowIDFieldID
iParam := genSimpleIndexParams()
segment.indexInfos[fieldID] = &indexInfo{}
err := segment.setIndexParam(fieldID, iParam)
assert.NoError(t, err)
_ = segment.getIndexParams(fieldID)
match := segment.matchIndexParam(fieldID, iParam)
assert.True(t, match)
match = segment.matchIndexParam(FieldID(1000), nil)
assert.False(t, match)
})
}
func TestSegment_dropFieldData(t *testing.T) {
t.Run("test dropFieldData", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
segment.setType(segmentTypeIndexing)
err = segment.dropFieldData(simpleVecField.id)
assert.NoError(t, err)
})
t.Run("test nil segment", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
segment.segmentPtr = nil
err = segment.dropFieldData(simpleVecField.id)
assert.Error(t, err)
})
t.Run("test invalid segment type", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
err = segment.dropFieldData(simpleVecField.id)
assert.Error(t, err)
})
t.Run("test invalid field", func(t *testing.T) {
segment, err := genSimpleSealedSegment()
assert.NoError(t, err)
segment.setType(segmentTypeIndexing)
err = segment.dropFieldData(FieldID(1000))
assert.Error(t, err)
})
}
func TestSegment_updateSegmentIndex(t *testing.T) {
t.Run("test updateSegmentIndex invalid", func(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
fieldID := rowIDFieldID
iParam := genSimpleIndexParams()
segment.indexInfos[fieldID] = &indexInfo{}
err := segment.setIndexParam(fieldID, iParam)
assert.NoError(t, err)
indexPaths := make([]string, 0)
indexPaths = append(indexPaths, "IVF")
err = segment.setIndexPaths(fieldID, indexPaths)
assert.NoError(t, err)
indexBytes, err := genIndexBinarySet()
assert.NoError(t, err)
err = segment.updateSegmentIndex(indexBytes, fieldID)
assert.Error(t, err)
segment.setType(segmentTypeGrowing)
err = segment.updateSegmentIndex(indexBytes, fieldID)
assert.Error(t, err)
segment.setType(segmentTypeSealed)
segment.segmentPtr = nil
err = segment.updateSegmentIndex(indexBytes, fieldID)
assert.Error(t, err)
})
}
func TestSegment_dropSegmentIndex(t *testing.T) {
t.Run("test dropSegmentIndex invalid segment type", func(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
fieldID := rowIDFieldID
err := segment.dropSegmentIndex(fieldID)
assert.Error(t, err)
})
t.Run("test dropSegmentIndex nil segment ptr", func(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
segment.segmentPtr = nil
fieldID := rowIDFieldID
err := segment.dropSegmentIndex(fieldID)
assert.Error(t, err)
})
t.Run("test dropSegmentIndex nil index", func(t *testing.T) {
_, schema := genSimpleSchema()
collection := newCollection(defaultCollectionID, schema)
segment := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultVChannel,
segmentTypeSealed,
true)
segment.setType(segmentTypeIndexing)
fieldID := rowIDFieldID
err := segment.dropSegmentIndex(fieldID)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
...@@ -16,8 +16,8 @@ import ( ...@@ -16,8 +16,8 @@ import (
) )
const ( const (
rowIDFieldID = 0 rowIDFieldID FieldID = 0
timestampFieldID = 1 timestampFieldID FieldID = 1
) )
const invalidTimestamp = Timestamp(0) const invalidTimestamp = Timestamp(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册