We want to support running inference, training and checkpointing in one ProgramDesc. We implement void Prune(const ProgramDesc* input, ProgramDesc* output) function, which takes a ProgramDesc and generate a pruned ProgramDesc.


Pruning need to support both variables and operators being evaluation targets. Consider the following different situations.

# Case 1: run foward pass.
cost_np =
# Case 2: run backward passing.
opts_np, _ =[cost, opt])
# Case 3: run checkpointing
_ =


To support evaluation of operators, we add is_target field in the OpDesc.

message OpDesc {
  required string type = 3;
  repeated Var inputs = 1;
  repeated Var outputs = 2;
  repeated Attr attrs = 4;
  optional bool is_target = 5 [ default = false ];

To support evaluation of variables, we add fetch_op. For each variable in the target, we insert a fetch_op into the ProgramDesc with variable being fetch_op's input. Then we also set fetch_op is a target.


If an operator needs to be run, it must fall into one of the following cases:

  1. It is the target.
  2. It is depended by some other ops, meaning its output is some other op's input.

The first case can be checked by op_desc.is_traget() . The second case can be implement as

bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
  for (auto& var : op_desc.outputs()) {
    for (auto& argu : var.arguments()) {
      if (dependent_vars.count(argu) != 0) {
        return true;
  return false;

Then the whole algorithm can be implemented as the following code.