Use actual template from Apertus

This commit is contained in:
Michelle Winkler
2025-09-24 11:43:06 +02:00
parent 23691ba663
commit bc12eb1bc1

48
app.py
View File

@@ -1,24 +1,44 @@
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import tokenizer
from torchao.quantization import quantize_, int8_weight_only from torchao.quantization import quantize_, int8_weight_only
model_name = "swiss-ai/Apertus-8B-Instruct-2509" model_name = "swiss-ai/Apertus-8B-2509"
device = "cuda" # for GPU usage or "cpu" for CPU usage
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(
model_name,
).to(device)
quantize_(model, int8_weight_only()) quantize_(model, int8_weight_only())
model.to("cuda")
print("Enter your prompt:") # prepare the model input
input_text = input() print("Please enter the prompt you want to ask the cool AI")
inputs = tokenizer.encode(input_text, return_tensors='pt').to("cuda") prompt = input()
messages_think = [
{"role": "user", "content": prompt}
]
import time example_template = """
start_time = time.time() {% for message in messages %}
with torch.no_grad(): <|start|>{{ message.role }}<|sep|>
outputs = model.generate(inputs, max_length=5000) {{ message.content }}
<|end|>
{% endfor %}
"""
end_time = time.time() text = tokenizer.apply_chat_template(
messages_think,
chat_template=example_template,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
print(f"Quantized inference time: {end_time - start_time:.2f} seconds") # Generate the output
print(f"Generated text: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") generated_ids = model.generate(**model_inputs, max_new_tokens=32768)
# Get and decode the output
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
print(tokenizer.decode(output_ids, skip_special_tokens=True))