dimanche 23 juin 2019

Reset default graph upon exiting tf.Session() in unit tests

  • At the end of each unit tests, I call tf.reset_default_graph() to clear the default graph.
  • However, when a unit test fails, the graph doesn't get cleared. That makes the next unit test fails as well.

How to clear a graph upon exiting the tf.Session() context?

Example (pytest):

import tensorflow as tf


def test_1():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

Aucun commentaire:

Enregistrer un commentaire