作者:罗秀哲
出处:https://zhuanlan.zhihu.com/p/47592565
在Julia里写一个自己的自动微分需要多少行代码?我花了一天时间写了一个自己的自动微分,算上注释,支持多维数组的大部分运算(以及broadcast),大概也就200~400行的样子。
我之前写了一篇英文的博文: Implement You Own AD in ONE day 当天上了Hacker News头版,现在在等PyTorch编译的时候再写个中文版的。
其实我最近在算一些张量网络,但是我不想像过去一样手动求导,用PyTorch写了一遍,发现因为没有batch trace之类的操作,我需要在Python里写几个循环。不是很想用Torch Script(为什么总要我在Python里面写别的语言?!)然而去看了一下Julia的framework,Knet使用的是和PyTorch类似的AutoGrad(它们基本都是按照Python的autograd来写的),没有按照Julia的performance tips实现,速度比较烂。Flux的速度很快(至少在我的情况下),然而我不想把导数定义成closure,我想要一个尽可能简单的,只需要几个type+multiple dispatch和两个接口就能搞定的自动微分,而且这个自动微分的一切命名和接口都要尽可能接近PyTorch(PyTorch有用户体验最好的接口)。
结果最后写这么个东西只需要一天不到的时间,我把代码整理成了一个package,你可以在GitHub上找到: Roger-luo/YAAD.jl 。
在这篇文章里,我会介绍我是怎么实现这个自动微分的,因为代码量非常小,也许你也可以写一个玩玩。
自动微分简介
自动微分(automatic differentiation)技术在机器学习里也叫做后向传播,它的原理实际上是通过记录运算顺序,利用已经定义好的导数规则,生成一个正常计算程序对偶的程序。一般来说有两种自动微分方式,一种是前向自动微分(Forward Automatic Differentiation)另外一种是后向自动微分(Reverse Automatic Differentiation),后者更加适合多参数的情况(算法细节就不详述了,多参数的时候后向自动微分的时间复杂度更低,一次传播可以计算所有的参数)。
后向自动微分会讲所有的操作以一张图的方式存储下来,这张图称为计算图。这也是各大深度学习框架的核心所在——如何干净地产生一个计算图,然后高效计算它。
为了展示计算图是什么,我从Cornell,CS5740,2017sp这门课的课件里搬运了一些图,然后把他们做成了动画(动画使用 Luxor.jl 制作,你也许想看看我是怎么画出来的,绘图的脚本在这里: plot.jl )
我们以计算下面这个表达式为例:
我们将会调用这样一些Julia函数来计算它:
- : transpose 函数
- :矩阵向量乘,你可以直接通过 * 号来计算,当然也可以使用BLAS routine BLAS.gemv来计算
- 向量的点积,这个可以调用函数 dot,也可以用UTF8字符 x ? y
- 另外一个向量点积
- 最后是把它们相加,使用了乘号 +
实际上我们可以通过一张图来表示这个表达式的计算过程,每一个拥有输入边的节点代表了一个函数,而每一个拥有输出边的节点代表了一个变量。
而计算这样一张图,我们将先从叶子结点开始赋值(绿色),然后依次计算相邻的节点,不断向前传播。这个过程称为前向传播过程。
接下来我们按照链式法则来计算导数,每个节点返回的导数都和输入的数量相同,我们从最上面的节点开始向后传播,将当前有导数的节点标记为红色。
当红色传播到变量处时,我们就获得了变量的导数。
动态图 VS 静态图
按照构建计算图的方式不同,我们可以将计算图分为动态图和静态图两种,尽管在算法上并没有很大区别,但是在实现上我们可以选择在前向传播的过程中构建计算图(比如PyTorch),也可以选择先构建计算图再计算各个节点的值(比如tensorflow)。
就我个人而言,我比较喜欢PyTorch,所以这里我将实现一个动态图。
定义计算图中的节点
在我们开始写具体的实现之前,先来为所有的节点类型定义一个抽象类型(类似于基类):
abstract type AbstractNode end
在PyTorch里,能够拥有导数的称为变量(Variable),尽管在0.4版本之后Tensor默认就是一个Variable了(有requires_grad为True),在后端依然还有这个类型。它是对计算图构建过程中不可或缺的类型。接下来我们来定义变量(Variable)
mutable struct Variable{T} <: AbstractNode
value::T
grad::T
Variable(val::T) where T = new{T}(val, zero(grad))
end
类似PyTorch一样,变量存储了值(value)和它的梯度(grad),在每一次后向传播的过程中我们将会不断地将梯度累加到这个变量的梯度上去。这里 zero 是几乎所有Julia数值类型都有的一个接口,它将放回对应的零元素,例如对 Float64 类型的Julia变量,将返回0.0,对Array{Float64}将返回一个充满0.0的 Array{Float64}。
其它节点
我们现在有了变量了,也就是计算图的叶子结点,接下来还需要有中间的节点。它们将存储一个函数和它们的输入
struct Node{FT <: Function, ArgsT <: Tuple, KwargsT <: NamedTuple} <: AbstractNode
f::FT
args::ArgsT
kwargs::KwargsT
end
我们这里使用参数类型,这样在将来进行分发的时候,编译器能够自己通过类型推导出要分发的函数从而提高运行时的性能。
但我们应当要考虑broadcast(广播)和正常的函数调用的区别,由于Julia能够对任意函数进行广播,广播时所调用的实际上是 broadcast 函数,所以我们不妨实现两个trait来区分这种情况:
abstract type Operator end
module Trait
import YAAD: Operator
struct Method{FT} <: Operator
f::FT
end
struct Broadcasted{FT} <: Operator
f::FT
end
end # Trait
这里我将这两个trait实现在一个module里面是为了能够显示地体现出他们俩是trait,因为之后调用的时候将会写为 Trait.Method 和 Trait.Broadcasted,他们各自存储了一个函数(注意Julia里每个函数都是一个callable的类型)。
然后我们把原先Node类型的参数约束Function改成Operator
struct Node{FT <: Operator, ArgsT <: Tuple, KwargsT <: NamedTuple} <: AbstractNode
f::FT
args::ArgsT
kwargs::KwargsT
end
接下来为了方便我们来定义一些构造函数
# wrap function to Method
Node(f::Function, args, kwargs) = Node(Trait.Method(f), args, kwargs)
Node(op, args) = Node(op, args, NamedTuple())
第一个是因为大部分时间,我们要记录的函数就是它本身而不是一个广播,第二个是因为大部分涉及数值计算的函数都没有关键字(keyword)。
实际上,Node类型本身也只是函数和它的输入类型的一个trait,它在计算的过程中也只是负责(静态地)分发方法。在更加高级的实现里,我们实际上有更加漂亮的实现,利用Cassette.jl对Julia代码进行非侵入式地自动微分(意思是无需给源码重载运算符,增加Variable类型,编译器将直接在JIT期间对前向传播的代码进行变换,从而直接得到计算梯度的代码)。
最后,我们还需要定义一个缓存函数输出的对象,这个缓存的值将会被一些函数的导数用到
mutable struct CachedNode{NT <: AbstractNode, OutT} <: AbstractNode
node::NT
output::OutT
end
而这个节点将在前向传播的同时被构建出来(否则我们无法知道输出的类型是什么)
function CachedNode(f, args...; kwargs...)
node = Node(f, args, kwargs.data) # this constructs a Node
output = forward(node)
CachedNode(node, output)
end
我们暂且把这个接口定义为forward(与PyTorch一致)
求值
求值是最重要的部分,因为我们需要将我们的自动微分设计地可扩展,尽量不要在扩展的时候编写冗余的代码。而在Julia里,我们可以利用多重派发(multiple dispatch)实现这一点。
前向传播
那么如何进行前向传播呢?这取决于对于 forward 这个抽象函数(generic function),实现了什么方法(method):
- 如果输入是一个 Node 类型,我们将其展开
forward(node.f, map(forward, node.args)...; map(forward, node.kwargs)...)
2. 这将使得我们多了一层插入自定义方法的接口,如果我们有一个自定义的算符,它并非一个函数,我们只需要实现对应的方法即可,例如
struct Linear <: Operator
w::Matrix{Float64}
b::Vector{Float64}
end
forward(op::Linear, x::Vector{Float64}) = op.w * x + b
3. 然而对于简单的函数调用,我们并不想每次都写
function forward(::Method{typeof(sin)}, x)
sin(x)
end
所以我们再实现一个默认展开Operator的方法
forward(op::Operator, args...; kwargs...) = op(args...; kwargs...)
这意味着只要Operator实现了自己的call方法(如果这个Operator类型是callable的),那么就无需去写别的东西,自动调用这个方法。当然我们现在要回去给Method Trait实现一下它的call方法
(op::Trait.Method)(args...; kwargs...) = op.f(args...; kwargs...)
例如,我们现在只需要定义 Linear的call方法就够了
(op::Linear)(x::Vector) = op.w * x + op.b
4. 此外,除了变量,还有一些常数例如
Variable(2.0) + 3.0
这里的3.0就是一个不需要求导的常数,我们原封不动地返回它,这样我们只要实现一个 value 接口来获取值即可
value(x) = x
value(x::Variable) = x.value
value(x::CachedNode) = x.output
然后直接调用value
forward(x) = x
forward(x::Variable) = value(x)
然后别忘了,对于其它类型我们返回一个友好一些的报错
forward(x::NT) where {NT <: AbstractNode} = error("forward method is not implemented for node type: $NT")
function value(x::T) where {T <: AbstractNode}
error(
"Expected value in this node $x of type $T ",
"check if you defined a non-cached node",
" or overload value function for your node."
)
end
然后对于 Variable和CachedNode我们要返回它们存储的值,好的:ok_hand:,到目前为止,我们已经搞定前向传播部分了,接下来是后向传播部分。
后向传播
后向传播实际上和前向传播几乎是一样的,我们只要不断地在不同的类型标签下迭代backward接口即可(注意我不打算在这里实现关键词的后向传播,尽管这并不难)
首先,对Variable来说,这很简单直接加接收到的梯度就好了
function backward(x::Variable, grad)
x.grad += grad
nothing
end
然后我们现在定义 CachedNode的后向传播规则
- 我们先从一个叫gradient的方法里获得各个输入的导数
- 然后再把这些导数依次输入到输入类型对应的backward函数里去
function backward(node::CachedNode, f, grad)
grad_inputs = gradient(node, grad)
for (each, each_grad) in zip(args(node), grad_inputs)
backward(each, each_grad)
end
nothing
end
等等,我们要在这里加一些友好的报错信息,面得以后我们自己抓狂。首先是类型的检查,这完全是静态的,所以不同担心会影响性能
backward_type_assert(node::CachedNode{<:AbstractNode, T}, grad::T) where T = true
backward_type_assert(node::CachedNode{<:AbstractNode, T1}, grad::T2) where {T1, T2} =
error("Gradient is expected to have the same",
" type with outputs, expected $T1",
" got $T2")
我们在这里要求输出和梯度的类型要一样,但是对于多维数组(AbstractArray)我们只要求它们的数据类型和维度相同即可,因为有可能一些函数会返回特别优化的数组(例如稀疏数组,或者一些懒惰求值的中间结果)。
# exclude arrays
backward_type_assert(node::CachedNode{<:AbstractNode, T1}, grad::T2) where
{T, N, T1 <: AbstractArray{T, N}, T2 <: AbstractArray{T, N}} = true
然后我们还要检查梯度和输出的大小是否匹配
function backward_size_assert(node::CachedNode, grad)
size(node.output) == size(grad) ||
error(
"gradient should have the same size with output,",
" expect size $(size(node.output)), got $(size(grad))"
)
end
在Julia里,可以通过编译选项把边界检查关掉,因为我们有时候完全不需要边界检查,你可以通过增加 @boundscheck 这个宏来实现这一点,最后我们的backward函数如下:
function backward(node::CachedNode, f, grad)
backward_type_assert(node, grad)
@boundscheck backward_size_assert(node, grad)
grad_inputs = gradient(node, grad)
for (each, each_grad) in zip(args(node), grad_inputs)
backward(each, each_grad)
end
nothing
end
现在我们来考虑如何定义梯度,也就是gradient方法,我们依然希望不要写冗余的代码,同时保证性能和扩展性。比如,实现sin的导数只需要定义
gradient(::typeof(sin), grad, output, x) = grad * cos(x)
我们还是利用多重派发来实现这一点,先把 CachedNode展开
gradient(x::CachedNode, grad) = gradient(x.node.f, grad, x.output, map(value, x.node.args)...; map(value, x.node.kwargs)...)
然后把Operator展开到函数上去
gradient(x::Trait.Method, grad, output, args...; kwargs...) =
gradient(x.f, grad, output, args...; kwargs...)
最后定义一个报错信息
gradient(fn, grad, output, args...; kwargs...) =
error(
"gradient of operator $fn is not defined\n",
"Possible Fix:\n",
"define one of the following:\n",
"1. gradient(::typeof($fn), grad, output, args...; kwargs...)\n",
"2. gradient(op::Trait.Method{typeof($fn)}, grad, output, args...; kwargs...)\n",
"3. gradient(op::Trait.Broadcasted{typeof($fn)}, grad, output, args...; kwargs...)\n"
)
这样,我们就可以选择不同的gradient接口来实现导数,Julia将自动派发你实现的这个方法,例如
# I re-define the concrete type `Linear` here in order to store the gradient
struct Linear <: Operator
w::Variable{Matrix{Float64}}
b::Variable{Vector{Float64}}
end
function gradient(op::Linear, grad, output, x)
grad_w, grad_b = # some gradient expression to calculate the gradient of w and b
backward(op.w, grad_w) # update gradient of w
backward(op.w, grad_b) # update gradient of b
grad_input = # calculate the gradient of input
grad_input # return the gradient of input
end
最后我们定义一个 register 的接口用来产生 CachedNode
register(f, args...; kwargs...) = CachedNode(f, args...; kwargs...)
这样我们就可以通过重载函数/运算符来构建计算图了
Base.sin(x::AbstractNode) = register(Base.sin, x)
gradient(::typeof(Base.sin), grad, output, x) = (grad * cos(x), )
不过等等,似乎这里有时候需要判断一下输入是什么类型比较好,我们不妨为Variable和CachedNode定义一个抽象类型Value
abstract type Value{T} <: AbstractNode end
Value类型将带有其子类型的值的类型T作为其参数。现在先回去修改 Variable和CachedNode
mutable struct Variable{T} <: Value{T}
value::T
grad::T
Variable(val::T) where T = new{T}(val, zero(grad))
end
mutable struct CachedNode{NT <: AbstractNode, OutT} <: Value{OutT}
node::NT
output::OutT
end
广播
然而上面的定义还只能给标量用,对于数组我们还需要广播才行。Julia自己实现了一套广播系统,它能够广播任何Julia函数到数组上,会融合多个被广播的函数(从而产生更优质的向量化SIMD代码),同时还允许定义广播的行为。这恰好就是我们需要的:我们要在广播的同时产生一个计算图,记录这个操作
首先我们定义我们自己的广播风格(BroadcastStyle):
struct ComputGraphStyle <: Broadcast.BroadcastStyle end
Base.BroadcastStyle(::Type{<:AbstractNode}) = ComputGraphStyle()
Broadcast.BroadcastStyle(s::ComputGraphStyle, x::Broadcast.BroadcastStyle) = s
这还不够,Julia的broadcast是懒惰求值的,它先通过broadcasted方法构建中间类型,然后再在最后通过materialize方法进行求值。我们还需要让它们也被记录在计算图里
function Broadcast.broadcasted(::ComputGraphStyle, f, args...)
mt = Trait.Broadcasted(f)
register(mt, args...)
end
Broadcast.materialize(x::AbstractNode) = register(Broadcast.materialize, x)
然后我们让materialize在后向传播的时候直接返回梯度
function backward(node::CachedNode, ::typeof(Broadcast.materialize), grad)
backward_type_assert(node, grad)
@boundscheck backward_size_assert(node, grad)
backward(node.node.args[1], grad) # materialize only has one arguments, we don't need the for loop
end
然而这时,Broadcasted类型的backward会调用默认的CachedNode的backward方法,有时就会因为类型不同报错(因为我们之前这么定义了)我们为这个类型开个后门
function backward(node::CachedNode, ::Trait.Broadcasted, grad)
grad_inputs = gradient(node, grad)
for (each, each_grad) in zip(args(node), grad_inputs)
backward(each, each_grad)
end
nothing
end
免费获得更多的算符
Julia有一个包叫做 DiffRules.jl 它记录了大量常用算符的导数规则,并且这些导数规则都以Julia表达式的方式记录,这意味着我们可以利用元编程批量生产算符。
这些导数规则都在一个常数列表里,名为DiffRules.DEFINED_DIFFRULES,我们遍历它即可
for (mod, name, nargs) in keys(DiffRules.DEFINED_DIFFRULES)
# code generation
end
这里 mod 是module的名字,name是函数的名字,nargs是函数输入变量的个数,然后我们就可以用如下的方式来批量产生这些导数的定义
for (mod, name, nargs) in keys(DiffRules.DEFINED_DIFFRULES)
f_ex_head = Expr(:., mod, QuoteNode(name))
if nargs == 1
df_ex = DiffRules.diffrule(mod, name, :x)
name === :abs && continue # exclude abs, it cannot be directly broadcasted
@eval begin
$(f_ex_head)(x::AbstractNode) = register($(f_ex_head), x)
gradient(::typeof($(f_ex_head)), grad, output, x) = (grad * $df_ex, )
gradient(mt::Trait.Broadcasted{typeof($f_ex_head)}, grad, output, x) = (@.(grad * $(df_ex)), )
end
elseif nargs == 2
df_ex = DiffRules.diffrule(mod, name, :x, :y)
@eval begin
$(f_ex_head)(x1::AbstractNode, x2) = register($f_ex_head, x1, x2)
$(f_ex_head)(x1, x2::AbstractNode) = register($f_ex_head, x1, x2)
$(f_ex_head)(x1::AbstractNode, x2::AbstractNode) = register($f_ex_head, x1, x2)
gradient(::typeof($f_ex_head), grad, output, x, y) =
(grad * $(df_ex[1]), grad * $(df_ex[2]))
gradient(::Trait.Broadcasted{typeof($f_ex_head)}, grad, output, x, y) =
(@.(grad * ($(df_ex[1]))), @.(grad * $(df_ex[2])))
end
else
@info "unknown operator $name"
end
end
对如何使用代码生成,我建议你阅读Julia的文档: 元编程 · Julia中文文档 。我在这里跳过了 abs 函数是因为批量广播的宏不能对if else进行广播。我们需要单独去定义abs的导数,但是剩下几乎所有的数学函数都用Diffrules生成了。
代码修饰
之后我又花了一些时间实现仿照PyTorch了一个计算Jacobbian的函数用来做单元测试。然后利用Trait将数组类型的Variable重新插入AbstractArray的类型树中以实现更好的打印信息。
https:// github.com/Roger-luo/YA AD.jl/blob/master/src/show.jl Roger-luo/YAAD.jl Roger-luo/YAAD.jl
性能对比
好了!到此我们就写完了这个自动微分库了,它的性能怎么样呢?我起初以为这么简单的一个实现只是一个玩具,但实际上它的性能非常不错!
我需要计算一个称为MPS的东西(Matrix product state),所以我在这里使用了我使用最频繁的操作进行benchmark,这个操作是 tr(x1 * x2) ,这里 x1和x2是矩阵,然后对其求迹。
所以我首先为YAAD实现了这两个算符:
# 这一部分其实已经在DiffRules进行代码生成的时候定义过了
Base.:(*)(lhs::Value, rhs) = register(Base.:(*), lhs, rhs)
Base.:(*)(lhs, rhs::Value) = register(Base.:(*), lhs, rhs)
Base.:(*)(lhs::Value, rhs::Value) = register(Base.:(*), lhs, rhs)
# 这里开始是新的定义
using LinearAlgebra
LinearAlgebra.tr(x::Value) = register(LinearAlgebra.tr, x)
gradient(::typeof(tr), grad, output, x) = (grad * Matrix(I, size(x)), )
function gradient(::typeof(*), grad, output, lhs::AbstractVecOrMat, rhs::AbstractVecOrMat)
grad * transpose(rhs), transpose(lhs) * grad
end
然后我选取了几个Julia的库(Zygote,Flux,YAAD是我的),还有PyTorch在CPU上进行了一下比较
Zygote.@grad LinearAlgebra.tr(x) = LinearAlgebra.tr(x), Δ-> (Δ * Matrix(I, size(x)), )
function bench_tr_mul_yaad(x1, x2)
z = tr(x1 * x2)
YAAD.backward(z)
x1.grad, x2.grad
end
function bench_tr_mul_autograd(x1, x2)
z = AutoGrad.@diff tr(x1 * x2)
AutoGrad.grad(z, x1), AutoGrad.grad(z, x2)
end
function bench_tr_mul_zygote(x1, x2)
Zygote.gradient((x1, x2)->tr(x1 * x2), x1, x2)
end
function bench_tr_mul_flux(x1, x2)
z = tr(x1 * x2)
Flux.Tracker.back!(z, 1)
x1.grad, x2.grad
end
然后在Python里测试PyTorch(我们的接口和PyTorch非常相似不是吗?)
def bench_tr_mul_torch(x1, x2):
z = torch.trace(torch.matmul(x1, x2))
z.backward()
return x1.grad, x2.grad
然后输入定义如下:
xv, yv = rand(30, 30), rand(30, 30)
yaad_x, yaad_y = YAAD.Variable(xv), YAAD.Variable(yv)
autograd_x, autograd_y = AutoGrad.Param(xv), AutoGrad.Param(yv)
flux_x, flux_y = Flux.param(xv), Flux.param(yv)
此外,在进行测试之前,我们实现一个手动计算梯度的版本作为基准:
function bench_tr_mul_base(x1, x2)
z1 = x1 * x2
z2 = tr(z1)
grad_z1 = Matrix{eltype(z1)}(I, size(z1))
grad_z1 * transpose(x2), transpose(x1) * grad_z1
end
然后在Julia里我们用 @benchmark 宏来多次测量以获取运行时间
julia> @benchmark bench_tr_mul_autograd(autograd_x, autograd_y)
BenchmarkTools.Trial:
memory estimate: 33.20 KiB
allocs estimate: 82
--------------
minimum time: 50.218 μs (0.00% GC)
median time: 62.364 μs (0.00% GC)
mean time: 90.422 μs (9.86% GC)
maximum time: 55.386 ms (99.86% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark bench_tr_mul_yaad(yaad_x, yaad_y)
BenchmarkTools.Trial:
memory estimate: 51.50 KiB
allocs estimate: 16
--------------
minimum time: 10.387 μs (0.00% GC)
median time: 13.429 μs (0.00% GC)
mean time: 24.273 μs (45.13% GC)
maximum time: 55.963 ms (99.96% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark bench_tr_mul_zygote(xv, yv)
BenchmarkTools.Trial:
memory estimate: 29.98 KiB
allocs estimate: 10
--------------
minimum time: 42.527 μs (0.00% GC)
median time: 46.640 μs (0.00% GC)
mean time: 56.996 μs (15.31% GC)
maximum time: 51.718 ms (99.90% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark bench_tr_mul_base(xv, yv)
BenchmarkTools.Trial:
memory estimate: 28.78 KiB
allocs estimate: 5
--------------
minimum time: 6.413 μs (0.00% GC)
median time: 8.201 μs (0.00% GC)
mean time: 12.215 μs (31.57% GC)
maximum time: 11.012 ms (99.87% GC)
--------------
samples: 10000
evals/sample: 5
julia> @benchmark bench_tr_mul_flux(flux_x, flux_y)
BenchmarkTools.Trial:
memory estimate: 30.25 KiB
allocs estimate: 24
--------------
minimum time: 8.009 μs (0.00% GC)
median time: 10.002 μs (0.00% GC)
mean time: 14.412 μs (30.14% GC)
maximum time: 16.286 ms (99.87% GC)
--------------
samples: 10000
evals/sample: 3
然后PyTorch (0.4.1) 上
In [4]: x = torch.rand(30, 30, dtype=torch.float64, requires_grad=True)
In [5]: y = torch.rand(30, 30, dtype=torch.float64, requires_grad=True)
In [6]: %timeit bench_tr_mul_torch(x, y)
76.8 μs ± 1.68 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
所以我们花了小半天实现的这个自动微分还不赖嘛?只比基准性能满了几个微秒,意外的是它竟然比PyTorch快了不少。然后Flux的Tracker性能竟然非常接近手动求导!
然而让我觉得奇怪的是Zygote的性能竟然和Python的Autograd的Julia实现(因为实现是差不多的,所以速度都不怎么样)差不多,而Zygote是基于源代码变换的,理论上应该有最好的性能。我后来提交了这个issue: Backward performance of `tr(x1 * x2)` · Issue #28 · FluxML/Zygote.jl
所以,在Julia里实现你自己的自动微分并不难,你也可以像我一样一天实现一个你自己的自动微分。不妨试试?
作者:罗秀哲
出处:https://zhuanlan.zhihu.com/p/47592565