text icon indicating copy to clipboard operation
text copied to clipboard

`CLIPTokenizer` should output tensors in it's `forward` function rather than lists of numbers in str form

Open ProGamerGov opened this issue 3 years ago • 6 comments

🚀 Feature

Outputs for the current CLIP tokenizer appear to be a list of strings of numbers, rather than a tensor or even a list of numbers:

clip_tokenizer = torchtext.transforms.CLIPTokenizer(merges_path="clip_bpe.txt")

test_str = "This is a test"
test_output = clip_tokenizer(test_str)
print(test_output) # ['589', '533', '320', '1628']

It might be easier to have outputs be tensors with a shape of [batch, tokens].

OpenAI's CLIP tokenizer for example returns outputs like these, with zeros filling up the rest of the model's content_length (which also might be a good idea to include as a variable in torchtext's tokenizer):

test_str = "This is a test"
test_output_clip = clip.tokenize(test_str)
print(test_output_clip)

# shape: [1, 77]
tensor([[49406,   589,   533,   320,  1628, 49407,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])

https://github.com/openai/CLIP/blob/main/clip/clip.py#L195

ProGamerGov avatar Feb 25 '22 21:02 ProGamerGov

cc @abhinavarora

Nayef211 avatar Feb 25 '22 23:02 Nayef211

Thanks @ProGamerGov for raising the issue. The output for the transform is to keep consistency with other tokenizer transforms (SentencePiece, GPT2BPE). You can do the conversion to tensor using functional to_tensor or ToTensor nn module. You also have the flexibility to provide padding index which may not always be 0 (for example in RoBERTa/XLM-R models, the pad id is 1).

parmeet avatar Feb 28 '22 13:02 parmeet

@parmeet Ah, okay. I put together a helper function to convert the outputs to the proper format for the CLIP models:

def token_str_to_tensor(
    token_list: Union[List[str], List[List[str]]], content_length: int = 77
) -> torch.Tensor:
    """
    Convert torchtext tokenizer outputs to tensor format.

    Args:

        token_list (list of str or list of list of str): Token values to be converted
            to tensors.
        content_length (int, optional): The content length to use.
            Default: 77

    Returns:
        tokens (torch.Tensor): A tensor containing each token set stacked across the
            batch dimension.
    """
    token_list = (
        [token_list] if not isinstance(token_list[0], (tuple, list)) else token_list
    )
    assert all([len(t) <= content_length for t in token_list])
    tokens = [torch.as_tensor([int(v) for v in t_list]).int() for t_list in token_list]
    tokens = [
        torch.cat(
            [
                x,
                torch.zeros(
                    content_length - x.shape[0], dtype=x.dtype, device=x.device
                ),
            ],
            0,
        )
        for x in tokens
    ]
    return torch.stack(tokens)

ProGamerGov avatar Feb 28 '22 15:02 ProGamerGov

You can also try to use Sequential to compose the transform as shown here for XLM-R text pre-processing. In the linked example, you would need to add ToTensor transform such that you would get tensor instead of List[List[str]]. For VocabTransform (converting string indices to corresponding integer ids) you can construct the corresponding vocab object and pass it to the transform.

parmeet avatar Feb 28 '22 22:02 parmeet

Thanks @ProGamerGov for reporting this. We are looking to standardize the interface of our transforms/tokenizers and this will be an action item for us to refactor this if needed. The general philosophy is that any tokenizer has an interface like this:

def tokenize(sentence: str) -> List[str]:
    pass

Now the resulting tokens that are produced by the tokenizer could be anything, they could be real tokens or subword ids depending on the underlying tokenizer and the users are free to interpret them anyhow for downstream models. Let me know if this makes sense. Please feel free to share any ideas you have :-)

abhinavarora avatar Mar 01 '22 18:03 abhinavarora

@parmeet So, it looks like the ToTensor transform doesn't allow padding to be greater than the given list length. It would require concatinating a set of zeros (in the case of CLIP) to reach the content length size of 77.

ProGamerGov avatar Mar 05 '22 21:03 ProGamerGov