DiffMemorize icon indicating copy to clipboard operation
DiffMemorize copied to clipboard

Visualize the similar image pairs

Open Yeez-lee opened this issue 2 years ago • 1 comments

Hi,

Thanks for your efforts. I want to ask some questions about visualizations. Can you provide some codes on how to visualize the similar image pairs (like Figure 1 and 10 in your paper)?

Yeez-lee avatar Dec 01 '23 06:12 Yeez-lee

Thank you for your question. Here are some sample codes for visualising the similar image pairs. Basically, you need firstly generate images, and then search the nearest image in the training dataset using KNN distance.

import os
import numpy as np
import torch
import click
import json
import zipfile
import PIL.Image
from tqdm import tqdm
from glob import glob
try:
    import pyspng
except ImportError:
    pyspng = None

def file_ext(fname):
    return os.path.splitext(fname)[1].lower()

def load_cifar10_zip(zip_path):
    zip_file = zipfile.ZipFile(zip_path)
    all_names = set(zip_file.namelist())
    
    PIL.Image.init()
    image_names = sorted(fname for fname in all_names if file_ext(fname) in PIL.Image.EXTENSION)

    # load labels
    with zip_file.open('dataset.json', 'r') as f:
        labels = json.load(f)['labels']
    
    labels_dict = dict(labels)

    images = []
    labels = []
    
    # load images
    for name in tqdm(image_names):
        with zip_file.open(name, 'r') as f:
            if pyspng is not None and file_ext(name) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)     # HWC => CHW

        # append images
        images.append(image[np.newaxis, :, :, :])

        # append labels
        label = labels_dict[name]
        labels.append(label)

    images = np.concatenate(images, axis=0)
    labels = np.array(labels)
    labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

    return images, labels

def knn(seed_image, ref_images, k=1):
    # seed_image: [C, H, W]
    # ref_images: [N, C, H, W]
    C, H, W = seed_image.shape
    distance = torch.cdist(seed_image.reshape(1, C*H*W), ref_images.reshape(-1, C*H*W)) / np.sqrt(32 * 32 * 3)
    nearest_distance = distance.min(dim=1)[0]
    nearest_index = distance.min(dim=1)[1]
    nearest_image = ref_images[nearest_index]
    return nearest_distance, nearest_image

def plot_images(ckpt_folder):
    image_folder = os.path.join(ckpt_folder, "mem-tmp")
    save_path = os.path.join(ckpt_folder, "gen_image.png")
    total_images = []
    for i in range(3):
        row_images = []
        for j in range(8):
            index = i * 8 + j
            image_path = os.path.join(image_folder, f'{index-index%1000:06d}', f'{index:06d}.png')
            with open(image_path, 'rb') as f:
                image = pyspng.load(f.read())
                row_images.append(image)
            if j < 7:
                row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))

        row_images = np.concatenate(row_images, axis=1)
        total_images.append(row_images)
        if i < 2:
            total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
    total_images = np.concatenate(total_images, axis=0)
    PIL.Image.fromarray(total_images, 'RGB').save(save_path)

    image_folder = os.path.join(ckpt_folder, "knn-tmp")

    save_path = os.path.join(ckpt_folder, "knn_image.png")
    total_images = []
    for i in range(3):
        row_images = []
        for j in range(8):
            index = i * 8 + j
            image_path = os.path.join(image_folder, f'{index:06d}.png')
            with open(image_path, 'rb') as f:
                image = pyspng.load(f.read())
                row_images.append(image)
            if j < 7:
                row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))

        row_images = np.concatenate(row_images, axis=1)
        total_images.append(row_images)
        if i < 2:
            total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
    total_images = np.concatenate(total_images, axis=0)
    PIL.Image.fromarray(total_images, 'RGB').save(save_path)

if __name__ == "__main__":
    ckpt_folder = ""
    plot_images(ckpt_folder)

guxm2021 avatar Dec 06 '23 08:12 guxm2021