I'm working on a vision task using JAX, and I'm facing an issue with passing intermediate JAX tracer objects as images to the CLIP model for calculating the loss. The CLIP model expects NumPy arrays as inputs, so the JAX tracer objects are not directly compatible.
Here's a simplified version of the code:
img_txt_clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
def train_step(state, batch, rng):
"""Train Step"""
inputs, targets = batch
def clip_loss_fn(params):
model_fn = lambda x: state.apply_fn({"params": params}, x)
ray_origins, ray_directions = inputs
rgb, *_ = perform_volume_rendering(
model_fn, ray_origins, ray_directions, rng
)
# Here's where the issue arises
clip_input = processor(
text=["a bulldozer"], images=[rgb], return_tensors="jax", padding=True
)
outputs = img_txt_clip(**clip_input)
logits_per_image = outputs.logits_per_image
return jnp.mean(logits_per_image)
train_loss, gradients = jax.value_and_grad(clip_loss_fn)(state.params)
gradients = lax.pmean(gradients, axis_name="batch")
new_state = state.apply_gradients(grads=gradients)
train_loss = jnp.mean(train_loss)
return train_loss, new_state
I've tried using the FlaxCLIPModel for compatibility with JAX, but passing the JAX tracer objects as images to the CLIP model raises an error. Converting the JAX TRACER objects to NumPy arrays would be inefficient.
I would appreciate any suggestions or solutions to either convert the JAX TRACER objects to NumPy arrays efficiently or make the CLIP model accept the JAX TRACER objects as inputs.
Thank you in advance for your help!