You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to use in-context learning in qwen1.5-72b-chat inference and thus use tokenizer.apply_chat_template as in the official tutorial, however ValueError: max() arg. Doesn't airllm support the official inference way ?
#148
Open
Yang-bug-star opened this issue
Jun 16, 2024
· 0 comments
Here is my main inference code:
model = AirLLMQWen("/data2/Qwen1.5-72B-Chat", compression='8bit')
device = "cuda"
for ex in instruct_few_shot_examples:
#total_weibo_text_tokens += len(encoding.encode(ex['user'])) + len(encoding.encode(ex["gpt"]))
messages += [
{"role": "user", "content": ex['user']},
{"role": "assistant", "content": ex["gpt"]},
]
test_case = "【沙尘暴黄色预警+大风蓝色预警】#萌台报天气# 预计今天白天至夜间,京津冀、黑吉辽、山西、陕西、山东、河南等10余省份沙尘继续肆虐,并伴有大风天气,其中内蒙古中部等地的部分地区有沙尘暴。你家的天还是蓝色吗?! [组图共3张] 原图 "
messages += [
{"role": "user", "content": test_case},
]
text = model.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model.tokenizer.pad_token_id = model.tokenizer.im_end_id
input_tokens = model.tokenizer([text],
return_tensors="pt",
return_attention_mask=False,
padding=True
).to(device)
generation_output = model.generate(
input_tokens.input_ids,
use_cache=True,
return_dict_in_generate=True)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(input_tokens.input_ids, generation_output)
]
response = model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
The text was updated successfully, but these errors were encountered: