Virtual Try-On with Imagen 3

Introduction

Nitin Tiwari
6 min read3 days ago

At Google I/O 2024, Google DeepMind announced Imagen 3, their latest text-to-image model, and I’ve been keeping an eye on it for a while. What sets it apart from other diffusion models is its ability to generate high quality, photorealistic images with an incredible level of detailing.

But it’s not just limited to generating images — it can also help you edit them. Whether it’s inpainting to replace part of an image with something else, or outpainting to expand the image and add more content, Imagen 3 has you covered.

In this blog, I’ll show you an interesting use case of how you can use Imagen 3 to see what an outfit would look like on you — a practical problem I often face myself.

Imagen 3

Imagen 3 is a latent diffusion model that generates high-quality images from text prompts. It was trained on a large dataset of images, text, and annotations, along with both human and automatic evaluations. Read the complete technical report here.

One of the techniques used is Prompt-Image Alignment, which measures how accurately the input prompt is reflected in the generated image. Let’s look at an example to understand this better.

Prompt: A highly detailed, photo-realistic image of a lady smiling, with natural facial features and soft lighting. The background features a blurred, out-of-focus scene of lush green trees on a bright, sunny day, with warm sunlight casting gentle shadows.

Comparison of FLUX .1 Dev, SD3 and Imagen 3 [Source: Image by author]

For the exact same prompt, I tried generating images from three different models — FLUX .1 [Dev], Stable Diffusion 3, and Imagen 3. Based on the prompt, which of the images reflects all the properties? I’ll leave that for you to decide.

Imagen 3 can also create imaginative visuals, such as paintings, drawings, artwork, and more. Below are some prompts I experimented with on Vertex AI.

Imagen 3 examples on Vertex AI [Source: Image by author]

Now that you’re familiar with the capabilities of Imagen 3, let’s explore our use case in more detail.

Prerequisites

  • Access to Imagen 3 model on Vertex AI.
  • Gemini API (optional): For zero-shot detection, which will be further used in our pipeline.
  • SAM-2: For creating segmentation mask.

To gain a clear understanding of the approach, let’s illustrate it using a pipeline.

Pipeline

As always, we’ll work through the coding and concepts together. You can start by cloning the Colab notebook from this repository.

Note: I’m skipping the code for installing and importing libraries and dependencies to keep this blog concise. You can find the full details in the complete Colab notebook.

Step 0: Configure Gemini API key and GCP Project ID

First things first, let’s configure the Gemini API key and GCP Project ID to initialize the Vertex AI SDK.

API_KEY = userdata.get('gemini')
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(model_name='gemini-1.5-pro')

GCP_PROJECT_ID = userdata.get('GCP_PROJECT_ID')
vertexai.init(project=GCP_PROJECT_ID, location="us-central1")

Step 1: Zero-shot Object Detection

We begin by feeding the input image into the Gemini model for zero-shot object detection to identify the area (in this case, the outfit) that needs to be changed. You can choose either Gemini 1.5 Flash or Gemini 2.0 Flash.

input_image = 'image.jpg'  # @param {type: 'string'}
object_to_detect = 'hoodie' # @param {type: 'string'}
img = PIL.Image.open(input_image)


response = model.generate_content([
img,
(
f"Detect 2d bounding box of {object_to_detect} and return it in the below format"
"list. \n [ymin, xmin, ymax, xmax, object_name]. If there are more than one object, return separate lists for each object"
),
])

result = response.text

bounding_box = parse_bounding_box(result)
output, coordinates = draw_bounding_boxes(img, bounding_box)

Although this step is optional, I wanted to ensure the entire pipeline is fully automated without any manual intervention. However, if you don’t have access to the Gemini API, I’ve provided an alternative to run this without it.

Zero-shot detection by Gemini

As you can see, I’m wearing a hoodie in the input image and Gemini has accurately detected it in the image.

Step 2: Draw points within ROI

Since SAM-2 only supports point and box prompts and doesn’t directly support text prompts, we used Gemini in Step 1 to generate a bounding box around the region of interest, making it compatible with text prompts.

image = cv2.imread(input_image)
x1, y1, x2, y2 = coordinates[0], coordinates[1], coordinates[2], coordinates[3]
centre_x = (x1+x2)//2
centre_y = (y1+y2)//2

color = (0, 255, 0)
radius = 5

