CTranslate2 icon indicating copy to clipboard operation
CTranslate2 copied to clipboard

Gemma model - help needed

Open carolinaxxxxx opened this issue 1 year ago • 4 comments

Can any colleague help with the example of interference with the Gemma model in CTranslate2? Unfortunately, there is no information about this model in the documentation.

Thx

carolinaxxxxx avatar Jun 19 '24 13:06 carolinaxxxxx

Hello, I will update the doc in the future. BTW, you can convert the Gemma like mention in the llama documentation.

ct2-transformers-converter --model google/gemma-7b --output_dir gemma_ct2

Then you can try with script:

import ctranslate2
import transformers

generator = ctranslate2.Generator("gemma_ct2")
tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-7b")

b_inst = '<start_of_turn>'
e_inst = '<end_of_turn>'
intput = 'Ask something'
prompt = b_inst + 'user' + input + e_inst + '\n' + b_inst + model
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
results = generator.generate_batch([tokens], max_length=30, sampling_topk=10)
print(tokenizer.decode(results[0].sequences_ids[0]))

minhthuc2502 avatar Jun 20 '24 12:06 minhthuc2502

@minhthuc2502 - Does CTranslate2 support openchat models e.g. openchat/openchat-3.5-0106-gemma? I managed to perform the conversion to ct2, but I can't "make" it work properly? THX

carolinaxxxxx avatar Jun 25 '24 20:06 carolinaxxxxx

What is the error? I see the defined architecture in openchat model is GemmaForCausalLM so I think it should work.

minhthuc2502 avatar Jun 26 '24 08:06 minhthuc2502

@minhthuc2502 I use below code for test:

import ctranslate2
import transformers

generator = ctranslate2.Generator("/test/openchat35gemma", device="cuda", device_index=1)
tokenizer = transformers.AutoTokenizer.from_pretrained("/test/openchat35gemma")

prompt = f"GPT4 Correct User: Hello<end_of_turn>GPT4 Correct Assistant: Hi<end_of_turn>GPT4 Correct User: How are you today?<end_of_turn>GPT4 Correct Assistant: "
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
results = generator.generate_batch([tokens], max_length=4096, sampling_temperature=0.1, sampling_topk=1, sampling_topp=0.1, include_prompt_in_result=False)
print(tokenizer.decode(results[0].sequences_ids[0]))

The result is random characters. Where did I go wrong? tokenizer.model is on path. THX.

carolinaxxxxx avatar Jun 27 '24 11:06 carolinaxxxxx