Skip to content

Commit 6ce4a35

Browse files
committed
Support PEP 393 new Unicode APIs
1 parent db6f88a commit 6ce4a35

2 files changed

Lines changed: 230 additions & 2 deletions

File tree

markupsafe/_speedups.c

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@
88
* :copyright: (c) 2010 by Armin Ronacher.
99
* :license: BSD.
1010
*/
11-
1211
#include <Python.h>
1312

13+
#if PY_MAJOR_VERSION < 3
1414
#define ESCAPED_CHARS_TABLE_SIZE 63
1515
#define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
1616

1717

18-
static PyObject* markup;
1918
static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
2019
static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
20+
#endif
21+
22+
static PyObject* markup;
2123

2224
static int
2325
init_constants(void)
2426
{
2527
PyObject *module;
28+
29+
#if PY_MAJOR_VERSION < 3
2630
/* mapping of characters to replace */
2731
escaped_chars_repl['"'] = UNICHR("&#34;");
2832
escaped_chars_repl['\''] = UNICHR("&#39;");
@@ -35,6 +39,7 @@ init_constants(void)
3539
escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
3640
escaped_chars_delta_len['&'] = 4;
3741
escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
42+
#endif
3843

3944
/* import markup type so that we can mark the return value */
4045
module = PyImport_ImportModule("markupsafe");
@@ -46,6 +51,7 @@ init_constants(void)
4651
return 1;
4752
}
4853

54+
#if PY_MAJOR_VERSION < 3
4955
static PyObject*
5056
escape_unicode(PyUnicodeObject *in)
5157
{
@@ -106,7 +112,174 @@ escape_unicode(PyUnicodeObject *in)
106112

107113
return (PyObject*)out;
108114
}
115+
#else /* PY_MAJOR_VERSION < 3 */
116+
117+
#define GET_DELTA(inp, inp_end, delta) \
118+
while (inp < inp_end) { \
119+
switch (*inp++) { \
120+
case '"': \
121+
case '\'': \
122+
case '&': \
123+
delta += 4; \
124+
break; \
125+
case '<': \
126+
case '>': \
127+
delta += 3; \
128+
break; \
129+
} \
130+
}
131+
132+
#define DO_ESCAPE(inp, inp_end, outp) \
133+
{ \
134+
Py_ssize_t ncopy = 0; \
135+
while (inp < inp_end) { \
136+
switch (*inp) { \
137+
case '"': \
138+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
139+
outp += ncopy; ncopy = 0; \
140+
*outp++ = '&'; \
141+
*outp++ = '#'; \
142+
*outp++ = '3'; \
143+
*outp++ = '4'; \
144+
*outp++ = ';'; \
145+
break; \
146+
case '\'': \
147+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
148+
outp += ncopy; ncopy = 0; \
149+
*outp++ = '&'; \
150+
*outp++ = '#'; \
151+
*outp++ = '3'; \
152+
*outp++ = '9'; \
153+
*outp++ = ';'; \
154+
break; \
155+
case '&': \
156+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
157+
outp += ncopy; ncopy = 0; \
158+
*outp++ = '&'; \
159+
*outp++ = 'a'; \
160+
*outp++ = 'm'; \
161+
*outp++ = 'p'; \
162+
*outp++ = ';'; \
163+
break; \
164+
case '<': \
165+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
166+
outp += ncopy; ncopy = 0; \
167+
*outp++ = '&'; \
168+
*outp++ = 'l'; \
169+
*outp++ = 't'; \
170+
*outp++ = ';'; \
171+
break; \
172+
case '>': \
173+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
174+
outp += ncopy; ncopy = 0; \
175+
*outp++ = '&'; \
176+
*outp++ = 'g'; \
177+
*outp++ = 't'; \
178+
*outp++ = ';'; \
179+
break; \
180+
default: \
181+
ncopy++; \
182+
} \
183+
inp++; \
184+
} \
185+
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
186+
}
187+
188+
static PyObject*
189+
escape_unicode_kind1(PyUnicodeObject *in)
190+
{
191+
Py_UCS1 *inp = PyUnicode_1BYTE_DATA(in);
192+
Py_UCS1 *inp_end = inp + PyUnicode_GET_LENGTH(in);
193+
Py_UCS1 *outp;
194+
PyObject *out;
195+
Py_ssize_t delta = 0;
196+
197+
GET_DELTA(inp, inp_end, delta);
198+
if (!delta) {
199+
Py_INCREF(in);
200+
return (PyObject*)in;
201+
}
202+
203+
out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta,
204+
PyUnicode_IS_ASCII(in) ? 127 : 255);
205+
if (!out)
206+
return NULL;
207+
208+
inp = PyUnicode_1BYTE_DATA(in);
209+
outp = PyUnicode_1BYTE_DATA(out);
210+
DO_ESCAPE(inp, inp_end, outp);
211+
return out;
212+
}
213+
214+
static PyObject*
215+
escape_unicode_kind2(PyUnicodeObject *in)
216+
{
217+
Py_UCS2 *inp = PyUnicode_2BYTE_DATA(in);
218+
Py_UCS2 *inp_end = inp + PyUnicode_GET_LENGTH(in);
219+
Py_UCS2 *outp;
220+
PyObject *out;
221+
Py_ssize_t delta = 0;
222+
223+
GET_DELTA(inp, inp_end, delta);
224+
if (!delta) {
225+
Py_INCREF(in);
226+
return (PyObject*)in;
227+
}
228+
229+
out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 65535);
230+
if (!out)
231+
return NULL;
232+
233+
inp = PyUnicode_2BYTE_DATA(in);
234+
outp = PyUnicode_2BYTE_DATA(out);
235+
DO_ESCAPE(inp, inp_end, outp);
236+
return out;
237+
}
238+
109239

