limes 3.1.0
Composable Calculus Expressions for C++20
Loading...
Searching...
No Matches
conditional.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <span>
4#include <string>
5#include <cstddef>
6#include <algorithm>
7#include "binary.hpp"
8
9namespace limes::expr {
10
11// Conditional<Cond, Then, Else>: Piecewise/conditional expressions.
12// Evaluates to then_branch if condition > 0, else_branch otherwise.
13template<typename Cond, typename Then, typename Else>
15 using value_type = typename Then::value_type;
16 using condition_type = Cond;
17 using then_type = Then;
18 using else_type = Else;
19
20 static constexpr std::size_t arity_v =
21 std::max({Cond::arity_v, Then::arity_v, Else::arity_v});
22
26
27 constexpr Conditional(Cond c, Then t, Else e) noexcept
28 : condition{c}, then_branch{t}, else_branch{e} {}
29
30 [[nodiscard]] constexpr value_type eval(std::span<value_type const> args) const {
31 if (condition.eval(args) > value_type(0)) {
32 return then_branch.eval(args);
33 } else {
34 return else_branch.eval(args);
35 }
36 }
37
38 [[nodiscard]] [[deprecated("use eval() instead")]]
39 constexpr value_type evaluate(std::span<value_type const> args) const {
40 return eval(args);
41 }
42
43 // Subgradient: d/dx[if c>0 then f else g] = if c>0 then df else dg
44 template<std::size_t Dim>
45 [[nodiscard]] constexpr auto derivative() const {
46 auto df = then_branch.template derivative<Dim>();
47 auto dg = else_branch.template derivative<Dim>();
48 return Conditional<Cond, decltype(df), decltype(dg)>{condition, df, dg};
49 }
50
51 [[nodiscard]] std::string to_string() const {
52 return "(if " + condition.to_string() + " "
53 + then_branch.to_string() + " "
54 + else_branch.to_string() + ")";
55 }
56};
57
58template<typename T>
59struct is_conditional : std::false_type {};
60
61template<typename C, typename T, typename E>
62struct is_conditional<Conditional<C, T, E>> : std::true_type {};
63
64template<typename T>
66
67// Factory functions
68
69template<typename C, typename T, typename E>
70 requires (is_expr_node_v<C> && is_expr_node_v<T> && is_expr_node_v<E>)
71[[nodiscard]] constexpr auto if_then_else(C cond, T then_expr, E else_expr) {
72 return Conditional<C, T, E>{cond, then_expr, else_expr};
73}
74
75template<typename C, typename T>
76 requires (is_expr_node_v<C> && std::is_arithmetic_v<T>)
77[[nodiscard]] constexpr auto if_then_else(C cond, T then_val, T else_val) {
78 using VT = typename C::value_type;
80 cond,
81 Const<VT>{static_cast<VT>(then_val)},
82 Const<VT>{static_cast<VT>(else_val)}
83 };
84}
85
86// Common piecewise functions
87
88// heaviside(e): H(x) = 1 if x > 0, 0 otherwise
89template<typename E>
90 requires is_expr_node_v<E>
91[[nodiscard]] constexpr auto heaviside(E e) {
92 using T = typename E::value_type;
93 return if_then_else(e, One<T>{}, Zero<T>{});
94}
95
96// ramp(e): max(e, 0) -- the positive part / ReLU function
97template<typename E>
98 requires is_expr_node_v<E>
99[[nodiscard]] constexpr auto ramp(E e) {
100 using T = typename E::value_type;
101 return if_then_else(e, e, Zero<T>{});
102}
103
104// sign(e): 1 if x > 0, -1 if x < 0, 0 if x = 0
105template<typename E>
106 requires is_expr_node_v<E>
107[[nodiscard]] constexpr auto sign(E e) {
108 using T = typename E::value_type;
109 auto inner = if_then_else(-e, Const<T>{T(-1)}, Zero<T>{});
110 return if_then_else(e, One<T>{}, inner);
111}
112
113// clamp(e, lo, hi): Clamp e to [lo, hi]
114template<typename E, typename T>
115 requires (is_expr_node_v<E> && std::is_arithmetic_v<T>)
116[[nodiscard]] constexpr auto clamp(E e, T lo, T hi) {
117 using VT = typename E::value_type;
118 auto lo_const = Const<VT>{static_cast<VT>(lo)};
119 auto hi_const = Const<VT>{static_cast<VT>(hi)};
120 auto upper_clamped = if_then_else(hi_const - e, e, hi_const);
121 return if_then_else(e - lo_const, upper_clamped, lo_const);
122}
123
124// indicator(e): Alias for heaviside, emphasizing indicator function semantics
125template<typename E>
126 requires is_expr_node_v<E>
127[[nodiscard]] constexpr auto indicator(E e) {
128 return heaviside(e);
129}
130
131} // namespace limes::expr
Expression layer for composable calculus.
Definition analysis.hpp:7
constexpr auto heaviside(E e)
constexpr auto indicator(E e)
constexpr auto ramp(E e)
constexpr bool is_conditional_v
constexpr auto sign(E e)
constexpr auto clamp(E e, T lo, T hi)
constexpr auto if_then_else(C cond, T then_expr, E else_expr)
constexpr Conditional(Cond c, Then t, Else e) noexcept
typename Then::value_type value_type
constexpr value_type evaluate(std::span< value_type const > args) const
static constexpr std::size_t arity_v
constexpr value_type eval(std::span< value_type const > args) const
constexpr auto derivative() const
std::string to_string() const