提交 b5db1ca4 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Internal change

PiperOrigin-RevId: 340146798
Change-Id: I22b6f23e305c21c7c84d1ad2c0f22ee6fbef1343
上级 3ac50899
......@@ -291,13 +291,10 @@ std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
class OpsTest : public ::testing::TestWithParam<string> {};
TEST_P(OpsTest, RunZipTests) {
string test_path_and_label = GetParam();
size_t end_pos = test_path_and_label.find(" ");
string test_path = test_path_and_label.substr(0, end_pos);
string label = test_path_and_label.substr(end_pos + 1);
string test_path = GetParam();
string tflite_test_case = test_path + "_tests.txt";
string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
string test_name = label.substr(label.find_last_of('/'));
string test_name = test_path.substr(test_path.find_last_of('/'));
std::ifstream tflite_stream(tflite_test_case);
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
......@@ -308,7 +305,7 @@ TEST_P(OpsTest, RunZipTests) {
auto quantized_tests_error = GetQuantizeTestsError();
bool fully_quantize = false;
if (label.find("fully_quantize=True") != std::string::npos) {
if (test_path.find("fully_quantize=True") != std::string::npos) {
for (const auto& p : quantized_tests_error) {
if (RE2::PartialMatch(test_name, p.first)) {
test_driver.SetQuantizationErrorMultiplier(p.second);
......
......@@ -40,7 +40,6 @@ def make_conv_activation_tests(activation_op):
"constant_filter": [True, False],
"channel_multiplier": [1, 2],
"fully_quantize": [False],
"quant_16x8": [False],
"dynamic_range_quantize": [False],
},
# TODO(b/134702301): The fully_quantize param is just ignored by the
......@@ -48,15 +47,14 @@ def make_conv_activation_tests(activation_op):
# these tests or handle it properly in the mlir_convert() function.
{
"input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]],
"filter_shape": [[1, 1], [2, 3]],
"filter_shape": [[1, 1], [2, 3], [3, 3]],
"strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
"dilations": [[1, 1, 1, 1], [1, 3, 2, 1]],
"dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
"padding": ["SAME", "VALID"],
"data_format": ["NHWC"], # TODO(aselle): NCHW would be good
"constant_filter": [True],
"channel_multiplier": [1, 2],
"fully_quantize": [True],
"quant_16x8": [False, True],
"dynamic_range_quantize": [False],
},
{
......@@ -69,7 +67,6 @@ def make_conv_activation_tests(activation_op):
"constant_filter": [True],
"channel_multiplier": [1, 2],
"fully_quantize": [False],
"quant_16x8": [False],
"dynamic_range_quantize": [True],
},
]
......@@ -126,7 +123,7 @@ def make_conv_activation_tests(activation_op):
test_parameters,
build_graph,
build_inputs,
expected_tf_failures=48)
expected_tf_failures=60)
return f
......
......@@ -342,7 +342,6 @@ def make_zip_of_tests(options,
if options.multi_gen_state:
label_base_path = options.multi_gen_state.label_base_path
i = 1
for parameters in test_parameters:
keys = parameters.keys()
for curr in itertools.product(*parameters.values()):
......@@ -350,8 +349,6 @@ def make_zip_of_tests(options,
"%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
if label[0] == "/":
label = label[1:]
zip_path_label = label_base_path.replace(".zip", "_") + str(i)
i += 1
if label in processed_labels:
# Do not populate data for the same label more than once. It will cause
# errors when unzipping.
......@@ -400,14 +397,13 @@ def make_zip_of_tests(options,
return input_values, output_values
def build_example(label, param_dict_real, zip_path_label):
def build_example(label, param_dict_real):
"""Build the model with parameter values set in param_dict_real.
Args:
label: Label of the model
label: Label of the model (i.e. the filename in the zip).
param_dict_real: Parameter dictionary (arguments to the factories
make_graph and make_test_inputs)
zip_path_label: Filename in the zip
Returns:
(tflite_model_binary, report) where tflite_model_binary is the
......@@ -470,7 +466,7 @@ def make_zip_of_tests(options,
report["toco_log"] = toco_log
if options.save_graphdefs:
archive.writestr(zip_path_label + ".pbtxt",
archive.writestr(label + ".pbtxt",
text_format.MessageToString(graph_def),
zipfile.ZIP_DEFLATED)
......@@ -479,25 +475,25 @@ def make_zip_of_tests(options,
# Set proper min max values according to input dtype.
baseline_inputs, baseline_outputs = generate_inputs_outputs(
tflite_model_binary, min_value=0, max_value=255)
archive.writestr(zip_path_label + ".bin", tflite_model_binary,
archive.writestr(label + ".bin", tflite_model_binary,
zipfile.ZIP_DEFLATED)
example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
example_fp = StringIO()
write_examples(example_fp, [example])
archive.writestr(zip_path_label + ".inputs", example_fp.getvalue(),
archive.writestr(label + ".inputs", example_fp.getvalue(),
zipfile.ZIP_DEFLATED)
example_fp2 = StringIO()
write_test_cases(example_fp2, zip_path_label + ".bin", [example])
archive.writestr(zip_path_label + "_tests.txt",
example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
write_test_cases(example_fp2, label + ".bin", [example])
archive.writestr(label + "_tests.txt", example_fp2.getvalue(),
zipfile.ZIP_DEFLATED)
zip_manifest.append(zip_path_label + " " + label + "\n")
zip_manifest.append(label + "\n")
return tflite_model_binary, report
_, report = build_example(label, param_dict, zip_path_label)
_, report = build_example(label, param_dict)
if report["toco"] == report_lib.FAILED:
ignore_error = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册