Batched Text to image generation really slow with pipeline
Describe the bug
batched diffusers pipeline inference is really slow
Reproduction
class prompt_data(Dataset):
def __init__(self,prompt_file):
self.txt_prompts = []
with open(prompt_file,'r') as file:self.txt_prompts = file.readlines()
def __len__(self):
return len(self.txt_prompts)
def __getitem__(self, idx):
return self.txt_prompts[idx].rstrip()
prompt_dataset = prompt_data(args.prompt_file)
dataloader = DataLoader(prompt_dataset, batch_size=args.batch_size, shuffle=False,num_workers=6,)
args.negative_prompt = [args.negative_prompt]* args.batch_size
images_list = []
for i_batch, prompts_batched in tqdm(enumerate(dataloader)):
torch.cuda.empty_cache()
with torch.no_grad():
image = pipeline(
prompts_batched,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_images_per_prompt=1,
generator=generator,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
**args.pipe_add_kwargs,
)
print("----------batch generated----------")
images = image[0]
images_list.extend(images)
# saving images ....
Logs
No response
System Info
diffusers 0.27, torch Version: 2.0.1
Who can help?
No response
The reproduction code is not really in a place where we can take a look at it and provider you with feedback. Please format it properly and let us know.
And also, provide more information on your system, etc.
hi @sayakpaul apologies for the formating, can you PTAL now? prompts_batched is a list of len 32
The reproduction code should be more minimal. Could we eliminate the class structure and have a more simplistic reproduction code?
FWIW if you do inference just a single image and it works ok and this happens when you do batch inference, probably you're hitting the limit of your VRAM.
Also it's kind of a given that batches will take more time, so probably what you will need to give us also is:
- Amount of VRAM you have
- What model are you using
- How big is the batch you're using.
- Specify what is "really slow" with empirical data.
Right now we don't know any of this and we can't really help you, you're providing a code with a lot of unknowns.
hi @asomoza , i checked it never hits the vram limit, i'm using simple sd1.5, it just gets stuck
@asomoza @sayakpaul the reason of providing the dataset and dataloader is to show how the data is read from a text file and batched before feeding into a SD1.5 pipeline
IMO that can probably be just replicated with something much simpler:
prompt = ["a dog"]
batched_prompts = [prompt] * batch_size
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Marking due to months of inactivity by author. Please feel free to re-open the issue if it still persists