metrics.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import paddle.fluid as fluid

__all__ = ['chunk_count', "build_chunk"]


def build_chunk(data_list, id2label_dict):
    """
    Assembly entity
    """
    tag_list = [id2label_dict.get(str(id)) for id in data_list]
    ner_dict = {}
    ner_str = ""
    ner_start = 0
    for i in range(len(tag_list)):
        tag = tag_list[i]
        if tag == u"O":
            if i != 0:
                key = "%d_%d" % (ner_start, i - 1)
                ner_dict[key] = ner_str
            ner_start = i
            ner_str = tag
        elif tag.endswith(u"B"):
            if i != 0:
                key = "%d_%d" % (ner_start, i - 1)
                ner_dict[key] = ner_str
            ner_start = i
            ner_str = tag.split('-')[0]
        elif tag.endswith(u"I"):
            if tag.split('-')[0] != ner_str:
                if i != 0:
                    key = "%d_%d" % (ner_start, i - 1)
                    ner_dict[key] = ner_str
                ner_start = i
                ner_str = tag.split('-')[0]
    return ner_dict


def chunk_count(infer_numpy, label_numpy, seq_len, id2label_dict):
    """
    calculate num_correct_chunks num_error_chunks total_num for metrics
    """
    num_infer_chunks, num_label_chunks, num_correct_chunks = 0, 0, 0
    assert infer_numpy.shape[0] == label_numpy.shape[0]

    for i in range(infer_numpy.shape[0]):
        infer_list = infer_numpy[i][:seq_len[i]]
        label_list = label_numpy[i][:seq_len[i]]
        infer_dict = build_chunk(infer_list, id2label_dict)
        num_infer_chunks += len(infer_dict)
        label_dict = build_chunk(label_list, id2label_dict)
        num_label_chunks += len(label_dict)
        for key in infer_dict:
            if key in label_dict and label_dict[key] == infer_dict[key]:
                num_correct_chunks += 1
    return num_infer_chunks, num_label_chunks, num_correct_chunks