limes 3.1.0
Composable Calculus Expressions for C++20
Loading...
Searching...
No Matches
binary_func.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <span>
4#include <string>
5#include <cstddef>
6#include <cmath>
7#include <algorithm>
8#include "binary.hpp"
9#include "primitives.hpp"
10
11namespace limes::expr {
12
13// Binary function tags
14struct PowTag {};
15struct MaxTag {};
16struct MinTag {};
17
18template<typename Tag, typename L, typename R> struct BinaryFunc;
19
20// Type traits for BinaryFunc and specific tags
21template<typename E> inline constexpr bool is_binary_func_v = false;
22template<typename Tag, typename L, typename R>
23inline constexpr bool is_binary_func_v<BinaryFunc<Tag, L, R>> = true;
24
25template<typename E> inline constexpr bool is_runtime_pow_v = false;
26template<typename L, typename R>
27inline constexpr bool is_runtime_pow_v<BinaryFunc<PowTag, L, R>> = true;
28
29template<typename E> inline constexpr bool is_max_v = false;
30template<typename L, typename R>
31inline constexpr bool is_max_v<BinaryFunc<MaxTag, L, R>> = true;
32
33template<typename E> inline constexpr bool is_min_v = false;
34template<typename L, typename R>
35inline constexpr bool is_min_v<BinaryFunc<MinTag, L, R>> = true;
36
37// BinaryFunc<Tag, L, R>: A binary function applied to two child expressions
38template<typename Tag, typename L, typename R>
39struct BinaryFunc {
40 using value_type = typename L::value_type;
41 using tag_type = Tag;
42 using left_type = L;
43 using right_type = R;
44
45 static constexpr std::size_t arity_v = std::max(L::arity_v, R::arity_v);
46
49
50 constexpr BinaryFunc(L l, R r) noexcept : left{l}, right{r} {}
51
52 [[nodiscard]] constexpr value_type eval(std::span<value_type const> args) const {
53 value_type l_val = left.eval(args);
54 value_type r_val = right.eval(args);
55
56 if constexpr (std::is_same_v<Tag, PowTag>) {
57 return std::pow(l_val, r_val);
58 } else if constexpr (std::is_same_v<Tag, MaxTag>) {
59 return std::max(l_val, r_val);
60 } else if constexpr (std::is_same_v<Tag, MinTag>) {
61 return std::min(l_val, r_val);
62 }
63 }
64
65 [[nodiscard]] [[deprecated("use eval() instead")]]
66 constexpr value_type evaluate(std::span<value_type const> args) const {
67 return eval(args);
68 }
69
70 // Derivatives:
71 // pow(f, g): d/dx[f^g] = f^(g-1) * (g*f' + f*g'*ln(f))
72 // max/min: subgradient via 0.5*(f'+g') +/- 0.5*sign(f-g)*(f'-g')
73 template<std::size_t Dim>
74 [[nodiscard]] constexpr auto derivative() const {
75 auto df = left.template derivative<Dim>();
76 auto dg = right.template derivative<Dim>();
77
78 if constexpr (std::is_same_v<Tag, PowTag>) {
79 auto g_minus_1 = right - One<value_type>{};
80 auto f_to_g_minus_1 = BinaryFunc<PowTag, L, decltype(g_minus_1)>{left, g_minus_1};
81 auto log_f = UnaryFunc<LogTag, L>{left};
82 return f_to_g_minus_1 * (right * df + left * dg * log_f);
83 } else if constexpr (std::is_same_v<Tag, MaxTag>) {
84 auto half = Const<value_type>{value_type(0.5)};
85 auto diff = left - right;
86 auto sign_diff = diff / UnaryFunc<AbsTag, decltype(diff)>{diff};
87 return half * (df + dg) + half * sign_diff * (df - dg);
88 } else if constexpr (std::is_same_v<Tag, MinTag>) {
89 auto half = Const<value_type>{value_type(0.5)};
90 auto diff = left - right;
91 auto sign_diff = diff / UnaryFunc<AbsTag, decltype(diff)>{diff};
92 return half * (df + dg) - half * sign_diff * (df - dg);
93 }
94 }
95
96 [[nodiscard]] std::string to_string() const {
97 std::string func_name;
98 if constexpr (std::is_same_v<Tag, PowTag>) {
99 func_name = "pow";
100 } else if constexpr (std::is_same_v<Tag, MaxTag>) {
101 func_name = "max";
102 } else if constexpr (std::is_same_v<Tag, MinTag>) {
103 func_name = "min";
104 }
105 return "(" + func_name + " " + left.to_string() + " " + right.to_string() + ")";
106 }
107};
108
109// Factory functions with compile-time simplification
110
111// pow(expr, expr)
112template<typename L, typename R>
113 requires (is_expr_node_v<L> && is_expr_node_v<R>)
114[[nodiscard]] constexpr auto pow(L base, R exponent) {
115 if constexpr (is_zero_v<R>) {
117 } else if constexpr (is_one_v<R>) {
118 return base;
119 } else if constexpr (is_one_v<L>) {
121 } else if constexpr (is_const_expr_v<L> && is_const_expr_v<R>) {
122 return Const<typename L::value_type>{std::pow(base.value, exponent.value)};
123 } else {
124 return BinaryFunc<PowTag, L, R>{base, exponent};
125 }
126}
127
128// pow(expr, scalar)
129template<typename L, typename T>
130 requires (is_expr_node_v<L> && std::is_arithmetic_v<T>)
131[[nodiscard]] constexpr auto pow(L base, T exponent) {
132 using VT = typename L::value_type;
133 return pow(base, Const<VT>{static_cast<VT>(exponent)});
134}
135
136// pow(scalar, expr)
137template<typename T, typename R>
138 requires (std::is_arithmetic_v<T> && is_expr_node_v<R>)
139[[nodiscard]] constexpr auto pow(T base, R exponent) {
140 using VT = typename R::value_type;
141 return pow(Const<VT>{static_cast<VT>(base)}, exponent);
142}
143
144// max(expr, expr)
145template<typename L, typename R>
146 requires (is_expr_node_v<L> && is_expr_node_v<R>)
147[[nodiscard]] constexpr auto max(L a, R b) {
148 if constexpr (is_const_expr_v<L> && is_const_expr_v<R>) {
149 return Const<typename L::value_type>{std::max(a.value, b.value)};
150 } else if constexpr (std::is_same_v<L, R>) {
151 return a;
152 } else {
153 return BinaryFunc<MaxTag, L, R>{a, b};
154 }
155}
156
157// max(expr, scalar)
158template<typename L, typename T>
159 requires (is_expr_node_v<L> && std::is_arithmetic_v<T>)
160[[nodiscard]] constexpr auto max(L a, T b) {
161 using VT = typename L::value_type;
162 return max(a, Const<VT>{static_cast<VT>(b)});
163}
164
165// max(scalar, expr)
166template<typename T, typename R>
167 requires (std::is_arithmetic_v<T> && is_expr_node_v<R>)
168[[nodiscard]] constexpr auto max(T a, R b) {
169 using VT = typename R::value_type;
170 return max(Const<VT>{static_cast<VT>(a)}, b);
171}
172
173// min(expr, expr)
174template<typename L, typename R>
175 requires (is_expr_node_v<L> && is_expr_node_v<R>)
176[[nodiscard]] constexpr auto min(L a, R b) {
177 if constexpr (is_const_expr_v<L> && is_const_expr_v<R>) {
178 return Const<typename L::value_type>{std::min(a.value, b.value)};
179 } else if constexpr (std::is_same_v<L, R>) {
180 return a;
181 } else {
182 return BinaryFunc<MinTag, L, R>{a, b};
183 }
184}
185
186// min(expr, scalar)
187template<typename L, typename T>
188 requires (is_expr_node_v<L> && std::is_arithmetic_v<T>)
189[[nodiscard]] constexpr auto min(L a, T b) {
190 using VT = typename L::value_type;
191 return min(a, Const<VT>{static_cast<VT>(b)});
192}
193
194// min(scalar, expr)
195template<typename T, typename R>
196 requires (std::is_arithmetic_v<T> && is_expr_node_v<R>)
197[[nodiscard]] constexpr auto min(T a, R b) {
198 using VT = typename R::value_type;
199 return min(Const<VT>{static_cast<VT>(a)}, b);
200}
201
202// Type aliases
203template<typename L, typename R>
205
206template<typename L, typename R>
208
209template<typename L, typename R>
211
212} // namespace limes::expr
Expression layer for composable calculus.
Definition analysis.hpp:7
constexpr bool is_min_v
constexpr bool is_max_v
constexpr auto max(L a, R b)
constexpr bool is_runtime_pow_v
constexpr bool is_binary_func_v
constexpr auto min(L a, R b)
constexpr auto pow(L base, R exponent)
constexpr value_type evaluate(std::span< value_type const > args) const
constexpr auto derivative() const
constexpr BinaryFunc(L l, R r) noexcept
typename L::value_type value_type
constexpr value_type eval(std::span< value_type const > args) const
std::string to_string() const
static constexpr std::size_t arity_v