kaito icon indicating copy to clipboard operation
kaito copied to clipboard

feat: support preset model weight downloads

Open chewong opened this issue 10 months ago • 7 comments

Reason for Change:

This PR introduces significant changes to the model access system, focusing on enabling downloads of model weights at runtime instead of packaging them in a container image. This is particularly useful for gated Hugging Face projects like Llama, where Kaito currently requires users to build their own container image. This PR eliminates that requirement, only needing users to provide an HF token as a Kubernetes secret.

At this moment, no preset model utilizes this feature, but in the future, I plan to add Llama 3 and 4 as preset models.

Waiting for the merge before marking this PR as ready.

Requirements

  • [x] added unit tests and e2e tests (if applicable) - documented in https://github.com/kaito-project/kaito/pull/1035#issuecomment-2825108326
  • [x] Manually tested on both vLLM and Transformers

Issue Fixed:

Fixes #982

Notes for Reviewers:

TODO:

  • [ ] README on how to use this

chewong avatar Apr 22 '25 17:04 chewong

mcr.microsoft.com/aks/kaito/kaito-base:0.0.1 validated with Llama 3.1 and the changes in this PR:

Transformer: image

vLLM:

image

image

Workspace used:

apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
  name: workspace-llama-3-1-8b-instruct
resource:
  instanceType: "Standard_NC48ads_A100_v4"
  labelSelector:
    matchLabels:
      kubernetes.azure.com/accelerator: nvidia
inference:
  preset:
    name: llama-3.1-8b-instruct
    presetOptions:
      modelAccessSecret: hf-token

Llama 3.1 model.go

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package llama3

import (
	"time"

	kaitov1beta1 "github.com/kaito-project/kaito/api/v1beta1"
	"github.com/kaito-project/kaito/pkg/model"
	"github.com/kaito-project/kaito/pkg/utils/plugin"
	"github.com/kaito-project/kaito/pkg/workspace/inference"
)

func init() {
	plugin.KaitoModelRegister.Register(&plugin.Registration{
		Name:     "llama-3.1-8b-instruct",
		Instance: &llama3A,
	})
}

var llama3A llama3_1_8bInstruct

type llama3_1_8bInstruct struct{}

func (*llama3_1_8bInstruct) GetInferenceParameters() *model.PresetParam {
	return &model.PresetParam{
		ModelFamilyName:           "LLaMa3",
		ImageAccessMode:           string(kaitov1beta1.ModelImageAccessModePublic),
		DiskStorageRequirement:    "20Gi",
		GPUCountRequirement:       "1",
		TotalGPUMemoryRequirement: "20Gi",
		PerGPUMemoryRequirement:   "20Gi",
		RuntimeParam: model.RuntimeParam{
			Transformers: model.HuggingfaceTransformersParam{
				BaseCommand:       "accelerate launch",
				TorchRunParams:    inference.DefaultAccelerateParams,
				InferenceMainFile: inference.DefaultTransformersMainFile,
				ModelRunParams: map[string]string{
					"allow_remote_files": "true",
				},
			},
			VLLM: model.VLLMParam{
				BaseCommand: inference.DefaultVLLMCommand,
				// empty means use the served model name will fall back to the huggingface repo id
				ModelName:      "",
				ModelRunParams: map[string]string{},
			},
		},
		ReadinessTimeout: time.Duration(30) * time.Minute,
		WorldSize:        1,
		// base image with no model weights
		ImageName: "base",
		Tag:       "0.0.1",
	}
}

func (*llama3_1_8bInstruct) GetTuningParameters() *model.PresetParam {
	return nil // Currently doesn't support fine-tuning
}
func (*llama3_1_8bInstruct) GetDownloadParameters() *model.DownloadParam {
	return &model.DownloadParam{
		RepoId: "meta-llama/Llama-3.1-8B-Instruct",
	}
}
func (*llama3_1_8bInstruct) SupportDistributedInference() bool {
	return false
}
func (*llama3_1_8bInstruct) SupportTuning() bool {
	return false
}
func (*llama3_1_8bInstruct) RequireDownload() *model.DownloadRequirement {
	return &model.DownloadRequirement{
		Required:             true,
		AccessSecretRequired: true,
	}
}

chewong avatar Apr 23 '25 18:04 chewong

Failed to generate code suggestions for PR

kaito-pr-agent[bot] avatar Apr 23 '25 18:04 kaito-pr-agent[bot]

Failed to generate code suggestions for PR

kaito-pr-agent[bot] avatar Apr 23 '25 18:04 kaito-pr-agent[bot]

Title

