[textual_inversion] Add an option for only saving the embeddings
The textual inversion example now saves by default only the learned embeddings, unless the command line arguments --save_full_model or --push_to_hub are used. (Implements #759)
The documentation is not available anymore as the PR was closed or merged.
The main change request is to add documentation, isn't it? I'm willing to do it, I think I would just add a sentence under the example command line.
I should have used more feature branches, but what I could contribute next:
- Load textual embeddings: This are also 5 lines that need to be included in the example in the documentation, so the saved embedding makes sense there
- Resuming training (I just load the last saved embedding if it exists)
- Caching images. The current script seems to load, scale and crop the images in each
__getitem__call. I have some code that stores the images the were used before. - unet attention slicing. It doesn't seem to help too much, but maybe a bit?
- Loading template strings from a file. For learning drawn images I added an own set of strings, which, for example, avoid the term "photo".
The main change request is to add documentation, isn't it? I'm willing to do it, I think I would just add a sentence under the example command line.
Yes, and also a snippet that shows how to load the embedding in the model and do inference. Then this should be good to merge.
And that's an awesome list of features, looking forward to it :)
This may get a bit long for the documentation, maybe it should better be a second example script.
The main part is simple:
# Load the learned embeds
learned_embeds_dict = torch.load("textual_inversion_cat/learned_embeds.bin")
learned_embeds = learned_embeds_dict[token]
# Add the embeds to the text encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
learned_embeds.to(dtype)
text_encoder.get_input_embeddings().weight.data[token] = learned_embeds
But it depends on the resized text_encoder from the script and also needs a reference to tokenizer, which would add quite a few more lines.
What do you think about an examples/textual_inversion/inference.py script and adding --save_full_model to the example in the documentation with a note that one can also load only the embeds?
I wonder if loading the custom embeds could also become part of the pipeline, but for the pipeline it would probably be an API change that needs quite a bit of discussion first.
Looks like the latest changes change quite a bit. I wonder if I should push the full script into my repo and you like to have a look at the changes.
Adding them bit by bit would be good, but I guess it makes things more complicated when the script in the main repo is changing a lot, while I would wait for 3 PRs to be merged one after another.
Gently ping here @patil-suraj
This may get a bit long for the documentation, maybe it should better be a second example script.
The main part is simple:
# Load the learned embeds learned_embeds_dict = torch.load("textual_inversion_cat/learned_embeds.bin") learned_embeds = learned_embeds_dict[token] # Add the embeds to the text encoder dtype = text_encoder.get_input_embeddings().weight.dtype learned_embeds.to(dtype) text_encoder.get_input_embeddings().weight.data[token] = learned_embedsBut it depends on the resized
text_encoderfrom the script and also needs a reference totokenizer, which would add quite a few more lines.What do you think about an
examples/textual_inversion/inference.pyscript and adding--save_full_modelto the example in the documentation with a note that one can also load only the embeds?I wonder if loading the custom embeds could also become part of the pipeline, but for the pipeline it would probably be an API change that needs quite a bit of discussion first.
Hey @allo- good point! We could add a script like load_embeds_in_the_model.py script, which will take the path to the embeds and model, load the embeds and save the model.
So let's
- Add an option called
--only_save_embeds, if it'sTruewe will only save trained embeds and skip saving the full model. - add the script to load embeds in the model.
Then it should be good to merge :) And let me know if you are busy, happy to take care of this :)
Let's continue next week.
In my full script there are the 3-4 changes combined, the snippet is from the load method, but its called after resizing the tokenizer for training. The cache for the resized images is quite independent from it, but for merging it needs to be (re)based on the recent commits and I still need to have a look what you changed in the meantime.
Gently pinging @allo- , let me know if I can help here :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
I think that's it for a first minimal step for just preventing saving the full model.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
For the next steps: What do you think about automatically loading existing learned_embeddings.bin files in the script? It requires little changes in the training script (I load them after the tokenizer already is resized) and would also be an example for writing own scripts.
Adding another example script, that loads a model and then embeddings would be the step following this one, but adding it to the training script next would probably be the most logical thing, when one likes to split it into smaller incremental commits.
cc @patil-suraj for final review
@patrickvonplaten @patil-suraj Anything else neccessary to get this merged?
Added some rebasing onto current main without merging a branch for a single commit.
Let me see if I can resurrect the last version before merging the latest upstream or put back things from my local copy.
I did merge it against latest upstream. The argument got lost when --revision was added.
A few questions regarding the diff to current upstream:
- How should the imports be organized? I wonder why there is no explicit SafetyChecker import anymore.
- Should the scheduler/safety_checker/... arguments be explicit? I am not sure in which commits it was added/removed and what's the plan how it should look like
- I changed the local variable
only_save_embdstosave_full_modelwith inverted logic, so the confusion with the command line argument is resolved.save_full_modelcan be changed andargs.only_save_embedsis fixed. - General question concerning the pipeline: Does it even make sense to have the SafetyChecker there? No image is generated and the safety of the user input depends on the user.
In practice, I can train 512x512 samples with 12 GB VRAM (on a Desktop computer, that also needs some VRAM for the UI) without SafetyChecker and need to reduce to e.g. 400x400 when loading the checker (I didn't test where the exact limit is). It may be a good idea not to load it.
Red circle ci is unrelated - thanks a lot for making the changes @allo- -> merging!