Path: blob/master/modules/dnn/src/halide_scheduler.cpp
16337 views
// This file is part of OpenCV project.1// It is subject to the license terms in the LICENSE file found in the top-level directory2// of this distribution and at http://opencv.org/license.html.3//4// Copyright (C) 2017, Intel Corporation, all rights reserved.5// Third party copyrights are property of their respective owners.67#include "precomp.hpp"8#include "halide_scheduler.hpp"9#include "op_halide.hpp"1011namespace cv12{13namespace dnn14{1516#ifdef HAVE_HALIDE17static void applySplit(const FileNode& directive, Halide::Func& func,18const FileNode& params)19{20for (const auto& varNode : directive)21{22const std::string varName = varNode.name();23const std::string factorName = (std::string)varNode;24Halide::Var var(varName);25Halide::Var outerVar(varName + "o");26Halide::Var innerVar(varName + "i");27// If split factor is integer or parameters map has parameter value.28CV_Assert(varNode.isString() && !params[factorName].empty() ||29varNode.isInt());30int factor = (int)(varNode.isInt() ? varNode : params[factorName]);31func.split(var, outerVar, innerVar, factor);32}33}3435static void applyReorder(const FileNode& directive, Halide::Func& func)36{37std::string varName;38const int numVars = directive.size();39std::vector<Halide::VarOrRVar> reorderedVars;40reorderedVars.reserve(numVars);41for (int i = 0; i < numVars; ++i)42{43directive[i] >> varName;44reorderedVars.push_back(Halide::Var(varName));45}46func.reorder(reorderedVars);47}4849static void applyFuse(const FileNode& directive, Halide::Func& func)50{51CV_Assert(directive["src"].size() >= 2);52CV_Assert(directive["dst"].size() == 1);5354std::string str;55directive["src"][0] >> str;56Halide::Var firstVar(str);57directive["src"][1] >> str;58Halide::Var secondVar(str);59directive["dst"] >> str;60Halide::Var dstVar(str);6162func.fuse(firstVar, secondVar, dstVar);63for (int i = 2, n = directive["src"].size(); i < n; ++i)64{65directive["src"][i] >> str;66func.fuse(Halide::Var(str), dstVar, dstVar);67}68}6970static void applyParallel(const FileNode& directive, Halide::Func& func)71{72std::string varName;73for (int i = 0, n = directive.size(); i < n; ++i)74{75directive[i] >> varName;76func.parallel(Halide::Var(varName));77}78}7980static void applyUnroll(const FileNode& directive, Halide::Func& func)81{82std::string varName;83for (int i = 0, n = directive.size(); i < n; ++i)84{85directive[i] >> varName;86func.unroll(Halide::Var(varName));87}88}8990static void applyVectorize(const FileNode& directive, Halide::Func& func,91const FileNode& params)92{93for (const auto& varNode : directive)94{95const std::string varName = varNode.name();96const std::string factorName = (std::string)varNode;97// If split factor is integer or parameters map has parameter value.98CV_Assert(varNode.isString() && !params[factorName].empty() ||99varNode.isInt());100int factor = (int)(varNode.isInt() ? varNode : params[factorName]);101Halide::Var var(varName);102Halide::Var inner(varName + "v");103func.split(var, var, inner, factor);104func.vectorize(inner);105}106}107108static void applyStoreAt(const FileNode& directive, Halide::Func& func,109std::map<std::string, Halide::Func>& funcsMap)110{111for (const auto& funcNode : directive)112{113const std::string targetFuncName = funcNode.name();114if (funcsMap.find(targetFuncName) == funcsMap.end())115CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +116" is not represented in Halide pipeline");117Halide::Func targetFunc = funcsMap[targetFuncName];118func.store_at(targetFunc, (std::string)funcNode);119break;120}121}122123static void applyComputeAt(const FileNode& directive, Halide::Func& func,124std::map<std::string, Halide::Func>& funcsMap)125{126for (const auto& funcNode : directive)127{128const std::string targetFuncName = funcNode.name();129if (funcsMap.find(targetFuncName) == funcsMap.end())130CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +131" is not represented in Halide pipeline");132Halide::Func targetFunc = funcsMap[targetFuncName];133func.compute_at(targetFunc, (std::string)funcNode);134break;135}136}137138static void applyComputeRoot(const FileNode& directive, Halide::Func& func)139{140bool compute_root;141directive >> compute_root;142if (compute_root)143func.compute_root();144}145146static void applyGpuBlocks(const FileNode& directive, Halide::Func& func)147{148std::string varName;149for (int i = 0, n = directive.size(); i < n; ++i)150{151directive[i] >> varName;152func.gpu_blocks(Halide::Var(varName));153}154}155156static void applyGpuThreads(const FileNode& directive, Halide::Func& func)157{158std::string varName;159for (int i = 0, n = directive.size(); i < n; ++i)160{161directive[i] >> varName;162func.gpu_threads(Halide::Var(varName));163}164}165166static void apply(const FileNode& directives, Halide::Func& func,167std::map<std::string, Halide::Func>& funcsMap,168const FileNode& params)169{170for (const auto& directive : directives)171{172if (directive.name() == "split")173applySplit(directive, func, params);174else if (directive.name() == "reorder")175applyReorder(directive, func);176else if (directive.name() == "fuse")177applyFuse(directive, func);178else if (directive.name() == "parallel")179applyParallel(directive, func);180else if (directive.name() == "unroll")181applyUnroll(directive, func);182else if (directive.name() == "vectorize")183applyVectorize(directive, func, params);184else if (directive.name() == "store_at")185applyStoreAt(directive, func, funcsMap);186else if (directive.name() == "compute_at")187applyComputeAt(directive, func, funcsMap);188else if (directive.name() == "compute_root")189applyComputeRoot(directive, func);190else if (directive.name() == "gpu_blocks")191applyGpuBlocks(directive, func);192else if (directive.name() == "gpu_threads")193applyGpuThreads(directive, func);194else195CV_Error(Error::StsNotImplemented, "Scheduling directive " +196directive.name() + " is not implemented.");197}198}199200// Remove any numeric symbols after '$' sign.201static std::string Deunique(std::string str)202{203int pos = -1;204do205{206pos = str.find('$');207if (pos != -1)208{209int len = str.find_first_not_of("0123456789", pos + 1) - pos;210str = str.replace(pos, len, "");211}212}213while (pos != -1);214return str;215}216#endif // HAVE_HALIDE217218HalideScheduler::HalideScheduler(const std::string& configFile)219{220if (!configFile.empty())221fs = FileStorage(configFile, FileStorage::READ);222}223224HalideScheduler::~HalideScheduler()225{226if (fs.isOpened())227fs.release();228}229230bool HalideScheduler::process(Ptr<BackendNode>& node)231{232#ifdef HAVE_HALIDE233if (!fs.isOpened())234return false;235236const FileNode& scheduleNode = fs["scheduling"];237if (scheduleNode.empty())238CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node");239240std::string str;241std::map<std::string, Halide::Func> funcsMap; // Scheduled functions.242// For every function, from top to bottom, we try to find a scheduling node.243// Scheduling is successful (return true) if for the first function (top)244// node is represented.245CV_Assert(!node.empty());246std::vector<Halide::Func>& funcs = node.dynamicCast<HalideBackendNode>()->funcs;247for (int i = funcs.size() - 1; i >= 0; --i)248{249Halide::Func& func = funcs[i];250// For functions with the same name Halide generates unique names251// for example func, func$1, func$2.252// They are always formed with '$' and number.253std::string funcName = Deunique(func.name());254255const FileNode& funcNode = scheduleNode[funcName];256if (!funcNode.empty())257{258if (!funcNode["pattern"].empty())259{260funcNode["pattern"] >> str;261if (fs["patterns"][str].empty())262CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str +263" is not defined");264apply(fs["patterns"][str], func, funcsMap, funcNode["params"]);265}266else267{268apply(funcNode, func, funcsMap, funcNode["params"]);269}270}271else272{273if (funcsMap.empty())274return false;275}276funcsMap[funcName] = func;277}278return true;279#endif // HAVE_HALIDE280return false;281}282283} // namespace dnn284} // namespace cv285286287