Skip to content

Commit 460203c

Browse files
authored
Fix memory usage (#751)
* clear cache of flow.tree * categorical time period dtype * add pydantic for tests * use time_label_dtype only when available * allow missing taz skim_dict * recover tree when needed * predigitized time periods * pass sh_tree back again for tracing * better error message
1 parent 180dcca commit 460203c

12 files changed

Lines changed: 147 additions & 53 deletions

File tree

activitysim/abm/models/parking_location_choice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def parking_location(
318318
if "trip_period" not in trips_merged_df:
319319
# TODO: resolve this to the skim time period index not the label, it will be faster
320320
trips_merged_df["trip_period"] = network_los.skim_time_period_label(
321-
trips_merged_df[proposed_trip_departure_period]
321+
trips_merged_df[proposed_trip_departure_period], as_cat=True
322322
)
323323
model_settings["TRIP_DEPARTURE_PERIOD"] = "trip_period"
324324

activitysim/abm/models/trip_mode_choice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def trip_mode_choice(
7373
# setup skim keys
7474
assert "trip_period" not in trips_merged
7575
trips_merged["trip_period"] = network_los.skim_time_period_label(
76-
trips_merged.depart
76+
trips_merged.depart, as_cat=True
7777
)
7878

7979
orig_col = "origin"

activitysim/abm/models/util/logsums.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def compute_logsums(
7575
# FIXME - are we ok with altering choosers (so caller doesn't have to set these)?
7676
if (in_period_col is not None) and (out_period_col is not None):
7777
choosers["in_period"] = network_los.skim_time_period_label(
78-
choosers[in_period_col]
78+
choosers[in_period_col], as_cat=True
7979
)
8080
choosers["out_period"] = network_los.skim_time_period_label(
81-
choosers[out_period_col]
81+
choosers[out_period_col], as_cat=True
8282
)
8383
elif ("in_period" not in choosers.columns) and (
8484
"out_period" not in choosers.columns
@@ -92,17 +92,21 @@ def compute_logsums(
9292
and tour_purpose in model_settings["OUT_PERIOD"]
9393
):
9494
choosers["in_period"] = network_los.skim_time_period_label(
95-
model_settings["IN_PERIOD"][tour_purpose]
95+
model_settings["IN_PERIOD"][tour_purpose],
96+
as_cat=True,
97+
broadcast_to=choosers.index,
9698
)
9799
choosers["out_period"] = network_los.skim_time_period_label(
98-
model_settings["OUT_PERIOD"][tour_purpose]
100+
model_settings["OUT_PERIOD"][tour_purpose],
101+
as_cat=True,
102+
broadcast_to=choosers.index,
99103
)
100104
else:
101105
choosers["in_period"] = network_los.skim_time_period_label(
102-
model_settings["IN_PERIOD"]
106+
model_settings["IN_PERIOD"], as_cat=True, broadcast_to=choosers.index
103107
)
104108
choosers["out_period"] = network_los.skim_time_period_label(
105-
model_settings["OUT_PERIOD"]
109+
model_settings["OUT_PERIOD"], as_cat=True, broadcast_to=choosers.index
106110
)
107111
else:
108112
logger.error("Choosers table already has columns 'in_period' and 'out_period'.")

activitysim/abm/models/util/mode.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,12 @@ def run_tour_mode_choice_simulate(
131131
assert ("in_period" not in choosers) and ("out_period" not in choosers)
132132
in_time = skims["in_time_col_name"]
133133
out_time = skims["out_time_col_name"]
134-
choosers["in_period"] = network_los.skim_time_period_label(choosers[in_time])
135-
choosers["out_period"] = network_los.skim_time_period_label(choosers[out_time])
134+
choosers["in_period"] = network_los.skim_time_period_label(
135+
choosers[in_time], as_cat=True
136+
)
137+
choosers["out_period"] = network_los.skim_time_period_label(
138+
choosers[out_time], as_cat=True
139+
)
136140

137141
expressions.annotate_preprocessors(
138142
state, choosers, locals_dict, skims, model_settings, trace_label

activitysim/abm/models/util/vectorize_tour_scheduling.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ def dedupe_alt_tdd(state: workflow.State, alt_tdd, tour_purpose, trace_label):
185185

186186
logger.info("tdd_alt_segments specified for representative logsums")
187187

188+
if tdd_segments is not None:
189+
# apply categorical dtypes
190+
tdd_segments["time_period"] = tdd_segments["time_period"].astype(
191+
alt_tdd["out_period"].dtype
192+
)
193+
188194
with chunk.chunk_log(
189195
state, tracing.extend_trace_label(trace_label, "dedupe_alt_tdd")
190196
) as chunk_sizer:
@@ -328,11 +334,12 @@ def compute_tour_scheduling_logsums(
328334
assert "out_period" not in alt_tdd
329335
assert "in_period" not in alt_tdd
330336

331-
# FIXME:MEMORY
332-
# These two lines each generate a massive array of strings,
333-
# using a bunch of RAM and slowing things down.
334-
alt_tdd["out_period"] = network_los.skim_time_period_label(alt_tdd["start"])
335-
alt_tdd["in_period"] = network_los.skim_time_period_label(alt_tdd["end"])
337+
alt_tdd["out_period"] = network_los.skim_time_period_label(
338+
alt_tdd["start"], as_cat=True
339+
)
340+
alt_tdd["in_period"] = network_los.skim_time_period_label(
341+
alt_tdd["end"], as_cat=True
342+
)
336343

337344
alt_tdd["duration"] = alt_tdd["end"] - alt_tdd["start"]
338345

@@ -383,17 +390,28 @@ def compute_tour_scheduling_logsums(
383390

384391
# tracing.log_runtime(model_name=trace_label, start_time=t0)
385392

386-
# redupe - join the alt_tdd_period logsums to alt_tdd to get logsums for alt_tdd
387-
logsums = (
388-
pd.merge(
389-
alt_tdd.reset_index(),
390-
deduped_alt_tdds.reset_index(),
391-
on=[index_name] + redupe_columns,
392-
how="left",
393-
)
394-
.set_index(index_name)
395-
.logsums
396-
)
393+
logsums = pd.Series(data=0, index=alt_tdd.index, dtype=np.float64)
394+
left_on = [alt_tdd.index]
395+
right_on = [deduped_alt_tdds.index]
396+
for i in redupe_columns:
397+
if (
398+
alt_tdd[i].dtype == "category"
399+
and alt_tdd[i].dtype.ordered
400+
and alt_tdd[i].dtype == deduped_alt_tdds[i].dtype
401+
):
402+
left_on += [alt_tdd[i].cat.codes]
403+
right_on += [deduped_alt_tdds[i].cat.codes]
404+
else:
405+
left_on += [alt_tdd[i].to_numpy()]
406+
right_on += [deduped_alt_tdds[i].to_numpy()]
407+
408+
logsums.iloc[:] = pd.merge(
409+
pd.DataFrame(index=alt_tdd.index),
410+
deduped_alt_tdds.logsums,
411+
left_on=left_on,
412+
right_on=right_on,
413+
how="left",
414+
).logsums.to_numpy()
397415
chunk_sizer.log_df(trace_label, "logsums", logsums)
398416

399417
del deduped_alt_tdds

activitysim/abm/tables/landuse.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ def land_use(state: workflow.State):
2323

2424
sharrow_enabled = state.settings.sharrow
2525
if sharrow_enabled:
26+
err_msg = (
27+
"a zero-based land_use index is required for sharrow,\n"
28+
"try adding `recode_pipeline_columns: true` to your settings file."
29+
)
2630
# when using sharrow, the land use file must be organized (either in raw
2731
# form or via recoding) so that the index is zero-based and contiguous
28-
assert df.index.is_monotonic_increasing
29-
assert df.index[0] == 0
30-
assert df.index[-1] == len(df.index) - 1
31-
assert df.index.dtype.kind == "i"
32+
assert df.index.is_monotonic_increasing, err_msg
33+
assert df.index[0] == 0, err_msg
34+
assert df.index[-1] == len(df.index) - 1, err_msg
35+
assert df.index.dtype.kind == "i", err_msg
3236

3337
# try to make life easy for everybody by keeping everything in canonical order
3438
# but as long as coalesce_pipeline doesn't sort tables it coalesces, it might not stay in order

activitysim/core/flow.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def skims_mapping(
267267
parking_col_name=None,
268268
zone_layer=None,
269269
primary_origin_col_name=None,
270+
predigitized_time_periods=False,
270271
):
271272
logger.info("loading skims_mapping")
272273
logger.info(f"- orig_col_name: {orig_col_name}")
@@ -337,6 +338,10 @@ def skims_mapping(
337338
),
338339
)
339340
else:
341+
if predigitized_time_periods:
342+
time_rel = "_code ->"
343+
else:
344+
time_rel = " @"
340345
return dict(
341346
# TODO:SHARROW: organize dimensions.
342347
odt_skims=skim_dataset,
@@ -347,16 +352,16 @@ def skims_mapping(
347352
relationships=(
348353
f"df._orig_col_name -> odt_skims.{odim}",
349354
f"df._dest_col_name -> odt_skims.{ddim}",
350-
"df.out_period @ odt_skims.time_period",
355+
f"df.out_period{time_rel} odt_skims.time_period",
351356
f"df._dest_col_name -> dot_skims.{odim}",
352357
f"df._orig_col_name -> dot_skims.{ddim}",
353-
"df.in_period @ dot_skims.time_period",
358+
f"df.in_period{time_rel} dot_skims.time_period",
354359
f"df._orig_col_name -> odr_skims.{odim}",
355360
f"df._dest_col_name -> odr_skims.{ddim}",
356-
"df.in_period @ odr_skims.time_period",
361+
f"df.in_period{time_rel} odr_skims.time_period",
357362
f"df._dest_col_name -> dor_skims.{odim}",
358363
f"df._orig_col_name -> dor_skims.{ddim}",
359-
"df.out_period @ dor_skims.time_period",
364+
f"df.out_period{time_rel} dor_skims.time_period",
360365
f"df._orig_col_name -> od_skims.{odim}",
361366
f"df._dest_col_name -> od_skims.{ddim}",
362367
),
@@ -525,6 +530,15 @@ def new_flow(
525530

526531
cache_dir = state.filesystem.get_sharrow_cache_dir()
527532
logger.debug(f"flow.cache_dir: {cache_dir}")
533+
predigitized_time_periods = False
534+
if "out_period" in choosers and "in_period" in choosers:
535+
if (
536+
choosers["out_period"].dtype == "category"
537+
and choosers["in_period"].dtype == "category"
538+
):
539+
choosers["out_period_code"] = choosers["out_period"].cat.codes
540+
choosers["in_period_code"] = choosers["in_period"].cat.codes
541+
predigitized_time_periods = True
528542
skims_mapping_ = skims_mapping(
529543
state,
530544
orig_col_name,
@@ -534,6 +548,7 @@ def new_flow(
534548
parking_col_name=parking_col_name,
535549
zone_layer=zone_layer,
536550
primary_origin_col_name=primary_origin_col_name,
551+
predigitized_time_periods=predigitized_time_periods,
537552
)
538553
if size_term_mapping is None:
539554
size_term_mapping = {}
@@ -774,6 +789,9 @@ def apply_flow(
774789
it ever again, but having a reference to it available later can be useful
775790
in debugging and tracing. Flows are cached and reused anyway, so it is
776791
generally not important to delete this at any point to free resources.
792+
tree : sharrow.DataTree
793+
The tree data used to compute the flow result. It is seperate from the
794+
flow to prevent it from being cached with the flow.
777795
"""
778796
if sh is None:
779797
return None, None
@@ -800,7 +818,7 @@ def apply_flow(
800818
logger.error(f"error in apply_flow: {err!s}")
801819
if required:
802820
raise
803-
return None, None
821+
return None, None, None
804822
else:
805823
raise
806824
with logtime(f"{flow.name}.load", trace_label or ""):
@@ -822,7 +840,9 @@ def apply_flow(
822840
logger.error(f"error in apply_flow: {err!s}")
823841
if required:
824842
raise
825-
return None, flow
843+
tree = flow.tree
844+
flow.tree = None
845+
return None, flow, tree
826846
raise
827847
except Exception as err:
828848
logger.error(f"error in apply_flow: {err!s}")
@@ -833,4 +853,6 @@ def apply_flow(
833853
# Detecting compilation activity when in production mode is a bug
834854
# that should be investigated.
835855
tracing.timing_notes.add(f"compiled:{flow.name}")
836-
return flow_result, flow
856+
tree = flow.tree
857+
flow.tree = None
858+
return flow_result, flow, tree

activitysim/core/interaction_simulate.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def replace_in_index_level(mi, level, *repls):
171171

172172
timelogger.mark("sharrow preamble", True, logger, trace_label)
173173

174-
sh_util, sh_flow = apply_flow(
174+
sh_util, sh_flow, sh_tree = apply_flow(
175175
state,
176176
spec_sh,
177177
df,
@@ -187,10 +187,13 @@ def replace_in_index_level(mi, level, *repls):
187187
index=df.index if extra_data is None else None,
188188
)
189189
chunk_sizer.log_df(trace_label, "sh_util", None) # hand off to caller
190+
if sharrow_enabled != "test":
191+
# if not testing sharrow, we are done with this object now.
192+
del sh_util
190193

191194
timelogger.mark("sharrow flow", True, logger, trace_label)
192195
else:
193-
sh_util, sh_flow = None, None
196+
sh_util, sh_flow, sh_tree = None, None, None
194197
timelogger.mark("sharrow flow", False)
195198

196199
if (
@@ -404,9 +407,9 @@ def to_series(x):
404407
if sh_flow is not None and trace_rows is not None and trace_rows.any():
405408
assert type(trace_rows) == np.ndarray
406409
sh_utility_fat = sh_flow.load_dataarray(
407-
# sh_flow.tree.replace_datasets(
408-
# df=df.iloc[trace_rows],
409-
# ),
410+
sh_tree.replace_datasets(
411+
df=df.iloc[trace_rows],
412+
),
410413
dtype=np.float32,
411414
)
412415
sh_utility_fat = sh_utility_fat[trace_rows, :]

activitysim/core/los.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ def get_tappairs3d(self, otap, dtap, dim3, key):
845845

846846
return s.values
847847

848-
def skim_time_period_label(self, time_period, fillna=None):
848+
def skim_time_period_label(
849+
self, time_period, fillna=None, as_cat=False, broadcast_to=None
850+
):
849851
"""
850852
convert time period times to skim time period labels (e.g. 9 -> 'AM')
851853
@@ -873,6 +875,14 @@ def skim_time_period_label(self, time_period, fillna=None):
873875
assert 0 == model_time_window_min % period_minutes
874876
total_periods = model_time_window_min / period_minutes
875877

878+
try:
879+
time_label_dtype = self.skim_dicts["taz"].time_label_dtype
880+
except (KeyError, AttributeError):
881+
# if the "taz" skim_dict is missing, or if using old SkimDict
882+
# instead of SkimDataset, this labeling shortcut is unavailable.
883+
time_label_dtype = str
884+
as_cat = False
885+
876886
# FIXME - eventually test and use np version always?
877887
if np.isscalar(time_period):
878888
bin = (
@@ -888,6 +898,12 @@ def skim_time_period_label(self, time_period, fillna=None):
888898
result = self.skim_time_periods["labels"].get(bin, default=default)
889899
else:
890900
result = self.skim_time_periods["labels"][bin]
901+
if broadcast_to is not None:
902+
result = pd.Series(
903+
data=result,
904+
index=broadcast_to,
905+
dtype=time_label_dtype if as_cat else str,
906+
)
891907
else:
892908
result = pd.cut(
893909
time_period,
@@ -898,8 +914,10 @@ def skim_time_period_label(self, time_period, fillna=None):
898914
if fillna is not None:
899915
default = self.skim_time_periods["labels"][fillna]
900916
result = result.fillna(default)
901-
result = result.astype(str)
902-
917+
if as_cat:
918+
result = result.astype(time_label_dtype)
919+
else:
920+
result = result.astype(str)
903921
return result
904922

905923
def get_tazs(self, state):

activitysim/core/simulate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def eval_utilities(
536536
locals_dict.update(state.get_global_constants())
537537
if locals_d is not None:
538538
locals_dict.update(locals_d)
539-
sh_util, sh_flow = apply_flow(
539+
sh_util, sh_flow, sh_tree = apply_flow(
540540
state,
541541
spec_sh,
542542
choosers,
@@ -652,7 +652,7 @@ def eval_utilities(
652652
if sh_flow is not None:
653653
try:
654654
data_sh = sh_flow.load(
655-
sh_flow.tree.replace_datasets(
655+
sh_tree.replace_datasets(
656656
df=choosers.iloc[offsets],
657657
),
658658
dtype=np.float32,
@@ -731,7 +731,7 @@ def eval_utilities(
731731
)
732732
print(f"{sh_util.shape=}")
733733
print(misses)
734-
_sh_flow_load = sh_flow.load()
734+
_sh_flow_load = sh_flow.load(sh_tree)
735735
print("possible problematic expressions:")
736736
for expr_n, expr in enumerate(exprs):
737737
closeness = np.isclose(

0 commit comments

Comments
 (0)