From 67d44b2d1d39910b02b3563be841937130600298 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 20 Feb 2026 10:43:25 -0500 Subject: [PATCH 1/3] Releasing 1.2.0. --- MIGRATING.md | 7 ++- README.md | 58 ++++++++++++------- RELEASE.md | 2 +- docs/docs/index.md | 2 +- docs/docs/install.md | 13 +++-- docs/mkdocs.yml | 2 +- pom.xml | 5 +- tensorflow-core/pom.xml | 2 +- tensorflow-core/tensorflow-core-api/pom.xml | 2 +- .../tensorflow-core-generator/pom.xml | 2 +- .../tensorflow-core-native/pom.xml | 2 +- .../tensorflow-core-platform/pom.xml | 2 +- tensorflow-framework/pom.xml | 2 +- tensorflow-ndarray/pom.xml | 2 +- 14 files changed, 61 insertions(+), 42 deletions(-) diff --git a/MIGRATING.md b/MIGRATING.md index ac7276eba99..61148063487 100644 --- a/MIGRATING.md +++ b/MIGRATING.md @@ -1,9 +1,10 @@ # Migrating Between TensorFlow Java Releases -TensorFlow Java is still in an alpha stage, therefore is subject to contain breaking changes between the different releases. This guide explain in detail -how to migrate your code from a previous version to a new one that includes some changes that are not backward compatible. +This guide explains in detail how to migrate your code from a pre-1.0 release to a 1.0 or newer release. Post 1.0 +releases have API stability as much as possible, though upstream TensorFlow does remove ops from time to time and +consequently those ops will be removed from TensorFlow-Java -## Migrating to 1.0.0 +## Migrating to 1.0.0 or newer TensorFlow-Java 1.0.0 requires Java 11 or later. diff --git a/README.md b/README.md index 536ce3f27ba..cef72116bbe 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,8 @@ The following describes the layout of the repository and its different artifacts * Intended audience: neural network developers * For more information: [tensorflow-framework/README.md](tensorflow-framework/README.md) -*Note: The NdArray Library module has now its own [repository](https://github.com/tensorflow/java-ndarray) and has been moved out of TensorFlow Java.* +* `tensorflow-ndarray` + * API for creating and manipulating n-dimensional arrays, can be used independently from TensorFlow. ## Communication @@ -60,10 +61,11 @@ only binaries for the following are being **supported and distributed** by this - `linux-x86_64-gpu`: Linux platforms on Intel/AMD chips with Cuda GPU support - `linux-arm64`: Linux platforms on Arm chips - `macosx-arm64`: MacOS X platforms on Apple Silicon chips -- `windows-x86_64`: Windows platforms on Intel/AMD chips (v1.1.0 and earlier) Binaries for `macosx-x86_64` are available for TF-Java 1.0 series releases and earlier, they were dropped from -TF-Java 1.1 and newer as they are no longer supported or released by Google. +TF-Java 1.1 and newer as they are no longer supported or released by Google. Binaries for `windows-x86_64` are available +for TF-Java 1.1 and earlier, they were dropped for the 1.2 release and newer as the native binaries are no longer supported or +released by Google. For example, for building a JAR that uses TensorFlow and is targeted to be deployed only on Linux systems with no GPU support, you should add the following dependencies: @@ -71,18 +73,18 @@ systems with no GPU support, you should add the following dependencies: org.tensorflow tensorflow-core-api - 1.1.0 + 1.2.0 org.tensorflow tensorflow-core-native - 1.1.0 + 1.2.0 linux-x86_64 ``` Or Gradle: ```groovy -def tfVersion = '1.1.0' +def tfVersion = '1.2.0' implementation "org.tensorflow:tensorflow-core-api:$tfVersion" implementation "org.tensorflow:tensorflow-core-native:$tfVersion:linux-x86_64" ``` @@ -93,34 +95,33 @@ native dependencies as follows: org.tensorflow tensorflow-core-api - 1.1.0 + 1.2.0 org.tensorflow tensorflow-core-native - 1.1.0 - linux-x86_64-gpu + 1.2.0 + linux-arm64 org.tensorflow tensorflow-core-native - 1.1.0 - macosx-arm64 + 1.2.0 + linux-x86_64-gpu org.tensorflow tensorflow-core-native - 1.1.0 - windows-x86_64 + 1.2.0 + macosx-arm64 ``` Or Gradle: ```groovy -def tfVersion = '1.1.0' +def tfVersion = '1.2.0' implementation "org.tensorflow:tensorflow-core-api:$tfVersion" implementation "org.tensorflow:tensorflow-core-native:$tfVersion:linux-x86_64-gpu" implementation "org.tensorflow:tensorflow-core-native:$tfVersion:macosx-arm64" -implementation "org.tensorflow:tensorflow-core-native:$tfVersion:windows-x86_64" ``` Only one dependency can be added per platform, meaning that you cannot add native dependencies to both `linux-x86_64` and @@ -135,7 +136,7 @@ For Ubuntu 24.04, you can install them with the following command: In some cases, it might be preferable to add a single dependency that includes transitively all the artifacts required to run TensorFlow Java on any [supported platforms](README.md#individual-dependencies) -- `tensorflow-core-platform`: Includes `tensorflow-core-api`, plus native artifacts for `linux-x86_64`, `linux-x86_64-arm64`, `macosx-arm64` and `windows-x86_64` +- `tensorflow-core-platform`: Includes `tensorflow-core-api`, plus native artifacts for `linux-x86_64`, `linux-arm64`, and `macosx-arm64` For example, to run TensorFlow Java on any CPU platform for which a binary is being distributed by this project, you can simply add this dependency to your application: @@ -143,12 +144,12 @@ simply add this dependency to your application: org.tensorflow tensorflow-core-platform - 1.1.0 + 1.2.0 ``` Or Gradle: ```groovy -implementation "org.tensorflow:tensorflow-core-platform:1.1.0" +implementation "org.tensorflow:tensorflow-core-platform:1.2.0" ``` Be aware though that the builds of TensorFlow are quite voluminous and including too many native dependencies may @@ -177,7 +178,7 @@ to add Sonatype OSS repository in your `pom.xml`, like the following org.tensorflow tensorflow-core-platform - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ``` @@ -192,10 +193,24 @@ repositories { dependencies { // Example of dependency, see section above for more options - implementation "org.tensorflow:tensorflow-core-platform:1.2.0-SNAPSHOT" + implementation "org.tensorflow:tensorflow-core-platform:1.3.0-SNAPSHOT" } ``` +## TensorFlow native libraries + +TensorFlow-Java is built on top of the native TensorFlow library, and uses [JavaCPP](https://github.com/bytedeco/javacpp) +to call that library which in turn uses the Java Native Interface (JNI). In [Java 24 and newer](https://openjdk.org/jeps/472) +uses of JNI trigger a warning of the form: +``` +WARNING: A restricted method in java.lang.System has been called +WARNING: java.lang.System::loadLibrary has been called by org.bytedeco.javacpp.Loader in an unnamed module (file:/.../.m2/repository/org/bytedeco/javacpp/1.5.12/javacpp-1.5.12.jar) +WARNING: Use --enable-native-access=ALL-UNNAMED to avoid a warning for callers in this module +WARNING: Restricted methods will be blocked in a future release unless native access is enabled +``` +This is expected, and adding the `--enable-native-access=ALL-UNNAMED` flag to enable JNI will suppress it. In a future +Java version this warning may be turned into an error and the flag will be required to use TensorFlow-Java. + ## TensorFlow/Java Version Support This table shows the mapping between TensorFlow, TensorFlow Java and minimum supported Java versions. @@ -215,7 +230,8 @@ This table shows the mapping between TensorFlow, TensorFlow Java and minimum sup | 1.0.0-rc.2 | 2.16.2 | 11 | | 1.0.0 | 2.16.2 | 11 | | 1.1.0 | 2.18.0 | 11 | -| 1.2.0-SNAPSHOT | 2.21.0 | 11 | +| 1.2.0 | 2.21.0 | 11 | +| 1.3.0-SNAPSHOT | 2.21.0 | 11 | ## How to Contribute? diff --git a/RELEASE.md b/RELEASE.md index 66bd9dfaa9e..fcf1001439c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -191,7 +191,7 @@ Some things of note: ``` 2. In your local copy, checkout the master branch and increase the next snapshot version. ``` - mvn versions:set -DnewVersion=1.1.0-SNAPSHOT + mvn versions:set -DnewVersion=1.3.0-SNAPSHOT ``` 3. Update the TensorFlow Java version to reflect the new snapshot at the following locations: - https://github.com/tensorflow/java/blob/master/docs/install.md?plain=1#L104 diff --git a/docs/docs/index.md b/docs/docs/index.md index c9fcbf53e7e..8db1ff0e089 100755 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -26,7 +26,7 @@ The following describes the layout of the repository and its different artifacts * **Intended audience**: neural network developers. * Primary API for building and training neural networks with TensorFlow. -### [ndarray](https://github.com/tensorflow/java-ndarray) +### [tensorflow-ndarray](https://github.com/tensorflow/java/tree/master/tensorflow-ndarray) * **Intended audience**: any developer who needs a Java n-dimensional array implementation, whether or not they use it with TensorFlow. * Generic utility library for n-dimensional data I/O operations. * Used by TensorFlow but does not depend on TensorFlow. diff --git a/docs/docs/install.md b/docs/docs/install.md index 2fe676e956a..564414316dd 100755 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -21,9 +21,12 @@ following platforms: * Ubuntu 20.04 or higher; 64-bit, x86 * Ubuntu 22.04 or higher; 64-bit, arm * macOS 14 or higher; 64-bit, arm + +Tensorflow Java 1.1 and earlier has binaries for: + * Windows 10 or higher; 64-bit, x86 -TensorFlow Java 1.0 series and earlier releases also have binaries for: +TensorFlow Java 1.0 series and earlier has binaries for: * macOS 12 or higher; 64-bit, x86 @@ -63,7 +66,7 @@ For example, org.tensorflow tensorflow-core-platform - 1.1.0 + 1.2.0 ``` @@ -106,7 +109,7 @@ snapshots repository in your `pom.xml`. org.tensorflow tensorflow-core-platform - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ``` @@ -123,7 +126,7 @@ repositories { } dependencies { - compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '1.0.0' + compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '1.2.0' } ``` @@ -169,7 +172,7 @@ add the TensorFlow dependency to the project's `pom.xml` file: org.tensorflow tensorflow-core-platform - 1.1.0 + 1.2.0 diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 8729bca5af5..8a9fd4bf33c 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -2,7 +2,7 @@ site_name: '' site_url: https://tensorflow.org repo_url: https://github.com/tensorflow/java site_description: Documentation of TensorFlow Java API and tools. -copyright: "© TensorFlow Authors 2025" +copyright: "© TensorFlow Authors 2026" theme: name: material diff --git a/pom.xml b/pom.xml index d39dedd9e2a..b9e610ddf7a 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ org.tensorflow tensorflow-java - 1.2.0-SNAPSHOT + 1.2.0 pom TensorFlow Java Parent @@ -543,7 +543,7 @@ ./docs/overview.md - Copyright 2015, 2025 The TensorFlow Authors. All Rights Reserved. + Copyright 2015, 2026 The TensorFlow Authors. All Rights Reserved. -Xmaxerrs 65536 @@ -554,7 +554,6 @@ 256m 2048m - https://tensorflow.github.io/java/javadoc-ndarray/v1.0.0/ https://protobuf.dev/reference/java/api-docs https://bytedeco.org/javacpp/apidocs diff --git a/tensorflow-core/pom.xml b/tensorflow-core/pom.xml index cc87f6a76bc..f659865a2bf 100644 --- a/tensorflow-core/pom.xml +++ b/tensorflow-core/pom.xml @@ -22,7 +22,7 @@ org.tensorflow tensorflow-java - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-core pom diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index a4cd84dcf20..a0bcbb52b87 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -6,7 +6,7 @@ org.tensorflow tensorflow-core - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-core-api jar diff --git a/tensorflow-core/tensorflow-core-generator/pom.xml b/tensorflow-core/tensorflow-core-generator/pom.xml index bb532f5deab..91bf6541945 100644 --- a/tensorflow-core/tensorflow-core-generator/pom.xml +++ b/tensorflow-core/tensorflow-core-generator/pom.xml @@ -5,7 +5,7 @@ org.tensorflow tensorflow-core - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-core-generator jar diff --git a/tensorflow-core/tensorflow-core-native/pom.xml b/tensorflow-core/tensorflow-core-native/pom.xml index 2e9102b450b..db40c870bd8 100644 --- a/tensorflow-core/tensorflow-core-native/pom.xml +++ b/tensorflow-core/tensorflow-core-native/pom.xml @@ -6,7 +6,7 @@ org.tensorflow tensorflow-core - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-core-native jar diff --git a/tensorflow-core/tensorflow-core-platform/pom.xml b/tensorflow-core/tensorflow-core-platform/pom.xml index ca6a014f0ae..9cabb7f5d2a 100644 --- a/tensorflow-core/tensorflow-core-platform/pom.xml +++ b/tensorflow-core/tensorflow-core-platform/pom.xml @@ -22,7 +22,7 @@ org.tensorflow tensorflow-core - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-core-platform TensorFlow API Platform diff --git a/tensorflow-framework/pom.xml b/tensorflow-framework/pom.xml index 03745bfea9a..be65cd8f6a3 100644 --- a/tensorflow-framework/pom.xml +++ b/tensorflow-framework/pom.xml @@ -22,7 +22,7 @@ org.tensorflow tensorflow-java - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-framework jar diff --git a/tensorflow-ndarray/pom.xml b/tensorflow-ndarray/pom.xml index 8f1df831143..20f3651315d 100644 --- a/tensorflow-ndarray/pom.xml +++ b/tensorflow-ndarray/pom.xml @@ -22,7 +22,7 @@ org.tensorflow tensorflow-java - 1.2.0-SNAPSHOT + 1.2.0 tensorflow-ndarray jar From eac50137b65c9c43b81a95e539e59260b30c8a18 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Fri, 29 May 2026 20:04:15 +0200 Subject: [PATCH 2/3] Cache graph functions and add If gradient test (#637) * Cache graph functions and add If gradient test * Expose cached graph function names --- .../src/main/java/org/tensorflow/Graph.java | 45 +++ .../org/tensorflow/op/AttributeMetadata.java | 1 + .../java/org/tensorflow/IfGradientTest.java | 290 ++++++++++++++++++ 3 files changed, 336 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 488434c56f2..4302b722746 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -43,6 +43,7 @@ import java.util.Queue; import java.util.Set; import java.util.WeakHashMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; @@ -396,8 +397,21 @@ public GraphOperationBuilder opBuilder(String type, String name, Scope scope) { return new GraphOperationBuilder(this, type, name, scope, dangerousGradientBuilder); } + /** + * Attaches a {@link ConcreteFunction} to this graph. + * + *

