сохранить модель tensorflow( tf.estimator.Estimator ) — python model export

1 Звезда2 Звезды3 Звезды4 Звезды5 Звезд
Загрузка...

Вопрос:


Помогите пожалуйста сохранить модель, Estimator.

есть такая функция model.export_savedmodel(export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None, strip_default_attrs=False)

export_dir_base — тут все понятно

serving_input_receiver_fn — вот это главный затык, я уже помотрел и документацию и примеры.

вот такой пример я смог найти на просторах интернета, но не смог его адаптировать для себя:

ссылка на тему где обсуждали: https://github.com/tensorflow/tensorflow/issues/12508

def estimator(model_path):
    feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
    return tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                      hidden_units=[10, 20, 10],
                                      n_classes=3,
                                      model_dir=model_path)


def serving_input_receiver_fn():
    feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
    return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()


def train_input_fn(training_dir):
    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=os.path.join(training_dir, 'iris_training.csv'),
        target_dtype=np.int,
        features_dtype=np.float32)

    return tf.estimator.inputs.numpy_input_fn(
        x={INPUT_TENSOR_NAME: np.array(training_set.data)},
        y=np.array(training_set.target),
        num_epochs=None,
        shuffle=True)()

помогите , пожалуйста, разобраться, насколько я понимаю, проблема актуальна и там что-то добавили только 1.6 что бы это стало проще.

    #CONVOLUTIONAL NEURAL NETWORK
from __future__ import division, print_function, absolute_import


# Training Parameters
learning_rate = 0.001
num_steps = 100
batch_size = 128

# Network Parameters
num_input = 196 # systoles data input (img shape: 14*14)
num_classes = 19 # systoles total classes (0-19 digits)
dropout = 0.25 # Dropout, probability to drop a unit


# Create the neural network
def conv_net(x_dict, n_classes, dropout, reuse, is_training):
    # Define a scope for reusing the variables
    with tf.variable_scope('ConvNet', reuse=reuse):
        # TF Estimator input is a dict, in case of multiple inputs
        x = x_dict['images']

        # systoles data input is a 1-D vector of 196 features (14*14 pixels)
        # Reshape to match picture format [Height x Width x Channel]
        # Tensor input become 4-D: [Batch Size, Height, Width, Channel]
        x = tf.reshape(x, shape=[-1, 1, 196, 1])

        # Convolution Layer with 16 filters and a kernel size of 4
        conv1 = tf.layers.conv2d(x, filters=16, kernel_size= [1,4], activation=tf.nn.relu)
        # Max Pooling (down-sampling) with strides of 2 and kernel size of 2
        conv1 = tf.layers.max_pooling2d(conv1, pool_size=[1,2], strides = [1,2])

        # Convolution Layer with 32 filters and a kernel size of 2
        conv2 = tf.layers.conv2d(conv1, filters=32,  kernel_size= [1,2], activation=tf.nn.relu)
        # Max Pooling (down-sampling) with strides of 2 and kernel size of 2
        conv2 = tf.layers.max_pooling2d(conv2, pool_size=[1,2], strides = [1,2])

        # Flatten the data to a 1-D vector for the fully connected layer
        fc1 = tf.contrib.layers.flatten(conv2)

        # Fully connected layer (in tf contrib folder for now)
        fc1 = tf.layers.dense(inputs= fc1, units= 2048, activation=tf.nn.relu)
        # Apply Dropout (if is_training is False, dropout is not applied)
        fc1 = tf.layers.dropout(fc1, rate=dropout, training=is_training)

        # Output layer, class prediction
        out = tf.layers.dense(fc1, n_classes)

    return out


# Define the model function (following TF Estimator Template)
def model_fn(features, labels, mode):
    # Build the neural network
    # Because Dropout have different behavior at training and prediction time, we
    # need to create 2 distinct computation graphs that still share the same weights.
    logits_train = conv_net(features, num_classes, dropout, reuse=False,
                            is_training=True)
    logits_test = conv_net(features, num_classes, dropout, reuse=True,
                           is_training=False)

    # Predictions
    pred_classes = tf.argmax(logits_test, axis=1)
    pred_probas = tf.nn.softmax(logits_test)

    # If prediction mode, early return
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=pred_classes)

        # Define loss and optimizer
    loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits_train, labels=tf.cast(labels, dtype=tf.int32)))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss_op,
                                  global_step=tf.train.get_global_step())

    # Evaluate the accuracy of the model
    acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes)

    # TF Estimators requires to return a EstimatorSpec, that specify
    # the different ops for training, evaluating, ...
    estim_specs = tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=pred_classes,
        loss=loss_op,
        train_op=train_op,
        eval_metric_ops={'accuracy': acc_op})

    return estim_specs

# Build the Estimator
model = tf.estimator.Estimator(model_fn)

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(x={'images': x_train}, 
                                              y= y_train, 
                                              batch_size=batch_size, 
                                              num_epochs=None, 
                                              shuffle=True)
# Train the Model
model.train(input_fn, steps=num_steps)

# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(x={'images': x_eval}, 
                                              y= y_eval, 
                                              batch_size=batch_size, 
                                              shuffle=False)
# Use the Estimator 'evaluate' method

e = model.evaluate(input_fn)

