未验证 提交 167b7d04 编写于 作者: A Asim Shankar 提交者: GitHub

Merge pull request #4140 from asimshankar/r1.8.0-fix

[samples/core/get_starter/eager]: Update with API simplifications in 1.8
......@@ -114,7 +114,7 @@
"source": [
"### Install the latest version of TensorFlow\n",
"\n",
"This tutorial uses eager execution, which is available in [TensorFlow 1.7](https://www.tensorflow.org/install/). (You may need to restart the runtime after upgrading.)"
"This tutorial uses eager execution, which is available in [TensorFlow 1.8](https://www.tensorflow.org/install/). (You may need to restart the runtime after upgrading.)"
]
},
{
......@@ -374,7 +374,7 @@
"train_dataset = train_dataset.batch(32)\n",
"\n",
"# View a single example entry from a batch\n",
"features, label = tfe.Iterator(train_dataset).next()\n",
"features, label = iter(train_dataset).next()\n",
"print(\"example features:\", features[0])\n",
"print(\"example label:\", label[0])"
],
......@@ -508,7 +508,7 @@
"\n",
"\n",
"def grad(model, inputs, targets):\n",
" with tfe.GradientTape() as tape:\n",
" with tf.GradientTape() as tape:\n",
" loss_value = loss(model, inputs, targets)\n",
" return tape.gradient(loss_value, model.variables)"
],
......@@ -522,7 +522,7 @@
},
"cell_type": "markdown",
"source": [
"The `grad` function uses the `loss` function and the [tfe.GradientTape](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/GradientTape) to record operations that compute the *[gradients](https://developers.google.com/machine-learning/crash-course/glossary#gradient)* used to optimize our model. For more examples of this, see the [eager execution guide](https://www.tensorflow.org/programmers_guide/eager)."
"The `grad` function uses the `loss` function and the [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) to record operations that compute the *[gradients](https://developers.google.com/machine-learning/crash-course/glossary#gradient)* used to optimize our model. For more examples of this, see the [eager execution guide](https://www.tensorflow.org/programmers_guide/eager)."
]
},
{
......@@ -614,7 +614,7 @@
" epoch_accuracy = tfe.metrics.Accuracy()\n",
"\n",
" # Training loop - using batches of 32\n",
" for x, y in tfe.Iterator(train_dataset):\n",
" for x, y in train_dataset:\n",
" # Optimize the model\n",
" grads = grad(model, x, y)\n",
" optimizer.apply_gradients(zip(grads, model.variables),\n",
......@@ -800,7 +800,7 @@
"source": [
"test_accuracy = tfe.metrics.Accuracy()\n",
"\n",
"for (x, y) in tfe.Iterator(test_dataset):\n",
"for (x, y) in test_dataset:\n",
" prediction = tf.argmax(model(x), axis=1, output_type=tf.int32)\n",
" test_accuracy(prediction, y)\n",
"\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册