This sample provides a simple way to download a PyTorch model, optimize it for DirectML, and run it through a Gradio app UI.
This sample is extracted from pytorch-labs/gpt-fast, and has been slightly changed to use torch-directml
. The original code is Copyright (c) 2023 Meta, and is used here under the terms of the BSD 3-Clause License. See LICENSE for more information.
The following models are currently supported by this sample:
- Phi-2: Small Language Model with 2.7 billion parameters. Best suited for prompts using QA format, chat format, and code format.
- Phi-3 Mini 4K: Small Language Model with 3.8 billion parameters using a 4k context window. The Instruct version has been fine-tuned to follow instructions and adhere to safety measures.
- LLaMA 2: Large Language Model with 7 billion parameters optimized specifically for dialogue use cases.
- LLaMA 3: Large Language Model with 8 billion parameters. The Llama 3 instruction tuned models are optimized for dialogue use cases.
- Mistral 7B: Large Language Model with 7 billion parameters. The Mistral-7B-Instruct-v0.1 Large Language Model is a instruct fine-tuned version of the Mistral-7B-v0.1 generative text model using a variety of publicly available conversation datasets.
⚠️ NOTE: Other variants of these models may work but they were not tested.
The various models have different VRAM requirements, the following table lists the memory requirements for the tested models.
Model | fp16 | fp32 |
---|---|---|
Phi-2 | 6GB | 12GB |
Phi-3-mini-4k | 8GB | >16GB |
Llama-2-7b | 14GB | 28GB |
Meta-Llama-3-8B | >16GB | 32GB |
Mistral-7B | 15GB | 30GB |
Once you've setup torch-directml
following our Windows or WSL 2 guidance, install the following requirements for running app:
pip install -r requirements.txt
To use the Llama and Mistral models, you will need to go through an extra step to access their Hugging Face repository. To do so:
- Visit
- Follow the steps on the Hugging Face page to obtain access
- Run
huggingface-cli login
- Paste your Hugging Face User Access Token to login
Run the chatbot app using the following command:
> python app.py
The chatbot app will start with the default settings, which uses DirectML
as the backend to run the Phi-3
model for inference using float16
. The app will automatically download Phi-3-4k-instruct
on the first run from the default hf_model
which is set to microsoft/Phi-3-4k-instruct
.
This model is optimized to take advantage of DirectML operators and to use the custom DirectML graph implementations for Rotary Positional Embedding (RoPE), Multi-Head Attention (MHA), and the Feedforward layers (MLP).
When you run this code, a local URL will be displayed on the console along the following lines:
Using device=privateuseone:0, is_llama_3=False, is_phi_2=False
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Running on local URL: http://127.0.0.1:7860
--------
You should see results such as this, when running for the first time and downloading the model:
checkpoints\microsoft\Phi-3-mini-4k-instruct\model.pth doesnt exist. Downloading and converting from huggingface hub
.gitattributes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.52k/1.52k [00:00<?, ?B/s]
NOTICE.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.77k/1.77k [00:00<?, ?B/s]
LICENSE: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.08k/1.08k [00:00<?, ?B/s]
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17.2k/17.2k [00:00<?, ?B/s]
sample_finetune.py: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.34k/6.34k [00:00<?, ?B/s]
modeling_phi3.py: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 73.8k/73.8k [00:00<00:00, 5.90MB/s]
SECURITY.md: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.66k/2.66k [00:00<?, ?B/s]
CODE_OF_CONDUCT.md: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 444/444 [00:00<?, ?B/s]
Fetching 19 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:05<00:00, 3.79it/s]
Saving checkpoint to checkpoints\microsoft\Phi-3-mini-4k-instruct\model.pth
Using device=privateuseone:0, is_llama_3=False, is_phi_2=False
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Running on local URL: http://127.0.0.1:7860
--------
To create a public link, set `share=True` in `launch()`.
Open http://localhost:7860 (or the local URL you see) in a browser to interact with the chatbot.
To run the model using float32
precision, pass --precision float32
to app.py
.
> python app.py --precision float32
You can also select another model to run (microsoft/Phi-3-mini-4k-instruct
, microsoft/phi-2
, meta-llama/Meta-Llama-3-8B-Instruct
, meta-llama/Llama-2-7b-chat-hf
, mistralai/Mistral-7B-Instruct-v0.1
).
For example to run Mistral-7B-Instruct-v0.1
use the following command:
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
You should see the result such as this:
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
checkpoints\mistralai\Mistral-7B-Instruct-v0.1\model.pth doesnt exist. Downloading and converting from huggingface hub
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.90k/3.90k [00:00<?, ?B/s]
model.safetensors.index.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25.1k/25.1k [00:00<?, ?B/s]
config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 571/571 [00:00<?, ?B/s]
.gitattributes: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.52k/1.52k [00:00<00:00, 97.0kB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72.0/72.0 [00:00<?, ?B/s]
pytorch_model.bin.index.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23.9k/23.9k [00:00<?, ?B/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<?, ?B/s]
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.47k/1.47k [00:00<?, ?B/s]
tokenizer.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.80M/1.80M [00:00<00:00, 8.51MB/s]
tokenizer.model: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 493k/493k [00:00<00:00, 15.8MB/s]
model-00002-of-00002.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.54G/4.54G [04:42<00:00, 16.1MB/s]
pytorch_model-00002-of-00002.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.06G/5.06G [04:51<00:00, 17.4MB/s]
model-00001-of-00002.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.94G/9.94G [07:17<00:00, 22.7MB/s]
pytorch_model-00001-of-00002.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.94G/9.94G [07:19<00:00, 22.6MB/s]
Fetching 14 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [07:22<00:00, 31.62s/it]
Saving checkpoint to checkpoints\mistralai\Mistral-7B-Instruct-v0.1\model.pth
Using device=privateuseone:0, is_llama_3=False, is_phi_2=False
Running on local URL: http://127.0.0.1:7860
To create a public link, set `share=True` in `launch()`.
Following is a list of the basic settings supported by app.py
:
Flag | Description | Default |
---|---|---|
--help |
Show this help message. | N/A |
--hf_model |
Specify the model to downloading using the Hugging Face Repository ID. | microsoft/Phi-3-mini-4k-instruct |
--precision |
Model precision to use during generation. Options: [float16 , float32 ] |
float16 |
--checkpoint_path |
Path to converted PyTorch model checkpoint. | checkpoints/{hf_model}/model.pth |
--max_context_length |
Max prompt length including the history. If exceeded, history is clipped starting from the first (user, assistant) pair. | 1500 |
--disable_history |
Disable the chat history during generation. | Enabled |
⚠️ NOTE: The app uses the checkpoint path to determine the correct transformer model to load. The model path must specify the Hugging Face model ID included in the path name. For example:
checkpoints/microsoft/phi-2/model.pth
checkpoints/microsoft/Phi-3-mini-4k-instruct/model.pth
checkpoints/mistralai/Mistral-7B-v0.1/model.pth
checkpoints/mistralai/Mistral-7B-Instruct-v0.1/model.pth
checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth
checkpoints/meta-llama/Meta-Llama-3-8B/model.pth
checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth
This step is optional as app.py
script in Run the App section handles both downloading and optimizing a PyTorch model with DirectML.
We offer two methods for preparing PyTorch models:
> python .\scripts\download_and_convert.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
After the model is downloaded and converted, you can pass the following parameter to app.py
to run the language model:
> python app.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
Download a DirectML optimized PyTorch model from the Microsoft Hugging Face repo:
1. cd checkpoints
2. git clone https://huggingface.co/{hf_model} {hf_model}
3. cd ../
After the model is downloaded, you can pass the following parameter to app.py
to run the language model:
> python app.py --checkpoint_path "checkpoints/{hf_model}/model.pth"
- Phi-2 Hugging Face Repository
- Phi-3 Hugging Face Repository
- LLaMA 2 Hugging Face Repository
- LLaMA 3 Hugging Face Repository
- Mistral 7B Hugging Face Repository
- PyTorch gpt-fast Source Code
-
DirectML-Optimized Phi-2 Hugging Face Repository This sample uses the phi-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the MIT license. If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
-
DirectML-Optimized Phi-3 Hugging Face Repository This sample uses the phi-3 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the MIT license. If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.
-
DirectML-Optimized LLaMA 2 Hugging Face Repository This sample uses the Llama-2 model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the LLAMA 2 COMMUNITY LICENSE AGREEMENT. For terms of use, please visit: Llama 2 - Acceptable Use Policy - Meta AI. If you comply with the license and terms of use, you have the rights described therein. By using the Sample, you accept the terms.
-
DirectML-Optimized Mistral 7B Hugging Face Repository This sample uses the Mistral model, which has been optimized to work with PyTorch-DirectML. This model is licensed under the Apache-2.0 license. If you comply with the license, you have the rights described therein. By using the Sample, you accept the terms.