print("Testing Accuracy:", e['accuracy'])#95


model.export_savedmodel()# ????

model.export_savedmodel()# ???? как правильно и что передать этой функции?

другие варианты сохранения я знаю, они работают на других моделях, было очень интересно разобраться именно с сохранением Estimator.

ссылка от MaxU, Проверил вариант:

    def serving_input_receiver_fn():
    serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_tensors')
    receiver_tensors      = {"predictor_inputs": serialized_tf_example}
    feature_spec          = {"images": tf.FixedLenFeature([19],tf.float32)}
    features              = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

full_model_dir = model.export_savedmodel(export_dir_base= r"/savedmodel",                                                  
                     serving_input_receiver_fn=serving_input_receiver_fn)



# >>>> ValueError: export_outputs must be a dict and not<class 'NoneType'>

это я еще поправил, там сходу ошибка, в том варианте он в tf.FixedLenFeature — передает int , а можно только float, это як тому что это только теоретический вариант, скорей всего.

возможно в serialized_tf_example надо передавать что-то от моей модели?

Автор вопроса: Oleksii

Источник

Вам также может быть интересно:

Использование вложенных маршрутов в React Router — javascript reactjs react-router
Вопрос: Для организации маршрутов в приложении использую React Router. <Route path="/" component={...}> <IndexRoute component={...}/> <Route path="user/:userId" component={...}> ...
Как с помощью Retrofit 2.0 отправить данные в JSON на сервер и получить ответ? — java android retrofit
Вопрос: Только начал читать про Retrofit 2.0 до этого использовал HttpURLConnection. Как я работаю с HttpURLConnection, формирую Json перевожу его в byte, ставлю header в ...
Не приходят push уведомления. FCM — android firebase android-notification
Вопрос: Появилась необходимость реализовать push уведомления. Прописал в манифесте сервис: <service android:name=".MyFirebaseMessagingService"> <intent-filter> ...
Принцип браузерной игры в линукс терминале — java linux terminal
Вопрос: Наткнулся на Java библиотеку CHARVA. И хотел бы уточнить у знающих людей, возможно ли на основе данной библиотеки сделать программу по принципу браузерной игры, но ...
Мерцание заблокированного экрана при выключенной подсветке в Debian 8 Gnome 3 — linux debian экран
Вопрос: На ноутбуке с Debian 8 Jessie и Gnome 3 имеется следующая проблема. При выключенном заблокированном экране сквозь него можно наблюдать, как весь экран становится белым, ...
Создание WCF клиента на готовый SOAP web сервер — c# wcf
Вопрос: Доброго времени суток. Появилась задача опрашивать web сервер с клиента на котором планируется написать WCF клиентскую часть. Информации про сервер очень мало (не знаю платформу ...
Безопасно ли удалить файл логов general_log.txt? — mysql
Вопрос: При выполнении запроса со вставкой данных большого объёма SQLyog начал вылетать с ошибкой: not enough memory application terminated В связи с этим я решила ...
Callback функции создания таблицы mysql в nodejs — mysql node.js callback
Вопрос: Есть функция, которая при запуске создает базу даных, function showDb() { pool.query("show databases like 'bt' ",function (err, ...
Как создать Adapter с неограниченным количеством строк и с неограниченным разным количеством столбцов в каждой строке — java android
Вопрос: Как создать Adapter с неограниченным количеством строк и с неограниченным разным количеством столбцов в каждой строке Автор вопроса: Salut Amigo Источник
Не могу передать байтовый массив в контроллер — c# asp.net-mvc entity-framework
Вопрос: У меня изображения храняться в бд в формате байтового массива, через форич отлично все выводит, но когда я хочу открыть страницу для работы с ...
proguard release error — java android mvp
Вопрос: Включил в проекте proguard, apk собирается, все хорошо, но приложение не работает) Proguard-rules.pro -keepattributes InnerClasses -keepattributes EnclosingMethod -keepattributes *Annotation* -dontoptimize # Keep Butterknife -keep class butterknife.** { *; } -dontwarn butterknife.internal.** -keep ...
Не отрабатывает page:update — javascript ruby-on-rails
Вопрос: Есть мой учебный проект на ruby. Делаю редактирование объектов с помощью JS. Сейчас работает так: Редактирую первый раз - всё нормально. Не обновляя страницу, ...
Как найти определенный символ в строке и удалить значение после него (и вместе с ним) Jquery — javascript html jquery
Вопрос: Здравствуйте, есть определенный набор строк, типа "L / Красный / 12345", как можно на странице найти их, и вырезать из них все что находится ...
Почему не работает wildcard module declaration? — typescript
Вопрос: Почему не работает такой способ декларации: declare module "*!text" {} ? Цель - использовать контент файла в переменной: import layout = require("/js/views/layouts/wnd.html!text"); или так: import layout from "/js/views/layouts/wnd.html!text"; Если ...
Как прервать 3rd-party код? — c# многопоточность .net-core
Вопрос: Есть 3rd-party код из библиотеки который "зависает" в ожидании где-то в работе с сетью. CancellationToken поддержки нет, таймаутов нет. Запускаю я его через: Task.Run(() => ...

Оставьте ответ

Ваш e-mail не будет опубликован. Обязательные поля помечены *