FATE icon indicating copy to clipboard operation
FATE copied to clipboard

2.0支持纵向的lstm吗?

Open FancyXun opened this issue 1 year ago • 4 comments

我看fate对torch的nn 有些封装,包括Sequential这一类,同时也看到了lstm的模型,但怎么使用呢?lstm的输出有是个tuple,没法直接add 进Sequentia吧?

FancyXun avatar Sep 06 '24 10:09 FancyXun

bottom_model=Sequential( nn.Linear(10, 10), nn.LSTM(input_size=10, hidden_size=10, batch_first=True), nn.Linear(10, 10), ),

这种定义肯定有问题吧,LSTM的输出是个tuple

FancyXun avatar Sep 06 '24 10:09 FancyXun

请问一下,你看的是哪个教程呢?

talkingwallace avatar Sep 09 '24 07:09 talkingwallace

多谢答复,我使用的是fate 2.0里面自带的nn例子,这是我基于nn的例子写了一个,你看下有什么问题呢。

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import argparse
from fate_client.pipeline.utils import test_utils
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.nn.torch import nn, optim
from fate_client.pipeline.components.fate.nn.torch.base import Sequential
from fate_client.pipeline.components.fate.hetero_nn import HeteroNN, get_config_of_default_runner
from fate_client.pipeline.components.fate.psi import PSI
from fate_client.pipeline.components.fate.reader import Reader
from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments
from fate_client.pipeline.components.fate import Evaluation
from fate_client.pipeline.components.fate.nn.algo_params import FedPassArgument



class LSTMModel(Sequential):
    def __init__(self, input_size, hidden_size, num_layers, num_classes ):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # 设置初始状态
        # h0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))
        # c0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))

        # 前向传播
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :]) 
        return out






def main(config="../../config.yaml", namespace=""):
    # obtain config
    if isinstance(config, str):
        config = test_utils.load_job_config(config)
    parties = config.parties
    guest = parties.guest[0]
    host = parties.host[0]
    arbiter = parties.arbiter[0]

    pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)

    reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host))
    reader_0.guest.task_parameters(
        namespace=f"experiment{namespace}",
        name="breast_hetero_guest"
    )
    reader_0.hosts[0].task_parameters(
        namespace=f"experiment{namespace}",
        name="breast_hetero_host"
    )
    psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])

    training_args = TrainingArguments(
            num_train_epochs=1,
            per_device_train_batch_size=16,
            logging_strategy='epoch'
        )

    guest_conf = get_config_of_default_runner(
        bottom_model= LSTMModel(10, 10, 3, 10),
        top_model=Sequential(
            nn.Linear(10, 1),
            nn.Sigmoid()
        ),
        training_args=training_args,
        optimizer=optim.Adam(lr=0.01),
        loss=nn.BCELoss()
    )

    host_conf = get_config_of_default_runner(
        bottom_model=nn.Linear(20, 20),
        optimizer=optim.Adam(lr=0.01),
        training_args=training_args,
        agglayer_arg=FedPassArgument(
            layer_type='linear',
            in_channels_or_features=20,
            hidden_features=20,
            out_channels_or_features=10,
            passport_mode='single',
            passport_distribute='gaussian'
        )
    )

    hetero_nn_0 = HeteroNN(
        'hetero_nn_0',
        train_data=psi_0.outputs['output_data']
    )

    hetero_nn_0.guest.task_parameters(runner_conf=guest_conf)
    hetero_nn_0.hosts[0].task_parameters(runner_conf=host_conf)

    hetero_nn_1 = HeteroNN(
        'hetero_nn_1',
        test_data=psi_0.outputs['output_data'],
        predict_model_input=hetero_nn_0.outputs['train_model_output']
    )

    evaluation_0 = Evaluation(
        'eval_0',
        runtime_parties=dict(guest=guest),
        metrics=['auc'],
        input_data=[hetero_nn_1.outputs['predict_data_output'], hetero_nn_0.outputs['train_data_output']]
    )

    pipeline.add_tasks([reader_0, psi_0, hetero_nn_0, hetero_nn_1, evaluation_0])
    pipeline.compile()
    pipeline.fit()

    result_summary = pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]
    print(f"result_summary: {result_summary}")


if __name__ == "__main__":

    parser = argparse.ArgumentParser("PIPELINE DEMO")
    parser.add_argument("--config", type=str, default="../config.yaml",
                        help="config file")
    parser.add_argument("--namespace", type=str, default="",
                        help="namespace for data stored in FATE")
    args = parser.parse_args()
    main(config=args.config, namespace=args.namespace)

fate_client能编译通过,但是在fate server端回报错: Screen Shot 2024-09-09 at 15 37 35

lstm的输出是一个tuple,不是一个tensor,我还特点在自定义的model的forward阶段处理了下,这个模型使用torch自带也能预测出结果。不知道fate是不是对torch的一些中间接口改了,导致模型发送到server阶段后,重新加载,丢失了一些信息。 或者有没有一个可以用纵向的lstm用作分类的例子呢?lstm作为bottle model,或者top model都可以。参考官方自带的nn例子总有bug,不知道是不是我使用的不对,还请多指教。@talkingwallace

FancyXun avatar Sep 09 '24 07:09 FancyXun

这个是一个基于pipeline的例子,可以参考下ml直接运行的例子呢

talkingwallace avatar Sep 20 '24 09:09 talkingwallace

This issue has been marked as stale because it has been open for 365 days with no activity. If this issue is still relevant or if there is new information, please feel free to update or reopen it.

github-actions[bot] avatar Sep 21 '25 02:09 github-actions[bot]

This issue was closed because it has been inactive for 1 days since being marked as stale. If this issue is still relevant or if there is new information, please feel free to update or reopen it.

github-actions[bot] avatar Sep 22 '25 02:09 github-actions[bot]