I am trying to modify the HF UNet's for diffusion models. I'm doing this by adding conditions during the down and up blocks. This is a minimal example of the problem. It seems the last down_block is firing before the first one. Nothing in the UNet model's source code suggests that it should be firing in reverse order.
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DConditionModel
# config
SD_MODEL = "runwayml/stable-diffusion-v1-5"
DIM = 15
unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")
bs = 2
timestep = torch.randint(0, 100, (bs,))
noise = torch.randn((bs, 4, 64, 64))
text_encoding = torch.randn((bs, 77, 768))
condition = torch.randn((bs, DIM))
DownOutput = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
class ConditionResnet(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.call_count = 0
self.projector = nn.Linear(in_dim, out_dim)
self.conv1 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
self.non_linearity = F.silu
def forward(self, out: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
self.call_count += 1
input_vector = out
out = self.conv1(out) + self.projector(condition)[:, :, None, None]
return input_vector + self.non_linearity(out)
# down blocks return tuples, so need slightly modified version
class ConditionResnetDown(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.condition_resnet = ConditionResnet(in_dim, out_dim)
def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
return self.condition_resnet(x[0], condition), x[1]
class UNetWithConditions(nn.Module):
def __init__(self, unet: nn.Module, col_channels: int, down_block_sizes: list[int], up_block_sizes: list[int]):
super().__init__()
self.unet = unet
self.down_block_condition_resnets = nn.ModuleList([ConditionResnetDown(col_channels, out_channel) for out_channel in down_block_sizes])
self.up_block_condition_resnets = nn.ModuleList([ConditionResnet(col_channels, out_channel) for out_channel in up_block_sizes])
self.condition = None
# forward hooks
for i in range(len(self.unet.down_blocks)):
self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
def forward(self, noise, timestep, text_encoding, condition):
self.condition = condition
out = self.unet(noise, timestep, text_encoding).sample
self.condition = None
return out
unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
out2 = unet_with_conditions(noise, timestep, text_encoding, condition)
The reason I know that the last down_block is being fired is because I look at call_count
of ConditionResnet
via ( [a.condition_resnet.call_count for a in unet_with_conditions.down_block_condition_resnets], [a.call_count for a in unet_with_conditions.up_block_condition_resnets], )
and I get this: ([0, 0, 0, 1], [0, 0, 0, 0])
.
Potential Causes
- Is it possible that if the model was compiled (maybe with jit) that the wrong hook is being fired. If I was to modify the model code above and execute it, it seems to crash with the old code for some reason.
- Not sure if passing in extra inputs is allowed with forward hooks.
Error Log:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~tmp/ipykernel_574/3305635741.py in <cell line: 2>()
1 unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
----> 2 out2 = unet_with_conditions(noise, timestep, text_encoding, condition)
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, noise, timestep, text_encoding, condition)
59 def forward(self, noise, timestep, text_encoding, condition):
60 self.condition = condition
---> 61 out = self.unet(noise, timestep, text_encoding).sample
62 self.condition = None
63 return out
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~nix/store/vzqny68wq33dcg4hkdala51n5vqhpnwc-python3-3.9.12/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
795 for downsample_block in self.down_blocks:
796 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 797 sample, res_samples = downsample_block(
798 hidden_states=sample,
799 temb=emb,
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1213 if _global_forward_hooks or self._forward_hooks:
1214 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1215 hook_result = hook(self, input, result)
1216 if hook_result is not None:
1217 result = hook_result
~tmp/ipykernel_574/2058376754.py in <lambda>(module, inputs, outputs)
53 # forward hooks
54 for i in range(len(self.unet.down_blocks)):
---> 55 self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
56 for i in range(len(self.unet.up_blocks)):
57 self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, x, condition)
40
41 def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
---> 42 return self.condition_resnet(x[0], condition), x[1]
43
44 class UNetWithConditions(nn.Module):
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, out, condition)
30 self.call_count += 1
31 input_vector = out
---> 32 out = self.conv1(out) + self.projector(condition)[:, :, None, None]
33 return input_vector + self.non_linearity(out)
34
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in forward(self, input)
461
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
464
465 class Conv3d(_ConvNd):
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
457 weight, bias, self.stride,
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
461
RuntimeError: Given groups=1, weight of size [1280, 1280, 3, 3], expected input[2, 320, 32, 32] to have 1280 channels, but got 320 channels instead