“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 f(x1,x2)=lnx1+x1x2−sinx2 的计算图。
现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入x1=2,x2=5,上述计算图获得函数值 f(2,5)=ln(2)+2×5−sin(5)=11.652;并且根据微分链式法则,上图得到的梯度 ∇f=[∂f/∂x1,∂f/∂x2]=[1/x1+x2,x1−cosx2]=[5.500,1.716]。
知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(ex,即编程语言中的 exp(x) 函数)、对数(lnx,即编程语言中的 log(x) 函数)和正弦函数(sinx,即编程语言中的 sin(x) 函数)。
友情提醒:
- 常数的导数是 0;x 的导数是 1;ex 的导数还是 ex;lnx 的导数是 1/x;sinx 的导数是 cosx。
- 回顾一下什么是偏导数:在数学中,一个多变量的函数的偏导数,就是它关于其中一个变量的导数而保持其他变量恒定。在上面的例子中,当我们对 x1 求偏导数 ∂f/∂x1 时,就将 x2 当成常数,所以得到 lnx1 的导数是 1/x1,x1x2 的导数是 x2,sinx2 的导数是 0。
- 回顾一下链式法则:复合函数的导数是构成复合这有限个函数在相应点的导数的乘积,即若有 u=f(y),y=g(x),则 du/dx=du/dy⋅dy/dx。例如对 sin(lnx) 求导,就得到 cos(lnx)⋅(1/x)。
如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。
输入格式:
输入在第一行给出正整数 N(≤5×104),为计算图中的顶点数。
以下 N 行,第 i 行给出第 i 个顶点的信息,其中 i=0,1,⋯,N−1。第一个值是顶点的类型编号,分别为:
- 0 代表输入变量
- 1 代表加法,对应 x1+x2
- 2 代表减法,对应 x1−x2
- 3 代表乘法,对应 x1×x2
- 4 代表指数,对应 ex
- 5 代表对数,对应 lnx
- 6 代表正弦函数,对应 sinx
对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。
题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -
),且计算过程不会超过双精度浮点数的计算精度范围。
输出格式:
首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。
输入样例:
7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4
输出样例:
11.652
5.500 1.716
分析:结构体A中存储相关节点信息,de[i]表示第i个点是否被其他节点调用操作到,使用deal做dfs操作来根据题意进行题目的模拟运行,其中,c为1表示求导,c为0表示正常晕眩,index表示当前节点信息,to为求偏导数得目标,Record则是一个dfs下的记忆优化,start为一个可以作为出发点的节点。依题意输入数据后,遍历找出一个出发点节点,然后直接丢到deal函数中即可得到答案~注意导数运算的相关规则~
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
#include <bits/stdc++.h> using namespace std; struct node { int which, op1, op2; double value; }A[50001]; int N, start = -1, cnt, de[50001]; map<bool, map<int, map<int, double>>> Record; double deal(int c, int index, int to) { if (Record[c][index][to]) return Record[c][index][to]; if (A[index].which == 0) return Record[c][index][to] = c ? (index == to ? 1.0 : 0.0) : A[index].value; else if (A[index].which == 1) return Record[c][index][to] = deal(c, A[index].op1, to) + deal(c, A[index].op2, to); else if (A[index].which == 2) return Record[c][index][to] = deal(c, A[index].op1, to) - deal(c, A[index].op2, to); else if (A[index].which == 3) return Record[c][index][to] = c ? (deal(1, A[index].op1, to) * deal(0, A[index].op2, to) + deal(0, A[index].op1, to) * deal(1, A[index].op2, to)) : (deal(0, A[index].op1, to) * deal(0, A[index].op2, to)); else if (A[index].which == 4) return Record[c][index][to] = c ? exp(deal(0, A[index].op1, to)) * deal(1, A[index].op1, to) : exp(deal(0, A[index].op1, to)); else if (A[index].which == 5) return Record[c][index][to] = c ? 1.0 / deal(0, A[index].op1, to) * deal(1, A[index].op1, to) : log(deal(0, A[index].op1, to)); else return Record[c][index][to] = c ? cos(deal(0, A[index].op1, to)) * deal(1, A[index].op1, to) : sin(deal(0, A[index].op1, to)); } int main() { scanf("%d", &N); for (int i = 0; i < N; i++) { scanf("%d", &A[i].which); if (A[i].which == 0) { ++cnt; scanf("%lf", &A[i].value); } else if (A[i].which < 4) { scanf("%d%d", &A[i].op1, &A[i].op2); de[A[i].op1] = de[A[i].op2] = 1; } else { scanf("%d", &A[i].op1); de[A[i].op1] = 1; } } while(de[++start]); printf("%.3lf\n", deal(0, start, N)); for (int i = 0; i < N; i++) { if (A[i].which == 0) printf("%.3lf%c", deal(1, start, i), --cnt ? ' ' : '\n'); } return 0; } |
❤ 点击这里 -> 订阅《PAT | 蓝桥 | LeetCode学习路径 & 刷题经验》by 柳婼