1 module exceptionhandling;
2 
3 /**
4    version(exceptionhandling_release_asserts)
5    
6    releases all assertXXXXs
7 */
8 
9 private {
10 	version(unittest) {
11 		import core.exception : AssertError;
12 		alias ExceptionType = AssertError;
13 	} else {
14 		alias ExceptionType = Exception;
15 	}
16 
17 	import std.math : approxEqual;
18 	bool cmpFloat(T)(T tt, T tc) {
19 		return approxEqual(tt, tc);
20 	}
21 
22 	bool cmpFloatNot(T)(T tt, T tc) {
23 		return !approxEqual(tt, tc);
24 	}
25 
26 	bool cmpNot(T,S)(T tt, S tc) {
27 		return tt != tc;
28 	}
29 
30 	bool cmp(T,S)(T tt, S tc) {
31 		return tt == tc;
32 	}
33 
34 	bool cmpLess(T,S)(T tt, S tc) {
35 		return tt < tc;
36 	}
37 
38 	bool cmpGreater(T,S)(T tt, S tc) {
39 		return tt > tc;
40 	}
41 
42 	bool cmpLessEqual(T,S)(T tt, S tc) {
43 		return tt <= tc;
44 	}
45 
46 	bool cmpGreaterEqual(T,S)(T tt, S tc) {
47 		return tt >= tc;
48 	}
49 
50 	bool cmpLessEqualFloat(T)(T tt, T tc) {
51 		return tt < tc || approxEqual(tt, tc);
52 	}
53 
54 	bool cmpGreaterEqualFloat(T)(T tt, T tc) {
55 		return tt > tc || approxEqual(tt, tc);
56 	}
57 }
58 
59 template getCMP(T, alias FCMP, alias ICMP) {
60 	import std.traits : isFloatingPoint;
61 	import std.range : ElementType, isInputRange;
62 
63 	static if(isInputRange!T) {
64 		static if(isFloatingPoint!(ElementType!(T))) {
65 			alias getCMP = FCMP;
66 		} else {
67 			alias getCMP = ICMP;
68 		}
69 	} else {
70 		static if(isFloatingPoint!T) {
71 			alias getCMP = FCMP;
72 		} else {
73 			alias getCMP = ICMP;
74 		}
75 	}
76 }
77 
78 /** Assert that `toTest` is equal to `toCompareAgainst`.
79 If `T` is a floating point `approxEqual` is used to compare the values.
80 `toTest` is returned if the comparision is correct.
81 If the comparision is incorrect an Exception is thrown. If assertEqual is used
82 in a unittest block an AssertError is thrown an Exception otherwise.
83 */
84 void assertEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
85 		const string file = __FILE__, const int line = __LINE__)
86 {
87 	version(assert) {
88 		alias CMP = getCMP!(T, cmpFloat, cmp);
89 		AssertImpl!(CMP, "==")(toTest, toCompareAgainst,
90 				file, line
91 		);
92 	}
93 }
94 
95 /// ditto
96 void assertNotEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
97 		const string file = __FILE__, const int line = __LINE__)
98 {
99 	version(assert) {
100 		alias CMP = getCMP!(T,cmpFloatNot, cmpNot);
101 		AssertImpl!(CMP, "!=")(toTest, toCompareAgainst, file,
102 				line
103 		);
104 	}
105 }
106 
107 /// ditto
108 void assertLess(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
109 		const string file = __FILE__, const int line = __LINE__)
110 {
111 	version(assert) {
112 		return AssertImpl!(cmpLess, "<")(toTest, toCompareAgainst,
113 				file, line
114 		);
115 	}
116 }
117 
118 /// ditto
119 void assertGreater(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
120 		const string file = __FILE__, const int line = __LINE__)
121 {
122 	version(assert) {
123 		return AssertImpl!(cmpGreater, ">")(toTest, toCompareAgainst,
124 				file, line
125 		);
126 	}
127 }
128 
129 /// ditto
130 void assertGreaterEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
131 		const string file = __FILE__, const int line = __LINE__)
132 {
133 	version(assert) {
134 		alias CMP = getCMP!(T,cmpGreaterEqualFloat, cmpGreaterEqual);
135 		return AssertImpl!(CMP, ">=")(toTest,
136 				toCompareAgainst, file, line
137 		);
138 	}
139 }
140 
141 /// ditto
142 void assertLessEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
143 		const string file = __FILE__, const int line = __LINE__)
144 {
145 	version(assert) {
146 		alias CMP = getCMP!(T,cmpLessEqualFloat, cmpLessEqual);
147 		return AssertImpl!(CMP, "<=")(toTest,
148 				toCompareAgainst, file, line
149 		);
150 	}
151 }
152 
153 private void AssertImpl(alias Cmp, string cmpMsg,T,S)(auto ref T toTest,
154 		auto ref S toCompareAgainst, const string file, const int line)
155 {
156 	import std.format : format;
157 	import std.range : isForwardRange, isInputRange;
158 
159 	static assert(!isInputRange!T || isForwardRange!T);
160 	version(exceptionhandling_release_asserts) {
161 	} else {
162 		bool cmpRslt = false;
163 		try {
164 			static if(isForwardRange!T && isForwardRange!S) {
165 				import std.algorithm.comparison : equal;
166 				import std.traits : isImplicitlyConvertible;
167 				import std.range.primitives : isInputRange, ElementType;
168 				import std.functional : binaryFun;
169 				import std.array : empty, front, popFront;
170 				static assert(isImplicitlyConvertible!(
171 						ElementType!(T),
172 						ElementType!(S)),
173 						format("You can not compare ranges of type %s to"
174 							~ " ranges of type %s.", ElementType!(T).stringof,
175 							ElementType!(S).stringof)
176 				);
177 				alias CMP = Cmp!(ElementType!T,ElementType!S);
178 				while(!toTest.empty && !toCompareAgainst.empty) {
179 					if(!CMP(toTest.front, toCompareAgainst.front)) {
180 						cmpRslt = false;
181 						goto fail;
182 					}
183 					if(!toTest.empty && !toCompareAgainst.empty) {
184 						toTest.popFront();
185 						toCompareAgainst.popFront();
186 					} 
187 				}
188 				if(toTest.empty != toCompareAgainst.empty) {
189 					cmpRslt = false;
190 					goto fail;
191 				}
192 				cmpRslt = true;
193 				
194 				fail:
195 			} else {
196 				cmpRslt = Cmp(toTest, toCompareAgainst);
197 			}
198 		} catch(Exception e) {
199 			throw new ExceptionType(
200 				format("Exception thrown while \"toTest(%s) " ~ cmpMsg
201 					~ " toCompareAgainst(%s)\"",
202 				toTest, toCompareAgainst), file, line, e
203 			);
204 		}
205 
206 		if(!cmpRslt) {
207 			throw new ExceptionType(format("toTest(%s) " ~ cmpMsg ~
208 				" toCompareAgainst(%s) failed", toTest, toCompareAgainst), file, 
209 					line
210 			);
211 		}
212 	}
213 }
214 
215 unittest {
216 	import core.exception : AssertError;
217 	import std.exception : assertThrown;
218 	import std.meta : AliasSeq;
219 
220 	foreach(T; AliasSeq!(byte,int,float,double)) {
221 		T zero = 0;
222 		T one = 1;
223 		T two = 2;
224 
225 		assertEqual(one, one);
226 		assertNotEqual(one, zero);
227 
228 		assertEqual(cast(const(T))one, one);
229 		assertNotEqual(cast(const(T))one, zero);
230 		assertEqual(one, cast(const(T))one);
231 		assertNotEqual(one, cast(const(T))zero);
232 		assertLessEqual(cast(const(T))zero, one);
233 		assertGreaterEqual(one, cast(const(T))zero);
234 
235 		bool t = false;
236 		try {
237 			assertEqual(one, zero);
238 		} catch(Throwable e) {
239 			t = true;
240 		}
241 		assert(t);
242 	}
243 
244 	assertEqual(1, 1);
245 	assertNotEqual(1, 0);
246 }
247 
248 unittest {
249 	import core.exception : AssertError;
250 	import std.exception : assertThrown, assertNotThrown;
251 
252 	class Foo {
253 		int a;
254 		this(int a) { this.a = a; }
255 		override bool opEquals(Object o) {
256 			throw new Exception("Another test");
257 		}
258 	}
259 
260 	auto f = new Foo(1);
261 	auto g = new Foo(1);
262 
263 	assertThrown!AssertError(assertEqual(f, cast(Foo)null));
264 	assertThrown!AssertError(assertEqual(f, g));
265 	assertNotThrown!AssertError(assertEqual([0,1,2,3,4], [0,1,2,3,4]));
266 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2,3,4]));
267 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2]));
268 	assertThrown!AssertError(assertEqual([0,1,2,3,5], [0,1,2,3,4]));
269 	assertThrown!AssertError(assertEqual([0,1,2,3], [0,1,2,3,4]));
270 
271 	import std.container.array : Array;
272 
273 	auto ia = Array!int([0,1,2,3,4]);
274 	assertNotThrown!AssertError(assertEqual(ia[], [0,1,2,3,4]));
275 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3]));
276 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3,4,5]));
277 }
278 
279 /** Calls `exp` if `exp` does not throw the return value from `exp` is
280 returned, if `exp` throws the Exception is cought, a new Exception is
281 constructed with a message made of `args` space seperated and the previously
282 cought exception is nested in the newly created exception.
283 */
284 auto expect(ET = Exception, F, int line = __LINE__, string file = __FILE__, Args...)
285 		(lazy F exp, lazy Args args)
286 {
287 	try {
288 		return exp();
289 	} catch(Exception e) {
290 		throw new ET(joinElem(args), file, line, e);
291 	}
292 }
293 
294 private string joinElem(Args...)(lazy Args args) {
295 	import std.array : appender;
296 	import std.format : formattedWrite;
297 
298 	auto app = appender!string();
299 	foreach(arg; args) {
300 		formattedWrite(app, "%s ", arg);
301 	}
302 	return app.data;
303 }
304 
305 unittest {
306 	import std.string : indexOf;
307 	import std.exception : assertThrown;
308 
309 	string barMsg = "Fun will thrown, I'm sure";
310 	string funMsg = "Hopefully this is true";
311 
312 	void fun() {
313 		throw new Exception(funMsg);
314 	}
315 
316 	void bar() {
317 		expect(fun(), barMsg);
318 	}
319 
320 	bool func() {
321 		throw new Exception("e");
322 	}
323 
324 
325 	bool didThrow = false;
326 	try {
327 		bar();
328 	} catch(Exception e) {
329 		assert(e.msg.indexOf(barMsg) != -1, "\"" ~ e.msg ~ "\" " ~ barMsg);
330 		assert(e.next !is null);
331 		assert(e.next.msg.indexOf(funMsg) != -1, e.next.msg);
332 		didThrow = true;
333 	}
334 
335 	assert(didThrow);
336 
337 	assertThrown(assertEqual(func(), true));
338 }
339 
340 ///
341 auto ref ensure(ET = ExceptionType, E, int line = __LINE__,
342 		string file = __FILE__, Args...)(lazy E exp, Args args)
343 {
344 	typeof(exp) rslt;
345 
346 	try {
347 		rslt = exp();
348 	} catch(Throwable e) {
349 		throw new Exception(
350 			"Exception thrown will calling \"ensure\"", file, line, e
351 		);
352 	}
353 
354 	if(!rslt) {
355 		throw new Exception(joinElem("Ensure failed", args), file, line);
356 	} else {
357 		return rslt;
358 	}
359 }
360 
361 ///
362 unittest {
363 	import core.exception : AssertError;
364 	//import std.exception : assertThrown, assertNotThrown;
365 	bool func() {
366 		throw new Exception("e");
367 	}
368 
369 	auto e = assertThrown!Exception(ensure(func()));
370 	assert(e.line == __LINE__ - 1);
371 	auto e2 = assertThrown!Exception(ensure(false));
372 	assert(e2.line == __LINE__ - 1);
373 	bool b = assertNotThrown!Exception(ensure(true));
374 	assert(b);
375 }
376 
377 E assertThrown(E,T)(lazy T t, int line = __LINE__,
378 		string file = __FILE__)
379 {
380 	try {
381 		t();
382 	} catch(E e) {
383 		return e;
384 	}
385 	throw new ExceptionType("Exception of type " ~ E.stringof ~
386 			" was not thrown even though expected.", file, line
387 	);
388 }
389 
390 auto assertNotThrown(E,T)(lazy T t, int line = __LINE__,
391 		string file = __FILE__)
392 {
393 	try {
394 		return t();
395 	} catch(E e) {
396 		throw new ExceptionType("Exception of type " ~ E.stringof ~
397 				" caught unexceptionally", file, line
398 		);
399 	}
400 }
401 
402 ///
403 unittest {
404 	import core.exception : AssertError;
405 	//import std.exception : assertThrown, assertNotThrown;
406 	bool foo() {
407 		throw new Exception("e");
408 	}
409 
410 	bool bar() {
411 		return true;
412 	}
413 
414 	assertThrown!(AssertError)(assertThrown!(AssertError)(bar()));
415 	assertThrown!(AssertError)(assertNotThrown!(Exception)(foo()));
416 }
417 
418 unittest {
419 	struct C {
420 		int a;
421 
422 		bool opEquals(int other) const {
423 			return this.a == other;
424 		}
425 
426 		bool opEquals(C other) const {
427 			return this.a == other.a;
428 		}
429 
430 		int opCmp(int other) const {
431 			return this.a < other ? -1 : this.a > other ? 1 : 0;
432 		}
433 
434 		int opCmp(C other) const {
435 			return this.a < other.a ? -1 : this.a > other.a ? 1 : 0;
436 		}
437 	}
438 
439 	assertEqual(C(10), 10);
440 	assertLess(C(7), 9);
441 	assertGreater(C(10), 9);
442 	assertLessEqual(C(7), 9);
443 	assertGreaterEqual(C(10), 9);
444 	assertLessEqual(C(7), 7);
445 	assertGreaterEqual(C(10), 10);
446 
447 	assertEqual(C(10), C(10));
448 	assertLess(C(7), C(9));
449 	assertGreater(C(10), C(9));
450 	assertLessEqual(C(7), C(9));
451 	assertGreaterEqual(C(10), C(9));
452 	assertLessEqual(C(7), C(7));
453 	assertGreaterEqual(C(10), C(10));
454 }