diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py index 8895a3f36aafb23d62cdb5f286d761134671fdae..6b73c9f88fa7ea615cc8c3ee9807feec085ab638 100644 --- a/python_module/test/integration/test_parampack.py +++ b/python_module/test/integration/test_parampack.py @@ -105,9 +105,15 @@ def test_static_graph_parampack(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow @@ -140,9 +146,15 @@ def test_nopack_parampack(): losses.append(loss.numpy()) assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow @@ -178,9 +190,15 @@ def test_dynamic_graph_parampack(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" @pytest.mark.slow