PSPNet-tensorflow icon indicating copy to clipboard operation
PSPNet-tensorflow copied to clipboard

Training with a different number of classes

Open anny123123 opened this issue 8 years ago • 5 comments

Thanks for this great repository! I am trying to train the model on my own dataset which has only 5 labels. When I try to train with loaded checkpoints from the google drive you provided I get the following error:

File "C:\Users\mffarber\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1470, in init self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [1,1,512,5] rhs shape= [1,1,512,19] [[Node: save_1/Assign_556 = Assign[T=DT_FLOAT, _class=["loc:@conv6/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](conv6/weights, save_1/RestoreV2_556)]]

I assume this is because the checkpoint corresponds to a model that trained on 19 classes. Is there some other pretrained weights that I can use? or do I have to train from scratch?

anny123123 avatar Jan 12 '18 02:01 anny123123

So do I

FeiWard avatar Jan 18 '18 02:01 FeiWard

maybe use python train.py --not-restore-last

lijia-xing avatar Jan 22 '18 07:01 lijia-xing

yes, just like @lijia-xing mentioned. You could change following line in train.py

restore_var = [v for v in tf.global_variables() if 'conv6' not in v.name]

hellochick avatar Jan 22 '18 10:01 hellochick

i want to ask about step to train this model on different dataset , what can i do to train and how to run that to start train , i still beginner so i need help for them . Thanks @hellochick @lijia-xing @anny123123

shadydiaa avatar Apr 23 '18 11:04 shadydiaa

Hello @hellochick Thank you for your solution for train.py I am wondering how should we change the evaluate.py? I got the error "Assign requires shapes of both tensors to match. lhs shape= [1] rhs shape= [2] ; [Node: save/Assign_300 = Assign[T=DT_FLOAT, _class=["loc:@conv6/biases"]", and I believe this error is caused because the change made in train.py. Thank you.

qian49 avatar Nov 15 '18 08:11 qian49