1 module exceptionhandling;
2 
3 /**
4    version(exceptionhandling_release_asserts)
5    
6    releases all assertXXXXs
7 */
8 
9 private {
10 	version(unittest) {
11 		import core.exception : AssertError;
12 		alias ExceptionType = AssertError;
13 	} else {
14 		alias ExceptionType = Exception;
15 	}
16 
17 	import std.math : approxEqual;
18 	bool cmpFloat(T)(T tt, T tc) {
19 		return approxEqual(tt, tc);
20 	}
21 
22 	bool cmpFloatNot(T)(T tt, T tc) {
23 		return !approxEqual(tt, tc);
24 	}
25 
26 	bool cmpNot(T,S)(T tt, S tc) {
27 		return tt != tc;
28 	}
29 
30 	bool cmp(T,S)(T tt, S tc) {
31 		return tt == tc;
32 	}
33 
34 	bool cmpLess(T,S)(T tt, S tc) {
35 		return tt < tc;
36 	}
37 
38 	bool cmpGreater(T,S)(T tt, S tc) {
39 		return tt > tc;
40 	}
41 
42 	bool cmpLessEqual(T,S)(T tt, S tc) {
43 		return tt <= tc;
44 	}
45 
46 	bool cmpGreaterEqual(T,S)(T tt, S tc) {
47 		return tt >= tc;
48 	}
49 
50 	bool cmpLessEqualFloat(T)(T tt, T tc) {
51 		return tt < tc || approxEqual(tt, tc);
52 	}
53 
54 	bool cmpGreaterEqualFloat(T)(T tt, T tc) {
55 		return tt > tc || approxEqual(tt, tc);
56 	}
57 }
58 
59 template getCMP(T, alias FCMP, alias ICMP) {
60 	import std.traits : isFloatingPoint;
61 	import std.range : ElementType, isInputRange;
62 
63 	static if(isInputRange!T) {
64 		static if(isFloatingPoint!(ElementType!(T))) {
65 			alias getCMP = FCMP;
66 		} else {
67 			alias getCMP = ICMP;
68 		}
69 	} else {
70 		static if(isFloatingPoint!T) {
71 			alias getCMP = FCMP;
72 		} else {
73 			alias getCMP = ICMP;
74 		}
75 	}
76 }
77 
78 /** Assert that `toTest` is equal to `toCompareAgainst`.
79 If `T` is a floating point `approxEqual` is used to compare the values.
80 `toTest` is returned if the comparision is correct.
81 If the comparision is incorrect an Exception is thrown. If assertEqual is used
82 in a unittest block an AssertError is thrown an Exception otherwise.
83 */
84 auto ref T assertEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
85 		const string file = __FILE__, const int line = __LINE__)
86 {
87 	version(assert) {
88 		alias CMP = getCMP!(T, cmpFloat, cmp);
89 		return AssertImpl!(T,S, CMP, "==")(toTest, toCompareAgainst,
90 				file, line
91 		);
92 	} else {
93 		return toTest;
94 	}
95 }
96 
97 /// ditto
98 auto ref T assertNotEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
99 		const string file = __FILE__, const int line = __LINE__)
100 {
101 	version(assert) {
102 		alias CMP = getCMP!(T,cmpFloatNot, cmpNot);
103 		return AssertImpl!(T,S, CMP, "!=")(toTest, toCompareAgainst, file,
104 				line
105 		);
106 	} else {
107 		return toTest;
108 	}
109 }
110 
111 /// ditto
112 auto ref T assertLess(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
113 		const string file = __FILE__, const int line = __LINE__)
114 {
115 	version(assert) {
116 		return AssertImpl!(T,S, cmpLess, "<")(toTest, toCompareAgainst,
117 				file, line
118 		);
119 	} else {
120 		return toTest;
121 	}
122 }
123 
124 /// ditto
125 auto ref T assertGreater(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
126 		const string file = __FILE__, const int line = __LINE__)
127 {
128 	version(assert) {
129 		return AssertImpl!(T,S, cmpGreater, ">")(toTest, toCompareAgainst,
130 				file, line
131 		);
132 	} else {
133 		return toTest;
134 	}
135 }
136 
137 /// ditto
138 auto ref T assertGreaterEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
139 		const string file = __FILE__, const int line = __LINE__)
140 {
141 	version(assert) {
142 		alias CMP = getCMP!(T,cmpGreaterEqualFloat, cmpGreaterEqual);
143 		return AssertImpl!(T,S, CMP, ">=")(toTest,
144 				toCompareAgainst, file, line
145 		);
146 	} else {
147 		return toTest;
148 	}
149 }
150 
151 /// ditto
152 auto ref T assertLessEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
153 		const string file = __FILE__, const int line = __LINE__)
154 {
155 	version(assert) {
156 		alias CMP = getCMP!(T,cmpLessEqualFloat, cmpLessEqual);
157 		return AssertImpl!(T,S, CMP, "<=")(toTest,
158 				toCompareAgainst, file, line
159 		);
160 	} else {
161 		return toTest;
162 	}
163 }
164 
165 private auto ref T AssertImpl(T,S,alias Cmp, string cmpMsg)(auto ref T toTest,
166 		auto ref S toCompareAgainst, const string file, const int line)
167 {
168 	import std.format : format;
169 	import std.range : isForwardRange, isInputRange;
170 
171 	static assert(!isInputRange!T || isForwardRange!T);
172 	version(exceptionhandling_release_asserts) {
173 		return toTest;
174 	} else {
175 		bool cmpRslt = false;
176 		try {
177 			static if(isForwardRange!T && isForwardRange!S) {
178 				import std.algorithm.comparison : equal;
179 				import std.traits : isImplicitlyConvertible;
180 				import std.range.primitives : isInputRange, ElementType;
181 				import std.functional : binaryFun;
182 				import std.array : empty, front, popFront;
183 				static assert(isImplicitlyConvertible!(
184 						ElementType!(T),
185 						ElementType!(S)),
186 						format("You can not compare ranges of type %s to"
187 							~ " ranges of type %s.", ElementType!(T).stringof,
188 							ElementType!(S).stringof)
189 				);
190 				alias CMP = Cmp!(ElementType!T,ElementType!S);
191 				while(!toTest.empty && !toCompareAgainst.empty) {
192 					if(!CMP(toTest.front, toCompareAgainst.front)) {
193 						cmpRslt = false;
194 						goto fail;
195 					}
196 					if(!toTest.empty && !toCompareAgainst.empty) {
197 						toTest.popFront();
198 						toCompareAgainst.popFront();
199 					} 
200 				}
201 				if(toTest.empty != toCompareAgainst.empty) {
202 					cmpRslt = false;
203 					goto fail;
204 				}
205 				cmpRslt = true;
206 				
207 				fail:
208 			} else {
209 				cmpRslt = Cmp(toTest, toCompareAgainst);
210 			}
211 		} catch(Exception e) {
212 			throw new ExceptionType(
213 				format("Exception thrown while \"toTest(%s) " ~ cmpMsg
214 					~ " toCompareAgainst(%s)\"",
215 				toTest, toCompareAgainst), file, line, e
216 			);
217 		}
218 
219 		if(!cmpRslt) {
220 			throw new ExceptionType(format("toTest(%s) " ~ cmpMsg ~
221 				" toCompareAgainst(%s) failed", toTest, toCompareAgainst), file, 
222 					line
223 			);
224 		}
225 		return toTest;
226 	}
227 }
228 
229 unittest {
230 	import core.exception : AssertError;
231 	import std.exception : assertThrown;
232 	import std.meta : AliasSeq;
233 
234 	foreach(T; AliasSeq!(byte,int,float,double)) {
235 		T zero = 0;
236 		T one = 1;
237 		T two = 2;
238 
239 		T ret = assertEqual(one, one).assertGreater(zero).assertLess(two);
240 		cast(void)assertEqual(ret, one);
241 		ret = assertNotEqual(one, zero).assertGreater(zero).assertLess(two);
242 		cast(void)assertEqual(ret, one);
243 		ret = assertLessEqual(one, two)
244 			.assertGreaterEqual(zero)
245 			.assertEqual(one);
246 		cast(void)assertEqual(ret, one);
247 
248 		cast(void)assertEqual(cast(const(T))one, one);
249 		cast(void)assertNotEqual(cast(const(T))one, zero);
250 		cast(void)assertEqual(one, cast(const(T))one);
251 		cast(void)assertNotEqual(one, cast(const(T))zero);
252 
253 		assertThrown!AssertError(assertEqual(one, zero));
254 	}
255 
256 	cast(void)assertEqual(1, 1);
257 	cast(void)assertNotEqual(1, 0);
258 }
259 
260 unittest {
261 	import core.exception : AssertError;
262 	import std.exception : assertThrown;
263 
264 	class Foo {
265 		int a;
266 		this(int a) { this.a = a; }
267 		override bool opEquals(Object o) {
268 			throw new Exception("Another test");
269 		}
270 	}
271 
272 	auto f = new Foo(1);
273 	auto g = new Foo(1);
274 
275 	assertThrown!AssertError(assertEqual(f, cast(Foo)null));
276 	assertThrown!AssertError(assertEqual(f, g));
277 	assertNotThrown!AssertError(assertEqual([0,1,2,3,4], [0,1,2,3,4]));
278 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2,3,4]));
279 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2]));
280 	assertThrown!AssertError(assertEqual([0,1,2,3,5], [0,1,2,3,4]));
281 	assertThrown!AssertError(assertEqual([0,1,2,3], [0,1,2,3,4]));
282 
283 	import std.container.array : Array;
284 
285 	auto ia = Array!int([0,1,2,3,4]);
286 	assertNotThrown!AssertError(assertEqual(ia[], [0,1,2,3,4]));
287 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3]));
288 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3,4,5]));
289 }
290 
291 /** Calls `exp` if `exp` does not throw the return value from `exp` is
292 returned, if `exp` throws the Exception is cought, a new Exception is
293 constructed with a message made of `args` space seperated and the previously
294 cought exception is nested in the newly created exception.
295 */
296 auto expect(ET = Exception, F, int line = __LINE__, string file = __FILE__, Args...)
297 		(lazy F exp, lazy Args args)
298 {
299 	try {
300 		return exp();
301 	} catch(Exception e) {
302 		throw new ET(joinElem(args), file, line, e);
303 	}
304 }
305 
306 private string joinElem(Args...)(lazy Args args) {
307 	import std.array : appender;
308 	import std.format : formattedWrite;
309 
310 	auto app = appender!string();
311 	foreach(arg; args) {
312 		formattedWrite(app, "%s ", arg);
313 	}
314 	return app.data;
315 }
316 
317 unittest {
318 	import std.string : indexOf;
319 	import std.exception : assertThrown;
320 
321 	string barMsg = "Fun will thrown, I'm sure";
322 	string funMsg = "Hopefully this is true";
323 
324 	void fun() {
325 		throw new Exception(funMsg);
326 	}
327 
328 	void bar() {
329 		expect(fun(), barMsg);
330 	}
331 
332 	bool func() {
333 		throw new Exception("e");
334 	}
335 
336 
337 	bool didThrow = false;
338 	try {
339 		bar();
340 	} catch(Exception e) {
341 		assert(e.msg.indexOf(barMsg) != -1, "\"" ~ e.msg ~ "\" " ~ barMsg);
342 		assert(e.next !is null);
343 		assert(e.next.msg.indexOf(funMsg) != -1, e.next.msg);
344 		didThrow = true;
345 	}
346 
347 	assert(didThrow);
348 
349 	assertThrown(assertEqual(func(), true));
350 }
351 
352 ///
353 auto ref ensure(ET = ExceptionType, E, int line = __LINE__,
354 		string file = __FILE__, Args...)(lazy E exp, Args args)
355 {
356 	typeof(exp) rslt;
357 
358 	try {
359 		rslt = exp();
360 	} catch(Throwable e) {
361 		throw new Exception(
362 			"Exception thrown will calling \"ensure\"", file, line, e
363 		);
364 	}
365 
366 	if(!rslt) {
367 		throw new Exception(joinElem("Ensure failed", args), file, line);
368 	} else {
369 		return rslt;
370 	}
371 }
372 
373 ///
374 unittest {
375 	import core.exception : AssertError;
376 	//import std.exception : assertThrown, assertNotThrown;
377 	bool func() {
378 		throw new Exception("e");
379 	}
380 
381 	auto e = assertThrown!Exception(ensure(func()));
382 	assert(e.line == __LINE__ - 1);
383 	auto e2 = assertThrown!Exception(ensure(false));
384 	assert(e2.line == __LINE__ - 1);
385 	bool b = assertNotThrown!Exception(ensure(true));
386 	assert(b);
387 }
388 
389 E assertThrown(E,T)(lazy T t, int line = __LINE__,
390 		string file = __FILE__)
391 {
392 	try {
393 		t();
394 	} catch(E e) {
395 		return e;
396 	}
397 	throw new ExceptionType("Exception of type " ~ E.stringof ~
398 			" was not thrown even though expected.", file, line
399 	);
400 }
401 
402 auto assertNotThrown(E,T)(lazy T t, int line = __LINE__,
403 		string file = __FILE__)
404 {
405 	try {
406 		return t();
407 	} catch(E e) {
408 		throw new ExceptionType("Exception of type " ~ E.stringof ~
409 				" caught unexceptionally", file, line
410 		);
411 	}
412 }
413 
414 ///
415 unittest {
416 	import core.exception : AssertError;
417 	//import std.exception : assertThrown, assertNotThrown;
418 	bool foo() {
419 		throw new Exception("e");
420 	}
421 
422 	bool bar() {
423 		return true;
424 	}
425 
426 	assertThrown!(AssertError)(assertThrown!(AssertError)(bar()));
427 	assertThrown!(AssertError)(assertNotThrown!(Exception)(foo()));
428 }
429 
430 unittest {
431 	struct C {
432 		int a;
433 
434 		bool opEquals(int other) const {
435 			return this.a == other;
436 		}
437 
438 		bool opEquals(C other) const {
439 			return this.a == other.a;
440 		}
441 
442 		int opCmp(int other) const {
443 			return this.a < other ? -1 : this.a > other ? 1 : 0;
444 		}
445 
446 		int opCmp(C other) const {
447 			return this.a < other.a ? -1 : this.a > other.a ? 1 : 0;
448 		}
449 	}
450 
451 	assertEqual(C(10), 10);
452 	assertLess(C(7), 9);
453 	assertGreater(C(10), 9);
454 	assertLessEqual(C(7), 9);
455 	assertGreaterEqual(C(10), 9);
456 	assertLessEqual(C(7), 7);
457 	assertGreaterEqual(C(10), 10);
458 
459 	assertEqual(C(10), C(10));
460 	assertLess(C(7), C(9));
461 	assertGreater(C(10), C(9));
462 	assertLessEqual(C(7), C(9));
463 	assertGreaterEqual(C(10), C(9));
464 	assertLessEqual(C(7), C(7));
465 	assertGreaterEqual(C(10), C(10));
466 }