-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Caffe2::Operator
Caffe2 has a concept operator, which corresponds to TensorFlow's Op.
Different from Op, an operator usualy accompany with a gradient operator (GradientOp).
Let us take ReluOp and ReluGradientOp as an example.
All operators are classes derived from Operaotr<Context>.
template <typename T, class Context>
class ReluOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class ReluGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(ReluGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};Operator<Context> has a data member Context context_, which records the current device (or GPU). The constructor of Operator initializes context_ by passing in its constructor a proto message OperatorDef. Then Operator::Operator calls context_.SwitchToDevice(0).
Operator<Context has three virtual functions:
-
RunOnDevice() = 0is what you want to override, -
Run(stream_id)callscontext_.SwitchToDevice(stream_id),RunOnDevice, andcontext_.FinishDeviceComputation, and -
RunAsynccallscontext_.SwitchToDevice(stream_id)andRunOnDevice.
[TODO: Check what Context::FinishDeviceComputation does.]
Operator<Context> also allows user overriden RunOnDevice to access inputs and outputs through:
Operator<Context> derives from class OperatorBase.