Source code for gnes.preprocessor.image.sliding_window

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import io
from typing import List

import numpy as np
from PIL import Image

from .resize import SizedPreprocessor
from ..helper import get_all_subarea, torch_transform
from ...proto import gnes_pb2, array2blob


class _SlidingPreprocessor(SizedPreprocessor):

    def __init__(self, window_size: int = 64,
                 stride_height: int = 64,
                 stride_wide: int = 64,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.window_size = window_size
        self.stride_height = stride_height
        self.stride_wide = stride_wide

    def apply(self, doc: 'gnes_pb2.Document'):
        super().apply(doc)
        if doc.raw_bytes:
            original_image = Image.open(io.BytesIO(doc.raw_bytes))
            all_subareas, index = get_all_subarea(original_image)
            image_set, center_point_list = self._get_all_sliding_window(np.array(original_image))
            normalized_img_set = [np.array(torch_transform(img)).transpose(1, 2, 0)
                                  for img in image_set]
            weight = self._get_all_chunks_weight(normalized_img_set)

            for ci, ele in enumerate(zip(normalized_img_set, weight)):
                c = doc.chunks.add()
                c.doc_id = doc.doc_id
                c.blob.CopyFrom(array2blob(ele[0]))
                c.offset = ci
                c.offset_nd.extend(self._get_slid_offset_nd(all_subareas, index, center_point_list[ci]))
                c.weight = ele[1]
        else:
            self.logger.error('bad document: "raw_bytes" is empty!')

    def _get_all_sliding_window(self, img: 'np.ndarray'):
        extend_height = self.window_size - (img.shape[0]) % self.stride_height
        extend_wide = self.window_size - (img.shape[1]) % self.stride_wide

        input = np.pad(img, ((0, extend_height),
                             (0, extend_wide),
                             (0, 0)),
                       mode='constant', constant_values=0)
        expanded_input = np.lib.stride_tricks.as_strided(
            input,
            shape=(
                1 + int((input.shape[0] - self.window_size) / self.stride_height),
                1 + int((input.shape[1] - self.window_size) / self.stride_wide),
                self.window_size,
                self.window_size,
                3
            ),
            strides=(
                input.strides[0] * self.stride_height,
                input.strides[1] * self.stride_wide,
                input.strides[0],
                input.strides[1],
                1
            ),
            writeable=False
        )
        center_point_list = [
            [self.window_size / 2 + x * self.stride_wide, self.window_size / 2 + y * self.stride_height]
            for x in range(expanded_input.shape[0])
            for y in range(expanded_input.shape[1])]

        expanded_input = expanded_input.reshape((-1, self.window_size, self.window_size, 3))
        return [np.array(Image.fromarray(img)) for img in expanded_input], center_point_list

    def _get_slid_offset_nd(self, all_subareas: List[List[int]], index: List[List[int]], center_point: List[float]) -> \
            List[int]:
        location_list = self._get_location(all_subareas, center_point)
        location = [i for i in range(len(location_list)) if location_list[i] is True][0]
        return index[location][:2]

    @staticmethod
    def _get_location(all_subareas: List[List[int]], center_point: List[float]) -> List[bool]:
        location_list = []
        x_boundary = max([x[2] for x in all_subareas])
        y_boundary = max([y[3] for y in all_subareas])
        for area in all_subareas:
            if center_point[0] in range(int(area[0]), int(area[2])) and center_point[1] in range(int(area[1]),
                                                                                                 int(area[3])):
                location_list.append(True)
            elif center_point[0] in range(int(area[0]), int(area[2])) and y_boundary == area[3] and center_point[
                1] > y_boundary:
                location_list.append(True)
            elif center_point[1] in range(int(area[1]), int(area[3])) and x_boundary == area[2] and center_point[
                0] > x_boundary:
                location_list.append(True)
            else:
                location_list.append(False)
        if True not in location_list:
            location_list[-1] = True
        return location_list

    def _get_all_chunks_weight(self, normalizaed_image_set):
        raise NotImplementedError


[docs]class VanillaSlidingPreprocessor(_SlidingPreprocessor): def _get_all_chunks_weight(self, image_set: List['np.ndarray']) -> List[float]: return [1 / len(image_set) for _ in range(len(image_set))]
[docs]class WeightedSlidingPreprocessor(_SlidingPreprocessor): def _get_all_chunks_weight(self, image_set: List['np.ndarray']) -> List[float]: from ..video.ffmpeg import FFmpegPreprocessor return FFmpegPreprocessor.pic_weight(image_set)