Path: blob/master/modules/dnn/src/layers/shuffle_channel_layer.cpp
16337 views
// This file is part of OpenCV project.1// It is subject to the license terms in the LICENSE file found in the top-level directory2// of this distribution and at http://opencv.org/license.html.34// Copyright (C) 2018, Intel Corporation, all rights reserved.5// Third party copyrights are property of their respective owners.6#include "../precomp.hpp"78namespace cv { namespace dnn {910class ShuffleChannelLayerImpl CV_FINAL : public ShuffleChannelLayer11{12public:13ShuffleChannelLayerImpl(const LayerParams& params)14{15group = params.get<int>("group", 1);16setParamsFrom(params);17}1819bool getMemoryShapes(const std::vector<MatShape> &inputs,20const int requiredOutputs,21std::vector<MatShape> &outputs,22std::vector<MatShape> &internals) const CV_OVERRIDE23{24CV_Assert(inputs.size() == 1 && inputs[0].size() == 4);25CV_Assert(inputs[0][1] % group == 0);26Layer::getMemoryShapes(inputs, requiredOutputs, outputs, internals);27return group == 1;28}2930virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE31{32if (group != 1)33{34std::vector<Mat> inputs, outputs;35inputs_arr.getMatVector(inputs);36outputs_arr.getMatVector(outputs);3738LayerParams lp;39float order[] = {0, 2, 1, 3};40lp.set("order", DictValue::arrayInt(&order[0], 4));41permute = PermuteLayer::create(lp);4243const Mat& inp = inputs[0];44const Mat& out = outputs[0];4546permuteInpShape.resize(4);47permuteInpShape[0] = inp.size[0];48permuteInpShape[1] = group;49permuteInpShape[2] = inp.size[1] / group;50permuteInpShape[3] = inp.size[2]*inp.size[3];5152permuteOutShape.resize(4);53permuteOutShape[0] = permuteInpShape[0];54permuteOutShape[1] = permuteInpShape[2];55permuteOutShape[2] = permuteInpShape[1];56permuteOutShape[3] = permuteInpShape[3];5758std::vector<Mat> permuteInputs(1, inp.reshape(1, permuteInpShape));59std::vector<Mat> permuteOutputs(1, out.reshape(1, permuteOutShape));60permute->finalize(permuteInputs, permuteOutputs);61}62}6364#ifdef HAVE_OPENCL65bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)66{67std::vector<UMat> inputs;68std::vector<UMat> outputs;6970inps.getUMatVector(inputs);71outs.getUMatVector(outputs);7273if (inputs[0].u != outputs[0].u)74{75if (!permute.empty())76{77inputs[0] = inputs[0].reshape(1, permuteInpShape.size(), &permuteInpShape[0]);78outputs[0] = outputs[0].reshape(1, permuteOutShape.size(), &permuteOutShape[0]);79permute->preferableTarget = preferableTarget;80permute->forward(inputs, outputs, internals);81}82else83inputs[0].copyTo(outputs[0]);84}85return true;86}87#endif8889void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE90{91CV_TRACE_FUNCTION();92CV_TRACE_ARG_VALUE(name, "name", name.c_str());9394CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),95forward_ocl(inputs_arr, outputs_arr, internals_arr))9697if (inputs_arr.depth() == CV_16S)98{99forward_fallback(inputs_arr, outputs_arr, internals_arr);100return;101}102103std::vector<Mat> inputs, outputs, internals;104inputs_arr.getMatVector(inputs);105outputs_arr.getMatVector(outputs);106internals_arr.getMatVector(internals);107108Mat inp = inputs[0];109Mat out = outputs[0];110if (inp.data != out.data)111{112if (!permute.empty())113{114inp = inp.reshape(1, permuteInpShape);115out = out.reshape(1, permuteOutShape);116std::vector<Mat> permuteInputs(1, inp);117std::vector<Mat> permuteOutputs(1, out);118permute->forward(permuteInputs, permuteOutputs, internals);119}120else121inp.copyTo(out);122}123}124125private:126Ptr<PermuteLayer> permute;127std::vector<int> permuteInpShape, permuteOutShape;128};129130Ptr<Layer> ShuffleChannelLayer::create(const LayerParams& params)131{132return Ptr<Layer>(new ShuffleChannelLayerImpl(params));133}134135} // namespace dnn136} // namespace cv137138139