Ver Fonte

improve & granularize ExcalidrawElement types (#991)

* improve & granularize ExcalidrawElement types

* fix incorrectly passing type

* fix tests

* fix more tests

* fix unnecessary spreads & refactor

* add comments
David Luzar há 5 anos atrás
pai
commit
373d16abe6

+ 4 - 3
src/actions/actionDuplicateSelection.ts

@@ -11,9 +11,10 @@ export const actionDuplicateSelection = register({
       elements: elements.reduce(
         (acc: Array<ExcalidrawElement>, element: ExcalidrawElement) => {
           if (appState.selectedElementIds[element.id]) {
-            const newElement = duplicateElement(element);
-            newElement.x = newElement.x + 10;
-            newElement.y = newElement.y + 10;
+            const newElement = duplicateElement(element, {
+              x: element.x + 10,
+              y: element.y + 10,
+            });
             appState.selectedElementIds[newElement.id] = true;
             delete appState.selectedElementIds[element.id];
             return acc.concat([element, newElement]);

+ 162 - 144
src/components/App.tsx

@@ -21,6 +21,7 @@ import {
   getDrawingVersion,
   getSyncableElements,
   hasNonDeletedElements,
+  newLinearElement,
 } from "../element";
 import {
   deleteSelectedElements,
@@ -47,7 +48,7 @@ import { restore } from "../data/restore";
 
 import { renderScene } from "../renderer";
 import { AppState, GestureEvent, Gesture } from "../types";
-import { ExcalidrawElement } from "../element/types";
+import { ExcalidrawElement, ExcalidrawLinearElement } from "../element/types";
 
 import {
   isWritableElement,
@@ -99,6 +100,7 @@ import { mutateElement, newElementWith } from "../element/mutateElement";
 import { invalidateShapeForElement } from "../renderer/renderElement";
 import { unstable_batchedUpdates } from "react-dom";
 import { SceneStateCallbackRemover } from "../scene/globalScene";
+import { isLinearElement } from "../element/typeChecks";
 import { rescalePoints } from "../points";
 
 function withBatchedUpdates<
@@ -707,21 +709,18 @@ export class App extends React.Component<any, AppState> {
             window.devicePixelRatio,
           );
 
-          const element = newTextElement(
-            newElement(
-              "text",
-              x,
-              y,
-              this.state.currentItemStrokeColor,
-              this.state.currentItemBackgroundColor,
-              this.state.currentItemFillStyle,
-              this.state.currentItemStrokeWidth,
-              this.state.currentItemRoughness,
-              this.state.currentItemOpacity,
-            ),
-            data.text,
-            this.state.currentItemFont,
-          );
+          const element = newTextElement({
+            x: x,
+            y: y,
+            strokeColor: this.state.currentItemStrokeColor,
+            backgroundColor: this.state.currentItemBackgroundColor,
+            fillStyle: this.state.currentItemFillStyle,
+            strokeWidth: this.state.currentItemStrokeWidth,
+            roughness: this.state.currentItemRoughness,
+            opacity: this.state.currentItemOpacity,
+            text: data.text,
+            font: this.state.currentItemFont,
+          });
 
           globalSceneState.replaceAllElements([
             ...globalSceneState.getAllElements(),
@@ -960,21 +959,18 @@ export class App extends React.Component<any, AppState> {
     const element =
       elementAtPosition && isTextElement(elementAtPosition)
         ? elementAtPosition
-        : newTextElement(
-            newElement(
-              "text",
-              x,
-              y,
-              this.state.currentItemStrokeColor,
-              this.state.currentItemBackgroundColor,
-              this.state.currentItemFillStyle,
-              this.state.currentItemStrokeWidth,
-              this.state.currentItemRoughness,
-              this.state.currentItemOpacity,
-            ),
-            "", // default text
-            this.state.currentItemFont, // default font
-          );
+        : newTextElement({
+            x: x,
+            y: y,
+            strokeColor: this.state.currentItemStrokeColor,
+            backgroundColor: this.state.currentItemBackgroundColor,
+            fillStyle: this.state.currentItemFillStyle,
+            strokeWidth: this.state.currentItemStrokeWidth,
+            roughness: this.state.currentItemRoughness,
+            opacity: this.state.currentItemOpacity,
+            text: "",
+            font: this.state.currentItemFont,
+          });
 
     this.setState({ editingElement: element });
 
@@ -1044,11 +1040,8 @@ export class App extends React.Component<any, AppState> {
         if (text) {
           globalSceneState.replaceAllElements([
             ...globalSceneState.getAllElements(),
-            {
-              // we need to recreate the element to update dimensions &
-              //  position
-              ...newTextElement(element, text, element.font),
-            },
+            // we need to recreate the element to update dimensions & position
+            newTextElement({ ...element, text, font: element.font }),
           ]);
         }
         this.setState(prevState => ({
@@ -1332,22 +1325,6 @@ export class App extends React.Component<any, AppState> {
     const originX = x;
     const originY = y;
 
-    let element = newElement(
-      this.state.elementType,
-      x,
-      y,
-      this.state.currentItemStrokeColor,
-      this.state.currentItemBackgroundColor,
-      this.state.currentItemFillStyle,
-      this.state.currentItemStrokeWidth,
-      this.state.currentItemRoughness,
-      this.state.currentItemOpacity,
-    );
-
-    if (isTextElement(element)) {
-      element = newTextElement(element, "", this.state.currentItemFont);
-    }
-
     type ResizeTestType = ReturnType<typeof resizeTest>;
     let resizeHandle: ResizeTestType = false;
     let isResizingElements = false;
@@ -1437,30 +1414,30 @@ export class App extends React.Component<any, AppState> {
       this.setState({ selectedElementIds: {} });
     }
 
-    if (isTextElement(element)) {
+    if (this.state.elementType === "text") {
       // if we're currently still editing text, clicking outside
       //  should only finalize it, not create another (irrespective
       //  of state.elementLocked)
       if (this.state.editingElement?.type === "text") {
         return;
       }
-      if (elementIsAddedToSelection) {
-        element = hitElement!;
-      }
-      let textX = event.clientX;
-      let textY = event.clientY;
-      if (!event.altKey) {
-        const snappedToCenterPosition = this.getTextWysiwygSnappedToCenterPosition(
-          x,
-          y,
-        );
-        if (snappedToCenterPosition) {
-          element.x = snappedToCenterPosition.elementCenterX;
-          element.y = snappedToCenterPosition.elementCenterY;
-          textX = snappedToCenterPosition.wysiwygX;
-          textY = snappedToCenterPosition.wysiwygY;
-        }
-      }
+
+      const snappedToCenterPosition = event.altKey
+        ? null
+        : this.getTextWysiwygSnappedToCenterPosition(x, y);
+
+      const element = newTextElement({
+        x: snappedToCenterPosition?.elementCenterX ?? x,
+        y: snappedToCenterPosition?.elementCenterY ?? y,
+        strokeColor: this.state.currentItemStrokeColor,
+        backgroundColor: this.state.currentItemBackgroundColor,
+        fillStyle: this.state.currentItemFillStyle,
+        strokeWidth: this.state.currentItemStrokeWidth,
+        roughness: this.state.currentItemRoughness,
+        opacity: this.state.currentItemOpacity,
+        text: "",
+        font: this.state.currentItemFont,
+      });
 
       const resetSelection = () => {
         this.setState({
@@ -1471,8 +1448,8 @@ export class App extends React.Component<any, AppState> {
 
       textWysiwyg({
         initText: "",
-        x: textX,
-        y: textY,
+        x: snappedToCenterPosition?.wysiwygX ?? event.clientX,
+        y: snappedToCenterPosition?.wysiwygY ?? event.clientY,
         strokeColor: this.state.currentItemStrokeColor,
         opacity: this.state.currentItemOpacity,
         font: this.state.currentItemFont,
@@ -1481,9 +1458,11 @@ export class App extends React.Component<any, AppState> {
           if (text) {
             globalSceneState.replaceAllElements([
               ...globalSceneState.getAllElements(),
-              {
-                ...newTextElement(element, text, this.state.currentItemFont),
-              },
+              newTextElement({
+                ...element,
+                text,
+                font: this.state.currentItemFont,
+              }),
             ]);
           }
           this.setState(prevState => ({
@@ -1531,6 +1510,17 @@ export class App extends React.Component<any, AppState> {
           points: [...multiElement.points, [x - rx, y - ry]],
         });
       } else {
+        const element = newLinearElement({
+          type: this.state.elementType,
+          x: x,
+          y: y,
+          strokeColor: this.state.currentItemStrokeColor,
+          backgroundColor: this.state.currentItemBackgroundColor,
+          fillStyle: this.state.currentItemFillStyle,
+          strokeWidth: this.state.currentItemStrokeWidth,
+          roughness: this.state.currentItemRoughness,
+          opacity: this.state.currentItemOpacity,
+        });
         this.setState(prevState => ({
           selectedElementIds: {
             ...prevState.selectedElementIds,
@@ -1549,26 +1539,40 @@ export class App extends React.Component<any, AppState> {
           editingElement: element,
         });
       }
-    } else if (element.type === "selection") {
-      this.setState({
-        selectionElement: element,
-        draggingElement: element,
-      });
     } else {
-      globalSceneState.replaceAllElements([
-        ...globalSceneState.getAllElements(),
-        element,
-      ]);
-      this.setState({
-        multiElement: null,
-        draggingElement: element,
-        editingElement: element,
+      const element = newElement({
+        type: this.state.elementType,
+        x: x,
+        y: y,
+        strokeColor: this.state.currentItemStrokeColor,
+        backgroundColor: this.state.currentItemBackgroundColor,
+        fillStyle: this.state.currentItemFillStyle,
+        strokeWidth: this.state.currentItemStrokeWidth,
+        roughness: this.state.currentItemRoughness,
+        opacity: this.state.currentItemOpacity,
       });
+
+      if (element.type === "selection") {
+        this.setState({
+          selectionElement: element,
+          draggingElement: element,
+        });
+      } else {
+        globalSceneState.replaceAllElements([
+          ...globalSceneState.getAllElements(),
+          element,
+        ]);
+        this.setState({
+          multiElement: null,
+          draggingElement: element,
+          editingElement: element,
+        });
+      }
     }
 
     let resizeArrowFn:
       | ((
-          element: ExcalidrawElement,
+          element: ExcalidrawLinearElement,
           pointIndex: number,
           deltaX: number,
           deltaY: number,
@@ -1579,7 +1583,7 @@ export class App extends React.Component<any, AppState> {
       | null = null;
 
     const arrowResizeOrigin = (
-      element: ExcalidrawElement,
+      element: ExcalidrawLinearElement,
       pointIndex: number,
       deltaX: number,
       deltaY: number,
@@ -1604,7 +1608,9 @@ export class App extends React.Component<any, AppState> {
           x: dx,
           y: dy,
           points: element.points.map((point, i) =>
-            i === pointIndex ? [absPx - element.x, absPy - element.y] : point,
+            i === pointIndex
+              ? ([absPx - element.x, absPy - element.y] as const)
+              : point,
           ),
         });
       } else {
@@ -1612,14 +1618,16 @@ export class App extends React.Component<any, AppState> {
           x: element.x + deltaX,
           y: element.y + deltaY,
           points: element.points.map((point, i) =>
-            i === pointIndex ? [p1[0] - deltaX, p1[1] - deltaY] : point,
+            i === pointIndex
+              ? ([p1[0] - deltaX, p1[1] - deltaY] as const)
+              : point,
           ),
         });
       }
     };
 
     const arrowResizeEnd = (
-      element: ExcalidrawElement,
+      element: ExcalidrawLinearElement,
       pointIndex: number,
       deltaX: number,
       deltaY: number,
@@ -1636,13 +1644,15 @@ export class App extends React.Component<any, AppState> {
         );
         mutateElement(element, {
           points: element.points.map((point, i) =>
-            i === pointIndex ? [width, height] : point,
+            i === pointIndex ? ([width, height] as const) : point,
           ),
         });
       } else {
         mutateElement(element, {
           points: element.points.map((point, i) =>
-            i === pointIndex ? [p1[0] + deltaX, p1[1] + deltaY] : point,
+            i === pointIndex
+              ? ([p1[0] + deltaX, p1[1] + deltaY] as const)
+              : point,
           ),
         });
       }
@@ -1711,10 +1721,9 @@ export class App extends React.Component<any, AppState> {
           const deltaX = x - lastX;
           const deltaY = y - lastY;
           const element = selectedElements[0];
-          const isLinear = element.type === "line" || element.type === "arrow";
           switch (resizeHandle) {
             case "nw":
-              if (isLinear && element.points.length === 2) {
+              if (isLinearElement(element) && element.points.length === 2) {
                 const [, p1] = element.points;
 
                 if (!resizeArrowFn) {
@@ -1739,7 +1748,7 @@ export class App extends React.Component<any, AppState> {
               }
               break;
             case "ne":
-              if (isLinear && element.points.length === 2) {
+              if (isLinearElement(element) && element.points.length === 2) {
                 const [, p1] = element.points;
                 if (!resizeArrowFn) {
                   if (p1[0] >= 0) {
@@ -1761,7 +1770,7 @@ export class App extends React.Component<any, AppState> {
               }
               break;
             case "sw":
-              if (isLinear && element.points.length === 2) {
+              if (isLinearElement(element) && element.points.length === 2) {
                 const [, p1] = element.points;
                 if (!resizeArrowFn) {
                   if (p1[0] <= 0) {
@@ -1782,7 +1791,7 @@ export class App extends React.Component<any, AppState> {
               }
               break;
             case "se":
-              if (isLinear && element.points.length === 2) {
+              if (isLinearElement(element) && element.points.length === 2) {
                 const [, p1] = element.points;
                 if (!resizeArrowFn) {
                   if (p1[0] > 0 || p1[1] > 0) {
@@ -1807,14 +1816,18 @@ export class App extends React.Component<any, AppState> {
                 break;
               }
 
-              mutateElement(element, {
-                height,
-                y: element.y + deltaY,
-                points:
-                  element.points.length > 0
-                    ? rescalePoints(1, height, element.points)
-                    : undefined,
-              });
+              if (isLinearElement(element)) {
+                mutateElement(element, {
+                  height,
+                  y: element.y + deltaY,
+                  points: rescalePoints(1, height, element.points),
+                });
+              } else {
+                mutateElement(element, {
+                  height,
+                  y: element.y + deltaY,
+                });
+              }
 
               break;
             }
@@ -1825,15 +1838,18 @@ export class App extends React.Component<any, AppState> {
                 // Someday we should implement logic to flip the shape. But for now, just stop.
                 break;
               }
-
-              mutateElement(element, {
-                width,
-                x: element.x + deltaX,
-                points:
-                  element.points.length > 0
-                    ? rescalePoints(0, width, element.points)
-                    : undefined,
-              });
+              if (isLinearElement(element)) {
+                mutateElement(element, {
+                  width,
+                  x: element.x + deltaX,
+                  points: rescalePoints(0, width, element.points),
+                });
+              } else {
+                mutateElement(element, {
+                  width,
+                  x: element.x + deltaX,
+                });
+              }
               break;
             }
             case "s": {
@@ -1842,14 +1858,16 @@ export class App extends React.Component<any, AppState> {
                 break;
               }
 
-              mutateElement(element, {
-                height,
-                points:
-                  element.points.length > 0
-                    ? rescalePoints(1, height, element.points)
-                    : undefined,
-              });
-
+              if (isLinearElement(element)) {
+                mutateElement(element, {
+                  height,
+                  points: rescalePoints(1, height, element.points),
+                });
+              } else {
+                mutateElement(element, {
+                  height,
+                });
+              }
               break;
             }
             case "e": {
@@ -1858,13 +1876,16 @@ export class App extends React.Component<any, AppState> {
                 break;
               }
 
-              mutateElement(element, {
-                width,
-                points:
-                  element.points.length > 0
-                    ? rescalePoints(0, width, element.points)
-                    : undefined,
-              });
+              if (isLinearElement(element)) {
+                mutateElement(element, {
+                  width,
+                  points: rescalePoints(0, width, element.points),
+                });
+              } else {
+                mutateElement(element, {
+                  width,
+                });
+              }
               break;
             }
           }
@@ -1934,10 +1955,7 @@ export class App extends React.Component<any, AppState> {
       let width = distance(originX, x);
       let height = distance(originY, y);
 
-      const isLinear =
-        this.state.elementType === "line" || this.state.elementType === "arrow";
-
-      if (isLinear) {
+      if (isLinearElement(draggingElement)) {
         draggingOccurred = true;
         const points = draggingElement.points;
         let dx = x - draggingElement.x;
@@ -2023,7 +2041,7 @@ export class App extends React.Component<any, AppState> {
       window.removeEventListener("pointermove", onPointerMove);
       window.removeEventListener("pointerup", onPointerUp);
 
-      if (elementType === "arrow" || elementType === "line") {
+      if (isLinearElement(draggingElement)) {
         if (draggingElement!.points.length > 1) {
           history.resumeRecording();
         }
@@ -2041,7 +2059,7 @@ export class App extends React.Component<any, AppState> {
             ],
           });
           this.setState({
-            multiElement: this.state.draggingElement,
+            multiElement: draggingElement,
             editingElement: this.state.draggingElement,
           });
         } else if (draggingOccurred && !multiElement) {
@@ -2215,12 +2233,12 @@ export class App extends React.Component<any, AppState> {
     const dx = x - elementsCenterX;
     const dy = y - elementsCenterY;
 
-    const newElements = clipboardElements.map(clipboardElements => {
-      const duplicate = duplicateElement(clipboardElements);
-      duplicate.x += dx - minX;
-      duplicate.y += dy - minY;
-      return duplicate;
-    });
+    const newElements = clipboardElements.map(element =>
+      duplicateElement(element, {
+        x: element.x + dx - minX,
+        y: element.y + dy - minY,
+      }),
+    );
 
     globalSceneState.replaceAllElements([
       ...globalSceneState.getAllElements(),

+ 3 - 6
src/components/HintViewer.tsx

@@ -5,6 +5,7 @@ import { getSelectedElements } from "../scene";
 
 import "./HintViewer.scss";
 import { AppState } from "../types";
+import { isLinearElement } from "../element/typeChecks";
 
 interface Hint {
   appState: AppState;
@@ -23,12 +24,8 @@ const getHints = ({ appState, elements }: Hint) => {
 
   if (isResizing) {
     const selectedElements = getSelectedElements(elements, appState);
-    if (
-      selectedElements.length === 1 &&
-      (selectedElements[0].type === "arrow" ||
-        selectedElements[0].type === "line") &&
-      selectedElements[0].points.length > 2
-    ) {
+    const targetElement = selectedElements[0];
+    if (isLinearElement(targetElement) && targetElement.points.length > 2) {
       return null;
     }
     return t("hints.resize");

+ 9 - 2
src/data/restore.ts

@@ -8,7 +8,9 @@ import nanoid from "nanoid";
 import { calculateScrollCenter } from "../scene";
 
 export function restore(
-  savedElements: readonly ExcalidrawElement[],
+  // we're making the elements mutable for this API because we want to
+  //  efficiently remove/tweak properties on them (to migrate old scenes)
+  savedElements: readonly Mutable<ExcalidrawElement>[],
   savedState: AppState | null,
   opts?: { scrollToContent: boolean },
 ): DataState {
@@ -35,6 +37,7 @@ export function restore(
             [element.width, element.height],
           ];
         }
+        element.points = points;
       } else if (element.type === "line") {
         // old spec, pre-arrows
         // old spec, post-arrows
@@ -46,8 +49,13 @@ export function restore(
         } else {
           points = element.points;
         }
+        element.points = points;
       } else {
         normalizeDimensions(element);
+        // old spec, where non-linear elements used to have empty points arrays
+        if ("points" in element) {
+          delete element.points;
+        }
       }
 
       return {
@@ -62,7 +70,6 @@ export function restore(
           element.opacity === null || element.opacity === undefined
             ? 100
             : element.opacity,
-        points,
       };
     });
 

+ 1 - 1
src/element/bounds.test.ts

@@ -3,7 +3,7 @@ import { ExcalidrawElement } from "./types";
 
 const _ce = ({ x, y, w, h }: { x: number; y: number; w: number; h: number }) =>
   ({
-    type: "test",
+    type: "rectangle",
     strokeColor: "#000",
     backgroundColor: "#000",
     fillStyle: "solid",

+ 10 - 4
src/element/bounds.ts

@@ -1,13 +1,14 @@
-import { ExcalidrawElement } from "./types";
+import { ExcalidrawElement, ExcalidrawLinearElement } from "./types";
 import { rotate } from "../math";
 import { Drawable } from "roughjs/bin/core";
 import { Point } from "../types";
 import { getShapeForElement } from "../renderer/renderElement";
+import { isLinearElement } from "./typeChecks";
 
 // If the element is created from right to left, the width is going to be negative
 // This set of functions retrieves the absolute position of the 4 points.
 export function getElementAbsoluteCoords(element: ExcalidrawElement) {
-  if (element.type === "arrow" || element.type === "line") {
+  if (isLinearElement(element)) {
     return getLinearElementAbsoluteBounds(element);
   }
   return [
@@ -33,7 +34,9 @@ export function getDiamondPoints(element: ExcalidrawElement) {
   return [topX, topY, rightX, rightY, bottomX, bottomY, leftX, leftY];
 }
 
-export function getLinearElementAbsoluteBounds(element: ExcalidrawElement) {
+export function getLinearElementAbsoluteBounds(
+  element: ExcalidrawLinearElement,
+) {
   if (element.points.length < 2 || !getShapeForElement(element)) {
     const { minX, minY, maxX, maxY } = element.points.reduce(
       (limits, [x, y]) => {
@@ -119,7 +122,10 @@ export function getLinearElementAbsoluteBounds(element: ExcalidrawElement) {
   ];
 }
 
-export function getArrowPoints(element: ExcalidrawElement, shape: Drawable[]) {
+export function getArrowPoints(
+  element: ExcalidrawLinearElement,
+  shape: Drawable[],
+) {
   const ops = shape[0].sets[0].ops;
 
   const data = ops[ops.length - 1].data;

+ 2 - 1
src/element/collision.ts

@@ -11,6 +11,7 @@ import { Point } from "../types";
 import { Drawable, OpSet } from "roughjs/bin/core";
 import { AppState } from "../types";
 import { getShapeForElement } from "../renderer/renderElement";
+import { isLinearElement } from "./typeChecks";
 
 function isElementDraggableFromInside(
   element: ExcalidrawElement,
@@ -158,7 +159,7 @@ export function hitTest(
       distanceBetweenPointAndSegment(x, y, leftX, leftY, topX, topY) <
         lineThreshold
     );
-  } else if (element.type === "arrow" || element.type === "line") {
+  } else if (isLinearElement(element)) {
     if (!getShapeForElement(element)) {
       return false;
     }

+ 6 - 1
src/element/index.ts

@@ -1,7 +1,12 @@
 import { ExcalidrawElement } from "./types";
 import { isInvisiblySmallElement } from "./sizeHelpers";
 
-export { newElement, newTextElement, duplicateElement } from "./newElement";
+export {
+  newElement,
+  newTextElement,
+  newLinearElement,
+  duplicateElement,
+} from "./newElement";
 export {
   getElementAbsoluteCoords,
   getCommonBounds,

+ 11 - 8
src/element/mutateElement.ts

@@ -13,33 +13,36 @@ type ElementUpdate<TElement extends ExcalidrawElement> = Omit<
 // The version is used to compare updates when more than one user is working in
 // the same drawing. Note: this will trigger the component to update. Make sure you
 // are calling it either from a React event handler or within unstable_batchedUpdates().
-export function mutateElement<TElement extends ExcalidrawElement>(
+export function mutateElement<TElement extends Mutable<ExcalidrawElement>>(
   element: TElement,
   updates: ElementUpdate<TElement>,
 ) {
-  const mutableElement = element as any;
+  // casting to any because can't use `in` operator
+  // (see https://github.com/microsoft/TypeScript/issues/21732)
+  const { points } = updates as any;
 
-  if (typeof updates.points !== "undefined") {
-    updates = { ...getSizeFromPoints(updates.points!), ...updates };
+  if (typeof points !== "undefined") {
+    updates = { ...getSizeFromPoints(points), ...updates };
   }
 
   for (const key in updates) {
     const value = (updates as any)[key];
     if (typeof value !== "undefined") {
-      mutableElement[key] = value;
+      // @ts-ignore
+      element[key] = value;
     }
   }
 
   if (
     typeof updates.height !== "undefined" ||
     typeof updates.width !== "undefined" ||
-    typeof updates.points !== "undefined"
+    typeof points !== "undefined"
   ) {
     invalidateShapeForElement(element);
   }
 
-  mutableElement.version++;
-  mutableElement.versionNonce = randomSeed();
+  element.version++;
+  element.versionNonce = randomSeed();
 
   globalSceneState.informMutation();
 }

+ 36 - 22
src/element/newElement.test.ts

@@ -1,4 +1,9 @@
-import { newElement, newTextElement, duplicateElement } from "./newElement";
+import {
+  newTextElement,
+  duplicateElement,
+  newLinearElement,
+} from "./newElement";
+import { mutateElement } from "./mutateElement";
 
 function isPrimitive(val: any) {
   const type = typeof val;
@@ -17,25 +22,27 @@ function assertCloneObjects(source: any, clone: any) {
 }
 
 it("clones arrow element", () => {
-  const element = newElement(
-    "arrow",
-    0,
-    0,
-    "#000000",
-    "transparent",
-    "hachure",
-    1,
-    1,
-    100,
-  );
+  const element = newLinearElement({
+    type: "arrow",
+    x: 0,
+    y: 0,
+    strokeColor: "#000000",
+    backgroundColor: "transparent",
+    fillStyle: "hachure",
+    strokeWidth: 1,
+    roughness: 1,
+    opacity: 100,
+  });
 
   // @ts-ignore
   element.__proto__ = { hello: "world" };
 
-  element.points = [
-    [1, 2],
-    [3, 4],
-  ];
+  mutateElement(element, {
+    points: [
+      [1, 2],
+      [3, 4],
+    ],
+  });
 
   const copy = duplicateElement(element);
 
@@ -59,17 +66,24 @@ it("clones arrow element", () => {
 });
 
 it("clones text element", () => {
-  const element = newTextElement(
-    newElement("text", 0, 0, "#000000", "transparent", "hachure", 1, 1, 100),
-    "hello",
-    "Arial 20px",
-  );
+  const element = newTextElement({
+    x: 0,
+    y: 0,
+    strokeColor: "#000000",
+    backgroundColor: "transparent",
+    fillStyle: "hachure",
+    strokeWidth: 1,
+    roughness: 1,
+    opacity: 100,
+    text: "hello",
+    font: "Arial 20px",
+  });
 
   const copy = duplicateElement(element);
 
   assertCloneObjects(element, copy);
 
-  expect(copy.points).not.toBe(element.points);
+  expect(copy).not.toHaveProperty("points");
   expect(copy).not.toHaveProperty("shape");
   expect(copy.id).not.toBe(element.id);
   expect(typeof copy.id).toBe("string");

+ 77 - 35
src/element/newElement.ts

@@ -1,25 +1,45 @@
 import { randomSeed } from "roughjs/bin/math";
 import nanoid from "nanoid";
-import { Point } from "../types";
 
-import { ExcalidrawElement, ExcalidrawTextElement } from "../element/types";
+import {
+  ExcalidrawElement,
+  ExcalidrawTextElement,
+  ExcalidrawLinearElement,
+  ExcalidrawGenericElement,
+} from "../element/types";
 import { measureText } from "../utils";
 
-export function newElement(
-  type: string,
-  x: number,
-  y: number,
-  strokeColor: string,
-  backgroundColor: string,
-  fillStyle: string,
-  strokeWidth: number,
-  roughness: number,
-  opacity: number,
-  width = 0,
-  height = 0,
+type ElementConstructorOpts = {
+  x: ExcalidrawGenericElement["x"];
+  y: ExcalidrawGenericElement["y"];
+  strokeColor: ExcalidrawGenericElement["strokeColor"];
+  backgroundColor: ExcalidrawGenericElement["backgroundColor"];
+  fillStyle: ExcalidrawGenericElement["fillStyle"];
+  strokeWidth: ExcalidrawGenericElement["strokeWidth"];
+  roughness: ExcalidrawGenericElement["roughness"];
+  opacity: ExcalidrawGenericElement["opacity"];
+  width?: ExcalidrawGenericElement["width"];
+  height?: ExcalidrawGenericElement["height"];
+};
+
+function _newElementBase<T extends ExcalidrawElement>(
+  type: T["type"],
+  {
+    x,
+    y,
+    strokeColor,
+    backgroundColor,
+    fillStyle,
+    strokeWidth,
+    roughness,
+    opacity,
+    width = 0,
+    height = 0,
+    ...rest
+  }: ElementConstructorOpts & Partial<ExcalidrawGenericElement>,
 ) {
-  const element = {
-    id: nanoid(),
+  return {
+    id: rest.id || nanoid(),
     type,
     x,
     y,
@@ -31,29 +51,36 @@ export function newElement(
     strokeWidth,
     roughness,
     opacity,
-    seed: randomSeed(),
-    points: [] as readonly Point[],
-    version: 1,
-    versionNonce: 0,
-    isDeleted: false,
+    seed: rest.seed ?? randomSeed(),
+    version: rest.version || 1,
+    versionNonce: rest.versionNonce ?? 0,
+    isDeleted: rest.isDeleted ?? false,
   };
-  return element;
+}
+
+export function newElement(
+  opts: {
+    type: ExcalidrawGenericElement["type"];
+  } & ElementConstructorOpts,
+): ExcalidrawGenericElement {
+  return _newElementBase<ExcalidrawGenericElement>(opts.type, opts);
 }
 
 export function newTextElement(
-  element: ExcalidrawElement,
-  text: string,
-  font: string,
-) {
+  opts: {
+    text: string;
+    font: string;
+  } & ElementConstructorOpts,
+): ExcalidrawTextElement {
+  const { text, font } = opts;
   const metrics = measureText(text, font);
-  const textElement: ExcalidrawTextElement = {
-    ...element,
-    type: "text",
+  const textElement = {
+    ..._newElementBase<ExcalidrawTextElement>("text", opts),
     text: text,
     font: font,
     // Center the text
-    x: element.x - metrics.width / 2,
-    y: element.y - metrics.height / 2,
+    x: opts.x - metrics.width / 2,
+    y: opts.y - metrics.height / 2,
     width: metrics.width,
     height: metrics.height,
     baseline: metrics.baseline,
@@ -62,6 +89,17 @@ export function newTextElement(
   return textElement;
 }
 
+export function newLinearElement(
+  opts: {
+    type: "arrow" | "line";
+  } & ElementConstructorOpts,
+): ExcalidrawLinearElement {
+  return {
+    ..._newElementBase<ExcalidrawLinearElement>(opts.type, opts),
+    points: [],
+  };
+}
+
 // Simplified deep clone for the purpose of cloning ExcalidrawElement only
 //  (doesn't clone Date, RegExp, Map, Set, Typed arrays etc.)
 //
@@ -100,11 +138,15 @@ function _duplicateElement(val: any, depth: number = 0) {
   return val;
 }
 
-export function duplicateElement(
-  element: ReturnType<typeof newElement>,
-): ReturnType<typeof newElement> {
-  const copy = _duplicateElement(element);
+export function duplicateElement<TElement extends Mutable<ExcalidrawElement>>(
+  element: TElement,
+  overrides?: Partial<TElement>,
+): TElement {
+  let copy: TElement = _duplicateElement(element);
   copy.id = nanoid();
   copy.seed = randomSeed();
+  if (overrides) {
+    copy = Object.assign(copy, overrides);
+  }
   return copy;
 }

+ 2 - 5
src/element/resizeTest.ts

@@ -2,6 +2,7 @@ import { ExcalidrawElement, PointerType } from "./types";
 
 import { handlerRectangles } from "./handlerRectangles";
 import { AppState } from "../types";
+import { isLinearElement } from "./typeChecks";
 
 type HandlerRectanglesRet = keyof ReturnType<typeof handlerRectangles>;
 
@@ -102,11 +103,7 @@ export function normalizeResizeHandle(
   element: ExcalidrawElement,
   resizeHandle: HandlerRectanglesRet,
 ): HandlerRectanglesRet {
-  if (
-    (element.width >= 0 && element.height >= 0) ||
-    element.type === "line" ||
-    element.type === "arrow"
-  ) {
+  if ((element.width >= 0 && element.height >= 0) || isLinearElement(element)) {
     return resizeHandle;
   }
 

+ 3 - 3
src/element/sizeHelpers.ts

@@ -1,8 +1,9 @@
 import { ExcalidrawElement } from "./types";
 import { mutateElement } from "./mutateElement";
+import { isLinearElement } from "./typeChecks";
 
 export function isInvisiblySmallElement(element: ExcalidrawElement): boolean {
-  if (element.type === "arrow" || element.type === "line") {
+  if (isLinearElement(element)) {
     return element.points.length < 2;
   }
   return element.width === 0 && element.height === 0;
@@ -78,8 +79,7 @@ export function normalizeDimensions(
   if (
     !element ||
     (element.width >= 0 && element.height >= 0) ||
-    element.type === "line" ||
-    element.type === "arrow"
+    isLinearElement(element)
   ) {
     return false;
   }

+ 13 - 1
src/element/typeChecks.ts

@@ -1,4 +1,8 @@
-import { ExcalidrawElement, ExcalidrawTextElement } from "./types";
+import {
+  ExcalidrawElement,
+  ExcalidrawTextElement,
+  ExcalidrawLinearElement,
+} from "./types";
 
 export function isTextElement(
   element: ExcalidrawElement,
@@ -6,6 +10,14 @@ export function isTextElement(
   return element.type === "text";
 }
 
+export function isLinearElement(
+  element?: ExcalidrawElement | null,
+): element is ExcalidrawLinearElement {
+  return (
+    element != null && (element.type === "arrow" || element.type === "line")
+  );
+}
+
 export function isExcalidrawElement(element: any): boolean {
   return (
     element?.type === "text" ||

+ 34 - 5
src/element/types.ts

@@ -1,20 +1,49 @@
-import { newElement } from "./newElement";
+import { Point } from "../types";
+
+type _ExcalidrawElementBase = Readonly<{
+  id: string;
+  x: number;
+  y: number;
+  strokeColor: string;
+  backgroundColor: string;
+  fillStyle: string;
+  strokeWidth: number;
+  roughness: number;
+  opacity: number;
+  width: number;
+  height: number;
+  seed: number;
+  version: number;
+  versionNonce: number;
+  isDeleted: boolean;
+}>;
+
+export type ExcalidrawGenericElement = _ExcalidrawElementBase & {
+  type: "selection" | "rectangle" | "diamond" | "ellipse";
+};
 
 /**
  * ExcalidrawElement should be JSON serializable and (eventually) contain
  * no computed data. The list of all ExcalidrawElements should be shareable
  * between peers and contain no state local to the peer.
  */
-export type ExcalidrawElement = Readonly<ReturnType<typeof newElement>>;
+export type ExcalidrawElement =
+  | ExcalidrawGenericElement
+  | ExcalidrawTextElement
+  | ExcalidrawLinearElement;
 
-export type ExcalidrawTextElement = ExcalidrawElement &
+export type ExcalidrawTextElement = _ExcalidrawElementBase &
   Readonly<{
     type: "text";
     font: string;
     text: string;
-    // for backward compatibility
-    actualBoundingBoxAscent?: number;
     baseline: number;
   }>;
 
+export type ExcalidrawLinearElement = _ExcalidrawElementBase &
+  Readonly<{
+    type: "arrow" | "line";
+    points: Point[];
+  }>;
+
 export type PointerType = "mouse" | "pen" | "touch";

+ 4 - 0
src/global.d.ts

@@ -5,3 +5,7 @@ interface Window {
 interface Clipboard extends EventTarget {
   write(data: any[]): Promise<void>;
 }
+
+type Mutable<T> = {
+  -readonly [P in keyof T]: T[P];
+};

+ 12 - 8
src/history.ts

@@ -2,6 +2,7 @@ import { AppState } from "./types";
 import { ExcalidrawElement } from "./element/types";
 import { clearAppStatePropertiesForHistory } from "./appState";
 import { newElementWith } from "./element/mutateElement";
+import { isLinearElement } from "./element/typeChecks";
 
 type Result = {
   appState: AppState;
@@ -24,14 +25,17 @@ export class SceneHistory {
   ) {
     return JSON.stringify({
       appState: clearAppStatePropertiesForHistory(appState),
-      elements: elements.map(element =>
-        newElementWith(element, {
-          points:
-            appState.multiElement && appState.multiElement.id === element.id
-              ? element.points.slice(0, -1)
-              : element.points,
-        }),
-      ),
+      elements: elements.map(element => {
+        if (isLinearElement(element)) {
+          return newElementWith(element, {
+            points:
+              appState.multiElement && appState.multiElement.id === element.id
+                ? element.points.slice(0, -1)
+                : element.points,
+          });
+        }
+        return newElementWith(element, {});
+      }),
     });
   }
 

+ 1 - 1
src/points.ts

@@ -12,7 +12,7 @@ export function rescalePoints(
   dimension: 0 | 1,
   nextDimensionSize: number,
   prevPoints: readonly Point[],
-): readonly Point[] {
+): Point[] {
   const prevDimValues = prevPoints.map(point => point[dimension]);
   const prevMaxDimension = Math.max(...prevDimValues);
   const prevMinDimension = Math.min(...prevDimValues);

+ 2 - 0
src/renderer/renderElement.ts

@@ -330,6 +330,7 @@ export function renderElement(
       break;
     }
     default: {
+      // @ts-ignore
       throw new Error(`Unimplemented type ${element.type}`);
     }
   }
@@ -420,6 +421,7 @@ export function renderElementToSvg(
         }
         svgRoot.appendChild(node);
       } else {
+        // @ts-ignore
         throw new Error(`Unimplemented type ${element.type}`);
       }
     }

+ 19 - 12
src/tests/dragCreate.test.tsx

@@ -4,6 +4,7 @@ import { App } from "../components/App";
 import * as Renderer from "../renderer/renderScene";
 import { KEYS } from "../keys";
 import { render, fireEvent } from "./test-utils";
+import { ExcalidrawLinearElement } from "../element/types";
 
 // Unmount ReactDOM from root
 ReactDOM.unmountComponentAtNode(document.getElementById("root")!);
@@ -122,12 +123,15 @@ describe("add element to the scene when pointer dragging long enough", () => {
     expect(h.appState.selectionElement).toBeNull();
 
     expect(h.elements.length).toEqual(1);
-    expect(h.elements[0].type).toEqual("arrow");
-    expect(h.elements[0].x).toEqual(30);
-    expect(h.elements[0].y).toEqual(20);
-    expect(h.elements[0].points.length).toEqual(2);
-    expect(h.elements[0].points[0]).toEqual([0, 0]);
-    expect(h.elements[0].points[1]).toEqual([30, 50]); // (60 - 30, 70 - 20)
+
+    const element = h.elements[0] as ExcalidrawLinearElement;
+
+    expect(element.type).toEqual("arrow");
+    expect(element.x).toEqual(30);
+    expect(element.y).toEqual(20);
+    expect(element.points.length).toEqual(2);
+    expect(element.points[0]).toEqual([0, 0]);
+    expect(element.points[1]).toEqual([30, 50]); // (60 - 30, 70 - 20)
   });
 
   it("line", () => {
@@ -151,12 +155,15 @@ describe("add element to the scene when pointer dragging long enough", () => {
     expect(h.appState.selectionElement).toBeNull();
 
     expect(h.elements.length).toEqual(1);
-    expect(h.elements[0].type).toEqual("line");
-    expect(h.elements[0].x).toEqual(30);
-    expect(h.elements[0].y).toEqual(20);
-    expect(h.elements[0].points.length).toEqual(2);
-    expect(h.elements[0].points[0]).toEqual([0, 0]);
-    expect(h.elements[0].points[1]).toEqual([30, 50]); // (60 - 30, 70 - 20)
+
+    const element = h.elements[0] as ExcalidrawLinearElement;
+
+    expect(element.type).toEqual("line");
+    expect(element.x).toEqual(30);
+    expect(element.y).toEqual(20);
+    expect(element.points.length).toEqual(2);
+    expect(element.points[0]).toEqual([0, 0]);
+    expect(element.points[1]).toEqual([30, 50]); // (60 - 30, 70 - 20)
   });
 });
 

+ 13 - 8
src/tests/multiPointCreate.test.tsx

@@ -4,6 +4,7 @@ import { render, fireEvent } from "./test-utils";
 import { App } from "../components/App";
 import * as Renderer from "../renderer/renderScene";
 import { KEYS } from "../keys";
+import { ExcalidrawLinearElement } from "../element/types";
 
 // Unmount ReactDOM from root
 ReactDOM.unmountComponentAtNode(document.getElementById("root")!);
@@ -88,10 +89,12 @@ describe("multi point mode in linear elements", () => {
     expect(renderScene).toHaveBeenCalledTimes(10);
     expect(h.elements.length).toEqual(1);
 
-    expect(h.elements[0].type).toEqual("arrow");
-    expect(h.elements[0].x).toEqual(30);
-    expect(h.elements[0].y).toEqual(30);
-    expect(h.elements[0].points).toEqual([
+    const element = h.elements[0] as ExcalidrawLinearElement;
+
+    expect(element.type).toEqual("arrow");
+    expect(element.x).toEqual(30);
+    expect(element.y).toEqual(30);
+    expect(element.points).toEqual([
       [0, 0],
       [20, 30],
       [70, 110],
@@ -125,10 +128,12 @@ describe("multi point mode in linear elements", () => {
     expect(renderScene).toHaveBeenCalledTimes(10);
     expect(h.elements.length).toEqual(1);
 
-    expect(h.elements[0].type).toEqual("line");
-    expect(h.elements[0].x).toEqual(30);
-    expect(h.elements[0].y).toEqual(30);
-    expect(h.elements[0].points).toEqual([
+    const element = h.elements[0] as ExcalidrawLinearElement;
+
+    expect(element.type).toEqual("line");
+    expect(element.x).toEqual(30);
+    expect(element.y).toEqual(30);
+    expect(element.points).toEqual([
       [0, 0],
       [20, 30],
       [70, 110],

+ 6 - 2
src/types.ts

@@ -1,4 +1,8 @@
-import { ExcalidrawElement, PointerType } from "./element/types";
+import {
+  ExcalidrawElement,
+  PointerType,
+  ExcalidrawLinearElement,
+} from "./element/types";
 import { SHAPES } from "./shapes";
 import { Point as RoughPoint } from "roughjs/bin/geometry";
 
@@ -8,7 +12,7 @@ export type Point = Readonly<RoughPoint>;
 export type AppState = {
   draggingElement: ExcalidrawElement | null;
   resizingElement: ExcalidrawElement | null;
-  multiElement: ExcalidrawElement | null;
+  multiElement: ExcalidrawLinearElement | null;
   selectionElement: ExcalidrawElement | null;
   // element being edited, but not necessarily added to elements array yet
   //  (e.g. text element when typing into the input)