Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
logger.setLevel(logging.INFO)
logger.propagate = False

# Maximum size in bytes for a single tracker message. Tracker frames carry
# small JSON command tuples; 1 MiB is well above any legitimate payload and
# bounds memory growth when a peer sends an oversized or malformed size
# header on the wire.
MAX_TRACKER_MSG_BYTES = 1 << 20


class Scheduler:
"""Abstract interface of scheduler."""
Expand Down Expand Up @@ -224,14 +230,29 @@ def on_message(self, message):
if self._msg_size == 0:
if len(self._data) >= 4:
self._msg_size = struct.unpack("<i", self._data[:4])[0]
if self._msg_size <= 0 or self._msg_size > MAX_TRACKER_MSG_BYTES:
logger.warning(
"Invalid msg_size %d from %s; closing connection",
self._msg_size,
self.name(),
)
self.close()
return
del self._data[:4]
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
msg = py_str(bytes(self._data[4 : 4 + self._msg_size]))
del self._data[: 4 + self._msg_size]
if self._msg_size != 0 and len(self._data) >= self._msg_size:
msg = py_str(bytes(self._data[: self._msg_size]))
del self._data[: self._msg_size]
self._msg_size = 0
# pylint: disable=broad-except
self.call_handler(json.loads(msg))
try:
self.call_handler(json.loads(msg))
except Exception: # pylint: disable=broad-except
logger.warning(
"Error handling message from %s", self.name(), exc_info=True
)
self.close()
return
else:
return

Expand Down
44 changes: 44 additions & 0 deletions tests/python/contrib/test_rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,50 @@ def myfunc(remote):
print("Skip because tornado is not available")


def check_tracker_rejects_oversized_msg_size():
"""Tracker must reject an oversized msg_size header and close the connection
instead of buffering an unbounded amount of data on a single TCP connection.

Regression test for the unbounded buffer growth defect in
TCPEventHandler.on_message. See MAX_TRACKER_MSG_BYTES in tracker.py.
"""
try:
# pylint: disable=import-outside-toplevel
import socket
import struct

from tvm.rpc import base, tracker

tserver = tracker.Tracker(port=9180, port_end=9290, silent=True)
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
sock.connect(("127.0.0.1", tserver.port))
# complete the 4-byte magic handshake
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic_reply = sock.recv(4)
assert struct.unpack("<i", magic_reply)[0] == base.RPC_TRACKER_MAGIC

# send an oversized msg_size header (2 GiB)
sock.sendall(struct.pack("<i", 0x7FFFFFFF))

# server must close the connection (no payload buffering)
for _ in range(20):
chunk = sock.recv(4096)
if chunk == b"":
break
time.sleep(0.05)
else:
raise AssertionError(
"tracker did not close connection after oversized msg_size"
)
finally:
tserver.terminate()
except ImportError:
print("Skip because tornado is not available")


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
check_server_drop()
check_tracker_rejects_oversized_msg_size()