HybridBackend icon indicating copy to clipboard operation
HybridBackend copied to clipboard

Error when running imported/restored model that uses feedable iterator

Open fuhailin opened this issue 3 years ago • 0 comments

I got a situation where I trained a model and saved its checkpoint files, then I need to restore the graph from the meta file and feed a new data iterator to keep training, so i find a issue talking about that, then i write some code to demo my situation.

Current behavior

When i use ParquetDataset to feed, i can't restore the meta file, and got the following error:

Traceback (most recent call last):
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test/io/restore_hb.py", line 223, in <module>
    resume_training(another_train_dataset, another_test_dataset)
  File "test/io/restore_hb.py", line 132, in resume_training
    saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
    **kwargs)[0]
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.

I guess that error not belongs to a bug for HybridBackend, because i also try the TFRecordDataset and get a similar error:

Traceback (most recent call last):
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test/io/restore_pb.py", line 225, in <module>
    restore_feed()
  File "test/io/restore_pb.py", line 220, in restore_feed
    resume_training(another_train_dataset, another_test_dataset)
  File "test/io/restore_pb.py", line 155, in resume_training
    saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
    **kwargs)[0]
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.

But that process works for from_tensor_slices and CsvDataset, i'm just curious and want to know how to restore and feed a new dataset iterator.

Expected behavior

When i use ParquetDataset in traing, i can restore the checkpoint and feed a new ParquetDataset iterator

System information

  • GPU model and memory: 16G for Tesla T4
  • OS Platform: Ubuntu 18.04.5 LTS (Bionic Beaver)
  • Docker version: 20.10.14
  • GCC/CUDA/cuDNN version: gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04),
  • Python/conda version: Python 3.6.12
  • TensorFlow/PyTorch version: tensorflow 1.15.5+deeprec2201

Code to reproduce

training and restore use ParquetDataset to feed that doesn't work

# Tensorflow 1.15
# https://github.com/tensorflow/tensorflow/issues/11679#
#
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from hybridbackend.tensorflow.data import DataFrame
from hybridbackend.tensorflow.data import ParquetDataset
from tensorflow.python.data.ops import dataset_ops

new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_parquet('train.parquet')

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_parquet('test.parquet')


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def train(train_dataset, test_dataset):
  """
    Create graph with an Dataset and Iterator and save the model.

    There is some op that is applied to the data from the iterator.
    """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element['test1']))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_hb/fufu')

def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
    saver.restore(sess, 'checkpoints_hb/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_hb/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_hb'):
    shutil.rmtree('checkpoints_hb')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = ParquetDataset('train.parquet',
                                 batch_size=1,
                                 fields=[DataFrame.Field('test1', tf.float32),
                                         DataFrame.Field('test2', tf.float32)])
  test_dataset = ParquetDataset('test.parquet',
                                batch_size=1,
                                fields=[DataFrame.Field('test1', tf.float32),
                                        DataFrame.Field('test2', tf.float32)])
  # test_dataset = tf.data.Dataset.from_tensor_slices(
  # tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = ParquetDataset(
      'train.parquet',
      batch_size=1,
      fields=[DataFrame.Field('test1', tf.float32),
              DataFrame.Field('test2', tf.float32)])
  another_test_dataset = ParquetDataset(
      'test.parquet', batch_size=1, fields=[DataFrame.Field('test1', tf.float32),
                                            DataFrame.Field('test2', tf.float32)])

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

It works neither for TFRecordDataset.

import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.ops import dataset_ops


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


# Define features
feature_description = {
    'test1': tf.io.FixedLenFeature([], dtype=tf.float32),
    'test2': tf.io.FixedLenFeature([], dtype=tf.float32)
}


def _parse_function(example_proto):
  return tf.io.parse_example(example_proto, feature_description)


def write_pb(df, file):
  # Write TFrecord file
  with tf.io.TFRecordWriter(file) as writer:
    for index, row in df.iterrows():
      print(row['test1'], row['test2'])
      # Create the Example
      example = tf.train.Example(features=tf.train.Features(
          feature={
              'test1': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test1']])),
              'test2': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test2']]))
          }))
      writer.write(example.SerializeToString())


new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
write_pb(train_df, 'train.tfrecord')

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
write_pb(test_df, 'test.tfrecord')


def train(train_dataset, test_dataset):
  """
  Create graph with an Dataset and Iterator and save the model.

  There is some op that is applied to the data from the iterator.
  """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element['test1']))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_pb/fufu')


def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
    saver.restore(sess, 'checkpoints_pb/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_pb/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_pb'):
    shutil.rmtree('checkpoints_pb')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
  test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
  another_test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

But works for CsvDataset

import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops

new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_csv('train.csv', index=False)

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_csv('test.csv', index=False)


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def train(train_dataset, test_dataset):
  """
    Create graph with an Dataset and Iterator and save the model.

    There is some op that is applied to the data from the iterator.
    """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_csv/fufu')


def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_csv/fufu.meta')
    saver.restore(sess, 'checkpoints_csv/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_csv/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_csv'):
    shutil.rmtree('checkpoints_csv')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
  test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)
  # test_dataset = tf.data.Dataset.from_tensor_slices(
  # tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
  another_test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

Willing to contribute

Yes

fuhailin avatar Apr 12 '22 07:04 fuhailin