浏览代码

feat: collab component state handling rewrite & fixes (#5046)

David Luzar 3 年之前
父节点
当前提交
dac8dda4d4

+ 4 - 2
src/components/CollabButton.scss

@@ -18,13 +18,15 @@
       left: -5px;
     }
     min-width: 1em;
+    min-height: 1em;
+    line-height: 1;
     position: absolute;
     bottom: -5px;
     padding: 3px;
     border-radius: 50%;
     background-color: $oc-green-6;
     color: $oc-white;
-    font-size: 0.7em;
-    font-family: var(--ui-font);
+    font-size: 0.6em;
+    font-family: "Cascadia";
   }
 }

+ 1 - 1
src/components/CollabButton.tsx

@@ -28,7 +28,7 @@ const CollabButton = ({
         aria-label={t("labels.liveCollaboration")}
         showAriaLabel={useDevice().isMobile}
       >
-        {collaboratorCount > 0 && (
+        {isCollaborating && (
           <div className="CollabButton-collaborators">{collaboratorCount}</div>
         )}
       </ToolButton>

+ 0 - 42
src/createInverseContext.tsx

@@ -1,42 +0,0 @@
-import React from "react";
-
-export const createInverseContext = <T extends unknown = null>(
-  initialValue: T,
-) => {
-  const Context = React.createContext(initialValue) as React.Context<T> & {
-    _updateProviderValue?: (value: T) => void;
-  };
-
-  class InverseConsumer extends React.Component {
-    state = { value: initialValue };
-    constructor(props: any) {
-      super(props);
-      Context._updateProviderValue = (value: T) => this.setState({ value });
-    }
-    render() {
-      return (
-        <Context.Provider value={this.state.value}>
-          {this.props.children}
-        </Context.Provider>
-      );
-    }
-  }
-
-  class InverseProvider extends React.Component<{ value: T }> {
-    componentDidMount() {
-      Context._updateProviderValue?.(this.props.value);
-    }
-    componentDidUpdate() {
-      Context._updateProviderValue?.(this.props.value);
-    }
-    render() {
-      return <Context.Consumer>{() => this.props.children}</Context.Consumer>;
-    }
-  }
-
-  return {
-    Context,
-    Consumer: InverseConsumer,
-    Provider: InverseProvider,
-  };
-};

+ 121 - 112
src/excalidraw-app/collab/CollabWrapper.tsx → src/excalidraw-app/collab/Collab.tsx

@@ -8,10 +8,12 @@ import {
   ExcalidrawElement,
   InitializedExcalidrawImageElement,
 } from "../../element/types";
-import { getSceneVersion } from "../../packages/excalidraw/index";
+import {
+  getSceneVersion,
+  restoreElements,
+} from "../../packages/excalidraw/index";
 import { Collaborator, Gesture } from "../../types";
 import {
-  getFrame,
   preventUnload,
   resolvablePromise,
   withBatchedUpdates,
@@ -47,11 +49,9 @@ import {
 } from "../data/localStorage";
 import Portal from "./Portal";
 import RoomDialog from "./RoomDialog";
-import { createInverseContext } from "../../createInverseContext";
 import { t } from "../../i18n";
 import { UserIdleState } from "../../types";
 import { IDLE_THRESHOLD, ACTIVE_THRESHOLD } from "../../constants";
-import { trackEvent } from "../../analytics";
 import {
   encodeFilesForUpload,
   FileManager,
@@ -70,52 +70,45 @@ import {
 import { decryptData } from "../../data/encryption";
 import { resetBrowserStateVersions } from "../data/tabSync";
 import { LocalData } from "../data/LocalData";
+import { atom, useAtom } from "jotai";
+import { jotaiStore } from "../../jotai";
+
+export const collabAPIAtom = atom<CollabAPI | null>(null);
+export const collabDialogShownAtom = atom(false);
+export const isCollaboratingAtom = atom(false);
 
 interface CollabState {
-  modalIsShown: boolean;
   errorMessage: string;
   username: string;
-  userState: UserIdleState;
   activeRoomLink: string;
 }
 
-type CollabInstance = InstanceType<typeof CollabWrapper>;
+type CollabInstance = InstanceType<typeof Collab>;
 
 export interface CollabAPI {
   /** function so that we can access the latest value from stale callbacks */
   isCollaborating: () => boolean;
-  username: CollabState["username"];
-  userState: CollabState["userState"];
   onPointerUpdate: CollabInstance["onPointerUpdate"];
-  initializeSocketClient: CollabInstance["initializeSocketClient"];
-  onCollabButtonClick: CollabInstance["onCollabButtonClick"];
+  startCollaboration: CollabInstance["startCollaboration"];
+  stopCollaboration: CollabInstance["stopCollaboration"];
   syncElements: CollabInstance["syncElements"];
   fetchImageFilesFromFirebase: CollabInstance["fetchImageFilesFromFirebase"];
   setUsername: (username: string) => void;
 }
 
-interface Props {
+interface PublicProps {
   excalidrawAPI: ExcalidrawImperativeAPI;
-  onRoomClose?: () => void;
 }
 
-const {
-  Context: CollabContext,
-  Consumer: CollabContextConsumer,
-  Provider: CollabContextProvider,
-} = createInverseContext<{ api: CollabAPI | null }>({ api: null });
-
-export { CollabContext, CollabContextConsumer };
+type Props = PublicProps & { modalIsShown: boolean };
 
-class CollabWrapper extends PureComponent<Props, CollabState> {
+class Collab extends PureComponent<Props, CollabState> {
   portal: Portal;
   fileManager: FileManager;
   excalidrawAPI: Props["excalidrawAPI"];
   activeIntervalId: number | null;
   idleTimeoutId: number | null;
 
-  // marked as private to ensure we don't change it outside this class
-  private _isCollaborating: boolean = false;
   private socketInitializationTimer?: number;
   private lastBroadcastedOrReceivedSceneVersion: number = -1;
   private collaborators = new Map<string, Collaborator>();
@@ -123,10 +116,8 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
   constructor(props: Props) {
     super(props);
     this.state = {
-      modalIsShown: false,
       errorMessage: "",
       username: importUsernameFromLocalStorage() || "",
-      userState: UserIdleState.ACTIVE,
       activeRoomLink: "",
     };
     this.portal = new Portal(this);
@@ -164,6 +155,18 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     window.addEventListener(EVENT.BEFORE_UNLOAD, this.beforeUnload);
     window.addEventListener(EVENT.UNLOAD, this.onUnload);
 
+    const collabAPI: CollabAPI = {
+      isCollaborating: this.isCollaborating,
+      onPointerUpdate: this.onPointerUpdate,
+      startCollaboration: this.startCollaboration,
+      syncElements: this.syncElements,
+      fetchImageFilesFromFirebase: this.fetchImageFilesFromFirebase,
+      stopCollaboration: this.stopCollaboration,
+      setUsername: this.setUsername,
+    };
+
+    jotaiStore.set(collabAPIAtom, collabAPI);
+
     if (
       process.env.NODE_ENV === ENV.TEST ||
       process.env.NODE_ENV === ENV.DEVELOPMENT
@@ -196,7 +199,11 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     }
   }
 
-  isCollaborating = () => this._isCollaborating;
+  isCollaborating = () => jotaiStore.get(isCollaboratingAtom)!;
+
+  private setIsCollaborating = (isCollaborating: boolean) => {
+    jotaiStore.set(isCollaboratingAtom, isCollaborating);
+  };
 
   private onUnload = () => {
     this.destroySocketClient({ isUnload: true });
@@ -208,7 +215,7 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     );
 
     if (
-      this._isCollaborating &&
+      this.isCollaborating() &&
       (this.fileManager.shouldPreventUnload(syncableElements) ||
         !isSavedToFirebase(this.portal, syncableElements))
     ) {
@@ -252,12 +259,7 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     }
   };
 
-  openPortal = async () => {
-    trackEvent("share", "room creation", `ui (${getFrame()})`);
-    return this.initializeSocketClient(null);
-  };
-
-  closePortal = () => {
+  stopCollaboration = (keepRemoteState = true) => {
     this.queueBroadcastAllElements.cancel();
     this.queueSaveToFirebase.cancel();
     this.loadImageFiles.cancel();
@@ -267,16 +269,26 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
         this.excalidrawAPI.getSceneElementsIncludingDeleted(),
       ),
     );
-    if (window.confirm(t("alerts.collabStopOverridePrompt"))) {
+
+    if (this.portal.socket && this.fallbackInitializationHandler) {
+      this.portal.socket.off(
+        "connect_error",
+        this.fallbackInitializationHandler,
+      );
+    }
+
+    if (!keepRemoteState) {
+      LocalData.fileStorage.reset();
+      this.destroySocketClient();
+    } else if (window.confirm(t("alerts.collabStopOverridePrompt"))) {
       // hack to ensure that we prefer we disregard any new browser state
       // that could have been saved in other tabs while we were collaborating
       resetBrowserStateVersions();
 
       window.history.pushState({}, APP_NAME, window.location.origin);
       this.destroySocketClient();
-      trackEvent("share", "room closed");
 
-      this.props.onRoomClose?.();
+      LocalData.fileStorage.reset();
 
       const elements = this.excalidrawAPI
         .getSceneElementsIncludingDeleted()
@@ -295,20 +307,20 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
   };
 
   private destroySocketClient = (opts?: { isUnload: boolean }) => {
+    this.lastBroadcastedOrReceivedSceneVersion = -1;
+    this.portal.close();
+    this.fileManager.reset();
     if (!opts?.isUnload) {
+      this.setIsCollaborating(false);
+      this.setState({
+        activeRoomLink: "",
+      });
       this.collaborators = new Map();
       this.excalidrawAPI.updateScene({
         collaborators: this.collaborators,
       });
-      this.setState({
-        activeRoomLink: "",
-      });
-      this._isCollaborating = false;
       LocalData.resumeSave("collaboration");
     }
-    this.lastBroadcastedOrReceivedSceneVersion = -1;
-    this.portal.close();
-    this.fileManager.reset();
   };
 
   private fetchImageFilesFromFirebase = async (scene: {
@@ -349,7 +361,9 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     }
   };
 
-  private initializeSocketClient = async (
+  private fallbackInitializationHandler: null | (() => any) = null;
+
+  startCollaboration = async (
     existingRoomLinkData: null | { roomId: string; roomKey: string },
   ): Promise<ImportedDataState | null> => {
     if (this.portal.socket) {
@@ -372,13 +386,23 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
 
     const scenePromise = resolvablePromise<ImportedDataState | null>();
 
-    this._isCollaborating = true;
+    this.setIsCollaborating(true);
     LocalData.pauseSave("collaboration");
 
     const { default: socketIOClient } = await import(
       /* webpackChunkName: "socketIoClient" */ "socket.io-client"
     );
 
+    const fallbackInitializationHandler = () => {
+      this.initializeRoom({
+        roomLinkData: existingRoomLinkData,
+        fetchScene: true,
+      }).then((scene) => {
+        scenePromise.resolve(scene);
+      });
+    };
+    this.fallbackInitializationHandler = fallbackInitializationHandler;
+
     try {
       const socketServerData = await getCollabServer();
 
@@ -391,6 +415,8 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
         roomId,
         roomKey,
       );
+
+      this.portal.socket.once("connect_error", fallbackInitializationHandler);
     } catch (error: any) {
       console.error(error);
       this.setState({ errorMessage: error.message });
@@ -419,13 +445,10 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
 
     // fallback in case you're not alone in the room but still don't receive
     // initial SCENE_INIT message
-    this.socketInitializationTimer = window.setTimeout(() => {
-      this.initializeRoom({
-        roomLinkData: existingRoomLinkData,
-        fetchScene: true,
-      });
-      scenePromise.resolve(null);
-    }, INITIAL_SCENE_UPDATE_TIMEOUT);
+    this.socketInitializationTimer = window.setTimeout(
+      fallbackInitializationHandler,
+      INITIAL_SCENE_UPDATE_TIMEOUT,
+    );
 
     // All socket listeners are moving to Portal
     this.portal.socket.on(
@@ -530,6 +553,12 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
       }
     | { fetchScene: false; roomLinkData?: null }) => {
     clearTimeout(this.socketInitializationTimer!);
+    if (this.portal.socket && this.fallbackInitializationHandler) {
+      this.portal.socket.off(
+        "connect_error",
+        this.fallbackInitializationHandler,
+      );
+    }
     if (fetchScene && roomLinkData && this.portal.socket) {
       this.excalidrawAPI.resetScene();
 
@@ -567,6 +596,8 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     const localElements = this.getSceneElementsIncludingDeleted();
     const appState = this.excalidrawAPI.getAppState();
 
+    remoteElements = restoreElements(remoteElements, null);
+
     const reconciledElements = _reconcileElements(
       localElements,
       remoteElements,
@@ -672,19 +703,17 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
   };
 
   setCollaborators(sockets: string[]) {
-    this.setState((state) => {
-      const collaborators: InstanceType<typeof CollabWrapper>["collaborators"] =
-        new Map();
-      for (const socketId of sockets) {
-        if (this.collaborators.has(socketId)) {
-          collaborators.set(socketId, this.collaborators.get(socketId)!);
-        } else {
-          collaborators.set(socketId, {});
-        }
+    const collaborators: InstanceType<typeof Collab>["collaborators"] =
+      new Map();
+    for (const socketId of sockets) {
+      if (this.collaborators.has(socketId)) {
+        collaborators.set(socketId, this.collaborators.get(socketId)!);
+      } else {
+        collaborators.set(socketId, {});
       }
-      this.collaborators = collaborators;
-      this.excalidrawAPI.updateScene({ collaborators });
-    });
+    }
+    this.collaborators = collaborators;
+    this.excalidrawAPI.updateScene({ collaborators });
   }
 
   public setLastBroadcastedOrReceivedSceneVersion = (version: number) => {
@@ -713,7 +742,6 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
   );
 
   onIdleStateChange = (userState: UserIdleState) => {
-    this.setState({ userState });
     this.portal.broadcastIdleChange(userState);
   };
 
@@ -747,18 +775,22 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     this.setLastBroadcastedOrReceivedSceneVersion(newVersion);
   }, SYNC_FULL_SCENE_INTERVAL_MS);
 
-  queueSaveToFirebase = throttle(() => {
-    if (this.portal.socketInitialized) {
-      this.saveCollabRoomToFirebase(
-        getSyncableElements(
-          this.excalidrawAPI.getSceneElementsIncludingDeleted(),
-        ),
-      );
-    }
-  }, SYNC_FULL_SCENE_INTERVAL_MS);
+  queueSaveToFirebase = throttle(
+    () => {
+      if (this.portal.socketInitialized) {
+        this.saveCollabRoomToFirebase(
+          getSyncableElements(
+            this.excalidrawAPI.getSceneElementsIncludingDeleted(),
+          ),
+        );
+      }
+    },
+    SYNC_FULL_SCENE_INTERVAL_MS,
+    { leading: false },
+  );
 
   handleClose = () => {
-    this.setState({ modalIsShown: false });
+    jotaiStore.set(collabDialogShownAtom, false);
   };
 
   setUsername = (username: string) => {
@@ -770,35 +802,10 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
     saveUsernameToLocalStorage(username);
   };
 
-  onCollabButtonClick = () => {
-    this.setState({
-      modalIsShown: true,
-    });
-  };
-
-  /** PRIVATE. Use `this.getContextValue()` instead. */
-  private contextValue: CollabAPI | null = null;
-
-  /** Getter of context value. Returned object is stable. */
-  getContextValue = (): CollabAPI => {
-    if (!this.contextValue) {
-      this.contextValue = {} as CollabAPI;
-    }
-
-    this.contextValue.isCollaborating = this.isCollaborating;
-    this.contextValue.username = this.state.username;
-    this.contextValue.onPointerUpdate = this.onPointerUpdate;
-    this.contextValue.initializeSocketClient = this.initializeSocketClient;
-    this.contextValue.onCollabButtonClick = this.onCollabButtonClick;
-    this.contextValue.syncElements = this.syncElements;
-    this.contextValue.fetchImageFilesFromFirebase =
-      this.fetchImageFilesFromFirebase;
-    this.contextValue.setUsername = this.setUsername;
-    return this.contextValue;
-  };
-
   render() {
-    const { modalIsShown, username, errorMessage, activeRoomLink } = this.state;
+    const { username, errorMessage, activeRoomLink } = this.state;
+
+    const { modalIsShown } = this.props;
 
     return (
       <>
@@ -808,8 +815,8 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
             activeRoomLink={activeRoomLink}
             username={username}
             onUsernameChange={this.onUsernameChange}
-            onRoomCreate={this.openPortal}
-            onRoomDestroy={this.closePortal}
+            onRoomCreate={() => this.startCollaboration(null)}
+            onRoomDestroy={this.stopCollaboration}
             setErrorMessage={(errorMessage) => {
               this.setState({ errorMessage });
             }}
@@ -822,11 +829,6 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
             onClose={() => this.setState({ errorMessage: "" })}
           />
         )}
-        <CollabContextProvider
-          value={{
-            api: this.getContextValue(),
-          }}
-        />
       </>
     );
   }
@@ -834,7 +836,7 @@ class CollabWrapper extends PureComponent<Props, CollabState> {
 
 declare global {
   interface Window {
-    collab: InstanceType<typeof CollabWrapper>;
+    collab: InstanceType<typeof Collab>;
   }
 }
 
@@ -845,4 +847,11 @@ if (
   window.collab = window.collab || ({} as Window["collab"]);
 }
 
-export default CollabWrapper;
+const _Collab: React.FC<PublicProps> = (props) => {
+  const [collabDialogShown] = useAtom(collabDialogShownAtom);
+  return <Collab {...props} modalIsShown={collabDialogShown} />;
+};
+
+export default _Collab;
+
+export type TCollabClass = Collab;

+ 3 - 3
src/excalidraw-app/collab/Portal.tsx

@@ -4,7 +4,7 @@ import {
   SocketUpdateDataSource,
 } from "../data";
 
-import CollabWrapper from "./CollabWrapper";
+import { TCollabClass } from "./Collab";
 
 import { ExcalidrawElement } from "../../element/types";
 import {
@@ -20,14 +20,14 @@ import { BroadcastedExcalidrawElement } from "./reconciliation";
 import { encryptData } from "../../data/encryption";
 
 class Portal {
-  collab: CollabWrapper;
+  collab: TCollabClass;
   socket: SocketIOClient.Socket | null = null;
   socketInitialized: boolean = false; // we don't want the socket to emit any updates until it is fully initialized
   roomId: string | null = null;
   roomKey: string | null = null;
   broadcastedElementVersions: Map<string, number> = new Map();
 
-  constructor(collab: CollabWrapper) {
+  constructor(collab: TCollabClass) {
     this.collab = collab;
   }
 

+ 10 - 2
src/excalidraw-app/collab/RoomDialog.tsx

@@ -14,6 +14,8 @@ import { t } from "../../i18n";
 import "./RoomDialog.scss";
 import Stack from "../../components/Stack";
 import { AppState } from "../../types";
+import { trackEvent } from "../../analytics";
+import { getFrame } from "../../utils";
 
 const getShareIcon = () => {
   const navigator = window.navigator as any;
@@ -95,7 +97,10 @@ const RoomDialog = ({
                 title={t("roomDialog.button_startSession")}
                 aria-label={t("roomDialog.button_startSession")}
                 showAriaLabel={true}
-                onClick={onRoomCreate}
+                onClick={() => {
+                  trackEvent("share", "room creation", `ui (${getFrame()})`);
+                  onRoomCreate();
+                }}
               />
             </div>
           </>
@@ -160,7 +165,10 @@ const RoomDialog = ({
                 title={t("roomDialog.button_stopSession")}
                 aria-label={t("roomDialog.button_stopSession")}
                 showAriaLabel={true}
-                onClick={onRoomDestroy}
+                onClick={() => {
+                  trackEvent("share", "room closed");
+                  onRoomDestroy();
+                }}
               />
             </div>
           </>

+ 8 - 1
src/excalidraw-app/data/index.ts

@@ -134,9 +134,16 @@ export type SocketUpdateData =
     _brand: "socketUpdateData";
   };
 
+const RE_COLLAB_LINK = /^#room=([a-zA-Z0-9_-]+),([a-zA-Z0-9_-]+)$/;
+
+export const isCollaborationLink = (link: string) => {
+  const hash = new URL(link).hash;
+  return RE_COLLAB_LINK.test(hash);
+};
+
 export const getCollaborationLinkData = (link: string) => {
   const hash = new URL(link).hash;
-  const match = hash.match(/^#room=([a-zA-Z0-9_-]+),([a-zA-Z0-9_-]+)$/);
+  const match = hash.match(RE_COLLAB_LINK);
   if (match && match[2].length !== 22) {
     window.alert(t("alerts.invalidEncryptionKey"));
     return null;

+ 54 - 27
src/excalidraw-app/index.tsx

@@ -1,5 +1,5 @@
 import LanguageDetector from "i18next-browser-languagedetector";
-import { useCallback, useContext, useEffect, useRef, useState } from "react";
+import { useCallback, useEffect, useRef, useState } from "react";
 import { trackEvent } from "../analytics";
 import { getDefaultAppState } from "../appState";
 import { ErrorDialog } from "../components/ErrorDialog";
@@ -45,20 +45,26 @@ import {
   STORAGE_KEYS,
   SYNC_BROWSER_TABS_TIMEOUT,
 } from "./app_constants";
-import CollabWrapper, {
+import Collab, {
   CollabAPI,
-  CollabContext,
-  CollabContextConsumer,
-} from "./collab/CollabWrapper";
+  collabAPIAtom,
+  collabDialogShownAtom,
+  isCollaboratingAtom,
+} from "./collab/Collab";
 import { LanguageList } from "./components/LanguageList";
-import { exportToBackend, getCollaborationLinkData, loadScene } from "./data";
+import {
+  exportToBackend,
+  getCollaborationLinkData,
+  isCollaborationLink,
+  loadScene,
+} from "./data";
 import {
   getLibraryItemsFromStorage,
   importFromLocalStorage,
   importUsernameFromLocalStorage,
 } from "./data/localStorage";
 import CustomStats from "./CustomStats";
-import { restoreAppState, RestoredDataState } from "../data/restore";
+import { restore, restoreAppState, RestoredDataState } from "../data/restore";
 import { Tooltip } from "../components/Tooltip";
 import { shield } from "../components/icons";
 
@@ -72,6 +78,9 @@ import { loadFilesFromFirebase } from "./data/firebase";
 import { LocalData } from "./data/LocalData";
 import { isBrowserStorageStateNewer } from "./data/tabSync";
 import clsx from "clsx";
+import { Provider, useAtom } from "jotai";
+import { jotaiStore, useAtomWithInitialValue } from "../jotai";
+import { reconcileElements } from "./collab/reconciliation";
 import { parseLibraryTokensFromUrl, useHandleLibrary } from "../data/library";
 
 const isExcalidrawPlusSignedUser = document.cookie.includes(
@@ -170,7 +179,7 @@ const initializeScene = async (opts: {
 
   if (roomLinkData) {
     return {
-      scene: await opts.collabAPI.initializeSocketClient(roomLinkData),
+      scene: await opts.collabAPI.startCollaboration(roomLinkData),
       isExternalScene: true,
       id: roomLinkData.roomId,
       key: roomLinkData.roomKey,
@@ -242,7 +251,11 @@ const ExcalidrawWrapper = () => {
   const [excalidrawAPI, excalidrawRefCallback] =
     useCallbackRefState<ExcalidrawImperativeAPI>();
 
-  const collabAPI = useContext(CollabContext)?.api;
+  const [collabAPI] = useAtom(collabAPIAtom);
+  const [, setCollabDialogShown] = useAtom(collabDialogShownAtom);
+  const [isCollaborating] = useAtomWithInitialValue(isCollaboratingAtom, () => {
+    return isCollaborationLink(window.location.href);
+  });
 
   useHandleLibrary({
     excalidrawAPI,
@@ -320,21 +333,44 @@ const ExcalidrawWrapper = () => {
       }
     };
 
-    initializeScene({ collabAPI }).then((data) => {
+    initializeScene({ collabAPI }).then(async (data) => {
       loadImages(data, /* isInitialLoad */ true);
-      initialStatePromiseRef.current.promise.resolve(data.scene);
+
+      initialStatePromiseRef.current.promise.resolve({
+        ...data.scene,
+        // at this point the state may have already been updated (e.g. when
+        // collaborating, we may have received updates from other clients)
+        appState: restoreAppState(
+          data.scene?.appState,
+          excalidrawAPI.getAppState(),
+        ),
+        elements: reconcileElements(
+          data.scene?.elements || [],
+          excalidrawAPI.getSceneElementsIncludingDeleted(),
+          excalidrawAPI.getAppState(),
+        ),
+      });
     });
 
     const onHashChange = async (event: HashChangeEvent) => {
       event.preventDefault();
       const libraryUrlTokens = parseLibraryTokensFromUrl();
       if (!libraryUrlTokens) {
+        if (
+          collabAPI.isCollaborating() &&
+          !isCollaborationLink(window.location.href)
+        ) {
+          collabAPI.stopCollaboration(false);
+        }
+        excalidrawAPI.updateScene({ appState: { isLoading: true } });
+
         initializeScene({ collabAPI }).then((data) => {
           loadImages(data);
           if (data.scene) {
             excalidrawAPI.updateScene({
               ...data.scene,
-              appState: restoreAppState(data.scene.appState, null),
+              ...restore(data.scene, null, null),
+              commitToHistory: true,
             });
           }
         });
@@ -636,23 +672,19 @@ const ExcalidrawWrapper = () => {
     localStorage.setItem(STORAGE_KEYS.LOCAL_STORAGE_LIBRARY, serializedItems);
   };
 
-  const onRoomClose = useCallback(() => {
-    LocalData.fileStorage.reset();
-  }, []);
-
   return (
     <div
       style={{ height: "100%" }}
       className={clsx("excalidraw-app", {
-        "is-collaborating": collabAPI?.isCollaborating(),
+        "is-collaborating": isCollaborating,
       })}
     >
       <Excalidraw
         ref={excalidrawRefCallback}
         onChange={onChange}
         initialData={initialStatePromiseRef.current.promise}
-        onCollabButtonClick={collabAPI?.onCollabButtonClick}
-        isCollaborating={collabAPI?.isCollaborating()}
+        onCollabButtonClick={() => setCollabDialogShown(true)}
+        isCollaborating={isCollaborating}
         onPointerUpdate={collabAPI?.onPointerUpdate}
         UIOptions={{
           canvasActions: {
@@ -686,12 +718,7 @@ const ExcalidrawWrapper = () => {
         onLibraryChange={onLibraryChange}
         autoFocus={true}
       />
-      {excalidrawAPI && (
-        <CollabWrapper
-          excalidrawAPI={excalidrawAPI}
-          onRoomClose={onRoomClose}
-        />
-      )}
+      {excalidrawAPI && <Collab excalidrawAPI={excalidrawAPI} />}
       {errorMessage && (
         <ErrorDialog
           message={errorMessage}
@@ -705,9 +732,9 @@ const ExcalidrawWrapper = () => {
 const ExcalidrawApp = () => {
   return (
     <TopErrorBoundary>
-      <CollabContextConsumer>
+      <Provider unstable_createStore={() => jotaiStore}>
         <ExcalidrawWrapper />
-      </CollabContextConsumer>
+      </Provider>
     </TopErrorBoundary>
   );
 };

+ 24 - 1
src/jotai.ts

@@ -1,4 +1,27 @@
-import { unstable_createStore } from "jotai";
+import { unstable_createStore, useAtom, WritableAtom } from "jotai";
+import { useLayoutEffect } from "react";
 
 export const jotaiScope = Symbol();
 export const jotaiStore = unstable_createStore();
+
+export const useAtomWithInitialValue = <
+  T extends unknown,
+  A extends WritableAtom<T, T>,
+>(
+  atom: A,
+  initialValue: T | (() => T),
+) => {
+  const [value, setValue] = useAtom(atom);
+
+  useLayoutEffect(() => {
+    if (typeof initialValue === "function") {
+      // @ts-ignore
+      setValue(initialValue());
+    } else {
+      setValue(initialValue);
+    }
+    // eslint-disable-next-line react-hooks/exhaustive-deps
+  }, []);
+
+  return [value, setValue] as const;
+};

+ 2 - 1
src/tests/collab.test.tsx

@@ -50,6 +50,7 @@ jest.mock("socket.io-client", () => {
     return {
       close: () => {},
       on: () => {},
+      once: () => {},
       off: () => {},
       emit: () => {},
     };
@@ -77,7 +78,7 @@ describe("collaboration", () => {
       ]);
       expect(API.getStateHistory().length).toBe(1);
     });
-    window.collab.openPortal();
+    window.collab.startCollaboration(null);
     await waitFor(() => {
       expect(h.elements).toEqual([expect.objectContaining({ id: "A" })]);
       expect(API.getStateHistory().length).toBe(1);