1 module exceptionhandling;
2 
3 /**
4    version(exceptionhandling_release_asserts)
5    
6    releases all assertXXXXs
7 */
8 
9 private {
10 	version(unittest) {
11 		import core.exception : AssertError;
12 		alias ExceptionType = AssertError;
13 	} else {
14 		alias ExceptionType = Exception;
15 	}
16 
17 	import std.math : approxEqual;
18 	bool cmpFloat(T)(T tt, T tc) {
19 		return approxEqual(tt, tc);
20 	}
21 
22 	bool cmpFloatNot(T)(T tt, T tc) {
23 		return !approxEqual(tt, tc);
24 	}
25 
26 	bool cmpNot(T)(T tt, T tc) {
27 		return tt != tc;
28 	}
29 
30 	bool cmp(T)(T tt, T tc) {
31 		return tt == tc;
32 	}
33 
34 	bool cmpLess(T)(T tt, T tc) {
35 		return tt < tc;
36 	}
37 
38 	bool cmpGreater(T)(T tt, T tc) {
39 		return tt > tc;
40 	}
41 
42 	bool cmpLessEqual(T)(T tt, T tc) {
43 		return tt <= tc;
44 	}
45 
46 	bool cmpGreaterEqual(T)(T tt, T tc) {
47 		return tt >= tc;
48 	}
49 
50 	bool cmpLessEqualFloat(T)(T tt, T tc) {
51 		return tt < tc || approxEqual(tt, tc);
52 	}
53 
54 	bool cmpGreaterEqualFloat(T)(T tt, T tc) {
55 		return tt > tc || approxEqual(tt, tc);
56 	}
57 }
58 
59 template getCMP(T, alias FCMP, alias ICMP) {
60 	import std.traits : isFloatingPoint;
61 	import std.range : ElementType, isInputRange;
62 
63 	static if(isInputRange!T) {
64 		static if(isFloatingPoint!(ElementType!(T))) {
65 			alias getCMP = FCMP;
66 		} else {
67 			alias getCMP = ICMP;
68 		}
69 	} else {
70 		static if(isFloatingPoint!T) {
71 			alias getCMP = FCMP;
72 		} else {
73 			alias getCMP = ICMP;
74 		}
75 	}
76 }
77 
78 /** Assert that `toTest` is equal to `toCompareAgainst`.
79 If `T` is a floating point `approxEqual` is used to compare the values.
80 `toTest` is returned if the comparision is correct.
81 If the comparision is incorrect an Exception is thrown. If assertEqual is used
82 in a unittest block an AssertError is thrown an Exception otherwise.
83 */
84 auto ref T assertEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
85 		const string file = __FILE__, const int line = __LINE__)
86 {
87 	version(assert) {
88 		alias CMP = getCMP!(T, cmpFloat, cmp);
89 		return AssertImpl!(T,S, CMP, "==")(toTest, toCompareAgainst,
90 				file, line
91 		);
92 	} else {
93 		return toTest;
94 	}
95 }
96 
97 /// ditto
98 auto ref T assertNotEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
99 		const string file = __FILE__, const int line = __LINE__)
100 {
101 	version(assert) {
102 		alias CMP = getCMP!(T,cmpFloatNot, cmpNot);
103 		return AssertImpl!(T,S, CMP, "!=")(toTest, toCompareAgainst, file,
104 				line
105 		);
106 	} else {
107 		return toTest;
108 	}
109 }
110 
111 /// ditto
112 auto ref T assertLess(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
113 		const string file = __FILE__, const int line = __LINE__)
114 {
115 	version(assert) {
116 		import std.traits : isFloatingPoint, isImplicitlyConvertible;
117 		static assert(isImplicitlyConvertible!(T,S));
118 		return AssertImpl!(T,S, cmpLess, "<")(toTest, toCompareAgainst,
119 				file, line
120 		);
121 	} else {
122 		return toTest;
123 	}
124 }
125 
126 /// ditto
127 auto ref T assertGreater(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
128 		const string file = __FILE__, const int line = __LINE__)
129 {
130 	version(assert) {
131 		import std.traits : isFloatingPoint, isImplicitlyConvertible;
132 		static assert(isImplicitlyConvertible!(T,S));
133 		return AssertImpl!(T,S, cmpGreater, ">")(toTest, toCompareAgainst,
134 				file, line
135 		);
136 	} else {
137 		return toTest;
138 	}
139 }
140 
141 /// ditto
142 auto ref T assertGreaterEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
143 		const string file = __FILE__, const int line = __LINE__)
144 {
145 	version(assert) {
146 		import std.traits : isFloatingPoint, isImplicitlyConvertible;
147 		static assert(isImplicitlyConvertible!(T,S));
148 
149 		alias CMP = getCMP!(T,cmpGreaterEqualFloat, cmpGreaterEqual);
150 		return AssertImpl!(T,S, CMP, ">=")(toTest,
151 				toCompareAgainst, file, line
152 		);
153 	} else {
154 		return toTest;
155 	}
156 }
157 
158 /// ditto
159 auto ref T assertLessEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
160 		const string file = __FILE__, const int line = __LINE__)
161 {
162 	version(assert) {
163 		import std.traits : isFloatingPoint, isImplicitlyConvertible;
164 		static assert(isImplicitlyConvertible!(T,S));
165 		alias CMP = getCMP!(T,cmpLessEqualFloat, cmpLessEqual);
166 		return AssertImpl!(T,S, CMP, "<=")(toTest,
167 				toCompareAgainst, file, line
168 		);
169 	} else {
170 		return toTest;
171 	}
172 }
173 
174 private auto ref T AssertImpl(T,S,alias Cmp, string cmpMsg)(auto ref T toTest,
175 		auto ref S toCompareAgainst, const string file, const int line)
176 {
177 	import std.format : format;
178 	import std.range : isForwardRange, isInputRange;
179 
180 	static assert(!isInputRange!T || isForwardRange!T);
181 	version(exceptionhandling_release_asserts) {
182 		return toTest;
183 	} else {
184 		bool cmpRslt = false;
185 		try {
186 			static if(isForwardRange!T) {
187 				import std.algorithm.comparison : equal;
188 				import std.traits : isImplicitlyConvertible;
189 				import std.range.primitives : isInputRange, ElementType;
190 				import std.functional : binaryFun;
191 				import std.array : empty, front, popFront;
192 				static assert(isImplicitlyConvertible!(
193 						ElementType!(T),
194 						ElementType!(S)),
195 						format("You can not compare ranges of type %s to"
196 							~ " ranges of type %s.", ElementType!(T).stringof,
197 							ElementType!(S).stringof)
198 				);
199 				alias CMP = Cmp!(ElementType!T);
200 				while(!toTest.empty && !toCompareAgainst.empty) {
201 					if(!CMP(toTest.front, toCompareAgainst.front)) {
202 						cmpRslt = false;
203 						goto fail;
204 					}
205 					if(!toTest.empty && !toCompareAgainst.empty) {
206 						toTest.popFront();
207 						toCompareAgainst.popFront();
208 					} 
209 				}
210 				if(toTest.empty != toCompareAgainst.empty) {
211 					cmpRslt = false;
212 					goto fail;
213 				}
214 				cmpRslt = true;
215 				
216 				fail:
217 			} else {
218 				cmpRslt = Cmp(toTest, toCompareAgainst);
219 			}
220 		} catch(Exception e) {
221 			throw new ExceptionType(
222 				format("Exception thrown while \"toTest(%s) " ~ cmpMsg
223 					~ " toCompareAgainst(%s)\"",
224 				toTest, toCompareAgainst), file, line, e
225 			);
226 		}
227 
228 		if(!cmpRslt) {
229 			throw new ExceptionType(format("toTest(%s) " ~ cmpMsg ~
230 				" toCompareAgainst(%s) failed", toTest, toCompareAgainst), file, 
231 					line
232 			);
233 		}
234 		return toTest;
235 	}
236 }
237 
238 unittest {
239 	import core.exception : AssertError;
240 	import std.exception : assertThrown;
241 	import std.meta : AliasSeq;
242 
243 	foreach(T; AliasSeq!(byte,int,float,double)) {
244 		T zero = 0;
245 		T one = 1;
246 		T two = 2;
247 
248 		T ret = assertEqual(one, one).assertGreater(zero).assertLess(two);
249 		cast(void)assertEqual(ret, one);
250 		ret = assertNotEqual(one, zero).assertGreater(zero).assertLess(two);
251 		cast(void)assertEqual(ret, one);
252 		ret = assertLessEqual(one, two)
253 			.assertGreaterEqual(zero)
254 			.assertEqual(one);
255 		cast(void)assertEqual(ret, one);
256 
257 		cast(void)assertEqual(cast(const(T))one, one);
258 		cast(void)assertNotEqual(cast(const(T))one, zero);
259 		cast(void)assertEqual(one, cast(const(T))one);
260 		cast(void)assertNotEqual(one, cast(const(T))zero);
261 
262 		assertThrown!AssertError(assertEqual(one, zero));
263 	}
264 
265 	cast(void)assertEqual(1, 1);
266 	cast(void)assertNotEqual(1, 0);
267 }
268 
269 unittest {
270 	import core.exception : AssertError;
271 	import std.exception : assertThrown;
272 
273 	class Foo {
274 		int a;
275 		this(int a) { this.a = a; }
276 		override bool opEquals(Object o) {
277 			throw new Exception("Another test");
278 		}
279 	}
280 
281 	auto f = new Foo(1);
282 	auto g = new Foo(1);
283 
284 	assertThrown!AssertError(assertEqual(f, cast(Foo)null));
285 	assertThrown!AssertError(assertEqual(f, g));
286 	assertNotThrown!AssertError(assertEqual([0,1,2,3,4], [0,1,2,3,4]));
287 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2,3,4]));
288 	assertThrown!AssertError(assertEqual([0,2,3,4], [0,1,2]));
289 	assertThrown!AssertError(assertEqual([0,1,2,3,5], [0,1,2,3,4]));
290 	assertThrown!AssertError(assertEqual([0,1,2,3], [0,1,2,3,4]));
291 
292 	import std.container.array : Array;
293 
294 	auto ia = Array!int([0,1,2,3,4]);
295 	assertNotThrown!AssertError(assertEqual(ia[], [0,1,2,3,4]));
296 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3]));
297 	assertThrown!AssertError(assertEqual(ia[], [0,1,2,3,4,5]));
298 }
299 
300 /** Calls `exp` if `exp` does not throw the return value from `exp` is
301 returned, if `exp` throws the Exception is cought, a new Exception is
302 constructed with a message made of `args` space seperated and the previously
303 cought exception is nested in the newly created exception.
304 */
305 auto expect(ET = Exception, F, int line = __LINE__, string file = __FILE__, Args...)
306 		(lazy F exp, lazy Args args)
307 {
308 	try {
309 		return exp();
310 	} catch(Exception e) {
311 		throw new ET(joinElem(args), file, line, e);
312 	}
313 }
314 
315 private string joinElem(Args...)(lazy Args args) {
316 	import std.array : appender;
317 	import std.format : formattedWrite;
318 
319 	auto app = appender!string();
320 	foreach(arg; args) {
321 		formattedWrite(app, "%s ", arg);
322 	}
323 	return app.data;
324 }
325 
326 unittest {
327 	import std.string : indexOf;
328 	import std.exception : assertThrown;
329 
330 	string barMsg = "Fun will thrown, I'm sure";
331 	string funMsg = "Hopefully this is true";
332 
333 	void fun() {
334 		throw new Exception(funMsg);
335 	}
336 
337 	void bar() {
338 		expect(fun(), barMsg);
339 	}
340 
341 	bool func() {
342 		throw new Exception("e");
343 	}
344 
345 
346 	bool didThrow = false;
347 	try {
348 		bar();
349 	} catch(Exception e) {
350 		assert(e.msg.indexOf(barMsg) != -1, "\"" ~ e.msg ~ "\" " ~ barMsg);
351 		assert(e.next !is null);
352 		assert(e.next.msg.indexOf(funMsg) != -1, e.next.msg);
353 		didThrow = true;
354 	}
355 
356 	assert(didThrow);
357 
358 	assertThrown(assertEqual(func(), true));
359 }
360 
361 ///
362 auto ref ensure(ET = ExceptionType, E, int line = __LINE__,
363 		string file = __FILE__, Args...)(lazy E exp, Args args)
364 {
365 	typeof(exp) rslt;
366 
367 	try {
368 		rslt = exp();
369 	} catch(Exception e) {
370 		throw new ExceptionType(
371 			"Exception thrown will calling \"ensure\"", file, line, e
372 		);
373 	}
374 
375 	if(!rslt) {
376 		throw new ExceptionType(joinElem("Ensure failed", args), file, line);
377 	} else {
378 		return rslt;
379 	}
380 }
381 
382 ///
383 unittest {
384 	import core.exception : AssertError;
385 	//import std.exception : assertThrown, assertNotThrown;
386 	bool func() {
387 		throw new Exception("e");
388 	}
389 
390 	auto e = assertThrown!AssertError(ensure(func()));
391 	assert(e.line == __LINE__ - 1);
392 	auto e2 = assertThrown!AssertError(ensure(false));
393 	assert(e2.line == __LINE__ - 1);
394 	bool b = assertNotThrown!AssertError(ensure(true));
395 	assert(b);
396 }
397 
398 E assertThrown(E,T)(lazy T t, int line = __LINE__,
399 		string file = __FILE__)
400 {
401 	try {
402 		t();
403 	} catch(E e) {
404 		return e;
405 	}
406 	throw new ExceptionType("Exception of type " ~ E.stringof ~
407 			" was not thrown even though expected.", file, line
408 	);
409 }
410 
411 auto assertNotThrown(E,T)(lazy T t, int line = __LINE__,
412 		string file = __FILE__)
413 {
414 	try {
415 		return t();
416 	} catch(E e) {
417 		throw new ExceptionType("Exception of type " ~ E.stringof ~
418 				" caught unexceptionally", file, line
419 		);
420 	}
421 }
422 
423 ///
424 unittest {
425 	import core.exception : AssertError;
426 	//import std.exception : assertThrown, assertNotThrown;
427 	bool foo() {
428 		throw new Exception("e");
429 	}
430 
431 	bool bar() {
432 		return true;
433 	}
434 
435 	assertThrown!(AssertError)(assertThrown!(AssertError)(bar()));
436 	assertThrown!(AssertError)(assertNotThrown!(Exception)(foo()));
437 }