1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use super::{PrimExpr, PrimExprNode};
use crate::ir::span::Span;
use crate::runtime::{IsObjectRef, String as TVMString};
use crate::DataType;
use tvm_macros::Object;
macro_rules! define_node {
($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => {
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = $ref]
#[type_key = $typekey]
pub struct $node {
base: PrimExprNode,
$(pub $id : $t),*
}
impl $name {
pub fn new(datatype: DataType, $($id : $t,)*) -> $name {
let base = PrimExprNode::base::<$node>(datatype, Span::null());
let node = $node { base, $($id),* };
node.into()
}
}
}
}
define_node!(IntImm, "IntImm", "IntImm";
IntImmNode { value: i64 });
impl From<i32> for IntImm {
fn from(i: i32) -> IntImm {
IntImm::new(DataType::int(32, 1), i as i64)
}
}
impl From<i32> for PrimExpr {
fn from(i: i32) -> PrimExpr {
IntImm::from(i).upcast()
}
}
define_node!(Var, "Var", "tir.Var";
VarNode { name_hint: TVMString });
define_node!(Add, "Add", "tir.Add"; AddNode { a: PrimExpr, b: PrimExpr });
define_node!(Sub, "Sub", "tir.Sub"; SubNode { a: PrimExpr, b: PrimExpr });
define_node!(Mul, "Mul", "tir.Mul"; MulNode { a: PrimExpr, b: PrimExpr });
define_node!(Div, "Div", "tir.Div"; DivNode { a: PrimExpr, b: PrimExpr });
define_node!(Mod, "Mod", "tir.Mod"; ModNode { a: PrimExpr, b: PrimExpr });
define_node!(FloorDiv, "FloorDiv", "tir.FloorDiv"; FloorDivNode { a: PrimExpr, b: PrimExpr });
define_node!(FloorMod, "FloorMod", "tir.FloorMod"; FloorModNode { a: PrimExpr, b: PrimExpr });
define_node!(Min, "Min", "tir.Min"; MinNode { a: PrimExpr, b: PrimExpr });
define_node!(Max, "Max", "tir.Max"; MaxNode { a: PrimExpr, b: PrimExpr });
define_node!(Cast, "Cast", "tir.Cast"; CastNode { value: PrimExpr });
define_node!(Ramp, "Ramp", "tir.Ramp"; RampNode { start: PrimExpr, stride: PrimExpr, lanes: i32 });
define_node!(Select, "Select", "tir.Select";
SelectNode { condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr });
define_node!(Eq, "Eq", "tir.EQ"; EqNode { a: PrimExpr, b: PrimExpr });
define_node!(Ne, "Ne", "tir.NE"; NeNode { a: PrimExpr, b: PrimExpr });
define_node!(Lt, "Lt", "tir.LT"; LtNode { a: PrimExpr, b: PrimExpr });
define_node!(Le, "Le", "tir.LE"; LeNode { a: PrimExpr, b: PrimExpr });
define_node!(Gt, "Gt", "tir.GT"; GtNode { a: PrimExpr, b: PrimExpr });
define_node!(Ge, "Ge", "tir.GE"; GeNode { a: PrimExpr, b: PrimExpr });
define_node!(And, "And", "tir.And"; AndNode { a: PrimExpr, b: PrimExpr });
define_node!(Or, "Or", "tir.Or"; OrNode { a: PrimExpr, b: PrimExpr });
define_node!(Not, "Not", "tir.Not"; NotNode { value: PrimExpr });
define_node!(Let, "Let", "tir.Let"; LetNode { var: Var, value: PrimExpr, body: PrimExpr });