Deep-Learning- icon indicating copy to clipboard operation
Deep-Learning- copied to clipboard

Help in generating .tfilte for CNN-Text-Classificationa

Open skmalviya opened this issue 5 years ago • 0 comments

Can you help me out with the code at [(https://github.com/dennybritz/cnn-text-classification-tf)] ? I am new to tensorflow. I want to create .tflite file for the model in train.py, As you mentioned in the video It starts with making a checkpoint, save its graph file as .pbtxt, freeze it with creating .pb file and then finally converting it to .tflite which I want to obtain at last. I run it in CPU mode with tensorflow=1.13.1. I am able to generate both .pbtxt and .pb file successfully for the very first checkpoint, but getting error message at the tf.lite.TocoConverter.from_frozen_graph() line of my code.

# Training loop. For each batch...
            for batch in batches:
                x_batch, y_batch = zip(*batch)
                train_step(x_batch, y_batch)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("\nEvaluation:")
                    dev_step(x_dev, y_dev, writer=dev_summary_writer)
                    print("")
                if current_step % FLAGS.checkpoint_every == 0:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    tf.train.write_graph(sess.graph_def, checkpoint_dir, 'savegraph.pbtxt') #saving the model's tensorflow graph definition
                    freez_grph(checkpoint_dir)
                    inp_node = ['input_x']
                    out_node = ['output']
                    #nodes = [e.name + '=>' +  e.op for e in tf.get_default_graph().as_graph_def().node if e.op in  (( 'Softmax','Placeholder'))]
                    #print (nodes)
                    #converter = tf.lite.TFLiteConverter.from_session(sess, [cnn.embedded_chars_expanded], [cnn.input_y])
                    converter = tf.lite.TocoConverter.from_frozen_graph(checkpoint_dir+'/frozen_model_TextCNN Model.pb',inp_node, out_node)
                    tflite_model = converter.convert()
                    open("TextCNN.tflite", "wb").write(tflite_model)
                    exit()
                    print("Saved model checkpoint to {}\n".format(path))

skmalviya avatar Mar 06 '20 12:03 skmalviya