3

I am trying to modify the HF UNet's for diffusion models. I'm doing this by adding conditions during the down and up blocks. This is a minimal example of the problem. It seems the last down_block is firing before the first one. Nothing in the UNet model's source code suggests that it should be firing in reverse order.

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import UNet2DConditionModel

# config
SD_MODEL = "runwayml/stable-diffusion-v1-5"
DIM = 15

unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")

bs = 2
timestep = torch.randint(0, 100, (bs,))
noise = torch.randn((bs, 4, 64, 64))
text_encoding = torch.randn((bs, 77, 768))
condition = torch.randn((bs, DIM))

DownOutput = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

class ConditionResnet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.call_count = 0
        self.projector = nn.Linear(in_dim, out_dim)
        self.conv1 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
        self.non_linearity = F.silu
                        
    def forward(self, out: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        self.call_count += 1
        input_vector = out
        out = self.conv1(out) + self.projector(condition)[:, :, None, None]
        return input_vector + self.non_linearity(out)

# down blocks return tuples, so need slightly modified version  
class ConditionResnetDown(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.condition_resnet = ConditionResnet(in_dim, out_dim)
        
    def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
        return self.condition_resnet(x[0], condition), x[1]

class UNetWithConditions(nn.Module):
    def __init__(self, unet: nn.Module, col_channels: int, down_block_sizes: list[int], up_block_sizes: list[int]):
        super().__init__()
        self.unet = unet
        self.down_block_condition_resnets = nn.ModuleList([ConditionResnetDown(col_channels, out_channel) for out_channel in down_block_sizes])
        self.up_block_condition_resnets = nn.ModuleList([ConditionResnet(col_channels, out_channel) for out_channel in up_block_sizes])
        
        self.condition = None
        
        # forward hooks
        for i in range(len(self.unet.down_blocks)):
            self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
        for i in range(len(self.unet.up_blocks)):
            self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
        
    def forward(self, noise, timestep, text_encoding, condition):
        self.condition = condition
        out = self.unet(noise, timestep, text_encoding).sample
        self.condition = None
        return out

unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

The reason I know that the last down_block is being fired is because I look at call_count of ConditionResnet via ( [a.condition_resnet.call_count for a in unet_with_conditions.down_block_condition_resnets], [a.call_count for a in unet_with_conditions.up_block_condition_resnets], ) and I get this: ([0, 0, 0, 1], [0, 0, 0, 0]).

Potential Causes

  • Is it possible that if the model was compiled (maybe with jit) that the wrong hook is being fired. If I was to modify the model code above and execute it, it seems to crash with the old code for some reason.
  • Not sure if passing in extra inputs is allowed with forward hooks.

Error Log:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~tmp/ipykernel_574/3305635741.py in <cell line: 2>()
      1 unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
----> 2 out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, noise, timestep, text_encoding, condition)
     59     def forward(self, noise, timestep, text_encoding, condition):
     60         self.condition = condition
---> 61         out = self.unet(noise, timestep, text_encoding).sample
     62         self.condition = None
     63         return out

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~nix/store/vzqny68wq33dcg4hkdala51n5vqhpnwc-python3-3.9.12/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
    795         for downsample_block in self.down_blocks:
    796             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 797                 sample, res_samples = downsample_block(
    798                     hidden_states=sample,
    799                     temb=emb,

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1213         if _global_forward_hooks or self._forward_hooks:
   1214             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1215                 hook_result = hook(self, input, result)
   1216                 if hook_result is not None:
   1217                     result = hook_result

~tmp/ipykernel_574/2058376754.py in <lambda>(module, inputs, outputs)
     53         # forward hooks
     54         for i in range(len(self.unet.down_blocks)):
---> 55             self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
     56         for i in range(len(self.unet.up_blocks)):
     57             self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, x, condition)
     40 
     41     def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
---> 42         return self.condition_resnet(x[0], condition), x[1]
     43 
     44 class UNetWithConditions(nn.Module):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, out, condition)
     30         self.call_count += 1
     31         input_vector = out
---> 32         out = self.conv1(out) + self.projector(condition)[:, :, None, None]
     33         return input_vector + self.non_linearity(out)
     34 

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    457                             weight, bias, self.stride,
    458                             _pair(0), self.dilation, self.groups)
--> 459         return F.conv2d(input, weight, bias, self.stride,
    460                         self.padding, self.dilation, self.groups)
    461 

RuntimeError: Given groups=1, weight of size [1280, 1280, 3, 3], expected input[2, 320, 32, 32] to have 1280 channels, but got 320 channels instead
sachinruk
  • 9,571
  • 12
  • 55
  • 86

1 Answers1

2

I see you are using lambda functions defined within the loops that register the forward hooks.

# forward hooks
for i in range(len(self.unet.down_blocks)):
    self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
    self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))

But the lambda functions are capturing the variable i by reference. This means that when the lambda functions are executed, they will use the current value of i rather than the value it had when the lambda function was created ("late binding").
Since i is incremented in each iteration of the loop, by the time the lambda functions are actually called, i is equal to its final value at the end of the loop. That would explain why the last down_block is firing before the first one, as all the lambda functions are using the final value of i.

To fix this, you can use a default argument in the lambda function to capture the current value of i in each iteration of the loop, like this:

# forward hooks
for i in range(len(self.unet.down_blocks)):
    self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs, i=i: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
    self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs, i=i: self.up_block_condition_resnets[i](outputs, self.condition))

This captures the current value of i for each lambda function, ensuring that the correct index is used when the functions are called later.

VonC
  • 1,262,500
  • 529
  • 4,410
  • 5,250