{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "2021-07-23-grocery-recommendation-using-graph-network.ipynb",
"provenance": [],
"toc_visible": true,
"mount_file_id": "1cPLzOTYZ-bjoAp9UJGm_Bb2hdhVvd1M1",
"authorship_tag": "ABX9TyNxp/EiF+OGe0Z67Wt5J1I1"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"a495a0cf7928479291950a2f7f530734": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_c3ce9ef5d8bd4478949f7d21b2df356d",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_356f6a0c0bad461babd889941b728924",
"IPY_MODEL_8ebfaa8911d24f1eb8ceb66666d12d99"
]
}
},
"c3ce9ef5d8bd4478949f7d21b2df356d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"356f6a0c0bad461babd889941b728924": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_5a57039da7644e16b5c097cb2c95e0dd",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 3978,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 3978,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_974e635aab5e4ec5a8fc80bbfe7a4383"
}
},
"8ebfaa8911d24f1eb8ceb66666d12d99": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_1d48322154fd460d8c9b87ac897df9bf",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 3978/3978 [15:45<00:00, 4.21it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_5c4596f1ab044a6da4ec47e08ac60bc4"
}
},
"5a57039da7644e16b5c097cb2c95e0dd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"974e635aab5e4ec5a8fc80bbfe7a4383": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"1d48322154fd460d8c9b87ac897df9bf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"5c4596f1ab044a6da4ec47e08ac60bc4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e949cb02ed1a4971b04545f7436db8b8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_fc75c12c7e184a4fbec57a9569691dcf",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_23afc4c497154f02be04e93f85cd7842",
"IPY_MODEL_66b4e0f46b3b43c1b63b1e5038783979"
]
}
},
"fc75c12c7e184a4fbec57a9569691dcf": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"23afc4c497154f02be04e93f85cd7842": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_664304d4f06a457d9d713f3c625d1d83",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 3222,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 3222,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_3e68e2ff48d84d4d9bc752ab023d85b6"
}
},
"66b4e0f46b3b43c1b63b1e5038783979": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_2a52ba2c014843698de8b1bd6e766d4b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 3222/3222 [15:41<00:00, 3.42it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_1876c679da2143f7874894593742baaa"
}
},
"664304d4f06a457d9d713f3c625d1d83": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"3e68e2ff48d84d4d9bc752ab023d85b6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2a52ba2c014843698de8b1bd6e766d4b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"1876c679da2143f7874894593742baaa": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2908997c4929430f918dd0c63608b106": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_a72e4ad4bfc347d3afee27b1f19ea1fa",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_53e606b87aa14214adb83d0a7e152ac0",
"IPY_MODEL_a5898e01abb242baad7ff68221106c26"
]
}
},
"a72e4ad4bfc347d3afee27b1f19ea1fa": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"53e606b87aa14214adb83d0a7e152ac0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_a254c644e16d4725bf91cc0a2288d1e1",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 382,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 382,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_dee2b7a161564a90989ac4ce59e6e501"
}
},
"a5898e01abb242baad7ff68221106c26": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_bb26e1e4c01f48dc93f1a8efcf14f0e6",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 382/382 [01:11<00:00, 5.38it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_761feb5caa81411ca7b4b4f9cd45a194"
}
},
"a254c644e16d4725bf91cc0a2288d1e1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"dee2b7a161564a90989ac4ce59e6e501": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"bb26e1e4c01f48dc93f1a8efcf14f0e6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"761feb5caa81411ca7b4b4f9cd45a194": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"86a1e7530696423e9ccba1dacc5c435d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_4a7ddd8095874012841f9e7f12447b0e",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_a073ac68471a431e98a14fed6f121636",
"IPY_MODEL_82b43f61286e4edb9fab53cf1c625de8"
]
}
},
"4a7ddd8095874012841f9e7f12447b0e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"a073ac68471a431e98a14fed6f121636": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_10afc6afd6d34cb8b282d25af6bfb261",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 382,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 382,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b1cff006a21c4323a0e4cead86772f00"
}
},
"82b43f61286e4edb9fab53cf1c625de8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_9bca0d35e0ef4b07bbc001d787c17a39",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 382/382 [00:37<00:00, 10.31it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_7a5d3adfe202447a8bc49a2eaa962008"
}
},
"10afc6afd6d34cb8b282d25af6bfb261": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"b1cff006a21c4323a0e4cead86772f00": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"9bca0d35e0ef4b07bbc001d787c17a39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"7a5d3adfe202447a8bc49a2eaa962008": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wK2-BSgmQ2Ai"
},
"source": [
"# Grocery Recommendation using Graph Network\n",
"> Building word2vec based Graph network using Instacart dataset and finding similar as well as neighbourhood items, and building a dash app\n",
"\n",
"- toc: true\n",
"- badges: true\n",
"- comments: true\n",
"- categories: [Dash, App, NetworkX, Word2Vec, Graph, Retail, Visualization]\n",
"- image:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bTWqNCOePwQ_"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ucPNt39z81kg"
},
"source": [
"!pip install -q dash dash-renderer dash-html-components dash-core-components\n",
"!pip install -q jupyter-dash"
],
"execution_count": 86,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fVjCEEO786RX"
},
"source": [
"import re\n",
"import random\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import plotly.offline as py\n",
"import plotly.graph_objects as go\n",
"\n",
"import networkx as nx\n",
"from networkx.readwrite import json_graph\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import itertools\n",
"import collections\n",
"\n",
"from gensim.models import Word2Vec\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import pickle\n",
"\n",
"import dash\n",
"import dash_core_components as dcc\n",
"import dash_html_components as html\n",
"from jupyter_dash import JupyterDash\n",
"from dash.dependencies import Input, Output, State, ALL"
],
"execution_count": 87,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fl0gfJeX-euv",
"outputId": "7e21aa01-98ca-4010-f9c2-b2121f773e77"
},
"source": [
"!pip install -q watermark\n",
"%reload_ext watermark\n",
"%watermark -m -iv"
],
"execution_count": 92,
"outputs": [
{
"output_type": "stream",
"text": [
"Compiler : GCC 7.5.0\n",
"OS : Linux\n",
"Release : 5.4.104+\n",
"Machine : x86_64\n",
"Processor : x86_64\n",
"CPU cores : 2\n",
"Architecture: 64bit\n",
"\n",
"numpy : 1.19.5\n",
"networkx : 2.5.1\n",
"seaborn : 0.11.1\n",
"re : 2.2.1\n",
"sys : 3.7.11 (default, Jul 3 2021, 18:01:19) \n",
"[GCC 7.5.0]\n",
"pandas : 1.1.5\n",
"dash : 1.21.0\n",
"plotly : 4.4.1\n",
"IPython : 5.5.0\n",
"dash_html_components: 1.1.4\n",
"dash_core_components: 1.17.1\n",
"matplotlib : 3.2.2\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_p_wYeySP1d1"
},
"source": [
"## Loading data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-7wt1IwpGWlm"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "7-2nOHbi9DL8"
},
"source": [
"!pip install -q -U kaggle\n",
"!pip install --upgrade --force-reinstall --no-deps kaggle\n",
"!mkdir ~/.kaggle\n",
"!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/\n",
"!chmod 600 ~/.kaggle/kaggle.json\n",
"!kaggle competitions download -c instacart-market-basket-analysis\n",
"!unzip /content/instacart-market-basket-analysis.zip"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "m8opruUi9Yoo",
"outputId": "e68def05-38ae-4eef-bcd8-c6c64705fa92"
},
"source": [
"!sudo apt-get install tree\n",
"!tree . -L 1"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"Reading package lists... Done\n",
"Building dependency tree \n",
"Reading state information... Done\n",
"tree is already the newest version (1.7.0-5).\n",
"0 upgraded, 0 newly installed, 0 to remove and 40 not upgraded.\n",
".\n",
"├── aisles.csv.zip\n",
"├── departments.csv.zip\n",
"├── drive\n",
"├── instacart-market-basket-analysis.zip\n",
"├── order_products__prior.csv.zip\n",
"├── order_products__train.csv.zip\n",
"├── orders.csv.zip\n",
"├── products.csv.zip\n",
"├── __pycache__\n",
"├── sample_data\n",
"└── sample_submission.csv.zip\n",
"\n",
"3 directories, 8 files\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "8UHM52Al_RQW",
"outputId": "015ca52d-062b-4e5d-e1b5-4ce20bb1d64b"
},
"source": [
"!unzip -qqo /content/order_products__train.csv.zip\n",
"train_df = pd.read_csv('order_products__train.csv')\n",
"train_df.head()"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" order_id | \n",
" product_id | \n",
" add_to_cart_order | \n",
" reordered | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" 49302 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" 11109 | \n",
" 2 | \n",
" 1 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 10246 | \n",
" 3 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 49683 | \n",
" 4 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1 | \n",
" 43633 | \n",
" 5 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" order_id product_id add_to_cart_order reordered\n",
"0 1 49302 1 1\n",
"1 1 11109 2 1\n",
"2 1 10246 3 0\n",
"3 1 49683 4 0\n",
"4 1 43633 5 1"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "PiFy5-DL_RL0",
"outputId": "88f94165-b5fc-4d86-90cf-5711cafaff86"
},
"source": [
"!unzip -qqo /content/products.csv.zip\n",
"products_df = pd.read_csv('products.csv')\n",
"products_df.head()"
],
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" product_id | \n",
" product_name | \n",
" aisle_id | \n",
" department_id | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" Chocolate Sandwich Cookies | \n",
" 61 | \n",
" 19 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2 | \n",
" All-Seasons Salt | \n",
" 104 | \n",
" 13 | \n",
"
\n",
" \n",
" | 2 | \n",
" 3 | \n",
" Robust Golden Unsweetened Oolong Tea | \n",
" 94 | \n",
" 7 | \n",
"
\n",
" \n",
" | 3 | \n",
" 4 | \n",
" Smart Ones Classic Favorites Mini Rigatoni Wit... | \n",
" 38 | \n",
" 1 | \n",
"
\n",
" \n",
" | 4 | \n",
" 5 | \n",
" Green Chile Anytime Sauce | \n",
" 5 | \n",
" 13 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" product_id ... department_id\n",
"0 1 ... 19\n",
"1 2 ... 13\n",
"2 3 ... 7\n",
"3 4 ... 1\n",
"4 5 ... 13\n",
"\n",
"[5 rows x 4 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Px03Qre9P5sO"
},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Jk70YSgM9wDV"
},
"source": [
"def return_dfs(train_df, products_df, train_percent=0.1, products_cutoff=0,\n",
" orders_q1=5, orders_q2=9):\n",
" ''' Function that returns two dataframes for 2 segments of users based basket size\n",
" Args: train_file - the training dataframe\n",
" products_file - the products dataframe\n",
" train_percent - percentage of the train file sampled for this (smaller % makes viz possible)\n",
" products_cutoff - only products appearing MORE often than this are included\n",
" orders_q1 - first cutoff point for number of items in a basket\n",
" orders_q2 - second cutoff point for number of items in a basket\n",
" \n",
" '''\n",
" orders = train_df[['order_id', 'product_id']].copy()\n",
" products = products_df.copy()\n",
" \n",
" # Get a wide range of orders\n",
" order_ids = orders.order_id.unique()\n",
" # Select a sample of the orders\n",
" order_ids = random.sample(set(order_ids), int(len(order_ids)*train_percent))\n",
" # Reduce the size of the initial orders data\n",
" orders = orders[orders['order_id'].isin(order_ids)]\n",
" \n",
" \n",
" # Take a look at the distribution of product counts\n",
" counts = orders.groupby('product_id').count()\n",
" counts.rename(columns = {'order_id':'count'}, inplace = True)\n",
" counts.reset_index(inplace = True)\n",
" # Remove the products occuring less often that products_cutoff\n",
" product_ids = counts.product_id[counts['count'] > products_cutoff]\n",
" \n",
" # Filter for baskets of a certain size\n",
" counts = orders.groupby('order_id').count()\n",
" counts.rename(columns = {'product_id':'count'}, inplace = True)\n",
" counts.reset_index(inplace = True)\n",
" # Only keep baskets below orders_q1 size and between orders_q1 and orders_q2 size\n",
" order_ids_Q1 = counts.order_id[counts['count'] <= orders_q1]\n",
" order_ids_Q2 = counts.order_id[(counts['count'] <= orders_q2) & (counts['count'] > orders_q1)]\n",
" \n",
" # Create two dataframes for the orders\n",
" orders_small = orders[orders['order_id'].isin(order_ids_Q1)]\n",
" orders_small = orders_small[orders_small['product_id'].isin(product_ids)]\n",
" orders_small = orders_small.merge(products.loc[:, ['product_id', 'product_name']], how = 'left')\n",
" # To simplify what the orders look like, I've replaced 'bag of organic bananas' with just 'bananas'\n",
" orders_small['product_name'] = orders_small['product_name'].replace({'Bag of Organic Bananas': 'Banana'})\n",
" orders_small['product_name'] = orders_small['product_name'].str.replace('Organic ', '')\n",
"\n",
" orders_large = orders[orders['order_id'].isin(order_ids_Q2)]\n",
" orders_large = orders_large[orders_large['product_id'].isin(product_ids)]\n",
" orders_large = orders_large.merge(products.loc[:, ['product_id', 'product_name']], how = 'left')\n",
"\n",
" orders_large['product_name'] = orders_large['product_name'].replace({'Bag of Organic Bananas': 'Banana'})\n",
" orders_large['product_name'] = orders_large['product_name'].str.replace('Organic ', '')\n",
" \n",
" return orders_small, orders_large, order_ids_Q1, order_ids_Q2"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IY2I8xqY-PiJ"
},
"source": [
"orders_small, orders_large, order_ids_Q1, order_ids_Q2 = return_dfs(train_df, products_df)"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "4QgS2UpsAXi6",
"outputId": "f794a85b-aa76-407e-cf20-18b4f46a361c"
},
"source": [
"orders_small.head()"
],
"execution_count": 20,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" order_id | \n",
" product_id | \n",
" product_name | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 719 | \n",
" 45683 | \n",
" Heavy Duty Scrub Sponge | \n",
"
\n",
" \n",
" | 1 | \n",
" 904 | \n",
" 8013 | \n",
" Cup Noodles Chicken Flavor | \n",
"
\n",
" \n",
" | 2 | \n",
" 904 | \n",
" 46149 | \n",
" Zero Calorie Cola | \n",
"
\n",
" \n",
" | 3 | \n",
" 988 | \n",
" 45061 | \n",
" Natural Vanilla Ice Cream | \n",
"
\n",
" \n",
" | 4 | \n",
" 988 | \n",
" 28464 | \n",
" Whipped Light Cream, Original | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" order_id product_id product_name\n",
"0 719 45683 Heavy Duty Scrub Sponge\n",
"1 904 8013 Cup Noodles Chicken Flavor\n",
"2 904 46149 Zero Calorie Cola\n",
"3 988 45061 Natural Vanilla Ice Cream\n",
"4 988 28464 Whipped Light Cream, Original"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "zZfD3l2YAZWc",
"outputId": "05615a47-b34b-4073-df56-7e1c79fdb6b2"
},
"source": [
"orders_large.head()"
],
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" order_id | \n",
" product_id | \n",
" product_name | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" 49302 | \n",
" Bulgarian Yogurt | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" 11109 | \n",
" 4% Milk Fat Whole Milk Cottage Cheese | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 10246 | \n",
" Celery Hearts | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 49683 | \n",
" Cucumber Kirby | \n",
"
\n",
" \n",
" | 4 | \n",
" 1 | \n",
" 43633 | \n",
" Lightly Smoked Sardines in Olive Oil | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" order_id product_id product_name\n",
"0 1 49302 Bulgarian Yogurt\n",
"1 1 11109 4% Milk Fat Whole Milk Cottage Cheese\n",
"2 1 10246 Celery Hearts\n",
"3 1 49683 Cucumber Kirby\n",
"4 1 43633 Lightly Smoked Sardines in Olive Oil"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DLAMF-ALAe3W",
"outputId": "3874060f-31ab-4c14-c405-897255ad121f"
},
"source": [
"order_ids_Q1"
],
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"3 719\n",
"4 904\n",
"5 988\n",
"14 3243\n",
"15 3817\n",
" ... \n",
"13099 3416849\n",
"13110 3419245\n",
"13113 3419891\n",
"13114 3420008\n",
"13117 3420798\n",
"Name: order_id, Length: 3978, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QIZLIUYHAq9r"
},
"source": [
"## Processing the Data for NetworkX\n",
"Here we need to create tuples comprising the paired items in the data so that we can build the graph. This code creates two sets of data, one for the \"small\" baskets and one for the \"large\" (although they're still quite small) baskets"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 115,
"referenced_widgets": [
"a495a0cf7928479291950a2f7f530734",
"c3ce9ef5d8bd4478949f7d21b2df356d",
"356f6a0c0bad461babd889941b728924",
"8ebfaa8911d24f1eb8ceb66666d12d99",
"5a57039da7644e16b5c097cb2c95e0dd",
"974e635aab5e4ec5a8fc80bbfe7a4383",
"1d48322154fd460d8c9b87ac897df9bf",
"5c4596f1ab044a6da4ec47e08ac60bc4",
"e949cb02ed1a4971b04545f7436db8b8",
"fc75c12c7e184a4fbec57a9569691dcf",
"23afc4c497154f02be04e93f85cd7842",
"66b4e0f46b3b43c1b63b1e5038783979",
"664304d4f06a457d9d713f3c625d1d83",
"3e68e2ff48d84d4d9bc752ab023d85b6",
"2a52ba2c014843698de8b1bd6e766d4b",
"1876c679da2143f7874894593742baaa"
]
},
"id": "e6a5M5lnAhqu",
"outputId": "c43c8838-b2dd-4f1d-f66c-84bb80e87dac"
},
"source": [
"paired_products_small = []\n",
"\n",
"# Create the pairwise product combinations\n",
"for order_id in tqdm(order_ids_Q1):\n",
" tmp_df = orders_small[orders_small['order_id'] == order_id]\n",
" paired_products_small.extend(list(itertools.combinations(tmp_df.iloc[:, 2], 2)))\n",
" \n",
"paired_products_large = []\n",
"\n",
"# Create the pairwise product combinations\n",
"for order_id in tqdm(order_ids_Q2):\n",
" tmp_df = orders_large[orders_large['order_id'] == order_id]\n",
" paired_products_large.extend(list(itertools.combinations(tmp_df.iloc[:, 2], 2)))\n",
" \n",
"counts_small = collections.Counter(paired_products_small)\n",
"\n",
"counts_large = collections.Counter(paired_products_large)\n",
"\n",
"food_df_small = pd.DataFrame(counts_small.most_common(1000),\n",
" columns = ['products', 'counts'])\n",
"\n",
"\n",
"food_df_large = pd.DataFrame(counts_large.most_common(4000),\n",
" columns = ['products', 'counts'])"
],
"execution_count": 25,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a495a0cf7928479291950a2f7f530734",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=3978.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e949cb02ed1a4971b04545f7436db8b8",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=3222.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uEFzEPO4BGJy"
},
"source": [
"# Turn one of the dataframes into a dictionary for processing into a graph\n",
"d = food_df_small.set_index('products').T.to_dict('records')\n",
"# d = food_df_large.set_index('products').T.to_dict('records')"
],
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YAPzMOhiBtKb",
"outputId": "4f125ad9-400b-4e22-dcbd-5ca381bdc238"
},
"source": [
"dict(list(d[0].items())[0:10])"
],
"execution_count": 34,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{('Banana', 'Baby Spinach'): 7,\n",
" ('Banana', 'Clementines'): 11,\n",
" ('Banana', 'Raspberries'): 9,\n",
" ('Banana', 'Strawberries'): 13,\n",
" ('Raspberries', 'Blackberries'): 12,\n",
" ('Raspberries', 'Blueberries'): 14,\n",
" ('Raspberries', 'Strawberries'): 14,\n",
" ('Strawberries', 'Banana'): 8,\n",
" ('Strawberries', 'Blueberries'): 10,\n",
" ('Strawberries', 'Raspberries'): 12}"
]
},
"metadata": {
"tags": []
},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Dpt-jUMBBisR",
"outputId": "eccf5447-701c-4e4a-a548-785559c2a61c"
},
"source": [
"# Create and populate the graph object\n",
"G = nx.Graph()\n",
"\n",
"for key, val in d[0].items():\n",
" G.add_edge(key[0], key[1], weight = val)\n",
"\n",
"# Take a look at how many nodes there are in the graph; too many and it's uncomfortable to visualise\n",
"nodes = list(G.nodes)\n",
"len(nodes)"
],
"execution_count": 35,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"527"
]
},
"metadata": {
"tags": []
},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ImVVMWVNCQ_H",
"outputId": "c1816646-260c-4897-f6bc-7e16bf7a2206"
},
"source": [
"# Prune the plot so we only have items that are matched with at least two others\n",
"for node in nodes:\n",
" try:\n",
" if G.degree[node] <= 1:\n",
" G.remove_node(node)\n",
" except:\n",
" print(f'error with node {node}')\n",
"\n",
"nodes = list(G.nodes)\n",
"len(nodes)"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"382"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cPTbLMiuCnQe"
},
"source": [
"with open('large_graph.pickle', 'wb') as handle:\n",
" pickle.dump(G, handle, protocol=pickle.HIGHEST_PROTOCOL)"
],
"execution_count": 39,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9kSDVEuCTVH"
},
"source": [
"## Build the Word2Vec model\n",
"This section of the code focuses on building the Word2Vec-like embeddings for the nodes in the network using the Deep Walk procedure."
]
},
{
"cell_type": "code",
"metadata": {
"id": "WWw8cEGsCZQV"
},
"source": [
"# Read the pickle in \n",
"with open('large_graph.pickle', 'rb') as f:\n",
" G_large = pickle.load(f)"
],
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fUvy3GNrHa5E"
},
"source": [
"def load_graph(segment):\n",
" ''' Function that creates the graph of the graph based on the min number of edges\n",
" Args: segment: indicates which segment: 0, 1, 2 to choose -> int\n",
" Returns: graph and pos objects\n",
" '''\n",
" ### Load the data up\n",
" segments = ['small_graph.pickle', 'med_graph.pickle', 'large_graph.pickle']\n",
"\n",
" with open(segments[segment], 'rb') as f:\n",
" G = pickle.load(f)\n",
"\n",
" pos = nx.spring_layout(G)\n",
"\n",
" return pos, G"
],
"execution_count": 90,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66,
"referenced_widgets": [
"2908997c4929430f918dd0c63608b106",
"a72e4ad4bfc347d3afee27b1f19ea1fa",
"53e606b87aa14214adb83d0a7e152ac0",
"a5898e01abb242baad7ff68221106c26",
"a254c644e16d4725bf91cc0a2288d1e1",
"dee2b7a161564a90989ac4ce59e6e501",
"bb26e1e4c01f48dc93f1a8efcf14f0e6",
"761feb5caa81411ca7b4b4f9cd45a194"
]
},
"id": "O06nFr4ZChJ0",
"outputId": "e754847b-9388-4d9e-9ad1-612a9b1cff3b"
},
"source": [
"# Build a dictionary containing the weights of the edges; doing it this way saves a LOT of time in doing the probabilistic\n",
"# random walks in the next steps\n",
"weights = {}\n",
"for node in tqdm(G_large.nodes()):\n",
" w_ = []\n",
" for nodes in list(G_large.edges(str(node))):\n",
" w_.append(G_large.get_edge_data(nodes[0], nodes[1])['weight'])\n",
" weights[node]=w_ "
],
"execution_count": 42,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2908997c4929430f918dd0c63608b106",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=382.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "-zjlAG4CC0LU"
},
"source": [
"def random_walk(graph, node, weighted=False, n_steps = 5):\n",
" ''' Function that takes a random walk along a graph'''\n",
" local_path = [str(node),]\n",
" target_node = node\n",
" \n",
" # Take n_steps random walk away from the node (can return to the node)\n",
" for _ in range(n_steps):\n",
" neighbours = list(nx.all_neighbors(graph, target_node))\n",
" # See the difference between doing this with and without edge weight - it takes many, many times longer\n",
" if weighted:\n",
" # sample in a weighted manner\n",
" target_node = random.choices(neighbours, weights[target_node])[0]\n",
" else:\n",
" target_node = random.choice(neighbours)\n",
" local_path.append(str(target_node))\n",
" \n",
" return local_path"
],
"execution_count": 43,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "i-Cm6L9IDQFU"
},
"source": [
"Now we do the random walk and then we create the node embeddings\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66,
"referenced_widgets": [
"86a1e7530696423e9ccba1dacc5c435d",
"4a7ddd8095874012841f9e7f12447b0e",
"a073ac68471a431e98a14fed6f121636",
"82b43f61286e4edb9fab53cf1c625de8",
"10afc6afd6d34cb8b282d25af6bfb261",
"b1cff006a21c4323a0e4cead86772f00",
"9bca0d35e0ef4b07bbc001d787c17a39",
"7a5d3adfe202447a8bc49a2eaa962008"
]
},
"id": "Cz_hDXgIDOYd",
"outputId": "269f9f0a-e48e-47f8-8037-676ad2580063"
},
"source": [
"walk_paths_weighted = []\n",
"\n",
"i = 0\n",
"for node in tqdm(G_large.nodes()):\n",
" for _ in range(10):\n",
" walk_paths_weighted.append(random_walk(G_large, node, weighted=True))"
],
"execution_count": 44,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "86a1e7530696423e9ccba1dacc5c435d",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=382.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "keBb_QWjDTiK",
"outputId": "c7bfee96-752b-4b8d-ed4f-3deb9e205e07"
},
"source": [
"# Instantiate the embedder\n",
"embedder_weighted = Word2Vec(window = 4, sg=1, negative=10, alpha=0.03, min_alpha=0.0001, seed=42)\n",
"# Build the vocab\n",
"embedder_weighted.build_vocab(walk_paths_weighted, progress_per=2)\n",
"# Train teh embedder to build the word embeddings- this takes a little bit of time\n",
"embedder_weighted.train(walk_paths_weighted, total_examples=embedder_weighted.corpus_count, epochs=20, report_delay=1)"
],
"execution_count": 45,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(381480, 458400)"
]
},
"metadata": {
"tags": []
},
"execution_count": 45
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LoSzE80JDbjQ",
"outputId": "fb4f547c-1f42-4c41-a7e3-c58959d66624"
},
"source": [
"some_random_words = [list(embedder_weighted.wv.vocab.keys())[x] for x in np.random.choice(len(embedder_weighted.wv.vocab),10)]\n",
"some_random_words"
],
"execution_count": 65,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Green Bell Pepper',\n",
" 'Natural Vanilla Ice Cream',\n",
" 'Avocado',\n",
" 'Baby Spring Mix',\n",
" 'Crimini Mushrooms',\n",
" 'Zucchini',\n",
" 'Total 0% Nonfat Plain Greek Yogurt',\n",
" 'Cold Brew Coffee Double Espresso with Almond Milk',\n",
" 'Half Baked® Ice Cream',\n",
" 'Vanilla Almond Breeze Almond Milk']"
]
},
"metadata": {
"tags": []
},
"execution_count": 65
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GoYTqbEmDYnT",
"outputId": "58347235-533c-4ce7-9e98-aa1921aa55bc"
},
"source": [
"x_ = embedder_weighted.wv.most_similar(some_random_words[0], topn=10)\n",
"[i[0] for i in x_]"
],
"execution_count": 66,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Large Grade AA Brown Eggs',\n",
" 'Red Bell Pepper',\n",
" 'Milk, Organic, Fat Free',\n",
" 'Cucumber Kirby',\n",
" 'Chicken Thighs',\n",
" 'Red Onion',\n",
" 'Boneless Skinless Chicken Breast',\n",
" 'Large Extra Fancy Fuji Apple',\n",
" 'Popcorn Shrimp Oven Crispy',\n",
" 'Lorna Doone Shortbread Cookies']"
]
},
"metadata": {
"tags": []
},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pW7e8wTdFhu0",
"outputId": "1f8152b9-779b-4280-e2ed-e99a874b23f0"
},
"source": [
"x_ = embedder_weighted.wv.most_similar(some_random_words[1], topn=10)\n",
"[i[0] for i in x_]"
],
"execution_count": 67,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Classic Vanilla Coffee Creamer',\n",
" 'Whipped Light Cream, Original',\n",
" 'Complete ActionPacs Lemon Burst Dishwasher Detergent',\n",
" 'Raspberry Tea',\n",
" 'Medium Roast Original Blend Ground Coffee',\n",
" 'Original Fat Free Liquid Creamer',\n",
" 'Diet Peach',\n",
" 'Sriracha Sauce',\n",
" 'Just Mayo',\n",
" 'Taquitos, Crispy, Large, Vegan Chorizo & Black Bean Style 8 Count']"
]
},
"metadata": {
"tags": []
},
"execution_count": 67
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QCDKucnYFn3Y"
},
"source": [
"## Save and/or load the embedding objects\n",
"\n",
"# with open('embedder_weighted.pickle', 'wb') as f:\n",
"# pickle.dump(embedder_weighted, f)\n",
"\n",
"# with open('embedder_weighted.pickle', 'rb') as f:\n",
"# embedder_weighted = pickle.load(f)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "70wpnnpsQQMd"
},
"source": [
"## Finding similar items"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sr-kgFnNFy97",
"outputId": "42b424c2-1865-4a60-e02e-a0195575f5be"
},
"source": [
"nodes = [node for node in G.nodes()]\n",
"pos = nx.spring_layout(G)\n",
"len(nodes)"
],
"execution_count": 68,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"382"
]
},
"metadata": {
"tags": []
},
"execution_count": 68
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ql6pYoFXF7pD"
},
"source": [
"def similar_embeddings(source_node, topn):\n",
" ''' Function that returns the top ncounts most similar items using embeddings'''\n",
" most_similar = embedder.wv.most_similar(source_node, topn=topn)\n",
" return [i[0] for i in most_similar]"
],
"execution_count": 84,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QMo4v_xqH1sf"
},
"source": [
"def find_ingredient(nodes, ingredient=\"Pear\"):\n",
" ''' Function that returns the closet match to an ingredient in the graph\n",
" Args: ingredient: the ingredient you want to find -> str\n",
" nodes: a list of the nodes in the graph -> list\n",
" Returns: a list of the closest ingredients found\n",
" '''\n",
" ingredients = []\n",
"\n",
" for node in nodes:\n",
" # This does a string-like search for the ingredient/item in each node\n",
" # So ingredient=\"Pear\" can return \"Pear Jam\", \"Potato and Pear Soup\" etc.\n",
" if ingredient in node:\n",
" ingredients.append(node)\n",
"\n",
" return ingredients"
],
"execution_count": 77,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NtUUmE6IH3V_",
"outputId": "be3930b1-5d4b-4f17-9bdf-753aa3713aa7"
},
"source": [
"find_ingredient(nodes)"
],
"execution_count": 78,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Bartlett Pears', 'Sparkling Orange Juice & Prickly Pear Beverage']"
]
},
"metadata": {
"tags": []
},
"execution_count": 78
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zha_f54kH_v4",
"outputId": "bf4ae9da-af2b-4ab5-dc33-fa10acb8a7d3"
},
"source": [
"find_ingredient(nodes, ingredient='Onion')"
],
"execution_count": 79,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Red Onion', 'Yellow Onions', 'French Onion Dip', 'Yellow Onion']"
]
},
"metadata": {
"tags": []
},
"execution_count": 79
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wT5w70jrQhcl"
},
"source": [
"## Finding neighbour items"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HtSwrjRMIgBw"
},
"source": [
"# Traverse the graph by selecting the most weighted item\n",
"def get_neighbours(G, item, topn=10):\n",
" ''' Function that returns the neighbours of a node\n",
" Args: G - the netwowrkx Graph object\n",
" item: the start node for searching -> str\n",
" topn: number of neighbours to return -> int\n",
" Returns: a list of grocery items occuring in a basket together\n",
" '''\n",
"\n",
" # items = list(G.neighbors(item))\n",
" weights = {}\n",
" # Get all the neighbours of a node and sort them by their edge weight\n",
" for nodes in list(G.edges(str(item))):\n",
" weights[nodes[1]] = G.get_edge_data(nodes[0], nodes[1])['weight']\n",
" weights_sorted = {k: v for k, v in sorted(weights.items(), key=lambda x: x[1], reverse=True)}\n",
" # Filter so we just have the topn items\n",
" items = list(weights_sorted.keys())[0:topn]\n",
"\n",
" return items"
],
"execution_count": 83,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DifcqCnhFyYi"
},
"source": [
"## Using Plotly to make the graph interactive"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Iirc0sYKG10x"
},
"source": [
"def create_graph_display(G, pos):\n",
" ''' Function for displaying the graph; most of this code is taken\n",
" from the Plotly site\n",
" Args: G - networkx graph object\n",
" pos - positions of nodes\n",
" '''\n",
" nodes = [node for node in G.nodes()]\n",
"\n",
" edge_x = []\n",
" edge_y = []\n",
" for edge in G.edges():\n",
" x0, y0 = pos[edge[0]]\n",
" x1, y1 = pos[edge[1]]\n",
" edge_x.append(x0)\n",
" edge_x.append(x1)\n",
" edge_x.append(None)\n",
" edge_y.append(y0)\n",
" edge_y.append(y1)\n",
" edge_y.append(None)\n",
"\n",
" edge_trace = go.Scatter(\n",
" x=edge_x, y=edge_y,\n",
" line=dict(width=0.5, color='#888'),\n",
" hoverinfo='none',\n",
" mode='lines')\n",
"\n",
" node_x = []\n",
" node_y = []\n",
" for node in G.nodes():\n",
" x, y = pos[node]\n",
" node_x.append(x)\n",
" node_y.append(y)\n",
"\n",
" node_trace = go.Scatter(\n",
" x=node_x, y=node_y,\n",
" mode='markers+text',\n",
" hoverinfo='text',\n",
" hovertext=\"10\",\n",
" text=\"\",\n",
" textfont=dict(\n",
" family=\"sans serif\",\n",
" size=11\n",
" ),\n",
" marker=dict(\n",
" showscale=True,\n",
" colorscale='YlGnBu',\n",
" reversescale=False,\n",
" color=[],\n",
" size=8,\n",
" colorbar=dict(\n",
" thickness=10,\n",
" title='Node Connections',\n",
" xanchor='left',\n",
" titleside='right'\n",
" ),\n",
" line_width=1))\n",
"\n",
" # Update the text displayed on mouse over\n",
" node_adjacencies = []\n",
" node_text = []\n",
" for node, adjacencies in enumerate(G.adjacency()):\n",
" node_adjacencies.append(len(adjacencies[1]))\n",
" node_text.append(f'{nodes[node]}: {str(len(adjacencies[1]))} connections')\n",
"\n",
" node_trace.marker.color = node_adjacencies\n",
" node_trace.hovertext = node_text\n",
"\n",
" return node_trace, edge_trace"
],
"execution_count": 73,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 542
},
"id": "U0WGBsllF6L5",
"outputId": "a5c6d605-7c29-46c9-fb37-6c853c017de6"
},
"source": [
"node_trace, edge_trace = create_graph_display(G, pos)\n",
"\n",
"fig = go.Figure(data=[edge_trace, node_trace],\n",
" layout=go.Layout(\n",
" title='
Graph of shopping cart items',\n",
" titlefont_size=16,\n",
" showlegend=False,\n",
" hovermode='closest',\n",
" margin=dict(b=20,l=5,r=5,t=40),\n",
" annotations=[ dict(\n",
" text=\"some text\",\n",
" showarrow=False,\n",
" xref=\"paper\", yref=\"paper\",\n",
" x=0.005, y=-0.002 ) ],\n",
" xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
" yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))\n",
" )\n",
"\n",
"\n",
"fig.show()"
],
"execution_count": 74,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"\n",
"\n",
" \n",
" \n",
" \n",
" \n",
"
\n",
" \n",
"
\n",
"\n",
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lv8kwW6FKsGv"
},
"source": [
"## Building the Dash App"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WtYDHUpUKu7E"
},
"source": [
"### Define the app"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gqQ05tatKyyb"
},
"source": [
"external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']\n",
"app = JupyterDash(__name__, external_stylesheets=external_stylesheets)\n",
"server = app.server\n",
"app.title='Groceries on a graph'\n",
"list_dict = [{}]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ijCXElvK8zZ"
},
"source": [
"### Load an initial Graph"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ec1kwyaqK21M"
},
"source": [
"with open('large_graph.pickle', 'rb') as f:\n",
" G_init = pickle.load(f)\n",
"nodes = list(G_init.nodes)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "i5fHFu5CK_zf"
},
"source": [
"### Create a global button tracker"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HYMOSW1ULCDI"
},
"source": [
"BUTTON_CLICKED = None\n",
"button_style = {'margin-right': '5px',\n",
" 'margin-left': '5px',\n",
" 'margin-top': '5px',\n",
" 'margin-bottom': '5px'}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zbxPYineLGuH"
},
"source": [
"### Define Layout"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Fr3rSoYKLH9T"
},
"source": [
"tabs_styles = {\n",
" 'height': '44px'\n",
"}\n",
"tab_style = {\n",
" 'borderBottom': '1px solid #d6d6d6',\n",
" 'padding': '6px',\n",
" 'fontWeight': 'bold'\n",
"}\n",
"\n",
"tab_selected_style = {\n",
" 'borderTop': '0px solid #d6d6d6',\n",
" 'borderBottom': '1px solid #d6d6d6',\n",
" 'backgroundColor': '#1a1a1a',\n",
" 'color': 'white',\n",
" 'padding': '6px'\n",
"}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SY0407OtLQYT"
},
"source": [
"> Note: Layout comprises two tabs - one for viewing of the graph and the other for making a shopping list"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AkTAMMDoLqpU"
},
"source": [
"app.layout = html.Div([\n",
" html.Div(className='row', children=[\n",
" html.H1(children='Grocery Graph Network')\n",
" ], style={'textAlign': 'center',\n",
" 'backgroundColor': '#1a1a1a',\n",
" 'color': 'white'}),\n",
" html.Div([\n",
" ## Show the tabs\n",
" dcc.Tabs(id='tabs-example', value='tab-1',\n",
" children=[\n",
" dcc.Tab(label='Explore Network Graph', value='tab-1',\n",
" style=tab_style,\n",
" selected_style=tab_selected_style),\n",
" dcc.Tab(label='Shopping List Builder', value='tab-2',\n",
" style=tab_style,\n",
" selected_style=tab_selected_style)\n",
" ], style=tabs_styles\n",
" )\n",
" ]\n",
" ),\n",
" html.Div(id='tabs-output')\n",
"])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HG4H_IZrLtZ8"
},
"source": [
"### Callback for selecting/changing tabs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "58sbECR0L2Hy"
},
"source": [
"@app.callback(Output('tabs-output', 'children'),\n",
" Input('tabs-example', 'value'))\n",
"def render_content(value):\n",
" if value == 'tab-1':\n",
" return display_network_graph()\n",
" elif value == 'tab-2':\n",
" return display_shopping_list()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "X5MwKYV9METD"
},
"source": [
"### Function for displaying the network graph\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zGdZG9hSMHIT"
},
"source": [
"def display_network_graph():\n",
" ''' Function that displays the network tab'''\n",
"\n",
" # Setup a dropdown menu for the inputs to the graph\n",
" dd_segment = dcc.Dropdown(\n",
" id='dd_segment',\n",
" className='dropdown',\n",
" options=[{'label': 'Small', 'value': 0},\n",
" {'label': 'Medium', 'value': 1},\n",
" {'label': 'Large', 'value': 2}],\n",
" value=2\n",
" )\n",
" # Create a div for the input settings, which includes the dropdown declared above\n",
" input_settings = html.Div([\n",
" html.Div(className='row', children=[\n",
" html.Div(className='col',\n",
" children=[\n",
" html.H4(\"Select Segment\"),\n",
" html.P(\"Select a segment, named according to basket size\"),\n",
" dd_segment\n",
" ],\n",
" style={'width': '30%', 'display': 'inline-block'}\n",
" )\n",
" ]),\n",
" # Display the main graph, with loading icon whilst loading\n",
" html.Div(className='row', children=[\n",
" html.Div(children=[\n",
" dcc.Loading(id='loading-icon',\n",
" children=[\n",
" dcc.Graph(id='graph-graphic')\n",
" ])\n",
" ])\n",
" ])\n",
" ])\n",
"\n",
" return input_settings"
],
"execution_count": 91,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nAUj6pO2MR07"
},
"source": [
"### Function for displaying the items on the shopping recommender/list tab"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cMHGrTwNMQOw"
},
"source": [
"def display_shopping_list():\n",
" return html.Div([\n",
" html.Div(className='row', children=[\n",
" html.Div(className='col', children=[\n",
" html.P(\"Description: Search for items to start building your shopping list. \\\n",
" Once you've selected an item, similar items will be recommended to you\")\n",
" ])\n",
" ]),\n",
" html.Div(className='row', children=[\n",
" html.Div(className='six columns', children=[\n",
" html.Div(className='row', children=[\n",
" # Search box in here\n",
" html.H3(\"Search items\"),\n",
" dcc.Input(\n",
" id='item_search',\n",
" type='search',\n",
" placeholder='Search Shopping Items',\n",
" debounce=True,\n",
" value='Pears'\n",
" ),\n",
" html.P(\"\"),\n",
" # Radio items to choose how\n",
" html.P(\"Method for making recommendations: \"),\n",
" dcc.RadioItems(id='sim-radio',\n",
" options=[{'label': 'Similar', 'value': 'similar'},\n",
" {'label': 'Neighbours', 'value': 'neighbours'}],\n",
" value='similar',\n",
" labelStyle={'display': 'inline-block'}),\n",
"\n",
" html.P(\"\"),\n",
" html.P(id='explainer',\n",
" children=[\"The items below are the closest that match your search\"],\n",
" style={'font-weight': 'bold'}),\n",
" html.P(\"\")\n",
" ]\n",
" ),\n",
" # Container that loads the items closest to what you searched for or\n",
" # recommended items\n",
" html.Div(className='row', id='button-container', children=[]\n",
"\n",
" )],\n",
" ),\n",
" html.Div(className='six columns', children=[\n",
" html.Div(className='row', children=[\n",
" html.Div(className='six columns', children=[html.H3(\"Your Shopping List\")])\n",
" ]),\n",
" # Dyanamic shopping list is built here\n",
" html.Div(className='row', id='shopping-list-container', children=[]\n",
" )\n",
" ],\n",
"\n",
" )\n",
"\n",
" ])\n",
" ])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "qdHO4M99Mbcy"
},
"source": [
"### Callback for displaying similar/neighbor items (tab 2)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5gb97cDeMism"
},
"source": [
"@app.callback(\n",
" Output('button-container', 'children'),\n",
" Output('shopping-list-container', 'children'),\n",
" Output('explainer', 'children'),\n",
" [Input('item_search', 'value'),\n",
" Input({'type': 'button', 'index': ALL}, 'n_clicks'),\n",
" Input('sim-radio', 'value')],\n",
" [State('button-container', 'children'),\n",
" State('shopping-list-container', 'children'),\n",
" State('explainer', 'children')]\n",
")\n",
"def display_search_buttons(item, vals, sim_val, buttons, shopping_list_items, explainer_text):\n",
" ''' Function that runs all of the updates on for the shopping list'''\n",
" ctx = dash.callback_context\n",
" # Make the text in the same format as the times\n",
" item = item.title()\n",
"\n",
" # If something has been triggered and it's the page loading or it's an item searched for\n",
" # then load the items as per what was searched for\n",
" if ctx.triggered is not None and \\\n",
" ctx.triggered[0]['prop_id'] == '.' or \\\n",
" ctx.triggered[0]['prop_id'] == 'item_search.value':\n",
" buttons = []\n",
" # Search for shopping items based on the item typed in the search bar\n",
" # and create buttons\n",
" shopping_items = find_ingredient(nodes, item)\n",
" # Sort the items by length to make the display look nicer\n",
" shopping_items.sort(key=len)\n",
" counter = 0\n",
" for i, it in enumerate(shopping_items):\n",
" new_button = html.Button(\n",
" f'{it}',\n",
" id={'type': 'button',\n",
" 'index': it\n",
" },\n",
" n_clicks=0,\n",
" style=button_style\n",
" )\n",
" counter += 1\n",
" buttons.append(new_button)\n",
" # Stop too many items from being added\n",
" if counter > 20:\n",
" break\n",
" explainer_text = r\"\"\"The items below are the closest that match your search\"\"\"\n",
"\n",
" # Check a button was clicked & that it's at least the first click (no auto clicks when the page loads)\n",
" # & it's not the search input being searched in and it's not the radio button being checked\n",
" elif ctx.triggered and ctx.triggered[0]['value'] != 0 and \\\n",
" ctx.triggered[0]['value'] is not None and \\\n",
" ctx.triggered[0]['prop_id'] != 'item_search.value' and \\\n",
" ctx.triggered[0]['prop_id'] != 'sim-radio.value':\n",
"\n",
" # Get the name of the grocery item\n",
" button_clicked = re.findall(r':\"(.*?)\"', ctx.triggered[0]['prop_id'])[0]\n",
" # track the button clicked for the next elif\n",
" global BUTTON_CLICKED\n",
" BUTTON_CLICKED = button_clicked\n",
" # Add it to the shopping list\n",
" new_item = html.P(\n",
" f'{button_clicked}'\n",
" )\n",
" shopping_list_items.append(new_item)\n",
"\n",
" # Erase the list of ingredients and present similar ingredients by searching the graph\n",
" buttons, explainer_text = recommend_groceries(button_clicked, sim_val)\n",
"\n",
" # Check if someone does something and if that something is changing the value on the\n",
" # similarity measure radio button and\n",
" elif ctx.triggered is not None and \\\n",
" BUTTON_CLICKED is not None and \\\n",
" ctx.triggered[0]['prop_id'] == 'sim-radio.value':\n",
" buttons, explainer_text = recommend_groceries(BUTTON_CLICKED, sim_val)\n",
"\n",
" return buttons, shopping_list_items, explainer_text"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2m5kNGehMt7T"
},
"source": [
"### Function that returns a list of recommended groceries"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3R_ePlVtMsCH"
},
"source": [
"def recommend_groceries(button_clicked, sim_val):\n",
" ''' Function that returns a list of recommended groceries\n",
" Args: button_clicked - the button (item) clicked by the user\n",
" sim_val - the type of similarity the user wants for recommendations\n",
" either 'similar' or 'neighbours'\n",
" Returns: a list of recommended items for the user\n",
" '''\n",
" buttons = []\n",
" # Get recommendations based on the similarity method chosen by the user\n",
" if sim_val == \"similar\":\n",
" recommendations = similar_embeddings(button_clicked, 10)\n",
" else:\n",
" recommendations = get_neighbours(G_init, button_clicked)\n",
" # Update the explainer so the user knows what's going on\n",
" explainer_text = r\"\"\"These items are recommended based on the last item added to your basket\"\"\"\n",
"\n",
" # Stop a system-hanging number of recommendations being added\n",
" if len(recommendations) > 20:\n",
" recommendations = recommendations[:20]\n",
" # Sort the recommendations based on how many words each comprises; this makes the display\n",
" # look nicer\n",
" recommendations.sort(key=len)\n",
" # Add the recommended items to the buttons for adding to the shopping list\n",
" for i, it in enumerate(recommendations):\n",
" new_button = html.Button(\n",
" f'{it}',\n",
" id={'type': 'button',\n",
" 'index': it\n",
" },\n",
" n_clicks=0,\n",
" style=button_style\n",
" )\n",
" buttons.append(new_button)\n",
" return buttons, explainer_text"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJVC9bZ0Mw1E"
},
"source": [
"### Callback for updating the graph network\n",
"\n",
"> Note: Graph display will update when the user changes the segment size"
]
},
{
"cell_type": "code",
"metadata": {
"id": "enwvE-sGIobl"
},
"source": [
"@app.callback(\n",
" Output('graph-graphic', 'figure'),\n",
" [Input('dd_segment', 'value')]\n",
")\n",
"def update_graph(segment):\n",
" ''' Function to load a pre-computed graph network based on the segment selected by\n",
" the user in the dropdown\n",
" Args: segment - the segment selected by the user in the dropdown\n",
" '''\n",
" # Load the graph data and create the nodes and elements needed for display\n",
" pos, G = load_graph(segment=segment)\n",
"\n",
" node_trace, edge_trace = create_graph_display(G, pos)\n",
" # Display the graph\n",
" fig = go.Figure(data=[edge_trace, node_trace],\n",
" layout=go.Layout(\n",
" title='',\n",
" titlefont_size=16,\n",
" showlegend=False,\n",
" hovermode='closest',\n",
" margin=dict(b=20, l=5, r=5, t=40),\n",
" annotations=[dict(\n",
" text=\"\",\n",
" showarrow=False,\n",
" xref=\"paper\", yref=\"paper\",\n",
" x=0.005, y=-0.002)],\n",
" xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
" yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))\n",
" )\n",
"\n",
" return fig"
],
"execution_count": 88,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"id": "eFPsZbX4KARM",
"outputId": "297dd5f2-13d5-4b19-8d77-062c78cb8656"
},
"source": [
"app.run_server(mode='external')"
],
"execution_count": 89,
"outputs": [
{
"output_type": "stream",
"text": [
"Dash app running on:\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"(async (port, path, text, element) => {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" element.appendChild(document.createTextNode(''));\n",
" const url = await google.colab.kernel.proxyPort(port);\n",
" const anchor = document.createElement('a');\n",
" anchor.href = new URL(path, url).toString();\n",
" anchor.target = '_blank';\n",
" anchor.setAttribute('data-href', url + path);\n",
" anchor.textContent = text;\n",
" element.appendChild(anchor);\n",
" })(8050, \"/\", \"http://127.0.0.1:8050/\", window.element)"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uYilo8-GNwKC"
},
"source": [
"### Analyzing callback map\n",
"\n",
"This is retrieved from the dash layout"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jxdxHzufN4k0"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "geSxcD6AO-Xk"
},
"source": [
"### Tab 1 visual"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CTvMQCszPERo"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQg9jcIOPoBx"
},
"source": [
"### Tab 2 visual"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_S3pls0HN365"
},
"source": [
""
]
}
]
}