Path: blob/master/modules/gapi/src/compiler/gmodelbuilder.cpp
16337 views
// This file is part of OpenCV project.1// It is subject to the license terms in the LICENSE file found in the top-level directory2// of this distribution and at http://opencv.org/license.html.3//4// Copyright (C) 2018 Intel Corporation567////////////////////////////////////////////////////////////////////////////////8//9// FIXME: "I personally hate this file"10// - Dmitry11//12////////////////////////////////////////////////////////////////////////////////13#include "precomp.hpp"1415#include <utility> // tuple16#include <stack> // stack17#include <vector> // vector18#include <unordered_set> // unordered_set19#include <type_traits> // is_same2021#include <ade/util/zip_range.hpp> // util::indexed2223#include "api/gapi_priv.hpp" // GOrigin24#include "api/gproto_priv.hpp" // descriptor_of and other GProtoArg-related25#include "api/gcall_priv.hpp"26#include "api/gnode_priv.hpp"2728#include "compiler/gmodelbuilder.hpp"2930namespace {313233// TODO: move to helpers and cover with internal tests?34template<typename T> struct GVisited35{36typedef std::unordered_set<T> VTs;3738bool visited(const T& t) const { return m_visited.find(t) != m_visited.end(); }39void visit (const T& t) { m_visited.insert(t); }40const VTs& visited() const { return m_visited; }4142private:43VTs m_visited;44};4546template<typename T, typename U = T> struct GVisitedTracker: protected GVisited<T>47{48typedef std::vector<U> TUs;4950void visit(const T& t, const U& u) { GVisited<T>::visit(t); m_tracked.push_back(u); }51const TUs& tracked() const { return m_tracked; }52using GVisited<T>::visited;5354private:55TUs m_tracked;56};5758} // namespace596061cv::gimpl::Unrolled cv::gimpl::unrollExpr(const GProtoArgs &ins,62const GProtoArgs &outs)63{64// FIXME: Who's gonna check if ins/outs are not EMPTY?65// FIXME: operator== for GObjects? (test if the same object or not)66using GObjId = const cv::GOrigin*;6768GVisitedTracker<const GNode::Priv*, cv::GNode> ops;69GVisited<GObjId> reached_sources;70cv::GOriginSet origins;7172// Cache input argument objects for a faster look-up73// While the only reliable way to identify a Data object is Origin74// (multiple data objects may refer to the same Origin as result of75// multuple yield() calls), input objects can be uniquely identified76// by its `priv` address. Here we rely on this to verify if the expression77// we unroll actually matches the protocol specified to us by user.78std::unordered_set<GObjId> in_objs_p;79for (const auto& in_obj : ins)80{81// Objects are guarnateed to remain alive while this method82// is working, so it is safe to keep pointers here and below83in_objs_p.insert(&proto::origin_of(in_obj));84}8586// Recursive expression traversal87std::stack<cv::GProtoArg> data_objs(std::deque<cv::GProtoArg>(outs.begin(), outs.end()));88while (!data_objs.empty())89{90const auto obj = data_objs.top();91const auto &obj_p = proto::origin_of(obj);92data_objs.pop();9394const auto &origin = obj_p;95origins.insert(origin); // TODO: Put Object description here later on9697// If this Object is listed in the protocol, don't dive deeper (even98// if it is in fact a result of operation). Our computation is99// bounded by this data slot, so terminate this recursion path early.100if (in_objs_p.find(&obj_p) != in_objs_p.end())101{102reached_sources.visit(&obj_p);103continue;104}105106const cv::GNode &node = origin.node;107switch (node.shape())108{109case cv::GNode::NodeShape::EMPTY:110// TODO: Own exception type?111util::throw_error(std::logic_error("Empty node reached!"));112break;113114case cv::GNode::NodeShape::PARAM:115case cv::GNode::NodeShape::CONST_BOUNDED:116// No preceding operation to this data object - so the data object is either a GComputation117// parameter or a constant (compile-time) value118// Record it to check if protocol matches expression tree later119if (!reached_sources.visited(&obj_p))120reached_sources.visit(&obj_p);121break;122123case cv::GNode::NodeShape::CALL:124if (!ops.visited(&node.priv()))125{126// This operation hasn't been visited yet - mark it so,127// then add its operands to stack to continue recursion.128ops.visit(&node.priv(), node);129130const cv::GCall call = origin.node.call();131const cv::GCall::Priv& call_p = call.priv();132133// Put the outputs object description of the node134// so that they are not lost if they are not consumed by other operations135for (const auto &it : ade::util::indexed(call_p.m_k.outShapes))136{137std::size_t port = ade::util::index(it);138GShape shape = ade::util::value(it);139140GOrigin org { shape, node, port};141origins.insert(org);142}143144for (const auto &arg : call_p.m_args)145{146if (proto::is_dynamic(arg))147{148data_objs.push(proto::rewrap(arg)); // Dive deeper149}150}151}152break;153154default:155// Unsupported node shape156GAPI_Assert(false);157break;158}159}160161// Check if protocol mentions data_objs which weren't reached during traversal162const auto missing_reached_sources = [&reached_sources](GObjId p) {163return reached_sources.visited().find(p) == reached_sources.visited().end();164};165if (ade::util::any_of(in_objs_p, missing_reached_sources))166{167// TODO: Own exception type or a return code?168util::throw_error(std::logic_error("Data object listed in Protocol "169"wasn\'t reached during unroll"));170}171172// Check if there endpoint (parameter) data_objs which are not listed in protocol173const auto missing_in_proto = [&in_objs_p](GObjId p) {174return p->node.shape() != cv::GNode::NodeShape::CONST_BOUNDED &&175in_objs_p.find(p) == in_objs_p.end();176};177if (ade::util::any_of(reached_sources.visited(), missing_in_proto))178{179// TODO: Own exception type or a return code?180util::throw_error(std::logic_error("Data object reached during unroll "181"wasn\'t found in Protocol"));182}183184return cv::gimpl::Unrolled{ops.tracked(), origins};185}186187188cv::gimpl::GModelBuilder::GModelBuilder(ade::Graph &g)189: m_g(g)190{191}192193cv::gimpl::GModelBuilder::ProtoSlots194cv::gimpl::GModelBuilder::put(const GProtoArgs &ins, const GProtoArgs &outs)195{196const auto unrolled = cv::gimpl::unrollExpr(ins, outs);197198// First, put all operations and its arguments into graph.199for (const auto &op_expr_node : unrolled.all_ops)200{201GAPI_Assert(op_expr_node.shape() == GNode::NodeShape::CALL);202const GCall& call = op_expr_node.call();203const GCall::Priv& call_p = call.priv();204ade::NodeHandle call_h = put_OpNode(op_expr_node);205206for (const auto &it : ade::util::indexed(call_p.m_args))207{208const auto in_port = ade::util::index(it);209const auto& in_arg = ade::util::value(it);210211if (proto::is_dynamic(in_arg))212{213ade::NodeHandle data_h = put_DataNode(proto::origin_of(in_arg));214cv::gimpl::GModel::linkIn(m_g, call_h, data_h, in_port);215}216}217}218219// Then iterate via all "origins", instantiate (if not yet) Data graph nodes220// and connect these nodes with their producers in graph221for (const auto &origin : unrolled.all_data)222{223const cv::GNode& prod = origin.node;224GAPI_Assert(prod.shape() != cv::GNode::NodeShape::EMPTY);225226ade::NodeHandle data_h = put_DataNode(origin);227if (prod.shape() == cv::GNode::NodeShape::CALL)228{229ade::NodeHandle call_h = put_OpNode(prod);230cv::gimpl::GModel::linkOut(m_g, call_h, data_h, origin.port);231}232}233234// Mark graph data nodes as INPUTs and OUTPUTs respectively (according to the protocol)235for (const auto &arg : ins)236{237ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));238m_g.metadata(nh).get<Data>().storage = Data::Storage::INPUT;239}240for (const auto &arg : outs)241{242ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));243m_g.metadata(nh).get<Data>().storage = Data::Storage::OUTPUT;244}245246// And, finally, store data object layout in meta247m_g.metadata().set(Layout{m_graph_data});248249// After graph is generated, specify which data objects are actually250// computation entry/exit points.251using NodeDescr = std::pair<std::vector<RcDesc>,252std::vector<ade::NodeHandle> >;253254const auto get_proto_slots = [&](const GProtoArgs &proto) -> NodeDescr255{256NodeDescr slots;257258slots.first.reserve(proto.size());259slots.second.reserve(proto.size());260261for (const auto &arg : proto)262{263ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));264const auto &desc = m_g.metadata(nh).get<Data>();265//These extra empty {} are to please GCC (-Wmissing-field-initializers)266slots.first.push_back(RcDesc{desc.rc, desc.shape, {}});267slots.second.push_back(nh);268}269return slots;270};271272auto in_slots = get_proto_slots(ins);273auto out_slots = get_proto_slots(outs);274return ProtoSlots{in_slots.first, out_slots.first,275in_slots.second, out_slots.second};276}277278ade::NodeHandle cv::gimpl::GModelBuilder::put_OpNode(const cv::GNode &node)279{280const auto& node_p = node.priv();281const auto it = m_graph_ops.find(&node_p);282if (it == m_graph_ops.end())283{284GAPI_Assert(node.shape() == GNode::NodeShape::CALL);285const auto &call_p = node.call().priv();286auto nh = cv::gimpl::GModel::mkOpNode(m_g, call_p.m_k, call_p.m_args, node_p.m_island);287m_graph_ops[&node_p] = nh;288return nh;289}290else return it->second;291}292293// FIXME: rename to get_DataNode (and same for Op)294ade::NodeHandle cv::gimpl::GModelBuilder::put_DataNode(const GOrigin &origin)295{296const auto it = m_graph_data.find(origin);297if (it == m_graph_data.end())298{299auto nh = cv::gimpl::GModel::mkDataNode(m_g, origin);300m_graph_data[origin] = nh;301return nh;302}303else return it->second;304}305306307