(Describe updated until commit https://github.com/kaito-project/kaito/commit/1694e904378648fa2c0cb26fd7967ece24e017c8)

Support preset model weight downloads with HuggingFace access tokens


Description

  • Introduced ModelAccessSecret for HuggingFace token management.

  • Added validation for ModelAccessSecret based on download requirements.

  • Enhanced PresetParam to include metadata validation.

  • Updated manifest generation to include HF_TOKEN environment variable.


Changes walkthrough 📝

Relevant files
Enhancement
9 files
workspace_types.go
Added ModelAccessSecret to PresetOptions                                 
+3/-0     
workspace_validation.go
Added validation for ModelAccessSecret                                     
+7/-0     
interface.go
Added Metadata validation and updated command building     
+26/-0   
common.go
Added HuggingFace model version parsing                                   
+46/-0   
testModel.go
Added test model with download at runtime                               
+44/-0   
testUtils.go
Added mock workspaces with preset download                             
+53/-0   
preset-inferences.go
Updated image selection and environment variable injection
+37/-9   
manifests.go
Updated manifest generation to include envVars                     
+9/-12   
metadata.go
Added metadata validation                                                               
+1/-0     
Tests
5 files
workspace_validation_test.go
Added test cases for ModelAccessSecret                                     
+72/-0   
common_test.go
Added tests for HuggingFace model version parsing               
+111/-0 
preset-inferences_test.go
Updated test cases for preset inference with download       
+126/-16
manifests_test.go
Updated tests for manifest generation                                       
+2/-0     
metadata_test.go
Added metadata package test file                                                 
+3/-0     
Configuration changes
2 files
kaito.sh_workspaces.yaml
Added ModelAccessSecret to CRDs                                                   
+8/-0     
kaito.sh_workspaces.yaml
Added ModelAccessSecret to CRDs                                                   
+8/-0     

Need help?
  • Type /help how to ... in the comments thread for any questions about PR-Agent usage.
  • Check out the documentation for more information.
  • kaito-pr-agent[bot] avatar Apr 24 '25 17:04 kaito-pr-agent[bot]

    PR Reviewer Guide 🔍

    (Review updated until commit https://github.com/kaito-project/kaito/commit/1694e904378648fa2c0cb26fd7967ece24e017c8)

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
    🧪 PR contains tests
    🔒 Security concerns

    Sensitive information exposure:
    The PR introduces a new field ModelAccessSecret to store the HuggingFace access token. Ensure that this secret is properly managed and not exposed in logs or error messages. Additionally, verify that the secret is only accessible to the necessary components and that it is encrypted at rest.

    ⚡ Recommended focus areas for review

    Possible Issue

    The GetInferenceImageInfo function sets imageName to "base" when presetObj.DownloadAtRuntime is true, but it uses metadata.MustGet(imageName).Tag to set imageTag. This could lead to an error if metadata.MustGet("base") does not exist.

    	imageName = "base"
    	imageTag = metadata.MustGet(imageName).Tag
    } else {
    
    Error Handling

    The ParseHuggingFaceModelVersion function returns an error if the URL is invalid, but it does not handle the case where the URL is valid but the path does not match the expected format. Consider adding more specific error messages for these cases.

    parsedURL, err := url.Parse(version)
    if err != nil {
    	return "", "", err
    }
    
    if parsedURL.Host != "huggingface.co" {
    	return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
    }
    
    parts := strings.Split(strings.Trim(parsedURL.Path, "/"), "/")
    switch len(parts) {
    case 2: // Expected path: "<org>/<model>"
    	repoId, revision = parts[0]+"/"+parts[1], ""
    	return
    case 4: // Expected path: "<org>/<model>/commit/<revision>"
    	if parts[2] != "commit" {
    		break
    	}
    	repoId, revision = parts[0]+"/"+parts[1], parts[3]
    	return
    }
    
    return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
    

    kaito-pr-agent[bot] avatar Apr 24 '25 17:04 kaito-pr-agent[bot]

    PR Code Suggestions ✨

    Latest suggestions up to 1694e90

    Explore these optional code suggestions:

    CategorySuggestion                                                                                                                                    Impact
    Security
    Validate URL scheme in ParseHuggingFaceModelVersion

    Consider adding a check for the URL scheme to ensure it is HTTPS.

    pkg/utils/common.go [298-300]

    -if parsedURL.Host != "huggingface.co" {
    +if parsedURL.Scheme != "https" || parsedURL.Host != "huggingface.co" {
         return "", "", fmt.Errorf(errInvalidModelVersionURL, version)
     }
    
    Suggestion importance[1-10]: 8

    __

    Why: Ensuring the URL scheme is HTTPS improves security by preventing potential man-in-the-middle attacks when fetching model versions.

    Medium
    Possible issue
    Handle potential errors from metadata.Get

    Ensure that the metadata.MustGet function call does not panic and handle potential
    errors gracefully.

    pkg/workspace/inference/preset-inferences.go [130-131]

    +meta, err := metadata.Get(imageName)
    +if err != nil {
    +    return "", nil, err
    +}
     imageName = "base"
    -imageTag = metadata.MustGet(imageName).Tag
    +imageTag = meta.Tag
    
    Suggestion importance[1-10]: 8

    __

    Why: Handling potential errors from metadata.Get prevents the function from panicking and ensures graceful error handling.

    Medium
    General
    Improve validation logic for ModelAccessSecret

    Ensure that the validation logic correctly handles cases where ModelAccessSecret
    might be optional based on additional conditions or configurations.

    api/v1beta1/workspace_validation.go [470-474]

     if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
         errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
    -} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
    +} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" && !presetAllowsSecret(i.Preset) {
         errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
     }
    
    Suggestion importance[1-10]: 7

    __

    Why: The suggestion enhances the validation logic by considering additional conditions for ModelAccessSecret, making it more robust and flexible.

    Medium

    Previous suggestions

    Suggestions up to commit cc3de15
    CategorySuggestion                                                                                                                                    Impact
    Possible issue
    Handle errors from ParseHuggingFaceModelVersion

    Handle potential errors from ParseHuggingFaceModelVersion to prevent silent
    failures.

    pkg/workspace/inference/preset-inferences.go [112]

    -repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +if err != nil {
    +    // Handle the error appropriately, e.g., log it or return an error
    +    return fmt.Errorf("failed to parse model version: %w", err)
    +}
    
    Suggestion importance[1-10]: 8

    __

    Why: Handling errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and ensures robustness in the code.

    Medium
    General
    Improve validation logic for ModelAccessSecret

    Ensure that the validation logic correctly handles cases where the ModelAccessSecret
    might be optional based on additional conditions.

    api/v1beta1/workspace_validation.go [470-474]

     if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
         errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
     } else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
    +    // Consider adding additional checks or logging here if the secret is unexpectedly provided
         errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
     }
    
    Suggestion importance[1-10]: 3

    __

    Why: The suggestion suggests adding a comment for additional checks or logging, which is a minor improvement and does not significantly impact the functionality.

    Low
    Suggestions up to commit 9e0ac34
    CategorySuggestion                                                                                                                                    Impact
    Possible issue
    Handle errors from ParseHuggingFaceModelVersion

    Handle potential errors from ParseHuggingFaceModelVersion to prevent silent
    failures.

    pkg/workspace/inference/preset-inferences.go [112]

    -repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +if err != nil {
    +    // Handle the error appropriately, e.g., log it or return an error
    +    log.Errorf("Failed to parse Hugging Face model version: %v", err)
    +    return
    +}
    
    Suggestion importance[1-10]: 8

    __

    Why: Handling potential errors from ParseHuggingFaceModelVersion prevents silent failures and ensures robustness in the code.

    Medium
    Handle errors from metadata.Get

    Verify that metadata.MustGet(imageName) does not panic and handle potential errors
    gracefully.

    pkg/workspace/inference/preset-inferences.go [152]

    -imageName = "base"
    -imageTag = metadata.MustGet(imageName).Tag
    +metadataEntry, err := metadata.Get(imageName)
    +if err != nil {
    +    // Handle the error appropriately, e.g., log it or return an error
    +    log.Errorf("Failed to get metadata for image %s: %v", imageName, err)
    +    return
    +}
    +imageTag = metadataEntry.Tag
    
    Suggestion importance[1-10]: 8

    __

    Why: Verifying that metadata.Get(imageName) does not panic and handling potential errors gracefully improves the reliability of the code.

    Medium
    General
    Clarify error messages for modelAccessSecret

    Ensure that the error messages clearly indicate the requirement for the
    modelAccessSecret based on the DownloadAtRuntime flag.

    api/v1beta1/workspace_validation.go [470-474]

     if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
    -    errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
    +    errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model at runtime"))
     } else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
    -    errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
    +    errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions when not downloading at runtime"))
     }
    
    Suggestion importance[1-10]: 5

    __

    Why: The error messages are improved to clearly indicate the requirement for the modelAccessSecret based on the DownloadAtRuntime flag, enhancing clarity and user understanding.

    Low
    Suggestions up to commit 1595afc
    CategorySuggestion                                                                                                                                    Impact
    Possible issue
    Handle parsing errors

    Handle potential errors from ParseHuggingFaceModelVersion to prevent silent
    failures.

    pkg/workspace/inference/preset-inferences.go [112]

    -repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +if err != nil {
    +    // Handle error appropriately, e.g., log it or return an error
    +}
    
    Suggestion importance[1-10]: 8

    __

    Why: Handling potential errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and ensures robustness in the code.

    Medium
    General
    Clarify modelAccessSecret requirement

    Ensure that the validation logic correctly handles cases where the modelAccessSecret
    is optional and not strictly required when DownloadAtRuntime is false.

    api/v1beta1/workspace_validation.go [470-474]

     if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
         errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
    -} else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
    -    errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
     }
    
    Suggestion importance[1-10]: 5

    __

    Why: The suggestion clarifies the requirement for modelAccessSecret when DownloadAtRuntime is true, but removing the else if condition might lead to missing validation when DownloadAtRuntime is false and modelAccessSecret is provided, which could be intentional based on the current logic.

    Low
    Verify command string accuracy

    Verify that the expected command string matches the actual command generated by the
    system.

    pkg/workspace/inference/preset-inferences_test.go [192]

     expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py --gpu-memory-utilization=0.90 --kaito-config-file=/mnt/config/inference_config.yaml --model=test-repo/test-model --code-revision=test-revision --tensor-parallel-size=2",
    +// Consider adding a comment or logging to verify the command string
    
    Suggestion importance[1-10]: 2

    __

    Why: While verifying the command string is good practice, the suggestion does not provide actionable changes and only asks for verification, reducing its impact.

    Low
    Suggestions up to commit 56b8155
    CategorySuggestion                                                                                                                                    Impact
    Possible issue
    Handle parsing errors

    Handle potential errors from ParseHuggingFaceModelVersion to prevent silent
    failures.

    pkg/workspace/inference/preset-inferences.go [112]

    -repoId, revision, _ := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +repoId, revision, err := utils.ParseHuggingFaceModelVersion(inferenceParams.Version)
    +if err != nil {
    +    // Log the error or handle it appropriately
    +    log.Printf("Error parsing Hugging Face model version: %v", err)
    +    // Optionally, return an error or take corrective action
    +    return
    +}
    
    Suggestion importance[1-10]: 8

    __

    Why: Handling errors from ParseHuggingFaceModelVersion is crucial to prevent silent failures and improve robustness. This suggestion enhances the code's reliability.

    Medium
    General
    Review ModelAccessSecret validation logic

    Ensure that the validation logic correctly handles cases where ModelAccessSecret
    might be optional based on additional conditions or configurations.

    api/v1beta1/workspace_validation.go [470-474]

     if params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret == "" {
         errs = errs.Also(apis.ErrGeneric("This preset requires a modelAccessSecret with HF_TOKEN key under presetOptions to download the model"))
     } else if !params.DownloadAtRuntime && i.Preset.PresetOptions.ModelAccessSecret != "" {
    +    // Consider adding a condition to check if ModelAccessSecret is unexpectedly provided
    +    // when DownloadAtRuntime is false, but only if it's not needed for other reasons.
    +    // For now, this remains as is, but review the necessity of ModelAccessSecret in non-download scenarios.
         errs = errs.Also(apis.ErrGeneric("This preset does not require a modelAccessSecret with HF_TOKEN key under presetOptions"))
     }
    
    Suggestion importance[1-10]: 5

    __

    Why: The suggestion suggests reviewing the validation logic, which is important but does not provide a concrete improvement. It keeps the existing logic intact with a comment for further consideration.

    Low
    Verify expected command

    Verify that the expected command string matches the actual command generated by the
    system.

    pkg/workspace/inference/preset-inferences_test.go [192]

     expectedCmd: "/bin/sh -c python3 /workspace/vllm/inference_api.py --gpu-memory-utilization=0.90 --kaito-config-file=/mnt/config/inference_config.yaml --model=test-repo/test-model --code-revision=test-revision --tensor-parallel-size=2",
    +// Add a comment or assertion to verify the correctness of the command string
    +// assert.Equal(t, expectedCmd, generatedCmd)
    
    Suggestion importance[1-10]: 3

    __

    Why: While verifying the expected command is good practice, the suggestion only adds a comment without implementing a concrete verification step. It offers minimal improvement.

    Low

    kaito-pr-agent[bot] avatar Apr 24 '25 17:04 kaito-pr-agent[bot]

    Converted back to draft since there are other changes I am working on that this PR depends on.

    chewong avatar Apr 24 '25 21:04 chewong

    Rebased

    chewong avatar Apr 29 '25 23:04 chewong