1 module exceptionhandling;
2 
3 /**
4    version(exceptionhandling_release_asserts)
5    
6    releases all assertXXXXs
7 */
8 
9 private {
10 	version(unittest) {
11 		import core.exception : AssertError;
12 		alias ExceptionType = AssertError;
13 	} else {
14 		alias ExceptionType = Exception;
15 	}
16 
17 	import std.math : approxEqual;
18 	bool cmpFloat(T)(T tt, T tc) {
19 		return approxEqual(tt, tc);
20 	}
21 
22 	bool cmpFloatNot(T)(T tt, T tc) {
23 		return !approxEqual(tt, tc);
24 	}
25 
26 	bool cmpNot(T)(T tt, T tc) {
27 		return tt != tc;
28 	}
29 
30 	bool cmp(T)(T tt, T tc) {
31 		return tt == tc;
32 	}
33 
34 	bool cmpLess(T)(T tt, T tc) {
35 		return tt < tc;
36 	}
37 
38 	bool cmpGreater(T)(T tt, T tc) {
39 		return tt > tc;
40 	}
41 
42 	bool cmpLessEqual(T)(T tt, T tc) {
43 		return tt <= tc;
44 	}
45 
46 	bool cmpGreaterEqual(T)(T tt, T tc) {
47 		return tt >= tc;
48 	}
49 
50 	bool cmpLessEqualFloat(T)(T tt, T tc) {
51 		return tt < tc || approxEqual(tt, tc);
52 	}
53 
54 	bool cmpGreaterEqualFloat(T)(T tt, T tc) {
55 		return tt > tc || approxEqual(tt, tc);
56 	}
57 }
58 
59 /** Assert that `toTest` is equal to `toCompareAgainst`.
60 If `T` is a floating point `approxEqual` is used to compare the values.
61 `toTest` is returned if the comparision is correct.
62 If the comparision is incorrect an Exception is thrown. If assertEqual is used
63 in a unittest block an AssertError is thrown an Exception otherwise.
64 */
65 auto ref T assertEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
66 		const string file = __FILE__, const int line = __LINE__)
67 {
68 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
69 	static assert(isImplicitlyConvertible!(T,S));
70 
71 	static if(isFloatingPoint!T) {
72 		return AssertImpl!(T,S, cmpFloat, "==")(toTest, toCompareAgainst,
73 				file, line
74 		);
75 	} else {
76 		return AssertImpl!(T,S, cmp, "==")(toTest, toCompareAgainst, file, line);
77 	}
78 }
79 
80 /// ditto
81 auto ref T assertNotEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
82 		const string file = __FILE__, const int line = __LINE__)
83 {
84 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
85 	static assert(isImplicitlyConvertible!(T,S));
86 	static if(isFloatingPoint!T) {
87 		return AssertImpl!(T,S, cmpFloatNot, "!=")(toTest, toCompareAgainst,
88 				file, line
89 		);
90 	} else {
91 		return AssertImpl!(T,S, cmpNot, "!=")(toTest, toCompareAgainst, file,
92 				line
93 		);
94 	}
95 }
96 
97 /// ditto
98 auto ref T assertLess(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
99 		const string file = __FILE__, const int line = __LINE__)
100 {
101 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
102 	static assert(isImplicitlyConvertible!(T,S));
103 	return AssertImpl!(T,S, cmpLess, "<")(toTest, toCompareAgainst,
104 			file, line
105 	);
106 }
107 
108 /// ditto
109 auto ref T assertGreater(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
110 		const string file = __FILE__, const int line = __LINE__)
111 {
112 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
113 	static assert(isImplicitlyConvertible!(T,S));
114 	return AssertImpl!(T,S, cmpGreater, ">")(toTest, toCompareAgainst,
115 			file, line
116 	);
117 }
118 
119 /// ditto
120 auto ref T assertGreaterEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
121 		const string file = __FILE__, const int line = __LINE__)
122 {
123 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
124 	static assert(isImplicitlyConvertible!(T,S));
125 
126 	static if(isFloatingPoint!T) {
127 		return AssertImpl!(T,S, cmpGreaterEqualFloat, ">=")(toTest,
128 				toCompareAgainst, file, line
129 		);
130 	} else {
131 		return AssertImpl!(T,S, cmpGreaterEqual, ">=")(toTest,
132 				toCompareAgainst, file, line
133 		);
134 	}
135 }
136 
137 /// ditto
138 auto ref T assertLessEqual(T,S)(auto ref T toTest, auto ref S toCompareAgainst,
139 		const string file = __FILE__, const int line = __LINE__)
140 {
141 	import std.traits : isFloatingPoint, isImplicitlyConvertible;
142 	static assert(isImplicitlyConvertible!(T,S));
143 
144 	static if(isFloatingPoint!T) {
145 		return AssertImpl!(T,S, cmpLessEqualFloat, "<=")(toTest,
146 				toCompareAgainst, file, line
147 		);
148 	} else {
149 		return AssertImpl!(T,S, cmpLessEqual, "<=")(toTest,
150 				toCompareAgainst, file, line
151 		);
152 	}
153 }
154 
155 private auto ref T AssertImpl(T,S,alias Cmp, string cmpMsg)(auto ref T toTest,
156 		auto ref S toCompareAgainst, const string file, const int line)
157 {
158 	import std.format : format;
159 
160 	version(exceptionhandling_release_asserts) {
161 		return toTest;
162 	} else {
163 		bool cmpRslt = false;
164 		try {
165 			cmpRslt = Cmp(toTest, toCompareAgainst);
166 		} catch(Exception e) {
167 			throw new ExceptionType(
168 				format("Exception thrown while \"toTest(%s) " ~ cmpMsg
169 					~ " toCompareAgainst(%s)\"",
170 				toTest, toCompareAgainst), file, line, e
171 			);
172 		}
173 
174 		if(!cmpRslt) {
175 			throw new ExceptionType(format("toTest(%s) " ~ cmpMsg ~
176 				" toCompareAgainst(%s) failed", toTest, toCompareAgainst), file, 
177 					line
178 			);
179 		}
180 		return toTest;
181 	}
182 }
183 
184 unittest {
185 	import core.exception : AssertError;
186 	import std.exception : assertThrown;
187 	import std.meta : AliasSeq;
188 
189 	foreach(T; AliasSeq!(byte,int,float,double)) {
190 		T zero = 0;
191 		T one = 1;
192 		T two = 2;
193 
194 		T ret = assertEqual(one, one).assertGreater(zero).assertLess(two);
195 		cast(void)assertEqual(ret, one);
196 		ret = assertNotEqual(one, zero).assertGreater(zero).assertLess(two);
197 		cast(void)assertEqual(ret, one);
198 		ret = assertLessEqual(one, two)
199 			.assertGreaterEqual(zero)
200 			.assertEqual(one);
201 		cast(void)assertEqual(ret, one);
202 
203 		cast(void)assertEqual(cast(const(T))one, one);
204 		cast(void)assertNotEqual(cast(const(T))one, zero);
205 		cast(void)assertEqual(one, cast(const(T))one);
206 		cast(void)assertNotEqual(one, cast(const(T))zero);
207 
208 		assertThrown!AssertError(assertEqual(one, zero));
209 	}
210 
211 	cast(void)assertEqual(1, 1);
212 	cast(void)assertNotEqual(1, 0);
213 }
214 
215 unittest {
216 	import core.exception : AssertError;
217 	import std.exception : assertThrown;
218 
219 	class Foo {
220 		int a;
221 		this(int a) { this.a = a; }
222 		override bool opEquals(Object o) {
223 			throw new Exception("Another test");
224 		}
225 	}
226 
227 	auto f = new Foo(1);
228 	auto g = new Foo(1);
229 
230 	assertThrown!AssertError(assertEqual(f, cast(Foo)null));
231 	assertThrown!AssertError(assertEqual(f, g));
232 }
233 
234 /** Calls `exp` if `exp` does not throw the return value from `exp` is
235 returned, if `exp` throws the Exception is cought, a new Exception is
236 constructed with a message made of `args` space seperated and the previously
237 cought exception is nested in the newly created exception.
238 */
239 auto expect(ET = Exception, F, int line = __LINE__, string file = __FILE__, Args...)
240 		(lazy F exp, lazy Args args)
241 {
242 	try {
243 		return exp();
244 	} catch(Exception e) {
245 		throw new ET(joinElem(args), file, line, e);
246 	}
247 }
248 
249 private string joinElem(Args...)(lazy Args args) {
250 	import std.array : appender;
251 	import std.format : formattedWrite;
252 
253 	auto app = appender!string();
254 	foreach(arg; args) {
255 		formattedWrite(app, "%s ", arg);
256 	}
257 	return app.data;
258 }
259 
260 unittest {
261 	import std.string : indexOf;
262 	import std.exception : assertThrown;
263 
264 	string barMsg = "Fun will thrown, I'm sure";
265 	string funMsg = "Hopefully this is true";
266 
267 	void fun() {
268 		throw new Exception(funMsg);
269 	}
270 
271 	void bar() {
272 		expect(fun(), barMsg);
273 	}
274 
275 	bool func() {
276 		throw new Exception("e");
277 	}
278 
279 
280 	bool didThrow = false;
281 	try {
282 		bar();
283 	} catch(Exception e) {
284 		assert(e.msg.indexOf(barMsg) != -1, "\"" ~ e.msg ~ "\" " ~ barMsg);
285 		assert(e.next !is null);
286 		assert(e.next.msg.indexOf(funMsg) != -1, e.next.msg);
287 		didThrow = true;
288 	}
289 
290 	assert(didThrow);
291 
292 	assertThrown(assertEqual(func(), true));
293 }
294 
295 auto ref ensure(ET = ExceptionType, E, int line = __LINE__,
296 		string file = __FILE__, Args...)(lazy E exp, Args args)
297 {
298 	typeof(exp) rslt;
299 
300 	try {
301 		rslt = exp();
302 	} catch(Exception e) {
303 		throw new ExceptionType(
304 			"Exception thrown will calling \"ensure\"", file, line, e
305 		);
306 	}
307 
308 	if(!rslt) {
309 		throw new ExceptionType(joinElem("Ensure failed", args), file, line);
310 	} else {
311 		return rslt;
312 	}
313 }
314 
315 unittest {
316 	import core.exception : AssertError;
317 	//import std.exception : assertThrown, assertNotThrown;
318 	bool func() {
319 		throw new Exception("e");
320 	}
321 
322 	auto e = assertThrown!AssertError(ensure(func()));
323 	assert(e.line == __LINE__ - 1);
324 	auto e2 = assertThrown!AssertError(ensure(false));
325 	assert(e2.line == __LINE__ - 1);
326 	bool b = assertNotThrown!AssertError(ensure(true));
327 	assert(b);
328 }
329 
330 E assertThrown(E,T)(lazy T t, int line = __LINE__,
331 		string file = __FILE__)
332 {
333 	try {
334 		t();
335 	} catch(E e) {
336 		return e;
337 	}
338 	throw new ExceptionType("Exception of type " ~ E.stringof ~
339 			" was not thrown even though expected.", file, line
340 	);
341 }
342 
343 auto assertNotThrown(E,T)(lazy T t, int line = __LINE__,
344 		string file = __FILE__)
345 {
346 	try {
347 		return t();
348 	} catch(E e) {
349 		throw new ExceptionType("Exception of type " ~ E.stringof ~
350 				" not caught unexceptionally", file, line
351 		);
352 	}
353 }