Source code for gnes.encoder.video.yt8m_model

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from typing import List
import numpy as np
from ..base import BaseVideoEncoder
from ...helper import batching, get_first_available_gpu


[docs]class YouTube8MEncoder(BaseVideoEncoder): batch_size = 64 def __init__(self, model_dir: str, model_name: str, *args, **kwargs): super().__init__(*args, **kwargs) self.model_dir = model_dir self.model_name = model_name self.max_num_frames = 300 self.embedding_dim = 1152
[docs] def post_init(self): import tensorflow as tf import os os.environ['CUDA_VISIBLE_DEVICES'] = str(get_first_available_gpu()) g = tf.Graph() with g.as_default(): checkpoint_file = os.path.join(self.model_dir, self.model_name, "inference_model") meta_graph_location = checkpoint_file + ".meta" saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) config = tf.ConfigProto(log_device_placement=False) if self.on_gpu: config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) saver.restore(self.sess, checkpoint_file) self.input_tensor = tf.get_collection("input_batch_raw")[0] self.num_frames_tensor = tf.get_collection("num_frames")[0] self.predictions_tensor = self.sess.graph.get_tensor_by_name("tower/gates/MatMul:0")
[docs] def encode(self, data: List['np.ndarray'], *args, **kwargs) -> np.ndarray: def _padding(data): _data = np.array([np.concatenate((d, np.zeros((self.max_num_frames - d.shape[0], self.embedding_dim), dtype=np.float32)), axis=0) if d.shape[0] < self.max_num_frames else d[:self.max_num_frames] for d in data]) return _data.reshape((-1, self.max_num_frames, self.embedding_dim)) @batching def _encode(_, data): num_frames_list = list(map(lambda x: x.shape[0], data)) predictions_val, = self.sess.run([self.predictions_tensor], feed_dict={self.input_tensor: _padding(data), self.num_frames_tensor: np.array(num_frames_list)}) return np.array(predictions_val).astype(np.float32) return _encode(self, _padding(data))