240+
static PyObject*
241+
escape_unicode_kind4(PyUnicodeObject *in)
242+
{
243+
Py_UCS4 *inp = PyUnicode_4BYTE_DATA(in);
244+
Py_UCS4 *inp_end = inp + PyUnicode_GET_LENGTH(in);
245+
Py_UCS4 *outp;
246+
PyObject *out;
247+
Py_ssize_t delta = 0;
248+
249+
GET_DELTA(inp, inp_end, delta);
250+
if (!delta) {
251+
Py_INCREF(in);
252+
return (PyObject*)in;
253+
}
254+
255+
out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 1114111);
256+
if (!out)
257+
return NULL;
258+
259+
inp = PyUnicode_4BYTE_DATA(in);
260+
outp = PyUnicode_4BYTE_DATA(out);
261+
DO_ESCAPE(inp, inp_end, outp);
262+
return out;
263+
}
264+
265+
static PyObject*
266+
escape_unicode(PyUnicodeObject *in)
267+
{
268+
if (PyUnicode_READY(in))
269+
return NULL;
270+
271+
switch (PyUnicode_KIND(in)) {
272+
case PyUnicode_1BYTE_KIND:
273+
return escape_unicode_kind1(in);
274+
case PyUnicode_2BYTE_KIND:
275+
return escape_unicode_kind2(in);
276+
case PyUnicode_4BYTE_KIND:
277+
return escape_unicode_kind4(in);
278+
}
279+
assert(0); /* shouldn't happen */
280+
return NULL;
281+
}
282+
#endif /* PY_MAJOR_VERSION < 3 */
110283

111284
static PyObject*
112285
escape(PyObject *self, PyObject *text)

tests.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
import unittest
66
from markupsafe import Markup, escape, escape_silent
77
from markupsafe._compat import text_type, PY2
8+
from markupsafe import _native
9+
try:
10+
from markupsafe import _speedups
11+
have_speedups = True
12+
except ImportError:
13+
have_speedups = False
814

915

1016
class MarkupTestCase(unittest.TestCase):
@@ -199,6 +205,51 @@ def test_markup_leaks(self):
199205
'leak objects, got: ' + str(len(counts))
200206

201207

208+
class NativeEscapeTestCase(unittest.TestCase):
209+
210+
escape = staticmethod(_native.escape)
211+
212+
def test_empty(self):
213+
self.assertEqual(Markup(u''), self.escape(u''))
214+
215+
def test_ascii(self):
216+
self.assertEqual(
217+
Markup(u'abcd&amp;&gt;&lt;&#39;&#34;efgh'),
218+
self.escape(u'abcd&><\'"efgh'))
219+
self.assertEqual(
220+
Markup(u'&amp;&gt;&lt;&#39;&#34;efgh'),
221+
self.escape(u'&><\'"efgh'))
222+
self.assertEqual(
223+
Markup(u'abcd&amp;&gt;&lt;&#39;&#34;'),
224+
self.escape(u'abcd&><\'"'))
225+
226+
def test_2byte(self):
227+
self.assertEqual(
228+
Markup(u'こんにちは&amp;&gt;&lt;&#39;&#34;こんばんは'),
229+
self.escape(u'こんにちは&><\'"こんばんは'))
230+
self.assertEqual(
231+
Markup(u'&amp;&gt;&lt;&#39;&#34;こんばんは'),
232+
self.escape(u'&><\'"こんばんは'))
233+
self.assertEqual(
234+
Markup(u'こんにちは&amp;&gt;&lt;&#39;&#34;'),
235+
self.escape(u'こんにちは&><\'"'))
236+
237+
def test_4byte(self):
238+
self.assertEqual(
239+
Markup(u'\U0001F363\U0001F362&amp;&gt;&lt;&#39;&#34;\U0001F37A xyz'),
240+
self.escape(u'\U0001F363\U0001F362&><\'"\U0001F37A xyz'))
241+
self.assertEqual(
242+
Markup(u'&amp;&gt;&lt;&#39;&#34;\U0001F37A xyz'),
243+
self.escape(u'&><\'"\U0001F37A xyz'))
244+
self.assertEqual(
245+
Markup(u'\U0001F363\U0001F362&amp;&gt;&lt;&#39;&#34;'),
246+
self.escape(u'\U0001F363\U0001F362&><\'"'))
247+
248+
if have_speedups:
249+
class SpeedupEscapeTestCase(NativeEscapeTestCase):
250+
escape = _speedups.escape
251+
252+
202253
def suite():
203254
suite = unittest.TestSuite()
204255
suite.addTest(unittest.makeSuite(MarkupTestCase))
@@ -207,6 +258,10 @@ def suite():
207258
if not hasattr(escape, 'func_code'):
208259
suite.addTest(unittest.makeSuite(MarkupLeakTestCase))
209260

261+
suite.addTest(unittest.makeSuite(NativeEscapeTestCase))
262+
if have_speedups:
263+
suite.addTest(unittest.makeSuite(SpeedupEscapeTestCase))
264+
210265
return suite
211266

212267

0 commit comments

Comments
 (0)