in usr/awf/Julia/awfrdiff.jl [318:524]
function awfrdiff(ex; outsym=nothing, order::Int=1, evalmod=Main, debug=false, allorders=true, params...)
length(params) >= 1 || error("There should be at least one parameter specified, none found")
order <= 1 ||
length(params) == 1 || error("Only one param allowed for order >= 2")
order <= 1 ||
isa(params[1][2], Vector) ||
isa(params[1][2], Real) || error("Param should be a real or vector for order >= 2")
paramsym = Symbol[ e[1] for e in params]
paramvalues = [ e[2] for e in params]
parval = Dict(zip(paramsym, paramvalues))
g = awftograph(ex, evalmod)
hassym(g.seti, outsym) ||
error("can't find output var $( outsym==nothing ? "" : outsym)")
g.seti = NSMap([getnode(g.seti, outsym)], [ outsym ])
g |> splitnary! |> prune! |> simplify!
calc!(g, params=parval, emod=evalmod)
ov = getnode(g.seti, outsym).val
isa(ov, Real) || error("output var should be a Real, $(typeof(ov)) found")
voi = Any[ outsym ]
if order == 1
dg = reversegraph(g, getnode(g.seti, outsym), paramsym)
append!(g.nodes, dg.nodes)
for p in paramsym
nn = getnode(dg.seti, dprefix(p))
ns = newvar("_dv")
g.seti[nn] = ns
push!(voi, ns)
end
g |> splitnary! |> prune! |> simplify!
elseif order > 1 && isa(paramvalues[1], Real)
for i in 1:order
dg = reversegraph(g, getnode(g.seti, voi[i]), paramsym)
append!(g.nodes, dg.nodes)
nn = collect(nodes(dg.seti))[1]
ns = newvar("_dv")
g.seti[nn] = ns
push!(voi, ns)
g |> splitnary! |> prune! |> simplify!
calc!(g, params=parval, emod=evalmod)
end
elseif order > 1 && isa(paramvalues[1], Vector)
dg = reversegraph(g, getnode(g.seti, outsym), paramsym)
append!(g.nodes, dg.nodes)
ns = newvar(:_dv)
g.seti[ collect(nodes(dg.seti))[1] ] = ns
push!(voi, ns)
g |> splitnary! |> prune! |> simplify!
for i in 2:order
no = getnode(g.seti, voi[i])
si = newvar(:_idx)
ni = addnode!(g, NExt(si))
ns = addnode!(g, NRef(:getidx, [ no, ni ]))
calc!(g, params=Dict(zip([paramsym; si], [paramvalues; 1])), emod=evalmod)
dg = reversegraph(g, ns, paramsym)
dg2 = ExNode[]
nmap = Dict()
for n in dg.nodes
for (j, np) in enumerate(n.parents)
if haskey(nmap, np)
n.parents[j] = nmap[np]
elseif np == ni
nn = NExt(si)
push!(dg2, nn)
dg.exti[nn] = si
n.parents[j] = nn
nmap[np] = nn
elseif np == ns
if !haskey(nmap, no)
sn = newvar()
nn = NExt(sn)
push!(dg2, nn)
dg.exti[nn] = sn
dg.exto[no] = sn
nmap[no] = nn
end
nn = NRef(:getidx, [ nmap[no], nmap[ni] ])
push!(dg2, nn)
nmap[ns] = nn
elseif !(np in dg.nodes)
sn = newvar()
nn = NExt(sn)
push!(dg2, nn)
dg.exti[nn] = sn
dg.exto[np] = sn
n.parents[j] = nn
nmap[np] = nn
end
end
if isa(n, NFor)
g2 = n.main[2]
for (o,s) in g2.exto
if haskey(nmap, o)
g2.exto[ nmap[o] ] = s
end
end
end
end
append!(dg.nodes, dg2)
nf = addnode!(g, NFor(Any[ si, dg ] ) )
nsz = addgraph!( :( length( x ) ), g, Dict( :x => getnode(g.exti, paramsym[1]) ) )
ndsz = addgraph!( :( sz ^ $(i-1) ), g, Dict( :sz => nsz ) )
nid = addgraph!( :( 1:dsz ), g, Dict( :dsz => ndsz ) )
push!(nf.parents, nid)
sst = newvar()
inst = addnode!(dg, NExt(sst))
dg.exti[inst] = sst
dg.exto[nsz] = sst
push!(nf.parents, nsz)
nsa = addgraph!( :( zeros( $( Expr(:tuple, [:sz for j in 1:i]...) ) ) ),
g, Dict( :sz => nsz ) )
ssa = newvar()
insa = addnode!(dg, NExt(ssa))
dg.exti[insa] = ssa
dg.exto[nsa] = ssa
push!(nf.parents, nsa)
nres = addgraph!( :( res[ ((sidx-1)*st+1):(sidx*st) ] = dx ; res ), dg,
Dict(:res => insa,
:sidx => nmap[ni],
:st => inst,
:dx => collect(dg.seti)[1][1] ) )
dg.seti = NSMap([nres], [ssa])
nex = addnode!(g, NIn(ssa, [nf]))
dg.seto = NSMap([nex], [ssa])
append!( nf.parents, setdiff(collect( nodes(dg.exto)), nf.parents[2:end]) )
ns = newvar(:_dv)
g.seti[nex] = ns
push!(voi, ns)
g |> splitnary! |> prune! |> simplify!
end
end
if !allorders
voi = [voi[end]]
end
if length(voi) > 1
voin = map( s -> getnode(g.seti, s), voi )
nf = addnode!(g, NConst(tuple))
exitnode = addnode!(g, NCall(:call, [nf, voin...]))
else
exitnode = getnode(g.seti, voi[1])
end
g.seti = NSMap( [exitnode], [nothing])
g |> splitnary! |> prune! |> simplify!
resetvar()
debug ? g : tocode(g)
end