If a function with the same defined name has already been attached, this method returns + * immediately without re-registering it. + * + *

The function is also stored in an internal cache to speed up subsequent lookups performed by + * {@link #getFunction(String)}. + */ @Override public void attachFunction(ConcreteFunction function) { + String name = function.getDefinedName(); + if (functionCache.putIfAbsent(name, function) != null) { + return; + } try (Reference ref = ref(); PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); @@ -455,6 +469,10 @@ List getNativeFunctions(PointerScope outerScope) { * name */ public ConcreteFunction getFunction(String key) { + ConcreteFunction cached = functionCache.get(key); + if (cached != null) { + return cached; + } try (Reference ref = ref(); PointerScope scope = new PointerScope()) { List funcs = getNativeFunctions(scope); @@ -881,6 +899,33 @@ Set initializers() { private final Set initializers = Collections.synchronizedSet(new LinkedHashSet<>()); private int newInitializersMarker = -1; + /** + * Cache of {@link ConcreteFunction}s attached to this graph, indexed by their defined name. + * + *

This cache avoids repeatedly scanning the native function library when resolving functions + * during gradient construction or control-flow expansion. + * + *

The cache is populated lazily when {@link #attachFunction(ConcreteFunction)} is called and + * consulted first by {@link #getFunction(String)}. + * + *

A {@link ConcurrentHashMap} is used to allow concurrent reads during graph building without + * additional synchronization. + */ + private final ConcurrentHashMap functionCache = + new ConcurrentHashMap<>(); + + /** + * Returns a read-only view of the function names cached by this graph. + * + *

This exposes only the function names so callers can resolve ambiguous matches themselves + * before calling {@link #getFunction(String)} with an exact name. + * + * @return a read-only view of cached function names + */ + public Set functionNames() { + return Collections.unmodifiableSet(functionCache.keySet()); + } + /** * Use builders without locking. This should only be used during custom gradient building. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java index 29ff5950786..83b71fdcc42 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java @@ -30,6 +30,7 @@ public class AttributeMetadata { /** The size of the list if this attribute is a list, undefined otherwise. */ public final long listSize; + /** * The type of this attribute, or the type of the list values if it is a list. * diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java new file mode 100644 index 00000000000..6a02503578a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IfGradientTest.java @@ -0,0 +1,290 @@ +/* + Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== + */ +package org.tensorflow; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Gradients; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.StatefulIf; +import org.tensorflow.op.core.StatefulPartitionedCall; +import org.tensorflow.op.core.StatelessIf; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +public class IfGradientTest { + + private static ConcreteFunction thenFn() { + return ConcreteFunction.create( + (Ops tf) -> { + Placeholder x = tf.placeholder(TFloat32.class); + Operand y = tf.math.mul(x, tf.constant(3.0f)); + return Signature.builder("thenBranch").input("x", x).output("y", y).build(); + }); + } + + private static ConcreteFunction elseFn() { + return ConcreteFunction.create( + (Ops tf) -> { + Placeholder x = tf.placeholder(TFloat32.class); + Operand y = tf.math.mul(x, tf.constant(5.0f)); + return Signature.builder("elseBranch").input("x", x).output("y", y).build(); + }); + } + + private static void assertClose(float got, float expected, float eps, String msg) { + if (Math.abs(got - expected) > eps) { + throw new AssertionError(msg + " (got=" + got + ", expected=" + expected + ")"); + } + } + + private static void primeIfGradFunctions(Graph g) { + + Iterator operations = g.operations(); + while (operations.hasNext()) { + GraphOperation op = operations.next(); + String type = op.type(); + if (!StatefulIf.OP_NAME.equals(type) && !StatelessIf.OP_NAME.equals(type)) continue; + + ConcreteFunction thenFwd = op.attributes().getAttrFunction("then_branch"); + ConcreteFunction elseFwd = op.attributes().getAttrFunction("else_branch"); + + int nInputs = op.inputListLength("input"); + int nOut = op.numOutputs(); + + List> tin = new ArrayList<>(nInputs); + for (int i = 0; i < nInputs; i++) { + Class c = op.input(1 + i).asOutput().type(); + tin.add(c); + } + + List> tout = new ArrayList<>(nOut); + for (int i = 0; i < nOut; i++) { + Class c = op.output(i).type(); + tout.add(c); + } + + ConcreteFunction thenGrad = buildBranchGradFn(op.name() + "/then_grad", thenFwd, tin, tout); + ConcreteFunction elseGrad = buildBranchGradFn(op.name() + "/else_grad", elseFwd, tin, tout); + + g.attachFunction(thenGrad); + g.attachFunction(elseGrad); + } + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static ConcreteFunction buildBranchGradFn( + String prefix, + ConcreteFunction branchFn, + List> tin, + List> toutForward) { + + return ConcreteFunction.create( + (Ops tf) -> { + Signature.Builder sig = Signature.builder(prefix); + + List> x = new ArrayList<>(tin.size()); + for (int i = 0; i < tin.size(); i++) { + Placeholder ph = tf.placeholder((Class) tin.get(i)); + x.add(ph); + sig.input("x" + i, ph); + } + + List> dy = new ArrayList<>(toutForward.size()); + for (int i = 0; i < toutForward.size(); i++) { + Placeholder ph = tf.placeholder((Class) toutForward.get(i)); + dy.add(ph); + sig.input("dy" + i, ph); + } + + StatefulPartitionedCall yCall = + StatefulPartitionedCall.create(tf.scope(), x, toutForward, branchFn); + + Operand L = tf.constant(0.0f); + for (int i = 0; i < toutForward.size(); i++) { + Operand prod = tf.math.mul((Operand) yCall.output().get(i), (Operand) dy.get(i)); + L = tf.math.add((Operand) L, (Operand) sumAll(tf, prod)); + } + + Gradients g = tf.gradients((Iterable) List.of((Operand) L), x); + + for (int i = 0; i < tin.size(); i++) { + Operand dx = g.dy(i); + sig.output("dx" + i, dx); + } + + return sig.build(); + }); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static Operand sumAll(Ops tf, Operand v) { + Operand r = tf.rank(v); + Operand axes = tf.range(tf.constant(0), r, tf.constant(1)); + return tf.reduceSum((Operand) v, axes); + } + + private static ConcreteFunction getSingleFunctionByPrefix(Graph graph, String prefix) { + List matches = + graph.functionNames().stream() + .filter(name -> name.startsWith(prefix)) + .collect(Collectors.toList()); + if (matches.size() != 1) { + throw new IllegalStateException( + "Expected one cached function for prefix=" + prefix + ", found=" + matches); + } + return graph.getFunction(matches.get(0)); + } + + @Test + public void testStatefullIfGradient() { + TensorFlow.registerCustomGradient( + StatefulIf.OP_NAME, + (tf, op, gradOutputs) -> { + OperationAttributeInspector attrs = op.attributes(); + ConcreteFunction thenBranch = attrs.getAttrFunction("then_branch"); + ConcreteFunction elseBranch = attrs.getAttrFunction("else_branch"); + + if (thenBranch == null || elseBranch == null) { + int n = 1 + op.inputListLength("input"); + List> no = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + no.add(null); + } + return no; + } + + Operand cond = op.input(0); + int nInputs = op.inputListLength("input"); + List> inputs = new ArrayList<>(nInputs); + for (int i = 0; i < nInputs; i++) { + inputs.add(op.input(1 + i)); + } + + int nOut = op.numOutputs(); + List> toutForward = new ArrayList<>(nOut); + for (int i = 0; i < nOut; i++) { + toutForward.add(op.output(i).type()); + } + + List> tin = + inputs.stream().map(input -> input.asOutput().type()).collect(Collectors.toList()); + List> dys = new ArrayList<>(nOut); + for (int i = 0; i < nOut; i++) { + Operand dy = null; + if (gradOutputs != null && i < gradOutputs.size()) { + dy = gradOutputs.get(i); + } + if (dy == null) { + dy = + gradOutputs == null || gradOutputs.isEmpty() + ? tf.onesLike((Operand) op.output(i)) + : tf.zerosLike((Operand) op.output(i)); + } + dys.add(dy); + } + + List> input = new ArrayList<>(nInputs + nOut); + input.addAll(inputs); + input.addAll(dys); + + final String thenPrefix = op.name() + "/then_grad"; // op has unique name + final String elsePrefix = op.name() + "/else_grad"; + + ConcreteFunction thenGrad = getSingleFunctionByPrefix(op.env(), thenPrefix); + ConcreteFunction elseGrad = getSingleFunctionByPrefix(op.env(), elsePrefix); + + if (thenGrad == null || elseGrad == null) { + throw new IllegalStateException("If grad functions not primed for op=" + op.name()); + } + StatefulIf dInputsIf = + StatefulIf.create(tf.scope(), cond, input, tin, thenGrad, elseGrad); + List> result = new ArrayList<>(1 + nInputs); + result.add(null); // no gradient for condition + result.addAll(dInputsIf.output()); + return result; + }); + + Graph g = new Graph(); + Ops tf = Ops.create(g); + + var x = tf.placeholder(TFloat32.class); // scalar + var cond = tf.placeholder(TBool.class); // scalar + + try (ConcreteFunction thenBranch = thenFn(); + ConcreteFunction elseBranch = elseFn()) { + + StatefulIf ifOp = + StatefulIf.create( + tf.scope(), + cond, + List.of((Operand) x), + List.of(TFloat32.class), + thenBranch, + elseBranch); + + var y = ifOp.output().get(0); + + primeIfGradFunctions(g); + + var dy_dx = g.addGradients(y, new Output[] {x.asOutput()})[0]; + + try (Session session = new Session(g)) { + + try (Result r = + session + .runner() + .feed(x, TFloat32.scalarOf(2.0f)) + .feed(cond, TBool.scalarOf(true)) + .fetch(y) + .fetch(dy_dx) + .run()) { + + float yVal = ((TFloat32) r.get(0)).getFloat(); + float gVal = ((TFloat32) r.get(1)).getFloat(); + + assertClose(yVal, 6.0f, 1e-6f, "y mismatch for cond=true"); + assertClose(gVal, 3.0f, 1e-6f, "grad mismatch for cond=true"); + } + + // ---- cond=false + try (Result r = + session + .runner() + .feed(x, TFloat32.scalarOf(2.0f)) + .feed(cond, TBool.scalarOf(false)) + .fetch(y) + .fetch(dy_dx) + .run()) { + + float yVal = ((TFloat32) r.get(0)).getFloat(); + float gVal = ((TFloat32) r.get(1)).getFloat(); + assertClose(yVal, 10.0f, 1e-6f, "y mismatch for cond=false"); + assertClose(gVal, 5.0f, 1e-6f, "grad mismatch for cond=false"); + } + } + ; + } + } + ; +} From fa8adb832e83e9f0de1f0e2606651cd645d625ad Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 29 May 2026 14:09:44 -0400 Subject: [PATCH 3/3] Update copyright section with documentation link (#627) Added a link to the TensorFlow-Java main documentation in the copyright section. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b9e610ddf7a..f286ff996e3 100644 --- a/pom.xml +++ b/pom.xml @@ -543,7 +543,7 @@ ./docs/overview.md - Copyright 2015, 2026 The TensorFlow Authors. All Rights Reserved. + Copyright 2015, 2026 The TensorFlow Authors. All Rights Reserved. TensorFlow-Java Main Documentation]]> -Xmaxerrs 65536