PuLID icon indicating copy to clipboard operation
PuLID copied to clipboard

Flux model is being fully created before bfloat16 cast

Open PRPA1984 opened this issue 1 year ago • 1 comments

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

PRPA1984 avatar Sep 13 '24 19:09 PRPA1984

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

Everyone with the same problem. Need to add .to(torch.bfloat16) in flux/model.py here:

 self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                ).to(torch.bfloat16)

                for _ in range(params.depth)
            ]
        )

and here

self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio).to(torch.bfloat16)

                for _ in range(params.depth_single_blocks)
            ]
        )

denred0 avatar Sep 15 '24 03:09 denred0