tfx icon indicating copy to clipboard operation
tfx copied to clipboard

add base_model to transform component

Open sapphire008 opened this issue 3 years ago • 5 comments

Summary

Add base_model channel to the Transform component, so that preprocessing_fn can access the additional model artifacts during transformation through custom_configs["base_model"]. This design is similar to the base_model argument of the Trainer component.

Example use cases

Apply pre-computed feature segmentations

def preprocessing_fn(inputs, custom_config):
    outputs = {}
    # load the feature segmentation base model
    with tf.init_scope():
        base_model = tf.saved_models.load(custom_config["base_model"])
    
    outputs["user_segments"] = base_model({
        "age": inputs["age"], 
        "gender": inputs["gender"],
        "location": inputs["location"],
    })
    ...

Apply embeddings of categorical variables trained from another model in the pipeline

import os
import pprint
import tempfile
import math
from typing import Dict

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_hub as hub

import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

from apache_beam.transforms import util

from tensorflow.python.keras.models import Model, load_model

raw_data = [
      {'title': ['sterling silver necklace'], 'query': ['silver necklace']},
      {'title': ['gold bracelet'], 'query': ['silver necklace']},
      {'title': ['opal'], 'query': ['silver necklace']}
  ]

raw_feature_spec = {
    'title': tf.io.VarLenFeature(tf.string),
    'query': tf.io.VarLenFeature(tf.string),
}

raw_data_metadata = dataset_metadata.DatasetMetadata(
  schema_utils.schema_from_feature_spec(raw_feature_spec)
)


def get_embedding_similarity(input_1, input_2, embed):
    @tf.function
    def embed_fn(a, b):
        text_1_embedding = embed(a)
        text_2_embedding = embed(b)

        text_1_normalized = tf.nn.l2_normalize(text_1_embedding, axis=-1)
        text_2_normalized = tf.nn.l2_normalize(text_2_embedding, axis=-1)
        cosine_distance = tf.reduce_sum(
            tf.multiply(text_1_normalized, text_2_normalized), axis=-1
        )
        clip_cosine_similarities = tf.clip_by_value(cosine_distance, -1.0, 1.0)
        cosine_scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi

        l2_norm = tf.norm(text_1_normalized - text_2_normalized, axis=-1, ord='euclidean')

        return cosine_scores, l2_norm
    return embed_fn(input_1, input_2)

def impute(feature_tensor, default):
    sparse = tf.sparse.SparseTensor(
        feature_tensor.indices,
        feature_tensor.values,
        [feature_tensor.dense_shape[0], 1],
    )
    dense = tf.sparse.to_dense(sp_input=sparse, default_value=default)

    return tf.squeeze(dense, axis=1)

def text_feature_transform(feature_dict: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
    outputs = dict()
    for key in feature_dict.keys():
        feature = impute(feature_dict[key], "")
        outputs[key] = feature
    return outputs

def preprocessing_fn(inputs, custom_config):
    inputs_text_transform = {key: inputs[key] for key in ['query','title']}
    outputs_text = text_feature_transform(inputs_text_transform)

    outputs = {
        **outputs_text,
    }

    with tf.init_scope():
        embed = tf.saved_model.load(custom_config["base_model"])

    cos_dist, euc_dist = get_embedding_similarity(impute(inputs['query'],""), impute(inputs['title'],""), embed)
    outputs[f"query_title_nnlm_cos_dist"] = cos_dist
    outputs[f"query_title_nnlm_euc_dist"] = euc_dist

    return outputs

def run_tft_pipeline():
    custom_config = {"base_model": "gs://mybucekt/custom_embedding_model/"}
    preprocessing_fn = lambda inputs: preprocessing_fn(inputs, custom_config)
    temp_dir = tempfile.mkdtemp()
    os.environ['TFHUB_CACHE_DIR']=os.path.join(temp_dir, "tfhub_modules")
    with tft_beam.Context(temp_dir=temp_dir, force_tf_compat_v1=False):
        transformed_dataset, transform_fn = ((raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))

    transformed_data, transformed_metadata = transformed_dataset

    print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
    print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))


if __name__== "__main__":
    run_tft_pipeline()

Example adapted from: https://github.com/tensorflow/transform/issues/219

The model artifacts can be generic and stores any tensorflow models and functions.

sapphire008 avatar Oct 02 '22 21:10 sapphire008

@jiyongjung0 @gbaned Please let me know if there are any suggestions.

EdwardCuiPeacock avatar Oct 08 '22 17:10 EdwardCuiPeacock

Thank you for the contribution!

This seems like a non trivial change and we might need to discuss its impact and possible alternatives(For example, using BulkInferrer).

@zoyahav @iindyk Could you take a look at this PR?

jiyongjung0 avatar Oct 11 '22 01:10 jiyongjung0

The second example snippet above is hardcoding the path: model_path = "gs://mybucekt/custom_embedding_model/" What's the reason for not doing the same in the first snippet, when using TFX Transform?

zoyahav avatar Oct 11 '22 10:10 zoyahav

The second example snippet above is hardcoding the path: model_path = "gs://mybucekt/custom_embedding_model/" What's the reason for not doing the same in the first snippet, when using TFX Transform?

I can fix that. But this is just to show that we can initialize a model inside the preprocessing function.

EdwardCuiPeacock avatar Oct 11 '22 15:10 EdwardCuiPeacock

Thank you for the contribution!

This seems like a non trivial change and we might need to discuss its impact and possible alternatives(For example, using BulkInferrer).

@zoyahav @iindyk Could you take a look at this PR?

The need here is using a pretrained model (or a model trained in a separate trainer) to preprocess the features used to train another model. The desirable outcome is to save a base model together with the main model during preprocessing and training so it can also be used during serving with the same logic. This seems to be the perfect task for Transform component to handle. The examples I have provided above has shown that people have used Transform component to apply models in the past. BulkInferrer seems to be used only after a model is trained and an artifact is produced. BulkInferrer also does not save a transformation graph needed to preprocess features not using the base model.

sapphire008 avatar Oct 11 '22 15:10 sapphire008

@zoyahav @jiyongjung0 Any followups?

EdwardCuiPeacock avatar Oct 21 '22 13:10 EdwardCuiPeacock

Hi @zoyahav Any update on this PR? Please. Thank you!

gbaned avatar Nov 12 '22 14:11 gbaned

Hi @zoyahav Any update on this PR? Please. Thank you!

gbaned avatar Dec 16 '22 13:12 gbaned

This PR is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days

github-actions[bot] avatar Jan 16 '23 02:01 github-actions[bot]