1 /**
2     Light-Weight Modular staging library for D langauge.
3 
4     Popularization of dependency injection framework somehow shaded
5     the bigger and more general technique - staging and staged computation.
6 
7     This library ractifies that omission for the D language.
8     
9     
10     Built with love based on the ideas in the paper:
11     
12     Lightweight Modular Staging: A Pragmatic Approach to
13     Runtime Code Generation and Compiled DSLs 
14     by Tiark Rompf and Martin Odersky.
15 
16     The above paper and many other good ones by Scala team at EPFL
17     are here https://scala-lang.org/old/node/143.
18 */
19 
20 module lms;
21 
22 class Box {
23     // replace value of this box with another, may throw if cannot do that
24     void replace(Box another) {
25         throw new LmsException("Internal error - cannot replace contents of this lifted value");
26     }
27 }
28 
29 /// Lift a simple constant
30 Lift!T lift(T)(T value) 
31 if (!is(T : Lift!U, U)){
32     return new Constant!T(value);
33 }
34 
35 /// ditto
36 Lift!T lift(T)(Lift!T lifted) {
37     return lifted;
38 }
39 
40 /**
41     Stage is equivalent to DI container (or rather DI is simple late-binding + basic form of staging)
42     but the composition and execution of them is independent
43     and encapsulated by Lift!T interface
44     
45     A user is expected to sub-class and define custom stages as needed. See also `BasicStage`.
46 */
47 interface Stage {    
48     /// Lift a placeholder - slot for concrete value to be filled in at a later stage
49     Slot!T slot(T)(string name) {
50         auto lifted = new Slot!T(name);
51         register(name, lifted);
52         return lifted;
53     }
54 
55     /// Register existing slot at this _stage_ with name `name`
56     Slot!T slot(T)(string name, Slot!T value) {
57         register(name, value);
58         value.reset();
59         return value;
60     }
61 
62     /// Register existing slot at this _stage_ with original name
63     Slot!T slot(T)(Slot!T value) {
64         register(value.name, value);
65         value.reset();
66         return value;
67     }
68 
69     /// Try to evaluate (lower) lifted value using this stage
70     auto eval(T)(Lift!T value) {
71         return value.eval(this);
72     }
73 
74     /// Do partial evaluation for lifted value, this folds all known-constant sub-tries and optimizes expressions
75     auto partial(U)(U value) 
76     if (!is(U : Slot!T, T)) {
77         return value.partial(this);
78     }
79 
80     ///ditto
81     auto partial(U)(U value)
82     if (is(U : Slot!T, T)) {
83         return cast(Lift!T)this[value];
84     }
85 
86     void register(T)(Slot!T slot) {
87         register(slot.name, typeid(T));
88     }
89 
90     /// This is an implementation hook - register must save name,typeinfo pair to check type matching later
91     void register(string name, Box box);
92 
93     /// This is an implementation hook - bind must check that typeinfo matches and bind value to the lifted slot
94     Box opIndexAssign(Box value, string name);
95 
96     /// Third implementation hook - lookup bound value for a given name
97     Box opIndex(string name);
98 }
99 
100 /**
101     Simple stage that keeps slots as key-value pairs in built-in AA.
102 
103     Could be used as is or as an example to build your own stage(s).
104 */
105 class BasicStage : Stage { 
106     override void register(string name, Box box) {
107         if (name in slots) throw new LmsNameConflict("This stage already has slot for '"~name~"' variable");
108         slots[name] = box;
109     }
110 
111     override Box opIndexAssign(Box lifted, string name) {
112         auto p = name in slots;
113         if (!p) throw new LmsNameResolution("This stage doesn't have '"~name~"' variable");
114         slots[name].replace(lifted);
115         return lifted;
116     }
117 
118     override Box opIndex(string name) {
119         auto p = name in slots;
120         if (!p) throw new LmsNameResolution("This stage doesn't have '"~name~"' variable");
121         return slots[name];
122     }
123 
124     private Box[string] slots;
125 }
126 
127 // Lifted value of type T
128 abstract class Lift(T) : Box {
129     // full evaluation, may fail if some variables are not defined at this stage
130     abstract T eval(Stage stage);   
131     // partial evaluation given all of variables we know at this stage
132     abstract Lift!T partial(Stage stage);
133     //
134     final Lift!U map(U)(U delegate(T) mapFunc) {
135         return new Mapped!(T, U)(this, mapFunc);
136     }
137     //
138     final Lift!U flatMap(U)(Lift!U delegate(T) mapFunc) {
139         return new FlatMapped!(T, U)(this, mapFunc);
140     }
141 
142     ///
143     auto opBinary(string op, U)(U rhsV)
144     if (!is(U : Lift!V, V)) {
145         return map((lhsV){
146             return mixin("lhsV "~op~" rhsV");
147         });
148     }
149 
150     ///
151     auto opBinary(string op, U)(U rhs) 
152     if (is(U : Lift!V, V)) {
153         return flatMap((lhsV) {
154             return rhs.map((rhsV) {
155                 return mixin("lhsV "~op~" rhsV");
156             });
157         });
158     }
159 }
160 
161 /// Simpliest of all - just a constant, stays the same, regardless of _stage_
162 class Constant(T) : Lift!T {
163     this(T value) {
164         this.value = value;
165     }
166 
167     override T eval(Stage stage) { 
168         return value; 
169     }
170 
171     override Lift!T partial(Stage stage) {
172         return this;
173     }
174     
175     private T value;
176 }
177 
178 /// Slot - a placeholder for value, that will be provided at a later _stage_
179 class Slot(T) : Lift!T {
180     this(string name) {
181         _name = name;
182         reset();
183     }
184 
185     override void replace(Box another) {
186         expr = cast(Lift!T)another;
187     }
188 
189     void reset() {
190         expr = lift(T.init).map(delegate T (T x){
191             throw new LmsEvaluationFailed("slot "~_name~" has no bound value at this stage");
192         });
193     }
194 
195     override T eval(Stage stage) {
196         return expr.eval(stage);
197     }
198 
199     override Lift!T partial(Stage stage) {
200         return expr;
201     }
202 
203     string name() { return _name; }
204 
205     private string _name;
206     private Lift!T expr;
207 }
208 
209 // Lifted map function call
210 private class Mapped(T, U) : Lift!U {
211     this(Lift!T arg, U delegate(T) func) {
212         this.liftedArg = arg;
213         this.func = func;
214     }
215 
216     override U eval(Stage stage) { 
217         return func(liftedArg.eval(stage)); 
218     }
219 
220     override Lift!U partial(Stage stage) {
221         import std.stdio : writeln;
222         auto v = liftedArg.partial(stage);
223         auto c = cast(Constant!T)v;
224         if (c) return lift(func(c.eval(stage)));
225         return v.map(func);
226     }
227 
228     private Lift!T liftedArg;
229     private U delegate(T) func;
230 }
231 
232 private class FlatMapped(T, U) : Lift!U {
233     this(Lift!T arg, Lift!U delegate(T) func) {
234         this.liftedArg = arg;
235         this.func = func;
236     }
237 
238     override U eval(Stage stage) { 
239         return func(liftedArg.eval(stage)).eval(stage); 
240     }
241 
242     override Lift!U partial(Stage stage) {
243         return liftedArg.partial(stage).flatMap((arg){
244             return func(arg);
245         });
246     }
247 
248     private Lift!T liftedArg;
249     private Lift!U delegate(T) func;
250 }
251 
252 class LmsException : Exception {
253     this(string message) {
254         super(message);
255     }
256 }
257 
258 class LmsNameResolution : LmsException {
259     this(string message) {
260         super(message);
261     }
262 }
263 
264 class LmsNameConflict : LmsNameResolution {
265     this(string message){
266         super(message);
267     }
268 }
269 
270 class LmsEvaluationFailed : LmsException {
271     this(string message){
272         super(message);
273     }
274 }
275 
276 version(unittest) {
277     void assertThrows(T)(lazy T expr) {
278         try {
279             expr;
280         }
281         catch(LmsException e) {
282             return;
283         }
284         assert(0, expr.stringof ~ " should throw but didn't!");
285     }
286 }
287 
288 ///
289 @("basics")
290 unittest {
291     auto stage = new BasicStage();
292     auto value = lift(40) + 2;
293     assert(stage.eval(value) == 42);
294 }
295 
296 ///
297 @("slots")
298 unittest {
299     auto stage = new BasicStage();
300     auto slot = stage.slot!string("some.slot");
301     assert(slot.name == "some.slot");
302     auto expr = slot ~ ", world!";
303     assertThrows(stage.eval(expr));
304     
305     stage["some.slot"] =  lift("Hello");
306     assert(stage.eval(expr) == "Hello, world!");
307 
308     auto laterStage = new BasicStage();
309     laterStage.slot(slot);
310     assertThrows(laterStage.eval(slot));
311 
312     laterStage["some.slot"] = lift("Bye");
313 
314     assert(stage.eval(expr) == "Bye, world!");
315 }
316 
317 
318 ///
319 @("partial evaluation")
320 unittest {
321     auto stage = new BasicStage();
322     int[] trace; // our primitive trace buffer
323     auto v1 = stage.slot!double("var1").map(delegate double(double x) {
324         trace ~= 1;
325         return x;
326     });
327     auto v2 = stage.slot!double("var2").map(delegate double(double x) {
328         trace ~= 2;
329         return x;
330     });
331     stage["var1"] = lift(1.5);
332     auto part = (v1 + v2).partial(stage);
333     stage["var2"] = lift(-0.5);
334     // first pass - both map functions called once
335     assert(part.eval(stage) == 1.0);
336     assert(trace == [1, 2]);
337 
338     // second pass - only v2 is evaluated
339     assert(part.eval(stage) == 1.0);
340     assert(trace == [1, 2, 2]);
341 }
342 
343 ///
344 @("CTFE")
345 unittest {
346     enum result = () {
347         auto stage = new BasicStage();
348         auto s = stage.slot!int("int");
349         stage["int"] = lift(123);
350         auto s2 = lift(2);
351         //auto s3 = s + s2; // somehow fails.. to be fixed soon
352         return stage.eval(s);
353     }();
354     static assert(result == 123);
355 }