Halide 14.0.0
Halide compiler and libraries
Expr.h
Go to the documentation of this file.
1#ifndef HALIDE_EXPR_H
2#define HALIDE_EXPR_H
3
4/** \file
5 * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6 */
7
8#include <string>
9#include <vector>
10
11#include "IntrusivePtr.h"
12#include "Type.h"
13
14namespace Halide {
15
16struct bfloat16_t;
17struct float16_t;
18
19namespace Internal {
20
21class IRMutator;
22class IRVisitor;
23
24/** All our IR node types get unique IDs for the purposes of RTTI */
25enum class IRNodeType {
26 // Exprs, in order of strength. Code in IRMatch.h and the
27 // simplifier relies on this order for canonicalization of
28 // expressions, so you may need to update those modules if you
29 // change this list.
30 IntImm,
31 UIntImm,
35 Cast,
37 Add,
38 Sub,
39 Mod,
40 Mul,
41 Div,
42 Min,
43 Max,
44 EQ,
45 NE,
46 LT,
47 LE,
48 GT,
49 GE,
50 And,
51 Or,
52 Not,
53 Select,
54 Load,
55 Ramp,
56 Call,
57 Let,
58 Shuffle,
60 // Stmts
61 LetStmt,
64 For,
65 Acquire,
66 Store,
67 Provide,
69 Free,
70 Realize,
71 Block,
72 Fork,
76 Atomic
77};
78
80
81/** The abstract base classes for a node in the Halide IR. */
82struct IRNode {
83
84 /** We use the visitor pattern to traverse IR nodes throughout the
85 * compiler, so we have a virtual accept method which accepts
86 * visitors.
87 */
88 virtual void accept(IRVisitor *v) const = 0;
90 : node_type(t) {
91 }
92 virtual ~IRNode() = default;
93
94 /** These classes are all managed with intrusive reference
95 * counting, so we also track a reference count. It's mutable
96 * so that we can do reference counting even through const
97 * references to IR nodes.
98 */
100
101 /** Each IR node subclass has a unique identifier. We can compare
102 * these values to do runtime type identification. We don't
103 * compile with rtti because that injects run-time type
104 * identification stuff everywhere (and often breaks when linking
105 * external libraries compiled without it), and we only want it
106 * for IR nodes. One might want to put this value in the vtable,
107 * but that adds another level of indirection, and for Exprs we
108 * have 32 free bits in between the ref count and the Type
109 * anyway, so this doesn't increase the memory footprint of an IR node.
110 */
112};
113
114template<>
115inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
116 return t->ref_count;
117}
118
119template<>
120inline void destroy<IRNode>(const IRNode *t) {
121 delete t;
122}
123
124/** IR nodes are split into expressions and statements. These are
125 similar to expressions and statements in C - expressions
126 represent some value and have some type (e.g. x + 3), and
127 statements are side-effecting pieces of code that do not
128 represent a value (e.g. assert(x > 3)) */
129
130/** A base class for statement nodes. They have no properties or
131 methods beyond base IR nodes for now. */
132struct BaseStmtNode : public IRNode {
134 : IRNode(t) {
135 }
136 virtual Stmt mutate_stmt(IRMutator *v) const = 0;
137};
138
139/** A base class for expression nodes. They all contain their types
140 * (e.g. Int(32), Float(32)) */
141struct BaseExprNode : public IRNode {
143 : IRNode(t) {
144 }
145 virtual Expr mutate_expr(IRMutator *v) const = 0;
147};
148
149/** We use the "curiously recurring template pattern" to avoid
150 duplicated code in the IR Nodes. These classes live between the
151 abstract base classes and the actual IR Nodes in the
152 inheritance hierarchy. It provides an implementation of the
153 accept function necessary for the visitor pattern to work, and
154 a concrete instantiation of a unique IRNodeType per class. */
155template<typename T>
156struct ExprNode : public BaseExprNode {
157 void accept(IRVisitor *v) const override;
158 Expr mutate_expr(IRMutator *v) const override;
160 : BaseExprNode(T::_node_type) {
161 }
162 ~ExprNode() override = default;
163};
164
165template<typename T>
166struct StmtNode : public BaseStmtNode {
167 void accept(IRVisitor *v) const override;
168 Stmt mutate_stmt(IRMutator *v) const override;
170 : BaseStmtNode(T::_node_type) {
171 }
172 ~StmtNode() override = default;
173};
174
175/** IR nodes are passed around opaque handles to them. This is a
176 base class for those handles. It manages the reference count,
177 and dispatches visitors. */
178struct IRHandle : public IntrusivePtr<const IRNode> {
180 IRHandle() = default;
181
183 IRHandle(const IRNode *p)
184 : IntrusivePtr<const IRNode>(p) {
185 }
186
187 /** Dispatch to the correct visitor method for this node. E.g. if
188 * this node is actually an Add node, then this will call
189 * IRVisitor::visit(const Add *) */
190 void accept(IRVisitor *v) const {
191 ptr->accept(v);
192 }
193
194 /** Downcast this ir node to its actual type (e.g. Add, or
195 * Select). This returns nullptr if the node is not of the requested
196 * type. Example usage:
197 *
198 * if (const Add *add = node->as<Add>()) {
199 * // This is an add node
200 * }
201 */
202 template<typename T>
203 const T *as() const {
204 if (ptr && ptr->node_type == T::_node_type) {
205 return (const T *)ptr;
206 }
207 return nullptr;
208 }
209
211 return ptr->node_type;
212 }
213};
214
215/** Integer constants */
216struct IntImm : public ExprNode<IntImm> {
218
219 static const IntImm *make(Type t, int64_t value);
220
222};
223
224/** Unsigned integer constants */
225struct UIntImm : public ExprNode<UIntImm> {
227
228 static const UIntImm *make(Type t, uint64_t value);
229
231};
232
233/** Floating point constants */
234struct FloatImm : public ExprNode<FloatImm> {
235 double value;
236
237 static const FloatImm *make(Type t, double value);
238
240};
241
242/** String constants */
243struct StringImm : public ExprNode<StringImm> {
244 std::string value;
245
246 static const StringImm *make(const std::string &val);
247
249};
250
251} // namespace Internal
252
253/** A fragment of Halide syntax. It's implemented as reference-counted
254 * handle to a concrete expression node, but it's immutable, so you
255 * can treat it as a value type. */
256struct Expr : public Internal::IRHandle {
257 /** Make an undefined expression */
259 Expr() = default;
260
261 /** Make an expression from a concrete expression node pointer (e.g. Add) */
264 : IRHandle(n) {
265 }
266
267 /** Make an expression representing numeric constants of various types. */
268 // @{
269 explicit Expr(int8_t x)
270 : IRHandle(Internal::IntImm::make(Int(8), x)) {
271 }
272 explicit Expr(int16_t x)
273 : IRHandle(Internal::IntImm::make(Int(16), x)) {
274 }
276 : IRHandle(Internal::IntImm::make(Int(32), x)) {
277 }
278 explicit Expr(int64_t x)
279 : IRHandle(Internal::IntImm::make(Int(64), x)) {
280 }
281 explicit Expr(uint8_t x)
282 : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
283 }
284 explicit Expr(uint16_t x)
285 : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
286 }
287 explicit Expr(uint32_t x)
288 : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
289 }
290 explicit Expr(uint64_t x)
291 : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
292 }
294 : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
295 }
297 : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
298 }
299 Expr(float x)
300 : IRHandle(Internal::FloatImm::make(Float(32), x)) {
301 }
302 explicit Expr(double x)
303 : IRHandle(Internal::FloatImm::make(Float(64), x)) {
304 }
305 // @}
306
307 /** Make an expression representing a const string (i.e. a StringImm) */
308 Expr(const std::string &s)
309 : IRHandle(Internal::StringImm::make(s)) {
310 }
311
312 /** Override get() to return a BaseExprNode * instead of an IRNode * */
315 return (const Internal::BaseExprNode *)ptr;
316 }
317
318 /** Get the type of this expression node */
320 Type type() const {
321 return get()->type;
322 }
323};
324
325/** This lets you use an Expr as a key in a map of the form
326 * map<Expr, Foo, ExprCompare> */
328 bool operator()(const Expr &a, const Expr &b) const {
329 return a.get() < b.get();
330 }
331};
332
333/** A single-dimensional span. Includes all numbers between min and
334 * (min + extent - 1). */
335struct Range {
337
338 Range() = default;
339 Range(const Expr &min_in, const Expr &extent_in);
340};
341
342/** A multi-dimensional box. The outer product of the elements */
343typedef std::vector<Range> Region;
344
345/** An enum describing different address spaces to be used with Func::store_in. */
346enum class MemoryType {
347 /** Let Halide select a storage type automatically */
348 Auto,
349
350 /** Heap/global memory. Allocated using halide_malloc, or
351 * halide_device_malloc */
352 Heap,
353
354 /** Stack memory. Allocated using alloca. Requires a constant
355 * size. Corresponds to per-thread local memory on the GPU. If all
356 * accesses are at constant coordinates, may be promoted into the
357 * register file at the discretion of the register allocator. */
358 Stack,
359
360 /** Register memory. The allocation should be promoted into the
361 * register file. All stores must be at constant coordinates. May
362 * be spilled to the stack at the discretion of the register
363 * allocator. */
364 Register,
365
366 /** Allocation is stored in GPU shared memory. Also known as
367 * "local" in OpenCL, and "threadgroup" in metal. Can be shared
368 * across GPU threads within the same block. */
369 GPUShared,
370
371 /** Allocation is stored in GPU texture memory and accessed through
372 * hardware sampler */
374
375 /** Allocate Locked Cache Memory to act as local memory */
377 /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
378 * v65+. This memory has higher performance and lower power. Ideal for
379 * intermediate buffers. Necessary for vgather-vscatter instructions
380 * on Hexagon */
381 VTCM,
382
383 /** AMX Tile register for X86. Any data that would be used in an AMX matrix
384 * multiplication must first be loaded into an AMX tile register. */
385 AMXTile,
386};
387
388namespace Internal {
389
390/** An enum describing a type of loop traversal. Used in schedules,
391 * and in the For loop IR node. Serial is a conventional ordered for
392 * loop. Iterations occur in increasing order, and each iteration must
393 * appear to have finished before the next begins. Parallel, GPUBlock,
394 * and GPUThread are parallel and unordered: iterations may occur in
395 * any order, and multiple iterations may occur
396 * simultaneously. Vectorized and GPULane are parallel and
397 * synchronous: they act as if all iterations occur at the same time
398 * in lockstep. */
399enum class ForType {
400 Serial,
401 Parallel,
403 Unrolled,
404 Extern,
405 GPUBlock,
406 GPUThread,
407 GPULane,
408};
409
410/** Check if for_type executes for loop iterations in parallel and unordered. */
412
413/** Returns true if for_type executes for loop iterations in parallel. */
414bool is_parallel(ForType for_type);
415
416/** A reference-counted handle to a statement node. */
417struct Stmt : public IRHandle {
418 Stmt() = default;
420 : IRHandle(n) {
421 }
422
423 /** Override get() to return a BaseStmtNode * instead of an IRNode * */
425 const BaseStmtNode *get() const {
426 return (const Internal::BaseStmtNode *)ptr;
427 }
428
429 /** This lets you use a Stmt as a key in a map of the form
430 * map<Stmt, Foo, Stmt::Compare> */
431 struct Compare {
432 bool operator()(const Stmt &a, const Stmt &b) const {
433 return a.ptr < b.ptr;
434 }
435 };
436};
437
438} // namespace Internal
439} // namespace Halide
440
441#endif
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:38
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:26
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
A class representing a reference count to be used with IntrusivePtr.
Definition: IntrusivePtr.h:19
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:79
ForType
An enum describing a type of loop traversal.
Definition: Expr.h:399
RefCount & ref_count< IRNode >(const IRNode *t) noexcept
Definition: Expr.h:115
bool is_unordered_parallel(ForType for_type)
Check if for_type executes for loop iterations in parallel and unordered.
bool is_parallel(ForType for_type)
Returns true if for_type executes for loop iterations in parallel.
void destroy< IRNode >(const IRNode *t)
Definition: Expr.h:120
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition: Type.h:531
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition: Type.h:521
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition: Type.h:526
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition: Type.h:516
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:343
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition: Expr.h:346
@ Auto
Let Halide select a storage type automatically.
@ Register
Register memory.
@ Stack
Stack memory.
@ VTCM
Vector Tightly Coupled Memory.
@ AMXTile
AMX Tile register for X86.
@ LockedCache
Allocate Locked Cache Memory to act as local memory.
@ Heap
Heap/global memory.
@ GPUTexture
Allocation is stored in GPU texture memory and accessed through hardware sampler.
@ GPUShared
Allocation is stored in GPU shared memory.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
signed __INT16_TYPE__ int16_t
signed __INT8_TYPE__ int8_t
This lets you use an Expr as a key in a map of the form map<Expr, Foo, ExprCompare>
Definition: Expr.h:327
bool operator()(const Expr &a, const Expr &b) const
Definition: Expr.h:328
A fragment of Halide syntax.
Definition: Expr.h:256
Expr(float x)
Definition: Expr.h:299
HALIDE_ALWAYS_INLINE Expr()=default
Make an undefined expression.
Expr(int32_t x)
Definition: Expr.h:275
Expr(bfloat16_t x)
Definition: Expr.h:296
Expr(uint32_t x)
Definition: Expr.h:287
Expr(const std::string &s)
Make an expression representing a const string (i.e.
Definition: Expr.h:308
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:320
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:314
Expr(int64_t x)
Definition: Expr.h:278
Expr(int16_t x)
Definition: Expr.h:272
Expr(uint64_t x)
Definition: Expr.h:290
Expr(uint16_t x)
Definition: Expr.h:284
Expr(double x)
Definition: Expr.h:302
Expr(int8_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:269
HALIDE_ALWAYS_INLINE Expr(const Internal::BaseExprNode *n)
Make an expression from a concrete expression node pointer (e.g.
Definition: Expr.h:263
Expr(uint8_t x)
Definition: Expr.h:281
Expr(float16_t x)
Definition: Expr.h:293
The sum of two expressions.
Definition: IR.h:38
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:353
Logical and - are both expressions true.
Definition: IR.h:157
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:276
Lock all the Store nodes in the body statement.
Definition: IR.h:870
A base class for expression nodes.
Definition: Expr.h:141
virtual Expr mutate_expr(IRMutator *v) const =0
BaseExprNode(IRNodeType t)
Definition: Expr.h:142
IR nodes are split into expressions and statements.
Definition: Expr.h:132
BaseStmtNode(IRNodeType t)
Definition: Expr.h:133
virtual Stmt mutate_stmt(IRMutator *v) const =0
A sequence of statements to be executed in-order.
Definition: IR.h:418
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
A function call.
Definition: IR.h:466
The actual IR nodes begin here.
Definition: IR.h:29
The ratio of two expressions.
Definition: IR.h:65
Is the first expression equal to the second.
Definition: IR.h:103
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:452
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:156
~ExprNode() override=default
Expr mutate_expr(IRMutator *v) const override
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Floating point constants.
Definition: Expr.h:234
static const IRNodeType _node_type
Definition: Expr.h:239
static const FloatImm * make(Type t, double value)
A for loop.
Definition: IR.h:747
A pair of statements executed concurrently.
Definition: IR.h:433
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Is the first expression greater than the second.
Definition: IR.h:139
IR nodes are passed around opaque handles to them.
Definition: Expr.h:178
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:190
HALIDE_ALWAYS_INLINE IRHandle()=default
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:203
IRNodeType node_type() const
Definition: Expr.h:210
HALIDE_ALWAYS_INLINE IRHandle(const IRNode *p)
Definition: Expr.h:183
The abstract base classes for a node in the Halide IR.
Definition: Expr.h:82
virtual ~IRNode()=default
virtual void accept(IRVisitor *v) const =0
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:111
RefCount ref_count
These classes are all managed with intrusive reference counting, so we also track a reference count.
Definition: Expr.h:99
IRNode(IRNodeType t)
Definition: Expr.h:89
An if-then-else block.
Definition: IR.h:442
Integer constants.
Definition: Expr.h:216
static const IRNodeType _node_type
Definition: Expr.h:221
static const IntImm * make(Type t, int64_t value)
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
Definition: IntrusivePtr.h:68
Is the first expression less than or equal to the second.
Definition: IR.h:130
Is the first expression less than the second.
Definition: IR.h:121
A let expression, like you might find in a functional language.
Definition: IR.h:253
The statement form of a let node.
Definition: IR.h:264
Load a value from a named symbol if predicate is true.
Definition: IR.h:199
The greater of two values.
Definition: IR.h:94
The lesser of two values.
Definition: IR.h:85
The remainder of a / b.
Definition: IR.h:76
The product of two expressions.
Definition: IR.h:56
Is the first expression not equal to the second.
Definition: IR.h:112
Logical not - true if the expression false.
Definition: IR.h:175
Logical or - is at least one of the expression true.
Definition: IR.h:166
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:847
This node is a helpful annotation to do with permissions.
Definition: IR.h:297
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:336
A linear ramp vector node.
Definition: IR.h:229
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:403
A ternary operator.
Definition: IR.h:186
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:778
This lets you use a Stmt as a key in a map of the form map<Stmt, Foo, Stmt::Compare>
Definition: Expr.h:431
bool operator()(const Stmt &a, const Stmt &b) const
Definition: Expr.h:432
A reference-counted handle to a statement node.
Definition: Expr.h:417
Stmt(const BaseStmtNode *n)
Definition: Expr.h:419
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:425
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Stmt mutate_stmt(IRMutator *v) const override
~StmtNode() override=default
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:315
String constants.
Definition: Expr.h:243
static const StringImm * make(const std::string &val)
static const IRNodeType _node_type
Definition: Expr.h:248
The difference of two expressions.
Definition: IR.h:47
Unsigned integer constants.
Definition: Expr.h:225
static const IRNodeType _node_type
Definition: Expr.h:230
static const UIntImm * make(Type t, uint64_t value)
A named variable.
Definition: IR.h:700
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:888
A single-dimensional span.
Definition: Expr.h:335
Range()=default
Expr min
Definition: Expr.h:336
Expr extent
Definition: Expr.h:336
Range(const Expr &min_in, const Expr &extent_in)
Types in the halide type system.
Definition: Type.h:266
Class that provides a type that implements half precision floating point using the bfloat16 format.
Definition: Float16.h:142
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition: Float16.h:17