提交 5add5979 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!92 add prim name to param check error message for math_ops.py

Merge pull request !92 from fary86/add-prim-name-for-param-validator
......@@ -15,6 +15,7 @@
"""Check parameters."""
import re
from enum import Enum
from functools import reduce
from itertools import repeat
from collections import Iterable
......@@ -93,8 +94,131 @@ rel_strs = {
}
class Validator:
"""validator for checking input parameters"""
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
"""
Method for judging relation between two int values or list/tuple made up of ints.
This method is not suitable for judging relation between floats, since it does not consider float error.
"""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def _check_argument_type(arg):
arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
return arg
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
excp_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass
else:
excp_flag = True
if excp_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
reduce(_check_types_same, map(_check_argument_type, args.items()))
@staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
f'{"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
class ParamValidator:
"""Parameter validator."""
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
......
......@@ -16,13 +16,14 @@
"""broadcast"""
def _get_broadcast_shape(x_shape, y_shape):
def _get_broadcast_shape(x_shape, y_shape, prim_name):
"""
Doing broadcast between tensor x and tensor y.
Args:
x_shape (list): The shape of tensor x.
y_shape (list): The shape of tensor y.
prim_name (str): Primitive name.
Returns:
List, the shape that broadcast between tensor x and tensor y.
......@@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape):
elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i])
else:
raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape))
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
prim_name, x_shape, y_shape))
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back
......
......@@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims):
validator.check_type('keep_dims', keep_dims, [bool])
validator.check_type('axis', axis, [int, tuple])
if isinstance(axis, tuple):
for index, value in enumerate(axis):
validator.check_type('axis[%d]' % index, value, [int])
class ExpandDims(PrimitiveWithInfer):
"""
......@@ -1090,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):
......@@ -1136,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):
......
......@@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self):
"""
Deeply clones the primitive object.
......
......@@ -23,20 +23,25 @@ from ...utils import keyword
class CheckExceptionsEC(IExectorComponent):
"""
Check if the function raises the expected Exception.
Check if the function raises the expected Exception and the error message contains specified keywords if not None.
Examples:
{
'block': f,
'exception': Exception
'exception': Exception,
'error_keywords': ['TensorAdd', 'shape']
}
"""
def run_function(self, function, inputs, verification_set):
f = function[keyword.block]
args = inputs[keyword.desc_inputs]
e = function.get(keyword.exception, Exception)
error_kws = function.get(keyword.error_keywords, None)
try:
with pytest.raises(e):
with pytest.raises(e) as exec_info:
f(*args)
except:
raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}")
if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws):
raise ValueError('Error message `{}` does not contain all keywords `{}`'.format(
str(exec_info.value), error_kws))
......@@ -87,8 +87,9 @@ def get_function_config(function):
init_param_with = function.get(keyword.init_param_with, None)
split_outputs = function.get(keyword.split_outputs, True)
exception = function.get(keyword.exception, Exception)
error_keywords = function.get(keyword.error_keywords, None)
return delta, max_error, input_selector, output_selector, sampling_times, \
reduce_output, init_param_with, split_outputs, exception
reduce_output, init_param_with, split_outputs, exception, error_keywords
def get_grad_checking_options(function, inputs):
"""
......@@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs):
"""
f = function[keyword.block]
args = inputs[keyword.desc_inputs]
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \
get_function_config(function)
return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output
......@@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
block = block_config
delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({})
sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({})
if isinstance(block_config, tuple) and isinstance(block_config[-1], dict):
block = block_config[0]
delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1])
sampling_times, reduce_output, init_param_with, \
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
if block:
func_list.append({
......@@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
keyword.const_first: const_first,
keyword.add_fake_input: add_fake_input,
keyword.split_outputs: split_outputs,
keyword.exception: exception
keyword.exception: exception,
keyword.error_keywords: error_keywords
})
if desc_inputs or desc_const:
......
......@@ -73,5 +73,6 @@ keyword.const_first = "const_first"
keyword.add_fake_input = "add_fake_input"
keyword.fake_input_type = "fake_input_type"
keyword.exception = "exception"
keyword.error_keywords = "error_keywords"
sys.modules[__name__] = keyword
......@@ -234,7 +234,7 @@ raise_set = [
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('ReduceSum_Error', {
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}),
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
]
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册