Some issue when I ran NetworkA inference and ResNet50 Training
While writing the inference code for NetworkA, I encountered an issue with the Flatten layer not calculating the shape correctly during the model construction. Here is my code:```python import tensorflow as tf
import tf_encrypted as tfe from tf_encrypted.protocol import ABY3 # noqa:F403,F401 from tf_encrypted.convert import convert
from tf_encrypted.keras import layers, models
config = tfe.LocalConfig( player_names=[ "server0", "server1", "server2", "training-client", "prediction-client", ] ) tfe.set_config(config) tfe.set_protocol(ABY3(fixedpoint_config='l'))
batch_size = 64 learning_rate = 0.01 num_epochs = 1
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(buffer_size=10000).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) test_dataset = test_dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
model = tfe.keras.Sequential() model.add(layers.Flatten(input_shape=(28, 28))) # Error occurs here model.add(layers.Dense(128, activation='relu')) model.add(layers.Dense(64, activation='relu')) model.add(layers.Dense(10))
model.compile(optimizer=tfe.keras.optimizers.SGD(learning_rate=learning_rate), loss=tfe.keras.losses.CategoricalCrossentropy(from_logits=True))
data = tf.convert_to_tensor(x_test) tfe_data = tfe.define_private_input("prediction-client", lambda: data)
preds = model.predict(tfe_data, batch_size=batch_size) print(tf.argmax(preds, axis=1))
After attempting to modify the source code of the Flatten layer, I found a temporary solution:```python
# file tf_encrypted/keras/layers/flatten.py#L45:
flatten_shape = [-1, input_shape[-1] * input_shape[-2]]
Additionally, while executing the ResNet50 Training task, I encountered the following issue regardless of whether the batch size was set to 32 or 64:
I believe this issue may be directly caused by TFE when handling certain large tensors, as it does not occur in plaintext computations.