Reason for Change:
This PR introduces significant changes to the model access system, focusing on enabling downloads of model weights at runtime instead of packaging them in a container image. This is particularly useful for gated Hugging Face projects like Llama, where Kaito currently requires users to build their own container image. This PR eliminates that requirement, only needing users to provide an HF token as a Kubernetes secret.
At this moment, no preset model utilizes this feature, but in the future, I plan to add Llama 3 and 4 as preset models.
Waiting for the merge before marking this PR as ready.
Requirements
- [x] added unit tests and e2e tests (if applicable) - documented in https://github.com/kaito-project/kaito/pull/1035#issuecomment-2825108326
- [x] Manually tested on both vLLM and Transformers
Issue Fixed:
Fixes #982
Notes for Reviewers:
TODO:
- [ ] README on how to use this
mcr.microsoft.com/aks/kaito/kaito-base:0.0.1 validated with Llama 3.1 and the changes in this PR:
Transformer:

vLLM:


Workspace used:
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-llama-3-1-8b-instruct
resource:
instanceType: "Standard_NC48ads_A100_v4"
labelSelector:
matchLabels:
kubernetes.azure.com/accelerator: nvidia
inference:
preset:
name: llama-3.1-8b-instruct
presetOptions:
modelAccessSecret: hf-token
Llama 3.1 model.go
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package llama3
import (
"time"
kaitov1beta1 "github.com/kaito-project/kaito/api/v1beta1"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/plugin"
"github.com/kaito-project/kaito/pkg/workspace/inference"
)
func init() {
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "llama-3.1-8b-instruct",
Instance: &llama3A,
})
}
var llama3A llama3_1_8bInstruct
type llama3_1_8bInstruct struct{}
func (*llama3_1_8bInstruct) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "LLaMa3",
ImageAccessMode: string(kaitov1beta1.ModelImageAccessModePublic),
DiskStorageRequirement: "20Gi",
GPUCountRequirement: "1",
TotalGPUMemoryRequirement: "20Gi",
PerGPUMemoryRequirement: "20Gi",
RuntimeParam: model.RuntimeParam{
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
TorchRunParams: inference.DefaultAccelerateParams,
InferenceMainFile: inference.DefaultTransformersMainFile,
ModelRunParams: map[string]string{
"allow_remote_files": "true",
},
},
VLLM: model.VLLMParam{
BaseCommand: inference.DefaultVLLMCommand,
// empty means use the served model name will fall back to the huggingface repo id
ModelName: "",
ModelRunParams: map[string]string{},
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
WorldSize: 1,
// base image with no model weights
ImageName: "base",
Tag: "0.0.1",
}
}
func (*llama3_1_8bInstruct) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama3_1_8bInstruct) GetDownloadParameters() *model.DownloadParam {
return &model.DownloadParam{
RepoId: "meta-llama/Llama-3.1-8B-Instruct",
}
}
func (*llama3_1_8bInstruct) SupportDistributedInference() bool {
return false
}
func (*llama3_1_8bInstruct) SupportTuning() bool {
return false
}
func (*llama3_1_8bInstruct) RequireDownload() *model.DownloadRequirement {
return &model.DownloadRequirement{
Required: true,
AccessSecretRequired: true,
}
}
Failed to generate code suggestions for PR
Failed to generate code suggestions for PR
Title
(Describe updated until commit https://github.com/kaito-project/kaito/commit/1694e904378648fa2c0cb26fd7967ece24e017c8)
Support preset model weight downloads with HuggingFace access tokens
Description
-
Introduced ModelAccessSecret for HuggingFace token management.
-
Added validation for ModelAccessSecret based on download requirements.
-
Enhanced PresetParam to include metadata validation.
-
Updated manifest generation to include HF_TOKEN environment variable.
Changes walkthrough 📝
| Relevant files |
|---|
| Enhancement | 9 files
workspace_types.goAdded ModelAccessSecret to PresetOptions |
+3/-0 |
workspace_validation.goAdded validation for ModelAccessSecret |
+7/-0 |
interface.goAdded Metadata validation and updated command building |
+26/-0 |
common.goAdded HuggingFace model version parsing |
+46/-0 |
testModel.goAdded test model with download at runtime |
+44/-0 |
testUtils.goAdded mock workspaces with preset download |
+53/-0 |
preset-inferences.goUpdated image selection and environment variable injection |
+37/-9 |
manifests.goUpdated manifest generation to include envVars |
+9/-12 |
metadata.goAdded metadata validation |
+1/-0 |
|
| Tests | 5 files
workspace_validation_test.goAdded test cases for ModelAccessSecret |
+72/-0 |
common_test.goAdded tests for HuggingFace model version parsing |
+111/-0 |
preset-inferences_test.goUpdated test cases for preset inference with download |
+126/-16 |
manifests_test.goUpdated tests for manifest generation |
+2/-0 |
metadata_test.goAdded metadata package test file |
+3/-0 |
|
| Configuration changes | 2 files
kaito.sh_workspaces.yamlAdded ModelAccessSecret to CRDs |
+8/-0 |
kaito.sh_workspaces.yamlAdded ModelAccessSecret to CRDs |
+8/-0 |
|
Need help?
Type /help how to ... in the comments thread for any questions about PR-Agent usage.Check out the documentation for more information.
PR Reviewer Guide 🔍
(Review updated until commit https://github.com/kaito-project/kaito/commit/1694e904378648fa2c0cb26fd7967ece24e017c8)
Here are some key observations to aid the review process:
| ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪ |
| 🧪 PR contains tests |
🔒 Security concerns
Sensitive information exposure: The PR introduces a new field ModelAccessSecret to store the HuggingFace access token. Ensure that this secret is properly managed and not exposed in logs or error messages. Additionally, verify that the secret is only accessible to the necessary components and that it is encrypted at rest. |
⚡ Recommended focus areas for review
Possible Issue
The GetInferenceImageInfo function sets imageName to "base" when presetObj.DownloadAtRuntime is true, but it uses metadata.MustGet(imageName).Tag to set imageTag. This could lead to an error if metadata.MustGet("base") does not exist.
imageName = "base"
imageTag = metadata.MustGet(imageName).Tag
} else {
Error Handling
The ParseHuggingFaceModelVersion function returns an error if the URL is invalid, but it does not handle the case where the URL is valid but the path does not match the expected format. Consider adding more specific error messages for these cases.
parsedURL, err := url.Parse(version)
if err != nil {
return "", "", err
}
if parsedURL.Host != "huggingface.co" {
return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
}
parts := strings.Split(strings.Trim(parsedURL.Path, "/"), "/")
switch len(parts) {
case 2: // Expected path: "<org>/<model>"
repoId, revision = parts[0]+"/"+parts[1], ""
return
case 4: // Expected path: "<org>/<model>/commit/<revision>"
if parts[2] != "commit" {
break
}
repoId, revision = parts[0]+"/"+parts[1], parts[3]
return
}
return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
|
PR Code Suggestions ✨
Latest suggestions up to 1694e90
Explore these optional code suggestions:
| Category | Suggestion | Impact |
| Security |
Validate URL scheme in ParseHuggingFaceModelVersion
Consider adding a check for the URL scheme to ensure it is HTTPS.
pkg/utils/common.go [298-300]
-if parsedURL.Host != "huggingface.co" {
+if parsedURL.Scheme != "https" || parsedURL.Host != "huggingface.co" {
return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
}
Suggestion importance[1-10]: 8
__
Why: Ensuring the URL scheme is HTTPS improves security by preventing potential man-in-the-middle attacks when fetching model versions.
| Medium
|
| Possible issue |
Handle potential errors from metadata.Get
Ensure that the metadata.MustGet function call does not panic and handle potential errors gracefully.
pkg/workspace/inference/preset-inferences.go [130-131]
+meta, err := metadata.Get(imageName)
+if err != nil {
+ return "", nil, err
+}
imageName = "base"
-imageTag = metadata.MustGet(imageName).Tag
+imageTag = meta.Tag
Suggestion importance[1-10]: 8
__
Why: Handling potential errors from metadata.Get prevents the function from panicking and ensures graceful error handling.
| Medium
|
| General |
Improve validation logic for ModelAccessSecret
Ensure that the validation logic correctly handles cases where ModelAccessSecret might be optional based on additional conditions or configurations.
api/v1beta1/workspace_validation.go [470-474]
if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
-} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
+} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" && !presetAllowsSecret(i.Preset) {
errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
}
Suggestion importance[1-10]: 7
__
Why: The suggestion enhances the validation logic by considering additional conditions for ModelAccessSecret, making it more robust and flexible.
| Medium
|
Previous suggestions
Suggestions up to commit cc3de15
| Category | Suggestion | Impact |
| Possible issue |
Handle errors from ParseHuggingFaceModelVersion
Handle potential errors from ParseHuggingFaceModelVersion to prevent silent failures.
pkg/workspace/inference/preset-inferences.go [112]
-repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+if err != nil {
+ // Handle the error appropriately, e.g., log it or return an error
+ return fmt.Errorf("failed to parse model version: %w", err)
+}
Suggestion importance[1-10]: 8
__
Why: Handling errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and ensures robustness in the code.
| Medium
|
| General |
Improve validation logic for ModelAccessSecret
Ensure that the validation logic correctly handles cases where the ModelAccessSecret might be optional based on additional conditions.
api/v1beta1/workspace_validation.go [470-474]
if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
+ // Consider adding additional checks or logging here if the secret is unexpectedly provided
errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
}
Suggestion importance[1-10]: 3
__
Why: The suggestion suggests adding a comment for additional checks or logging, which is a minor improvement and does not significantly impact the functionality.
| Low
|
Suggestions up to commit 9e0ac34
| Category | Suggestion | Impact |
| Possible issue |
Handle errors from ParseHuggingFaceModelVersion
Handle potential errors from ParseHuggingFaceModelVersion to prevent silent failures.
pkg/workspace/inference/preset-inferences.go [112]
-repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+if err != nil {
+ // Handle the error appropriately, e.g., log it or return an error
+ log.Errorf("Failed to parse Hugging Face model version: %v", err)
+ return
+}
Suggestion importance[1-10]: 8
__
Why: Handling potential errors from ParseHuggingFaceModelVersion prevents silent failures and ensures robustness in the code.
| Medium
|
Handle errors from metadata.Get
Verify that metadata.MustGet(imageName) does not panic and handle potential errors gracefully.
pkg/workspace/inference/preset-inferences.go [152]
-imageName = "base"
-imageTag = metadata.MustGet(imageName).Tag
+metadataEntry, err := metadata.Get(imageName)
+if err != nil {
+ // Handle the error appropriately, e.g., log it or return an error
+ log.Errorf("Failed to get metadata for image %s: %v", imageName, err)
+ return
+}
+imageTag = metadataEntry.Tag
Suggestion importance[1-10]: 8
__
Why: Verifying that metadata.Get(imageName) does not panic and handling potential errors gracefully improves the reliability of the code.
| Medium
|
| General |
Clarify error messages for modelAccessSecret
Ensure that the error messages clearly indicate the requirement for the
modelAccessSecret based on the DownloadAtRuntime flag.
api/v1beta1/workspace_validation.go [470-474]
if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
- errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
+ errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model at runtime"))
} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
- errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
+ errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions when not downloading at runtime"))
}
Suggestion importance[1-10]: 5
__
Why: The error messages are improved to clearly indicate the requirement for the modelAccessSecret based on the DownloadAtRuntime flag, enhancing clarity and user understanding.
| Low
|
Suggestions up to commit 1595afc
| Category | Suggestion | Impact |
| Possible issue |
Handle parsing errors
Handle potential errors from ParseHuggingFaceModelVersion to prevent silent failures.
pkg/workspace/inference/preset-inferences.go [112]
-repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+if err != nil {
+ // Handle error appropriately, e.g., log it or return an error
+}
Suggestion importance[1-10]: 8
__
Why: Handling potential errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and ensures robustness in the code.
| Medium
|
| General |
Clarify modelAccessSecret requirement
Ensure that the validation logic correctly handles cases where the modelAccessSecret is optional and not strictly required when DownloadAtRuntime is false.
api/v1beta1/workspace_validation.go [470-474]
if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
-} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
- errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
}
Suggestion importance[1-10]: 5
__
Why: The suggestion clarifies the requirement for modelAccessSecret when DownloadAtRuntime is true, but removing the else if condition might lead to missing validation when DownloadAtRuntime is false and modelAccessSecret is provided, which could be intentional based on the current logic.
| Low
|
Verify command string accuracy
Verify that the expected command string matches the actual command generated by the system.
pkg/workspace/inference/preset-inferences_test.go [192]
expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py --gpu-memory-utilization=0.90 --kaito-config-file=/mnt/config/inference_config.yaml --model=test-repo/test-model --code-revision=test-revision --tensor-parallel-size=2",
+// Consider adding a comment or logging to verify the command string
Suggestion importance[1-10]: 2
__
Why: While verifying the command string is good practice, the suggestion does not provide actionable changes and only asks for verification, reducing its impact.
| Low
|
Suggestions up to commit 56b8155
| Category | Suggestion | Impact |
| Possible issue |
Handle parsing errors
Handle potential errors from ParseHuggingFaceModelVersion to prevent silent failures.
pkg/workspace/inference/preset-inferences.go [112]
-repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
+if err != nil {
+ // Log the error or handle it appropriately
+ log.Printf("Error parsing Hugging Face model version: %v", err)
+ // Optionally, return an error or take corrective action
+ return
+}
Suggestion importance[1-10]: 8
__
Why: Handling errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and improve robustness. This suggestion enhances the code's reliability.
| Medium
|
| General |
Review ModelAccessSecret validation logic
Ensure that the validation logic correctly handles cases where ModelAccessSecret might be optional based on additional conditions or configurations.
api/v1beta1/workspace_validation.go [470-474]
if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
+ // Consider adding a condition to check if ModelAccessSecret is unexpectedly provided
+ // when DownloadAtRuntime is false, but only if it's not needed for other reasons.
+ // For now, this remains as is, but review the necessity of ModelAccessSecret in non-download scenarios.
errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
}
Suggestion importance[1-10]: 5
__
Why: The suggestion suggests reviewing the validation logic, which is important but does not provide a concrete improvement. It keeps the existing logic intact with a comment for further consideration.
| Low
|
Verify expected command
Verify that the expected command string matches the actual command generated by the system.
pkg/workspace/inference/preset-inferences_test.go [192]
expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py --gpu-memory-utilization=0.90 --kaito-config-file=/mnt/config/inference_config.yaml --model=test-repo/test-model --code-revision=test-revision --tensor-parallel-size=2",
+// Add a comment or assertion to verify the correctness of the command string
+// assert.Equal(t, expectedCmd, generatedCmd)
Suggestion importance[1-10]: 3
__
Why: While verifying the expected command is good practice, the suggestion only adds a comment without implementing a concrete verification step. It offers minimal improvement.
| Low
|
Converted back to draft since there are other changes I am working on that this PR depends on.