limes 3.1.0
Composable Calculus Expressions for C++20
Loading...
Searching...
No Matches
derivative.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "nodes/const.hpp"
4#include "nodes/var.hpp"
5#include "nodes/binary.hpp"
6#include "nodes/unary.hpp"
8#include <functional>
9
10namespace limes::expr {
11
12// Compile-time derivative: preferred method
13// Returns an expression representing the derivative with respect to dimension Dim
14template<std::size_t Dim, typename E>
15 requires is_expr_node_v<E>
16[[nodiscard]] constexpr auto derivative(E expr) {
17 return expr.template derivative<Dim>();
18}
19
20// =============================================================================
21// AnyExpr: Type-erased expression wrapper for runtime derivative dispatch
22// =============================================================================
23// With type-level simplification (Zero<T>, One<T>), derivatives for different
24// dimensions can return different types. AnyExpr provides type erasure to
25// support runtime dimension selection.
26
27template<typename T>
28struct AnyExpr {
29 using value_type = T;
30 static constexpr std::size_t arity_v = 8; // Maximum supported arity
31
32 std::function<T(std::span<T const>)> eval_fn;
33 std::function<std::string()> to_string_fn;
34
35 AnyExpr() = default;
36
37 template<typename E>
38 requires is_expr_node_v<E>
39 explicit AnyExpr(E expr)
40 : eval_fn([expr](std::span<T const> args) { return expr.eval(args); })
41 , to_string_fn([expr]() { return expr.to_string(); })
42 {}
43
44 [[nodiscard]] T eval(std::span<T const> args) const {
45 return eval_fn(args);
46 }
47
48 // Deprecated: use eval() instead
49 [[nodiscard]] [[deprecated("use eval() instead")]]
50 T evaluate(std::span<T const> args) const {
51 return eval(args);
52 }
53
54 [[nodiscard]] std::string to_string() const {
55 return to_string_fn();
56 }
57};
58
59// Runtime derivative dispatch using type erasure
60// Supports dimensions 0-7
61
62namespace detail {
63
64template<typename E, std::size_t... Is>
65auto derivative_dispatch_impl(E expr, std::size_t dim, std::index_sequence<Is...>) {
66 using T = typename E::value_type;
67
68 // Table of type-erased derivative constructors
69 AnyExpr<T> (*funcs[])(E) = {
70 [](E e) { return AnyExpr<T>{e.template derivative<Is>()}; }...
71 };
72
73 if (dim >= sizeof...(Is)) {
74 return AnyExpr<T>{Zero<T>{}}; // Beyond supported dimensions, return zero
75 }
76
77 return funcs[dim](expr);
78}
79
80} // namespace detail
81
82// Runtime derivative with dimension as parameter
83// Returns type-erased AnyExpr<T> to handle varying derivative types
84// Supports dimensions 0-7
85template<typename E>
86 requires is_expr_node_v<E>
87[[nodiscard]] auto derivative(E expr, std::size_t dim) {
88 return detail::derivative_dispatch_impl(expr, dim, std::make_index_sequence<8>{});
89}
90
91// Gradient: compute all partial derivatives as a tuple
92// Returns tuple<d/dx0, d/dx1, ..., d/dx(N-1)>
93namespace detail {
94
95template<typename E, std::size_t... Is>
96constexpr auto gradient_impl(E expr, std::index_sequence<Is...>) {
97 return std::make_tuple(expr.template derivative<Is>()...);
98}
99
100} // namespace detail
101
102template<typename E>
103 requires is_expr_node_v<E>
104[[nodiscard]] constexpr auto gradient(E expr) {
105 return detail::gradient_impl(expr, std::make_index_sequence<E::arity_v>{});
106}
107
108// Higher-order derivatives
109// derivative<Dim1, Dim2, ...>(expr) computes d/dx_Dim1 d/dx_Dim2 ... expr
110template<std::size_t Dim, std::size_t... Dims, typename E>
111 requires is_expr_node_v<E>
112[[nodiscard]] constexpr auto derivative_n(E expr) {
113 if constexpr (sizeof...(Dims) == 0) {
114 return expr.template derivative<Dim>();
115 } else {
116 return derivative_n<Dims...>(expr.template derivative<Dim>());
117 }
118}
119
120} // namespace limes::expr
auto derivative_dispatch_impl(E expr, std::size_t dim, std::index_sequence< Is... >)
constexpr auto gradient_impl(E expr, std::index_sequence< Is... >)
Expression layer for composable calculus.
Definition analysis.hpp:7
constexpr auto derivative_n(E expr)
constexpr auto derivative(E expr)
constexpr auto gradient(E expr)
static constexpr std::size_t arity_v
T evaluate(std::span< T const > args) const
std::function< T(std::span< T const >)> eval_fn
T eval(std::span< T const > args) const
std::string to_string() const
std::function< std::string()> to_string_fn