I'm trying to make a reverse mode automatic differentiation in C++.
The idea I came up with is that each variable that results of an operation on one or two other variables, is going to save the gradients in a vector.
This is the code :
class Var {
private:
double value;
char character;
std::vector<std::pair<double, const Var*> > children;
public:
Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
void set_character(const char& character){ this->character = character; }
// computes the derivative of the current object with respect to 'var'
double gradient(Var* var) const{
if(this==var){
return 1.0;
}
double sum=0.0;
for(auto& pair : children){
// std::cout << "(" << this->character << " -> " << pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
sum += pair.first*pair.second->gradient(var);
}
return sum;
}
friend Var operator+(const Var& l, const Var& r){
Var result(l.value+r.value);
result.children.push_back(std::make_pair(1.0, &l));
result.children.push_back(std::make_pair(1.0, &r));
return result;
}
friend Var operator*(const Var& l, const Var& r){
Var result(l.value*r.value);
result.children.push_back(std::make_pair(r.value, &l));
result.children.push_back(std::make_pair(l.value, &r));
return result;
}
friend std::ostream& operator<<(std::ostream& os, const Var& var){
os << var.value;
return os;
}
};
I tried to run the code like this :
int main(int argc, char const *argv[]) {
Var x(5,'x'), y(6,'y'), z(7,'z');
Var k = z + x*y;
k.set_character('k');
std::cout << "k = " << k << std::endl;
std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;
return 0;
}
The computational graph that should be build is the following :
x(5) y(6) z(7)
\ / /
∂w/∂x=y \ / ∂w/∂y=x /
\ / /
w=x*y /
\ / ∂k/∂z=1
\ /
∂k/∂w=1 \ /
\_________/
|
k=w+z
Then, if I want to calculate ∂k/∂x
for instance, I have to multiply the gradients following the edges, and sum the result for every edge. This is done recursively by double gradient(Var* var) const
. So I have ∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x
.
The problem
If I have intermediate calculation such as x*y
here, something goes wrong. When std::cout
is uncommented here is the output :
k = 37
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂x = 0
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂y = 5
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂z = 1
It prints which variable is connected to which one, then their addresses, and the weight of the connection (which should be the gradient).
The problem is weight=0
between x
and the intermediate variable which holds the result of x*y
(and which I denoted as w
in my graph).
I have no idea why this one is zero and not the other weight connected to y
.
Another thing I noticed, is that if you switch the lines in operator*
like so :
result.children.push_back(std::make_pair(1.0, &r));
result.children.push_back(std::make_pair(1.0, &l));
Then it's the y
connections which cancels.
Thanks in advance for any help.