# Draw the points on the image.
p1 = (centre_x, centre_y)
p2 = (centre_x, (centre_y+y1)//2)
p3 = (centre_x, (centre_y+y2)//2)
p4 = ((centre_x+x1)//2, centre_y)
p5 = ((centre_x+x2)//2, centre_y)

points = [p1, p2, p3, p4, p5]

cv2.circle(image, (centre_x, centre_y), radius, color, -1)
cv2.circle(image, (centre_x, (centre_y+y1)//2), radius, color, -1)
cv2.circle(image, (centre_x, (centre_y+y2)//2), radius, color, -1)
cv2.circle(image, ((centre_x+x1)//2, centre_y), radius, color, -1)
cv2.circle(image, ((centre_x+x2)//2, centre_y), radius, color, -1)

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
roi_output = PIL.Image.fromarray(image)
roi_output.save(f'roi_point_{input_image}')
display(roi_output)
Draw key points within the ROI

Step 3: Create segmentation mask

Now that we have the points, we can prompt the SAM-2 model to generate a segmentation mask using the latest SAM-2.1 model weights.

For this task, I’ve selected the SAM 2.1 small model checkpoints, as they are sufficient. However, feel free to choose other versions if your input image contains complex segments.

# Select SAM-2.1 checkpoint and config.
sam2_checkpoint = "sam2.1_hiera_small.pt" # @param ["sam2.1_hiera_tiny.pt", "sam2.1_hiera_small.pt", "sam2.1_hiera_base_plus.pt", "sam2.1_hiera_large.pt"]
model_config = "sam2.1_hiera_s.yaml" # @param ["sam2.1_hiera_t.yaml", "sam2.1_hiera_s.yaml", "sam2.1_hiera_b+.yaml", "sam2.1_hiera_l.yaml"]

sam2_checkpoint = f"/content/sam2.1_checkpoints/{sam2_checkpoint}"
model_config = f"configs/sam2.1/{model_config}"

sam2_model = build_sam2(model_config, sam2_checkpoint, device="cuda")
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
predictor = SAM2ImagePredictor(sam2_model)

input_point = np.array([
[
point[0],
point[1]
] for point in points
])

input_label = np.ones(input_point.shape[0])

predictor.set_image(img)

# Predict the segmentation mask.
masks, scores, logits = predictor.predict(
point_coords = input_point,
point_labels = input_label,
multimask_output = False,
)

mask_img = PIL.Image.fromarray((masks[0]*255).astype(np.uint8))
mask_img.save(f"{annotation}_mask_{input_image}")
display(mask_img)
Binary mask created by SAM-2.1

The binary mask generated appears accurate based on the points provided to SAM-2.1, with only a small amount of noise here and there. However, if you believe better points could have been used, I’ve written a code that allows you to manually input the points through the UI.

Step 4: Generate new image with Imagen 3

Finally, we’re ready to generate a new image using Imagen 3 and inpaint it onto the segmentation mask.

Below are some configurations that we need to provide as parameters:

  • Base image: The input image file to be edited.
  • Mask file: The binary mask of the object to be inpainted.
  • Prompt: The input prompt to generate new image.
  • Edit mode: Options like inpainting-insert, inpainting-remove, or outpainting.
  • Mask mode: Choose between background or foreground.
  • Dilation: A float value between 0 and 1 indicating the percentage of the provided mask.

mask_file = f"{annotation}_mask_{input_image}"
output_file = f"output_{input_image}"
prompt = "A dark green hoodie, white shirt inside, waist length" # @param {type: 'string'}

edit_mode = 'inpainting-insert' # @param ['inpainting-insert', 'outpainting', 'inpainting-remove']
mask_mode = 'foreground' # @param ['foreground', 'background']
dilation = 0.01 # @param {type:"slider", min:0, max:1, step:0.01}

edit_model = ImageGenerationModel.from_pretrained("imagen-3.0-capability-001")

base_img = Image.load_from_file(location=input_image)
mask_img = Image.load_from_file(location=mask_file)

raw_ref_image = RawReferenceImage(image=base_img, reference_id=0)
mask_ref_image = MaskReferenceImage(
reference_id=1, image=mask_img, mask_mode=mask_mode, dilation=dilation
)

edited_image = edit_model.edit_image(
prompt=prompt,
edit_mode=edit_mode,
reference_images=[raw_ref_image, mask_ref_image],
number_of_images=1,
safety_filter_level="block_some",
person_generation="allow_adult",
)

edited_image[0].save(output_file)
edited_image[0].show()
Final output

This looks so fantastic and almost perfect.

I’m usually slow and indecisive when it comes to selecting outfits while shopping, but with Imagen 3’s impressive generation and editing capabilities, I guess that won’t be a problem anymore. 😜

I hope you enjoyed reading this blog and found it informative. If you appreciated my work, feel free to ⭐ the repository and share it with others.

The future certainly looks bright with these state-of-the-art models, and I’m excited to see the amazing things you’ll create with Imagen 3. Meanwhile, see you again with more cool reads.

References & Resources

--

--

Responses (1)