EasyTransfer icon indicating copy to clipboard operation
EasyTransfer copied to clipboard

在TensorFlow12.3版本下测试hcnn模型报错

Open zhangyunGit opened this issue 4 years ago • 1 comments

测试命令如下: easy_transfer_app --mode=train --inputSchema=query:str:1,doc:str:1,label:str:1 --inputTable=./train_lcqmc.csv,.dev_lcqmc.csv --firstSequence=query --secondSequence=doc --labelName=label --labelEnumerateValues=0,1 --batchSize=32 --numEpochs=1 --optimizerType=adam --learningRate=0.001 --modelName=text_match_hcnn --checkpointDir=./hcnn_match_models --advancedParameters='first_sequence_length=40 second_sequence_length=40 pretrain_word_embedding_name_or_path=./sgns.zhihu.char.300.bin fix_embedding=true max_vocab_size=30000 embedding_size=300 hidden_size=300' 报错信息如下: INFO:tensorflow:Initialize word embedding from pretrained Traceback (most recent call last): File "/usr/local/anaconda3/envs/tf12.3/bin/easy_transfer_app", line 8, in sys.exit(main()) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo_cli.py", line 99, in main app.run() File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/app_utils.py", line 168, in wrapper func(*args, **kw) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/base.py", line 44, in run getattr(self, self.config.mode.replace("_on_the_fly", ""))() File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/base.py", line 113, in train_and_evaluate self.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/engines/model.py", line 608, in run_train_and_evaluate eval_spec=eval_spec) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 471, in train_and_evaluate return executor.run() File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 610, in run return self.run_local() File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 711, in run_local saving_listeners=saving_listeners) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model return self._train_model_default(input_fn, hooks, saving_listeners) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1237, in _train_model_default features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/engines/model.py", line 530, in model_fn logits, labels = self.build_logits(features, mode=mode) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/text_match.py", line 618, in build_logits filter_size=self.config.filter_size)([a_embeds, b_embeds, text_a_masks, text_b_masks]) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 374, in call outputs = super(Layer, self).call(inputs, *args, **kwargs) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 757, in call outputs = self.call(inputs, *args, **kwargs) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/layers/cnn.py", line 276, in call (a_length / 4 / 3 / 2) * (b_length / 4 / 3 / 2)]) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6482, in reshape "Reshape", tensor=tensor, shape=shape, name=name) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper param_name=input_name) File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) TypeError: Value passed to parameter 'shape' has DataType float32 not in list of allowed values: int32, int64

以上基于安装官方给出的版本tf1.12.3

zhangyunGit avatar May 17 '21 12:05 zhangyunGit

label列的输入应该是int的,schema改成:label:int:1

minghui avatar Jun 10 '21 06:06 minghui